未验证 提交 06c63ca0 编写于 作者: C chentianyu03 提交者: GitHub

replace and remove complex64/128 types in custom OP and other files (#33195)

* replace and remove complex64/128 types in custom OP and other files

* fix custom_tensor_test fail bug

* fix custom_conj_test fail bug

* fix dispatch_test_op build fail bug
上级 dfce571c
......@@ -193,10 +193,7 @@ copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/*
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex64.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/float16.h
......
......@@ -16,15 +16,14 @@ limitations under the License. */
#include <cstdint>
#include <string>
#include "complex128.h" // NOLINT
#include "complex64.h" // NOLINT
#include "complex.h" // NOLINT
#include "ext_exception.h" // NOLINT
#include "float16.h" // NOLINT
namespace paddle {
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
using complex64 = paddle::platform::complex<float>;
using complex128 = paddle::platform::complex<double>;
using float16 = paddle::platform::float16;
enum class DataType {
......
......@@ -19,8 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h"
......@@ -238,9 +237,9 @@ template PD_DLL_DECL Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex<float>>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex<double>>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
......@@ -253,10 +252,10 @@ template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL bool *Tensor::data<bool>() const;
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
template PD_DLL_DECL paddle::platform::complex<float>
*Tensor::data<paddle::platform::complex<float>>() const;
template PD_DLL_DECL paddle::platform::complex<double>
*Tensor::data<paddle::platform::complex<double>>() const;
template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
......@@ -268,10 +267,10 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>();
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>();
template PD_DLL_DECL bool *Tensor::mutable_data<bool>();
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
template PD_DLL_DECL paddle::platform::complex<float>
*Tensor::mutable_data<paddle::platform::complex<float>>();
template PD_DLL_DECL paddle::platform::complex<double>
*Tensor::mutable_data<paddle::platform::complex<double>>();
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
......@@ -289,10 +288,10 @@ template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex<float> *
Tensor::mutable_data<paddle::platform::complex<float>>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex<double> *
Tensor::mutable_data<paddle::platform::complex<double>>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
......@@ -356,13 +355,13 @@ Tensor Tensor::cast(const DataType &target_type) const {
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX64:
framework::VisitDataType(
dst_type,
CastDataType<paddle::platform::complex64>(*tensor, rlt_tensor_, ctx));
framework::VisitDataType(dst_type,
CastDataType<paddle::platform::complex<float>>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX128:
framework::VisitDataType(dst_type,
CastDataType<paddle::platform::complex128>(
CastDataType<paddle::platform::complex<double>>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP16:
......
......@@ -109,9 +109,9 @@ void GroupTestCopy() {
TestCopyTensor<int8_t>();
VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<uint8_t>();
VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu";
VLOG(2) << "complex<float> cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::complex64>();
VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
VLOG(2) << "complex<double> cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::complex128>();
VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::float16>();
......@@ -132,9 +132,9 @@ void GroupTestCast() {
TestCast<uint8_t>(paddle::DataType::FLOAT32);
VLOG(2) << "float cast";
TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex64 cast";
VLOG(2) << "complex<float> cast";
TestCast<paddle::complex64>(paddle::DataType::FLOAT32);
VLOG(2) << "complex128 cast";
VLOG(2) << "complex<double> cast";
TestCast<paddle::complex128>(paddle::DataType::FLOAT32);
VLOG(2) << "float16 cast";
TestCast<paddle::float16>(paddle::DataType::FLOAT16);
......
......@@ -19,8 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
......@@ -28,8 +26,8 @@ limitations under the License. */
namespace paddle {
namespace platform {
struct bfloat16;
struct complex128;
struct complex64;
template <typename T>
struct complex;
struct float16;
template <typename T>
struct complex;
......@@ -53,35 +51,31 @@ struct DataTypeTrait<void> {
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
#define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128);
#define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128);
// For the use of thrust, as index-type elements can be only integers.
#define _ForEachDataTypeTiny_(callback) \
......
......@@ -119,12 +119,12 @@ void TransComplexToReal(const proto::VarType::Type& dst_type,
// complex -> real
switch (src_type) {
case proto::VarType::COMPLEX64:
framework::VisitDataType(dst_type,
CastDataType<platform::complex64>(in, out, ctx));
framework::VisitDataType(
dst_type, CastDataType<platform::complex<float>>(in, out, ctx));
break;
case proto::VarType::COMPLEX128:
framework::VisitDataType(
dst_type, CastDataType<platform::complex128>(in, out, ctx));
dst_type, CastDataType<platform::complex<double>>(in, out, ctx));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -29,9 +29,7 @@ namespace { // NOLINT
template <typename T>
constexpr uint8_t GetDLDataTypeCode() {
if (std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value ||
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value) {
std::is_same<T, platform::complex<double>>::value) {
return static_cast<uint8_t>(5);
}
......
......@@ -22,8 +22,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......@@ -1137,9 +1136,9 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
}
template <>
std::ostream& print_tensor<paddle::platform::complex64>(
std::ostream& print_tensor<paddle::platform::complex<float>>(
std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<paddle::platform::complex64>();
auto inspect = tensor.data<paddle::platform::complex<float>>();
auto element_num = tensor.numel();
os << " - data: [";
......@@ -1155,9 +1154,9 @@ std::ostream& print_tensor<paddle::platform::complex64>(
}
template <>
std::ostream& print_tensor<paddle::platform::complex128>(
std::ostream& print_tensor<paddle::platform::complex<double>>(
std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<paddle::platform::complex128>();
auto inspect = tensor.data<paddle::platform::complex<double>>();
auto element_num = tensor.numel();
os << " - data: [";
......
......@@ -14,8 +14,6 @@
#include "paddle/fluid/operators/abs_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......
......@@ -23,8 +23,6 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using complex64 = platform::complex64;
using complex128 = platform::complex128;
template <typename T, typename R>
struct P {
......
......@@ -64,9 +64,7 @@ using select_t = typename select<Head, Tail...>::type;
template <typename T>
using Real =
select_t<cond<std::is_same<T, platform::complex64>::value, float>,
cond<std::is_same<T, platform::complex128>::value, double>,
cond<std::is_same<T, platform::complex<float>>::value, float>,
select_t<cond<std::is_same<T, platform::complex<float>>::value, float>,
cond<std::is_same<T, platform::complex<double>>::value, double>,
T>;
......@@ -79,15 +77,11 @@ using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
template <typename T>
using EnableComplex = typename std::enable_if<
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value ||
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value &&
!std::is_same<T, platform::complex<float>>::value &&
!std::is_same<T, platform::complex<double>>::value>::type;
......
......@@ -65,18 +65,16 @@ class SplitFunctor {
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16); \
macro(::paddle::platform::complex<float>); \
macro(::paddle::platform::complex<double>); \
macro(::paddle::platform::complex64); \
macro(::paddle::platform::complex128)
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16); \
macro(::paddle::platform::complex<float>); \
macro(::paddle::platform::complex<double>);
......@@ -45,8 +45,6 @@ template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex64>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
template struct SetConstant<platform::CPUDeviceContext,
platform::complex<float>>;
template struct SetConstant<platform::CPUDeviceContext,
......@@ -61,35 +59,29 @@ template struct SetConstant<platform::XPUDeviceContext, uint8_t>;
template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
template struct SetConstant<platform::XPUDeviceContext, platform::complex64>;
template struct SetConstant<platform::XPUDeviceContext, platform::complex128>;
template struct SetConstant<platform::XPUDeviceContext,
platform::complex<float>>;
template struct SetConstant<platform::XPUDeviceContext,
platform::complex<double>>;
#endif
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, \
platform::complex<float>, RANK>; \
template struct Transpose<platform::CPUDeviceContext, \
platform::complex<double>, RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex64, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex128, \
RANK>;
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, \
platform::complex<float>, RANK>; \
template struct Transpose<platform::CPUDeviceContext, \
platform::complex<double>, RANK>;
DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);
......@@ -140,8 +132,6 @@ DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(platform::complex64);
DEFINE_CPU_TRANS_NORMAL(platform::complex128);
DEFINE_CPU_TRANS_NORMAL(platform::complex<float>);
DEFINE_CPU_TRANS_NORMAL(platform::complex<double>);
......
......@@ -20,8 +20,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -30,8 +28,6 @@ namespace math {
using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, platform::bfloat16>;
......@@ -41,27 +37,23 @@ template struct SetConstant<platform::CUDADeviceContext, uint8_t>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex64>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex128>;
template struct SetConstant<platform::CUDADeviceContext,
platform::complex<float>>;
template struct SetConstant<platform::CUDADeviceContext,
platform::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, \
paddle::platform::complex<float>, RANK>; \
template struct Transpose<platform::CUDADeviceContext, \
paddle::platform::complex<double>, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex64, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex128, RANK>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, \
paddle::platform::complex<float>, RANK>; \
template struct Transpose<platform::CUDADeviceContext, \
paddle::platform::complex<double>, RANK>;
DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
......@@ -151,8 +143,6 @@ DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int8_t);
DEFINE_GPU_TRANS_NORMAL(complex64);
DEFINE_GPU_TRANS_NORMAL(complex128);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<float>);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<double>);
......
......@@ -448,8 +448,9 @@ template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex64>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex128>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
template struct MergeAdd<platform::CUDADeviceContext,
platform::complex<double>>;
template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
......
......@@ -20,8 +20,7 @@ limitations under the License. */
#include <hip/hip_runtime.h>
#endif
#include <stdio.h>
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -135,18 +134,18 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
}
#endif
CUDA_ATOMIC_WRAPPER(Add, complex64) {
CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
float *real = reinterpret_cast<float *>(address);
float *imag = real + 1;
return complex64(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag));
return complex<float>(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag));
}
CUDA_ATOMIC_WRAPPER(Add, complex128) {
CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
double *real = reinterpret_cast<double *>(address);
double *imag = real + 1;
return complex128(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag));
return complex<double>(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag));
}
// For atomicMax
......
......@@ -406,6 +406,7 @@ headers = (
# to `extension/incude`,
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex64.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex128.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/float16.h'])
if '${WITH_MKLDNN}' == 'ON':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册