From 918aeb714f3694e1dbad5ffced57d484d15d33ce Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Wed, 16 Jun 2021 21:38:57 -0500 Subject: [PATCH] Add atan2 op and test (#33067) * add atan2_op * fix --- paddle/fluid/operators/atan2_op.cc | 138 ++++++++++++++ paddle/fluid/operators/atan2_op.cu | 31 ++++ paddle/fluid/operators/atan2_op.h | 168 ++++++++++++++++++ python/paddle/__init__.py | 2 + .../fluid/tests/unittests/test_atan2_op.py | 132 ++++++++++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/math.py | 56 ++++++ 7 files changed, 528 insertions(+) create mode 100644 paddle/fluid/operators/atan2_op.cc create mode 100644 paddle/fluid/operators/atan2_op.cu create mode 100644 paddle/fluid/operators/atan2_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_atan2_op.py diff --git a/paddle/fluid/operators/atan2_op.cc b/paddle/fluid/operators/atan2_op.cc new file mode 100644 index 00000000000..8ee6540bfa5 --- /dev/null +++ b/paddle/fluid/operators/atan2_op.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/atan2_op.h" + +#include +#include +#include +#include + +namespace paddle { +namespace operators { + +class Atan2Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X1"), "Input", "X1", "atan2"); + OP_INOUT_CHECK(ctx->HasInput("X2"), "Input", "X2", "atan2"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "atan2"); + + auto in_dims = ctx->GetInputDim("X1"); + + ctx->SetOutputDim("Out", in_dims); + } +}; + +class Atan2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X1", "(Tensor), The input tensor of atan2 op."); + AddInput("X2", "(Tensor), The input tensor of atan2 op."); + AddOutput("Out", "(Tensor), The output tensor of atan2 op."); + AddComment(R"DOC( +Atan2 Operator. + +This operator is used to perform elementwise atan2 for input $X1$, $X2$. +$$out = atan2(x1, x2)$$ + +)DOC"); + } +}; + +class Atan2GradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X1"), "Input", "X1", "Atan2Grad"); + OP_INOUT_CHECK(ctx->HasInput("X2"), "Input", "X2", "Atan2Grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "Atan2Grad"); + + auto x1_grad_name = framework::GradVarName("X1"); + auto x2_grad_name = framework::GradVarName("X2"); + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + if (ctx->HasOutput(x1_grad_name)) { + ctx->SetOutputDim(framework::GradVarName("X1"), dout_dims); + } + if (ctx->HasOutput(x2_grad_name)) { + ctx->SetOutputDim(framework::GradVarName("X2"), dout_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X1"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class Atan2GradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("atan2_grad"); + retv->SetInput("X1", this->Input("X1")); + retv->SetInput("X2", this->Input("X2")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X1"), this->InputGrad("X1")); + retv->SetOutput(framework::GradVarName("X2"), this->InputGrad("X2")); + } +}; + +class Atan2OpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override { + auto type = ctx->GetInputDataType("X1"); + if (ctx->GetInputDataType("X1") == framework::proto::VarType::INT32 || + ctx->GetInputDataType("X1") == framework::proto::VarType::INT64 || + ctx->GetInputDataType("X2") == framework::proto::VarType::INT32 || + ctx->GetInputDataType("X2") == framework::proto::VarType::INT64) { + type = framework::proto::VarType::FP64; + } + ctx->SetOutputDataType("Out", type); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(atan2, ops::Atan2Op, ops::Atan2OpMaker, + ops::Atan2GradMaker, + ops::Atan2GradMaker, + ops::Atan2OpVarTypeInference); + +REGISTER_OPERATOR(atan2_grad, ops::Atan2GradOp); + +REGISTER_OP_CPU_KERNEL( + atan2, ops::Atan2Kernel, + ops::Atan2Kernel, + ops::Atan2Kernel, + ops::Atan2Kernel, + ops::Atan2Kernel); + +REGISTER_OP_CPU_KERNEL( + atan2_grad, ops::Atan2GradKernel, + ops::Atan2GradKernel, + ops::Atan2GradKernel); diff --git a/paddle/fluid/operators/atan2_op.cu b/paddle/fluid/operators/atan2_op.cu new file mode 100644 index 00000000000..faf1fde47e4 --- /dev/null +++ b/paddle/fluid/operators/atan2_op.cu @@ -0,0 +1,31 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/atan2_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + atan2, ops::Atan2Kernel, + ops::Atan2Kernel, + ops::Atan2Kernel, + ops::Atan2Kernel, + ops::Atan2Kernel); + +REGISTER_OP_CUDA_KERNEL( + atan2_grad, + ops::Atan2GradKernel, + ops::Atan2GradKernel, + ops::Atan2GradKernel); diff --git a/paddle/fluid/operators/atan2_op.h b/paddle/fluid/operators/atan2_op.h new file mode 100644 index 00000000000..8ed0fda843d --- /dev/null +++ b/paddle/fluid/operators/atan2_op.h @@ -0,0 +1,168 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using framework::To32BitIndex; + +template +struct Atan2Out { + using type = T; +}; + +template <> +struct Atan2Out { + using type = double; +}; + +template <> +struct Atan2Out { + using type = double; +}; + +template +struct Atan2Functor { + Atan2Functor(const T* x1, const T* x2, typename Atan2Out::type* out, + int64_t numel) + : x1_(x1), x2_(x2), out_(out), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + out_[idx] = static_cast::type>( + ::atan2f(static_cast(x1_[idx]), static_cast(x2_[idx]))); + } + + const T* x1_; + const T* x2_; + typename Atan2Out::type* out_; + int64_t numel_; +}; + +template <> +struct Atan2Functor { + Atan2Functor(const double* x1, const double* x2, double* out, int64_t numel) + : x1_(x1), x2_(x2), out_(out), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + out_[idx] = ::atan2(x1_[idx], x2_[idx]); + } + + const double* x1_; + const double* x2_; + double* out_; + int64_t numel_; +}; + +// dx1 = dout * x2 / ((x1)^2 + (x2)^2) +// dx2 = - dout * x1 / ((x1)^2 + (x2)^2) +template +struct Atan2GradFunctor { + Atan2GradFunctor(const T* x1, const T* x2, const T* dout, T* dx1, T* dx2, + int64_t numel) + : x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + float x1 = static_cast(x1_[idx]); + float x2 = static_cast(x2_[idx]); + float x = x1 * x1 + x2 * x2; + dx1_[idx] = static_cast(static_cast(dout_[idx]) * x2 / x); + dx2_[idx] = static_cast(-static_cast(dout_[idx]) * x1 / x); + } + + const T* x1_; + const T* x2_; + const T* dout_; + T* dx1_; + T* dx2_; + int64_t numel_; +}; + +template <> +struct Atan2GradFunctor { + Atan2GradFunctor(const double* x1, const double* x2, const double* dout, + double* dx1, double* dx2, int64_t numel) + : x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto x = x1_[idx] * x1_[idx] + x2_[idx] * x2_[idx]; + dx1_[idx] = dout_[idx] * x2_[idx] / x; + dx2_[idx] = -dout_[idx] * x1_[idx] / x; + } + + const double* x1_; + const double* x2_; + const double* dout_; + double* dx1_; + double* dx2_; + int64_t numel_; +}; + +template +class Atan2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* X1 = context.Input("X1"); + const Tensor* X2 = context.Input("X2"); + Tensor* Out = context.Output("Out"); + + auto numel = X1->numel(); + auto x1 = X1->data(); + auto x2 = X2->data(); + auto out = Out->mutable_data::type>( + context.GetPlace(), size_t(numel * sizeof(typename Atan2Out::type))); + auto& dev_ctx = context.template device_context(); + + platform::ForRange for_range(dev_ctx, numel); + Atan2Functor functor(x1, x2, out, numel); + for_range(functor); + } +}; + +template +class Atan2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const { + const Tensor* X1 = context.Input("X1"); + const Tensor* X2 = context.Input("X2"); + const Tensor* dOut = context.Input(framework::GradVarName("Out")); + Tensor* dX1 = context.Output(framework::GradVarName("X1")); + Tensor* dX2 = context.Output(framework::GradVarName("X2")); + + auto numel = X1->numel(); + auto x1 = X1->data(); + auto x2 = X2->data(); + auto dout = dOut->data(); + auto dx1 = + dX1->mutable_data(context.GetPlace(), size_t(numel * sizeof(T))); + auto dx2 = + dX2->mutable_data(context.GetPlace(), size_t(numel * sizeof(T))); + auto& dev_ctx = context.template device_context(); + + platform::ForRange for_range(dev_ctx, numel); + Atan2GradFunctor functor(x1, x2, dout, dx1, dx2, numel); + for_range(functor); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index cc8a43c572c..a3b01573b62 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -152,6 +152,7 @@ from .tensor.math import abs # noqa: F401 from .tensor.math import acos # noqa: F401 from .tensor.math import asin # noqa: F401 from .tensor.math import atan # noqa: F401 +from .tensor.math import atan2 # noqa: F401 from .tensor.math import ceil # noqa: F401 from .tensor.math import cos # noqa: F401 from .tensor.math import tan # noqa: F401 @@ -434,6 +435,7 @@ __all__ = [ # noqa 'divide', 'ceil', 'atan', + 'atan2', 'expand', 'broadcast_to', 'ones_like', diff --git a/python/paddle/fluid/tests/unittests/test_atan2_op.py b/python/paddle/fluid/tests/unittests/test_atan2_op.py new file mode 100644 index 00000000000..b29ab822f25 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_atan2_op.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest + +from op_test import OpTest +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import compiler, Program, program_guard + +paddle.enable_static() +np.random.seed(0) + + +def atan2_grad(x1, x2, dout): + dx1 = dout * x2 / (x1 * x1 + x2 * x2) + dx2 = -dout * x1 / (x1 * x1 + x2 * x2) + return dx1, dx2 + + +class TestAtan2(OpTest): + def setUp(self): + self.op_type = "atan2" + self.init_dtype() + + x1 = np.random.uniform(-1, -0.1, [15, 17]).astype(self.dtype) + x2 = np.random.uniform(0.1, 1, [15, 17]).astype(self.dtype) + out = np.arctan2(x1, x2) + + self.inputs = {'X1': x1, 'X2': x2} + self.outputs = {'Out': out} + + def test_check_grad(self): + self.check_grad(['X1', 'X2'], 'Out') + + def test_check_output(self): + self.check_output() + + def init_dtype(self): + self.dtype = np.float64 + + +class TestAtan2_float(TestAtan2): + def init_dtype(self): + self.dtype = np.float32 + + def test_check_grad(self): + if self.dtype not in [np.int32, np.int64]: + self.check_grad( + ['X1', 'X2'], + 'Out', + user_defined_grads=atan2_grad(self.inputs['X1'], + self.inputs['X2'], + 1 / self.inputs['X1'].size)) + + +class TestAtan2_float16(TestAtan2_float): + def init_dtype(self): + self.dtype = np.float16 + + +class TestAtan2_int32(TestAtan2_float): + def init_dtype(self): + self.dtype = np.int32 + + +class TestAtan2_int64(TestAtan2_float): + def init_dtype(self): + self.dtype = np.int64 + + +class TestAtan2API(unittest.TestCase): + def init_dtype(self): + self.dtype = 'float64' + self.shape = [11, 17] + + def setUp(self): + self.init_dtype() + self.x1 = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + self.x2 = np.random.uniform(-1, -0.1, self.shape).astype(self.dtype) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + X1 = paddle.fluid.data('X1', self.shape, dtype=self.dtype) + X2 = paddle.fluid.data('X2', self.shape, dtype=self.dtype) + out = paddle.atan2(X1, X2) + exe = paddle.static.Executor(place) + res = exe.run(feed={'X1': self.x1, 'X2': self.x2}) + out_ref = np.arctan2(self.x1, self.x2) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + for place in self.place: + run(place) + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + X1 = paddle.to_tensor(self.x1) + X2 = paddle.to_tensor(self.x2) + out = paddle.atan2(X1, X2) + out_ref = np.arctan2(self.x1, self.x2) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 2cb3f540634..bdefece122a 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -147,6 +147,7 @@ from .math import add # noqa: F401 from .math import add_ # noqa: F401 from .math import subtract # noqa: F401 from .math import subtract_ # noqa: F401 +from .math import atan2 # noqa: F401 from .math import logsumexp # noqa: F401 from .math import inverse # noqa: F401 from .math import log2 # noqa: F401 diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2ffb8d9302c..3f1f2b42147 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2386,3 +2386,59 @@ def neg(x, name=None): """ return layers.scale(x, scale=-1.0, bias=0.0, bias_after_scale=True, act=None, name=name) + +def atan2(y, x, name=None): + r""" + Element-wise arctangent of y/x with consideration of the quadrant. + + Equation: + .. math:: + + atan2(y,x)=\left\{\begin{matrix} + & tan^{-1}(\frac{y}{x}) & x > 0 \\ + & tan^{-1}(\frac{y}{x}) + \pi & y>=0, x < 0 \\ + & tan^{-1}(\frac{y}{x}) - \pi & y<0, x < 0 \\ + & +\frac{\pi}{2} & y>0, x = 0 \\ + & -\frac{\pi}{2} & y<0, x = 0 \\ + &\text{undefined} & y=0, x = 0 + \end{matrix}\right. + + Args: + y (Tensor): An N-D Tensor, the data type is int32, int64, float16, float32, float64. + x (Tensor): An N-D Tensor, must have the same type as `x`. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor, the shape and data type is the same with input (The output data type is float64 when the input data type is int). + + Examples: + .. code-block:: python + + import paddle + + y = paddle.to_tensor([-1, +1, +1, -1]).astype('float32') + #Tensor(shape=[4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [-1, 1, 1, -1]) + + x = paddle.to_tensor([-1, -1, +1, +1]).astype('float32') + #Tensor(shape=[4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [-1, -1, 1, 1]) + + out = paddle.atan2(y, x) + #Tensor(shape=[4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [-2.35619450, 2.35619450, 0.78539819, -0.78539819]) + + """ + + if in_dygraph_mode(): + return core.ops.atan2(y, x) + else: + check_variable_and_dtype(y, 'y', ['int32', 'int64', 'float16', 'float32', 'float64'], 'atan2') + check_variable_and_dtype(x, 'x', ['int32', 'int64', 'float16', 'float32', 'float64'], 'atan2') + + helper = LayerHelper('atan2', **locals()) + inputs = {'X1' : y, 'X2' : x} + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='atan2', inputs=inputs, outputs={'Out': out}) + return out -- GitLab