未验证 提交 06c63ca0 编写于 作者: C chentianyu03 提交者: GitHub

replace and remove complex64/128 types in custom OP and other files (#33195)

* replace and remove complex64/128 types in custom OP and other files

* fix custom_tensor_test fail bug

* fix custom_conj_test fail bug

* fix dispatch_test_op build fail bug
上级 dfce571c
...@@ -193,10 +193,7 @@ copy(inference_lib_dist ...@@ -193,10 +193,7 @@ copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/* SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/*
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex64.h SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/float16.h SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/float16.h
......
...@@ -16,15 +16,14 @@ limitations under the License. */ ...@@ -16,15 +16,14 @@ limitations under the License. */
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include "complex128.h" // NOLINT #include "complex.h" // NOLINT
#include "complex64.h" // NOLINT
#include "ext_exception.h" // NOLINT #include "ext_exception.h" // NOLINT
#include "float16.h" // NOLINT #include "float16.h" // NOLINT
namespace paddle { namespace paddle {
using complex64 = paddle::platform::complex64; using complex64 = paddle::platform::complex<float>;
using complex128 = paddle::platform::complex128; using complex128 = paddle::platform::complex<double>;
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
enum class DataType { enum class DataType {
......
...@@ -19,8 +19,7 @@ limitations under the License. */ ...@@ -19,8 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
...@@ -238,9 +237,9 @@ template PD_DLL_DECL Tensor ...@@ -238,9 +237,9 @@ template PD_DLL_DECL Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const; Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor template PD_DLL_DECL Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const; Tensor::copy_to<bool>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>( template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex<float>>(
const PlaceType &target_place) const; const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>( template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex<double>>(
const PlaceType &target_place) const; const PlaceType &target_place) const;
template PD_DLL_DECL Tensor template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const; Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
...@@ -253,10 +252,10 @@ template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const; ...@@ -253,10 +252,10 @@ template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const; template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const; template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL bool *Tensor::data<bool>() const; template PD_DLL_DECL bool *Tensor::data<bool>() const;
template PD_DLL_DECL paddle::platform::complex64 * template PD_DLL_DECL paddle::platform::complex<float>
Tensor::data<paddle::platform::complex64>() const; *Tensor::data<paddle::platform::complex<float>>() const;
template PD_DLL_DECL paddle::platform::complex128 * template PD_DLL_DECL paddle::platform::complex<double>
Tensor::data<paddle::platform::complex128>() const; *Tensor::data<paddle::platform::complex<double>>() const;
template PD_DLL_DECL paddle::platform::float16 * template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const; Tensor::data<paddle::platform::float16>() const;
...@@ -268,10 +267,10 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>(); ...@@ -268,10 +267,10 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(); template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>();
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(); template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>();
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(); template PD_DLL_DECL bool *Tensor::mutable_data<bool>();
template PD_DLL_DECL paddle::platform::complex64 * template PD_DLL_DECL paddle::platform::complex<float>
Tensor::mutable_data<paddle::platform::complex64>(); *Tensor::mutable_data<paddle::platform::complex<float>>();
template PD_DLL_DECL paddle::platform::complex128 * template PD_DLL_DECL paddle::platform::complex<double>
Tensor::mutable_data<paddle::platform::complex128>(); *Tensor::mutable_data<paddle::platform::complex<double>>();
template PD_DLL_DECL paddle::platform::float16 * template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(); Tensor::mutable_data<paddle::platform::float16>();
...@@ -289,10 +288,10 @@ template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>( ...@@ -289,10 +288,10 @@ template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>( template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place); const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place); template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex64 * template PD_DLL_DECL paddle::platform::complex<float> *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place); Tensor::mutable_data<paddle::platform::complex<float>>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 * template PD_DLL_DECL paddle::platform::complex<double> *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place); Tensor::mutable_data<paddle::platform::complex<double>>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::float16 * template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place); Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
...@@ -356,13 +355,13 @@ Tensor Tensor::cast(const DataType &target_type) const { ...@@ -356,13 +355,13 @@ Tensor Tensor::cast(const DataType &target_type) const {
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx)); dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break; break;
case framework::proto::VarType::COMPLEX64: case framework::proto::VarType::COMPLEX64:
framework::VisitDataType( framework::VisitDataType(dst_type,
dst_type, CastDataType<paddle::platform::complex<float>>(
CastDataType<paddle::platform::complex64>(*tensor, rlt_tensor_, ctx)); *tensor, rlt_tensor_, ctx));
break; break;
case framework::proto::VarType::COMPLEX128: case framework::proto::VarType::COMPLEX128:
framework::VisitDataType(dst_type, framework::VisitDataType(dst_type,
CastDataType<paddle::platform::complex128>( CastDataType<paddle::platform::complex<double>>(
*tensor, rlt_tensor_, ctx)); *tensor, rlt_tensor_, ctx));
break; break;
case framework::proto::VarType::FP16: case framework::proto::VarType::FP16:
......
...@@ -109,9 +109,9 @@ void GroupTestCopy() { ...@@ -109,9 +109,9 @@ void GroupTestCopy() {
TestCopyTensor<int8_t>(); TestCopyTensor<int8_t>();
VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<uint8_t>(); TestCopyTensor<uint8_t>();
VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "complex<float> cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::complex64>(); TestCopyTensor<paddle::complex64>();
VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "complex<double> cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::complex128>(); TestCopyTensor<paddle::complex128>();
VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::float16>(); TestCopyTensor<paddle::float16>();
...@@ -132,9 +132,9 @@ void GroupTestCast() { ...@@ -132,9 +132,9 @@ void GroupTestCast() {
TestCast<uint8_t>(paddle::DataType::FLOAT32); TestCast<uint8_t>(paddle::DataType::FLOAT32);
VLOG(2) << "float cast"; VLOG(2) << "float cast";
TestCast<float>(paddle::DataType::FLOAT32); TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex64 cast"; VLOG(2) << "complex<float> cast";
TestCast<paddle::complex64>(paddle::DataType::FLOAT32); TestCast<paddle::complex64>(paddle::DataType::FLOAT32);
VLOG(2) << "complex128 cast"; VLOG(2) << "complex<double> cast";
TestCast<paddle::complex128>(paddle::DataType::FLOAT32); TestCast<paddle::complex128>(paddle::DataType::FLOAT32);
VLOG(2) << "float16 cast"; VLOG(2) << "float16 cast";
TestCast<paddle::float16>(paddle::DataType::FLOAT16); TestCast<paddle::float16>(paddle::DataType::FLOAT16);
......
...@@ -19,8 +19,6 @@ limitations under the License. */ ...@@ -19,8 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/eigen_ext.h" #include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -28,8 +26,8 @@ limitations under the License. */ ...@@ -28,8 +26,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct bfloat16; struct bfloat16;
struct complex128; template <typename T>
struct complex64; struct complex;
struct float16; struct float16;
template <typename T> template <typename T>
struct complex; struct complex;
...@@ -53,35 +51,31 @@ struct DataTypeTrait<void> { ...@@ -53,35 +51,31 @@ struct DataTypeTrait<void> {
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \ #define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type); callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \ #define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \ _ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \ _ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \ _ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \ _ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \ _ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \ _ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \ _ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \ COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); \ COMPLEX128);
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128); #define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
#define _ForEachDataTypeSmall_(callback) \ _ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, float, FP32); \ _ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, double, FP64); \ _ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, int, INT32); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \ COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX64); \ COMPLEX128);
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
// For the use of thrust, as index-type elements can be only integers. // For the use of thrust, as index-type elements can be only integers.
#define _ForEachDataTypeTiny_(callback) \ #define _ForEachDataTypeTiny_(callback) \
......
...@@ -119,12 +119,12 @@ void TransComplexToReal(const proto::VarType::Type& dst_type, ...@@ -119,12 +119,12 @@ void TransComplexToReal(const proto::VarType::Type& dst_type,
// complex -> real // complex -> real
switch (src_type) { switch (src_type) {
case proto::VarType::COMPLEX64: case proto::VarType::COMPLEX64:
framework::VisitDataType(dst_type, framework::VisitDataType(
CastDataType<platform::complex64>(in, out, ctx)); dst_type, CastDataType<platform::complex<float>>(in, out, ctx));
break; break;
case proto::VarType::COMPLEX128: case proto::VarType::COMPLEX128:
framework::VisitDataType( framework::VisitDataType(
dst_type, CastDataType<platform::complex128>(in, out, ctx)); dst_type, CastDataType<platform::complex<double>>(in, out, ctx));
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -29,9 +29,7 @@ namespace { // NOLINT ...@@ -29,9 +29,7 @@ namespace { // NOLINT
template <typename T> template <typename T>
constexpr uint8_t GetDLDataTypeCode() { constexpr uint8_t GetDLDataTypeCode() {
if (std::is_same<T, platform::complex<float>>::value || if (std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value || std::is_same<T, platform::complex<double>>::value) {
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value) {
return static_cast<uint8_t>(5); return static_cast<uint8_t>(5);
} }
......
...@@ -22,8 +22,7 @@ limitations under the License. */ ...@@ -22,8 +22,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -1137,9 +1136,9 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { ...@@ -1137,9 +1136,9 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
} }
template <> template <>
std::ostream& print_tensor<paddle::platform::complex64>( std::ostream& print_tensor<paddle::platform::complex<float>>(
std::ostream& os, const framework::Tensor& tensor) { std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<paddle::platform::complex64>(); auto inspect = tensor.data<paddle::platform::complex<float>>();
auto element_num = tensor.numel(); auto element_num = tensor.numel();
os << " - data: ["; os << " - data: [";
...@@ -1155,9 +1154,9 @@ std::ostream& print_tensor<paddle::platform::complex64>( ...@@ -1155,9 +1154,9 @@ std::ostream& print_tensor<paddle::platform::complex64>(
} }
template <> template <>
std::ostream& print_tensor<paddle::platform::complex128>( std::ostream& print_tensor<paddle::platform::complex<double>>(
std::ostream& os, const framework::Tensor& tensor) { std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<paddle::platform::complex128>(); auto inspect = tensor.data<paddle::platform::complex<double>>();
auto element_num = tensor.numel(); auto element_num = tensor.numel();
os << " - data: ["; os << " - data: [";
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#include "paddle/fluid/operators/abs_op.h" #include "paddle/fluid/operators/abs_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
......
...@@ -23,8 +23,6 @@ namespace paddle { ...@@ -23,8 +23,6 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using complex64 = platform::complex64;
using complex128 = platform::complex128;
template <typename T, typename R> template <typename T, typename R>
struct P { struct P {
......
...@@ -64,9 +64,7 @@ using select_t = typename select<Head, Tail...>::type; ...@@ -64,9 +64,7 @@ using select_t = typename select<Head, Tail...>::type;
template <typename T> template <typename T>
using Real = using Real =
select_t<cond<std::is_same<T, platform::complex64>::value, float>, select_t<cond<std::is_same<T, platform::complex<float>>::value, float>,
cond<std::is_same<T, platform::complex128>::value, double>,
cond<std::is_same<T, platform::complex<float>>::value, float>,
cond<std::is_same<T, platform::complex<double>>::value, double>, cond<std::is_same<T, platform::complex<double>>::value, double>,
T>; T>;
...@@ -79,15 +77,11 @@ using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type; ...@@ -79,15 +77,11 @@ using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
template <typename T> template <typename T>
using EnableComplex = typename std::enable_if< using EnableComplex = typename std::enable_if<
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value ||
std::is_same<T, platform::complex<float>>::value || std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type; std::is_same<T, platform::complex<double>>::value>::type;
template <typename T> template <typename T>
using DisableComplex = typename std::enable_if< using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value &&
!std::is_same<T, platform::complex<float>>::value && !std::is_same<T, platform::complex<float>>::value &&
!std::is_same<T, platform::complex<double>>::value>::type; !std::is_same<T, platform::complex<double>>::value>::type;
......
...@@ -65,18 +65,16 @@ class SplitFunctor { ...@@ -65,18 +65,16 @@ class SplitFunctor {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_ALL_TYPES(macro) \ #define FOR_ALL_TYPES(macro) \
macro(int); \ macro(int); \
macro(float); \ macro(float); \
macro(double); \ macro(double); \
macro(bool); \ macro(bool); \
macro(int64_t); \ macro(int64_t); \
macro(int16_t); \ macro(int16_t); \
macro(uint8_t); \ macro(uint8_t); \
macro(int8_t); \ macro(int8_t); \
macro(::paddle::platform::float16); \ macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16); \ macro(::paddle::platform::bfloat16); \
macro(::paddle::platform::complex<float>); \ macro(::paddle::platform::complex<float>); \
macro(::paddle::platform::complex<double>); \ macro(::paddle::platform::complex<double>);
macro(::paddle::platform::complex64); \
macro(::paddle::platform::complex128)
...@@ -45,8 +45,6 @@ template struct SetConstant<platform::CPUDeviceContext, int>; ...@@ -45,8 +45,6 @@ template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>; template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>; template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>; template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex64>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
template struct SetConstant<platform::CPUDeviceContext, template struct SetConstant<platform::CPUDeviceContext,
platform::complex<float>>; platform::complex<float>>;
template struct SetConstant<platform::CPUDeviceContext, template struct SetConstant<platform::CPUDeviceContext,
...@@ -61,35 +59,29 @@ template struct SetConstant<platform::XPUDeviceContext, uint8_t>; ...@@ -61,35 +59,29 @@ template struct SetConstant<platform::XPUDeviceContext, uint8_t>;
template struct SetConstant<platform::XPUDeviceContext, int>; template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>; template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>; template struct SetConstant<platform::XPUDeviceContext, bool>;
template struct SetConstant<platform::XPUDeviceContext, platform::complex64>;
template struct SetConstant<platform::XPUDeviceContext, platform::complex128>;
template struct SetConstant<platform::XPUDeviceContext, template struct SetConstant<platform::XPUDeviceContext,
platform::complex<float>>; platform::complex<float>>;
template struct SetConstant<platform::XPUDeviceContext, template struct SetConstant<platform::XPUDeviceContext,
platform::complex<double>>; platform::complex<double>>;
#endif #endif
#define DEFINE_CPU_TRANS(RANK) \ #define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \ template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \ RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \ template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \ RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \ template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \ template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \ template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \ template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, \ template struct Transpose<platform::CPUDeviceContext, \
platform::complex<float>, RANK>; \ platform::complex<float>, RANK>; \
template struct Transpose<platform::CPUDeviceContext, \ template struct Transpose<platform::CPUDeviceContext, \
platform::complex<double>, RANK>; \ platform::complex<double>, RANK>;
template struct Transpose<platform::CPUDeviceContext, platform::complex64, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex128, \
RANK>;
DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2); DEFINE_CPU_TRANS(2);
...@@ -140,8 +132,6 @@ DEFINE_CPU_TRANS_NORMAL(bool); ...@@ -140,8 +132,6 @@ DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int16_t); DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t); DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int8_t); DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(platform::complex64);
DEFINE_CPU_TRANS_NORMAL(platform::complex128);
DEFINE_CPU_TRANS_NORMAL(platform::complex<float>); DEFINE_CPU_TRANS_NORMAL(platform::complex<float>);
DEFINE_CPU_TRANS_NORMAL(platform::complex<double>); DEFINE_CPU_TRANS_NORMAL(platform::complex<double>);
......
...@@ -20,8 +20,6 @@ limitations under the License. */ ...@@ -20,8 +20,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -30,8 +28,6 @@ namespace math { ...@@ -30,8 +28,6 @@ namespace math {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16; using bfloat16 = paddle::platform::bfloat16;
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
template struct SetConstant<platform::CUDADeviceContext, platform::float16>; template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, platform::bfloat16>; template struct SetConstant<platform::CUDADeviceContext, platform::bfloat16>;
...@@ -41,27 +37,23 @@ template struct SetConstant<platform::CUDADeviceContext, uint8_t>; ...@@ -41,27 +37,23 @@ template struct SetConstant<platform::CUDADeviceContext, uint8_t>;
template struct SetConstant<platform::CUDADeviceContext, int>; template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>; template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>; template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex64>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex128>;
template struct SetConstant<platform::CUDADeviceContext, template struct SetConstant<platform::CUDADeviceContext,
platform::complex<float>>; platform::complex<float>>;
template struct SetConstant<platform::CUDADeviceContext, template struct SetConstant<platform::CUDADeviceContext,
platform::complex<double>>; platform::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \ template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \ template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \ template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \ template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>; \ template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, \ template struct Transpose<platform::CUDADeviceContext, \
paddle::platform::complex<float>, RANK>; \ paddle::platform::complex<float>, RANK>; \
template struct Transpose<platform::CUDADeviceContext, \ template struct Transpose<platform::CUDADeviceContext, \
paddle::platform::complex<double>, RANK>; \ paddle::platform::complex<double>, RANK>;
template struct Transpose<platform::CUDADeviceContext, complex64, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex128, RANK>;
DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2); DEFINE_GPU_TRANS(2);
...@@ -151,8 +143,6 @@ DEFINE_GPU_TRANS_NORMAL(bool); ...@@ -151,8 +143,6 @@ DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int16_t); DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t); DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int8_t); DEFINE_GPU_TRANS_NORMAL(int8_t);
DEFINE_GPU_TRANS_NORMAL(complex64);
DEFINE_GPU_TRANS_NORMAL(complex128);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<float>); DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<float>);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<double>); DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<double>);
......
...@@ -448,8 +448,9 @@ template struct MergeAdd<platform::CUDADeviceContext, double>; ...@@ -448,8 +448,9 @@ template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>; template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>; template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>; template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex64>; template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex128>; template struct MergeAdd<platform::CUDADeviceContext,
platform::complex<double>>;
template <typename T, int block_size> template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows, __global__ void UpdateToTensorKernel(const T* selected_rows,
......
...@@ -20,8 +20,7 @@ limitations under the License. */ ...@@ -20,8 +20,7 @@ limitations under the License. */
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#include <stdio.h> #include <stdio.h>
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -135,18 +134,18 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { ...@@ -135,18 +134,18 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
} }
#endif #endif
CUDA_ATOMIC_WRAPPER(Add, complex64) { CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
float *real = reinterpret_cast<float *>(address); float *real = reinterpret_cast<float *>(address);
float *imag = real + 1; float *imag = real + 1;
return complex64(CudaAtomicAdd(real, val.real), return complex<float>(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag)); CudaAtomicAdd(imag, val.imag));
} }
CUDA_ATOMIC_WRAPPER(Add, complex128) { CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
double *real = reinterpret_cast<double *>(address); double *real = reinterpret_cast<double *>(address);
double *imag = real + 1; double *imag = real + 1;
return complex128(CudaAtomicAdd(real, val.real), return complex<double>(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag)); CudaAtomicAdd(imag, val.imag));
} }
// For atomicMax // For atomicMax
......
...@@ -406,6 +406,7 @@ headers = ( ...@@ -406,6 +406,7 @@ headers = (
# to `extension/incude`, # to `extension/incude`,
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex64.h'] + ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex64.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex128.h'] + ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex128.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/float16.h']) ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/float16.h'])
if '${WITH_MKLDNN}' == 'ON': if '${WITH_MKLDNN}' == 'ON':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册