提交 c77d0dee 编写于 作者: S Steffy-zxf

fix return_result bug

上级 8666b7b8
......@@ -150,7 +150,7 @@ class ClassifierTask(BaseTask):
results = []
for batch_state in run_states:
batch_result = batch_state.run_results
batch_infer = np.argmax(batch_result, axis=2)[0]
batch_infer = np.argmax(batch_result[0], axis=1)
results += [id2label[sample_infer] for sample_infer in batch_infer]
return results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册