From 6506668ebe345ea1cc9fa27aa85cc658dcdbc3c5 Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Thu, 4 Aug 2022 15:56:41 +0800 Subject: [PATCH] Addition of fp16 type support for Compare OP (#44405) * first commit * add fp16 ctest files for compare op * add cpu register of float16 for compare ops --- paddle/phi/kernels/cpu/compare_kernel.cc | 18 ++++++++++++------ paddle/phi/kernels/kps/compare_kernel.cu | 18 ++++++++++++------ .../fluid/tests/unittests/test_compare_op.py | 4 +++- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index 694b44c16d..ae6c3fd5cb 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -80,7 +80,8 @@ PD_REGISTER_KERNEL(less_than, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(less_equal, CPU, ALL_LAYOUT, @@ -90,7 +91,8 @@ PD_REGISTER_KERNEL(less_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(greater_than, CPU, ALL_LAYOUT, @@ -100,7 +102,8 @@ PD_REGISTER_KERNEL(greater_than, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(greater_equal, CPU, ALL_LAYOUT, @@ -110,7 +113,8 @@ PD_REGISTER_KERNEL(greater_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(equal, CPU, ALL_LAYOUT, @@ -120,7 +124,8 @@ PD_REGISTER_KERNEL(equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(not_equal, CPU, ALL_LAYOUT, @@ -130,7 +135,8 @@ PD_REGISTER_KERNEL(not_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(equal_all, CPU, diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index 0b0990627f..b981d80225 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -113,7 +113,8 @@ PD_REGISTER_KERNEL(less_than, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, @@ -123,7 +124,8 @@ PD_REGISTER_KERNEL(less_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, @@ -133,7 +135,8 @@ PD_REGISTER_KERNEL(greater_than, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(greater_equal, KPS, ALL_LAYOUT, @@ -143,7 +146,8 @@ PD_REGISTER_KERNEL(greater_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, @@ -153,7 +157,8 @@ PD_REGISTER_KERNEL(equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, @@ -163,7 +168,8 @@ PD_REGISTER_KERNEL(not_equal, int, int64_t, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(equal_all, KPS, diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index a893b65f5a..f1bacfbb6f 100755 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -62,9 +62,11 @@ def create_test_class(op_type, typename, callback): globals()[cls_name] = Cls -for _type_name in {'float32', 'float64', 'int32', 'int64'}: +for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}: if _type_name == 'float64' and core.is_compiled_with_rocm(): _type_name = 'float32' + if _type_name == 'float16' and (not core.is_compiled_with_cuda()): + continue create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) -- GitLab