未验证 提交 32281ad9 编写于 作者: 王肖 提交者: GitHub

Update match.py

上级 65688371
...@@ -179,10 +179,9 @@ class Match(Head): ...@@ -179,10 +179,9 @@ class Match(Head):
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
if self._learning_strategy == 'pointwise': if self._learning_strategy == 'pointwise':
label = 0 if self._preds[i][0] > self._preds[i][1] else 1 label = np.argmax(np.array(self._preds[i]))
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise': elif self._learning_strategy == 'pairwise':
label = 0 if self._preds[i][0] < 0.5 else 1
result = {'index': i, 'label': label, 'probs': self._preds[i][0]} result = {'index': i, 'label': label, 'probs': self._preds[i][0]}
result = json.dumps(result, ensure_ascii=False) result = json.dumps(result, ensure_ascii=False)
writer.write(result+'\n') writer.write(result+'\n')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册