diff --git a/paddle/fluid/operators/atan2_op.cc b/paddle/fluid/operators/atan2_op.cc index 8ee6540bfa5f0c413f759f58ab506ac181c19c49..71a895c244c54f62c0af1745635c08fea35436c4 100644 --- a/paddle/fluid/operators/atan2_op.cc +++ b/paddle/fluid/operators/atan2_op.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/atan2_op.h" - -#include -#include -#include -#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -25,16 +25,6 @@ namespace operators { class Atan2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X1"), "Input", "X1", "atan2"); - OP_INOUT_CHECK(ctx->HasInput("X2"), "Input", "X2", "atan2"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "atan2"); - - auto in_dims = ctx->GetInputDim("X1"); - - ctx->SetOutputDim("Out", in_dims); - } }; class Atan2OpMaker : public framework::OpProtoAndCheckerMaker { @@ -115,24 +105,11 @@ class Atan2OpVarTypeInference : public framework::VarTypeInference { } // namespace paddle namespace ops = paddle::operators; - +DELCARE_INFER_SHAPE_FUNCTOR(atan2, Atan2InferShapeFunctor, + PT_INFER_META(phi::Atan2InferMeta)); REGISTER_OPERATOR(atan2, ops::Atan2Op, ops::Atan2OpMaker, ops::Atan2GradMaker, ops::Atan2GradMaker, - ops::Atan2OpVarTypeInference); + ops::Atan2OpVarTypeInference, Atan2InferShapeFunctor); REGISTER_OPERATOR(atan2_grad, ops::Atan2GradOp); - -REGISTER_OP_CPU_KERNEL( - atan2, ops::Atan2Kernel, - ops::Atan2Kernel, - ops::Atan2Kernel, - ops::Atan2Kernel, - ops::Atan2Kernel); - -REGISTER_OP_CPU_KERNEL( - atan2_grad, ops::Atan2GradKernel, - ops::Atan2GradKernel, - ops::Atan2GradKernel); diff --git a/paddle/fluid/operators/atan2_op.cu b/paddle/fluid/operators/atan2_op.cu deleted file mode 100644 index faf1fde47e4c45a00836eee1d81ed1233170ecbe..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/atan2_op.cu +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/atan2_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - atan2, ops::Atan2Kernel, - ops::Atan2Kernel, - ops::Atan2Kernel, - ops::Atan2Kernel, - ops::Atan2Kernel); - -REGISTER_OP_CUDA_KERNEL( - atan2_grad, - ops::Atan2GradKernel, - ops::Atan2GradKernel, - ops::Atan2GradKernel); diff --git a/paddle/fluid/operators/atan2_op.h b/paddle/fluid/operators/atan2_op.h deleted file mode 100644 index a0e64c301524e2051abf8d2fc1641e0bcfafe69d..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/atan2_op.h +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; -using framework::To32BitIndex; - -template -struct Atan2Out { - using type = T; -}; - -template <> -struct Atan2Out { - using type = double; -}; - -template <> -struct Atan2Out { - using type = double; -}; - -template -struct Atan2Functor { - Atan2Functor(const T* x1, const T* x2, typename Atan2Out::type* out, - int64_t numel) - : x1_(x1), x2_(x2), out_(out), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - out_[idx] = static_cast::type>( - ::atan2f(static_cast(x1_[idx]), static_cast(x2_[idx]))); - } - - const T* x1_; - const T* x2_; - typename Atan2Out::type* out_; - int64_t numel_; -}; - -template <> -struct Atan2Functor { - Atan2Functor(const double* x1, const double* x2, double* out, int64_t numel) - : x1_(x1), x2_(x2), out_(out), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - out_[idx] = ::atan2(x1_[idx], x2_[idx]); - } - - const double* x1_; - const double* x2_; - double* out_; - int64_t numel_; -}; - -// dx1 = dout * x2 / ((x1)^2 + (x2)^2) -// dx2 = - dout * x1 / ((x1)^2 + (x2)^2) -template -struct Atan2GradFunctor { - Atan2GradFunctor(const T* x1, const T* x2, const T* dout, T* dx1, T* dx2, - int64_t numel) - : x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - float x1 = static_cast(x1_[idx]); - float x2 = static_cast(x2_[idx]); - float x = x1 * x1 + x2 * x2; - dx1_[idx] = static_cast(static_cast(dout_[idx]) * x2 / x); - dx2_[idx] = static_cast(-static_cast(dout_[idx]) * x1 / x); - } - - const T* x1_; - const T* x2_; - const T* dout_; - T* dx1_; - T* dx2_; - int64_t numel_; -}; - -template <> -struct Atan2GradFunctor { - Atan2GradFunctor(const double* x1, const double* x2, const double* dout, - double* dx1, double* dx2, int64_t numel) - : x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - auto x = x1_[idx] * x1_[idx] + x2_[idx] * x2_[idx]; - dx1_[idx] = dout_[idx] * x2_[idx] / x; - dx2_[idx] = -dout_[idx] * x1_[idx] / x; - } - - const double* x1_; - const double* x2_; - const double* dout_; - double* dx1_; - double* dx2_; - int64_t numel_; -}; - -template -class Atan2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* X1 = context.Input("X1"); - const Tensor* X2 = context.Input("X2"); - Tensor* Out = context.Output("Out"); - - auto numel = X1->numel(); - auto x1 = X1->data(); - auto x2 = X2->data(); - auto out = Out->mutable_data::type>( - context.GetPlace(), size_t(numel * sizeof(typename Atan2Out::type))); - auto& dev_ctx = context.template device_context(); - - platform::ForRange for_range(dev_ctx, numel); - Atan2Functor functor(x1, x2, out, numel); - for_range(functor); - } -}; - -template -class Atan2GradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const { - const Tensor* X1 = context.Input("X1"); - const Tensor* X2 = context.Input("X2"); - const Tensor* dOut = context.Input(framework::GradVarName("Out")); - Tensor* dX1 = context.Output(framework::GradVarName("X1")); - Tensor* dX2 = context.Output(framework::GradVarName("X2")); - - auto numel = X1->numel(); - auto x1 = X1->data(); - auto x2 = X2->data(); - auto dout = dOut->data(); - auto dx1 = - dX1->mutable_data(context.GetPlace(), size_t(numel * sizeof(T))); - auto dx2 = - dX2->mutable_data(context.GetPlace(), size_t(numel * sizeof(T))); - auto& dev_ctx = context.template device_context(); - - platform::ForRange for_range(dev_ctx, numel); - Atan2GradFunctor functor(x1, x2, dout, dx1, dx2, numel); - for_range(functor); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 7455f1e6a0896fa25a3b02a03da3f3223f1d087b..e94926a9c1403c8d5c2da8cdd939c11745c0768f 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -225,4 +225,9 @@ void HuberLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { + auto in_dims = x.dims(); + out->set_dims(in_dims); +} + } // namespace phi diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 93ef9f5f35abbac2fd6c2c804efeb5a767a0d20f..f23382be89b6aa726b99ab5a6e0a1bf40c60cf44 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -52,4 +52,6 @@ void HuberLossInferMeta(const MetaTensor& input_meta, MetaTensor* out, MetaTensor* residual, MetaConfig config = MetaConfig()); + +void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/atan2_grad_kernel.h b/paddle/phi/kernels/atan2_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ddd87c9da156d4b9ff983972010b90a74a231c4a --- /dev/null +++ b/paddle/phi/kernels/atan2_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void Atan2GradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/atan2_kernel.h b/paddle/phi/kernels/atan2_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..38276fa4f73ce5c0c94383a90e6f6f711efd9bcf --- /dev/null +++ b/paddle/phi/kernels/atan2_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void Atan2Kernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/atan2_grad_kernel.cc b/paddle/phi/kernels/cpu/atan2_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ff7431f0c8c556770b54e1328251e5996850fc9 --- /dev/null +++ b/paddle/phi/kernels/cpu/atan2_grad_kernel.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/atan2_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(atan2_grad, + CPU, + ALL_LAYOUT, + phi::Atan2GradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/atan2_kernel.cc b/paddle/phi/kernels/cpu/atan2_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb38a6c90b7938ef16cf9d56dfdb93903cc3c6a1 --- /dev/null +++ b/paddle/phi/kernels/cpu/atan2_kernel.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/atan2_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/atan2_kernel_impl.h" + +PD_REGISTER_KERNEL(atan2, + CPU, + ALL_LAYOUT, + phi::Atan2Kernel, + float, + double, + phi::dtype::float16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/atan2_grad_kernel.cu b/paddle/phi/kernels/gpu/atan2_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..1cc3311c3639820ef9b6d3a29d9274ac93bb5963 --- /dev/null +++ b/paddle/phi/kernels/gpu/atan2_grad_kernel.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/atan2_grad_kernel.h" +#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(atan2_grad, + GPU, + ALL_LAYOUT, + phi::Atan2GradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/atan2_kernel.cu b/paddle/phi/kernels/gpu/atan2_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..702c959b78f75d0e52511d9bdc9d4330c6838aa4 --- /dev/null +++ b/paddle/phi/kernels/gpu/atan2_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/atan2_kernel.h" +#include "paddle/phi/kernels/impl/atan2_kernel_impl.h" + +PD_REGISTER_KERNEL(atan2, + GPU, + ALL_LAYOUT, + phi::Atan2Kernel, + float, + double, + phi::dtype::float16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h b/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..5f75a95f4a7b18f0ccf450e003860eeeef3c649d --- /dev/null +++ b/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/atan2_grad_kernel.h" + +namespace phi { + +// dx1 = dout * x2 / ((x1)^2 + (x2)^2) +// dx2 = - dout * x1 / ((x1)^2 + (x2)^2) +template +struct Atan2GradFunctor { + Atan2GradFunctor( + const T* x1, const T* x2, const T* dout, T* dx1, T* dx2, int64_t numel) + : x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + float x1 = static_cast(x1_[idx]); + float x2 = static_cast(x2_[idx]); + float x = x1 * x1 + x2 * x2; + dx1_[idx] = static_cast(static_cast(dout_[idx]) * x2 / x); + dx2_[idx] = static_cast(-static_cast(dout_[idx]) * x1 / x); + } + + const T* x1_; + const T* x2_; + const T* dout_; + T* dx1_; + T* dx2_; + int64_t numel_; +}; + +template <> +struct Atan2GradFunctor { + Atan2GradFunctor(const double* x1, + const double* x2, + const double* dout, + double* dx1, + double* dx2, + int64_t numel) + : x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + auto x = x1_[idx] * x1_[idx] + x2_[idx] * x2_[idx]; + dx1_[idx] = dout_[idx] * x2_[idx] / x; + dx2_[idx] = -dout_[idx] * x1_[idx] / x; + } + + const double* x1_; + const double* x2_; + const double* dout_; + double* dx1_; + double* dx2_; + int64_t numel_; +}; + +template +void Atan2GradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto numel = x.numel(); + auto x_data = x.data(); + auto y_data = y.data(); + auto out_grad_data = out_grad.data(); + + auto* x_grad_data = + ctx.template Alloc(x_grad, size_t(x.numel() * sizeof(T))); + auto* y_grad_data = + ctx.template Alloc(y_grad, size_t(y.numel() * sizeof(T))); + + paddle::platform::ForRange for_range(ctx, numel); + phi::Atan2GradFunctor functor( + x_data, y_data, out_grad_data, x_grad_data, y_grad_data, numel); + for_range(functor); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/atan2_kernel_impl.h b/paddle/phi/kernels/impl/atan2_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..c29449a27e0b5603c4e6f50c8ed676677c29796a --- /dev/null +++ b/paddle/phi/kernels/impl/atan2_kernel_impl.h @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/atan2_kernel.h" + +namespace phi { +template +struct Atan2Out { + using type = T; +}; + +template <> +struct Atan2Out { + using type = double; +}; + +template <> +struct Atan2Out { + using type = double; +}; + +template +struct Atan2Functor { + Atan2Functor(const T* x1, + const T* x2, + typename Atan2Out::type* out, + int64_t numel) + : x1_(x1), x2_(x2), out_(out), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + out_[idx] = static_cast::type>( + ::atan2f(static_cast(x1_[idx]), static_cast(x2_[idx]))); + } + + const T* x1_; + const T* x2_; + typename Atan2Out::type* out_; + int64_t numel_; +}; + +template <> +struct Atan2Functor { + Atan2Functor(const double* x1, const double* x2, double* out, int64_t numel) + : x1_(x1), x2_(x2), out_(out), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + out_[idx] = ::atan2(x1_[idx], x2_[idx]); + } + + const double* x1_; + const double* x2_; + double* out_; + int64_t numel_; +}; + +template +void Atan2Kernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + auto numel = x.numel(); + auto x_data = x.data(); + auto y_data = y.data(); + + auto* out_data = ctx.template Alloc::type>( + out, size_t(x.numel() * sizeof(typename Atan2Out::type))); + + paddle::platform::ForRange for_range(ctx, numel); + phi::Atan2Functor functor(x_data, y_data, out_data, numel); + for_range(functor); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/atan2_sig.cc b/paddle/phi/ops/compat/atan2_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a6049e67b668e4cd97e928414bbca10bf29c0c4 --- /dev/null +++ b/paddle/phi/ops/compat/atan2_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("atan2_grad", + {"X1", "X2", GradVarName("Out")}, + {}, + {GradVarName("X1"), GradVarName("X2")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(atan2_grad, phi::Atan2GradOpArgumentMapping);