未验证 提交 f7dd889c 编写于 作者: Q qingqing01 提交者: GitHub

Support squeezed label as input in paddle.metric.Accuracy (#28535)

* Support squeezed label as input in paddle.metric.Accuracy
* Revert cifar and fix UT
上级 8b97bb2e
...@@ -244,6 +244,7 @@ class Accuracy(Metric): ...@@ -244,6 +244,7 @@ class Accuracy(Metric):
Tensor: Correct mask, a tensor with shape [batch_size, topk]. Tensor: Correct mask, a tensor with shape [batch_size, topk].
""" """
pred = paddle.argsort(pred, descending=True)[:, :self.maxk] pred = paddle.argsort(pred, descending=True)[:, :self.maxk]
label = paddle.reshape(label, (-1, 1))
correct = pred == label correct = pred == label
return paddle.cast(correct, dtype='float32') return paddle.cast(correct, dtype='float32')
......
...@@ -28,6 +28,7 @@ from paddle.hapi.model import to_list ...@@ -28,6 +28,7 @@ from paddle.hapi.model import to_list
def accuracy(pred, label, topk=(1, )): def accuracy(pred, label, topk=(1, )):
maxk = max(topk) maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk] pred = np.argsort(pred)[:, ::-1][:, :maxk]
label = label.reshape(-1, 1)
correct = (pred == np.repeat(label, maxk, 1)) correct = (pred == np.repeat(label, maxk, 1))
batch_size = label.shape[0] batch_size = label.shape[0]
...@@ -47,13 +48,18 @@ def convert_to_one_hot(y, C): ...@@ -47,13 +48,18 @@ def convert_to_one_hot(y, C):
class TestAccuracy(unittest.TestCase): class TestAccuracy(unittest.TestCase):
def test_acc(self): def test_acc(self, squeeze_y=False):
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor( x = paddle.to_tensor(
np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2], np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2],
[0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]])) [0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]]))
y = paddle.to_tensor(np.array([[0], [1], [2], [3]]))
y = np.array([[0], [1], [2], [3]])
if squeeze_y:
y = y.squeeze()
y = paddle.to_tensor(y)
m = paddle.metric.Accuracy(name='my_acc') m = paddle.metric.Accuracy(name='my_acc')
...@@ -61,7 +67,8 @@ class TestAccuracy(unittest.TestCase): ...@@ -61,7 +67,8 @@ class TestAccuracy(unittest.TestCase):
self.assertEqual(m.name(), ['my_acc']) self.assertEqual(m.name(), ['my_acc'])
correct = m.compute(x, y) correct = m.compute(x, y)
# check results # check shape and results
self.assertEqual(correct.shape, [4, 1])
self.assertEqual(m.update(correct), 0.75) self.assertEqual(m.update(correct), 0.75)
self.assertEqual(m.accumulate(), 0.75) self.assertEqual(m.accumulate(), 0.75)
...@@ -80,6 +87,9 @@ class TestAccuracy(unittest.TestCase): ...@@ -80,6 +87,9 @@ class TestAccuracy(unittest.TestCase):
self.assertEqual(m.count[0], 0.0) self.assertEqual(m.count[0], 0.0)
paddle.enable_static() paddle.enable_static()
def test_1d_label(self):
self.test_acc(True)
class TestAccuracyDynamic(unittest.TestCase): class TestAccuracyDynamic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -87,12 +97,15 @@ class TestAccuracyDynamic(unittest.TestCase): ...@@ -87,12 +97,15 @@ class TestAccuracyDynamic(unittest.TestCase):
self.class_num = 5 self.class_num = 5
self.sample_num = 1000 self.sample_num = 1000
self.name = None self.name = None
self.squeeze_label = False
def random_pred_label(self): def random_pred_label(self):
label = np.random.randint(0, self.class_num, label = np.random.randint(0, self.class_num,
(self.sample_num, 1)).astype('int64') (self.sample_num, 1)).astype('int64')
pred = np.random.randint(0, self.class_num, pred = np.random.randint(0, self.class_num,
(self.sample_num, 1)).astype('int32') (self.sample_num, 1)).astype('int32')
if self.squeeze_label:
label = label.squeeze()
pred_one_hot = convert_to_one_hot(pred, self.class_num) pred_one_hot = convert_to_one_hot(pred, self.class_num)
pred_one_hot = pred_one_hot.astype('float32') pred_one_hot = pred_one_hot.astype('float32')
...@@ -123,9 +136,17 @@ class TestAccuracyDynamicMultiTopk(TestAccuracyDynamic): ...@@ -123,9 +136,17 @@ class TestAccuracyDynamicMultiTopk(TestAccuracyDynamic):
self.class_num = 10 self.class_num = 10
self.sample_num = 1000 self.sample_num = 1000
self.name = "accuracy" self.name = "accuracy"
self.squeeze_label = True
class TestAccuracyStatic(TestAccuracyDynamic): class TestAccuracyStatic(TestAccuracyDynamic):
def setUp(self):
self.topk = (1, )
self.class_num = 5
self.sample_num = 1000
self.name = None
self.squeeze_label = True
def test_main(self): def test_main(self):
main_prog = fluid.Program() main_prog = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
...@@ -164,6 +185,7 @@ class TestAccuracyStaticMultiTopk(TestAccuracyStatic): ...@@ -164,6 +185,7 @@ class TestAccuracyStaticMultiTopk(TestAccuracyStatic):
self.class_num = 10 self.class_num = 10
self.sample_num = 100 self.sample_num = 100
self.name = "accuracy" self.name = "accuracy"
self.squeeze_label = False
class TestPrecision(unittest.TestCase): class TestPrecision(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册