未验证 提交 547075e9 编写于 作者: W WangZhen 提交者: GitHub

[Phi]Move angle op to phi (#44393)

* Move angle op to phi

* Replace mutable_data using Alloc

* Remove some include

* Try to fix windows ci error

* include math.h to fix windows ci error

* Fix kernel name

* Move angle_grad infershape
上级 3788f5e5
......@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/angle_op.h"
#include <memory>
#include <string>
#include <unordered_map>
......@@ -22,22 +20,18 @@
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class AngleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "angle");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "angle");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class AngleOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -67,20 +61,6 @@ $$out = angle(x)$$
class AngleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@Grad",
"angle_grad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Out@Grad", "angle_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
"X@Grad",
"angle_grad");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -108,24 +88,19 @@ class AngleGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(angle,
AngleInferShapeFunctor,
PD_INFER_META(phi::RealAndImagInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(angle_grad,
AngleGradInferShapeFunctor,
PD_INFER_META(phi::AngleGradInferMeta));
REGISTER_OPERATOR(angle,
ops::AngleOp,
ops::AngleOpMaker,
ops::AngleGradMaker<paddle::framework::OpDesc>,
ops::AngleGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
angle,
ops::AngleKernel<phi::CPUContext, float>,
ops::AngleKernel<phi::CPUContext, double>,
ops::AngleKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::AngleKernel<phi::CPUContext, paddle::platform::complex<double>>);
REGISTER_OPERATOR(angle_grad, ops::AngleGradOp);
REGISTER_OP_CPU_KERNEL(
angle_grad,
ops::AngleGradKernel<phi::CPUContext, float>,
ops::AngleGradKernel<phi::CPUContext, double>,
ops::AngleGradKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::AngleGradKernel<phi::CPUContext, paddle::platform::complex<double>>);
ops::AngleGradMaker<paddle::imperative::OpBase>,
AngleInferShapeFunctor);
REGISTER_OPERATOR(angle_grad, ops::AngleGradOp, AngleGradInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <cmath>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class AngleKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
auto numel = x->numel();
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<phi::dtype::Real<T>>(
context.GetPlace(), size_t(x->numel() * sizeof(phi::dtype::Real<T>)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
phi::funcs::AngleFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class AngleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<phi::dtype::Real<T>>();
auto* x_data = x->data<T>();
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
phi::funcs::AngleGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -118,6 +118,15 @@
kernel :
func : allclose
- api : angle
args : (Tensor x)
output : Tensor
infer_meta :
func : RealAndImagInferMeta
kernel :
func : angle
backward : angle_grad
- api : any
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
output : Tensor(out)
......
......@@ -92,6 +92,18 @@
kernel :
func : addmm_grad
- backward_api : angle_grad
forward : angle (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : angle_grad
data_transform:
skip_transform : out_grad
- backward_api : argsort_grad
forward : argsort (Tensor x, int axis, bool descending) -> Tensor(out), Tensor(indices)
args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending)
......
......@@ -19,6 +19,12 @@ limitations under the License. */
namespace phi {
void AngleGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
MetaTensor* x_grad) {
UnchangedInferMeta(x, x_grad);
}
void BilinearTensorProductGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
......
......@@ -28,6 +28,10 @@ namespace phi {
//
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
void AngleGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
MetaTensor* x_grad);
void BilinearTensorProductGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,22 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/angle_op.h"
#include "paddle/fluid/platform/complex.h"
#pragma once
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#include "paddle/phi/core/dense_tensor.h"
REGISTER_OP_CUDA_KERNEL(
angle,
ops::AngleKernel<plat::CUDADeviceContext, float>,
ops::AngleKernel<plat::CUDADeviceContext, double>,
ops::AngleKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AngleKernel<plat::CUDADeviceContext, plat::complex<double>>);
namespace phi {
REGISTER_OP_CUDA_KERNEL(
angle_grad,
ops::AngleGradKernel<plat::CUDADeviceContext, float>,
ops::AngleGradKernel<plat::CUDADeviceContext, double>,
ops::AngleGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AngleGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
template <typename T, typename Context>
void AngleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <math.h>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AngleKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/angle_grad_kernel.h"
#include "paddle/phi/kernels/impl/angle_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(angle_grad,
CPU,
ALL_LAYOUT,
phi::AngleGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/angle_kernel.h"
#include "paddle/phi/kernels/impl/angle_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(angle,
CPU,
ALL_LAYOUT,
phi::AngleKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/angle_grad_kernel.h"
#include "paddle/phi/kernels/impl/angle_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(angle_grad,
GPU,
ALL_LAYOUT,
phi::AngleGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/angle_kernel.h"
#include "paddle/phi/kernels/impl/angle_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(angle,
GPU,
ALL_LAYOUT,
phi::AngleKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T, typename Context>
void AngleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto numel = out_grad.numel();
auto* dout_data = out_grad.data<phi::dtype::Real<T>>();
auto* x_data = x.data<T>();
x_grad->Resize(out_grad.dims());
auto* dx_data = dev_ctx.template Alloc<T>(x_grad);
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::AngleGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T, typename Context>
void AngleKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
out->Resize(x.dims());
auto* out_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(out);
funcs::ForRange<Context> for_range(dev_ctx, numel);
funcs::AngleFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature AngleOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("angle", {"X"}, {}, {"Out"});
}
KernelSignature AngleGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("angle_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(angle, phi::AngleOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(angle_grad, phi::AngleGradOpArgumentMapping);
......@@ -43,6 +43,7 @@ class TestAngleOpFloat(OpTest):
def setUp(self):
self.op_type = "angle"
self.python_api = paddle.angle
self.dtype = "float64"
self.x = np.linspace(-5, 5, 101).astype(self.dtype)
out_ref = np.angle(self.x)
......@@ -50,7 +51,7 @@ class TestAngleOpFloat(OpTest):
self.outputs = {'Out': out_ref}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'],
......@@ -58,13 +59,15 @@ class TestAngleOpFloat(OpTest):
user_defined_grads=[
angle_grad(self.x,
np.ones_like(self.x) / self.x.size)
])
],
check_eager=True)
class TestAngleOpComplex(OpTest):
def setUp(self):
self.op_type = "angle"
self.python_api = paddle.angle
self.dtype = "complex128"
real = np.expand_dims(np.linspace(-2, 2, 11), -1).astype("float64")
imag = np.linspace(-2, 2, 11).astype("float64")
......@@ -74,7 +77,7 @@ class TestAngleOpComplex(OpTest):
self.outputs = {'Out': out_ref}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'],
......@@ -82,7 +85,8 @@ class TestAngleOpComplex(OpTest):
user_defined_grads=[
angle_grad(self.x,
np.ones_like(self.x) / self.x.size)
])
],
check_eager=True)
class TestAngleAPI(unittest.TestCase):
......
......@@ -4434,7 +4434,9 @@ def angle(x, name=None):
# [-1.1071488 -0.7853982 0. 0.7853982]]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_angle(x)
elif paddle.in_dynamic_mode():
return _C_ops.angle(x)
check_variable_and_dtype(x, 'x',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册