未验证 提交 813d728f 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #69 from wangxiao1021/api

fix bugs of head/match, head/cls
...@@ -98,7 +98,7 @@ class Classify(Head): ...@@ -98,7 +98,7 @@ class Classify(Head):
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
label = 0 if self._preds[i][0] > self._preds[i][1] else 1 label = np.argmax(np.array(self._preds[i]))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
result = json.dumps(result) result = json.dumps(result)
writer.write(result+'\n') writer.write(result+'\n')
......
...@@ -179,11 +179,10 @@ class Match(Head): ...@@ -179,11 +179,10 @@ class Match(Head):
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
if self._learning_strategy == 'pointwise': if self._learning_strategy == 'pointwise':
label = 0 if self._preds[i][0] > self._preds[i][1] else 1 label = np.argmax(np.array(self._preds[i]))
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise': elif self._learning_strategy == 'pairwise':
label = 0 if self._preds[i][0] < 0.5 else 1 result = {'index': i, 'probs': self._preds[i][0]}
result = {'index': i, 'label': label, 'probs': self._preds[i][0]}
result = json.dumps(result, ensure_ascii=False) result = json.dumps(result, ensure_ascii=False)
writer.write(result+'\n') writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册