From ec4b3b89e3c240c3a17807be16180c0c3d3b6bac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=82=96?= Date: Fri, 13 Mar 2020 10:36:51 +0800 Subject: [PATCH] Update cls.py --- paddlepalm/head/cls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlepalm/head/cls.py b/paddlepalm/head/cls.py index 499c0f7..66117ac 100644 --- a/paddlepalm/head/cls.py +++ b/paddlepalm/head/cls.py @@ -98,7 +98,7 @@ class Classify(Head): raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): - label = 0 if self._preds[i][0] > self._preds[i][1] else 1 + label = np.argmax(np.array(self._preds[i])) result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = json.dumps(result) writer.write(result+'\n') -- GitLab