未验证 提交 0e2dd2f3 编写于 作者: W Weilong Wu 提交者: GitHub

[Phi] migrate as_complex kernel to phi (#44438)

* migrate as_complex kernel to phi

* support as_complex and as_real in phi

* rm GetExpectedKernelType for AsRealOp
上级 55427f15
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/complex_view_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -21,6 +19,8 @@ ...@@ -21,6 +19,8 @@
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
...@@ -30,36 +30,6 @@ namespace operators { ...@@ -30,36 +30,6 @@ namespace operators {
class AsComplexOp : public framework::OperatorWithKernel { class AsComplexOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "as_complex");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "as_complex");
auto in_dims = ctx->GetInputDim("X");
const int input_rank = in_dims.size();
PADDLE_ENFORCE_GE(
input_rank,
1,
platform::errors::InvalidArgument(
"The rank of input(X) is less than 1. "
"Expected the rank of input(X) to be equal to or greater than 1."
"But received rank of input(X) = %d",
input_rank));
const int last_dim_size = in_dims[input_rank - 1];
PADDLE_ENFORCE_EQ(
last_dim_size,
2,
platform::errors::InvalidArgument(
"The size of the last dimension of input(X)"
"does not equals 2."
"Expected the size of last dimension of input(X) to be 2."
"But received %d",
last_dim_size));
const framework::DDim out_dims(in_dims.Get(), input_rank - 1);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class AsComplexOpMaker : public framework::OpProtoAndCheckerMaker { class AsComplexOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -95,15 +65,6 @@ class AsComplexGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -95,15 +65,6 @@ class AsComplexGradMaker : public framework::SingleGradOpMaker<T> {
class AsRealOp : public framework::OperatorWithKernel { class AsRealOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(framework::ToRealType(input_data_type),
ctx.GetPlace());
}
}; };
class AsRealOpMaker : public framework::OpProtoAndCheckerMaker { class AsRealOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -143,12 +104,6 @@ DECLARE_INFER_SHAPE_FUNCTOR(as_real, ...@@ -143,12 +104,6 @@ DECLARE_INFER_SHAPE_FUNCTOR(as_real,
AsRealInferShapeFunctor, AsRealInferShapeFunctor,
PD_INFER_META(phi::AsRealInferMeta)); PD_INFER_META(phi::AsRealInferMeta));
REGISTER_OPERATOR(as_complex,
ops::AsComplexOp,
ops::AsComplexOpMaker,
ops::AsComplexGradMaker<paddle::framework::OpDesc>,
ops::AsComplexGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(as_real, REGISTER_OPERATOR(as_real,
ops::AsRealOp, ops::AsRealOp,
ops::AsRealOpMaker, ops::AsRealOpMaker,
...@@ -156,6 +111,13 @@ REGISTER_OPERATOR(as_real, ...@@ -156,6 +111,13 @@ REGISTER_OPERATOR(as_real,
ops::AsRealGradMaker<paddle::framework::OpDesc>, ops::AsRealGradMaker<paddle::framework::OpDesc>,
ops::AsRealGradMaker<paddle::imperative::OpBase>); ops::AsRealGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(as_complex, DECLARE_INFER_SHAPE_FUNCTOR(as_complex,
ops::AsComplexKernel<phi::CPUContext, float>, AsComplexInferShapeFunctor,
ops::AsComplexKernel<phi::CPUContext, double>); PD_INFER_META(phi::AsComplexInferMeta));
REGISTER_OPERATOR(as_complex,
ops::AsComplexOp,
ops::AsComplexOpMaker,
AsComplexInferShapeFunctor,
ops::AsComplexGradMaker<paddle::framework::OpDesc>,
ops::AsComplexGradMaker<paddle::imperative::OpBase>);
...@@ -176,6 +176,15 @@ ...@@ -176,6 +176,15 @@
func : argsort func : argsort
backward : argsort_grad backward : argsort_grad
- api : as_complex
args : (Tensor x)
output : Tensor
infer_meta :
func : AsComplexInferMeta
kernel :
func : as_complex
backward : as_complex_grad
- api : as_real - api : as_real
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
...@@ -183,8 +192,7 @@ ...@@ -183,8 +192,7 @@
func : AsRealInferMeta func : AsRealInferMeta
kernel : kernel :
func : as_real func : as_real
# backward : as_complex backward : as_real_grad
# asin # asin
- api : asin - api : asin
args : (Tensor x) args : (Tensor x)
......
...@@ -116,6 +116,18 @@ ...@@ -116,6 +116,18 @@
data_type : out_grad data_type : out_grad
no_need_buffer : x no_need_buffer : x
- backward_api : as_complex_grad
forward : as_complex (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : as_real(out_grad)
- backward_api : as_real_grad
forward : as_real (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : as_complex(out_grad)
- backward_api : asin_grad - backward_api : asin_grad
forward : asin (Tensor x) -> Tensor(out) forward : asin (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad) args : (Tensor x, Tensor out_grad)
......
...@@ -156,6 +156,33 @@ void AsRealInferMeta(const MetaTensor& input, MetaTensor* output) { ...@@ -156,6 +156,33 @@ void AsRealInferMeta(const MetaTensor& input, MetaTensor* output) {
output->share_lod(input); output->share_lod(input);
} }
void AsComplexInferMeta(const MetaTensor& input, MetaTensor* output) {
auto in_dims = input.dims();
const int input_rank = in_dims.size();
PADDLE_ENFORCE_GE(
input_rank,
1,
phi::errors::InvalidArgument(
"The rank of input(X) is less than 1. "
"Expected the rank of input(X) to be equal to or greater than 1."
"But received rank of input(X) = %d",
input_rank));
const int last_dim_size = in_dims[input_rank - 1];
PADDLE_ENFORCE_EQ(
last_dim_size,
2,
phi::errors::InvalidArgument(
"The size of the last dimension of input(X)"
"does not equals 2."
"Expected the size of last dimension of input(X) to be 2."
"But received %d",
last_dim_size));
const phi::DDim out_dims(in_dims.Get(), input_rank - 1);
output->set_dims(out_dims);
output->share_lod(input);
}
void BatchSizeLikeInferMeta(const MetaTensor& x, void BatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape, const std::vector<int>& shape,
int x_batch_size_dim, int x_batch_size_dim,
......
...@@ -50,6 +50,8 @@ void ArgsortInferMeta(const MetaTensor& input, ...@@ -50,6 +50,8 @@ void ArgsortInferMeta(const MetaTensor& input,
void AsRealInferMeta(const MetaTensor& input, MetaTensor* output); void AsRealInferMeta(const MetaTensor& input, MetaTensor* output);
void AsComplexInferMeta(const MetaTensor& input, MetaTensor* output);
void BatchSizeLikeInferMeta(const MetaTensor& x, void BatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape, const std::vector<int>& shape,
int x_batch_size_dim, int x_batch_size_dim,
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/complex_view_op.h" #pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/enforce.h"
namespace ops = paddle::operators; #include "paddle/phi/core/dense_tensor.h"
REGISTER_OP_CUDA_KERNEL( namespace phi {
as_complex,
ops::AsComplexKernel<paddle::platform::CUDADeviceContext, float>, template <typename T, typename Context>
ops::AsComplexKernel<paddle::platform::CUDADeviceContext, double>); void AsComplexKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,34 +12,11 @@ ...@@ -12,34 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #include "paddle/phi/kernels/as_complex_kernel.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/phi/kernels/impl/as_complex_impl.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace paddle { PD_REGISTER_KERNEL(
namespace operators { as_complex, CPU, ALL_LAYOUT, phi::AsComplexKernel, float, double) {}
template <typename DeviceContext, typename T>
class AsComplexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* x = context.Input<framework::LoDTensor>("X");
auto* out = context.Output<framework::LoDTensor>("Out");
out->mutable_data<platform::complex<T>>(context.GetPlace());
// TensorCopy also changes output's shape & dtype
const framework::DDim out_dims_original = out->dims();
framework::TensorCopy(*x, context.GetPlace(), out);
out->Resize(out_dims_original); // restored the shape
out->mutable_data<platform::complex<T>>(
context.GetPlace()); // restore the dtype
}
};
} // namespace operators
} // namespace paddle
...@@ -15,8 +15,12 @@ ...@@ -15,8 +15,12 @@
#include "paddle/phi/kernels/as_real_kernel.h" #include "paddle/phi/kernels/as_real_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/as_real_impl.h" #include "paddle/phi/kernels/impl/as_real_impl.h"
PD_REGISTER_KERNEL(as_real, CPU, ALL_LAYOUT, phi::AsRealKernel, float, double) { using complex64 = ::phi::dtype::complex<float>;
} using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
as_real, CPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/as_complex_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/as_complex_impl.h"
PD_REGISTER_KERNEL(
as_complex, GPU, ALL_LAYOUT, phi::AsComplexKernel, float, double) {}
...@@ -15,8 +15,12 @@ ...@@ -15,8 +15,12 @@
#include "paddle/phi/kernels/as_real_kernel.h" #include "paddle/phi/kernels/as_real_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/as_real_impl.h" #include "paddle/phi/kernels/impl/as_real_impl.h"
PD_REGISTER_KERNEL(as_real, GPU, ALL_LAYOUT, phi::AsRealKernel, float, double) { using complex64 = ::phi::dtype::complex<float>;
} using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
as_real, GPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/as_complex_kernel.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
/**
* @brief This operator is used to return a complex tensor represented by an
* old-fashioned real tensor. The size of the last dimension of the input tensor
* should be 2, which corresponds to 'real' and 'complex', respectively.
*
* @param ctx device context
* @param x the input tensor of as_complex
* @param out the output tensor of as_complex
*/
template <typename T, typename Context>
void AsComplexKernel(const Context& ctx,
const DenseTensor& x,
DenseTensor* out) {
ctx.template Alloc<phi::dtype::complex<T>>(out);
auto out_dims_original = out->dims();
Copy(ctx, x, ctx.GetPlace(), false, out);
out->Resize(out_dims_original); // restored the shape.
out->set_type(paddle::experimental::CppTypeToDataType<
phi::dtype::complex<T>>::Type()); // restored the dtype.
}
} // namespace phi
...@@ -21,15 +21,24 @@ ...@@ -21,15 +21,24 @@
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
namespace phi { namespace phi {
/**
* @brief This operator is used to return an old-fashioned real tensor from a
* complex tensor. The size of the last dimension of the output tensor is 2,
* which corresponds to 'real' and 'complex', respectively.
*
* @param ctx device context
* @param x the input tensor of as_real
* @param out the output tensor of as_real
*/
template <typename T, typename Context> template <typename T, typename Context>
void AsRealKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { void AsRealKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out); ctx.template Alloc<typename T::value_type>(out);
auto out_dims_original = out->dims(); auto out_dims_original = out->dims();
Copy(ctx, x, ctx.GetPlace(), false, out); Copy(ctx, x, ctx.GetPlace(), false, out);
out->Resize(out_dims_original); // restored the shape. out->Resize(out_dims_original); // restored the shape.
out->set_type( out->set_type(paddle::experimental::CppTypeToDataType<
paddle::experimental::CppTypeToDataType<T>::Type()); // restored the typename T::value_type>::Type()); // restored the dtype.
// dtype.
} }
} // namespace phi } // namespace phi
...@@ -39,6 +39,7 @@ class TestViewAsComplexOp(OpTest): ...@@ -39,6 +39,7 @@ class TestViewAsComplexOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "as_complex" self.op_type = "as_complex"
self.python_api = paddle.as_complex
x = np.random.randn(10, 10, 2).astype("float64") x = np.random.randn(10, 10, 2).astype("float64")
out_ref = ref_view_as_complex(x) out_ref = ref_view_as_complex(x)
self.out_grad = np.ones( self.out_grad = np.ones(
......
...@@ -3877,8 +3877,10 @@ def as_complex(x, name=None): ...@@ -3877,8 +3877,10 @@ def as_complex(x, name=None):
# [[ 0. +1.j 2. +3.j 4. +5.j] # [[ 0. +1.j 2. +3.j 4. +5.j]
# [ 6. +7.j 8. +9.j 10.+11.j]] # [ 6. +7.j 8. +9.j 10.+11.j]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return paddle._C_ops.as_complex(x) return _C_ops.final_state_as_complex(x)
if _in_legacy_dygraph():
return _C_ops.as_complex(x)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'as_complex') check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'as_complex')
op_type = "as_complex" op_type = "as_complex"
...@@ -3926,8 +3928,10 @@ def as_real(x, name=None): ...@@ -3926,8 +3928,10 @@ def as_real(x, name=None):
# [ 8. 9.] # [ 8. 9.]
# [10. 11.]]] # [10. 11.]]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return paddle._C_ops.as_real(x) return _C_ops.final_state_as_real(x)
if _in_legacy_dygraph():
return _C_ops.as_real(x)
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'as_real') check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'as_real')
op_type = "as_real" op_type = "as_real"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册