未验证 提交 692466b4 编写于 作者: J Jiaqi Liu 提交者: GitHub

Make accuracy function support dtype int64 for input label (#43003)

* support int64 for acc

* support int64 for acc
上级 f3bdabc1
...@@ -776,7 +776,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): ...@@ -776,7 +776,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None):
Args: Args:
input(Tensor): The input of accuracy layer, which is the predictions of network. A Tensor with type float32,float64. input(Tensor): The input of accuracy layer, which is the predictions of network. A Tensor with type float32,float64.
The shape is ``[sample_number, class_dim]`` . The shape is ``[sample_number, class_dim]`` .
label(Tensor): The label of dataset. Tensor with type int64. The shape is ``[sample_number, 1]`` . label(Tensor): The label of dataset. Tensor with type int64 or int32. The shape is ``[sample_number, 1]`` .
k(int, optional): The top k predictions for each class will be checked. Data type is int64 or int32. k(int, optional): The top k predictions for each class will be checked. Data type is int64 or int32.
correct(Tensor, optional): The correct predictions count. A Tensor with type int64 or int32. correct(Tensor, optional): The correct predictions count. A Tensor with type int64 or int32.
total(Tensor, optional): The total entries count. A tensor with type int64 or int32. total(Tensor, optional): The total entries count. A tensor with type int64 or int32.
...@@ -796,6 +796,8 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): ...@@ -796,6 +796,8 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None):
result = paddle.metric.accuracy(input=predictions, label=label, k=1) result = paddle.metric.accuracy(input=predictions, label=label, k=1)
# [0.5] # [0.5]
""" """
if label.dtype == paddle.int32:
label = paddle.cast(label, paddle.int64)
if _non_static_mode(): if _non_static_mode():
if correct is None: if correct is None:
correct = _varbase_creator(dtype="int32") correct = _varbase_creator(dtype="int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册