未验证 提交 8077d79a 编写于 作者: Q Qi Shao 提交者: GitHub

[Cherry-Pick] Modify the bf16 accuracy checking framework in OpTest (#54658)

* modify the bf16 accuracy checking framework in OpTest

* modify the bf16 accuracy checking framework in OpTest

* modify the bf16 accuracy checking framework in OpTest

* modify the bf16 accuracy checking framework in OpTest

* modify the bf16 accuracy checking framework in OpTest

* modify the bf16 accuracy checking framework in OpTest
上级 0abd9ffd
...@@ -550,8 +550,17 @@ class OpTest(unittest.TestCase): ...@@ -550,8 +550,17 @@ class OpTest(unittest.TestCase):
not in op_accuracy_white_list.NO_FP16_COMPARED_WITH_FP32_OP_LIST not in op_accuracy_white_list.NO_FP16_COMPARED_WITH_FP32_OP_LIST
) )
def is_bf16_compared_with_fp32(self):
return self.is_bfloat16_op() and (
self.op_type
not in op_accuracy_white_list.NO_BF16_COMPARED_WITH_FP32_OP_LIST
)
def enable_cal_ref_output(self): def enable_cal_ref_output(self):
self.is_calc_ref = self.is_fp16_compared_with_fp32() self.is_calc_ref = (
self.is_fp16_compared_with_fp32()
or self.is_bf16_compared_with_fp32()
)
def disable_cal_ref_output(self): def disable_cal_ref_output(self):
self.is_calc_ref = False self.is_calc_ref = False
...@@ -652,7 +661,10 @@ class OpTest(unittest.TestCase): ...@@ -652,7 +661,10 @@ class OpTest(unittest.TestCase):
if isinstance(np_value, tuple): if isinstance(np_value, tuple):
tensor.set(np_value[0], place) tensor.set(np_value[0], place)
dtype = np.array(np_value[1]).dtype dtype = np.array(np_value[1]).dtype
if self.is_calc_ref and dtype == np.float16:
if self.is_calc_ref:
# convert the float16 to float by numpy.astype
if dtype == np.float16:
if isinstance(np_value[1], list): if isinstance(np_value[1], list):
tensor.set_recursive_sequence_lengths( tensor.set_recursive_sequence_lengths(
np.array(np_value[1]).astype(np.float32) np.array(np_value[1]).astype(np.float32)
...@@ -661,11 +673,35 @@ class OpTest(unittest.TestCase): ...@@ -661,11 +673,35 @@ class OpTest(unittest.TestCase):
tensor.set_recursive_sequence_lengths( tensor.set_recursive_sequence_lengths(
np_value[1].astype(np.float32) np_value[1].astype(np.float32)
) )
# convert the bfloat16 to float by convert_uint16_to_float
# provided in this file
elif dtype == np.uint16:
if isinstance(np_value[1], list):
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(
np.array(np_value[1])
)
)
else:
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(np_value[1])
)
else:
tensor.set_recursive_sequence_lengths(
np_value[1]
)
else: else:
tensor.set_recursive_sequence_lengths(np_value[1]) tensor.set_recursive_sequence_lengths(np_value[1])
else: else:
if self.is_calc_ref and np_value.dtype == np.float16: if self.is_calc_ref:
if np_value.dtype == np.float16:
tensor.set(np_value.astype(np.float32), place) tensor.set(np_value.astype(np.float32), place)
elif np_value.dtype == np.uint16:
tensor.set(
convert_uint16_to_float(np_value), place
)
else:
tensor.set(np_value, place)
else: else:
tensor.set(np_value, place) tensor.set(np_value, place)
feed_map[name] = tensor feed_map[name] = tensor
...@@ -673,25 +709,38 @@ class OpTest(unittest.TestCase): ...@@ -673,25 +709,38 @@ class OpTest(unittest.TestCase):
tensor = core.LoDTensor() tensor = core.LoDTensor()
if isinstance(self.inputs[var_name], tuple): if isinstance(self.inputs[var_name], tuple):
tensor.set(self.inputs[var_name][0], place) tensor.set(self.inputs[var_name][0], place)
if ( if self.is_calc_ref:
self.is_calc_ref if self.inputs[var_name][1].dtype == np.float16:
and self.inputs[var_name][1].dtype == np.float16
):
tensor.set_recursive_sequence_lengths( tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1].astype(np.float32) self.inputs[var_name][1].astype(np.float32)
) )
elif self.inputs[var_name][1].dtype == np.uint16:
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(
self.inputs[var_name][1]
)
)
else: else:
tensor.set_recursive_sequence_lengths( tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1] self.inputs[var_name][1]
) )
else: else:
if ( tensor.set_recursive_sequence_lengths(
self.is_calc_ref self.inputs[var_name][1]
and self.inputs[var_name].dtype == np.float16 )
): else:
if self.is_calc_ref:
if self.inputs[var_name].dtype == np.float16:
tensor.set( tensor.set(
self.inputs[var_name].astype(np.float32), place self.inputs[var_name].astype(np.float32), place
) )
elif self.inputs[var_name].dtype == np.uint16:
tensor.set(
convert_uint16_to_float(self.inputs[var_name]),
place,
)
else:
tensor.set(self.inputs[var_name], place)
else: else:
tensor.set(self.inputs[var_name], place) tensor.set(self.inputs[var_name], place)
feed_map[var_name] = tensor feed_map[var_name] = tensor
...@@ -1761,7 +1810,10 @@ class OpTest(unittest.TestCase): ...@@ -1761,7 +1810,10 @@ class OpTest(unittest.TestCase):
def compare_single_output_with_expect(self, name, expect): def compare_single_output_with_expect(self, name, expect):
actual, actual_np = self.find_actual_value(name) actual, actual_np = self.find_actual_value(name)
# expect_np = expect[0] if isinstance(expect, tuple) else expect # expect_np = expect[0] if isinstance(expect, tuple) else expect
if self.op_test.is_fp16_compared_with_fp32(): if (
self.op_test.is_fp16_compared_with_fp32()
or self.op_test.is_bf16_compared_with_fp32()
):
expect, expect_np = self.find_expect_value(name) expect, expect_np = self.find_expect_value(name)
else: else:
expect_np = ( expect_np = (
...@@ -1816,7 +1868,10 @@ class OpTest(unittest.TestCase): ...@@ -1816,7 +1868,10 @@ class OpTest(unittest.TestCase):
) )
self.outputs = outs self.outputs = outs
self.fetch_list = fetch_list self.fetch_list = fetch_list
if self.op_test.is_fp16_compared_with_fp32(): if (
self.op_test.is_fp16_compared_with_fp32()
or self.op_test.is_bf16_compared_with_fp32()
):
self.op_test.enable_cal_ref_output() self.op_test.enable_cal_ref_output()
ref_outs, ref_fetch_list = self.op_test._calc_output( ref_outs, ref_fetch_list = self.op_test._calc_output(
place, no_check_set=no_check_set place, no_check_set=no_check_set
...@@ -1883,7 +1938,10 @@ class OpTest(unittest.TestCase): ...@@ -1883,7 +1938,10 @@ class OpTest(unittest.TestCase):
place, no_check_set=no_check_set place, no_check_set=no_check_set
) )
self.outputs = dygraph_outs self.outputs = dygraph_outs
if self.op_test.is_fp16_compared_with_fp32(): if (
self.op_test.is_fp16_compared_with_fp32()
or self.op_test.is_bf16_compared_with_fp32()
):
self.op_test.enable_cal_ref_output() self.op_test.enable_cal_ref_output()
self.is_python_api_test = True self.is_python_api_test = True
self.ref_outputs = self.op_test._calc_python_api_output( self.ref_outputs = self.op_test._calc_python_api_output(
...@@ -2228,9 +2286,8 @@ class OpTest(unittest.TestCase): ...@@ -2228,9 +2286,8 @@ class OpTest(unittest.TestCase):
atol=atol, atol=atol,
equal_nan=False, equal_nan=False,
err_msg=( err_msg=(
"Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff over limit" "Operator {} error, {} variable {} (shape: {}, dtype: {}) max gradient diff over limit"
) ).format(
% (
self.op_type, self.op_type,
msg_prefix, msg_prefix,
name, name,
...@@ -2486,7 +2543,10 @@ class OpTest(unittest.TestCase): ...@@ -2486,7 +2543,10 @@ class OpTest(unittest.TestCase):
if numeric_place is None: if numeric_place is None:
numeric_place = place numeric_place = place
if user_defined_grads is None and self.is_fp16_compared_with_fp32(): if user_defined_grads is None and (
self.is_fp16_compared_with_fp32()
or self.is_bf16_compared_with_fp32()
):
self.enable_cal_ref_output() self.enable_cal_ref_output()
numeric_grads = self._get_gradient( numeric_grads = self._get_gradient(
inputs_to_check, inputs_to_check,
...@@ -2769,7 +2829,7 @@ class OpTest(unittest.TestCase): ...@@ -2769,7 +2829,7 @@ class OpTest(unittest.TestCase):
feed_dict = self.feed_var(inputs, place) feed_dict = self.feed_var(inputs, place)
if user_defined_grad_outputs is None: if user_defined_grad_outputs is None:
if self.dtype == np.uint16: if self.dtype == np.uint16 and not self.is_calc_ref:
cast_inputs = list(map(block.var, output_names)) cast_inputs = list(map(block.var, output_names))
if self.op_type in ["broadcast_tensors", "meshgrid"]: if self.op_type in ["broadcast_tensors", "meshgrid"]:
output_names = self.cast_bf16_output(block, cast_inputs) output_names = self.cast_bf16_output(block, cast_inputs)
......
...@@ -120,7 +120,7 @@ def append_input_output( ...@@ -120,7 +120,7 @@ def append_input_output(
if is_input: if is_input:
shape = list(np_value.shape) shape = list(np_value.shape)
lod_level = 0 lod_level = 0
if is_calc_ref and dtype == np.float16: if is_calc_ref and (dtype == np.float16 or dtype == np.uint16):
dtype = np.float32 dtype = np.float32
return block.create_var( return block.create_var(
dtype=dtype, shape=shape, lod_level=lod_level, name=name dtype=dtype, shape=shape, lod_level=lod_level, name=name
......
...@@ -94,3 +94,11 @@ NO_FP16_COMPARED_WITH_FP32_OP_LIST = [ ...@@ -94,3 +94,11 @@ NO_FP16_COMPARED_WITH_FP32_OP_LIST = [
'fake_quantize_moving_average_abs_max', 'fake_quantize_moving_average_abs_max',
'p_norm', 'p_norm',
] ]
NO_BF16_COMPARED_WITH_FP32_OP_LIST = [
'unique',
'fusion_gru',
'fusion_lstm',
'dequantize',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册