未验证 提交 0c38708a 编写于 作者: J Jiabin Yang 提交者: GitHub

[Custom Op] Remove unsupport dtypes (#31232)

* remove remove_unsupport_dtype

* remove remove_unsupport_dtype

* remove test dtype

* add more include

* change dtype.h's enum as enum class to avoid conflict with inference lib

* make enum as enum class

* remove additional test

* merge develop

* polish code
上级 b8bce682
......@@ -69,23 +69,6 @@ namespace paddle {
} \
}()
///////// Complex Dispatch Marco ///////////
#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating and Integral Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
......@@ -112,57 +95,6 @@ namespace paddle {
} \
}()
///////// Floating and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating, Integral and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
// TODO(chenweihang): Add more Marcos in the future if needed
} // namespace paddle
......@@ -11,34 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
#include <cstdint>
#include <stdexcept>
#include <string>
namespace paddle {
using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
enum DataType {
enum class DataType {
BOOL,
INT8,
UINT8,
INT16,
INT32,
INT64,
FLOAT16,
BFLOAT16,
FLOAT32,
FLOAT64,
COMPLEX64,
COMPLEX128,
// TODO(JiabinYang) support more data types if needed.
};
......@@ -56,36 +44,24 @@ inline std::string ToString(DataType dtype) {
return "int32_t";
case DataType::INT64:
return "int64_t";
case DataType::FLOAT16:
return "float16";
case DataType::BFLOAT16:
return "bfloat16";
case DataType::FLOAT32:
return "float";
case DataType::FLOAT64:
return "double";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
default:
throw std::runtime_error("Unsupported paddle enum data type.");
}
}
#define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(int, DataType::INT32) \
_(int64_t, DataType::INT64) \
_(float16, DataType::FLOAT16) \
_(bfloat16, DataType::BFLOAT16) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64) \
_(complex64, DataType::COMPLEX64) \
_(complex128, DataType::COMPLEX128)
#define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(int, DataType::INT32) \
_(int64_t, DataType::INT64) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64)
template <paddle::DataType T>
struct DataTypeToCPPType;
......
......@@ -24,6 +24,7 @@ namespace paddle {
namespace framework {
class CustomTensorUtils;
} // namespace framework
class PD_DLL_DECL Tensor {
public:
/// \brief Construct a Tensor on target Place for CustomOp.
......
......@@ -159,17 +159,10 @@ DataType Tensor::type() const {
return DataType::UINT8;
} else if (type == framework::proto::VarType::FP64) {
return DataType::FLOAT64;
} else if (type == framework::proto::VarType::BF16) {
return DataType::BFLOAT16;
} else if (type == framework::proto::VarType::FP16) {
return DataType::FLOAT16;
} else if (type == framework::proto::VarType::COMPLEX64) {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
} else if (type == framework::proto::VarType::BOOL) {
return DataType::BOOL;
}
// TODO(JiabinYang) Support more dtype here
return DataType::FLOAT32;
}
......@@ -207,14 +200,6 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
return target;
}
template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::bfloat16>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<float>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
......@@ -238,14 +223,6 @@ template PD_DLL_DECL int64_t *Tensor::data<int64_t>() const;
template PD_DLL_DECL int32_t *Tensor::data<int32_t>() const;
template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::data<paddle::platform::bfloat16>() const;
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL bool *Tensor::data<bool>() const;
......@@ -255,14 +232,6 @@ template PD_DLL_DECL int64_t *Tensor::mutable_data<int64_t>();
template PD_DLL_DECL int32_t *Tensor::mutable_data<int32_t>();
template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>();
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>();
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>();
template PD_DLL_DECL bool *Tensor::mutable_data<bool>();
......@@ -277,14 +246,6 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>(
const PlaceType &place);
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
const PlaceType &place);
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
......@@ -320,14 +281,6 @@ Tensor Tensor::cast(const DataType &target_type) const {
auto dst_type =
framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type);
switch (src_type) {
case framework::proto::VarType::FP16:
framework::VisitDataType(
dst_type, CastDataType<platform::float16>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BF16:
framework::VisitDataType(dst_type, CastDataType<platform::bfloat16>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP32:
framework::VisitDataType(dst_type,
CastDataType<float>(*tensor, rlt_tensor_, ctx));
......@@ -356,14 +309,7 @@ Tensor Tensor::cast(const DataType &target_type) const {
framework::VisitDataType(
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX64:
framework::VisitDataType(dst_type, CastDataType<platform::complex64>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX128:
framework::VisitDataType(dst_type, CastDataType<platform::complex128>(
*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang) Support more dtype here
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
......
......@@ -91,7 +91,7 @@ void TestCast(paddle::DataType data_type) {
t1.reshape(tensor_shape);
t1.template mutable_data<T>();
auto t2 = t1.cast(data_type);
CHECK_EQ(t2.type(), data_type);
CHECK(t2.type() == data_type);
}
void GroupTestCopy() {
......@@ -99,14 +99,6 @@ void GroupTestCopy() {
TestCopyTensor<float>();
VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<double>();
VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::float16>();
VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::bfloat16>();
VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::complex128>();
VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::complex64>();
VLOG(2) << "int cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int>();
VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu";
......@@ -128,31 +120,17 @@ void GroupTestCast() {
TestCast<int64_t>(paddle::DataType::FLOAT32);
VLOG(2) << "double cast";
TestCast<double>(paddle::DataType::FLOAT32);
VLOG(2) << "bfloat16 cast";
TestCast<paddle::platform::bfloat16>(paddle::DataType::FLOAT32);
VLOG(2) << "float16 cast";
TestCast<paddle::platform::float16>(paddle::DataType::FLOAT32);
VLOG(2) << "bool cast";
TestCast<bool>(paddle::DataType::FLOAT32);
VLOG(2) << "uint8 cast";
TestCast<uint8_t>(paddle::DataType::FLOAT32);
VLOG(2) << "float cast";
TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex64 cast";
TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex128 cast";
TestCast<float>(paddle::DataType::FLOAT32);
}
void GroupTestDtype() {
CHECK(TestDtype<float>() == paddle::DataType::FLOAT32);
CHECK(TestDtype<double>() == paddle::DataType::FLOAT64);
CHECK(TestDtype<paddle::platform::float16>() == paddle::DataType::FLOAT16);
CHECK(TestDtype<paddle::platform::bfloat16>() == paddle::DataType::BFLOAT16);
CHECK(TestDtype<paddle::platform::complex128>() ==
paddle::DataType::COMPLEX128);
CHECK(TestDtype<paddle::platform::complex64>() ==
paddle::DataType::COMPLEX64);
CHECK(TestDtype<int>() == paddle::DataType::INT32);
CHECK(TestDtype<int64_t>() == paddle::DataType::INT64);
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
......@@ -162,24 +140,12 @@ void GroupTestDtype() {
void GroupTestDtypeConvert() {
// enum -> proto
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX128) ==
paddle::framework::proto::VarType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX64) ==
paddle::framework::proto::VarType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT64) ==
paddle::framework::proto::VarType::FP64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT32) ==
paddle::framework::proto::VarType::FP32);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT16) ==
paddle::framework::proto::VarType::FP16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BFLOAT16) ==
paddle::framework::proto::VarType::BF16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::UINT8) ==
paddle::framework::proto::VarType::UINT8);
......@@ -197,24 +163,12 @@ void GroupTestDtypeConvert() {
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL);
// proto -> enum
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX128) ==
paddle::DataType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX64) ==
paddle::DataType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP64) ==
paddle::DataType::FLOAT64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP32) ==
paddle::DataType::FLOAT32);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP16) ==
paddle::DataType::FLOAT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::BF16) ==
paddle::DataType::BFLOAT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::INT64) ==
paddle::DataType::INT64);
......
......@@ -39,18 +39,10 @@ class CustomTensorUtils {
static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType(
const paddle::DataType& dtype) {
switch (dtype) {
case paddle::DataType::COMPLEX128:
return framework::proto::VarType::COMPLEX128;
case paddle::DataType::COMPLEX64:
return framework::proto::VarType::COMPLEX64;
case paddle::DataType::FLOAT64:
return framework::proto::VarType::FP64;
case paddle::DataType::FLOAT32:
return framework::proto::VarType::FP32;
case paddle::DataType::FLOAT16:
return framework::proto::VarType::FP16;
case paddle::DataType::BFLOAT16:
return framework::proto::VarType::BF16;
case paddle::DataType::UINT8:
return framework::proto::VarType::UINT8;
case paddle::DataType::INT8:
......@@ -74,18 +66,10 @@ class CustomTensorUtils {
static paddle::DataType ConvertInnerDTypeToEnumDType(
const framework::proto::VarType::Type& dtype) {
switch (dtype) {
case framework::proto::VarType::COMPLEX128:
return paddle::DataType::COMPLEX128;
case framework::proto::VarType::COMPLEX64:
return paddle::DataType::COMPLEX64;
case framework::proto::VarType::FP64:
return paddle::DataType::FLOAT64;
case framework::proto::VarType::FP32:
return paddle::DataType::FLOAT32;
case framework::proto::VarType::FP16:
return paddle::DataType::FLOAT16;
case framework::proto::VarType::BF16:
return paddle::DataType::BFLOAT16;
case framework::proto::VarType::INT64:
return paddle::DataType::INT64;
case framework::proto::VarType::INT32:
......
......@@ -44,24 +44,6 @@ PD_BUILD_OP(dispatch_test_integer)
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestInterger));
std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndInteger(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
......@@ -80,41 +62,3 @@ PD_BUILD_OP(dispatch_test_float_and_integer)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger));
std::vector<paddle::Tensor> DispatchTestFloatAndComplex(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_float_and_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_float_and_integer_and_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex));
......@@ -55,11 +55,6 @@ class TestJitDispatch(unittest.TestCase):
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_integer, dtype)
def test_dispatch_complex(self):
dtypes = ["complex64", "complex128"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_complex, dtype)
def test_dispatch_float_and_integer(self):
dtypes = [
"float32", "float64", "int32", "int64", "int8", "uint8", "int16"
......@@ -68,21 +63,6 @@ class TestJitDispatch(unittest.TestCase):
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_integer,
dtype)
def test_dispatch_float_and_complex(self):
dtypes = ["float32", "float64", "complex64", "complex128"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_complex,
dtype)
def test_dispatch_float_and_integer_and_complex(self):
dtypes = [
"float32", "float64", "int32", "int64", "int8", "uint8", "int16",
"complex64", "complex128"
]
for dtype in dtypes:
self.run_dispatch_test(
dispatch_op.dispatch_test_float_and_integer_and_complex, dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册