From 71063b81373f2aa1eb0f761fa2ab1d0201935c80 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 17 Dec 2020 20:43:34 +0800 Subject: [PATCH] add conj op for complex types (#29527) * add conj op for complex types * add conj for complex types * add more test case * add conj_op test * modify conj api and impl * add complex type for fill_constant_op xpu * add setConstant for complex type * remove complex conj test file * user define grad for test_conj_op * add test case for static mode of conj api * modify conj doc * change input args name to x * remove useless codes * conj support real types * add conj test case for real number --- paddle/fluid/operators/conj_op.cc | 87 ++++++++++++ paddle/fluid/operators/conj_op.cu | 28 ++++ paddle/fluid/operators/conj_op.h | 85 ++++++++++++ .../fluid/operators/fill_constant_op_xpu.cc | 4 +- python/paddle/__init__.py | 1 + .../paddle/fluid/tests/unittests/op_test.py | 10 +- .../fluid/tests/unittests/test_conj_op.py | 126 ++++++++++++++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/math.py | 44 +++++- 9 files changed, 382 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/conj_op.cc create mode 100644 paddle/fluid/operators/conj_op.cu create mode 100644 paddle/fluid/operators/conj_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_conj_op.py diff --git a/paddle/fluid/operators/conj_op.cc b/paddle/fluid/operators/conj_op.cc new file mode 100644 index 00000000000..3afe4f1e3d1 --- /dev/null +++ b/paddle/fluid/operators/conj_op.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/conj_op.h" + +#include +#include +#include +#include +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +class ConjOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "conj"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "conj"); + + auto in_dims = ctx->GetInputDim("X"); + + ctx->SetOutputDim("Out", in_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class ConjOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of conj op."); + AddOutput("Out", "(Tensor), The output tensor of conj op."); + AddComment(R"DOC( +Conj Operator. + +This operator is used to perform elementwise conjugate for input $X$. + +)DOC"); + } +}; + +template +class ConjGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("conj"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput("Out", this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker, + ops::ConjGradMaker, + ops::ConjGradMaker); + +REGISTER_OP_CPU_KERNEL( + conj, ops::ConjKernel, + ops::ConjKernel, + ops::ConjKernel, + ops::ConjKernel, + ops::ConjKernel, + ops::ConjKernel); diff --git a/paddle/fluid/operators/conj_op.cu b/paddle/fluid/operators/conj_op.cu new file mode 100644 index 00000000000..601caeb5055 --- /dev/null +++ b/paddle/fluid/operators/conj_op.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/conj_op.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + conj, ops::ConjKernel, + ops::ConjKernel, + ops::ConjKernel, + ops::ConjKernel, + ops::ConjKernel, + ops::ConjKernel); diff --git a/paddle/fluid/operators/conj_op.h b/paddle/fluid/operators/conj_op.h new file mode 100644 index 00000000000..0bec7b707e3 --- /dev/null +++ b/paddle/fluid/operators/conj_op.h @@ -0,0 +1,85 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +using EnableComplex = + typename std::enable_if::value || + std::is_same::value>::type; + +template +using DisableComplex = typename std::enable_if< + !std::is_same::value && + !std::is_same::value>::type; + +template +struct ConjFunctor; + +template +struct ConjFunctor> { + ConjFunctor(const T* input, int64_t numel, T* output) + : input_(input), numel_(numel), output_(output) {} + + HOSTDEVICE void operator()(size_t idx) const { + output_[idx] = T(input_[idx].real, -input_[idx].imag); + } + const T* input_; + int64_t numel_; + T* output_; +}; + +template +struct ConjFunctor> { + ConjFunctor(const T* input, int64_t numel, T* output) + : input_(input), numel_(numel), output_(output) {} + + HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; } + const T* input_; + int64_t numel_; + T* output_; +}; + +template +class ConjKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + auto numel = x->numel(); + auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace(), + size_t(x->numel() * sizeof(T))); + + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + ConjFunctor functor(x_data, numel, out_data); + for_range(functor); + } +}; + +DECLARE_INPLACE_OP_INFERER(ConjOpInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fill_constant_op_xpu.cc b/paddle/fluid/operators/fill_constant_op_xpu.cc index 2bf836272a4..16dd4c9292f 100644 --- a/paddle/fluid/operators/fill_constant_op_xpu.cc +++ b/paddle/fluid/operators/fill_constant_op_xpu.cc @@ -19,5 +19,7 @@ REGISTER_OP_XPU_KERNEL(fill_constant, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, - ops::FillConstantKernel); + ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel); #endif diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 602df10c653..75872ade77d 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -199,6 +199,7 @@ from .tensor.math import isinf #DEFINE_ALIAS from .tensor.math import isnan #DEFINE_ALIAS from .tensor.math import prod #DEFINE_ALIAS from .tensor.math import broadcast_shape #DEFINE_ALIAS +from .tensor.math import conj #DEFINE_ALIAS from .tensor.random import multinomial #DEFINE_ALIAS from .tensor.random import standard_normal diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 836c24d703b..bd38bae42e0 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -145,8 +145,11 @@ def get_numeric_gradient(place, return numpy_tensor[i] elif tensor_to_check_dtype == np.float32: return tensor._get_float_element(i) - else: + elif tensor_to_check_dtype == np.float64: return tensor._get_double_element(i) + else: + raise TypeError("Unsupported test data type %s." % + tensor_to_check_dtype) def __set_elem__(tensor, i, e): if tensor_to_check_dtype == np.float16: @@ -158,8 +161,11 @@ def get_numeric_gradient(place, tensor.set(numpy_tensor, place) elif tensor_to_check_dtype == np.float32: tensor._set_float_element(i, e) - else: + elif tensor_to_check_dtype == np.float64: tensor._set_double_element(i, e) + else: + raise TypeError("Unsupported test data type %s." % + tensor_to_check_dtype) # we only compute gradient of one element each time. # we use a for loop to compute the gradient of every element. diff --git a/python/paddle/fluid/tests/unittests/test_conj_op.py b/python/paddle/fluid/tests/unittests/test_conj_op.py new file mode 100644 index 00000000000..774a29ada4a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_conj_op.py @@ -0,0 +1,126 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys +sys.path.append("..") +from op_test import OpTest +from paddle.fluid import Program, program_guard +import paddle.fluid.dygraph as dg +import paddle.static as static +from numpy.random import random as rand + +paddle.enable_static() + + +class TestConjOp(OpTest): + def setUp(self): + self.op_type = "conj" + self.init_dtype_type() + self.init_input_output() + self.init_grad_input_output() + + def init_dtype_type(self): + self.dtype = np.complex64 + + def init_input_output(self): + x = (np.random.random((12, 14)) + 1j * np.random.random( + (12, 14))).astype(self.dtype) + out = np.conj(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def init_grad_input_output(self): + self.grad_out = (np.ones((12, 14)) + 1j * np.ones( + (12, 14))).astype(self.dtype) + self.grad_in = np.conj(self.grad_out) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[self.grad_in], + user_defined_grad_outputs=[self.grad_out]) + + +class TestComplexConjOp(unittest.TestCase): + def setUp(self): + self._dtypes = ["float32", "float64"] + self._places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + self._places.append(paddle.CUDAPlace(0)) + + def test_conj_api(self): + for dtype in self._dtypes: + input = rand([2, 20, 2, 3]).astype(dtype) + 1j * rand( + [2, 20, 2, 3]).astype(dtype) + for place in self._places: + with dg.guard(place): + var_x = paddle.to_tensor(input) + result = paddle.conj(var_x).numpy() + target = np.conj(input) + self.assertTrue(np.array_equal(result, target)) + + def test_conj_operator(self): + for dtype in self._dtypes: + input = rand([2, 20, 2, 3]).astype(dtype) + 1j * rand( + [2, 20, 2, 3]).astype(dtype) + for place in self._places: + with dg.guard(place): + var_x = paddle.to_tensor(input) + result = var_x.conj().numpy() + target = np.conj(input) + self.assertTrue(np.array_equal(result, target)) + + def test_conj_static_mode(self): + def init_input_output(dtype): + input = rand([2, 20, 2, 3]).astype(dtype) + 1j * rand( + [2, 20, 2, 3]).astype(dtype) + return {'x': input}, np.conj(input) + + for dtype in self._dtypes: + input_dict, np_res = init_input_output(dtype) + for place in self._places: + with static.program_guard(static.Program()): + x_dtype = np.complex64 if dtype == "float32" else np.complex128 + x = static.data( + name="x", shape=[2, 20, 2, 3], dtype=x_dtype) + out = paddle.conj(x) + + exe = static.Executor(place) + out_value = exe.run(feed=input_dict, fetch_list=[out.name]) + self.assertTrue(np.array_equal(np_res, out_value[0])) + + def test_conj_api_real_number(self): + for dtype in self._dtypes: + input = rand([2, 20, 2, 3]).astype(dtype) + for place in self._places: + with dg.guard(place): + var_x = paddle.to_tensor(input) + result = paddle.conj(var_x).numpy() + target = np.conj(input) + self.assertTrue(np.array_equal(result, target)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index f6e0ccd85fa..317c38494bb 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -170,6 +170,7 @@ from .math import prod #DEFINE_ALIAS from .math import all #DEFINE_ALIAS from .math import any #DEFINE_ALIAS from .math import broadcast_shape #DEFINE_ALIAS +from .math import conj #DEFINE_ALIAS from .random import multinomial #DEFINE_ALIAS from .random import standard_normal diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 3d3d24c7c25..a7b75491814 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -125,7 +125,8 @@ __all__ = [ 'isfinite', 'isinf', 'isnan', - 'broadcast_shape' + 'broadcast_shape', + 'conj' ] # yapf: enable. @@ -2214,3 +2215,44 @@ def broadcast_shape(x_shape, y_shape): """ return core.broadcast_shape(x_shape, y_shape) + +def conj(x, name=None): + r""" + This function computes the conjugate of the Tensor elementwisely. + + Args: + x (Tensor): The input tensor which hold the complex numbers. + Optional data types are: complex64, complex128, float32, float64, int32 or int64. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + out (Tensor): The conjugate of input. The shape and data type is the same with input. + If the elements of tensor is real type such as float32, float64, int32 or int64, the out is the same with input. + + Examples: + .. code-block:: python + + import paddle + data=paddle.to_tensor([[1+1j, 2+2j, 3+3j], [4+4j, 5+5j, 6+6j]]) + #Tensor(shape=[2, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[(1+1j), (2+2j), (3+3j)], + # [(4+4j), (5+5j), (6+6j)]]) + + conj_data=paddle.conj(data) + #Tensor(shape=[2, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[(1-1j), (2-2j), (3-3j)], + # [(4-4j), (5-5j), (6-6j)]]) + + """ + if in_dygraph_mode(): + return core.ops.conj(x) + + check_variable_and_dtype(x, "x", ['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'], 'conj') + + helper = LayerHelper('conj', **locals()) + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + + helper.append_op(type='conj', inputs={'X': x}, outputs={'Out': [out]}) + return out -- GitLab