未验证 提交 91119271 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance OpTest for bfloat16. (#36079)

上级 cb620ca6
......@@ -94,24 +94,19 @@ class CastCUDAOpKernel : public framework::OpKernel<InT> {
} // namespace paddle
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastCUDAOpKernel<float>, ops::CastCUDAOpKernel<double>,
ops::CastCUDAOpKernel<int>, ops::CastCUDAOpKernel<int64_t>,
ops::CastCUDAOpKernel<int16_t>, ops::CastCUDAOpKernel<bool>,
ops::CastCUDAOpKernel<uint8_t>,
ops::CastCUDAOpKernel<paddle::platform::float16>,
ops::CastCUDAOpKernel<paddle::platform::complex<float>>,
ops::CastCUDAOpKernel<paddle::platform::complex<double>>);
namespace plat = paddle::platform;
#define REGISTER_CAST_CUDA_BASE(op_name, ...) \
REGISTER_OP_CUDA_KERNEL( \
op_name, ops::CastCUDAOpKernel<float>, ops::CastCUDAOpKernel<double>, \
ops::CastCUDAOpKernel<int>, ops::CastCUDAOpKernel<int64_t>, \
ops::CastCUDAOpKernel<int16_t>, ops::CastCUDAOpKernel<bool>, \
ops::CastCUDAOpKernel<uint8_t>, ops::CastCUDAOpKernel<plat::float16>, \
ops::CastCUDAOpKernel<plat::complex<float>>, \
ops::CastCUDAOpKernel<plat::complex<double>>, ##__VA_ARGS__);
#if !defined(PADDLE_WITH_HIP)
REGISTER_CAST_CUDA_BASE(cast, ops::CastCUDAOpKernel<plat::bfloat16>)
#else
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastCUDAOpKernel<float>, ops::CastCUDAOpKernel<double>,
ops::CastCUDAOpKernel<int>, ops::CastCUDAOpKernel<int64_t>,
ops::CastCUDAOpKernel<int16_t>, ops::CastCUDAOpKernel<bool>,
ops::CastCUDAOpKernel<uint8_t>,
ops::CastCUDAOpKernel<paddle::platform::float16>,
ops::CastCUDAOpKernel<paddle::platform::bfloat16>,
ops::CastCUDAOpKernel<paddle::platform::complex<float>>,
ops::CastCUDAOpKernel<paddle::platform::complex<double>>);
REGISTER_CAST_CUDA_BASE(cast)
#endif
......@@ -147,6 +147,9 @@ def get_numeric_gradient(place,
op.run(scope, place)
for output_name in output_names:
output_numpy = np.array(scope.find_var(output_name).get_tensor())
# numpy.dtype does not have bfloat16, thus we use numpy.uint16 to
# store bfloat16 data, and need to be converted to float to check
# the floating precision.
if tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
output_numpy = convert_uint16_to_float(output_numpy)
sum.append(output_numpy.astype(tensor_to_check_dtype).mean())
......@@ -362,11 +365,26 @@ class OpTest(unittest.TestCase):
self.dtype = data_type
def is_bfloat16_op(self):
# self.dtype is the dtype of inputs, and is set in infer_dtype_from_inputs_outputs.
# Make sure this function is called after calling infer_dtype_from_inputs_outputs.
return self.dtype == np.uint16 or (
hasattr(self, 'mkldnn_data_type') and
getattr(self, 'mkldnn_data_type') is "bfloat16") or (
hasattr(self, 'attrs') and 'mkldnn_data_type' in self.attrs and
self.attrs['mkldnn_data_type'] == 'bfloat16')
hasattr(self, 'output_dtype') and
self.output_dtype == np.uint16) or (
hasattr(self, 'mkldnn_data_type') and
getattr(self, 'mkldnn_data_type') is "bfloat16") or (
hasattr(self, 'attrs') and
'mkldnn_data_type' in self.attrs and
self.attrs['mkldnn_data_type'] == 'bfloat16')
def is_mkldnn_op(self):
return (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or (
hasattr(self, "attrs") and "use_mkldnn" in self.attrs and
self.attrs["use_mkldnn"] == True)
def is_xpu_op(self):
return (hasattr(self, "use_xpu") and self.use_xpu == True) or (
hasattr(self, "attrs") and "use_xpu" in self.attrs and
self.attrs["use_xpu"] == True)
def infer_dtype_from_inputs_outputs(self, inputs, outputs):
def is_np_data(input):
......@@ -398,8 +416,8 @@ class OpTest(unittest.TestCase):
# infer dtype from inputs, and dtype means the precision of the test
# collect dtype of all inputs
dtype_set = set()
infer_dtype(inputs, dtype_set)
input_dtype_set = set()
infer_dtype(inputs, input_dtype_set)
dtype_list = [
np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16),
np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.uint16),
......@@ -408,12 +426,20 @@ class OpTest(unittest.TestCase):
]
# check the dtype in dtype_list in order, select the first dtype that in dtype_set
for dtype in dtype_list:
if dtype in dtype_set:
if dtype in input_dtype_set:
self.dtype = dtype
break
# save dtype in class attr
# save input dtype in class attr
self.__class__.dtype = self.dtype
# infer dtype of outputs
output_dtype_set = set()
infer_dtype(outputs, output_dtype_set)
for dtype in dtype_list:
if dtype in output_dtype_set:
self.output_dtype = dtype
break
def feed_var(self, input_vars, place):
feed_map = {}
for var_name in input_vars:
......@@ -439,14 +465,10 @@ class OpTest(unittest.TestCase):
def _append_ops(self, block):
self.__class__.op_type = self.op_type # for ci check, please not delete it for now
if (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or \
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
self.attrs["use_mkldnn"] == True):
if self.is_mkldnn_op():
self.__class__.use_mkldnn = True
if (hasattr(self, "use_xpu") and self.use_xpu == True) or \
(hasattr(self, "attrs") and "use_xpu" in self.attrs and \
self.attrs["use_xpu"] == True):
if self.is_xpu_op():
self.__class__.use_xpu = True
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
......@@ -1092,12 +1114,15 @@ class OpTest(unittest.TestCase):
atol = 0
if self.is_bfloat16_op():
check_dygraph = False
if hasattr(self, 'force_fp32_output') and getattr(
self, 'force_fp32_output'):
atol = 1e-2
if self.is_mkldnn_op():
check_dygraph = False
if hasattr(self, 'force_fp32_output') and getattr(
self, 'force_fp32_output'):
atol = 1e-2
else:
atol = 2
else:
atol = 2
atol = 1e-2
if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list:
......@@ -1193,6 +1218,7 @@ class OpTest(unittest.TestCase):
expect = self.outputs[out_name]
expect_t = expect[0] if isinstance(expect, tuple) else expect
# np.uint16 represents bfloat16
if actual_t.dtype == np.uint16 and expect_t.dtype in [
np.float32, np.float64
]:
......@@ -1205,6 +1231,7 @@ class OpTest(unittest.TestCase):
expect_t = convert_uint16_to_float(expect_t)
actual_t = convert_uint16_to_float(actual_t)
atol = max(atol, 0.03)
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_t.size == 0:
......@@ -1214,13 +1241,19 @@ class OpTest(unittest.TestCase):
np.allclose(
actual_t,
expect_t,
rtol=rtol,
atol=atol,
rtol=rtol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t) + " in class " + self.__class__.__name__)
if check_dygraph:
if self.is_bfloat16_op():
if imperative_actual_t.dtype == np.uint16:
imperative_actual_t = convert_uint16_to_float(
imperative_actual_t)
if expect_t.dtype == np.uint16:
expect_t = convert_uint16_to_float(expect_t)
if six.moves.reduce(
lambda x, y: x * y, imperative_actual_t.shape,
1) == 0 and six.moves.reduce(
......@@ -1232,6 +1265,7 @@ class OpTest(unittest.TestCase):
imperative_actual_t,
expect_t,
atol=atol,
rtol=rtol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " +
str(place) + "\nExpect " + str(expect_t) + "\n" +
......@@ -1340,14 +1374,10 @@ class OpTest(unittest.TestCase):
check_dygraph=True,
inplace_atol=None):
self.__class__.op_type = self.op_type
if (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or \
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
self.attrs["use_mkldnn"] == True):
if self.is_mkldnn_op():
self.__class__.use_mkldnn = True
if (hasattr(self, "use_xpu") and self.use_xpu == True) or \
(hasattr(self, "attrs") and "use_xpu" in self.attrs and \
self.attrs["use_xpu"] == True):
if self.is_xpu_op():
self.__class__.use_xpu = True
places = self._get_places()
......@@ -1452,10 +1482,10 @@ class OpTest(unittest.TestCase):
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
if self.is_bfloat16_op():
self._check_grad_helper()
if self.is_bfloat16_op() and self.is_mkldnn_op():
check_dygraph = False
self._check_grad_helper()
if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST:
numeric_grad_delta = 1e-5
......
......@@ -14,7 +14,6 @@
from __future__ import print_function
import op_test
import unittest
import numpy as np
......@@ -22,9 +21,10 @@ import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from op_test import OpTest, convert_uint16_to_float, convert_float_to_uint16
class TestCastOp1(op_test.OpTest):
class TestCastOpFp32ToFp64(OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
......@@ -42,7 +42,7 @@ class TestCastOp1(op_test.OpTest):
self.check_grad(['X'], ['Out'])
class TestCastOp2(op_test.OpTest):
class TestCastOpFp16ToFp32(OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float16')}
......@@ -57,7 +57,7 @@ class TestCastOp2(op_test.OpTest):
self.check_output(atol=1e-3)
class TestCastOp3(op_test.OpTest):
class TestCastOpFp32ToFp16(OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
......@@ -72,6 +72,36 @@ class TestCastOp3(op_test.OpTest):
self.check_output(atol=1e-3)
class TestCastOpBf16ToFp32(OpTest):
def setUp(self):
ipt = np.array(np.random.randint(10, size=[10, 10])).astype('uint16')
self.inputs = {'X': ipt}
self.outputs = {'Out': convert_uint16_to_float(ipt)}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.BF16),
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.op_type = 'cast'
def test_check_output(self):
self.check_output()
class TestCastOpFp32ToBf16(OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10]).astype('float32')
self.inputs = {'X': ipt}
self.outputs = {'Out': convert_float_to_uint16(ipt)}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.BF16)
}
self.op_type = 'cast'
def test_check_output(self):
self.check_output()
class TestCastOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册