From 6cfa59de1b57b7aad84ad87c6256c22bb4c5aed2 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 17 Dec 2020 02:05:42 -0600 Subject: [PATCH] [Complex] Add real & imag op and api for complex tensor (#29672) * add complex real op & api & unittest * add imag op & api & unittest * refactor op impl * revert simplify writing due to complile failed * polish details * polish grad op code --- paddle/fluid/framework/data_type.h | 14 ++ paddle/fluid/framework/tensor.cc | 2 +- paddle/fluid/operators/imag_op.cc | 106 +++++++++++ paddle/fluid/operators/imag_op.cu | 28 +++ paddle/fluid/operators/imag_op.h | 66 +++++++ .../fluid/operators/math/complex_functors.h | 140 +++++++++++++++ paddle/fluid/operators/real_op.cc | 105 +++++++++++ paddle/fluid/operators/real_op.cu | 28 +++ paddle/fluid/operators/real_op.h | 66 +++++++ python/paddle/__init__.py | 2 + .../tests/unittests/test_real_imag_op.py | 167 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/attribute.py | 105 ++++++++++- 13 files changed, 829 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/imag_op.cc create mode 100644 paddle/fluid/operators/imag_op.cu create mode 100644 paddle/fluid/operators/imag_op.h create mode 100644 paddle/fluid/operators/math/complex_functors.h create mode 100644 paddle/fluid/operators/real_op.cc create mode 100644 paddle/fluid/operators/real_op.cu create mode 100644 paddle/fluid/operators/real_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_real_imag_op.py diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index eafb8ade9e5..6a48378dc29 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -150,5 +150,19 @@ extern inline bool IsComplexType(const proto::VarType::Type type) { extern proto::VarType::Type PromoteTypesIfComplexExists( const proto::VarType::Type type_a, const proto::VarType::Type type_b); +extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) { + switch (t) { + case proto::VarType::FP32: + return proto::VarType::COMPLEX64; + case proto::VarType::FP64: + return proto::VarType::COMPLEX128; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unknown complex value data type (%s), now only support float32 and " + "float64.", + DataTypeToString(t))); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 9f5d8d30c9c..f721caaae9c 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -60,7 +60,7 @@ void* Tensor::mutable_data(const platform::Place& place, requested_size, size, platform::errors::InvalidArgument( "The requested memory size is less than the memory size of Tensor. " - "But received requested memory size is d%, " + "But received requested memory size is %d, " "memory size of Tensor is %d.", requested_size, size)); size = requested_size; diff --git a/paddle/fluid/operators/imag_op.cc b/paddle/fluid/operators/imag_op.cc new file mode 100644 index 00000000000..899025ae709 --- /dev/null +++ b/paddle/fluid/operators/imag_op.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/imag_op.h" + +namespace paddle { +namespace operators { + +class ImagOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Imag"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Imag"); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class ImagOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of imag op."); + AddOutput("Out", "(Tensor), The output tensor of imag op."); + AddComment(R"DOC( +Imag Operator. + +This operator is used to get a new tensor containing imaginary values +from a tensor with complex data type. + +)DOC"); + } +}; + +class ImagGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "ImagGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "ImagGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + auto complex_dtype = framework::ToComplexType(dtype); + return framework::OpKernelType(complex_dtype, ctx.GetPlace()); + } +}; + +template +class ImagGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("imag_grad"); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +DECLARE_INPLACE_OP_INFERER(ImagOpInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker, + ops::ImagGradOpMaker, + ops::ImagGradOpMaker); +REGISTER_OPERATOR(imag_grad, ops::ImagGradOp); + +REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel, + ops::ImagKernel); +REGISTER_OP_CPU_KERNEL(imag_grad, + ops::ImagGradKernel, + ops::ImagGradKernel); diff --git a/paddle/fluid/operators/imag_op.cu b/paddle/fluid/operators/imag_op.cu new file mode 100644 index 00000000000..a7a3b136821 --- /dev/null +++ b/paddle/fluid/operators/imag_op.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/imag_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(imag, + ops::ImagKernel, + ops::ImagKernel); +REGISTER_OP_CUDA_KERNEL(imag_grad, + ops::ImagGradKernel, + ops::ImagGradKernel); diff --git a/paddle/fluid/operators/imag_op.h b/paddle/fluid/operators/imag_op.h new file mode 100644 index 00000000000..562a8dffa90 --- /dev/null +++ b/paddle/fluid/operators/imag_op.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +class ImagKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* x = ctx.Input("X"); + framework::Tensor* out = ctx.Output("Out"); + + auto numel = x->numel(); + auto* x_data = x->data(); + auto* out_data = out->mutable_data>( + ctx.GetPlace(), static_cast(numel * sizeof(math::Real))); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + math::ImagFunctor functor(x_data, out_data, numel); + for_range(functor); + } +}; + +template +class ImagGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* d_out = + ctx.Input(framework::GradVarName("Out")); + framework::Tensor* d_x = + ctx.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + auto* dout_data = d_out->data>(); + auto* dx_data = d_x->mutable_data( + ctx.GetPlace(), static_cast(numel * sizeof(T))); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + math::ImagToComplexFunctor functor(dout_data, dx_data, numel); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/complex_functors.h b/paddle/fluid/operators/math/complex_functors.h new file mode 100644 index 00000000000..302e3d562c6 --- /dev/null +++ b/paddle/fluid/operators/math/complex_functors.h @@ -0,0 +1,140 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct cond { + static constexpr bool value = B; + using type = T; +}; + +template +struct eval_if { + using type = typename TrueF::type; +}; + +template +struct eval_if { + using type = typename FalseF::type; +}; + +template +using eval_if_t = typename eval_if::type; + +template +struct select { + using type = eval_if_t>; +}; + +template +using select_t = typename select::type; + +template +using Real = + select_t::value, float>, + cond::value, double>, T>; + +template +using Complex = typename std::enable_if::value>::type; + +// There are no NoComplex cases now, implement later if needed +template +using NoComplex = typename std::enable_if::value>::type; + +template +struct RealFunctor; + +template +struct RealFunctor>> { + public: + RealFunctor(const T* input, Real* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = input_[idx].real; + } + + private: + const T* input_; + Real* output_; + int64_t numel_; +}; + +template +struct ImagFunctor; + +template +struct ImagFunctor>> { + ImagFunctor(const T* input, Real* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = input_[idx].imag; + } + + const T* input_; + Real* output_; + int64_t numel_; +}; + +template +struct RealToComplexFunctor; + +template +struct RealToComplexFunctor>> { + RealToComplexFunctor(const Real* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx].real = input_[idx]; + output_[idx].imag = 0; + } + + const Real* input_; + T* output_; + int64_t numel_; +}; + +template +struct ImagToComplexFunctor; + +template +struct ImagToComplexFunctor>> { + ImagToComplexFunctor(const Real* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx].real = 0; + output_[idx].imag = input_[idx]; + } + + const Real* input_; + T* output_; + int64_t numel_; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/real_op.cc b/paddle/fluid/operators/real_op.cc new file mode 100644 index 00000000000..5f667999ee6 --- /dev/null +++ b/paddle/fluid/operators/real_op.cc @@ -0,0 +1,105 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/real_op.h" + +namespace paddle { +namespace operators { + +class RealOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Real"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Real"); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class RealOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of real op."); + AddOutput("Out", "(Tensor), The output tensor of real op."); + AddComment(R"DOC( +Real Operator. + +This operator is used to get a new tensor containing real values +from a tensor with complex data type. + +)DOC"); + } +}; + +class RealGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "RealGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "RealGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + auto complex_dtype = framework::ToComplexType(dtype); + return framework::OpKernelType(complex_dtype, ctx.GetPlace()); + } +}; + +template +class RealGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("real_grad"); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +DECLARE_INPLACE_OP_INFERER(RealOpInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker, + ops::RealGradOpMaker<::paddle::framework::OpDesc>, + ops::RealGradOpMaker<::paddle::imperative::OpBase>); +REGISTER_OPERATOR(real_grad, ops::RealGradOp); + +REGISTER_OP_CPU_KERNEL(real, ops::RealKernel, + ops::RealKernel); +REGISTER_OP_CPU_KERNEL(real_grad, + ops::RealGradKernel, + ops::RealGradKernel); diff --git a/paddle/fluid/operators/real_op.cu b/paddle/fluid/operators/real_op.cu new file mode 100644 index 00000000000..b3d0855111b --- /dev/null +++ b/paddle/fluid/operators/real_op.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/real_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(real, + ops::RealKernel, + ops::RealKernel); +REGISTER_OP_CUDA_KERNEL(real_grad, + ops::RealGradKernel, + ops::RealGradKernel); diff --git a/paddle/fluid/operators/real_op.h b/paddle/fluid/operators/real_op.h new file mode 100644 index 00000000000..6cc9065269c --- /dev/null +++ b/paddle/fluid/operators/real_op.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +class RealKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* x = ctx.Input("X"); + framework::Tensor* out = ctx.Output("Out"); + + auto numel = x->numel(); + auto* x_data = x->data(); + auto* out_data = out->mutable_data>( + ctx.GetPlace(), static_cast(numel * sizeof(math::Real))); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + math::RealFunctor functor(x_data, out_data, numel); + for_range(functor); + } +}; + +template +class RealGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* d_out = + ctx.Input(framework::GradVarName("Out")); + framework::Tensor* d_x = + ctx.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + auto* dout_data = d_out->data>(); + auto* dx_data = d_x->mutable_data( + ctx.GetPlace(), static_cast(numel * sizeof(T))); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + math::RealToComplexFunctor functor(dout_data, dx_data, numel); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ac279b796e4..602df10c653 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -51,6 +51,8 @@ from .tensor.random import bernoulli from .tensor.attribute import rank #DEFINE_ALIAS from .tensor.attribute import shape #DEFINE_ALIAS +from .tensor.attribute import real #DEFINE_ALIAS +from .tensor.attribute import imag #DEFINE_ALIAS from .tensor.creation import to_tensor #DEFINE_ALIAS from .tensor.creation import diag #DEFINE_ALIAS from .tensor.creation import eye #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_real_imag_op.py b/python/paddle/fluid/tests/unittests/test_real_imag_op.py new file mode 100644 index 00000000000..ab24506f801 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_real_imag_op.py @@ -0,0 +1,167 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.static as static +from op_test import OpTest + +numpy_apis = { + "real": np.real, + "imag": np.imag, +} + +paddle_apis = { + "real": paddle.real, + "imag": paddle.imag, +} + + +class TestRealOp(OpTest): + def setUp(self): + # switch to static + paddle.enable_static() + # op test attrs + self.op_type = "real" + self.dtype = np.float64 + self.init_input_output() + # backward attrs + self.init_grad_input_output() + + def init_input_output(self): + self.inputs = { + 'X': np.random.random( + (20, 5)).astype(self.dtype) + 1j * np.random.random( + (20, 5)).astype(self.dtype) + } + self.outputs = {'Out': numpy_apis[self.op_type](self.inputs['X'])} + + def init_grad_input_output(self): + self.grad_out = np.ones((20, 5), self.dtype) + self.grad_x = np.real(self.grad_out) + 1j * np.zeros( + self.grad_out.shape) + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestImagOp(TestRealOp): + def setUp(self): + # switch to static + paddle.enable_static() + # op test attrs + self.op_type = "imag" + self.dtype = np.float64 + self.init_input_output() + # backward attrs + self.init_grad_input_output() + + def init_grad_input_output(self): + self.grad_out = np.ones((20, 5), self.dtype) + self.grad_x = np.zeros(self.grad_out.shape) + 1j * np.real( + self.grad_out) + + +class TestRealAPI(unittest.TestCase): + def setUp(self): + # switch to static + paddle.enable_static() + # prepare test attrs + self.api = "real" + self.dtypes = ["complex64", "complex128"] + self.places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + self._shape = [2, 20, 2, 3] + + def test_in_static_mode(self): + def init_input_output(dtype): + input = np.random.random(self._shape).astype( + dtype) + 1j * np.random.random(self._shape).astype(dtype) + return {'x': input}, numpy_apis[self.api](input) + + for dtype in self.dtypes: + input_dict, np_res = init_input_output(dtype) + for place in self.places: + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype=dtype) + out = paddle_apis[self.api](x) + + exe = static.Executor(place) + out_value = exe.run(feed=input_dict, fetch_list=[out.name]) + self.assertTrue(np.array_equal(np_res, out_value[0])) + + def test_in_dynamic_mode(self): + for dtype in self.dtypes: + input = np.random.random(self._shape).astype( + dtype) + 1j * np.random.random(self._shape).astype(dtype) + np_res = numpy_apis[self.api](input) + for place in self.places: + # it is more convenient to use `guard` than `enable/disable_**` here + with fluid.dygraph.guard(place): + input_t = paddle.to_tensor(input) + res = paddle_apis[self.api](input_t).numpy() + self.assertTrue(np.array_equal(np_res, res)) + res_t = input_t.real().numpy( + ) if self.api is "real" else input_t.imag().numpy() + self.assertTrue(np.array_equal(np_res, res_t)) + + def test_name_argument(self): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype=self.dtypes[0]) + out = paddle_apis[self.api](x, name="real_res") + self.assertTrue("real_res" in out.name) + + def test_dtype_error(self): + # in static mode + with self.assertRaises(TypeError): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype="float32") + out = paddle_apis[self.api](x, name="real_res") + + # in dynamic mode + with self.assertRaises(RuntimeError): + with fluid.dygraph.guard(): + input = np.random.random(self._shape).astype("float32") + input_t = paddle.to_tensor(input) + res = paddle_apis[self.api](input_t) + + +class TestImagAPI(TestRealAPI): + def setUp(self): + # switch to static + paddle.enable_static() + # prepare test attrs + self.api = "imag" + self.dtypes = ["complex64", "complex128"] + self.places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + self._shape = [2, 20, 2, 3] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index daee64b4204..f6e0ccd85fa 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -22,6 +22,8 @@ from __future__ import print_function from .random import randperm from .attribute import rank #DEFINE_ALIAS from .attribute import shape #DEFINE_ALIAS +from .attribute import real #DEFINE_ALIAS +from .attribute import imag #DEFINE_ALIAS from .creation import to_tensor #DEFINE_ALIAS from .creation import diag #DEFINE_ALIAS from .creation import eye #DEFINE_ALIAS diff --git a/python/paddle/tensor/attribute.py b/python/paddle/tensor/attribute.py index 255557673c1..499586b083f 100644 --- a/python/paddle/tensor/attribute.py +++ b/python/paddle/tensor/attribute.py @@ -12,8 +12,111 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + +from ..fluid.framework import core, in_dygraph_mode, Variable +from ..fluid.layer_helper import LayerHelper +from ..fluid.data_feeder import check_variable_and_dtype + # TODO: define functions to get tensor attributes from ..fluid.layers import rank #DEFINE_ALIAS from ..fluid.layers import shape #DEFINE_ALIAS -__all__ = ['rank', 'shape'] +__all__ = ['rank', 'shape', 'real', 'imag'] + + +def _complex_to_real_dtype(dtype): + if dtype == core.VarDesc.VarType.COMPLEX64: + return core.VarDesc.VarType.FP32 + elif dtype == core.VarDesc.VarType.COMPLEX128: + return core.VarDesc.VarType.FP64 + else: + return dtype + + +def real(x, name=None): + """ + Returns a new tensor containing real values of the input tensor. + + Args: + x (Tensor): the input tensor, its data type could be complex64 or complex128. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Tensor: a tensor containing real values of the input tensor. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor( + [[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]]) + # Tensor(shape=[2, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[(1+6j), (2+5j), (3+4j)], + # [(4+3j), (5+2j), (6+1j)]]) + + real_res = paddle.real(x) + # Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[1., 2., 3.], + # [4., 5., 6.]]) + + real_t = x.real() + # Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[1., 2., 3.], + # [4., 5., 6.]]) + """ + if in_dygraph_mode(): + return core.ops.real(x) + + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'real') + helper = LayerHelper('real', **locals()) + out = helper.create_variable_for_type_inference( + dtype=_complex_to_real_dtype(helper.input_dtype())) + helper.append_op(type='real', inputs={'X': x}, outputs={'Out': out}) + return out + + +def imag(x, name=None): + """ + Returns a new tensor containing imaginary values of input tensor. + + Args: + x (Tensor): the input tensor, its data type could be complex64 or complex128. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Tensor: a tensor containing imaginary values of the input tensor. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor( + [[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]]) + # Tensor(shape=[2, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[(1+6j), (2+5j), (3+4j)], + # [(4+3j), (5+2j), (6+1j)]]) + + imag_res = paddle.imag(x) + # Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[6., 5., 4.], + # [3., 2., 1.]]) + + imag_t = x.imag() + # Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[6., 5., 4.], + # [3., 2., 1.]]) + """ + if in_dygraph_mode(): + return core.ops.imag(x) + + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'imag') + helper = LayerHelper('imag', **locals()) + out = helper.create_variable_for_type_inference( + dtype=_complex_to_real_dtype(helper.input_dtype())) + helper.append_op(type='imag', inputs={'X': x}, outputs={'Out': out}) + return out -- GitLab