From ae40370d2765802f627a6c14e6042cfff79851be Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Fri, 10 Dec 2021 14:58:29 +0800 Subject: [PATCH] add as_complex and as_real op (#37784) * add as_complex and as_real op --- paddle/fluid/operators/complex_view_op.cc | 163 ++++++++++++++++++ paddle/fluid/operators/complex_view_op.cu | 29 ++++ paddle/fluid/operators/complex_view_op.h | 60 +++++++ python/paddle/__init__.py | 5 + .../tests/unittests/test_complex_view_op.py | 127 ++++++++++++++ python/paddle/tensor/__init__.py | 5 + python/paddle/tensor/manipulation.py | 92 ++++++++++ python/paddle/tensor/math.py | 3 +- 8 files changed, 482 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/complex_view_op.cc create mode 100644 paddle/fluid/operators/complex_view_op.cu create mode 100644 paddle/fluid/operators/complex_view_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_complex_view_op.py diff --git a/paddle/fluid/operators/complex_view_op.cc b/paddle/fluid/operators/complex_view_op.cc new file mode 100644 index 0000000000..2fb21ca4ea --- /dev/null +++ b/paddle/fluid/operators/complex_view_op.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/complex_view_op.h" + +#include +#include +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +class AsComplexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "as_complex"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "as_complex"); + + auto in_dims = ctx->GetInputDim("X"); + const int input_rank = in_dims.size(); + PADDLE_ENFORCE_GE( + input_rank, 1, + platform::errors::InvalidArgument( + "The rank of input(X) is less than 1. " + "Expected the rank of input(X) to be equal to or greater than 1." + "But received rank of input(X) = %d", + input_rank)); + const int last_dim_size = in_dims[input_rank - 1]; + PADDLE_ENFORCE_EQ( + last_dim_size, 2, + platform::errors::InvalidArgument( + "The size of the last dimension of input(X)" + "does not equals 2." + "Expected the size of last dimension of input(X) to be 2." + "But received %d", + last_dim_size)); + + const framework::DDim out_dims(in_dims.Get(), input_rank - 1); + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class AsComplexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of view_as_complex op."); + AddOutput("Out", "(Tensor), The output tensor of view_as_complex op."); + AddComment(R"DOC( +As_complex Operator. + +This operator is used to return a complex tensor represented +by an old-fashioned real tensor. The size of the last dimension of +the input tensor should be 2, which corresponds to 'real' and +'complex', respectively. + +)DOC"); + } +}; + +template +class AsComplexGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("as_real"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput("Out", this->InputGrad("X")); + } +}; + +class AsRealOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "as_real"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "as_real"); + + auto out_dims_v = framework::vectorize(ctx->GetInputDim("X")); + out_dims_v.push_back(2); + const framework::DDim out_dims = framework::make_ddim(out_dims_v); + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(framework::ToRealType(input_data_type), + ctx.GetPlace()); + } +}; + +class AsRealOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of as_real op."); + AddOutput("Out", "(Tensor), The output tensor of as_real op."); + AddComment(R"DOC( +AsReal Operator. + +This operator is used to return an old-fashioned real tensor from a +complex tensor. The size of the last dimension of the output tensor is 2, +which corresponds to 'real' and 'complex', respectively. + +)DOC"); + } +}; + +template +class AsRealGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("as_complex"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput("Out", this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(as_complex, ops::AsComplexOp, ops::AsComplexOpMaker, + ops::AsComplexGradMaker, + ops::AsComplexGradMaker); + +REGISTER_OPERATOR(as_real, ops::AsRealOp, ops::AsRealOpMaker, + ops::AsRealGradMaker, + ops::AsRealGradMaker); + +REGISTER_OP_CPU_KERNEL( + as_complex, ops::AsComplexKernel, + ops::AsComplexKernel); + +REGISTER_OP_CPU_KERNEL( + as_real, ops::AsRealKernel, + ops::AsRealKernel); diff --git a/paddle/fluid/operators/complex_view_op.cu b/paddle/fluid/operators/complex_view_op.cu new file mode 100644 index 0000000000..261881cb8d --- /dev/null +++ b/paddle/fluid/operators/complex_view_op.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/complex_view_op.h" + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/enforce.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + as_complex, + ops::AsComplexKernel, + ops::AsComplexKernel); + +REGISTER_OP_CUDA_KERNEL( + as_real, ops::AsRealKernel, + ops::AsRealKernel); diff --git a/paddle/fluid/operators/complex_view_op.h b/paddle/fluid/operators/complex_view_op.h new file mode 100644 index 0000000000..9a8d89db40 --- /dev/null +++ b/paddle/fluid/operators/complex_view_op.h @@ -0,0 +1,60 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/complex.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +class AsComplexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data>(context.GetPlace()); + + // TensorCopy also changes output's shape & dtype + const framework::DDim out_dims_original = out->dims(); + framework::TensorCopy(*x, context.GetPlace(), out); + out->Resize(out_dims_original); // restored the shape + out->mutable_data>( + context.GetPlace()); // restore the dtype + } +}; + +template +class AsRealKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* x = context.Input("X"); + auto* out = context.Output("Out"); + + out->mutable_data(context.GetPlace()); + const framework::DDim out_dims_original = out->dims(); + framework::TensorCopy(*x, context.GetPlace(), out); + out->Resize(out_dims_original); // restored the shape + out->mutable_data(context.GetPlace()); // restore the dtype + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 28060ad171..6ffdd75f72 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -156,6 +156,9 @@ from .tensor.manipulation import roll # noqa: F401 from .tensor.manipulation import chunk # noqa: F401 from .tensor.manipulation import tolist # noqa: F401 from .tensor.manipulation import tensordot # noqa: F401 +from .tensor.manipulation import as_complex # noqa: F401 +from .tensor.manipulation import as_real # noqa: F401 + from .tensor.math import abs # noqa: F401 from .tensor.math import acos # noqa: F401 from .tensor.math import asin # noqa: F401 @@ -559,6 +562,8 @@ __all__ = [ # noqa 'einsum', 'set_flags', 'get_flags', + 'as_complex', + 'as_real', 'diff', 'angle', ] diff --git a/python/paddle/fluid/tests/unittests/test_complex_view_op.py b/python/paddle/fluid/tests/unittests/test_complex_view_op.py new file mode 100644 index 0000000000..5dac121ff3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_view_op.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + +import paddle +from paddle.fluid import dygraph +from paddle import static +paddle.enable_static() + + +def ref_view_as_complex(x): + real, imag = np.take(x, 0, axis=-1), np.take(x, 1, axis=-1) + return real + 1j * imag + + +def ref_view_as_real(x): + return np.stack([x.real, x.imag], -1) + + +class TestViewAsComplexOp(OpTest): + def setUp(self): + self.op_type = "as_complex" + x = np.random.randn(10, 10, 2).astype("float64") + out_ref = ref_view_as_complex(x) + self.out_grad = np.ones( + [10, 10], dtype="float64") + 1j * np.ones( + [10, 10], dtype="float64") + self.inputs = {'X': x} + self.outputs = {'Out': out_ref} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[ref_view_as_real(self.out_grad)], + user_defined_grad_outputs=[self.out_grad]) + + +class TestViewAsRealOp(OpTest): + def setUp(self): + self.op_type = "as_real" + real = np.random.randn(10, 10).astype("float64") + imag = np.random.randn(10, 10).astype("float64") + x = real + 1j * imag + out_ref = ref_view_as_real(x) + self.inputs = {'X': x} + self.outputs = {'Out': out_ref} + self.out_grad = np.ones([10, 10, 2], dtype="float64") + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[ref_view_as_complex(self.out_grad)], + user_defined_grad_outputs=[self.out_grad]) + + +class TestViewAsComplexAPI(unittest.TestCase): + def setUp(self): + self.x = np.random.randn(10, 10, 2) + self.out = ref_view_as_complex(self.x) + + def test_dygraph(self): + with dygraph.guard(): + x = paddle.to_tensor(self.x) + out_np = paddle.as_complex(x).numpy() + self.assertTrue(np.allclose(self.out, out_np)) + + def test_static(self): + mp, sp = static.Program(), static.Program() + with static.program_guard(mp, sp): + x = static.data("x", shape=[10, 10, 2], dtype="float64") + out = paddle.as_complex(x) + + exe = static.Executor() + exe.run(sp) + [out_np] = exe.run(mp, feed={"x": self.x}, fetch_list=[out]) + self.assertTrue(np.allclose(self.out, out_np)) + + +class TestViewAsRealAPI(unittest.TestCase): + def setUp(self): + self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10) + self.out = ref_view_as_real(self.x) + + def test_dygraph(self): + with dygraph.guard(): + x = paddle.to_tensor(self.x) + out_np = paddle.as_real(x).numpy() + self.assertTrue(np.allclose(self.out, out_np)) + + def test_static(self): + mp, sp = static.Program(), static.Program() + with static.program_guard(mp, sp): + x = static.data("x", shape=[10, 10], dtype="complex128") + out = paddle.as_real(x) + + exe = static.Executor() + exe.run(sp) + [out_np] = exe.run(mp, feed={"x": self.x}, fetch_list=[out]) + self.assertTrue(np.allclose(self.out, out_np)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 82727b33f9..36dfd717a1 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -111,6 +111,9 @@ from .manipulation import unbind # noqa: F401 from .manipulation import roll # noqa: F401 from .manipulation import chunk # noqa: F401 from .manipulation import tensordot # noqa: F401 +from .manipulation import as_complex # noqa: F401 +from .manipulation import as_real # noqa: F401 + from .math import abs # noqa: F401 from .math import acos # noqa: F401 from .math import asin # noqa: F401 @@ -411,6 +414,8 @@ tensor_method_func = [ #noqa 'multi_dot', 'solve', 'triangular_solve', + 'as_complex', + 'as_real', 'rad2deg', 'deg2rad', 'gcd', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index b4f00ab4ff..a81d8c54ff 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -34,6 +34,7 @@ from ..fluid import layers from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only import paddle from paddle import _C_ops +from paddle.tensor.attribute import _complex_to_real_dtype, _real_to_complex_dtype __all__ = [] @@ -2488,3 +2489,94 @@ def tensordot(x, y, axes=2, name=None): [contraction_size, not_contraction_size_y]) out = x.matmul(y).reshape(shape_out) return out + + +def as_complex(x, name=None): + """Transform a real tensor to a complex tensor. + + The data type of the input tensor is 'float32' or 'float64', and the data + type of the returned tensor is 'complex64' or 'complex128', respectively. + + The shape of the input tensor is ``(* ,2)``, (``*`` means arbitary shape), i.e. + the size of the last axis shoule be 2, which represent the real and imag part + of a complex number. The shape of the returned tensor is ``(*,)``. + + Args: + x (Tensor): The input tensor. Data type is 'float32' or 'float64'. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The output. Data type is 'complex64' or 'complex128', with the same precision as the input. + + Examples: + .. code-block:: python + + import paddle + x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2]) + y = paddle.as_complex(x) + print(y.numpy()) + + # [[ 0. +1.j 2. +3.j 4. +5.j] + # [ 6. +7.j 8. +9.j 10.+11.j]] + """ + if in_dygraph_mode(): + return paddle._C_ops.as_complex(x) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'as_complex') + op_type = "as_complex" + helper = LayerHelper(op_type, **locals()) + inputs = {"X": x} + out = helper.create_variable_for_type_inference( + dtype=_real_to_complex_dtype(x.dtype)) + outputs = {"Out": out} + attrs = {} + helper.append_op(type=op_type, inputs=inputs, attrs=attrs, outputs=outputs) + return out + + +def as_real(x, name=None): + """Transform a complex tensor to a real tensor. + + The data type of the input tensor is 'complex64' or 'complex128', and the data + type of the returned tensor is 'float32' or 'float64', respectively. + + When the shape of the input tensor is ``(*, )``, (``*`` means arbitary shape), + the shape of the output tensor is ``(*, 2)``, i.e. the shape of the output is + the shape of the input appended by an extra ``2``. + + Args: + x (Tensor): The input tensor. Data type is 'complex64' or 'complex128'. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The output. Data type is 'float32' or 'float64', with the same precision as the input. + + Examples: + .. code-block:: python + + import paddle + x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2]) + y = paddle.as_complex(x) + z = paddle.as_real(y) + print(z.numpy()) + + # [[[ 0. 1.] + # [ 2. 3.] + # [ 4. 5.]] + + # [[ 6. 7.] + # [ 8. 9.] + # [10. 11.]]] + """ + if in_dygraph_mode(): + return paddle._C_ops.as_real(x) + + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'as_real') + op_type = "as_real" + helper = LayerHelper(op_type, **locals()) + inputs = {"X": x} + out = helper.create_variable_for_type_inference( + dtype=_complex_to_real_dtype(x.dtype)) + outputs = {"Out": out} + helper.append_op(type=op_type, inputs=inputs, outputs=outputs) + return out diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b79caf0559..fefaecaf60 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3082,7 +3082,6 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): return out - def angle(x, name=None): r""" Element-wise angle of complex numbers. For non-negative real numbers, the angle is 0 while @@ -3098,7 +3097,7 @@ def angle(x, name=None): name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - out (Tensor): y (Tensor): An N-D Tensor of real data type with the same precision as that of x's data type. + Tensor: An N-D Tensor of real data type with the same precision as that of x's data type. Examples: .. code-block:: python -- GitLab