未验证 提交 6f7ceca0 编写于 作者: Q Qi Shao 提交者: GitHub

Modify bf16 and fix the elementwise_max (#54799)

* modify the accuracy checking framework of bf16 optest, including both of forward and backward
上级 4c5ce835
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/common_shape.h"
...@@ -114,40 +115,42 @@ static void ElemwiseGradBroadcast1CPU(const T *x, ...@@ -114,40 +115,42 @@ static void ElemwiseGradBroadcast1CPU(const T *x,
DY_OP dy_op, DY_OP dy_op,
T *dx, T *dx,
T *dy) { T *dy) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
if (is_xsize_larger) { if (is_xsize_larger) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) { for (int j = 0; j < w; ++j) {
MPType sum_y = static_cast<MPType>(0);
for (int i = 0; i < h; ++i) {
int x_offset = i * w + j; int x_offset = i * w + j;
if (dx != nullptr) { if (dx != nullptr) {
dx[x_offset] = dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
} }
if (dy != nullptr) { if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); sum_y += static_cast<MPType>(
if (i == 0) { dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]));
dy[j] = tmp;
} else {
dy[j] += tmp;
} }
} }
if (dy != nullptr) {
dy[j] = static_cast<T>(sum_y);
} }
} }
} else { // x.dims < y.dims, broadcast for x. } else {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) { for (int j = 0; j < w; ++j) {
MPType sum_x = static_cast<MPType>(0);
for (int i = 0; i < h; ++i) {
int y_offset = i * w + j; int y_offset = i * w + j;
if (dy != nullptr) { if (dy != nullptr) {
dy[y_offset] = dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
} }
if (dx != nullptr) { if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); sum_x += static_cast<MPType>(
if (i == 0) { dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]));
dx[j] = tmp;
} else {
dx[j] += tmp;
} }
} }
if (dx != nullptr) {
dx[j] = static_cast<T>(sum_x);
} }
} }
} }
...@@ -166,9 +169,12 @@ static void ElemwiseGradBroadcast2CPU(const T *x, ...@@ -166,9 +169,12 @@ static void ElemwiseGradBroadcast2CPU(const T *x,
DY_OP dy_op, DY_OP dy_op,
T *dx, T *dx,
T *dy) { T *dy) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
if (is_xsize_larger) { if (is_xsize_larger) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
MPType sum_y = static_cast<MPType>(0);
for (int i = 0; i < pre; ++i) {
for (int k = 0; k < post; ++k) { for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k; int x_offset = i * n * post + j * post + k;
if (dx != nullptr) { if (dx != nullptr) {
...@@ -176,19 +182,19 @@ static void ElemwiseGradBroadcast2CPU(const T *x, ...@@ -176,19 +182,19 @@ static void ElemwiseGradBroadcast2CPU(const T *x,
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
} }
if (dy != nullptr) { if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); sum_y += static_cast<MPType>(
if (i == 0 && k == 0) { dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]));
dy[j] = tmp;
} else {
dy[j] += tmp;
} }
} }
} }
if (dy != nullptr) {
dy[j] = static_cast<T>(sum_y);
} }
} }
} else { // x.dims < y.dims, broadcast for x. } else {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
MPType sum_x = static_cast<MPType>(0);
for (int i = 0; i < pre; ++i) {
for (int k = 0; k < post; ++k) { for (int k = 0; k < post; ++k) {
int y_offset = i * n * post + j * post + k; int y_offset = i * n * post + j * post + k;
if (dy != nullptr) { if (dy != nullptr) {
...@@ -196,14 +202,13 @@ static void ElemwiseGradBroadcast2CPU(const T *x, ...@@ -196,14 +202,13 @@ static void ElemwiseGradBroadcast2CPU(const T *x,
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
} }
if (dx != nullptr) { if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); sum_x += static_cast<MPType>(
if (i == 0 && k == 0) { dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]));
dx[j] = tmp;
} else {
dx[j] += tmp;
} }
} }
} }
if (dx != nullptr) {
dx[j] = static_cast<T>(sum_x);
} }
} }
} }
......
...@@ -552,8 +552,20 @@ class OpTest(unittest.TestCase): ...@@ -552,8 +552,20 @@ class OpTest(unittest.TestCase):
not in op_accuracy_white_list.NO_FP16_COMPARED_WITH_FP32_OP_LIST not in op_accuracy_white_list.NO_FP16_COMPARED_WITH_FP32_OP_LIST
) )
def is_bf16_compared_with_fp32(self):
return self.is_bfloat16_op() and (
self.op_type
not in op_accuracy_white_list.NO_BF16_COMPARED_WITH_FP32_OP_LIST
)
def is_compared_with_fp32(self):
return (
self.is_fp16_compared_with_fp32()
or self.is_bf16_compared_with_fp32()
)
def enable_cal_ref_output(self): def enable_cal_ref_output(self):
self.is_calc_ref = self.is_fp16_compared_with_fp32() self.is_calc_ref = True
def disable_cal_ref_output(self): def disable_cal_ref_output(self):
self.is_calc_ref = False self.is_calc_ref = False
...@@ -654,7 +666,10 @@ class OpTest(unittest.TestCase): ...@@ -654,7 +666,10 @@ class OpTest(unittest.TestCase):
if isinstance(np_value, tuple): if isinstance(np_value, tuple):
tensor.set(np_value[0], place) tensor.set(np_value[0], place)
dtype = np.array(np_value[1]).dtype dtype = np.array(np_value[1]).dtype
if self.is_calc_ref and dtype == np.float16:
if self.is_calc_ref:
# convert the float16 to float by numpy.astype
if dtype == np.float16:
if isinstance(np_value[1], list): if isinstance(np_value[1], list):
tensor.set_recursive_sequence_lengths( tensor.set_recursive_sequence_lengths(
np.array(np_value[1]).astype(np.float32) np.array(np_value[1]).astype(np.float32)
...@@ -663,11 +678,35 @@ class OpTest(unittest.TestCase): ...@@ -663,11 +678,35 @@ class OpTest(unittest.TestCase):
tensor.set_recursive_sequence_lengths( tensor.set_recursive_sequence_lengths(
np_value[1].astype(np.float32) np_value[1].astype(np.float32)
) )
# convert the bfloat16 to float by convert_uint16_to_float
# provided in this file
elif dtype == np.uint16:
if isinstance(np_value[1], list):
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(
np.array(np_value[1])
)
)
else:
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(np_value[1])
)
else:
tensor.set_recursive_sequence_lengths(
np_value[1]
)
else: else:
tensor.set_recursive_sequence_lengths(np_value[1]) tensor.set_recursive_sequence_lengths(np_value[1])
else: else:
if self.is_calc_ref and np_value.dtype == np.float16: if self.is_calc_ref:
if np_value.dtype == np.float16:
tensor.set(np_value.astype(np.float32), place) tensor.set(np_value.astype(np.float32), place)
elif np_value.dtype == np.uint16:
tensor.set(
convert_uint16_to_float(np_value), place
)
else:
tensor.set(np_value, place)
else: else:
tensor.set(np_value, place) tensor.set(np_value, place)
feed_map[name] = tensor feed_map[name] = tensor
...@@ -675,25 +714,57 @@ class OpTest(unittest.TestCase): ...@@ -675,25 +714,57 @@ class OpTest(unittest.TestCase):
tensor = core.LoDTensor() tensor = core.LoDTensor()
if isinstance(self.inputs[var_name], tuple): if isinstance(self.inputs[var_name], tuple):
tensor.set(self.inputs[var_name][0], place) tensor.set(self.inputs[var_name][0], place)
if ( if self.is_calc_ref:
self.is_calc_ref if isinstance(self.inputs[var_name][1], list):
and self.inputs[var_name][1].dtype == np.float16 dtype = np.array(self.inputs[var_name][1]).dtype
): if dtype == np.float16:
tensor.set_recursive_sequence_lengths(
np.array(self.inputs[var_name][1]).astype(
np.float32
)
)
elif dtype == np.uint16:
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(
np.array(self.inputs[var_name][1])
)
)
else:
tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1]
)
elif self.inputs[var_name][1].dtype == np.float16:
tensor.set_recursive_sequence_lengths( tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1].astype(np.float32) self.inputs[var_name][1].astype(np.float32)
) )
elif self.inputs[var_name][1].dtype == np.uint16:
tensor.set_recursive_sequence_lengths(
convert_uint16_to_float(
self.inputs[var_name][1]
)
)
else: else:
tensor.set_recursive_sequence_lengths( tensor.set_recursive_sequence_lengths(
self.inputs[var_name][1] self.inputs[var_name][1]
) )
else: else:
if ( tensor.set_recursive_sequence_lengths(
self.is_calc_ref self.inputs[var_name][1]
and self.inputs[var_name].dtype == np.float16 )
): else:
if self.is_calc_ref:
if self.inputs[var_name].dtype == np.float16:
tensor.set( tensor.set(
self.inputs[var_name].astype(np.float32), place self.inputs[var_name].astype(np.float32), place
) )
elif self.inputs[var_name].dtype == np.uint16:
tensor.set(
convert_uint16_to_float(self.inputs[var_name]),
place,
)
else:
tensor.set(self.inputs[var_name], place)
else: else:
tensor.set(self.inputs[var_name], place) tensor.set(self.inputs[var_name], place)
feed_map[var_name] = tensor feed_map[var_name] = tensor
...@@ -711,7 +782,8 @@ class OpTest(unittest.TestCase): ...@@ -711,7 +782,8 @@ class OpTest(unittest.TestCase):
self.__class__.use_xpu = True self.__class__.use_xpu = True
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case" # "infer datatype from inputs and outputs for this test case"
if self.is_float16_op(): if self.is_float16_op():
self.dtype = np.float16 self.dtype = np.float16
self.__class__.dtype = self.dtype self.__class__.dtype = self.dtype
...@@ -722,6 +794,7 @@ class OpTest(unittest.TestCase): ...@@ -722,6 +794,7 @@ class OpTest(unittest.TestCase):
self.output_dtype = np.uint16 self.output_dtype = np.uint16
else: else:
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
inputs = append_input_output( inputs = append_input_output(
block, op_proto, self.inputs, True, self.dtype, self.is_calc_ref block, op_proto, self.inputs, True, self.dtype, self.is_calc_ref
) )
...@@ -1809,7 +1882,7 @@ class OpTest(unittest.TestCase): ...@@ -1809,7 +1882,7 @@ class OpTest(unittest.TestCase):
def compare_single_output_with_expect(self, name, expect): def compare_single_output_with_expect(self, name, expect):
actual, actual_np = self.find_actual_value(name) actual, actual_np = self.find_actual_value(name)
# expect_np = expect[0] if isinstance(expect, tuple) else expect # expect_np = expect[0] if isinstance(expect, tuple) else expect
if self.op_test.is_fp16_compared_with_fp32(): if self.op_test.is_compared_with_fp32():
expect, expect_np = self.find_expect_value(name) expect, expect_np = self.find_expect_value(name)
else: else:
expect_np = ( expect_np = (
...@@ -1864,7 +1937,7 @@ class OpTest(unittest.TestCase): ...@@ -1864,7 +1937,7 @@ class OpTest(unittest.TestCase):
) )
self.outputs = outs self.outputs = outs
self.fetch_list = fetch_list self.fetch_list = fetch_list
if self.op_test.is_fp16_compared_with_fp32(): if self.op_test.is_compared_with_fp32():
self.op_test.enable_cal_ref_output() self.op_test.enable_cal_ref_output()
ref_outs, ref_fetch_list = self.op_test._calc_output( ref_outs, ref_fetch_list = self.op_test._calc_output(
place, no_check_set=no_check_set place, no_check_set=no_check_set
...@@ -1931,7 +2004,7 @@ class OpTest(unittest.TestCase): ...@@ -1931,7 +2004,7 @@ class OpTest(unittest.TestCase):
place, no_check_set=no_check_set place, no_check_set=no_check_set
) )
self.outputs = dygraph_outs self.outputs = dygraph_outs
if self.op_test.is_fp16_compared_with_fp32(): if self.op_test.is_compared_with_fp32():
self.op_test.enable_cal_ref_output() self.op_test.enable_cal_ref_output()
self.is_python_api_test = True self.is_python_api_test = True
self.ref_outputs = self.op_test._calc_python_api_output( self.ref_outputs = self.op_test._calc_python_api_output(
...@@ -2461,8 +2534,6 @@ class OpTest(unittest.TestCase): ...@@ -2461,8 +2534,6 @@ class OpTest(unittest.TestCase):
if self.is_mkldnn_op(): if self.is_mkldnn_op():
check_dygraph = False check_dygraph = False
atol = 1e-2 if atol < 1e-2 else atol atol = 1e-2 if atol < 1e-2 else atol
else:
atol = 1e-1 if atol < 1e-1 else atol
if self.is_float16_op(): if self.is_float16_op():
atol = 1e-3 if atol < 1e-3 else atol atol = 1e-3 if atol < 1e-3 else atol
...@@ -2492,7 +2563,6 @@ class OpTest(unittest.TestCase): ...@@ -2492,7 +2563,6 @@ class OpTest(unittest.TestCase):
if "use_mkldnn" in op_attrs and op_attrs["use_mkldnn"]: if "use_mkldnn" in op_attrs and op_attrs["use_mkldnn"]:
op_attrs["use_mkldnn"] = False op_attrs["use_mkldnn"] = False
use_onednn = True use_onednn = True
self.op = create_op( self.op = create_op(
self.scope, self.scope,
self.op_type, self.op_type,
...@@ -2538,8 +2608,9 @@ class OpTest(unittest.TestCase): ...@@ -2538,8 +2608,9 @@ class OpTest(unittest.TestCase):
if numeric_place is None: if numeric_place is None:
numeric_place = place numeric_place = place
if user_defined_grads is None and self.is_fp16_compared_with_fp32(): if user_defined_grads is None and self.is_compared_with_fp32():
self.enable_cal_ref_output() self.enable_cal_ref_output()
numeric_grads = self._get_gradient( numeric_grads = self._get_gradient(
inputs_to_check, inputs_to_check,
place, place,
...@@ -2573,6 +2644,7 @@ class OpTest(unittest.TestCase): ...@@ -2573,6 +2644,7 @@ class OpTest(unittest.TestCase):
) )
# comparison of bf16 results will happen as fp32 # comparison of bf16 results will happen as fp32
# loop over list of grads and convert bf16 to fp32 # loop over list of grads and convert bf16 to fp32
fp32_analytic_grads = [] fp32_analytic_grads = []
for grad in analytic_grads: for grad in analytic_grads:
if grad.dtype == np.uint16: if grad.dtype == np.uint16:
...@@ -2869,7 +2941,7 @@ class OpTest(unittest.TestCase): ...@@ -2869,7 +2941,7 @@ class OpTest(unittest.TestCase):
feed_dict = self.feed_var(inputs, place) feed_dict = self.feed_var(inputs, place)
if user_defined_grad_outputs is None: if user_defined_grad_outputs is None:
if self.dtype == np.uint16: if self.dtype == np.uint16 and not self.is_calc_ref:
cast_inputs = list(map(block.var, output_names)) cast_inputs = list(map(block.var, output_names))
if self.op_type in ["broadcast_tensors", "meshgrid"]: if self.op_type in ["broadcast_tensors", "meshgrid"]:
output_names = self.cast_bf16_output(block, cast_inputs) output_names = self.cast_bf16_output(block, cast_inputs)
......
...@@ -212,8 +212,6 @@ class TestElementwiseDivOpBF16(ElementwiseDivOp): ...@@ -212,8 +212,6 @@ class TestElementwiseDivOpBF16(ElementwiseDivOp):
check_args = [check_option['grad'], 'Out'] check_args = [check_option['grad'], 'Out']
check_kwargs = { check_kwargs = {
'no_grad_set': check_option['no_grad'], 'no_grad_set': check_option['no_grad'],
'user_defined_grads': check_option['val_grad'],
'user_defined_grad_outputs': [self.grad_out],
'check_dygraph': self.check_dygraph, 'check_dygraph': self.check_dygraph,
} }
if self.place is None: if self.place is None:
......
...@@ -441,6 +441,8 @@ def create_test_bf16_class(parent, atol=0.01): ...@@ -441,6 +441,8 @@ def create_test_bf16_class(parent, atol=0.01):
['X'], ['X'],
'Out', 'Out',
no_grad_set={'Y'}, no_grad_set={'Y'},
max_relative_error=3e-2,
atol=3e-2,
user_defined_grads=[numeric_grads], user_defined_grads=[numeric_grads],
check_cinn=self.check_cinn check_cinn=self.check_cinn
if hasattr(self, 'check_cinn') if hasattr(self, 'check_cinn')
...@@ -455,6 +457,8 @@ def create_test_bf16_class(parent, atol=0.01): ...@@ -455,6 +457,8 @@ def create_test_bf16_class(parent, atol=0.01):
['Y'], ['Y'],
'Out', 'Out',
no_grad_set={'X'}, no_grad_set={'X'},
max_relative_error=3e-2,
atol=3e-2,
user_defined_grads=[numeric_grads], user_defined_grads=[numeric_grads],
check_cinn=self.check_cinn check_cinn=self.check_cinn
if hasattr(self, 'check_cinn') if hasattr(self, 'check_cinn')
......
...@@ -317,7 +317,9 @@ def create_test_bf16_class(parent): ...@@ -317,7 +317,9 @@ def create_test_bf16_class(parent):
numeric_grads = self.get_numeric_grad(place, 'X') numeric_grads = self.get_numeric_grad(place, 'X')
if core.is_bfloat16_supported(place): if core.is_bfloat16_supported(place):
self.check_grad_with_place( self.check_grad_with_place(
place, {'X'}, ['Out'], user_defined_grads=[numeric_grads] place,
{'X'},
['Out'],
) )
cls_name = "{}_{}".format(parent.__name__, "BF16OP") cls_name = "{}_{}".format(parent.__name__, "BF16OP")
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest, paddle_static_guard from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -147,6 +147,13 @@ class TestSortedUniqueOp(TestUniqueOp): ...@@ -147,6 +147,13 @@ class TestSortedUniqueOp(TestUniqueOp):
self.dtype = np.float64 self.dtype = np.float64
def init_config(self): def init_config(self):
if self.dtype == np.uint16:
self.inputs = {
'X': convert_float_to_uint16(
np.array([2, 3, 3, 1, 5, 3], dtype=np.float32)
)
}
else:
self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype=self.dtype)} self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype=self.dtype)}
unique, indices, inverse, count = np.unique( unique, indices, inverse, count = np.unique(
self.inputs['X'], self.inputs['X'],
...@@ -197,6 +204,13 @@ class TestUniqueOpAxisNone(TestUniqueOp): ...@@ -197,6 +204,13 @@ class TestUniqueOpAxisNone(TestUniqueOp):
self.dtype = np.float64 self.dtype = np.float64
def init_config(self): def init_config(self):
if self.dtype == np.uint16:
self.inputs = {
'X': convert_float_to_uint16(
np.random.randint(0, 100, (4, 7, 10)).astype(np.float32)
)
}
else:
self.inputs = { self.inputs = {
'X': np.random.randint(0, 100, (4, 7, 10)).astype(self.dtype) 'X': np.random.randint(0, 100, (4, 7, 10)).astype(self.dtype)
} }
......
...@@ -120,7 +120,7 @@ def append_input_output( ...@@ -120,7 +120,7 @@ def append_input_output(
if is_input: if is_input:
shape = list(np_value.shape) shape = list(np_value.shape)
lod_level = 0 lod_level = 0
if is_calc_ref and dtype == np.float16: if is_calc_ref and (dtype == np.float16 or dtype == np.uint16):
dtype = np.float32 dtype = np.float32
return block.create_var( return block.create_var(
dtype=dtype, shape=shape, lod_level=lod_level, name=name dtype=dtype, shape=shape, lod_level=lod_level, name=name
......
...@@ -94,3 +94,7 @@ NO_FP16_COMPARED_WITH_FP32_OP_LIST = [ ...@@ -94,3 +94,7 @@ NO_FP16_COMPARED_WITH_FP32_OP_LIST = [
'fake_quantize_moving_average_abs_max', 'fake_quantize_moving_average_abs_max',
'p_norm', 'p_norm',
] ]
NO_BF16_COMPARED_WITH_FP32_OP_LIST = [
'dequantize',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册