From 90cb337ee315abb133d094340081ed7f4744c8e5 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 7 Apr 2022 20:32:14 +0800 Subject: [PATCH] [Phi]Add hard_swish/kron/linspace/logit yaml file (#41298) * add yaml * perfect converage --- paddle/fluid/operators/linspace_op.cc | 2 +- paddle/phi/infermeta/ternary.cc | 16 ++++++-- paddle/phi/infermeta/ternary.h | 6 +++ paddle/phi/kernels/activation_grad_kernel.h | 1 + python/paddle/fluid/layers/tensor.py | 6 ++- .../tests/unittests/test_activation_op.py | 13 ++++++- .../fluid/tests/unittests/test_kron_op.py | 29 ++++++++++---- .../fluid/tests/unittests/test_linspace.py | 15 +++++-- .../fluid/tests/unittests/test_logit_op.py | 12 +++++- python/paddle/nn/functional/activation.py | 5 ++- python/paddle/tensor/math.py | 10 +++-- python/paddle/utils/code_gen/api.yaml | 39 +++++++++++++++++++ python/paddle/utils/code_gen/backward.yaml | 31 +++++++++++++++ 13 files changed, 158 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index 5599debbf3..1cd59672f9 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -67,7 +67,7 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(linspace, LinspaceInferShapeFunctor, - PD_INFER_META(phi::LinspaceInferMeta)); + PD_INFER_META(phi::LinspaceRawInferMeta)); REGISTER_OPERATOR( linspace, ops::LinspaceOp, ops::LinspaceOpMaker, paddle::framework::EmptyGradOpMaker, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 3e4aa7b444..c692b6c8fc 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -276,10 +276,10 @@ void LerpInferMeta(const MetaTensor& x, out->share_lod(x); } -void LinspaceInferMeta(const MetaTensor& start, - const MetaTensor& stop, - const MetaTensor& number, - MetaTensor* out) { +void LinspaceRawInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + MetaTensor* out) { auto s_dims = start.dims(); PADDLE_ENFORCE_EQ( (s_dims.size() == 1) && (s_dims[0] == 1), @@ -305,6 +305,14 @@ void LinspaceInferMeta(const MetaTensor& start, out->set_dtype(start.dtype()); } +void LinspaceInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + DataType dtype, + MetaTensor* out) { + LinspaceRawInferMeta(start, stop, number, out); +} + void NllLossRawInferMeta(const MetaTensor& input, const MetaTensor& label, paddle::optional weight, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 00e4981168..83505f2c2f 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -65,9 +65,15 @@ void LerpInferMeta(const MetaTensor& x, const MetaTensor& weight, MetaTensor* out); +void LinspaceRawInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + MetaTensor* out); + void LinspaceInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, + DataType dtype, MetaTensor* out); void NllLossRawInferMeta(const MetaTensor& input, diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 82e168a3c6..065d018852 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -197,6 +197,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta); +DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max); diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index a63e87472e..e302371988 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1548,10 +1548,12 @@ def linspace(start, stop, num, dtype=None, name=None): if not isinstance(num, Variable): with device_guard("cpu"): tensor_num = fill_constant([1], 'int32', num) - if _non_static_mode(): + if _in_legacy_dygraph(): return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype', dtype) - + if in_dygraph_mode(): + return _C_ops.final_state_linspace(tensor_start, tensor_stop, + tensor_num, dtype) helper = LayerHelper("linspace", **locals()) start_dtype = convert_dtype(tensor_start.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 89f8ebbd0c..80fef6d375 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -25,6 +25,7 @@ import paddle.nn.functional as F import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import _test_eager_guard paddle.enable_static() @@ -1755,7 +1756,7 @@ class TestHardSwish(TestActivation): def setUp(self): self.op_type = 'hard_swish' self.init_dtype() - + self.python_api = paddle.nn.functional.hardswish skip_check_grad_ci(reason="not implemented yet") np.random.seed(1024) @@ -1777,7 +1778,10 @@ class TestHardSwish(TestActivation): return return # not implemented yet - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) + + def test_check_output(self): + self.check_output(check_eager=True) class TestHardswishAPI(unittest.TestCase): @@ -1838,6 +1842,11 @@ class TestHardswishAPI(unittest.TestCase): name='x_fp16', shape=[12, 10], dtype='float16') F.hardswish(x_fp16) + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_dygraph_api() + self.test_errors() + class TestSoftRelu(TestActivation): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_kron_op.py b/python/paddle/fluid/tests/unittests/test_kron_op.py index d6db4c2f07..f4d013b7c6 100644 --- a/python/paddle/fluid/tests/unittests/test_kron_op.py +++ b/python/paddle/fluid/tests/unittests/test_kron_op.py @@ -21,11 +21,13 @@ from op_test import OpTest import paddle import paddle.fluid as fluid import paddle.fluid.dygraph as dg +from paddle.fluid.framework import _test_eager_guard class TestKronOp(OpTest): def setUp(self): self.op_type = "kron" + self.python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(10, 10)).astype(self.dtype) y = np.random.uniform(size=(10, 10)).astype(self.dtype) @@ -37,21 +39,22 @@ class TestKronOp(OpTest): return "float64" def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) def test_check_grad_ignore_x(self): - self.check_grad(['Y'], 'Out', no_grad_set=set('X')) + self.check_grad(['Y'], 'Out', no_grad_set=set('X'), check_eager=True) def test_check_grad_ignore_y(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_eager=True) class TestKronOp2(TestKronOp): def setUp(self): self.op_type = "kron" + self.python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(5, 5, 4)).astype(self.dtype) y = np.random.uniform(size=(10, 10)).astype(self.dtype) @@ -63,6 +66,7 @@ class TestKronOp2(TestKronOp): class TestKronOp3(TestKronOp): def setUp(self): self.op_type = "kron" + self.python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(10, 10)).astype(self.dtype) y = np.random.uniform(size=(5, 5, 4)).astype(self.dtype) @@ -101,10 +105,16 @@ class TestKronLayer(unittest.TestCase): c, = exe.run(main, feed={'a': a, 'b': b}, fetch_list=[out_var]) np.testing.assert_allclose(c, np.kron(a, b)) + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_case() + self.test_case_with_output() + class TestComplexKronOp(OpTest): def setUp(self): self.op_type = "kron" + self.python_api = paddle.kron self.x_shape = np.array([10, 10]) self.y_shape = np.array([3, 35]) self.out_shape = self.x_shape * self.y_shape @@ -160,14 +170,15 @@ class TestComplexKronOp(OpTest): return grad_y def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): self.check_grad( ['X', 'Y'], 'Out', user_defined_grads=[self.grad_x, self.grad_y], - user_defined_grad_outputs=[self.grad_out]) + user_defined_grad_outputs=[self.grad_out], + check_eager=True) def test_check_grad_ingore_x(self): self.check_grad( @@ -175,7 +186,8 @@ class TestComplexKronOp(OpTest): 'Out', no_grad_set=set("X"), user_defined_grads=[self.grad_y], - user_defined_grad_outputs=[self.grad_out]) + user_defined_grad_outputs=[self.grad_out], + check_eager=True) def test_check_grad_ingore_y(self): self.check_grad( @@ -183,7 +195,8 @@ class TestComplexKronOp(OpTest): 'Out', no_grad_set=set('Y'), user_defined_grads=[self.grad_x], - user_defined_grad_outputs=[self.grad_out]) + user_defined_grad_outputs=[self.grad_out], + check_eager=True) class TestKronOpTypePromotion(TestComplexKronOp): diff --git a/python/paddle/fluid/tests/unittests/test_linspace.py b/python/paddle/fluid/tests/unittests/test_linspace.py index 54846e6a14..65a6c21fb0 100644 --- a/python/paddle/fluid/tests/unittests/test_linspace.py +++ b/python/paddle/fluid/tests/unittests/test_linspace.py @@ -21,11 +21,13 @@ import paddle import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard from paddle.fluid import core +from paddle.fluid.framework import _test_eager_guard class TestLinspaceOpCommonCase(OpTest): def setUp(self): self.op_type = "linspace" + self.python_api = paddle.linspace dtype = 'float32' self.inputs = { 'Start': np.array([0]).astype(dtype), @@ -37,12 +39,13 @@ class TestLinspaceOpCommonCase(OpTest): self.outputs = {'Out': np.arange(0, 11).astype(dtype)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestLinspaceOpReverseCase(OpTest): def setUp(self): self.op_type = "linspace" + self.python_api = paddle.linspace dtype = 'float32' self.inputs = { 'Start': np.array([10]).astype(dtype), @@ -54,12 +57,13 @@ class TestLinspaceOpReverseCase(OpTest): self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestLinspaceOpNumOneCase(OpTest): def setUp(self): self.op_type = "linspace" + self.python_api = paddle.linspace dtype = 'float32' self.inputs = { 'Start': np.array([10]).astype(dtype), @@ -71,7 +75,7 @@ class TestLinspaceOpNumOneCase(OpTest): self.outputs = {'Out': np.array(10, dtype=dtype)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestLinspaceAPI(unittest.TestCase): @@ -123,6 +127,11 @@ class TestLinspaceAPI(unittest.TestCase): self.assertEqual((out2.numpy() == np_out2).all(), True) self.assertEqual((out3.numpy() == np_out3).all(), True) + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_variable_input2() + self.test_imperative() + class TestLinspaceOpError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/fluid/tests/unittests/test_logit_op.py b/python/paddle/fluid/tests/unittests/test_logit_op.py index 9254996eb4..9b46039da1 100644 --- a/python/paddle/fluid/tests/unittests/test_logit_op.py +++ b/python/paddle/fluid/tests/unittests/test_logit_op.py @@ -16,6 +16,7 @@ import unittest import numpy as np from op_test import OpTest import paddle +from paddle.fluid.framework import _test_eager_guard np.random.seed(10) @@ -37,6 +38,7 @@ def logit_grad(x, eps=1e-8): class TestLogitOp(OpTest): def setUp(self): self.op_type = 'logit' + self.python_api = paddle.logit self.dtype = np.float64 self.shape = [120] self.eps = 1e-8 @@ -52,10 +54,11 @@ class TestLogitOp(OpTest): pass def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_eager=True) class TestLogitShape(TestLogitOp): @@ -106,6 +109,11 @@ class TestLogitAPI(unittest.TestCase): x = paddle.fluid.data(name='X2', shape=[100], dtype='float32') self.assertRaises(TypeError, paddle.logit, x, dtype='int32') + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_check_api() + self.test_errors() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index d145b615c3..10bf5d9a46 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -28,6 +28,7 @@ from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle from paddle import _C_ops, in_dynamic_mode from paddle.framework import core +from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode __all__ = [] @@ -386,8 +387,10 @@ def hardswish(x, name=None): out = F.hardswish(x) # [0., 5., 0.666667] """ - if in_dynamic_mode(): + if _in_legacy_dygraph(): return _C_ops.hard_swish(x) + if in_dygraph_mode(): + return _C_ops.final_state_hard_swish(x, 6, 6, 3) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'hardswish') diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9751892e70..311f5f8edd 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2674,9 +2674,10 @@ ${comment} # [12, 15, 18, 16, 20, 24], # [21, 24, 27, 28, 32, 36]]) """ - if paddle.in_dynamic_mode(): + if _in_legacy_dygraph(): return _C_ops.kron(x, y) - + if in_dygraph_mode(): + return _C_ops.final_state_kron(x, y) helper = LayerHelper('kron', **locals()) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron') check_variable_and_dtype(y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], 'kron') @@ -3525,9 +3526,10 @@ def logit(x, eps=None, name=None): if eps == None: eps = 0.0 - if paddle.in_dynamic_mode(): + if _in_legacy_dygraph(): return _C_ops.logit(x, 'eps', eps) - + if in_dygraph_mode(): + return _C_ops.final_state_logit(x, eps) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'logit') helper = LayerHelper("logit", **locals()) out = helper.create_variable_for_type_inference(x.dtype) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 97e8795818..e41495bf0c 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -838,6 +838,16 @@ func : hard_sigmoid backward : hard_sigmoid_grad +- api : hard_swish + args : (Tensor x, float threshold = 6.0, float scale = 6.0, float offset = 3.0) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : hard_swish + backward : hard_swish_grad + # histogram - api : histogram args : (Tensor x, int64_t bins, int min, int max) @@ -949,6 +959,15 @@ data_type : x backward : kldiv_loss_grad +- api : kron + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : KronInferMeta + kernel : + func : kron + backward : kron_grad + - api : kthvalue args : (Tensor x, int k, int axis, bool keepdim) output : Tensor(out), Tensor(indices) @@ -1016,6 +1035,15 @@ func : lgamma backward : lgamma_grad +- api : linspace + args : (Tensor start, Tensor stop, Tensor number, DataType dtype) + output : Tensor + infer_meta : + func : LinspaceInferMeta + kernel : + func : linspace + data_type : dtype + - api : log args : (Tensor x) output : Tensor @@ -1107,6 +1135,17 @@ kernel : func : logical_xor +# logit +- api : logit + args : (Tensor x, float eps = 1e-6f) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : logit + backward : logit_grad + # logsigmoid - api : logsigmoid args : (Tensor x) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 3f6dc0e747..917fd5ec44 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -568,6 +568,16 @@ kernel : func : hard_sigmoid_grad +- backward_api : hard_swish_grad + forward : hard_swish (Tensor x, float threshold = 6.0, float scale = 6.0, float offset = 3.0) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float threshold, float scale, float offset) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : hard_swish_grad + - backward_api : huber_loss_grad forward : huber_loss (Tensor input, Tensor label, float delta) -> Tensor(out), Tensor(residual) args : (Tensor residual, Tensor out_grad, float delta) @@ -617,6 +627,17 @@ kernel : func : kldiv_loss_grad +- backward_api : kron_grad + forward : kron (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : kron_grad + data_type : out_grad + - backward_api : kthvalue_grad forward : kthvalue(Tensor x, int k, int axis, bool keepdim) -> Tensor(out), Tensor(indices) args : (Tensor x, Tensor indices, Tensor out_grad, int k, int axis, bool keepdim) @@ -728,6 +749,16 @@ kernel : func : log_softmax_grad +- backward_api : logit_grad + forward : logit (Tensor x, float eps = 1e-6f) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float eps) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : logit_grad + - backward_api : logsigmoid_grad forward : logsigmoid (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) -- GitLab