From 692466b4c700b9e369efead356aebf8a2022f9fb Mon Sep 17 00:00:00 2001 From: Jiaqi Liu <709153940@qq.com> Date: Fri, 1 Jul 2022 16:16:08 +0800 Subject: [PATCH] Make accuracy function support dtype int64 for input label (#43003) * support int64 for acc * support int64 for acc --- python/paddle/metric/metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/metric/metrics.py b/python/paddle/metric/metrics.py index 4d28b68f994..919daa31d06 100644 --- a/python/paddle/metric/metrics.py +++ b/python/paddle/metric/metrics.py @@ -776,7 +776,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): Args: input(Tensor): The input of accuracy layer, which is the predictions of network. A Tensor with type float32,float64. The shape is ``[sample_number, class_dim]`` . - label(Tensor): The label of dataset. Tensor with type int64. The shape is ``[sample_number, 1]`` . + label(Tensor): The label of dataset. Tensor with type int64 or int32. The shape is ``[sample_number, 1]`` . k(int, optional): The top k predictions for each class will be checked. Data type is int64 or int32. correct(Tensor, optional): The correct predictions count. A Tensor with type int64 or int32. total(Tensor, optional): The total entries count. A tensor with type int64 or int32. @@ -796,6 +796,8 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): result = paddle.metric.accuracy(input=predictions, label=label, k=1) # [0.5] """ + if label.dtype == paddle.int32: + label = paddle.cast(label, paddle.int64) if _non_static_mode(): if correct is None: correct = _varbase_creator(dtype="int32") -- GitLab