未验证 提交 b1aa693e 编写于 作者: Z zhangbo9674 提交者: GitHub

[Phi] Migrate complex_op into Phi & Add complex api yaml (#44233)

* mv to phi

* refine infermeta code position

* refine grad code

* add api yaml and add final_state_api

* refine code
上级 05d5bbfb
......@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/complex_op.h"
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -59,36 +59,6 @@ class ComplexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "complex");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "complex");
if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) {
ctx->ShareDim("X", /*->*/ "Out");
// NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory
// so tensors with lod are not supported here
} else {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int max_dim = std::max(x_dims.size(), y_dims.size());
// start align axis
int axis = std::abs(x_dims.size() - y_dims.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -101,25 +71,6 @@ class ComplexGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex_grad");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_complex_gradgrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"complex_grad");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
}
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -135,18 +86,21 @@ class ComplexGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(complex,
ComplexInferShapeFunctor,
PD_INFER_META(phi::ComplexInferMeta));
REGISTER_OPERATOR(complex,
ops::ComplexOp,
ops::ComplexOpMaker,
ops::ComplexGradOpMaker<paddle::framework::OpDesc>,
ops::ComplexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(complex_grad, ops::ComplexGradOp);
ops::ComplexGradOpMaker<paddle::imperative::OpBase>,
ComplexInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(complex,
ops::ComplexKernel<phi::CPUContext, float>,
ops::ComplexKernel<phi::CPUContext, double>);
DECLARE_INFER_SHAPE_FUNCTOR(complex_grad,
ComplexGradInferShapeFunctor,
PD_INFER_META(phi::ComplexGradInferMeta));
REGISTER_OP_CPU_KERNEL(complex_grad,
ops::ComplexGradKernel<phi::CPUContext, float>,
ops::ComplexGradKernel<phi::CPUContext, double>);
REGISTER_OPERATOR(complex_grad,
ops::ComplexGradOp,
ComplexGradInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/complex_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
complex,
ops::ComplexKernel<paddle::platform::CUDADeviceContext, float>,
ops::ComplexKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
complex_grad,
ops::ComplexGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ComplexGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
// functors to use with ElementwiseComputeEx
template <typename T>
struct RealAndImagToComplexFunctor {
inline HOSTDEVICE platform::complex<T> operator()(const T x, const T y) {
return platform::complex<T>(x, y);
}
};
template <typename T>
struct ImagAndRealToComplexFunctor {
inline HOSTDEVICE platform::complex<T> operator()(const T y, const T x) {
return platform::complex<T>(x, y);
}
};
template <typename T>
struct ComplexGradForRealFunctor {
inline HOSTDEVICE T operator()(const T x,
const T y,
const platform::complex<T> out,
const platform::complex<T> dout) {
return dout.real;
}
};
template <typename T>
struct ComplexGradForImagFunctor {
inline HOSTDEVICE T operator()(const T x,
const T y,
const platform::complex<T> out,
const platform::complex<T> dout) {
return dout.imag;
}
};
template <typename DeviceContext, typename T>
class ComplexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<framework::Tensor>("X");
const auto* y = ctx.Input<framework::Tensor>("Y");
auto* z = ctx.Output<framework::Tensor>("Out");
using C = platform::complex<T>;
z->mutable_data<C>(ctx.GetPlace());
// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related
// facility functions
#if defined(__NVCC__) || defined(__HIPCC__)
ElementwiseComputeEx<RealAndImagToComplexFunctor<T>, DeviceContext, T, C>(
ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), z);
#else
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<RealAndImagToComplexFunctor<T>, DeviceContext, T, C>(
ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), z);
} else {
ElementwiseComputeEx<ImagAndRealToComplexFunctor<T>, DeviceContext, T, C>(
ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor<T>(), z);
}
#endif
}
};
template <typename DeviceContext, typename T>
class ComplexGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
using C = platform::complex<T>;
// skip out in a hacky way
auto* out = dout;
ElemwiseGradCompute<DeviceContext,
T,
ComplexGradForRealFunctor<T>,
ComplexGradForImagFunctor<T>,
C>(ctx,
*x,
*y,
*out,
*dout,
/*axis*/ -1,
dx,
dy,
ComplexGradForRealFunctor<T>(),
ComplexGradForImagFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
......@@ -342,6 +342,15 @@
func : clip
backward : clip_grad
- api : complex
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : ComplexInferMeta
kernel :
func : complex
backward : complex_grad
- api : concat
args : (Tensor[] x, Scalar(int64_t) axis)
output : Tensor
......
......@@ -306,6 +306,16 @@
backward : clip_double_grad
inplace : (out_grad -> x_grad)
- backward_api : complex_grad
forward : complex (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : ComplexGradInferMeta
kernel :
func : complex_grad
data_type : x
- backward_api : concat_double_grad
forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x)
args : (Tensor[] grad_x_grad, Scalar axis = 0)
......
......@@ -83,6 +83,23 @@ void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
x_grad->set_dtype(out_grad.dtype());
}
void ComplexGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& dout,
MetaTensor* dx,
MetaTensor* dy) {
auto x_dims = x.dims();
if (dx) {
dx->set_dims(x_dims);
dx->set_dtype(x.dtype());
}
auto y_dims = y.dims();
if (dy) {
dy->set_dims(y_dims);
dy->set_dtype(y.dtype());
}
}
void ConvTransposeGradInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& dout,
......
......@@ -42,6 +42,12 @@ void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
const std::string& data_format,
MetaTensor* x_grad);
void ComplexGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& dout,
MetaTensor* dx,
MetaTensor* dy);
void ConvTransposeGradInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& dout,
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
......@@ -358,6 +359,37 @@ void CompareAllInferMeta(const MetaTensor& x,
out->set_dtype(DataType::BOOL);
}
void ComplexInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
if (x.dims() == y.dims()) {
auto sizes = vectorize(x.dims());
out->set_dims(phi::make_ddim(sizes));
out->set_dtype(dtype::ToComplex(x.dtype()));
// NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory
// so tensors with lod are not supported here
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
// start align axis
int axis = std::abs(x_dims.size() - y_dims.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
phi::funcs::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out->set_dims(phi::make_ddim(out_dims_array));
out->set_dtype(dtype::ToComplex(x.dtype()));
}
}
void ConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
......
......@@ -74,6 +74,10 @@ void CompareInferMeta(const MetaTensor& x,
int axis,
MetaTensor* out);
void ComplexInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
void ConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
......
......@@ -28,4 +28,12 @@ void ImagGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx);
template <typename T, typename Context>
void ComplexGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy);
} // namespace phi
......@@ -30,6 +30,12 @@ void RealKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
void ImagKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
void ComplexKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
// If T is complex
template <
typename T,
......
......@@ -31,3 +31,8 @@ PD_REGISTER_KERNEL(imag_grad,
phi::ImagGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(
complex_grad, CPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) {
kernel->InputAt(2).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -49,3 +49,8 @@ PD_REGISTER_KERNEL(imag,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(
complex, CPU, ALL_LAYOUT, phi::ComplexKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -31,3 +31,8 @@ PD_REGISTER_KERNEL(real_grad,
phi::RealGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(
complex_grad, GPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) {
kernel->InputAt(2).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -50,3 +50,8 @@ PD_REGISTER_KERNEL(imag,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(
complex, GPU, ALL_LAYOUT, phi::ComplexKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/elementwise_grad_base.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
......@@ -47,4 +48,51 @@ void ImagGradKernel(const Context& dev_ctx,
for_range(functor);
}
template <typename T>
struct ComplexGradForRealFunctor {
inline HOSTDEVICE T operator()(const T x,
const T y,
const phi::dtype::complex<T> out,
const phi::dtype::complex<T> dout) {
return dout.real;
}
};
template <typename T>
struct ComplexGradForImagFunctor {
inline HOSTDEVICE T operator()(const T x,
const T y,
const phi::dtype::complex<T> out,
const phi::dtype::complex<T> dout) {
return dout.imag;
}
};
template <typename T, typename Context>
void ComplexGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy) {
using C = phi::dtype::complex<T>;
// skip out in a hacky way
auto out = dout;
phi::funcs::ElemwiseGradCompute<Context,
T,
ComplexGradForRealFunctor<T>,
ComplexGradForImagFunctor<T>,
C>(dev_ctx,
x,
y,
out,
dout,
/*axis*/ -1,
dx,
dy,
ComplexGradForRealFunctor<T>(),
ComplexGradForImagFunctor<T>());
}
} // namespace phi
......@@ -15,7 +15,9 @@
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
......@@ -61,4 +63,45 @@ void ImagKernel(const Context& dev_ctx,
for_range(functor);
}
// functors to use with ElementwiseComputeEx
template <typename T>
struct RealAndImagToComplexFunctor {
inline HOSTDEVICE phi::dtype::complex<T> operator()(const T x, const T y) {
return phi::dtype::complex<T>(x, y);
}
};
template <typename T>
struct ImagAndRealToComplexFunctor {
inline HOSTDEVICE phi::dtype::complex<T> operator()(const T y, const T x) {
return phi::dtype::complex<T>(x, y);
}
};
template <typename T, typename Context>
void ComplexKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
using C = phi::dtype::complex<T>;
dev_ctx.template Alloc<C>(out);
// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related
// facility functions
#if defined(__NVCC__) || defined(__HIPCC__)
phi::funcs::ElementwiseCompute<RealAndImagToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), out);
#else
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
phi::funcs::ElementwiseCompute<RealAndImagToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), out);
} else {
phi::funcs::ElementwiseCompute<ImagAndRealToComplexFunctor<T>, T, C>(
dev_ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor<T>(), out);
}
#endif
}
} // namespace phi
......@@ -24,7 +24,14 @@ KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature ComplexGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"complex_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(real_grad, phi::RealGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(imag_grad, phi::ImagGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(complex_grad, phi::ComplexGradOpArgumentMapping);
......@@ -58,6 +58,7 @@ class TestComplexOp(OpTest):
def setUp(self):
self.op_type = "complex"
self.python_api = paddle.complex
self.init_spec()
x = np.random.randn(*self.x_shape).astype(self.dtype)
y = np.random.randn(*self.y_shape).astype(self.dtype)
......
......@@ -1701,6 +1701,9 @@ def complex(real, imag, name=None):
# [[0.+0.j 0.+1.j 0.+2.j]
# [1.+0.j 1.+1.j 1.+2.j]]
"""
if in_dygraph_mode():
return _C_ops.final_state_complex(real, imag)
if paddle.in_dynamic_mode():
return paddle._C_ops.complex(real, imag)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册