From 0c38708a90019bfe72f06483ab14128eaca1a867 Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Fri, 26 Feb 2021 18:19:28 +0800 Subject: [PATCH] [Custom Op] Remove unsupport dtypes (#31232) * remove remove_unsupport_dtype * remove remove_unsupport_dtype * remove test dtype * add more include * change dtype.h's enum as enum class to avoid conflict with inference lib * make enum as enum class * remove additional test * merge develop * polish code --- paddle/fluid/extension/include/dispatch.h | 68 ------------------- paddle/fluid/extension/include/dtype.h | 50 ++++---------- paddle/fluid/extension/include/tensor.h | 1 + paddle/fluid/extension/src/tensor.cc | 58 +--------------- paddle/fluid/framework/custom_tensor_test.cc | 48 +------------ paddle/fluid/framework/custom_tensor_utils.h | 16 ----- .../fluid/tests/custom_op/dispatch_test_op.cc | 56 --------------- .../tests/custom_op/test_dispatch_jit.py | 20 ------ 8 files changed, 17 insertions(+), 300 deletions(-) diff --git a/paddle/fluid/extension/include/dispatch.h b/paddle/fluid/extension/include/dispatch.h index c229710395..3da64ad07a 100644 --- a/paddle/fluid/extension/include/dispatch.h +++ b/paddle/fluid/extension/include/dispatch.h @@ -69,23 +69,6 @@ namespace paddle { } \ }() -///////// Complex Dispatch Marco /////////// - -#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, __VA_ARGS__) \ - default: \ - throw std::runtime_error("function " #NAME \ - " not implemented for data type `" + \ - ::paddle::ToString(__dtype__) + "`"); \ - } \ - }() - ///////// Floating and Integral Dispatch Marco /////////// #define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ @@ -112,57 +95,6 @@ namespace paddle { } \ }() -///////// Floating and Complex Dispatch Marco /////////// - -#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, __VA_ARGS__) \ - default: \ - throw std::runtime_error("function " #NAME \ - " not implemented for data type `" + \ - ::paddle::ToString(__dtype__) + "`"); \ - } \ - }() - -///////// Floating, Integral and Complex Dispatch Marco /////////// - -#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, __VA_ARGS__) \ - default: \ - throw std::runtime_error("function " #NAME \ - " not implemented for data type `" + \ - ::paddle::ToString(__dtype__) + "`"); \ - } \ - }() - // TODO(chenweihang): Add more Marcos in the future if needed } // namespace paddle diff --git a/paddle/fluid/extension/include/dtype.h b/paddle/fluid/extension/include/dtype.h index c5d2e0f820..38c836c6fc 100644 --- a/paddle/fluid/extension/include/dtype.h +++ b/paddle/fluid/extension/include/dtype.h @@ -11,34 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once - -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" -#include "paddle/fluid/platform/float16.h" +#include +#include +#include namespace paddle { -using float16 = paddle::platform::float16; -using bfloat16 = paddle::platform::bfloat16; -using complex64 = paddle::platform::complex64; -using complex128 = paddle::platform::complex128; - -enum DataType { +enum class DataType { BOOL, INT8, UINT8, INT16, INT32, INT64, - FLOAT16, - BFLOAT16, FLOAT32, FLOAT64, - COMPLEX64, - COMPLEX128, // TODO(JiabinYang) support more data types if needed. }; @@ -56,36 +44,24 @@ inline std::string ToString(DataType dtype) { return "int32_t"; case DataType::INT64: return "int64_t"; - case DataType::FLOAT16: - return "float16"; - case DataType::BFLOAT16: - return "bfloat16"; case DataType::FLOAT32: return "float"; case DataType::FLOAT64: return "double"; - case DataType::COMPLEX64: - return "complex64"; - case DataType::COMPLEX128: - return "complex128"; default: throw std::runtime_error("Unsupported paddle enum data type."); } } -#define PD_FOR_EACH_DATA_TYPE(_) \ - _(bool, DataType::BOOL) \ - _(int8_t, DataType::INT8) \ - _(uint8_t, DataType::UINT8) \ - _(int16_t, DataType::INT16) \ - _(int, DataType::INT32) \ - _(int64_t, DataType::INT64) \ - _(float16, DataType::FLOAT16) \ - _(bfloat16, DataType::BFLOAT16) \ - _(float, DataType::FLOAT32) \ - _(double, DataType::FLOAT64) \ - _(complex64, DataType::COMPLEX64) \ - _(complex128, DataType::COMPLEX128) +#define PD_FOR_EACH_DATA_TYPE(_) \ + _(bool, DataType::BOOL) \ + _(int8_t, DataType::INT8) \ + _(uint8_t, DataType::UINT8) \ + _(int16_t, DataType::INT16) \ + _(int, DataType::INT32) \ + _(int64_t, DataType::INT64) \ + _(float, DataType::FLOAT32) \ + _(double, DataType::FLOAT64) template struct DataTypeToCPPType; diff --git a/paddle/fluid/extension/include/tensor.h b/paddle/fluid/extension/include/tensor.h index 47af4dc70a..061dc3ded2 100644 --- a/paddle/fluid/extension/include/tensor.h +++ b/paddle/fluid/extension/include/tensor.h @@ -24,6 +24,7 @@ namespace paddle { namespace framework { class CustomTensorUtils; } // namespace framework + class PD_DLL_DECL Tensor { public: /// \brief Construct a Tensor on target Place for CustomOp. diff --git a/paddle/fluid/extension/src/tensor.cc b/paddle/fluid/extension/src/tensor.cc index 39ed274864..dc7e3607bd 100644 --- a/paddle/fluid/extension/src/tensor.cc +++ b/paddle/fluid/extension/src/tensor.cc @@ -159,17 +159,10 @@ DataType Tensor::type() const { return DataType::UINT8; } else if (type == framework::proto::VarType::FP64) { return DataType::FLOAT64; - } else if (type == framework::proto::VarType::BF16) { - return DataType::BFLOAT16; - } else if (type == framework::proto::VarType::FP16) { - return DataType::FLOAT16; - } else if (type == framework::proto::VarType::COMPLEX64) { - return DataType::COMPLEX64; - } else if (type == framework::proto::VarType::COMPLEX128) { - return DataType::COMPLEX128; } else if (type == framework::proto::VarType::BOOL) { return DataType::BOOL; } + // TODO(JiabinYang) Support more dtype here return DataType::FLOAT32; } @@ -207,14 +200,6 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const { return target; } -template PD_DLL_DECL Tensor -Tensor::copy_to(const PlaceType &target_place) const; -template PD_DLL_DECL Tensor Tensor::copy_to( - const PlaceType &target_place) const; -template PD_DLL_DECL Tensor Tensor::copy_to( - const PlaceType &target_place) const; -template PD_DLL_DECL Tensor Tensor::copy_to( - const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor @@ -238,14 +223,6 @@ template PD_DLL_DECL int64_t *Tensor::data() const; template PD_DLL_DECL int32_t *Tensor::data() const; template PD_DLL_DECL uint8_t *Tensor::data() const; template PD_DLL_DECL int8_t *Tensor::data() const; -template PD_DLL_DECL paddle::platform::float16 * -Tensor::data() const; -template PD_DLL_DECL paddle::platform::bfloat16 * -Tensor::data() const; -template PD_DLL_DECL paddle::platform::complex128 * -Tensor::data() const; -template PD_DLL_DECL paddle::platform::complex64 * -Tensor::data() const; template PD_DLL_DECL int16_t *Tensor::data() const; template PD_DLL_DECL bool *Tensor::data() const; @@ -255,14 +232,6 @@ template PD_DLL_DECL int64_t *Tensor::mutable_data(); template PD_DLL_DECL int32_t *Tensor::mutable_data(); template PD_DLL_DECL uint8_t *Tensor::mutable_data(); template PD_DLL_DECL int8_t *Tensor::mutable_data(); -template PD_DLL_DECL paddle::platform::float16 * -Tensor::mutable_data(); -template PD_DLL_DECL paddle::platform::bfloat16 * -Tensor::mutable_data(); -template PD_DLL_DECL paddle::platform::complex128 * -Tensor::mutable_data(); -template PD_DLL_DECL paddle::platform::complex64 * -Tensor::mutable_data(); template PD_DLL_DECL int16_t *Tensor::mutable_data(); template PD_DLL_DECL bool *Tensor::mutable_data(); @@ -277,14 +246,6 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL int8_t *Tensor::mutable_data( const PlaceType &place); -template PD_DLL_DECL paddle::platform::float16 * -Tensor::mutable_data(const PlaceType &place); -template PD_DLL_DECL paddle::platform::bfloat16 * -Tensor::mutable_data(const PlaceType &place); -template PD_DLL_DECL paddle::platform::complex128 * -Tensor::mutable_data(const PlaceType &place); -template PD_DLL_DECL paddle::platform::complex64 * -Tensor::mutable_data(const PlaceType &place); template PD_DLL_DECL int16_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL bool *Tensor::mutable_data(const PlaceType &place); @@ -320,14 +281,6 @@ Tensor Tensor::cast(const DataType &target_type) const { auto dst_type = framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type); switch (src_type) { - case framework::proto::VarType::FP16: - framework::VisitDataType( - dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); - break; - case framework::proto::VarType::BF16: - framework::VisitDataType(dst_type, CastDataType( - *tensor, rlt_tensor_, ctx)); - break; case framework::proto::VarType::FP32: framework::VisitDataType(dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); @@ -356,14 +309,7 @@ Tensor Tensor::cast(const DataType &target_type) const { framework::VisitDataType( dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; - case framework::proto::VarType::COMPLEX64: - framework::VisitDataType(dst_type, CastDataType( - *tensor, rlt_tensor_, ctx)); - break; - case framework::proto::VarType::COMPLEX128: - framework::VisitDataType(dst_type, CastDataType( - *tensor, rlt_tensor_, ctx)); - break; + // TODO(JiabinYang) Support more dtype here default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when casting data type.", diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc index 33b6624542..0f351c3bbd 100644 --- a/paddle/fluid/framework/custom_tensor_test.cc +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -91,7 +91,7 @@ void TestCast(paddle::DataType data_type) { t1.reshape(tensor_shape); t1.template mutable_data(); auto t2 = t1.cast(data_type); - CHECK_EQ(t2.type(), data_type); + CHECK(t2.type() == data_type); } void GroupTestCopy() { @@ -99,14 +99,6 @@ void GroupTestCopy() { TestCopyTensor(); VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); - VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu"; - TestCopyTensor(); - VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu"; - TestCopyTensor(); - VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; - TestCopyTensor(); - VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu"; - TestCopyTensor(); VLOG(2) << "int cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu"; @@ -128,31 +120,17 @@ void GroupTestCast() { TestCast(paddle::DataType::FLOAT32); VLOG(2) << "double cast"; TestCast(paddle::DataType::FLOAT32); - VLOG(2) << "bfloat16 cast"; - TestCast(paddle::DataType::FLOAT32); - VLOG(2) << "float16 cast"; - TestCast(paddle::DataType::FLOAT32); VLOG(2) << "bool cast"; TestCast(paddle::DataType::FLOAT32); VLOG(2) << "uint8 cast"; TestCast(paddle::DataType::FLOAT32); VLOG(2) << "float cast"; TestCast(paddle::DataType::FLOAT32); - VLOG(2) << "complex64 cast"; - TestCast(paddle::DataType::FLOAT32); - VLOG(2) << "complex128 cast"; - TestCast(paddle::DataType::FLOAT32); } void GroupTestDtype() { CHECK(TestDtype() == paddle::DataType::FLOAT32); CHECK(TestDtype() == paddle::DataType::FLOAT64); - CHECK(TestDtype() == paddle::DataType::FLOAT16); - CHECK(TestDtype() == paddle::DataType::BFLOAT16); - CHECK(TestDtype() == - paddle::DataType::COMPLEX128); - CHECK(TestDtype() == - paddle::DataType::COMPLEX64); CHECK(TestDtype() == paddle::DataType::INT32); CHECK(TestDtype() == paddle::DataType::INT64); CHECK(TestDtype() == paddle::DataType::INT16); @@ -162,24 +140,12 @@ void GroupTestDtype() { void GroupTestDtypeConvert() { // enum -> proto - CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( - paddle::DataType::COMPLEX128) == - paddle::framework::proto::VarType::COMPLEX128); - CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( - paddle::DataType::COMPLEX64) == - paddle::framework::proto::VarType::COMPLEX64); CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( paddle::DataType::FLOAT64) == paddle::framework::proto::VarType::FP64); CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( paddle::DataType::FLOAT32) == paddle::framework::proto::VarType::FP32); - CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( - paddle::DataType::FLOAT16) == - paddle::framework::proto::VarType::FP16); - CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( - paddle::DataType::BFLOAT16) == - paddle::framework::proto::VarType::BF16); CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( paddle::DataType::UINT8) == paddle::framework::proto::VarType::UINT8); @@ -197,24 +163,12 @@ void GroupTestDtypeConvert() { CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL); // proto -> enum - CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( - paddle::framework::proto::VarType::COMPLEX128) == - paddle::DataType::COMPLEX128); - CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( - paddle::framework::proto::VarType::COMPLEX64) == - paddle::DataType::COMPLEX64); CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( paddle::framework::proto::VarType::FP64) == paddle::DataType::FLOAT64); CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( paddle::framework::proto::VarType::FP32) == paddle::DataType::FLOAT32); - CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( - paddle::framework::proto::VarType::FP16) == - paddle::DataType::FLOAT16); - CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( - paddle::framework::proto::VarType::BF16) == - paddle::DataType::BFLOAT16); CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( paddle::framework::proto::VarType::INT64) == paddle::DataType::INT64); diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h index 4b465d3911..1dc4e06e57 100644 --- a/paddle/fluid/framework/custom_tensor_utils.h +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -39,18 +39,10 @@ class CustomTensorUtils { static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType( const paddle::DataType& dtype) { switch (dtype) { - case paddle::DataType::COMPLEX128: - return framework::proto::VarType::COMPLEX128; - case paddle::DataType::COMPLEX64: - return framework::proto::VarType::COMPLEX64; case paddle::DataType::FLOAT64: return framework::proto::VarType::FP64; case paddle::DataType::FLOAT32: return framework::proto::VarType::FP32; - case paddle::DataType::FLOAT16: - return framework::proto::VarType::FP16; - case paddle::DataType::BFLOAT16: - return framework::proto::VarType::BF16; case paddle::DataType::UINT8: return framework::proto::VarType::UINT8; case paddle::DataType::INT8: @@ -74,18 +66,10 @@ class CustomTensorUtils { static paddle::DataType ConvertInnerDTypeToEnumDType( const framework::proto::VarType::Type& dtype) { switch (dtype) { - case framework::proto::VarType::COMPLEX128: - return paddle::DataType::COMPLEX128; - case framework::proto::VarType::COMPLEX64: - return paddle::DataType::COMPLEX64; case framework::proto::VarType::FP64: return paddle::DataType::FLOAT64; case framework::proto::VarType::FP32: return paddle::DataType::FLOAT32; - case framework::proto::VarType::FP16: - return paddle::DataType::FLOAT16; - case framework::proto::VarType::BF16: - return paddle::DataType::BFLOAT16; case framework::proto::VarType::INT64: return paddle::DataType::INT64; case framework::proto::VarType::INT32: diff --git a/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc b/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc index 720be8b4e3..33ca6ee86f 100644 --- a/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc +++ b/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc @@ -44,24 +44,6 @@ PD_BUILD_OP(dispatch_test_integer) .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(DispatchTestInterger)); -std::vector DispatchTestComplex(const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kCPU); - out.reshape(x.shape()); - - PD_DISPATCH_COMPLEX_TYPES( - x.type(), "assign_cpu_kernel", ([&] { - assign_cpu_kernel( - x.data(), out.mutable_data(), x.size()); - })); - - return {out}; -} - -PD_BUILD_OP(dispatch_test_complex) - .Inputs({"X"}) - .Outputs({"Out"}) - .SetKernelFn(PD_KERNEL(DispatchTestComplex)); - std::vector DispatchTestFloatAndInteger( const paddle::Tensor& x) { auto out = paddle::Tensor(paddle::PlaceType::kCPU); @@ -80,41 +62,3 @@ PD_BUILD_OP(dispatch_test_float_and_integer) .Inputs({"X"}) .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger)); - -std::vector DispatchTestFloatAndComplex( - const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kCPU); - out.reshape(x.shape()); - - PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES( - x.type(), "assign_cpu_kernel", ([&] { - assign_cpu_kernel( - x.data(), out.mutable_data(), x.size()); - })); - - return {out}; -} - -PD_BUILD_OP(dispatch_test_float_and_complex) - .Inputs({"X"}) - .Outputs({"Out"}) - .SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex)); - -std::vector DispatchTestFloatAndIntegerAndComplex( - const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kCPU); - out.reshape(x.shape()); - - PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES( - x.type(), "assign_cpu_kernel", ([&] { - assign_cpu_kernel( - x.data(), out.mutable_data(), x.size()); - })); - - return {out}; -} - -PD_BUILD_OP(dispatch_test_float_and_integer_and_complex) - .Inputs({"X"}) - .Outputs({"Out"}) - .SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex)); diff --git a/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py index 54d317c37f..05808d3d22 100644 --- a/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py @@ -55,11 +55,6 @@ class TestJitDispatch(unittest.TestCase): for dtype in dtypes: self.run_dispatch_test(dispatch_op.dispatch_test_integer, dtype) - def test_dispatch_complex(self): - dtypes = ["complex64", "complex128"] - for dtype in dtypes: - self.run_dispatch_test(dispatch_op.dispatch_test_complex, dtype) - def test_dispatch_float_and_integer(self): dtypes = [ "float32", "float64", "int32", "int64", "int8", "uint8", "int16" @@ -68,21 +63,6 @@ class TestJitDispatch(unittest.TestCase): self.run_dispatch_test(dispatch_op.dispatch_test_float_and_integer, dtype) - def test_dispatch_float_and_complex(self): - dtypes = ["float32", "float64", "complex64", "complex128"] - for dtype in dtypes: - self.run_dispatch_test(dispatch_op.dispatch_test_float_and_complex, - dtype) - - def test_dispatch_float_and_integer_and_complex(self): - dtypes = [ - "float32", "float64", "int32", "int64", "int8", "uint8", "int16", - "complex64", "complex128" - ] - for dtype in dtypes: - self.run_dispatch_test( - dispatch_op.dispatch_test_float_and_integer_and_complex, dtype) - if __name__ == '__main__': unittest.main() -- GitLab