未验证 提交 94d2cf82 编写于 作者: J Jiaqi Liu 提交者: GitHub

update acc func using topk v2 (#35789)

上级 9b2d53fc
...@@ -22,7 +22,6 @@ import numpy as np ...@@ -22,7 +22,6 @@ import numpy as np
from ..fluid.data_feeder import check_variable_and_dtype from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.layers.nn import topk
from ..fluid.framework import core, _varbase_creator, in_dygraph_mode from ..fluid.framework import core, _varbase_creator, in_dygraph_mode
import paddle import paddle
from paddle import _C_ops from paddle import _C_ops
...@@ -798,7 +797,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): ...@@ -798,7 +797,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None):
if total is None: if total is None:
total = _varbase_creator(dtype="int32") total = _varbase_creator(dtype="int32")
topk_out, topk_indices = topk(input, k=k) topk_out, topk_indices = paddle.topk(input, k=k)
_acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label, correct, _acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label, correct,
total) total)
return _acc return _acc
...@@ -806,7 +805,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): ...@@ -806,7 +805,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None):
helper = LayerHelper("accuracy", **locals()) helper = LayerHelper("accuracy", **locals())
check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'], check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
'accuracy') 'accuracy')
topk_out, topk_indices = topk(input, k=k) topk_out, topk_indices = paddle.topk(input, k=k)
acc_out = helper.create_variable_for_type_inference(dtype="float32") acc_out = helper.create_variable_for_type_inference(dtype="float32")
if correct is None: if correct is None:
correct = helper.create_variable_for_type_inference(dtype="int32") correct = helper.create_variable_for_type_inference(dtype="int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册