From 79d918d973c53522c42f5ca42993c50d80a29cc7 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Fri, 21 May 2021 14:46:18 +0800 Subject: [PATCH] replace complex64/128 with complex template in cast Op (#33019) * replace complex in set tensor from and to numpy * replace complex template in cast op --- paddle/fluid/operators/cast_op.cc | 18 ++--- paddle/fluid/operators/cast_op.cu | 8 +- paddle/fluid/pybind/tensor_py.h | 56 +------------- .../tests/unittests/test_complex_cast.py | 73 +++++++++++++++++++ 4 files changed, 88 insertions(+), 67 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_complex_cast.py diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 40f4b969ec..7252ed72b2 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -90,13 +90,11 @@ REGISTER_OPERATOR(cast, ops::CastOp, ops::CastOpGradMaker, ops::CastOpGradMaker, ops::CastOpProtoMaker); -REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel, - ops::CastOpKernel); +REGISTER_OP_CPU_KERNEL( + cast, ops::CastOpKernel, ops::CastOpKernel, + ops::CastOpKernel, ops::CastOpKernel, + ops::CastOpKernel, ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel>, + ops::CastOpKernel>); diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 2ef5b9ae3a..1ac110b3ca 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -106,9 +106,9 @@ REGISTER_OP_CUDA_KERNEL( ops::CastOpKernel, ops::CastOpKernel, + paddle::platform::complex>, ops::CastOpKernel); + paddle::platform::complex>); #else REGISTER_OP_CUDA_KERNEL( cast, ops::CastOpKernel, @@ -122,7 +122,7 @@ REGISTER_OP_CUDA_KERNEL( ops::CastOpKernel, ops::CastOpKernel, + paddle::platform::complex>, ops::CastOpKernel); + paddle::platform::complex>); #endif diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 2095b49974..586cbda7cc 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -84,45 +84,7 @@ struct npy_format_descriptor { static constexpr auto name = _("bfloat16"); }; -// we register paddle::platform::complex64 as numpy.complex64. -template <> -struct npy_format_descriptor { - static py::dtype dtype() { - handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX64); - return reinterpret_borrow(ptr); - } - - static std::string format() { - // Note: "F" represents complex64. - // Details at: - // https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx - // for k, v in np.sctypeDict.iteritems(): - // print '{0:14s} : {1:40s}'.format(str(k), v) - return "F"; - } - static constexpr auto name = _("complext64"); -}; - -// we register paddle::platform::complex128 as numpy.complex128. -template <> -struct npy_format_descriptor { - static py::dtype dtype() { - handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX128); - return reinterpret_borrow(ptr); - } - - static std::string format() { - // Note: "D" represents complex128. - // Details at: - // https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx - // for k, v in np.sctypeDict.iteritems(): - // print '{0:14s} : {1:40s}'.format(str(k), v) - return "D"; - } - static constexpr auto name = _("complext128"); -}; - -// we register paddle::platform::complex64 as numpy.complex64. +// we register paddle::platform::complex as numpy.complex64. template <> struct npy_format_descriptor> { static py::dtype dtype() { @@ -205,8 +167,6 @@ struct ValidDTypeToPyArrayChecker { DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex64); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex128); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex); DECLARE_VALID_DTYPE_TO_PY_ARRAY(float); @@ -227,10 +187,6 @@ inline std::string TensorDTypeToPyDTypeStr( } else if (std::is_same::value) { \ /* NumPy character code of uint16 due to no support for bfloat16 */ \ return "H"; \ - } else if (std::is_same::value) { \ - return "F"; \ - } else if (std::is_same::value) { \ - return "D"; \ } else if (std::is_same>::value) { \ return "F"; \ } else if (std::is_same>::value) { \ @@ -410,12 +366,6 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); - } else if (py::isinstance>(array)) { - SetTensorFromPyArrayT(self, array, place, - zero_copy); - } else if (py::isinstance>(array)) { - SetTensorFromPyArrayT(self, array, place, - zero_copy); } else if (py::isinstance>>( array)) { SetTensorFromPyArrayT, P>( @@ -645,9 +595,9 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self, case framework::proto::VarType::BF16: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::COMPLEX64: - return _sliceAndConcat(self, obj, dim); + return _sliceAndConcat>(self, obj, dim); case framework::proto::VarType::COMPLEX128: - return _sliceAndConcat(self, obj, dim); + return _sliceAndConcat>(self, obj, dim); case framework::proto::VarType::FP32: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::FP64: diff --git a/python/paddle/fluid/tests/unittests/test_complex_cast.py b/python/paddle/fluid/tests/unittests/test_complex_cast.py new file mode 100644 index 0000000000..b4162be5b3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_cast.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division + +import unittest +import numpy as np + +import paddle + + +class TestComplexCastOp(unittest.TestCase): + def test_complex_to_real(self): + r = np.random.random(size=[10, 10]) * 10 + i = np.random.random(size=[10, 10]) + + c_t = paddle.to_tensor(r + i * 1J, dtype='complex64') + + self.assertEqual(c_t.cast('int64').dtype, paddle.int64) + self.assertEqual(c_t.cast('int32').dtype, paddle.int32) + self.assertEqual(c_t.cast('float32').dtype, paddle.float32) + self.assertEqual(c_t.cast('float64').dtype, paddle.float64) + self.assertEqual(c_t.cast('bool').dtype, paddle.bool) + + self.assertTrue( + np.allclose(c_t.cast('int64').numpy(), r.astype('int64'))) + self.assertTrue( + np.allclose(c_t.cast('int32').numpy(), r.astype('int32'))) + self.assertTrue( + np.allclose(c_t.cast('float32').numpy(), r.astype('float32'))) + self.assertTrue( + np.allclose(c_t.cast('float64').numpy(), r.astype('float64'))) + self.assertTrue(np.allclose(c_t.cast('bool').numpy(), r.astype('bool'))) + + def test_real_to_complex(self): + r = np.random.random(size=[10, 10]) * 10 + r_t = paddle.to_tensor(r) + + self.assertEqual(r_t.cast('complex64').dtype, paddle.complex64) + self.assertEqual(r_t.cast('complex128').dtype, paddle.complex128) + + self.assertTrue(np.allclose(r_t.cast('complex64').real().numpy(), r)) + self.assertTrue(np.allclose(r_t.cast('complex128').real().numpy(), r)) + + def test_complex64_complex128(self): + r = np.random.random(size=[10, 10]) + i = np.random.random(size=[10, 10]) + + c = r + i * 1J + c_64 = paddle.to_tensor(c, dtype='complex64') + c_128 = paddle.to_tensor(c, dtype='complex128') + + self.assertTrue(c_64.cast('complex128').dtype, paddle.complex128) + self.assertTrue(c_128.cast('complex128').dtype, paddle.complex64) + self.assertTrue( + np.allclose(c_64.cast('complex128').numpy(), c_128.numpy())) + self.assertTrue( + np.allclose(c_128.cast('complex128').numpy(), c_64.numpy())) + + +if __name__ == '__main__': + unittest.main() -- GitLab