From 32281ad9f1a3bdb64a0a4a4836e6a367ea7c8c28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=82=96?= Date: Fri, 13 Mar 2020 10:36:30 +0800 Subject: [PATCH] Update match.py --- paddlepalm/head/match.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddlepalm/head/match.py b/paddlepalm/head/match.py index ec6bf40..7742f4a 100644 --- a/paddlepalm/head/match.py +++ b/paddlepalm/head/match.py @@ -179,10 +179,9 @@ class Match(Head): with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): if self._learning_strategy == 'pointwise': - label = 0 if self._preds[i][0] > self._preds[i][1] else 1 + label = np.argmax(np.array(self._preds[i])) result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} elif self._learning_strategy == 'pairwise': - label = 0 if self._preds[i][0] < 0.5 else 1 result = {'index': i, 'label': label, 'probs': self._preds[i][0]} result = json.dumps(result, ensure_ascii=False) writer.write(result+'\n') -- GitLab