提交 7a7e7551 编写于 作者: X xixiaoyao

add remove_noanswer

上级 fc69141e
...@@ -639,7 +639,8 @@ class MRCReader(BaseReader): ...@@ -639,7 +639,8 @@ class MRCReader(BaseReader):
for_cn=True, for_cn=True,
task_id=0, task_id=0,
doc_stride=128, doc_stride=128,
max_query_length=64): max_query_length=64,
remove_noanswer=True):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer( self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case) vocab_file=vocab_path, do_lower_case=do_lower_case)
...@@ -654,6 +655,7 @@ class MRCReader(BaseReader): ...@@ -654,6 +655,7 @@ class MRCReader(BaseReader):
self.max_query_length = max_query_length self.max_query_length = max_query_length
self.examples = {} self.examples = {}
self.features = {} self.features = {}
self.remove_noanswer = remove_noanswer
if random_seed is not None: if random_seed is not None:
np.random.seed(random_seed) np.random.seed(random_seed)
...@@ -758,7 +760,7 @@ class MRCReader(BaseReader): ...@@ -758,7 +760,7 @@ class MRCReader(BaseReader):
return cur_span_index == best_span_index return cur_span_index == best_span_index
def _convert_example_to_feature(self, examples, max_seq_length, tokenizer, def _convert_example_to_feature(self, examples, max_seq_length, tokenizer,
is_training): is_training, remove_noanswer=True):
features = [] features = []
unique_id = 1000000000 unique_id = 1000000000
...@@ -845,6 +847,8 @@ class MRCReader(BaseReader): ...@@ -845,6 +847,8 @@ class MRCReader(BaseReader):
if out_of_span: if out_of_span:
start_position = 0 start_position = 0
end_position = 0 end_position = 0
if remove_noanswer:
continue
else: else:
doc_offset = len(query_tokens) + 2 doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
...@@ -958,7 +962,7 @@ class MRCReader(BaseReader): ...@@ -958,7 +962,7 @@ class MRCReader(BaseReader):
if not examples: if not examples:
examples = self._read_json(input_file, phase == "train") examples = self._read_json(input_file, phase == "train")
features = self._convert_example_to_feature( features = self._convert_example_to_feature(
examples, self.max_seq_len, self.tokenizer, phase == "train") examples, self.max_seq_len, self.tokenizer, phase == "train", remove_noanswer=self.remove_noanswer)
self.examples[phase] = examples self.examples[phase] = examples
self.features[phase] = features self.features[phase] = features
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册