未验证 提交 6506668e 编写于 作者: L limingshu 提交者: GitHub

Addition of fp16 type support for Compare OP (#44405)

* first commit

* add fp16 ctest files for compare op

* add cpu register of float16 for compare ops
上级 c693a027
......@@ -80,7 +80,8 @@ PD_REGISTER_KERNEL(less_than,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(less_equal,
CPU,
ALL_LAYOUT,
......@@ -90,7 +91,8 @@ PD_REGISTER_KERNEL(less_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_than,
CPU,
ALL_LAYOUT,
......@@ -100,7 +102,8 @@ PD_REGISTER_KERNEL(greater_than,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_equal,
CPU,
ALL_LAYOUT,
......@@ -110,7 +113,8 @@ PD_REGISTER_KERNEL(greater_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal,
CPU,
ALL_LAYOUT,
......@@ -120,7 +124,8 @@ PD_REGISTER_KERNEL(equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(not_equal,
CPU,
ALL_LAYOUT,
......@@ -130,7 +135,8 @@ PD_REGISTER_KERNEL(not_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal_all,
CPU,
......
......@@ -113,7 +113,8 @@ PD_REGISTER_KERNEL(less_than,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(less_equal,
KPS,
ALL_LAYOUT,
......@@ -123,7 +124,8 @@ PD_REGISTER_KERNEL(less_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_than,
KPS,
ALL_LAYOUT,
......@@ -133,7 +135,8 @@ PD_REGISTER_KERNEL(greater_than,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_equal,
KPS,
ALL_LAYOUT,
......@@ -143,7 +146,8 @@ PD_REGISTER_KERNEL(greater_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal,
KPS,
ALL_LAYOUT,
......@@ -153,7 +157,8 @@ PD_REGISTER_KERNEL(equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(not_equal,
KPS,
ALL_LAYOUT,
......@@ -163,7 +168,8 @@ PD_REGISTER_KERNEL(not_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal_all,
KPS,
......
......@@ -62,9 +62,11 @@ def create_test_class(op_type, typename, callback):
globals()[cls_name] = Cls
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}:
if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32'
if _type_name == 'float16' and (not core.is_compiled_with_cuda()):
continue
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册