未验证 提交 058f1b22 编写于 作者: Q qingqing01 提交者: GitHub

Enhance paddle.metric.Accuracy (#29125)

上级 dc070ecf
......@@ -246,16 +246,27 @@ class Accuracy(Metric):
Compute the top-k (maxinum value in `topk`) indices.
Args:
pred (Tensor): The predicted value is a Tensor wit type
float32 or float64.
label (Tensor): The ground truth value is a 2D Tensor, its
shape is [batch_size, 1] and type is int64.
pred (Tensor): The predicted value is a Tensor with dtype
float32 or float64. Shape is [batch_size, d0, ..., dN].
label (Tensor): The ground truth value is Tensor with dtype
int64. Shape is [batch_size, d0, ..., 1], or
[batch_size, d0, ..., num_classes] in one hot representation.
Return:
Tensor: Correct mask, a tensor with shape [batch_size, topk].
"""
pred = paddle.argsort(pred, descending=True)[:, :self.maxk]
label = paddle.reshape(label, (-1, 1))
pred = paddle.argsort(pred, descending=True)
pred = paddle.slice(
pred, axes=[len(pred.shape) - 1], starts=[0], ends=[self.maxk])
if (len(label.shape) == 1) or \
(len(label.shape) == 2 and label.shape[-1] == 1):
# In static mode, the real label data shape may be different
# from shape defined by paddle.static.InputSpec in model
# building, reshape to the right shape.
label = paddle.reshape(label, (-1, 1))
elif label.shape[-1] != 1:
# one-hot label
label = paddle.argmax(label, axis=-1, keepdim=True)
correct = pred == label
return paddle.cast(correct, dtype='float32')
......@@ -273,10 +284,10 @@ class Accuracy(Metric):
"""
if isinstance(correct, paddle.Tensor):
correct = correct.numpy()
num_samples = np.prod(np.array(correct.shape[:-1]))
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
num_corrects = correct[..., :k].sum()
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
......
......@@ -25,17 +25,28 @@ import paddle.fluid as fluid
from paddle.hapi.model import to_list
def one_hot(x, n_class):
res = np.eye(n_class)[np.array(x).reshape(-1)]
res = res.reshape(list(x.shape) + [n_class])
return res
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
label = label.reshape(-1, 1)
correct = (pred == np.repeat(label, maxk, 1))
pred = np.argsort(pred)[..., ::-1][..., :maxk]
if len(label.shape) == 1:
label = label.reshape(-1, 1)
elif label.shape[-1] != 1:
label = np.argmax(label, axis=-1)
label = label[..., np.newaxis]
correct = (pred == np.repeat(label, maxk, -1))
total = np.prod(np.array(label.shape[:-1]))
batch_size = label.shape[0]
res = []
for k in topk:
correct_k = correct[:, :k].sum()
res.append(float(correct_k) / batch_size)
correct_k = correct[..., :k].sum()
res.append(float(correct_k) / total)
return res
......@@ -49,8 +60,6 @@ def convert_to_one_hot(y, C):
class TestAccuracy(unittest.TestCase):
def test_acc(self, squeeze_y=False):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.4, 0.3, 0.2],
[0.1, 0.2, 0.4, 0.3], [0.1, 0.2, 0.3, 0.4]]))
......@@ -85,11 +94,36 @@ class TestAccuracy(unittest.TestCase):
m.reset()
self.assertEqual(m.total[0], 0.0)
self.assertEqual(m.count[0], 0.0)
paddle.enable_static()
def test_1d_label(self):
self.test_acc(True)
def compare(self, x_np, y_np, k=(1, )):
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
m = paddle.metric.Accuracy(name='my_acc', topk=k)
correct = m.compute(x, y)
acc_np = accuracy(x_np, y_np, k)
acc_np = acc_np[0] if len(acc_np) == 1 else acc_np
# check shape and results
self.assertEqual(correct.shape, list(x_np.shape)[:-1] + [max(k)])
self.assertEqual(m.update(correct), acc_np)
self.assertEqual(m.accumulate(), acc_np)
def test_3d(self):
x_np = np.random.rand(2, 3, 4)
y_np = np.random.randint(4, size=(2, 3, 1))
self.compare(x_np, y_np)
def test_one_hot(self):
x_np = np.random.rand(2, 3, 4)
y_np = np.random.randint(4, size=(2, 3))
y_one_hot_np = one_hot(y_np, 4)
self.compare(x_np, y_one_hot_np, (1, 2))
class TestAccuracyDynamic(unittest.TestCase):
def setUp(self):
......@@ -148,6 +182,8 @@ class TestAccuracyStatic(TestAccuracyDynamic):
self.squeeze_label = True
def test_main(self):
paddle.enable_static()
main_prog = fluid.Program()
startup_prog = fluid.Program()
main_prog.random_seed = 1024
......@@ -178,6 +214,8 @@ class TestAccuracyStatic(TestAccuracyDynamic):
assert np.sum(acc.total) == 0
assert np.sum(acc.count) == 0
paddle.disable_static()
class TestAccuracyStaticMultiTopk(TestAccuracyStatic):
def setUp(self):
......@@ -190,7 +228,6 @@ class TestAccuracyStaticMultiTopk(TestAccuracyStatic):
class TestPrecision(unittest.TestCase):
def test_1d(self):
paddle.disable_static()
x = np.array([0.1, 0.5, 0.6, 0.7])
y = np.array([1, 0, 1, 1])
......@@ -206,11 +243,7 @@ class TestPrecision(unittest.TestCase):
r = m.accumulate()
self.assertAlmostEqual(r, 4. / 6.)
paddle.enable_static()
def test_2d(self):
paddle.disable_static()
x = np.array([0.1, 0.5, 0.6, 0.7]).reshape(-1, 1)
y = np.array([1, 0, 1, 1]).reshape(-1, 1)
......@@ -231,13 +264,9 @@ class TestPrecision(unittest.TestCase):
self.assertEqual(m.fp, 0.0)
self.assertEqual(m.accumulate(), 0.0)
paddle.enable_static()
class TestRecall(unittest.TestCase):
def test_1d(self):
paddle.disable_static()
x = np.array([0.1, 0.5, 0.6, 0.7])
y = np.array([1, 0, 1, 1])
......@@ -257,12 +286,10 @@ class TestRecall(unittest.TestCase):
self.assertEqual(m.tp, 0.0)
self.assertEqual(m.fn, 0.0)
self.assertEqual(m.accumulate(), 0.0)
paddle.enable_static()
class TestAuc(unittest.TestCase):
def test_auc_numpy(self):
paddle.disable_static()
x = np.array([[0.78, 0.22], [0.62, 0.38], [0.55, 0.45], [0.30, 0.70],
[0.14, 0.86], [0.59, 0.41], [0.91, 0.08], [0.16, 0.84]])
y = np.array([[0], [1], [1], [0], [1], [0], [0], [1]])
......@@ -274,10 +301,7 @@ class TestAuc(unittest.TestCase):
m.reset()
self.assertEqual(m.accumulate(), 0.0)
paddle.enable_static()
def test_auc_tensor(self):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[0.78, 0.22], [0.62, 0.38], [0.55, 0.45], [0.30, 0.70],
[0.14, 0.86], [0.59, 0.41], [0.91, 0.08], [0.16, 0.84]]))
......@@ -290,8 +314,6 @@ class TestAuc(unittest.TestCase):
m.reset()
self.assertEqual(m.accumulate(), 0.0)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册