diff --git a/paddlepalm/head/match.py b/paddlepalm/head/match.py index ec6bf40d9fa878d6b304d345026371a656dd935b..7742f4a11aff065450fc926f487aa9d48537c57f 100644 --- a/paddlepalm/head/match.py +++ b/paddlepalm/head/match.py @@ -179,10 +179,9 @@ class Match(Head): with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): if self._learning_strategy == 'pointwise': - label = 0 if self._preds[i][0] > self._preds[i][1] else 1 + label = np.argmax(np.array(self._preds[i])) result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} elif self._learning_strategy == 'pairwise': - label = 0 if self._preds[i][0] < 0.5 else 1 result = {'index': i, 'label': label, 'probs': self._preds[i][0]} result = json.dumps(result, ensure_ascii=False) writer.write(result+'\n')