From 06c63ca0d5a97460ee324ba8c3869d33f0cf3e48 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 1 Jun 2021 10:35:25 +0800 Subject: [PATCH] replace and remove complex64/128 types in custom OP and other files (#33195) * replace and remove complex64/128 types in custom OP and other files * fix custom_tensor_test fail bug * fix custom_conj_test fail bug * fix dispatch_test_op build fail bug --- cmake/inference_lib.cmake | 5 +- paddle/fluid/extension/include/ext_dtype.h | 7 +-- paddle/fluid/extension/src/ext_tensor.cc | 39 ++++++------ paddle/fluid/framework/custom_tensor_test.cc | 8 +-- paddle/fluid/framework/data_type.h | 60 +++++++++---------- paddle/fluid/framework/data_type_transform.cc | 6 +- paddle/fluid/framework/dlpack_tensor_test.cc | 4 +- paddle/fluid/framework/tensor_util.cc | 11 ++-- paddle/fluid/operators/abs_op.cu | 2 - paddle/fluid/operators/dot_op.h | 2 - .../fluid/operators/math/complex_functors.h | 8 +-- .../fluid/operators/math/concat_and_split.h | 28 ++++----- paddle/fluid/operators/math/math_function.cc | 44 ++++++-------- paddle/fluid/operators/math/math_function.cu | 34 ++++------- .../operators/math/selected_rows_functor.cu | 5 +- paddle/fluid/platform/cuda_primitives.h | 15 +++-- python/setup.py.in | 1 + 17 files changed, 117 insertions(+), 162 deletions(-) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 8220680cec..84ab072ddc 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -193,10 +193,7 @@ copy(inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/* DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) copy(inference_lib_dist - SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex64.h - DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) -copy(inference_lib_dist - SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h + SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) copy(inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/float16.h diff --git a/paddle/fluid/extension/include/ext_dtype.h b/paddle/fluid/extension/include/ext_dtype.h index 3890631a6f..a0816b65a3 100644 --- a/paddle/fluid/extension/include/ext_dtype.h +++ b/paddle/fluid/extension/include/ext_dtype.h @@ -16,15 +16,14 @@ limitations under the License. */ #include #include -#include "complex128.h" // NOLINT -#include "complex64.h" // NOLINT +#include "complex.h" // NOLINT #include "ext_exception.h" // NOLINT #include "float16.h" // NOLINT namespace paddle { -using complex64 = paddle::platform::complex64; -using complex128 = paddle::platform::complex128; +using complex64 = paddle::platform::complex; +using complex128 = paddle::platform::complex; using float16 = paddle::platform::float16; enum class DataType { diff --git a/paddle/fluid/extension/src/ext_tensor.cc b/paddle/fluid/extension/src/ext_tensor.cc index 8b2f7cc5bf..ab98bdc0bf 100644 --- a/paddle/fluid/extension/src/ext_tensor.cc +++ b/paddle/fluid/extension/src/ext_tensor.cc @@ -19,8 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/transform.h" @@ -238,9 +237,9 @@ template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; -template PD_DLL_DECL Tensor Tensor::copy_to( +template PD_DLL_DECL Tensor Tensor::copy_to>( const PlaceType &target_place) const; -template PD_DLL_DECL Tensor Tensor::copy_to( +template PD_DLL_DECL Tensor Tensor::copy_to>( const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; @@ -253,10 +252,10 @@ template PD_DLL_DECL uint8_t *Tensor::data() const; template PD_DLL_DECL int8_t *Tensor::data() const; template PD_DLL_DECL int16_t *Tensor::data() const; template PD_DLL_DECL bool *Tensor::data() const; -template PD_DLL_DECL paddle::platform::complex64 * -Tensor::data() const; -template PD_DLL_DECL paddle::platform::complex128 * -Tensor::data() const; +template PD_DLL_DECL paddle::platform::complex + *Tensor::data>() const; +template PD_DLL_DECL paddle::platform::complex + *Tensor::data>() const; template PD_DLL_DECL paddle::platform::float16 * Tensor::data() const; @@ -268,10 +267,10 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data(); template PD_DLL_DECL int8_t *Tensor::mutable_data(); template PD_DLL_DECL int16_t *Tensor::mutable_data(); template PD_DLL_DECL bool *Tensor::mutable_data(); -template PD_DLL_DECL paddle::platform::complex64 * -Tensor::mutable_data(); -template PD_DLL_DECL paddle::platform::complex128 * -Tensor::mutable_data(); +template PD_DLL_DECL paddle::platform::complex + *Tensor::mutable_data>(); +template PD_DLL_DECL paddle::platform::complex + *Tensor::mutable_data>(); template PD_DLL_DECL paddle::platform::float16 * Tensor::mutable_data(); @@ -289,10 +288,10 @@ template PD_DLL_DECL int8_t *Tensor::mutable_data( template PD_DLL_DECL int16_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL bool *Tensor::mutable_data(const PlaceType &place); -template PD_DLL_DECL paddle::platform::complex64 * -Tensor::mutable_data(const PlaceType &place); -template PD_DLL_DECL paddle::platform::complex128 * -Tensor::mutable_data(const PlaceType &place); +template PD_DLL_DECL paddle::platform::complex * +Tensor::mutable_data>(const PlaceType &place); +template PD_DLL_DECL paddle::platform::complex * +Tensor::mutable_data>(const PlaceType &place); template PD_DLL_DECL paddle::platform::float16 * Tensor::mutable_data(const PlaceType &place); @@ -356,13 +355,13 @@ Tensor Tensor::cast(const DataType &target_type) const { dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::COMPLEX64: - framework::VisitDataType( - dst_type, - CastDataType(*tensor, rlt_tensor_, ctx)); + framework::VisitDataType(dst_type, + CastDataType>( + *tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::COMPLEX128: framework::VisitDataType(dst_type, - CastDataType( + CastDataType>( *tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::FP16: diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc index a65dcbd55f..733831263a 100644 --- a/paddle/fluid/framework/custom_tensor_test.cc +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -109,9 +109,9 @@ void GroupTestCopy() { TestCopyTensor(); VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); - VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu"; + VLOG(2) << "complex cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); - VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; + VLOG(2) << "complex cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); @@ -132,9 +132,9 @@ void GroupTestCast() { TestCast(paddle::DataType::FLOAT32); VLOG(2) << "float cast"; TestCast(paddle::DataType::FLOAT32); - VLOG(2) << "complex64 cast"; + VLOG(2) << "complex cast"; TestCast(paddle::DataType::FLOAT32); - VLOG(2) << "complex128 cast"; + VLOG(2) << "complex cast"; TestCast(paddle::DataType::FLOAT32); VLOG(2) << "float16 cast"; TestCast(paddle::DataType::FLOAT16); diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 648a32420a..a16f35dc11 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -19,8 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/eigen_ext.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" @@ -28,8 +26,8 @@ limitations under the License. */ namespace paddle { namespace platform { struct bfloat16; -struct complex128; -struct complex64; +template +struct complex; struct float16; template struct complex; @@ -53,35 +51,31 @@ struct DataTypeTrait { #define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \ callback(cpp_type, ::paddle::framework::proto::VarType::proto_type); -#define _ForEachDataType_(callback) \ - _ForEachDataTypeHelper_(callback, float, FP32); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \ - _ForEachDataTypeHelper_(callback, double, FP64); \ - _ForEachDataTypeHelper_(callback, int, INT32); \ - _ForEachDataTypeHelper_(callback, int64_t, INT64); \ - _ForEachDataTypeHelper_(callback, bool, BOOL); \ - _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ - _ForEachDataTypeHelper_(callback, int16_t, INT16); \ - _ForEachDataTypeHelper_(callback, int8_t, INT8); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ - COMPLEX64); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ - COMPLEX128); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128); - -#define _ForEachDataTypeSmall_(callback) \ - _ForEachDataTypeHelper_(callback, float, FP32); \ - _ForEachDataTypeHelper_(callback, double, FP64); \ - _ForEachDataTypeHelper_(callback, int, INT32); \ - _ForEachDataTypeHelper_(callback, int64_t, INT64); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ - COMPLEX64); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ - COMPLEX128); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128); +#define _ForEachDataType_(callback) \ + _ForEachDataTypeHelper_(callback, float, FP32); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \ + _ForEachDataTypeHelper_(callback, double, FP64); \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, bool, BOOL); \ + _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ + _ForEachDataTypeHelper_(callback, int16_t, INT16); \ + _ForEachDataTypeHelper_(callback, int8_t, INT8); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX128); + +#define _ForEachDataTypeSmall_(callback) \ + _ForEachDataTypeHelper_(callback, float, FP32); \ + _ForEachDataTypeHelper_(callback, double, FP64); \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX128); // For the use of thrust, as index-type elements can be only integers. #define _ForEachDataTypeTiny_(callback) \ diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 5a716eba8d..888687c06c 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -119,12 +119,12 @@ void TransComplexToReal(const proto::VarType::Type& dst_type, // complex -> real switch (src_type) { case proto::VarType::COMPLEX64: - framework::VisitDataType(dst_type, - CastDataType(in, out, ctx)); + framework::VisitDataType( + dst_type, CastDataType>(in, out, ctx)); break; case proto::VarType::COMPLEX128: framework::VisitDataType( - dst_type, CastDataType(in, out, ctx)); + dst_type, CastDataType>(in, out, ctx)); break; default: PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/framework/dlpack_tensor_test.cc b/paddle/fluid/framework/dlpack_tensor_test.cc index 1a79ada0be..8265d105ac 100644 --- a/paddle/fluid/framework/dlpack_tensor_test.cc +++ b/paddle/fluid/framework/dlpack_tensor_test.cc @@ -29,9 +29,7 @@ namespace { // NOLINT template constexpr uint8_t GetDLDataTypeCode() { if (std::is_same>::value || - std::is_same>::value || - std::is_same::value || - std::is_same::value) { + std::is_same>::value) { return static_cast(5); } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 105751645b..32460a98ce 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -22,8 +22,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -1137,9 +1136,9 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { } template <> -std::ostream& print_tensor( +std::ostream& print_tensor>( std::ostream& os, const framework::Tensor& tensor) { - auto inspect = tensor.data(); + auto inspect = tensor.data>(); auto element_num = tensor.numel(); os << " - data: ["; @@ -1155,9 +1154,9 @@ std::ostream& print_tensor( } template <> -std::ostream& print_tensor( +std::ostream& print_tensor>( std::ostream& os, const framework::Tensor& tensor) { - auto inspect = tensor.data(); + auto inspect = tensor.data>(); auto element_num = tensor.numel(); os << " - data: ["; diff --git a/paddle/fluid/operators/abs_op.cu b/paddle/fluid/operators/abs_op.cu index d03de7a456..b0eba229fd 100644 --- a/paddle/fluid/operators/abs_op.cu +++ b/paddle/fluid/operators/abs_op.cu @@ -14,8 +14,6 @@ #include "paddle/fluid/operators/abs_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" namespace paddle { diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index 0987118ba3..09d607891b 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -23,8 +23,6 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -using complex64 = platform::complex64; -using complex128 = platform::complex128; template struct P { diff --git a/paddle/fluid/operators/math/complex_functors.h b/paddle/fluid/operators/math/complex_functors.h index f530256677..c4bd6ec4f1 100644 --- a/paddle/fluid/operators/math/complex_functors.h +++ b/paddle/fluid/operators/math/complex_functors.h @@ -64,9 +64,7 @@ using select_t = typename select::type; template using Real = - select_t::value, float>, - cond::value, double>, - cond>::value, float>, + select_t>::value, float>, cond>::value, double>, T>; @@ -79,15 +77,11 @@ using NoComplex = typename std::enable_if::value>::type; template using EnableComplex = typename std::enable_if< - std::is_same::value || - std::is_same::value || std::is_same>::value || std::is_same>::value>::type; template using DisableComplex = typename std::enable_if< - !std::is_same::value && - !std::is_same::value && !std::is_same>::value && !std::is_same>::value>::type; diff --git a/paddle/fluid/operators/math/concat_and_split.h b/paddle/fluid/operators/math/concat_and_split.h index a79a9da0b3..65d2ca79e6 100644 --- a/paddle/fluid/operators/math/concat_and_split.h +++ b/paddle/fluid/operators/math/concat_and_split.h @@ -65,18 +65,16 @@ class SplitFunctor { } // namespace operators } // namespace paddle -#define FOR_ALL_TYPES(macro) \ - macro(int); \ - macro(float); \ - macro(double); \ - macro(bool); \ - macro(int64_t); \ - macro(int16_t); \ - macro(uint8_t); \ - macro(int8_t); \ - macro(::paddle::platform::float16); \ - macro(::paddle::platform::bfloat16); \ - macro(::paddle::platform::complex); \ - macro(::paddle::platform::complex); \ - macro(::paddle::platform::complex64); \ - macro(::paddle::platform::complex128) +#define FOR_ALL_TYPES(macro) \ + macro(int); \ + macro(float); \ + macro(double); \ + macro(bool); \ + macro(int64_t); \ + macro(int16_t); \ + macro(uint8_t); \ + macro(int8_t); \ + macro(::paddle::platform::float16); \ + macro(::paddle::platform::bfloat16); \ + macro(::paddle::platform::complex); \ + macro(::paddle::platform::complex); diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index d01a39ecb7..1266ee7462 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -45,8 +45,6 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; -template struct SetConstant; -template struct SetConstant; template struct SetConstant>; template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; -template struct SetConstant; -template struct SetConstant; template struct SetConstant>; template struct SetConstant>; #endif -#define DEFINE_CPU_TRANS(RANK) \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose, RANK>; \ - template struct Transpose, RANK>; \ - template struct Transpose; \ - template struct Transpose; +#define DEFINE_CPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose, RANK>; \ + template struct Transpose, RANK>; DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(2); @@ -140,8 +132,6 @@ DEFINE_CPU_TRANS_NORMAL(bool); DEFINE_CPU_TRANS_NORMAL(int16_t); DEFINE_CPU_TRANS_NORMAL(uint8_t); DEFINE_CPU_TRANS_NORMAL(int8_t); -DEFINE_CPU_TRANS_NORMAL(platform::complex64); -DEFINE_CPU_TRANS_NORMAL(platform::complex128); DEFINE_CPU_TRANS_NORMAL(platform::complex); DEFINE_CPU_TRANS_NORMAL(platform::complex); diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index c5c78c87f7..248f621299 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -20,8 +20,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -30,8 +28,6 @@ namespace math { using float16 = paddle::platform::float16; using bfloat16 = paddle::platform::bfloat16; -using complex64 = paddle::platform::complex64; -using complex128 = paddle::platform::complex128; template struct SetConstant; template struct SetConstant; @@ -41,27 +37,23 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; -template struct SetConstant; -template struct SetConstant; template struct SetConstant>; template struct SetConstant>; -#define DEFINE_GPU_TRANS(RANK) \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose; \ - template struct Transpose, RANK>; \ - template struct Transpose, RANK>; \ - template struct Transpose; \ - template struct Transpose; +#define DEFINE_GPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose, RANK>; \ + template struct Transpose, RANK>; DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(2); @@ -151,8 +143,6 @@ DEFINE_GPU_TRANS_NORMAL(bool); DEFINE_GPU_TRANS_NORMAL(int16_t); DEFINE_GPU_TRANS_NORMAL(uint8_t); DEFINE_GPU_TRANS_NORMAL(int8_t); -DEFINE_GPU_TRANS_NORMAL(complex64); -DEFINE_GPU_TRANS_NORMAL(complex128); DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex); DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 26e9a0de60..f3ef537a31 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -448,8 +448,9 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; +template struct MergeAdd>; +template struct MergeAdd>; template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 94f64d158a..4708a99e8f 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -20,8 +20,7 @@ limitations under the License. */ #include #endif #include -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -135,18 +134,18 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { } #endif -CUDA_ATOMIC_WRAPPER(Add, complex64) { +CUDA_ATOMIC_WRAPPER(Add, complex) { float *real = reinterpret_cast(address); float *imag = real + 1; - return complex64(CudaAtomicAdd(real, val.real), - CudaAtomicAdd(imag, val.imag)); + return complex(CudaAtomicAdd(real, val.real), + CudaAtomicAdd(imag, val.imag)); } -CUDA_ATOMIC_WRAPPER(Add, complex128) { +CUDA_ATOMIC_WRAPPER(Add, complex) { double *real = reinterpret_cast(address); double *imag = real + 1; - return complex128(CudaAtomicAdd(real, val.real), - CudaAtomicAdd(imag, val.imag)); + return complex(CudaAtomicAdd(real, val.real), + CudaAtomicAdd(imag, val.imag)); } // For atomicMax diff --git a/python/setup.py.in b/python/setup.py.in index 3fbe796a81..3bc3057b33 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -406,6 +406,7 @@ headers = ( # to `extension/incude`, ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex64.h'] + ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex128.h'] + + ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex.h'] + ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/float16.h']) if '${WITH_MKLDNN}' == 'ON': -- GitLab