未验证 提交 b089e7cd 编写于 作者: R ronnywang 提交者: GitHub

[phi] migrate atan2_op into phi (#39806)

上级 dba694f4
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/atan2_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -25,16 +25,6 @@ namespace operators {
class Atan2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X1"), "Input", "X1", "atan2");
OP_INOUT_CHECK(ctx->HasInput("X2"), "Input", "X2", "atan2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "atan2");
auto in_dims = ctx->GetInputDim("X1");
ctx->SetOutputDim("Out", in_dims);
}
};
class Atan2OpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -115,24 +105,11 @@ class Atan2OpVarTypeInference : public framework::VarTypeInference {
} // namespace paddle
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(atan2, Atan2InferShapeFunctor,
PT_INFER_META(phi::Atan2InferMeta));
REGISTER_OPERATOR(atan2, ops::Atan2Op, ops::Atan2OpMaker,
ops::Atan2GradMaker<paddle::framework::OpDesc>,
ops::Atan2GradMaker<paddle::imperative::OpBase>,
ops::Atan2OpVarTypeInference);
ops::Atan2OpVarTypeInference, Atan2InferShapeFunctor);
REGISTER_OPERATOR(atan2_grad, ops::Atan2GradOp);
REGISTER_OP_CPU_KERNEL(
atan2, ops::Atan2Kernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_CPU_KERNEL(
atan2_grad, ops::Atan2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Atan2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Atan2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
......@@ -225,4 +225,9 @@ void HuberLossInferMeta(const MetaTensor& input,
out->share_lod(input);
}
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto in_dims = x.dims();
out->set_dims(in_dims);
}
} // namespace phi
......@@ -52,4 +52,6 @@ void HuberLossInferMeta(const MetaTensor& input_meta,
MetaTensor* out,
MetaTensor* residual,
MetaConfig config = MetaConfig());
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void Atan2GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void Atan2Kernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/atan2_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
PD_REGISTER_KERNEL(atan2_grad,
CPU,
ALL_LAYOUT,
phi::Atan2GradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
PD_REGISTER_KERNEL(atan2,
CPU,
ALL_LAYOUT,
phi::Atan2Kernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
......@@ -12,20 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/atan2_op.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/atan2_grad_kernel.h"
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
atan2, ops::Atan2Kernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
atan2_grad,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
PD_REGISTER_KERNEL(atan2_grad,
GPU,
ALL_LAYOUT,
phi::Atan2GradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
PD_REGISTER_KERNEL(atan2,
GPU,
ALL_LAYOUT,
phi::Atan2Kernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,72 +14,18 @@
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/atan2_grad_kernel.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::To32BitIndex;
template <typename T>
struct Atan2Out {
using type = T;
};
template <>
struct Atan2Out<int32_t> {
using type = double;
};
template <>
struct Atan2Out<int64_t> {
using type = double;
};
template <typename T>
struct Atan2Functor {
Atan2Functor(const T* x1, const T* x2, typename Atan2Out<T>::type* out,
int64_t numel)
: x1_(x1), x2_(x2), out_(out), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = static_cast<typename Atan2Out<T>::type>(
::atan2f(static_cast<float>(x1_[idx]), static_cast<float>(x2_[idx])));
}
const T* x1_;
const T* x2_;
typename Atan2Out<T>::type* out_;
int64_t numel_;
};
template <>
struct Atan2Functor<double> {
Atan2Functor(const double* x1, const double* x2, double* out, int64_t numel)
: x1_(x1), x2_(x2), out_(out), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = ::atan2(x1_[idx], x2_[idx]);
}
const double* x1_;
const double* x2_;
double* out_;
int64_t numel_;
};
namespace phi {
// dx1 = dout * x2 / ((x1)^2 + (x2)^2)
// dx2 = - dout * x1 / ((x1)^2 + (x2)^2)
template <typename T>
struct Atan2GradFunctor {
Atan2GradFunctor(const T* x1, const T* x2, const T* dout, T* dx1, T* dx2,
int64_t numel)
Atan2GradFunctor(
const T* x1, const T* x2, const T* dout, T* dx1, T* dx2, int64_t numel)
: x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -100,8 +46,12 @@ struct Atan2GradFunctor {
template <>
struct Atan2GradFunctor<double> {
Atan2GradFunctor(const double* x1, const double* x2, const double* dout,
double* dx1, double* dx2, int64_t numel)
Atan2GradFunctor(const double* x1,
const double* x2,
const double* dout,
double* dx1,
double* dx2,
int64_t numel)
: x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
......@@ -118,51 +68,27 @@ struct Atan2GradFunctor<double> {
int64_t numel_;
};
template <typename DeviceContext, typename T>
class Atan2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* X1 = context.Input<Tensor>("X1");
const Tensor* X2 = context.Input<Tensor>("X2");
Tensor* Out = context.Output<Tensor>("Out");
auto numel = X1->numel();
auto x1 = X1->data<T>();
auto x2 = X2->data<T>();
auto out = Out->mutable_data<typename Atan2Out<T>::type>(
context.GetPlace(), size_t(numel * sizeof(typename Atan2Out<T>::type)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
Atan2Functor<T> functor(x1, x2, out, numel);
template <typename T, typename Context>
void Atan2GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto numel = x.numel();
auto x_data = x.data<T>();
auto y_data = y.data<T>();
auto out_grad_data = out_grad.data<T>();
auto* x_grad_data =
ctx.template Alloc<T>(x_grad, size_t(x.numel() * sizeof(T)));
auto* y_grad_data =
ctx.template Alloc<T>(y_grad, size_t(y.numel() * sizeof(T)));
paddle::platform::ForRange<Context> for_range(ctx, numel);
phi::Atan2GradFunctor<T> functor(
x_data, y_data, out_grad_data, x_grad_data, y_grad_data, numel);
for_range(functor);
}
};
}
template <typename DeviceContext, typename T>
class Atan2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const {
const Tensor* X1 = context.Input<Tensor>("X1");
const Tensor* X2 = context.Input<Tensor>("X2");
const Tensor* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* dX1 = context.Output<Tensor>(framework::GradVarName("X1"));
Tensor* dX2 = context.Output<Tensor>(framework::GradVarName("X2"));
auto numel = X1->numel();
auto x1 = X1->data<T>();
auto x2 = X2->data<T>();
auto dout = dOut->data<T>();
auto dx1 =
dX1->mutable_data<T>(context.GetPlace(), size_t(numel * sizeof(T)));
auto dx2 =
dX2->mutable_data<T>(context.GetPlace(), size_t(numel * sizeof(T)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
Atan2GradFunctor<T> functor(x1, x2, dout, dx1, dx2, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/atan2_kernel.h"
namespace phi {
template <typename T>
struct Atan2Out {
using type = T;
};
template <>
struct Atan2Out<int32_t> {
using type = double;
};
template <>
struct Atan2Out<int64_t> {
using type = double;
};
template <typename T>
struct Atan2Functor {
Atan2Functor(const T* x1,
const T* x2,
typename Atan2Out<T>::type* out,
int64_t numel)
: x1_(x1), x2_(x2), out_(out), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = static_cast<typename Atan2Out<T>::type>(
::atan2f(static_cast<float>(x1_[idx]), static_cast<float>(x2_[idx])));
}
const T* x1_;
const T* x2_;
typename Atan2Out<T>::type* out_;
int64_t numel_;
};
template <>
struct Atan2Functor<double> {
Atan2Functor(const double* x1, const double* x2, double* out, int64_t numel)
: x1_(x1), x2_(x2), out_(out), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = ::atan2(x1_[idx], x2_[idx]);
}
const double* x1_;
const double* x2_;
double* out_;
int64_t numel_;
};
template <typename T, typename Context>
void Atan2Kernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto numel = x.numel();
auto x_data = x.data<T>();
auto y_data = y.data<T>();
auto* out_data = ctx.template Alloc<typename Atan2Out<T>::type>(
out, size_t(x.numel() * sizeof(typename Atan2Out<T>::type)));
paddle::platform::ForRange<Context> for_range(ctx, numel);
phi::Atan2Functor<T> functor(x_data, y_data, out_data, numel);
for_range(functor);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("atan2_grad",
{"X1", "X2", GradVarName("Out")},
{},
{GradVarName("X1"), GradVarName("X2")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(atan2_grad, phi::Atan2GradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册