diff --git a/paddlepalm/reader/utils/reader4ernie.py b/paddlepalm/reader/utils/reader4ernie.py index 4c014dd9a9ac0ffcba66aeae79b1819bb984d830..37b6396dd80a6e158bf06e295894a2094dfd16f6 100644 --- a/paddlepalm/reader/utils/reader4ernie.py +++ b/paddlepalm/reader/utils/reader4ernie.py @@ -639,7 +639,8 @@ class MRCReader(BaseReader): for_cn=True, task_id=0, doc_stride=128, - max_query_length=64): + max_query_length=64, + remove_noanswer=True): self.max_seq_len = max_seq_len self.tokenizer = tokenization.FullTokenizer( vocab_file=vocab_path, do_lower_case=do_lower_case) @@ -654,6 +655,7 @@ class MRCReader(BaseReader): self.max_query_length = max_query_length self.examples = {} self.features = {} + self.remove_noanswer = remove_noanswer if random_seed is not None: np.random.seed(random_seed) @@ -758,7 +760,7 @@ class MRCReader(BaseReader): return cur_span_index == best_span_index def _convert_example_to_feature(self, examples, max_seq_length, tokenizer, - is_training): + is_training, remove_noanswer=True): features = [] unique_id = 1000000000 @@ -845,6 +847,8 @@ class MRCReader(BaseReader): if out_of_span: start_position = 0 end_position = 0 + if remove_noanswer: + continue else: doc_offset = len(query_tokens) + 2 start_position = tok_start_position - doc_start + doc_offset @@ -958,7 +962,7 @@ class MRCReader(BaseReader): if not examples: examples = self._read_json(input_file, phase == "train") features = self._convert_example_to_feature( - examples, self.max_seq_len, self.tokenizer, phase == "train") + examples, self.max_seq_len, self.tokenizer, phase == "train", remove_noanswer=self.remove_noanswer) self.examples[phase] = examples self.features[phase] = features