提交 bb9803a0 编写于 作者: W wangxiao1021

remove dropout in predict, fix #77, update postprocess

上级 82874d8f
...@@ -42,8 +42,8 @@ class BERT(Backbone): ...@@ -42,8 +42,8 @@ class BERT(Backbone):
self._hidden_act = hidden_act self._hidden_act = hidden_act
self._prepostprocess_dropout = hidden_dropout_prob self._prepostprocess_dropout = 0. if phase == 'predict' else hidden_dropout_prob
self._attention_dropout = attention_probs_dropout_prob self._attention_dropout = 0. if phase == 'predict' else attention_probs_dropout_prob
self._word_emb_name = "word_embedding" self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
......
...@@ -45,8 +45,8 @@ class ERNIE(Backbone): ...@@ -45,8 +45,8 @@ class ERNIE(Backbone):
self._task_types = task_type_vocab_size self._task_types = task_type_vocab_size
self._hidden_act = hidden_act self._hidden_act = hidden_act
self._prepostprocess_dropout = hidden_dropout_prob self._prepostprocess_dropout = 0. if phase == 'predict' else hidden_dropout_prob
self._attention_dropout = attention_probs_dropout_prob self._attention_dropout = 0. if phase == 'predict' else attention_probs_dropout_prob
self._word_emb_name = "word_embedding" self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
......
...@@ -94,14 +94,17 @@ class Classify(Head): ...@@ -94,14 +94,17 @@ class Classify(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') for i in range(len(self._preds)):
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: label = int(np.argmax(np.array(self._preds[i])))
for i in range(len(self._preds)): result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
label = int(np.argmax(np.array(self._preds[i]))) results.append(result)
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} if output_dir is not None:
result = json.dumps(result) with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
writer.write(result+'\n') for result in results:
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -174,15 +174,18 @@ class Match(Head): ...@@ -174,15 +174,18 @@ class Match(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') for i in range(len(self._preds)):
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: if self._learning_strategy == 'pointwise':
for i in range(len(self._preds)): label = int(np.argmax(np.array(self._preds[i])))
if self._learning_strategy == 'pointwise': result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
label = int(np.argmax(np.array(self._preds[i]))) elif self._learning_strategy == 'pairwise':
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} result = {'index': i, 'probs': self._preds[i][0]}
elif self._learning_strategy == 'pairwise': results.append(result)
result = {'index': i, 'probs': self._preds[i][0]} if output_dir is not None:
result = json.dumps(result, ensure_ascii=False) with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
writer.write(result+'\n') for result in results:
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) result = json.dumps(result, ensure_ascii=False)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -128,13 +128,15 @@ class MaskLM(Head): ...@@ -128,13 +128,15 @@ class MaskLM(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: results = []
for p in self._preds: for i in range(len(self._preds)):
print(p) result = {'index': i, 'word_id': self._preds[i]}
else: results.append(result)
if output_dir is not None:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for p in self._preds: for result in results:
writer.write(str(p)+'\n') result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
return results
...@@ -154,21 +154,21 @@ class MRC(Head): ...@@ -154,21 +154,21 @@ class MRC(Head):
"""(optional interface) this func will be called after evaluation/predicting process and each epoch during training process.""" """(optional interface) this func will be called after evaluation/predicting process and each epoch during training process."""
if not self._is_training: if not self._is_training:
if output_dir is None: if output_dir is not None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') examples = post_inputs['reader']['examples']
examples = post_inputs['reader']['examples'] features = post_inputs['reader']['features']
features = post_inputs['reader']['features'] if not os.path.exists(output_dir):
if not os.path.exists(output_dir): os.makedirs(output_dir)
os.makedirs(output_dir) output_prediction_file = os.path.join(output_dir, "predictions.json")
output_prediction_file = os.path.join(output_dir, "predictions.json") output_nbest_file = os.path.join(output_dir, "nbest_predictions.json")
output_nbest_file = os.path.join(output_dir, "nbest_predictions.json") output_null_log_odds_file = os.path.join(output_dir, "null_odds.json")
output_null_log_odds_file = os.path.join(output_dir, "null_odds.json") _write_predictions(examples, features, self._pred_results,
_write_predictions(examples, features, self._pred_results, self._n_best_size, self._max_answer_length,
self._n_best_size, self._max_answer_length, self._do_lower_case, output_prediction_file,
self._do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file,
output_nbest_file, output_null_log_odds_file, self._with_negative,
self._with_negative, self._null_score_diff_threshold, self._verbose)
self._null_score_diff_threshold, self._verbose) return self._pred_results
def _write_predictions(all_examples, all_features, all_results, n_best_size, def _write_predictions(all_examples, all_features, all_results, n_best_size,
......
...@@ -118,9 +118,9 @@ class SequenceLabel(Head): ...@@ -118,9 +118,9 @@ class SequenceLabel(Head):
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if output_dir is None: if output_dir is not None:
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for p in self._preds:
for p in self._preds: writer.write(str(p)+'\n')
writer.write(str(p)+'\n') print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) return self._preds
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册