未验证 提交 345cc8fa 编写于 作者: F From00 提交者: GitHub

Move real and imag op to phi (#39777)

* Move Real OP to phi

* Move Imag OP to phi

* Move Real and Imag InferShape to phi

* Move Real and Imag to complex_kernel

* Change PT_REGISTER_XXX to PD_REGISTER_XXX
上级 74c0bc1c
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/imag_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -20,15 +23,6 @@ namespace operators {
class ImagOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Imag");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Imag");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", "Out");
}
};
class ImagOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -88,19 +82,13 @@ DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer,
} // namespace operators
} // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(imag, ImagInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker,
ops::ImagGradOpMaker<paddle::framework::OpDesc>,
ops::ImagGradOpMaker<paddle::imperative::OpBase>);
ops::ImagGradOpMaker<paddle::imperative::OpBase>,
ImagInferShapeFunctor);
REGISTER_OPERATOR(imag_grad, ops::ImagGradOp);
REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ImagKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(imag_grad,
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class ImagKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
auto numel = x->numel();
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<phi::funcs::Real<T>>(
ctx.GetPlace(),
static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
phi::funcs::ImagFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class ImagGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<phi::funcs::Real<T>>();
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
phi::funcs::ImagToComplexFunctor<T> functor(dout_data, dx_data, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/real_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -20,14 +23,6 @@ namespace operators {
class RealOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Real");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Real");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", "Out");
}
};
class RealOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -87,19 +82,13 @@ DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer,
} // namespace operators
} // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(real, RealInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker,
ops::RealGradOpMaker<::paddle::framework::OpDesc>,
ops::RealGradOpMaker<::paddle::imperative::OpBase>);
ops::RealGradOpMaker<::paddle::imperative::OpBase>,
RealInferShapeFunctor);
REGISTER_OPERATOR(real_grad, ops::RealGradOp);
REGISTER_OP_CPU_KERNEL(real, ops::RealKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::RealKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(real_grad,
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/real_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(real,
ops::RealKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::RealKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(real_grad,
ops::RealGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::RealGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class RealKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
auto numel = x->numel();
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<phi::funcs::Real<T>>(
ctx.GetPlace(),
static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
phi::funcs::RealFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class RealGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<phi::funcs::Real<T>>();
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
phi::funcs::RealToComplexFunctor<T> functor(dout_data, dx_data, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,17 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/imag_op.h"
#pragma once
namespace ops = paddle::operators;
#include "paddle/phi/core/dense_tensor.h"
REGISTER_OP_CUDA_KERNEL(imag,
ops::ImagKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::ImagKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(imag_grad,
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
namespace phi {
template <typename T, typename Context>
void RealGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx);
template <typename T, typename Context>
void ImagGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx);
} // namespace phi
......@@ -50,4 +50,14 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
return x;
}
template <typename T, typename DeviceContext>
void RealKernel(const DeviceContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
template <typename T, typename DeviceContext>
void ImagKernel(const DeviceContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/complex_grad_kernel.h"
#include "paddle/phi/kernels/impl/complex_grad_kernel_impl.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(real_grad,
CPU,
ALL_LAYOUT,
phi::RealGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(imag_grad,
CPU,
ALL_LAYOUT,
phi::ImagGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -31,3 +31,17 @@ PD_REGISTER_KERNEL(conj,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(real,
CPU,
ALL_LAYOUT,
phi::RealKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(imag,
CPU,
ALL_LAYOUT,
phi::ImagKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/complex_grad_kernel.h"
#include "paddle/phi/kernels/impl/complex_grad_kernel_impl.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(imag_grad,
GPU,
ALL_LAYOUT,
phi::ImagGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(real_grad,
GPU,
ALL_LAYOUT,
phi::RealGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -32,3 +32,17 @@ PD_REGISTER_KERNEL(conj,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(real,
GPU,
ALL_LAYOUT,
phi::RealKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(imag,
GPU,
ALL_LAYOUT,
phi::ImagKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace phi {
template <typename T, typename Context>
void RealGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx) {
auto numel = dout.numel();
auto* dout_data = dout.data<phi::funcs::Real<T>>();
auto* dx_data =
dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::RealToComplexFunctor<T> functor(dout_data, dx_data, numel);
for_range(functor);
}
template <typename T, typename Context>
void ImagGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx) {
auto numel = dout.numel();
auto* dout_data = dout.data<phi::funcs::Real<T>>();
auto* dx_data =
dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::ImagToComplexFunctor<T> functor(dout_data, dx_data, numel);
for_range(functor);
}
} // namespace phi
......@@ -33,4 +33,32 @@ void ConjKernel(const Context& dev_ctx,
for_range(functor);
}
template <typename T, typename Context>
void RealKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<phi::funcs::Real<T>>(
out, static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>)));
paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::RealFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
template <typename T, typename Context>
void ImagKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<phi::funcs::Real<T>>(
out, static_cast<size_t>(numel * sizeof(phi::funcs::Real<T>)));
paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::ImagFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"real_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
}
KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"imag_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(real_grad, phi::RealGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(imag_grad, phi::ImagGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册