未验证 提交 79d918d9 编写于 作者: C chentianyu03 提交者: GitHub

replace complex64/128 with complex template in cast Op (#33019)

* replace complex in set tensor from and to numpy

* replace complex template in cast op
上级 79ed7177
......@@ -90,13 +90,11 @@ REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex64>,
ops::CastOpKernel<CPU, paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
cast, ops::CastOpKernel<CPU, float>, ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex<float>>,
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);
......@@ -106,9 +106,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
#else
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -122,7 +122,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
#endif
......@@ -84,45 +84,7 @@ struct npy_format_descriptor<paddle::platform::bfloat16> {
static constexpr auto name = _("bfloat16");
};
// we register paddle::platform::complex64 as numpy.complex64.
template <>
struct npy_format_descriptor<paddle::platform::complex64> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX64);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "F" represents complex64.
// Details at:
// https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx
// for k, v in np.sctypeDict.iteritems():
// print '{0:14s} : {1:40s}'.format(str(k), v)
return "F";
}
static constexpr auto name = _("complext64");
};
// we register paddle::platform::complex128 as numpy.complex128.
template <>
struct npy_format_descriptor<paddle::platform::complex128> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX128);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "D" represents complex128.
// Details at:
// https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx
// for k, v in np.sctypeDict.iteritems():
// print '{0:14s} : {1:40s}'.format(str(k), v)
return "D";
}
static constexpr auto name = _("complext128");
};
// we register paddle::platform::complex64 as numpy.complex64.
// we register paddle::platform::complex<float> as numpy.complex64.
template <>
struct npy_format_descriptor<paddle::platform::complex<float>> {
static py::dtype dtype() {
......@@ -205,8 +167,6 @@ struct ValidDTypeToPyArrayChecker {
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex64);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex128);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex<float>);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex<double>);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
......@@ -227,10 +187,6 @@ inline std::string TensorDTypeToPyDTypeStr(
} else if (std::is_same<T, platform::bfloat16>::value) { \
/* NumPy character code of uint16 due to no support for bfloat16 */ \
return "H"; \
} else if (std::is_same<T, platform::complex64>::value) { \
return "F"; \
} else if (std::is_same<T, platform::complex128>::value) { \
return "D"; \
} else if (std::is_same<T, platform::complex<float>>::value) { \
return "F"; \
} else if (std::is_same<T, platform::complex<double>>::value) { \
......@@ -410,12 +366,6 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
} else if (py::isinstance<py::array_t<paddle::platform::float16>>(array)) {
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::complex64>>(array)) {
SetTensorFromPyArrayT<paddle::platform::complex64, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::complex128>>(array)) {
SetTensorFromPyArrayT<paddle::platform::complex128, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::complex<float>>>(
array)) {
SetTensorFromPyArrayT<paddle::platform::complex<float>, P>(
......@@ -645,9 +595,9 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self,
case framework::proto::VarType::BF16:
return _sliceAndConcat<paddle::platform::bfloat16>(self, obj, dim);
case framework::proto::VarType::COMPLEX64:
return _sliceAndConcat<paddle::platform::complex64>(self, obj, dim);
return _sliceAndConcat<paddle::platform::complex<float>>(self, obj, dim);
case framework::proto::VarType::COMPLEX128:
return _sliceAndConcat<paddle::platform::complex128>(self, obj, dim);
return _sliceAndConcat<paddle::platform::complex<double>>(self, obj, dim);
case framework::proto::VarType::FP32:
return _sliceAndConcat<float>(self, obj, dim);
case framework::proto::VarType::FP64:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import unittest
import numpy as np
import paddle
class TestComplexCastOp(unittest.TestCase):
def test_complex_to_real(self):
r = np.random.random(size=[10, 10]) * 10
i = np.random.random(size=[10, 10])
c_t = paddle.to_tensor(r + i * 1J, dtype='complex64')
self.assertEqual(c_t.cast('int64').dtype, paddle.int64)
self.assertEqual(c_t.cast('int32').dtype, paddle.int32)
self.assertEqual(c_t.cast('float32').dtype, paddle.float32)
self.assertEqual(c_t.cast('float64').dtype, paddle.float64)
self.assertEqual(c_t.cast('bool').dtype, paddle.bool)
self.assertTrue(
np.allclose(c_t.cast('int64').numpy(), r.astype('int64')))
self.assertTrue(
np.allclose(c_t.cast('int32').numpy(), r.astype('int32')))
self.assertTrue(
np.allclose(c_t.cast('float32').numpy(), r.astype('float32')))
self.assertTrue(
np.allclose(c_t.cast('float64').numpy(), r.astype('float64')))
self.assertTrue(np.allclose(c_t.cast('bool').numpy(), r.astype('bool')))
def test_real_to_complex(self):
r = np.random.random(size=[10, 10]) * 10
r_t = paddle.to_tensor(r)
self.assertEqual(r_t.cast('complex64').dtype, paddle.complex64)
self.assertEqual(r_t.cast('complex128').dtype, paddle.complex128)
self.assertTrue(np.allclose(r_t.cast('complex64').real().numpy(), r))
self.assertTrue(np.allclose(r_t.cast('complex128').real().numpy(), r))
def test_complex64_complex128(self):
r = np.random.random(size=[10, 10])
i = np.random.random(size=[10, 10])
c = r + i * 1J
c_64 = paddle.to_tensor(c, dtype='complex64')
c_128 = paddle.to_tensor(c, dtype='complex128')
self.assertTrue(c_64.cast('complex128').dtype, paddle.complex128)
self.assertTrue(c_128.cast('complex128').dtype, paddle.complex64)
self.assertTrue(
np.allclose(c_64.cast('complex128').numpy(), c_128.numpy()))
self.assertTrue(
np.allclose(c_128.cast('complex128').numpy(), c_64.numpy()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册