未验证 提交 0f062464 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #79 from wangxiao1021/api

remove dropout in predict, fix #77, update postprocess
...@@ -42,8 +42,8 @@ class BERT(Backbone): ...@@ -42,8 +42,8 @@ class BERT(Backbone):
self._hidden_act = hidden_act self._hidden_act = hidden_act
self._prepostprocess_dropout = hidden_dropout_prob self._prepostprocess_dropout = 0. if phase == 'predict' else hidden_dropout_prob
self._attention_dropout = attention_probs_dropout_prob self._attention_dropout = 0. if phase == 'predict' else attention_probs_dropout_prob
self._word_emb_name = "word_embedding" self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
......
...@@ -45,8 +45,8 @@ class ERNIE(Backbone): ...@@ -45,8 +45,8 @@ class ERNIE(Backbone):
self._task_types = task_type_vocab_size self._task_types = task_type_vocab_size
self._hidden_act = hidden_act self._hidden_act = hidden_act
self._prepostprocess_dropout = hidden_dropout_prob self._prepostprocess_dropout = 0. if phase == 'predict' else hidden_dropout_prob
self._attention_dropout = attention_probs_dropout_prob self._attention_dropout = 0. if phase == 'predict' else attention_probs_dropout_prob
self._word_emb_name = "word_embedding" self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
......
...@@ -94,14 +94,17 @@ class Classify(Head): ...@@ -94,14 +94,17 @@ class Classify(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
label = int(np.argmax(np.array(self._preds[i]))) label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for result in results:
result = json.dumps(result) result = json.dumps(result)
writer.write(result+'\n') writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -174,15 +174,18 @@ class Match(Head): ...@@ -174,15 +174,18 @@ class Match(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
if self._learning_strategy == 'pointwise': if self._learning_strategy == 'pointwise':
label = int(np.argmax(np.array(self._preds[i]))) label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise': elif self._learning_strategy == 'pairwise':
result = {'index': i, 'probs': self._preds[i][0]} result = {'index': i, 'probs': self._preds[i][0]}
results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for result in results:
result = json.dumps(result, ensure_ascii=False) result = json.dumps(result, ensure_ascii=False)
writer.write(result+'\n') writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -128,13 +128,15 @@ class MaskLM(Head): ...@@ -128,13 +128,15 @@ class MaskLM(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
for p in self._preds: for i in range(len(self._preds)):
print(p) result = {'index': i, 'word_id': self._preds[i]}
else: results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for p in self._preds: for result in results:
writer.write(str(p)+'\n') result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -154,8 +154,7 @@ class MRC(Head): ...@@ -154,8 +154,7 @@ class MRC(Head):
"""(optional interface) this func will be called after evaluation/predicting process and each epoch during training process.""" """(optional interface) this func will be called after evaluation/predicting process and each epoch during training process."""
if not self._is_training: if not self._is_training:
if output_dir is None: if output_dir is not None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
examples = post_inputs['reader']['examples'] examples = post_inputs['reader']['examples']
features = post_inputs['reader']['features'] features = post_inputs['reader']['features']
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
...@@ -169,6 +168,7 @@ class MRC(Head): ...@@ -169,6 +168,7 @@ class MRC(Head):
output_nbest_file, output_null_log_odds_file, output_nbest_file, output_null_log_odds_file,
self._with_negative, self._with_negative,
self._null_score_diff_threshold, self._verbose) self._null_score_diff_threshold, self._verbose)
return self._pred_results
def _write_predictions(all_examples, all_features, all_results, n_best_size, def _write_predictions(all_examples, all_features, all_results, n_best_size,
......
...@@ -118,9 +118,9 @@ class SequenceLabel(Head): ...@@ -118,9 +118,9 @@ class SequenceLabel(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: if output_dir is not None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for p in self._preds: for p in self._preds:
writer.write(str(p)+'\n') writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return self._preds
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册