未验证 提交 6506668e 编写于 作者: L limingshu 提交者: GitHub

Addition of fp16 type support for Compare OP (#44405)

* first commit

* add fp16 ctest files for compare op

* add cpu register of float16 for compare ops
上级 c693a027
...@@ -80,7 +80,8 @@ PD_REGISTER_KERNEL(less_than, ...@@ -80,7 +80,8 @@ PD_REGISTER_KERNEL(less_than,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(less_equal, PD_REGISTER_KERNEL(less_equal,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -90,7 +91,8 @@ PD_REGISTER_KERNEL(less_equal, ...@@ -90,7 +91,8 @@ PD_REGISTER_KERNEL(less_equal,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_than, PD_REGISTER_KERNEL(greater_than,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -100,7 +102,8 @@ PD_REGISTER_KERNEL(greater_than, ...@@ -100,7 +102,8 @@ PD_REGISTER_KERNEL(greater_than,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_equal, PD_REGISTER_KERNEL(greater_equal,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -110,7 +113,8 @@ PD_REGISTER_KERNEL(greater_equal, ...@@ -110,7 +113,8 @@ PD_REGISTER_KERNEL(greater_equal,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal, PD_REGISTER_KERNEL(equal,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -120,7 +124,8 @@ PD_REGISTER_KERNEL(equal, ...@@ -120,7 +124,8 @@ PD_REGISTER_KERNEL(equal,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(not_equal, PD_REGISTER_KERNEL(not_equal,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -130,7 +135,8 @@ PD_REGISTER_KERNEL(not_equal, ...@@ -130,7 +135,8 @@ PD_REGISTER_KERNEL(not_equal,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal_all, PD_REGISTER_KERNEL(equal_all,
CPU, CPU,
......
...@@ -113,7 +113,8 @@ PD_REGISTER_KERNEL(less_than, ...@@ -113,7 +113,8 @@ PD_REGISTER_KERNEL(less_than,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(less_equal, PD_REGISTER_KERNEL(less_equal,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -123,7 +124,8 @@ PD_REGISTER_KERNEL(less_equal, ...@@ -123,7 +124,8 @@ PD_REGISTER_KERNEL(less_equal,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_than, PD_REGISTER_KERNEL(greater_than,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -133,7 +135,8 @@ PD_REGISTER_KERNEL(greater_than, ...@@ -133,7 +135,8 @@ PD_REGISTER_KERNEL(greater_than,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_equal, PD_REGISTER_KERNEL(greater_equal,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -143,7 +146,8 @@ PD_REGISTER_KERNEL(greater_equal, ...@@ -143,7 +146,8 @@ PD_REGISTER_KERNEL(greater_equal,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal, PD_REGISTER_KERNEL(equal,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -153,7 +157,8 @@ PD_REGISTER_KERNEL(equal, ...@@ -153,7 +157,8 @@ PD_REGISTER_KERNEL(equal,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(not_equal, PD_REGISTER_KERNEL(not_equal,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -163,7 +168,8 @@ PD_REGISTER_KERNEL(not_equal, ...@@ -163,7 +168,8 @@ PD_REGISTER_KERNEL(not_equal,
int, int,
int64_t, int64_t,
float, float,
double) {} double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal_all, PD_REGISTER_KERNEL(equal_all,
KPS, KPS,
......
...@@ -62,9 +62,11 @@ def create_test_class(op_type, typename, callback): ...@@ -62,9 +62,11 @@ def create_test_class(op_type, typename, callback):
globals()[cls_name] = Cls globals()[cls_name] = Cls
for _type_name in {'float32', 'float64', 'int32', 'int64'}: for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}:
if _type_name == 'float64' and core.is_compiled_with_rocm(): if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32' _type_name = 'float32'
if _type_name == 'float16' and (not core.is_compiled_with_cuda()):
continue
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册