未验证 提交 f7dd889c 编写于 作者: Q qingqing01 提交者: GitHub

Support squeezed label as input in paddle.metric.Accuracy (#28535)

* Support squeezed label as input in paddle.metric.Accuracy
* Revert cifar and fix UT
上级 8b97bb2e
......@@ -244,6 +244,7 @@ class Accuracy(Metric):
Tensor: Correct mask, a tensor with shape [batch_size, topk].
"""
pred = paddle.argsort(pred, descending=True)[:, :self.maxk]
label = paddle.reshape(label, (-1, 1))
correct = pred == label
return paddle.cast(correct, dtype='float32')
......
......@@ -28,6 +28,7 @@ from paddle.hapi.model import to_list
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
label = label.reshape(-1, 1)
correct = (pred == np.repeat(label, maxk, 1))
batch_size = label.shape[0]
......@@ -47,13 +48,18 @@ def convert_to_one_hot(y, C):
class TestAccuracy(unittest.TestCase):
def test_acc(self):
def test_acc(self, squeeze_y=False):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2],
[0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]]))
y = paddle.to_tensor(np.array([[0], [1], [2], [3]]))
y = np.array([[0], [1], [2], [3]])
if squeeze_y:
y = y.squeeze()
y = paddle.to_tensor(y)
m = paddle.metric.Accuracy(name='my_acc')
......@@ -61,7 +67,8 @@ class TestAccuracy(unittest.TestCase):
self.assertEqual(m.name(), ['my_acc'])
correct = m.compute(x, y)
# check results
# check shape and results
self.assertEqual(correct.shape, [4, 1])
self.assertEqual(m.update(correct), 0.75)
self.assertEqual(m.accumulate(), 0.75)
......@@ -80,6 +87,9 @@ class TestAccuracy(unittest.TestCase):
self.assertEqual(m.count[0], 0.0)
paddle.enable_static()
def test_1d_label(self):
self.test_acc(True)
class TestAccuracyDynamic(unittest.TestCase):
def setUp(self):
......@@ -87,12 +97,15 @@ class TestAccuracyDynamic(unittest.TestCase):
self.class_num = 5
self.sample_num = 1000
self.name = None
self.squeeze_label = False
def random_pred_label(self):
label = np.random.randint(0, self.class_num,
(self.sample_num, 1)).astype('int64')
pred = np.random.randint(0, self.class_num,
(self.sample_num, 1)).astype('int32')
if self.squeeze_label:
label = label.squeeze()
pred_one_hot = convert_to_one_hot(pred, self.class_num)
pred_one_hot = pred_one_hot.astype('float32')
......@@ -123,9 +136,17 @@ class TestAccuracyDynamicMultiTopk(TestAccuracyDynamic):
self.class_num = 10
self.sample_num = 1000
self.name = "accuracy"
self.squeeze_label = True
class TestAccuracyStatic(TestAccuracyDynamic):
def setUp(self):
self.topk = (1, )
self.class_num = 5
self.sample_num = 1000
self.name = None
self.squeeze_label = True
def test_main(self):
main_prog = fluid.Program()
startup_prog = fluid.Program()
......@@ -164,6 +185,7 @@ class TestAccuracyStaticMultiTopk(TestAccuracyStatic):
self.class_num = 10
self.sample_num = 100
self.name = "accuracy"
self.squeeze_label = False
class TestPrecision(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册