未验证 提交 ec4b3b89 编写于 作者: 王肖 提交者: GitHub

Update cls.py

上级 32281ad9
......@@ -98,7 +98,7 @@ class Classify(Head):
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)):
label = 0 if self._preds[i][0] > self._preds[i][1] else 1
label = np.argmax(np.array(self._preds[i]))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
result = json.dumps(result)
writer.write(result+'\n')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册