From f7dd889ca443aaf1248947a1af65107b9779370d Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 16 Nov 2020 13:52:31 +0800 Subject: [PATCH] Support squeezed label as input in paddle.metric.Accuracy (#28535) * Support squeezed label as input in paddle.metric.Accuracy * Revert cifar and fix UT --- python/paddle/metric/metrics.py | 1 + python/paddle/tests/test_metrics.py | 28 +++++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/python/paddle/metric/metrics.py b/python/paddle/metric/metrics.py index fed659562cb..510b99c0300 100644 --- a/python/paddle/metric/metrics.py +++ b/python/paddle/metric/metrics.py @@ -244,6 +244,7 @@ class Accuracy(Metric): Tensor: Correct mask, a tensor with shape [batch_size, topk]. """ pred = paddle.argsort(pred, descending=True)[:, :self.maxk] + label = paddle.reshape(label, (-1, 1)) correct = pred == label return paddle.cast(correct, dtype='float32') diff --git a/python/paddle/tests/test_metrics.py b/python/paddle/tests/test_metrics.py index f05cdf9c6da..b1f53168e62 100644 --- a/python/paddle/tests/test_metrics.py +++ b/python/paddle/tests/test_metrics.py @@ -28,6 +28,7 @@ from paddle.hapi.model import to_list def accuracy(pred, label, topk=(1, )): maxk = max(topk) pred = np.argsort(pred)[:, ::-1][:, :maxk] + label = label.reshape(-1, 1) correct = (pred == np.repeat(label, maxk, 1)) batch_size = label.shape[0] @@ -47,13 +48,18 @@ def convert_to_one_hot(y, C): class TestAccuracy(unittest.TestCase): - def test_acc(self): + def test_acc(self, squeeze_y=False): paddle.disable_static() x = paddle.to_tensor( np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2], [0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]])) - y = paddle.to_tensor(np.array([[0], [1], [2], [3]])) + + y = np.array([[0], [1], [2], [3]]) + if squeeze_y: + y = y.squeeze() + + y = paddle.to_tensor(y) m = paddle.metric.Accuracy(name='my_acc') @@ -61,7 +67,8 @@ class TestAccuracy(unittest.TestCase): self.assertEqual(m.name(), ['my_acc']) correct = m.compute(x, y) - # check results + # check shape and results + self.assertEqual(correct.shape, [4, 1]) self.assertEqual(m.update(correct), 0.75) self.assertEqual(m.accumulate(), 0.75) @@ -80,6 +87,9 @@ class TestAccuracy(unittest.TestCase): self.assertEqual(m.count[0], 0.0) paddle.enable_static() + def test_1d_label(self): + self.test_acc(True) + class TestAccuracyDynamic(unittest.TestCase): def setUp(self): @@ -87,12 +97,15 @@ class TestAccuracyDynamic(unittest.TestCase): self.class_num = 5 self.sample_num = 1000 self.name = None + self.squeeze_label = False def random_pred_label(self): label = np.random.randint(0, self.class_num, (self.sample_num, 1)).astype('int64') pred = np.random.randint(0, self.class_num, (self.sample_num, 1)).astype('int32') + if self.squeeze_label: + label = label.squeeze() pred_one_hot = convert_to_one_hot(pred, self.class_num) pred_one_hot = pred_one_hot.astype('float32') @@ -123,9 +136,17 @@ class TestAccuracyDynamicMultiTopk(TestAccuracyDynamic): self.class_num = 10 self.sample_num = 1000 self.name = "accuracy" + self.squeeze_label = True class TestAccuracyStatic(TestAccuracyDynamic): + def setUp(self): + self.topk = (1, ) + self.class_num = 5 + self.sample_num = 1000 + self.name = None + self.squeeze_label = True + def test_main(self): main_prog = fluid.Program() startup_prog = fluid.Program() @@ -164,6 +185,7 @@ class TestAccuracyStaticMultiTopk(TestAccuracyStatic): self.class_num = 10 self.sample_num = 100 self.name = "accuracy" + self.squeeze_label = False class TestPrecision(unittest.TestCase): -- GitLab