未验证 提交 a7397e0c 编写于 作者: Y yeliang2258 提交者: GitHub

[AMP] Add bfloat16 and float16 tests for compare ops (#51978)

* add bf16 and fp16 tests

* fix dtype check
上级 9c853d1d
......@@ -94,7 +94,8 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, \
float, \
double, \
phi::dtype::float16) { \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
......@@ -107,7 +108,8 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, \
float, \
double, \
phi::dtype::float16) { \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
......
......@@ -136,6 +136,15 @@ def create_paddle_case(op_type, callback):
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_dynamic_api_float16(self):
paddle.disable_static()
x = paddle.to_tensor(self.input_x, dtype="float16")
y = paddle.to_tensor(self.input_y, dtype="float16")
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_dynamic_api_inf_1(self):
if self.op_type == "equal":
paddle.disable_static()
......@@ -434,6 +443,39 @@ create_paddle_case('equal', lambda _a, _b: _a == _b)
create_paddle_case('not_equal', lambda _a, _b: _a != _b)
# add bf16 tests
def create_bf16_case(op_type, callback):
class TestCompareOpBF16Op(op_test.OpTest):
def setUp(self):
self.op_type = op_type
self.dtype = np.uint16
self.python_api = eval("paddle." + op_type)
x = np.random.uniform(0, 1, [5, 5]).astype(np.float32)
y = np.random.uniform(0, 1, [5, 5]).astype(np.float32)
real_result = callback(x, y)
self.inputs = {
'X': op_test.convert_float_to_uint16(x),
'Y': op_test.convert_float_to_uint16(y),
}
self.outputs = {'Out': real_result}
def test_check_output(self):
self.check_output()
cls_name = "BF16TestCase_{}".format(op_type)
TestCompareOpBF16Op.__name__ = cls_name
globals()[cls_name] = TestCompareOpBF16Op
create_bf16_case('less_than', lambda _a, _b: _a < _b)
create_bf16_case('less_equal', lambda _a, _b: _a <= _b)
create_bf16_case('greater_than', lambda _a, _b: _a > _b)
create_bf16_case('greater_equal', lambda _a, _b: _a >= _b)
create_bf16_case('equal', lambda _a, _b: _a == _b)
create_bf16_case('not_equal', lambda _a, _b: _a != _b)
class TestCompareOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
......
......@@ -746,13 +746,13 @@ def not_equal(x, y, name=None):
check_variable_and_dtype(
x,
"x",
["bool", "float32", "float64", "int32", "int64"],
["bool", "float16", "float32", "float64", "int32", "int64"],
"not_equal",
)
check_variable_and_dtype(
y,
"y",
["bool", "float32", "float64", "int32", "int64"],
["bool", "float16", "float32", "float64", "int32", "int64"],
"not_equal",
)
helper = LayerHelper("not_equal", **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册