diff --git a/paddle/fluid/operators/angle_op.cc b/paddle/fluid/operators/angle_op.cc index f925c7fa747594b66ec19010e07cc9f63284fa6e..e787feb7dab943b37a71a697d65553034676c6a2 100644 --- a/paddle/fluid/operators/angle_op.cc +++ b/paddle/fluid/operators/angle_op.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/angle_op.h" - #include #include #include @@ -22,22 +20,18 @@ #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { class AngleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "angle"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "angle"); - - auto in_dims = ctx->GetInputDim("X"); - - ctx->SetOutputDim("Out", in_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class AngleOpMaker : public framework::OpProtoAndCheckerMaker { @@ -67,20 +61,6 @@ $$out = angle(x)$$ class AngleGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@Grad", - "angle_grad"); - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Out@Grad", "angle_grad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - "X@Grad", - "angle_grad"); - - auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -108,24 +88,19 @@ class AngleGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(angle, + AngleInferShapeFunctor, + PD_INFER_META(phi::RealAndImagInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(angle_grad, + AngleGradInferShapeFunctor, + PD_INFER_META(phi::AngleGradInferMeta)); + REGISTER_OPERATOR(angle, ops::AngleOp, ops::AngleOpMaker, ops::AngleGradMaker, - ops::AngleGradMaker); - -REGISTER_OP_CPU_KERNEL( - angle, - ops::AngleKernel, - ops::AngleKernel, - ops::AngleKernel>, - ops::AngleKernel>); - -REGISTER_OPERATOR(angle_grad, ops::AngleGradOp); - -REGISTER_OP_CPU_KERNEL( - angle_grad, - ops::AngleGradKernel, - ops::AngleGradKernel, - ops::AngleGradKernel>, - ops::AngleGradKernel>); + ops::AngleGradMaker, + AngleInferShapeFunctor); + +REGISTER_OPERATOR(angle_grad, ops::AngleGradOp, AngleGradInferShapeFunctor); diff --git a/paddle/fluid/operators/angle_op.cu b/paddle/fluid/operators/angle_op.cu deleted file mode 100644 index 26e6aebba9f95edff97df606dbcd58d7e016df8a..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/angle_op.cu +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/angle_op.h" -#include "paddle/fluid/platform/complex.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - angle, - ops::AngleKernel, - ops::AngleKernel, - ops::AngleKernel>, - ops::AngleKernel>); - -REGISTER_OP_CUDA_KERNEL( - angle_grad, - ops::AngleGradKernel, - ops::AngleGradKernel, - ops::AngleGradKernel>, - ops::AngleGradKernel>); diff --git a/paddle/fluid/operators/angle_op.h b/paddle/fluid/operators/angle_op.h deleted file mode 100644 index ace345465dc2590b2c1fb6d3366e92c111beb7d9..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/angle_op.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#ifndef _USE_MATH_DEFINES -#define _USE_MATH_DEFINES -#endif -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -class AngleKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - Tensor* out = context.Output("Out"); - - auto numel = x->numel(); - auto* x_data = x->data(); - auto* out_data = out->mutable_data>( - context.GetPlace(), size_t(x->numel() * sizeof(phi::dtype::Real))); - - auto& dev_ctx = context.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - phi::funcs::AngleFunctor functor(x_data, out_data, numel); - for_range(functor); - } -}; - -template -class AngleGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor* d_out = - ctx.Input(framework::GradVarName("Out")); - const framework::Tensor* x = ctx.Input("X"); - framework::Tensor* d_x = - ctx.Output(framework::GradVarName("X")); - - auto numel = d_out->numel(); - auto* dout_data = d_out->data>(); - auto* x_data = x->data(); - auto* dx_data = d_x->mutable_data( - ctx.GetPlace(), static_cast(numel * sizeof(T))); - - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - phi::funcs::AngleGradFunctor functor(dout_data, x_data, dx_data, numel); - for_range(functor); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 564de7b67435e301902ed1195d31fffe317bc9dc..ca5a788d6cbc4c0bdcf429b3f313b0b1cc5be466 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -118,6 +118,15 @@ kernel : func : allclose +- api : angle + args : (Tensor x) + output : Tensor + infer_meta : + func : RealAndImagInferMeta + kernel : + func : angle + backward : angle_grad + - api : any args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) output : Tensor(out) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index fa365d28dabcbcb9930cb71833ac1e5c01b00366..d02fd21f232274b1b68552eede1d78a6f2e75fb6 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -92,6 +92,18 @@ kernel : func : addmm_grad +- backward_api : angle_grad + forward : angle (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : angle_grad + data_transform: + skip_transform : out_grad + - backward_api : argsort_grad forward : argsort (Tensor x, int axis, bool descending) -> Tensor(out), Tensor(indices) args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index eee75af3a329bb11d4b38d7b07a2f0dcc6be7ad6..82eefdea596c9bd4f5a7271a5061975b771ca022 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -19,6 +19,12 @@ limitations under the License. */ namespace phi { +void AngleGradInferMeta(const MetaTensor& x, + const MetaTensor& out_grad, + MetaTensor* x_grad) { + UnchangedInferMeta(x, x_grad); +} + void BilinearTensorProductGradInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 527a3c107f820c8bed119cec0fabdd8731e54375..042de0b2dd64f3fed557397ced8799fffc57b8f4 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -28,6 +28,10 @@ namespace phi { // // NOTE: The InferMeta Functions in this file are arranged in alphabetic order. +void AngleGradInferMeta(const MetaTensor& x, + const MetaTensor& out_grad, + MetaTensor* x_grad); + void BilinearTensorProductGradInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/kernels/angle_grad_kernel.h b/paddle/phi/kernels/angle_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3082ef7e477d1dbb61fdf11d347295630a1782da --- /dev/null +++ b/paddle/phi/kernels/angle_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AngleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/angle_kernel.h b/paddle/phi/kernels/angle_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..e947fe0e3e38d905010cebc401d4c9398e63f9b4 --- /dev/null +++ b/paddle/phi/kernels/angle_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AngleKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/angle_grad_kernel.cc b/paddle/phi/kernels/cpu/angle_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d12501916d85d21b5fb2264036ee52140900fdd0 --- /dev/null +++ b/paddle/phi/kernels/cpu/angle_grad_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/angle_grad_kernel.h" +#include "paddle/phi/kernels/impl/angle_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(angle_grad, + CPU, + ALL_LAYOUT, + phi::AngleGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/angle_kernel.cc b/paddle/phi/kernels/cpu/angle_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc994f763cbad27ec5545bb00399a46d3e197bf6 --- /dev/null +++ b/paddle/phi/kernels/cpu/angle_kernel.cc @@ -0,0 +1,29 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/angle_kernel.h" +#include "paddle/phi/kernels/impl/angle_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(angle, + CPU, + ALL_LAYOUT, + phi::AngleKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/angle_grad_kernel.cu b/paddle/phi/kernels/gpu/angle_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..062c39a9d1f3f7df72500f76eacf5de2c999f387 --- /dev/null +++ b/paddle/phi/kernels/gpu/angle_grad_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/angle_grad_kernel.h" +#include "paddle/phi/kernels/impl/angle_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(angle_grad, + GPU, + ALL_LAYOUT, + phi::AngleGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/angle_kernel.cu b/paddle/phi/kernels/gpu/angle_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..237d97de21997dd3659de546bcde4bfda22dce2a --- /dev/null +++ b/paddle/phi/kernels/gpu/angle_kernel.cu @@ -0,0 +1,29 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/angle_kernel.h" +#include "paddle/phi/kernels/impl/angle_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(angle, + GPU, + ALL_LAYOUT, + phi::AngleKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/angle_grad_kernel_impl.h b/paddle/phi/kernels/impl/angle_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..84fc0c4cf2ad0f0940ae0076488e13eb07ff16dd --- /dev/null +++ b/paddle/phi/kernels/impl/angle_grad_kernel_impl.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +template +void AngleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + auto numel = out_grad.numel(); + auto* dout_data = out_grad.data>(); + auto* x_data = x.data(); + x_grad->Resize(out_grad.dims()); + auto* dx_data = dev_ctx.template Alloc(x_grad); + + phi::funcs::ForRange for_range(dev_ctx, numel); + phi::funcs::AngleGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/angle_kernel_impl.h b/paddle/phi/kernels/impl/angle_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..d0d51dcba399ff46ae63bb3976546b61daeb57fb --- /dev/null +++ b/paddle/phi/kernels/impl/angle_kernel_impl.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +template +void AngleKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + out->Resize(x.dims()); + auto* out_data = dev_ctx.template Alloc>(out); + + funcs::ForRange for_range(dev_ctx, numel); + funcs::AngleFunctor functor(x_data, out_data, numel); + for_range(functor); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/angle_sig.cc b/paddle/phi/ops/compat/angle_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..63b10e6bf401dcd9c9c069fb06adcce6d7db997c --- /dev/null +++ b/paddle/phi/ops/compat/angle_sig.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature AngleOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("angle", {"X"}, {}, {"Out"}); +} + +KernelSignature AngleGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("angle_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(angle, phi::AngleOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(angle_grad, phi::AngleGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_angle_op.py b/python/paddle/fluid/tests/unittests/test_angle_op.py index d21eb61b77dd952e16b1fdf85b6434a8552e78b1..9848c8320f2b4defe793b42d2eb830e14b32d555 100644 --- a/python/paddle/fluid/tests/unittests/test_angle_op.py +++ b/python/paddle/fluid/tests/unittests/test_angle_op.py @@ -43,6 +43,7 @@ class TestAngleOpFloat(OpTest): def setUp(self): self.op_type = "angle" + self.python_api = paddle.angle self.dtype = "float64" self.x = np.linspace(-5, 5, 101).astype(self.dtype) out_ref = np.angle(self.x) @@ -50,7 +51,7 @@ class TestAngleOpFloat(OpTest): self.outputs = {'Out': out_ref} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): self.check_grad(['X'], @@ -58,13 +59,15 @@ class TestAngleOpFloat(OpTest): user_defined_grads=[ angle_grad(self.x, np.ones_like(self.x) / self.x.size) - ]) + ], + check_eager=True) class TestAngleOpComplex(OpTest): def setUp(self): self.op_type = "angle" + self.python_api = paddle.angle self.dtype = "complex128" real = np.expand_dims(np.linspace(-2, 2, 11), -1).astype("float64") imag = np.linspace(-2, 2, 11).astype("float64") @@ -74,7 +77,7 @@ class TestAngleOpComplex(OpTest): self.outputs = {'Out': out_ref} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): self.check_grad(['X'], @@ -82,7 +85,8 @@ class TestAngleOpComplex(OpTest): user_defined_grads=[ angle_grad(self.x, np.ones_like(self.x) / self.x.size) - ]) + ], + check_eager=True) class TestAngleAPI(unittest.TestCase): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 0b5cc5bf6491bf9dda5118ee81c4c03514737a18..4fc725bf91303c4b10d06d196d77b3100ea7b63b 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4434,7 +4434,9 @@ def angle(x, name=None): # [-1.1071488 -0.7853982 0. 0.7853982]] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_angle(x) + elif paddle.in_dynamic_mode(): return _C_ops.angle(x) check_variable_and_dtype(x, 'x',