diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index 694b44c16d80e409493cbc04b03dcabff5f63903..ae6c3fd5cb020110d5b8b95ab313e94d6724b843 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -80,7 +80,8 @@ PD_REGISTER_KERNEL(less_than, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(less_equal, CPU, ALL_LAYOUT, @@ -90,7 +91,8 @@ PD_REGISTER_KERNEL(less_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(greater_than, CPU, ALL_LAYOUT, @@ -100,7 +102,8 @@ PD_REGISTER_KERNEL(greater_than, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(greater_equal, CPU, ALL_LAYOUT, @@ -110,7 +113,8 @@ PD_REGISTER_KERNEL(greater_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(equal, CPU, ALL_LAYOUT, @@ -120,7 +124,8 @@ PD_REGISTER_KERNEL(equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(not_equal, CPU, ALL_LAYOUT, @@ -130,7 +135,8 @@ PD_REGISTER_KERNEL(not_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(equal_all, CPU, diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index 0b0990627f0beb8a15715d5224e81843d4645acc..b981d802255a2a4550034b69fe7d7fc10e341587 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -113,7 +113,8 @@ PD_REGISTER_KERNEL(less_than, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, @@ -123,7 +124,8 @@ PD_REGISTER_KERNEL(less_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, @@ -133,7 +135,8 @@ PD_REGISTER_KERNEL(greater_than, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(greater_equal, KPS, ALL_LAYOUT, @@ -143,7 +146,8 @@ PD_REGISTER_KERNEL(greater_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, @@ -153,7 +157,8 @@ PD_REGISTER_KERNEL(equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, @@ -163,7 +168,8 @@ PD_REGISTER_KERNEL(not_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(equal_all, KPS, diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index a893b65f5a4211a05e163e7b733a1b9e417b4371..f1bacfbb6f8f8a203ec0295ffd6a5ea43ffd7522 100755 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -62,9 +62,11 @@ def create_test_class(op_type, typename, callback): globals()[cls_name] = Cls -for _type_name in {'float32', 'float64', 'int32', 'int64'}: +for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}: if _type_name == 'float64' and core.is_compiled_with_rocm(): _type_name = 'float32' + if _type_name == 'float16' and (not core.is_compiled_with_cuda()): + continue create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)