diff --git a/python/paddle/metric/metrics.py b/python/paddle/metric/metrics.py index ac9f048bab91662e4d5f3f96ddf048c5c1e25a79..0784775b6695eefb091bf0643a0b5c12d4b4664f 100644 --- a/python/paddle/metric/metrics.py +++ b/python/paddle/metric/metrics.py @@ -246,16 +246,27 @@ class Accuracy(Metric): Compute the top-k (maxinum value in `topk`) indices. Args: - pred (Tensor): The predicted value is a Tensor wit type - float32 or float64. - label (Tensor): The ground truth value is a 2D Tensor, its - shape is [batch_size, 1] and type is int64. - + pred (Tensor): The predicted value is a Tensor with dtype + float32 or float64. Shape is [batch_size, d0, ..., dN]. + label (Tensor): The ground truth value is Tensor with dtype + int64. Shape is [batch_size, d0, ..., 1], or + [batch_size, d0, ..., num_classes] in one hot representation. + Return: Tensor: Correct mask, a tensor with shape [batch_size, topk]. """ - pred = paddle.argsort(pred, descending=True)[:, :self.maxk] - label = paddle.reshape(label, (-1, 1)) + pred = paddle.argsort(pred, descending=True) + pred = paddle.slice( + pred, axes=[len(pred.shape) - 1], starts=[0], ends=[self.maxk]) + if (len(label.shape) == 1) or \ + (len(label.shape) == 2 and label.shape[-1] == 1): + # In static mode, the real label data shape may be different + # from shape defined by paddle.static.InputSpec in model + # building, reshape to the right shape. + label = paddle.reshape(label, (-1, 1)) + elif label.shape[-1] != 1: + # one-hot label + label = paddle.argmax(label, axis=-1, keepdim=True) correct = pred == label return paddle.cast(correct, dtype='float32') @@ -273,10 +284,10 @@ class Accuracy(Metric): """ if isinstance(correct, paddle.Tensor): correct = correct.numpy() + num_samples = np.prod(np.array(correct.shape[:-1])) accs = [] for i, k in enumerate(self.topk): - num_corrects = correct[:, :k].sum() - num_samples = len(correct) + num_corrects = correct[..., :k].sum() accs.append(float(num_corrects) / num_samples) self.total[i] += num_corrects self.count[i] += num_samples diff --git a/python/paddle/tests/test_metrics.py b/python/paddle/tests/test_metrics.py index b1f53168e62ceca11eea74a0cd9bbbd415bc06aa..0cf52b35e444b9ba7951a96f10aa6eadf44441da 100644 --- a/python/paddle/tests/test_metrics.py +++ b/python/paddle/tests/test_metrics.py @@ -25,17 +25,28 @@ import paddle.fluid as fluid from paddle.hapi.model import to_list +def one_hot(x, n_class): + res = np.eye(n_class)[np.array(x).reshape(-1)] + res = res.reshape(list(x.shape) + [n_class]) + return res + + def accuracy(pred, label, topk=(1, )): maxk = max(topk) - pred = np.argsort(pred)[:, ::-1][:, :maxk] - label = label.reshape(-1, 1) - correct = (pred == np.repeat(label, maxk, 1)) + pred = np.argsort(pred)[..., ::-1][..., :maxk] + if len(label.shape) == 1: + label = label.reshape(-1, 1) + elif label.shape[-1] != 1: + label = np.argmax(label, axis=-1) + label = label[..., np.newaxis] + correct = (pred == np.repeat(label, maxk, -1)) + + total = np.prod(np.array(label.shape[:-1])) - batch_size = label.shape[0] res = [] for k in topk: - correct_k = correct[:, :k].sum() - res.append(float(correct_k) / batch_size) + correct_k = correct[..., :k].sum() + res.append(float(correct_k) / total) return res @@ -49,8 +60,6 @@ def convert_to_one_hot(y, C): class TestAccuracy(unittest.TestCase): def test_acc(self, squeeze_y=False): - paddle.disable_static() - x = paddle.to_tensor( np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2], [0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]])) @@ -85,11 +94,36 @@ class TestAccuracy(unittest.TestCase): m.reset() self.assertEqual(m.total[0], 0.0) self.assertEqual(m.count[0], 0.0) - paddle.enable_static() def test_1d_label(self): self.test_acc(True) + def compare(self, x_np, y_np, k=(1, )): + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + + m = paddle.metric.Accuracy(name='my_acc', topk=k) + correct = m.compute(x, y) + + acc_np = accuracy(x_np, y_np, k) + acc_np = acc_np[0] if len(acc_np) == 1 else acc_np + + # check shape and results + self.assertEqual(correct.shape, list(x_np.shape)[:-1] + [max(k)]) + self.assertEqual(m.update(correct), acc_np) + self.assertEqual(m.accumulate(), acc_np) + + def test_3d(self): + x_np = np.random.rand(2, 3, 4) + y_np = np.random.randint(4, size=(2, 3, 1)) + self.compare(x_np, y_np) + + def test_one_hot(self): + x_np = np.random.rand(2, 3, 4) + y_np = np.random.randint(4, size=(2, 3)) + y_one_hot_np = one_hot(y_np, 4) + self.compare(x_np, y_one_hot_np, (1, 2)) + class TestAccuracyDynamic(unittest.TestCase): def setUp(self): @@ -148,6 +182,8 @@ class TestAccuracyStatic(TestAccuracyDynamic): self.squeeze_label = True def test_main(self): + paddle.enable_static() + main_prog = fluid.Program() startup_prog = fluid.Program() main_prog.random_seed = 1024 @@ -178,6 +214,8 @@ class TestAccuracyStatic(TestAccuracyDynamic): assert np.sum(acc.total) == 0 assert np.sum(acc.count) == 0 + paddle.disable_static() + class TestAccuracyStaticMultiTopk(TestAccuracyStatic): def setUp(self): @@ -190,7 +228,6 @@ class TestAccuracyStaticMultiTopk(TestAccuracyStatic): class TestPrecision(unittest.TestCase): def test_1d(self): - paddle.disable_static() x = np.array([0.1, 0.5, 0.6, 0.7]) y = np.array([1, 0, 1, 1]) @@ -206,11 +243,7 @@ class TestPrecision(unittest.TestCase): r = m.accumulate() self.assertAlmostEqual(r, 4. / 6.) - paddle.enable_static() - def test_2d(self): - paddle.disable_static() - x = np.array([0.1, 0.5, 0.6, 0.7]).reshape(-1, 1) y = np.array([1, 0, 1, 1]).reshape(-1, 1) @@ -231,13 +264,9 @@ class TestPrecision(unittest.TestCase): self.assertEqual(m.fp, 0.0) self.assertEqual(m.accumulate(), 0.0) - paddle.enable_static() - class TestRecall(unittest.TestCase): def test_1d(self): - paddle.disable_static() - x = np.array([0.1, 0.5, 0.6, 0.7]) y = np.array([1, 0, 1, 1]) @@ -257,12 +286,10 @@ class TestRecall(unittest.TestCase): self.assertEqual(m.tp, 0.0) self.assertEqual(m.fn, 0.0) self.assertEqual(m.accumulate(), 0.0) - paddle.enable_static() class TestAuc(unittest.TestCase): def test_auc_numpy(self): - paddle.disable_static() x = np.array([[0.78, 0.22], [0.62, 0.38], [0.55, 0.45], [0.30, 0.70], [0.14, 0.86], [0.59, 0.41], [0.91, 0.08], [0.16, 0.84]]) y = np.array([[0], [1], [1], [0], [1], [0], [0], [1]]) @@ -274,10 +301,7 @@ class TestAuc(unittest.TestCase): m.reset() self.assertEqual(m.accumulate(), 0.0) - paddle.enable_static() - def test_auc_tensor(self): - paddle.disable_static() x = paddle.to_tensor( np.array([[0.78, 0.22], [0.62, 0.38], [0.55, 0.45], [0.30, 0.70], [0.14, 0.86], [0.59, 0.41], [0.91, 0.08], [0.16, 0.84]])) @@ -290,8 +314,6 @@ class TestAuc(unittest.TestCase): m.reset() self.assertEqual(m.accumulate(), 0.0) - paddle.enable_static() - if __name__ == '__main__': unittest.main()