【关系抽取-mre-in-one-pass】加载数据(一)

2021-03-22 10:27:40 浏览数 (1)

模型训练命令

代码语言:javascript复制
python run_classifier.py 
        --task_name=semeval 
        --do_train=true 
        --do_eval=false 
        --do_predict=false 
        --data_dir=$DATA_DIR/semeval2018/multi 
        --vocab_file=$BERT_BASE_DIR/vocab.txt 
        --bert_config_file=$BERT_BASE_DIR/bert_config.json 
        --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt 
        --max_seq_length=128 
        --train_batch_size=4 
        --learning_rate=2e-5 
        --num_train_epochs=30 
        --max_distance=2 
        --max_num_relations=12 
        --output_dir=<path to store the checkpoint>

我们从这个开始看起,是为了看看我们需要用到的一些参数。

数据是怎么生成的

在run_classifier.py的main()函数中: 这里定义了一个字典,用于存放不同数据的处理器。

代码语言:javascript复制
    processors = {
        "semeval": SemEvalProcessor,
        "ace": ACEProcessor
    }

接着定义了一个tokenization,用于分词。其中do_lower_case用于是否忽略大小写,init_checkpoint是模型的地址。

代码语言:javascript复制
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

通过任务名初始化指定的数据集:

代码语言:javascript复制
    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

这里引用的类是:

代码语言:javascript复制
class SemEvalProcessor(DataProcessor):
    """Processor for the SemEval data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "SemEval.train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "SemEval.test.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "SemEval.test.tsv")), "test")

    def get_labels(self, data_dir):
        """See base class."""
        label_list = []
        filein = open(os.path.join(data_dir, "SemEval.label.tsv"))
        for line in filein:
            label = line.strip()
            label_list.append(tokenization.convert_to_unicode(label))
        return label_list

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []

        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = tokenization.convert_to_unicode(line[0])
            num_relations = int((len(line) - 1) / 7)
            locations = list()
            labels = list()
            for j in range(num_relations):
                label = tokenization.convert_to_unicode(line[j * 7   1])
                labels.append(label)
                # (lo, hi)
                entity_pos1 = (int(line[j * 7   2])   1, int(line[j * 7   3])   1)
                entity_pos2 = (int(line[j * 7   5])   1, int(line[j * 7   6])   1)
                # [((lo1,hi1), (lo2, hi2))]
                locations.append((entity_pos1, entity_pos2))
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None,
                             locations=locations, labels=labels, num_relations=num_relations))

        return examples
```# 模型训练命令
```python
python run_classifier.py 
        --task_name=semeval 
        --do_train=true 
        --do_eval=false 
        --do_predict=false 
        --data_dir=$DATA_DIR/semeval2018/multi 
        --vocab_file=$BERT_BASE_DIR/vocab.txt 
        --bert_config_file=$BERT_BASE_DIR/bert_config.json 
        --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt 
        --max_seq_length=128 
        --train_batch_size=4 
        --learning_rate=2e-5 
        --num_train_epochs=30 
        --max_distance=2 
        --max_num_relations=12 
        --output_dir=<path to store the checkpoint>

我们从这个开始看起,是为了看看我们需要用到的一些参数。

数据是怎么生成的

在run_classifier.py的main()函数中: 这里定义了一个字典,用于存放不同数据的处理器。

代码语言:javascript复制
    processors = {
        "semeval": SemEvalProcessor,
        "ace": ACEProcessor
    }

接着定义了一个tokenization,用于分词。其中do_lower_case用于是否忽略大小写,init_checkpoint是模型的地址。

代码语言:javascript复制
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

通过任务名初始化指定的数据集:

代码语言:javascript复制
    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

这里引用的类是:

代码语言:javascript复制
class SemEvalProcessor(DataProcessor):
    """Processor for the SemEval data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "SemEval.train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "SemEval.test.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "SemEval.test.tsv")), "test")

    def get_labels(self, data_dir):
        """See base class."""
        label_list = []
        filein = open(os.path.join(data_dir, "SemEval.label.tsv"))
        for line in filein:
            label = line.strip()
            label_list.append(tokenization.convert_to_unicode(label))
        return label_list

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []

        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = tokenization.convert_to_unicode(line[0])
            num_relations = int((len(line) - 1) / 7)
            locations = list()
            labels = list()
            for j in range(num_relations):
                label = tokenization.convert_to_unicode(line[j * 7   1])
                labels.append(label)
                # (lo, hi)
                entity_pos1 = (int(line[j * 7   2])   1, int(line[j * 7   3])   1)
                entity_pos2 = (int(line[j * 7   5])   1, int(line[j * 7   6])   1)
                # [((lo1,hi1), (lo2, hi2))]
                locations.append((entity_pos1, entity_pos2))
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None,
                             locations=locations, labels=labels, num_relations=num_relations))

        return examples

使用的数据如下:

会先读取tsv文件里面的数据,然后调用_create_examples(self, lines, *set_type)函数。 看一下标签数据:

代码语言:javascript复制
COMPARE
MODEL-FEATURE
PART_WHOLE
RESULT
TOPIC
USAGE

看一下训练数据:

代码语言:javascript复制
a large database . Traditional information retrieval techniques use a histogram of keywords as the document representation but oral communication may offer additional indices such as the time and is shown on a large database of TV shows . Emotions and other indices	USAGE	12	12	H01-1001.7	5	7	H01-1001.5	USAGE	18	19	H01-1001.9	23	23	H01-1001.10	PART_WHOLE	36	37	H01-1001.15	34	34	H01-1001.14
funding the development of a distributed message-passing infrastructure for dialogue systems which all Communicator participants are	MODEL-FEATURE	5	7	H01-1017.4	9	10	H01-1017.5
Lincoln Laboratory ) . The CCLINC Korean-to-English translation system consists of two core modules , language understanding and generation ( i ) Robust efficient parsing of Korean ( a verb final language with overt case markers , relatively free word order ( ii ) High quality translation via word sense disambiguation and accurate word order generation	PART_WHOLE	12	13	H01-1041.4	5	8	H01-1041.3	USAGE	24	24	H01-1041.8	26	26	H01-1041.9	MODEL-FEATURE	33	35	H01-1041.11	29	31	H01-1041.10	USAGE	48	50	H01-1041.15	46	46	H01-1041.14
test the efficacy of applying automated evaluation techniques , originally devised for the evaluation of human language learners , to the output of machine translation ( MT experiments , looks at the intelligibility of MT output . A language learning experiment	USAGE	5	7	H01-1042.1	21	21	H01-1042.3	MODEL-FEATURE	32	32	H01-1042.10	34	35	H01-1042.11
sources . We integrate a spoken language understanding system with intelligent mobile agents that mediate between users and	PART_WHOLE	10	12	H01-1049.4	5	8	H01-1049.3
. We find that simple interpolation methods , like log-linear and linear interpolation , improve the performance but fall short of the word string and selects the word string with the best performance ( typically , word or word strings , where each word string has been obtained by using a different LM . Actually , the oracle acts like a dynamic combiner with hard decisions using the reference . We provide experimental results show the need for a dynamic language model combination to improve the performance further . We suggest a The method amounts to tagging LMs with confidence measures and picking the best hypothesis corresponding to the LM with the best confidence .	RESULT	5	6	H01-1058.2	16	16	H01-1058.4	RESULT	27	28	H01-1058.9	32	32	H01-1058.10	RESULT	52	52	H01-1058.14	43	44	H01-1058.13	USAGE	68	68	H01-1058.18	61	62	H01-1058.16	RESULT	79	82	H01-1058.19	86	86	H01-1058.20	MODEL-FEATURE	99	100	H01-1058.25	97	97	H01-1058.24	MODEL-FEATURE	113	113	H01-1058.28	109	109	H01-1058.27
approach employing n-gram models and error-correction rules for Thai key prediction and Thai-English language identification . also proposes rule-reduction algorithm applying mutual information to reduce the error-correction rules . Our algorithm reported more than 99 % accuracy in both language identification and key prediction .	USAGE	5	6	H01-1070.2	8	10	H01-1070.3	USAGE	21	22	H01-1070.6	26	27	H01-1070.7	RESULT	39	40	H01-1070.9	36	36	H01-1070.8
potentially large list of possible sentence plans for a given text-plan input . Second , the sentence-plan-ranker plan . The SPR uses ranking rules automatically learned from training data . We show that the SPR learns to select a sentence plan whose rating on average is only 5 % worse than the top human-ranked sentence plan .	MODEL-FEATURE	5	6	N01-1003.12	10	11	N01-1003.13	USAGE	27	28	N01-1003.19	22	23	N01-1003.18	COMPARE	39	40	N01-1003.21	52	55	N01-1003.22
and segment contiguity on the retrieval performance of a translation memory system . We take a selection of both bag-of-words and segment order-sensitive string comparison methods , and run each over both character- and word-segmented data , in combination with a datasets , we find that indexing according to simple character bigrams produces a retrieval accuracy superior in their optimum configuration , bag-of-words methods are shown to be equivalent to segment order-sensitive methods in terms of retrieval accuracy	RESULT	9	11	P01-1004.5	5	6	P01-1004.4	USAGE	19	25	P01-1004.6	32	35	P01-1004.7	USAGE	50	51	P01-1004.12	46	46	P01-1004.11	COMPARE	62	63	P01-1004.16	70	72	P01-1004.17
The theoretical study of the range concatenation grammar [ RCG ] formalism has revealed many attractive properties which may be used in NLP . In particular , range concatenation languages [ RCL ] can be parsed in polynomial time and many classical grammatical formalisms an equivalent RCG , any tree adjoining grammar can be parsed in O ( n6 ) time . In this paper , we study a parsing technique whose purpose is to improve the practical efficiency of RCL parsers . The non-deterministic parsing choices L are directed by a guide which uses the shared derivation forest output by a prior RCL parser for a suitable superset of L . The results of a	USAGE	5	11	P01-1007.1	22	22	P01-1007.2	MODEL-FEATURE	37	38	P01-1007.4	27	32	P01-1007.3	MODEL-FEATURE	56	60	P01-1007.11	49	51	P01-1007.10	USAGE	69	70	P01-1007.12	80	81	P01-1007.13	USAGE	96	98	P01-1007.18	92	92	P01-1007.17	USAGE	103	104	P01-1007.19	108	110	P01-1007.20
from a Parallel Corpus While paraphrasing is critical both for interpretation and generation of natural language , current systems use manual paraphrases . We present an unsupervised learning algorithm for identification of paraphrases from a corpus of multiple	USAGE	5	5	P01-1008.1	10	15	P01-1008.2	USAGE	26	28	P01-1008.4	30	32	P01-1008.5
Retrieval This paper presents a formal analysis for a large class of words called alternative markers , which includes other ( . I show that the performance of a search engine can be improved dramatically by incorporating an approximation of the formal analysis that is compatible with the approach is that as the operational semantics of natural language applications improve , even larger improvements	TOPIC	5	6	P01-1009.1	14	15	P01-1009.3	RESULT	41	42	P01-1009.14	26	26	P01-1009.12	PART_WHOLE	53	54	P01-1009.17	56	58	P01-1009.18
sensitive logic , and a learning algorithm from structured data ( based on a typing-algorithm	USAGE	8	9	P01-1047.11	5	6	P01-1047.10

这里打印一下_create_examples()的作用:

代码语言:javascript复制
0 ['a large database . Traditional information retrieval techniques use a histogram of keywords as the document representation but oral communication may offer additional indices such as the time and is shown on a large database of TV shows . Emotions and other indices', 'USAGE', '12', '12', 'H01-1001.7', '5', '7', 'H01-1001.5', 'USAGE', '18', '19', 'H01-1001.9', '23', '23', 'H01-1001.10', 'PART_WHOLE', '36', '37', 'H01-1001.15', '34', '34', 'H01-1001.14']
text_a= a large database . Traditional information retrieval techniques use a histogram of keywords as the document representation but oral communication may offer additional indices such as the time and is shown on a large database of TV shows . Emotions and other indices
num_relations= 3
labels= ['USAGE', 'USAGE', 'PART_WHOLE']
locations= [((13, 13), (6, 8)), ((19, 20), (24, 24)), ((37, 38), (35, 35))]

说明:因为一个关系是:关系标签、实体1首位、实体1末位、实体名、实体2首位、实体2末位、实体名,所以一个关系包含了7位,通过计算,可以计算出一个句子中有多少个关系。由于之后要在句子前增加一个[cls],所以所有的实体的索引都需要加1.最后用InputExample类包裹这些信息后用examples列表封装起来。 之后调用:

代码语言:javascript复制
train_examples = processor.get_train_examples(FLAGS.data_dir)

就得到了一个训练的examples列表,接着是:

代码语言:javascript复制
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
file_based_convert_examples_to_features(
            train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)

我们看一下file_based_convert_examples_to_features()这个函数:

代码语言:javascript复制
def file_based_convert_examples_to_features(
        examples, label_list, max_seq_length, tokenizer, output_file):
    """Convert a set of `InputExample`s to a TFRecord file."""

    writer = tf.python_io.TFRecordWriter(output_file)

    for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
            tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

        feature = convert_single_example(ex_index, example, label_list,
                                         max_seq_length, tokenizer)

        def create_int_feature(values):
            f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
            return f

        features = collections.OrderedDict()
        features["input_ids"] = create_int_feature(feature.input_ids)
        features["input_mask"] = create_int_feature(feature.input_mask)
        features["segment_ids"] = create_int_feature(feature.segment_ids)
        features["loc"] = create_int_feature(feature.loc)
        features["mas"] = create_int_feature(feature.mas)
        features["e1_mas"] = create_int_feature(feature.e1_mas)
        features["e2_mas"] = create_int_feature(feature.e2_mas)
        features["cls_mask"] = create_int_feature(feature.cls_mask)
        features["label_ids"] = create_int_feature(feature.label_id)

        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        writer.write(tf_example.SerializeToString())
    writer.close()

在这里面,对于每一个样本,即example:

代码语言:javascript复制
feature = convert_single_example(ex_index, example, label_list,
                                         max_seq_length, tokenizer)

都会调用一个convert_single_example(),我们看下这个函数:

代码语言:javascript复制
def convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer):
    """Converts a single `InputExample` into a single `InputFeatures`."""

    if isinstance(example, PaddingInputExample):
        return InputFeatures(
            input_ids=[0] * max_seq_length,
            input_mask=[0] * max_seq_length,
            segment_ids=[0] * max_seq_length,
            label_id=0,
            is_real_example=False)

    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i

    tokens_a, mapping_a = tokenizer.tokenize(example.text_a)
    tokens_b = None
    if example.text_b:
        tokens_b = tokenizer.tokenize(example.text_b)

    if tokens_b:
        # Modifies `tokens_a` and `tokens_b` in place so that the total
        # length is less than the specified length.
        # Account for [CLS], [SEP], [SEP] with "- 3"
        _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
    else:
        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[0:(max_seq_length - 2)]

    # The convention in BERT is:
    # (a) For sequence pairs:
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
    #  type_ids: 0     0  0    0    0     0       0 0     1  1  1  1   1 1
    # (b) For single sequences:
    #  tokens:   [CLS] the dog is hairy . [SEP]
    #  type_ids: 0     0   0   0  0     0 0
    #
    # Where "type_ids" are used to indicate whether this is the first
    # sequence or the second sequence. The embedding vectors for `type=0` and
    # `type=1` were learned during pre-training and are added to the wordpiece
    # embedding vector (and position vector). This is not *strictly* necessary
    # since the [SEP] token unambiguously separates the sequences, but it makes
    # it easier for the model to learn the concept of sequences.
    #
    # For classification tasks, the first vector (corresponding to [CLS]) is
    # used as the "sentence vector". Note that this only makes sense because
    # the entire model is fine-tuned.
    tokens = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in tokens_a:
        tokens.append(token)
        segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    if tokens_b:
        for token in tokens_b:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    loc, mas, e1_mas, e2_mas = prepare_extra_data(mapping_a, example.locations, FLAGS.max_distance)
    label_id = [label_map[label] for label in example.labels]
    label_id = label_id   [0] * (FLAGS.max_num_relations - len(label_id))
    cls_mask = [1] * example.num_relations   [0] * (FLAGS.max_num_relations - example.num_relations)

    np.set_printoptions(edgeitems=15)
    if ex_index < 5:
        tf.logging.info("*** Example ***")
        tf.logging.info("guid: %s" % (example.guid))
        tf.logging.info("tokens: %s" % " ".join(
            [tokenization.printable_text(x) for x in tokens]))
        tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        tf.logging.info("loc:")
        tf.logging.info("n"   str(loc))
        tf.logging.info("mas:")
        tf.logging.info("n"   str(mas))
        tf.logging.info("e1_mas:")
        tf.logging.info("n"   str(e1_mas))
        tf.logging.info("e2_mas:")
        tf.logging.info("n"   str(e2_mas))
        tf.logging.info("cls_mask:")
        tf.logging.info("n"   str(cls_mask))
        tf.logging.info("labels: %s" % " ".join([str(x) for x in label_id]))

    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        loc=loc.flatten(),
        mas=mas.flatten(),
        e1_mas=e1_mas.flatten(),
        e2_mas=e2_mas.flatten(),
        cls_mask=cls_mask,
        label_id=label_id)
    return feature

说明:

  • 将标签映射成数字
代码语言:javascript复制
    for (i, label) in enumerate(label_list):
        label_map[label] = i
  • 对每一个example中的句子进行分词处理,并将每个词分别映射成数字:
代码语言:javascript复制
tokens_a, mapping_a = tokenizer.tokenize(example.text_a)
['a', 'large', 'database', '.', 'Traditional', 'information', 're', '##tri', '##eval', 'techniques', 'use', 'a', 'his', '##to', '##gram', 'of', 'key', '##words', 'as', 'the', 'document', 'representation', 'but', 'oral', 'communication', 'may', 'offer', 'additional', 'in', '##dices', 'such', 'as', 'the', 'time', 'and', 'is', 'shown', 'on', 'a', 'large', 'database', 'of', 'TV', 'shows', '.', 'Em', '##otion', '##s', 'and', 'other', 'in', '##dices']
  • 对于句子大于最大长度的,进行截断
代码语言:javascript复制
        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[0:(max_seq_length - 2)]

由于加入了cls和sep这两个特殊字符,因此需要最大长度-2. 接下来关于input_ids,segments_ids,input_mask就是标准的bert输入了,没什么好讲的了。 主要看一下这个部分:

代码语言:javascript复制
loc, mas, e1_mas, e2_mas = prepare_extra_data(mapping_a, example.locations, FLAGS.max_distance)
    label_id = [label_map[label] for label in example.labels]
    label_id = label_id   [0] * (FLAGS.max_num_relations - len(label_id))
    cls_mask = [1] * example.num_relations   [0] * (FLAGS.max_num_relations - example.num_relations)

这里调用了prepare_extra_data函数,看下其作用:

代码语言:javascript复制
def prepare_extra_data(mapping, locs, max_distance):
  # [128, 128]
  res = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
  # [128, 128]
  mas = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
  # [12, 128]
  e1_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)
  # [12, 128]
  e2_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)

  entities = set()
  # 定义一个实体集合
  for loc in locs:
    entities.add(loc[0])
    entities.add(loc[1])
  # 遍历每一个实体
  for e in entities:
    (lo, hi) = e
    relative_position, _ = convert_entity_row(mapping, e, max_distance)
    sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
    sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
    if sub_lo1 == 0 and sub_hi1 == 0:
      continue
    if sub_lo2 == 0 and sub_hi2 == 0:
      continue
    # col
    res[:, sub_lo1:sub_hi2 1] = np.expand_dims(relative_position, -1)
    mas[1:, sub_lo1:sub_hi2 1] = 1

  for e in entities:
    (lo, hi) = e
    relative_position, _ = convert_entity_row(mapping, e, max_distance)
    sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
    sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
    if sub_lo1 == 0 and sub_hi1 == 0:
      continue
    if sub_lo2 == 0 and sub_hi2 == 0:
      continue
    # row
    res[sub_lo1:sub_hi2 1, :] = relative_position
    mas[sub_lo1:sub_hi2 1, 1:] = 1

  for idx, (e1,e2) in enumerate(locs):
    # e1
    (lo, hi) = e1
    _, mask = convert_entity_row(mapping, e1, max_distance)
    e1_mas[idx] = mask
    # e2
    (lo, hi) = e2
    _, mask = convert_entity_row(mapping, e2, max_distance)
    e2_mas[idx] = mask

  return res, mas, e1_mas, e2_mas

然后里面调用了两个函数:convert_entity_row和find_lo_hi。

接下来的是关于Entity-Aware Self-Attention based on Relative Distance,就放到下一节吧。

代码来源:https://sourcegraph.com/github.com/helloeve/mre-in-one-pass/-/blob/run_classifier.py#L379

0 人点赞