未验证 提交 0c38708a 编写于 作者: J Jiabin Yang 提交者: GitHub

[Custom Op] Remove unsupport dtypes (#31232)

* remove remove_unsupport_dtype

* remove remove_unsupport_dtype

* remove test dtype

* add more include

* change dtype.h's enum as enum class to avoid conflict with inference lib

* make enum as enum class

* remove additional test

* merge develop

* polish code
上级 b8bce682
...@@ -69,23 +69,6 @@ namespace paddle { ...@@ -69,23 +69,6 @@ namespace paddle {
} \ } \
}() }()
///////// Complex Dispatch Marco ///////////
#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating and Integral Dispatch Marco /////////// ///////// Floating and Integral Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
...@@ -112,57 +95,6 @@ namespace paddle { ...@@ -112,57 +95,6 @@ namespace paddle {
} \ } \
}() }()
///////// Floating and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating, Integral and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
// TODO(chenweihang): Add more Marcos in the future if needed // TODO(chenweihang): Add more Marcos in the future if needed
} // namespace paddle } // namespace paddle
...@@ -11,34 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,34 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <cstdint>
#include "paddle/fluid/platform/bfloat16.h" #include <stdexcept>
#include "paddle/fluid/platform/complex128.h" #include <string>
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
using float16 = paddle::platform::float16; enum class DataType {
using bfloat16 = paddle::platform::bfloat16;
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
enum DataType {
BOOL, BOOL,
INT8, INT8,
UINT8, UINT8,
INT16, INT16,
INT32, INT32,
INT64, INT64,
FLOAT16,
BFLOAT16,
FLOAT32, FLOAT32,
FLOAT64, FLOAT64,
COMPLEX64,
COMPLEX128,
// TODO(JiabinYang) support more data types if needed. // TODO(JiabinYang) support more data types if needed.
}; };
...@@ -56,36 +44,24 @@ inline std::string ToString(DataType dtype) { ...@@ -56,36 +44,24 @@ inline std::string ToString(DataType dtype) {
return "int32_t"; return "int32_t";
case DataType::INT64: case DataType::INT64:
return "int64_t"; return "int64_t";
case DataType::FLOAT16:
return "float16";
case DataType::BFLOAT16:
return "bfloat16";
case DataType::FLOAT32: case DataType::FLOAT32:
return "float"; return "float";
case DataType::FLOAT64: case DataType::FLOAT64:
return "double"; return "double";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
default: default:
throw std::runtime_error("Unsupported paddle enum data type."); throw std::runtime_error("Unsupported paddle enum data type.");
} }
} }
#define PD_FOR_EACH_DATA_TYPE(_) \ #define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \ _(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \ _(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \ _(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \ _(int16_t, DataType::INT16) \
_(int, DataType::INT32) \ _(int, DataType::INT32) \
_(int64_t, DataType::INT64) \ _(int64_t, DataType::INT64) \
_(float16, DataType::FLOAT16) \ _(float, DataType::FLOAT32) \
_(bfloat16, DataType::BFLOAT16) \ _(double, DataType::FLOAT64)
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64) \
_(complex64, DataType::COMPLEX64) \
_(complex128, DataType::COMPLEX128)
template <paddle::DataType T> template <paddle::DataType T>
struct DataTypeToCPPType; struct DataTypeToCPPType;
......
...@@ -24,6 +24,7 @@ namespace paddle { ...@@ -24,6 +24,7 @@ namespace paddle {
namespace framework { namespace framework {
class CustomTensorUtils; class CustomTensorUtils;
} // namespace framework } // namespace framework
class PD_DLL_DECL Tensor { class PD_DLL_DECL Tensor {
public: public:
/// \brief Construct a Tensor on target Place for CustomOp. /// \brief Construct a Tensor on target Place for CustomOp.
......
...@@ -159,17 +159,10 @@ DataType Tensor::type() const { ...@@ -159,17 +159,10 @@ DataType Tensor::type() const {
return DataType::UINT8; return DataType::UINT8;
} else if (type == framework::proto::VarType::FP64) { } else if (type == framework::proto::VarType::FP64) {
return DataType::FLOAT64; return DataType::FLOAT64;
} else if (type == framework::proto::VarType::BF16) {
return DataType::BFLOAT16;
} else if (type == framework::proto::VarType::FP16) {
return DataType::FLOAT16;
} else if (type == framework::proto::VarType::COMPLEX64) {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
} else if (type == framework::proto::VarType::BOOL) { } else if (type == framework::proto::VarType::BOOL) {
return DataType::BOOL; return DataType::BOOL;
} }
// TODO(JiabinYang) Support more dtype here
return DataType::FLOAT32; return DataType::FLOAT32;
} }
...@@ -207,14 +200,6 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const { ...@@ -207,14 +200,6 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
return target; return target;
} }
template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::bfloat16>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor template PD_DLL_DECL Tensor
Tensor::copy_to<float>(const PlaceType &target_place) const; Tensor::copy_to<float>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor template PD_DLL_DECL Tensor
...@@ -238,14 +223,6 @@ template PD_DLL_DECL int64_t *Tensor::data<int64_t>() const; ...@@ -238,14 +223,6 @@ template PD_DLL_DECL int64_t *Tensor::data<int64_t>() const;
template PD_DLL_DECL int32_t *Tensor::data<int32_t>() const; template PD_DLL_DECL int32_t *Tensor::data<int32_t>() const;
template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const; template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const; template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::data<paddle::platform::bfloat16>() const;
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const; template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL bool *Tensor::data<bool>() const; template PD_DLL_DECL bool *Tensor::data<bool>() const;
...@@ -255,14 +232,6 @@ template PD_DLL_DECL int64_t *Tensor::mutable_data<int64_t>(); ...@@ -255,14 +232,6 @@ template PD_DLL_DECL int64_t *Tensor::mutable_data<int64_t>();
template PD_DLL_DECL int32_t *Tensor::mutable_data<int32_t>(); template PD_DLL_DECL int32_t *Tensor::mutable_data<int32_t>();
template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>(); template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(); template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>();
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>();
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(); template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>();
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(); template PD_DLL_DECL bool *Tensor::mutable_data<bool>();
...@@ -277,14 +246,6 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>( ...@@ -277,14 +246,6 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>(
const PlaceType &place); const PlaceType &place);
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>( template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
const PlaceType &place); const PlaceType &place);
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>( template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place); const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place); template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
...@@ -320,14 +281,6 @@ Tensor Tensor::cast(const DataType &target_type) const { ...@@ -320,14 +281,6 @@ Tensor Tensor::cast(const DataType &target_type) const {
auto dst_type = auto dst_type =
framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type); framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type);
switch (src_type) { switch (src_type) {
case framework::proto::VarType::FP16:
framework::VisitDataType(
dst_type, CastDataType<platform::float16>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BF16:
framework::VisitDataType(dst_type, CastDataType<platform::bfloat16>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP32: case framework::proto::VarType::FP32:
framework::VisitDataType(dst_type, framework::VisitDataType(dst_type,
CastDataType<float>(*tensor, rlt_tensor_, ctx)); CastDataType<float>(*tensor, rlt_tensor_, ctx));
...@@ -356,14 +309,7 @@ Tensor Tensor::cast(const DataType &target_type) const { ...@@ -356,14 +309,7 @@ Tensor Tensor::cast(const DataType &target_type) const {
framework::VisitDataType( framework::VisitDataType(
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx)); dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break; break;
case framework::proto::VarType::COMPLEX64: // TODO(JiabinYang) Support more dtype here
framework::VisitDataType(dst_type, CastDataType<platform::complex64>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX128:
framework::VisitDataType(dst_type, CastDataType<platform::complex128>(
*tensor, rlt_tensor_, ctx));
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.", "Data type (%s) is not supported when casting data type.",
......
...@@ -91,7 +91,7 @@ void TestCast(paddle::DataType data_type) { ...@@ -91,7 +91,7 @@ void TestCast(paddle::DataType data_type) {
t1.reshape(tensor_shape); t1.reshape(tensor_shape);
t1.template mutable_data<T>(); t1.template mutable_data<T>();
auto t2 = t1.cast(data_type); auto t2 = t1.cast(data_type);
CHECK_EQ(t2.type(), data_type); CHECK(t2.type() == data_type);
} }
void GroupTestCopy() { void GroupTestCopy() {
...@@ -99,14 +99,6 @@ void GroupTestCopy() { ...@@ -99,14 +99,6 @@ void GroupTestCopy() {
TestCopyTensor<float>(); TestCopyTensor<float>();
VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<double>(); TestCopyTensor<double>();
VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::float16>();
VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::bfloat16>();
VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::complex128>();
VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::complex64>();
VLOG(2) << "int cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "int cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int>(); TestCopyTensor<int>();
VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu"; VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu";
...@@ -128,31 +120,17 @@ void GroupTestCast() { ...@@ -128,31 +120,17 @@ void GroupTestCast() {
TestCast<int64_t>(paddle::DataType::FLOAT32); TestCast<int64_t>(paddle::DataType::FLOAT32);
VLOG(2) << "double cast"; VLOG(2) << "double cast";
TestCast<double>(paddle::DataType::FLOAT32); TestCast<double>(paddle::DataType::FLOAT32);
VLOG(2) << "bfloat16 cast";
TestCast<paddle::platform::bfloat16>(paddle::DataType::FLOAT32);
VLOG(2) << "float16 cast";
TestCast<paddle::platform::float16>(paddle::DataType::FLOAT32);
VLOG(2) << "bool cast"; VLOG(2) << "bool cast";
TestCast<bool>(paddle::DataType::FLOAT32); TestCast<bool>(paddle::DataType::FLOAT32);
VLOG(2) << "uint8 cast"; VLOG(2) << "uint8 cast";
TestCast<uint8_t>(paddle::DataType::FLOAT32); TestCast<uint8_t>(paddle::DataType::FLOAT32);
VLOG(2) << "float cast"; VLOG(2) << "float cast";
TestCast<float>(paddle::DataType::FLOAT32); TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex64 cast";
TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex128 cast";
TestCast<float>(paddle::DataType::FLOAT32);
} }
void GroupTestDtype() { void GroupTestDtype() {
CHECK(TestDtype<float>() == paddle::DataType::FLOAT32); CHECK(TestDtype<float>() == paddle::DataType::FLOAT32);
CHECK(TestDtype<double>() == paddle::DataType::FLOAT64); CHECK(TestDtype<double>() == paddle::DataType::FLOAT64);
CHECK(TestDtype<paddle::platform::float16>() == paddle::DataType::FLOAT16);
CHECK(TestDtype<paddle::platform::bfloat16>() == paddle::DataType::BFLOAT16);
CHECK(TestDtype<paddle::platform::complex128>() ==
paddle::DataType::COMPLEX128);
CHECK(TestDtype<paddle::platform::complex64>() ==
paddle::DataType::COMPLEX64);
CHECK(TestDtype<int>() == paddle::DataType::INT32); CHECK(TestDtype<int>() == paddle::DataType::INT32);
CHECK(TestDtype<int64_t>() == paddle::DataType::INT64); CHECK(TestDtype<int64_t>() == paddle::DataType::INT64);
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16); CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
...@@ -162,24 +140,12 @@ void GroupTestDtype() { ...@@ -162,24 +140,12 @@ void GroupTestDtype() {
void GroupTestDtypeConvert() { void GroupTestDtypeConvert() {
// enum -> proto // enum -> proto
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX128) ==
paddle::framework::proto::VarType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX64) ==
paddle::framework::proto::VarType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT64) == paddle::DataType::FLOAT64) ==
paddle::framework::proto::VarType::FP64); paddle::framework::proto::VarType::FP64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT32) == paddle::DataType::FLOAT32) ==
paddle::framework::proto::VarType::FP32); paddle::framework::proto::VarType::FP32);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT16) ==
paddle::framework::proto::VarType::FP16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BFLOAT16) ==
paddle::framework::proto::VarType::BF16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::UINT8) == paddle::DataType::UINT8) ==
paddle::framework::proto::VarType::UINT8); paddle::framework::proto::VarType::UINT8);
...@@ -197,24 +163,12 @@ void GroupTestDtypeConvert() { ...@@ -197,24 +163,12 @@ void GroupTestDtypeConvert() {
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL); paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL);
// proto -> enum // proto -> enum
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX128) ==
paddle::DataType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX64) ==
paddle::DataType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP64) == paddle::framework::proto::VarType::FP64) ==
paddle::DataType::FLOAT64); paddle::DataType::FLOAT64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP32) == paddle::framework::proto::VarType::FP32) ==
paddle::DataType::FLOAT32); paddle::DataType::FLOAT32);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP16) ==
paddle::DataType::FLOAT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::BF16) ==
paddle::DataType::BFLOAT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::INT64) == paddle::framework::proto::VarType::INT64) ==
paddle::DataType::INT64); paddle::DataType::INT64);
......
...@@ -39,18 +39,10 @@ class CustomTensorUtils { ...@@ -39,18 +39,10 @@ class CustomTensorUtils {
static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType( static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType(
const paddle::DataType& dtype) { const paddle::DataType& dtype) {
switch (dtype) { switch (dtype) {
case paddle::DataType::COMPLEX128:
return framework::proto::VarType::COMPLEX128;
case paddle::DataType::COMPLEX64:
return framework::proto::VarType::COMPLEX64;
case paddle::DataType::FLOAT64: case paddle::DataType::FLOAT64:
return framework::proto::VarType::FP64; return framework::proto::VarType::FP64;
case paddle::DataType::FLOAT32: case paddle::DataType::FLOAT32:
return framework::proto::VarType::FP32; return framework::proto::VarType::FP32;
case paddle::DataType::FLOAT16:
return framework::proto::VarType::FP16;
case paddle::DataType::BFLOAT16:
return framework::proto::VarType::BF16;
case paddle::DataType::UINT8: case paddle::DataType::UINT8:
return framework::proto::VarType::UINT8; return framework::proto::VarType::UINT8;
case paddle::DataType::INT8: case paddle::DataType::INT8:
...@@ -74,18 +66,10 @@ class CustomTensorUtils { ...@@ -74,18 +66,10 @@ class CustomTensorUtils {
static paddle::DataType ConvertInnerDTypeToEnumDType( static paddle::DataType ConvertInnerDTypeToEnumDType(
const framework::proto::VarType::Type& dtype) { const framework::proto::VarType::Type& dtype) {
switch (dtype) { switch (dtype) {
case framework::proto::VarType::COMPLEX128:
return paddle::DataType::COMPLEX128;
case framework::proto::VarType::COMPLEX64:
return paddle::DataType::COMPLEX64;
case framework::proto::VarType::FP64: case framework::proto::VarType::FP64:
return paddle::DataType::FLOAT64; return paddle::DataType::FLOAT64;
case framework::proto::VarType::FP32: case framework::proto::VarType::FP32:
return paddle::DataType::FLOAT32; return paddle::DataType::FLOAT32;
case framework::proto::VarType::FP16:
return paddle::DataType::FLOAT16;
case framework::proto::VarType::BF16:
return paddle::DataType::BFLOAT16;
case framework::proto::VarType::INT64: case framework::proto::VarType::INT64:
return paddle::DataType::INT64; return paddle::DataType::INT64;
case framework::proto::VarType::INT32: case framework::proto::VarType::INT32:
......
...@@ -44,24 +44,6 @@ PD_BUILD_OP(dispatch_test_integer) ...@@ -44,24 +44,6 @@ PD_BUILD_OP(dispatch_test_integer)
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestInterger)); .SetKernelFn(PD_KERNEL(DispatchTestInterger));
std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndInteger( std::vector<paddle::Tensor> DispatchTestFloatAndInteger(
const paddle::Tensor& x) { const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU); auto out = paddle::Tensor(paddle::PlaceType::kCPU);
...@@ -80,41 +62,3 @@ PD_BUILD_OP(dispatch_test_float_and_integer) ...@@ -80,41 +62,3 @@ PD_BUILD_OP(dispatch_test_float_and_integer)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger)); .SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger));
std::vector<paddle::Tensor> DispatchTestFloatAndComplex(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_float_and_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_float_and_integer_and_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex));
...@@ -55,11 +55,6 @@ class TestJitDispatch(unittest.TestCase): ...@@ -55,11 +55,6 @@ class TestJitDispatch(unittest.TestCase):
for dtype in dtypes: for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_integer, dtype) self.run_dispatch_test(dispatch_op.dispatch_test_integer, dtype)
def test_dispatch_complex(self):
dtypes = ["complex64", "complex128"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_complex, dtype)
def test_dispatch_float_and_integer(self): def test_dispatch_float_and_integer(self):
dtypes = [ dtypes = [
"float32", "float64", "int32", "int64", "int8", "uint8", "int16" "float32", "float64", "int32", "int64", "int8", "uint8", "int16"
...@@ -68,21 +63,6 @@ class TestJitDispatch(unittest.TestCase): ...@@ -68,21 +63,6 @@ class TestJitDispatch(unittest.TestCase):
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_integer, self.run_dispatch_test(dispatch_op.dispatch_test_float_and_integer,
dtype) dtype)
def test_dispatch_float_and_complex(self):
dtypes = ["float32", "float64", "complex64", "complex128"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_complex,
dtype)
def test_dispatch_float_and_integer_and_complex(self):
dtypes = [
"float32", "float64", "int32", "int64", "int8", "uint8", "int16",
"complex64", "complex128"
]
for dtype in dtypes:
self.run_dispatch_test(
dispatch_op.dispatch_test_float_and_integer_and_complex, dtype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册