未验证 提交 14b91f60 编写于 作者: H hong 提交者: GitHub

add topk cast (#41304)

上级 7dd4a9fe
......@@ -4791,7 +4791,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_prod(y, dim=[1, 2]) # [24.0, 1680.0]
fluid.layers.reduce_prod(y, dim=[0, 1]) # [105.0, 384.0]
"""
helper = LayerHelper('reduce_prod', **locals())
if dim is not None and not isinstance(dim, list):
if isinstance(dim, tuple):
dim = list(dim)
......@@ -4801,6 +4801,12 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".
format(type(dim)))
if in_dygraph_mode():
return _C_ops.final_state_reduce_prod(
input, dim if dim != None and dim != [] else [0], keep_dim, True if
dim == None or dim == [] or len(dim) == len(input.shape) else False)
helper = LayerHelper('reduce_prod', **locals())
check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod')
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
......
......@@ -21,7 +21,7 @@ import warnings
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..initializer import Initializer
from ..framework import convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph
from ..framework import convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode
from ..framework import Variable
from ..initializer import Constant
from ..core import VarDesc
......@@ -243,6 +243,11 @@ def cast(x, dtype):
x = paddle.to_tensor([2, 3, 4], 'float64')
y = paddle.cast(x, 'uint8')
"""
if in_dygraph_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return _C_ops.final_state_cast(x, dtype)
if _non_static_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
......
......@@ -1559,6 +1559,8 @@ class OpTest(unittest.TestCase):
def _compare_numpy(self, name, actual_np, expect_np):
with _test_eager_guard():
print(actual_np)
print(expect_np)
super()._compare_numpy(name, actual_np, expect_np)
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
......
......@@ -22,6 +22,7 @@ import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from op_test import OpTest, convert_uint16_to_float, convert_float_to_uint16
from paddle.fluid.framework import _test_eager_guard
class TestCastOpFp32ToFp64(OpTest):
......@@ -115,6 +116,20 @@ class TestCastOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
class TestCastOpEager(unittest.TestCase):
def test_eager(self):
with paddle.fluid.dygraph.base.guard():
with _test_eager_guard():
x = paddle.ones([2, 2], dtype="float16")
x.stop_gradient = False
out = paddle.cast(x, "float32")
self.assertTrue(
np.array_equal(out, np.ones([2, 2]).astype("float32")))
out.backward()
self.assertTrue(np.array_equal(x.gradient(), x.numpy()))
self.assertTrue(x.gradient().dtype == np.float16)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -241,6 +241,7 @@ class TestMin8DOp(OpTest):
class TestProdOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.python_api = paddle.prod
self.init_data_type()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)}
self.outputs = {'Out': self.inputs['X'].prod(axis=0)}
......
......@@ -57,10 +57,10 @@ class TestTopkOp(OpTest):
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
self.check_output(check_eager=False)
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(set(['X']), 'Out', check_eager=False)
self.check_grad(set(['X']), 'Out', check_eager=True)
class TestTopkOp1(TestTopkOp):
......
......@@ -858,6 +858,12 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None):
"""
if in_dygraph_mode():
if axis == None:
axis = -1
out, indices = _C_ops.final_state_top_k(x, k, axis, largest, sorted)
return out, indices
if _non_static_mode():
if axis is None:
out, indices = _C_ops.top_k_v2(x, 'k',
......
......@@ -1123,7 +1123,8 @@
infer_meta :
func : ReduceInferMetaBase
kernel :
func : reduce_prod
func : prod_raw
backward : reduce_prod_grad
- api : relu
args : (Tensor x)
......
......@@ -721,6 +721,16 @@
kernel :
func : reciprocal_grad
- backward_api : reduce_prod_grad
forward : reduce_prod (Tensor x, int64_t[] dims, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims, bool keep_dim, bool reduce_all)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : reduce_prod_grad
- backward_api : relu_double_grad
forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor grad_x_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册