diff --git a/paddlepalm/head/cls.py b/paddlepalm/head/cls.py index 499c0f77d82f36872b745f94b94bd7ef89bf1727..66117ac8810b9844f9ee2f73972b1090aa470122 100644 --- a/paddlepalm/head/cls.py +++ b/paddlepalm/head/cls.py @@ -98,7 +98,7 @@ class Classify(Head): raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): - label = 0 if self._preds[i][0] > self._preds[i][1] else 1 + label = np.argmax(np.array(self._preds[i])) result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = json.dumps(result) writer.write(result+'\n')