未验证 提交 f7eb03c6 编写于 作者: C Charles-hit 提交者: GitHub

support group_norm and cumsum prim ops bf16 dtype (#54580)

上级 3d4d995f
......@@ -693,11 +693,13 @@ void group_norm_grad(const Tensor& x,
Tensor x_data = x;
Tensor out_grad_data = out_grad;
if (x.dtype() == phi::DataType::FLOAT16) {
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_data = cast<T>(x, phi::DataType::FLOAT32);
}
if (out_grad.dtype() == phi::DataType::FLOAT16) {
if (out_grad.dtype() == phi::DataType::FLOAT16 ||
out_grad.dtype() == phi::DataType::BFLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
}
......@@ -728,7 +730,8 @@ void group_norm_grad(const Tensor& x,
Tensor p1;
if (scale_ptr) {
auto scale_data = scale.get();
if (scale_data.dtype() == phi::DataType::FLOAT16) {
if (scale_data.dtype() == phi::DataType::FLOAT16 ||
scale_data.dtype() == phi::DataType::BFLOAT16) {
scale_data = cast<T>(scale_data, phi::DataType::FLOAT32);
}
d1 = (reshape<T>(sum_y_grad_mul_x * scale_data, shape_group))
......@@ -757,7 +760,8 @@ void group_norm_grad(const Tensor& x,
auto tmp_2 = reshape<T>(x_data, whole_group_shape) * p2 + p3;
auto x_grad_data = tmp_1 + tmp_2;
x_grad_data = reshape<T>(x_grad_data, x.shape());
if (x.dtype() == phi::DataType::FLOAT16) {
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
}
......@@ -770,9 +774,9 @@ void group_norm_grad(const Tensor& x,
reshape<T>(sum_y_grad, shape_group) *
reshape<T>(mean, third_shape)) *
reshape<T>(inv_std, third_shape);
auto scale_grad_tmp =
reshape<T>(tmp1.sum(std::vector<int64_t>({0}), dtype, false),
IntArray(std::vector<int64_t>({C})));
auto scale_grad_tmp = reshape<T>(
tmp1.sum(std::vector<int64_t>({0}), scale_ptr->dtype(), false),
IntArray(std::vector<int64_t>({C})));
set_output<T>(scale_grad_tmp, scale_grad);
} else {
scale_grad = nullptr;
......@@ -782,7 +786,7 @@ void group_norm_grad(const Tensor& x,
if (bias_grad) {
if (bias_ptr) {
auto bias_grad_tmp =
sum_y_grad.sum(std::vector<int64_t>({0}), dtype, false);
sum_y_grad.sum(std::vector<int64_t>({0}), bias_ptr->dtype(), false);
set_output<T>(bias_grad_tmp, bias_grad);
} else {
bias_grad = nullptr;
......
......@@ -298,8 +298,13 @@ namespace phi {
} \
}()
#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \
SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \
#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(SPECIFIED_TYPE1, \
SPECIFIED_TYPE2, \
SPECIFIED_TYPE3, \
SPECIFIED_TYPE4, \
TYPE, \
NAME, \
...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
......@@ -328,6 +333,10 @@ namespace phi {
SPECIFIED_TYPE3, \
::phi::DataTypeToCppType<SPECIFIED_TYPE3>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
SPECIFIED_TYPE4, \
::phi::DataTypeToCppType<SPECIFIED_TYPE4>::type, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
......
......@@ -47,10 +47,11 @@ void Reduce(const KPDevice& dev_ctx,
#ifndef PADDLE_WITH_XPU_KP
if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) {
auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES(
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
out_dtype,
"ReduceKernel",
([&] {
......
......@@ -656,8 +656,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
# when inputs are float16, convert to float32 in computing
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
# when inputs are float16 or bfloat16, convert to float32 in computing
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
scale = cast(scale, "float32")
......@@ -676,9 +677,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
out = out + reshape(bias, (-1, 1, 1))
ret_mean_ = reshape(mean_, (N, groups))
ret_var_ = reshape(var_, (N, groups))
# return output in float16, mean and var in float32
# return output in float16 or bfloat16, mean and var in float32
if is_amp:
out = cast(out, "float16")
out = cast(out, dtype)
return out, ret_mean_, ret_var_
......
......@@ -1033,7 +1033,7 @@ set_tests_properties(
PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 120)
set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120)
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250)
set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120)
......
......@@ -122,7 +122,7 @@ class TestSumOp1(OpTest):
self.prim_op_type = "prim"
self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum
self.set_enable_cinn()
self.if_enable_cinn()
self.init_dtype()
self.set_attrs_input_output()
if self.dtype == np.uint16:
......@@ -141,7 +141,7 @@ class TestSumOp1(OpTest):
def init_dtype(self):
self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self):
def if_enable_cinn(self):
pass
def set_attrs_input_output(self):
......@@ -221,7 +221,7 @@ class TestSumOpExclusive1(OpTest):
self.prim_op_type = "prim"
self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum
self.set_enable_cinn()
self.if_enable_cinn()
self.init_dtype()
self.set_attrs_input_output()
if self.dtype == np.uint16:
......@@ -240,7 +240,7 @@ class TestSumOpExclusive1(OpTest):
def init_dtype(self):
self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self):
def if_enable_cinn(self):
pass
def set_attrs_input_output(self):
......@@ -346,7 +346,7 @@ class TestSumOpReverseExclusive(OpTest):
self.prim_op_type = "prim"
self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum
self.set_enable_cinn()
self.if_enable_cinn()
self.init_dtype()
self.attrs = {
'axis': 2,
......@@ -378,7 +378,7 @@ class TestSumOpReverseExclusive(OpTest):
def init_dtype(self):
self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self):
def if_enable_cinn(self):
pass
......@@ -387,7 +387,7 @@ def create_test_fp16_class(parent, max_relative_error=1e-2):
def init_dtype(self):
self.dtype = self.dtype_ = np.float16
def set_enable_cinn(self):
def if_enable_cinn(self):
pass
def test_check_output(self):
......@@ -430,7 +430,7 @@ def create_test_bf16_class(parent):
self.dtype = np.uint16
self.dtype_ = np.float32
def set_enable_cinn(self):
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
......@@ -439,7 +439,9 @@ def create_test_bf16_class(parent):
def test_check_grad(self):
place = paddle.CUDAPlace(0)
self.check_grad_with_place(place, ["X"], "Out", check_prim=True)
self.check_grad_with_place(
place, ["X"], "Out", check_prim=True, numeric_grad_delta=0.05
)
cls_name = "{}_{}".format(parent.__name__, "BF16")
TestCumsumBF16Op.__name__ = cls_name
......
......@@ -19,6 +19,7 @@ import parameterized as param
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
paddle_static_guard,
skip_check_grad_ci,
)
......@@ -810,6 +811,96 @@ def apply_to_static(net, use_cinn):
[[1e-3, 1e-3, 1e-3]], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'test0_bfp16',
(2, 100, 3, 5),
1e-5,
2,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'test1_bfp16',
(2, 100, 3, 5),
1e-5,
1,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'test2_bfp16',
(2, 100, 3, 5),
1e-5,
4,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'bigeps3_bfp16',
(2, 100, 3, 5),
0.5,
2,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'largedata_bfp16',
(2, 32, 64, 64),
1e-5,
4,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
),
)
class TestCompositeGroupNorm(unittest.TestCase):
......@@ -825,12 +916,23 @@ class TestCompositeGroupNorm(unittest.TestCase):
np.random.seed(1234)
self.fwd_desire = []
self.rev_desire = []
self.x = np.random.random(self.shape).astype(self.dtype)
self.scale = np.random.random([self.shape[1]]).astype(self.dtype)
self.bias = np.random.random([self.shape[1]]).astype(self.dtype)
if self.dtype != "bfloat16":
self.x = np.random.random(self.shape).astype(self.dtype)
self.scale = np.random.random([self.shape[1]]).astype(self.dtype)
self.bias = np.random.random([self.shape[1]]).astype(self.dtype)
else:
self.x = convert_float_to_uint16(
np.random.random(self.shape).astype("float32")
)
self.scale = convert_float_to_uint16(
np.random.random([self.shape[1]]).astype("float32")
)
self.bias = convert_float_to_uint16(
np.random.random([self.shape[1]]).astype("float32")
)
self.num_channels = self.shape[1]
if self.dtype == 'float16':
if self.dtype in ['float16', 'bfloat16']:
self.places = []
if paddle.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
......@@ -879,7 +981,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
paddle.assign(bias_, group_norm.bias)
output = group_norm(input_)
grad = paddle.grad(output, input_)
if self.dtype == "bfloat16":
output = paddle.cast(output, "float32")
grad = paddle.utils.map_structure(
lambda x: paddle.cast(x, "float32"), grad
)
return output, grad[0]
def get_static_desire(self, place):
......@@ -923,7 +1029,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
output = group_norm(input_)
blocks = mp.blocks
names = dict(
zip(
blocks[0].ops[2].output_names,
......@@ -964,7 +1069,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
)
paddle.disable_static()
core._set_prim_all_enabled(True)
if self.dtype == "bfloat16":
out_list[0] = convert_uint16_to_float(out_list[0])
i = 3
for i in range(3, len(out_list)):
out_list[i] = convert_uint16_to_float(out_list[i])
return out_list[:3], out_list[3:]
def test_static_comp(self):
......@@ -1051,6 +1160,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
},
fetch_list=vars_list + [grads],
)
if self.dtype == "bfloat16":
out_list[0] = convert_uint16_to_float(out_list[0])
i = 3
for i in range(3, len(out_list)):
out_list[i] = convert_uint16_to_float(out_list[i])
fwd_actual[-1].append(out_list[0])
fwd_actual[-1].append(out_list[1])
fwd_actual[-1].append(out_list[2])
......@@ -1075,12 +1189,14 @@ class TestCompositeGroupNorm(unittest.TestCase):
atol = self.threshold_list[i][0]
rtol = self.threshold_list[i][0]
for j in range(len(self.static_fwd_desire[i])):
# in float16 type, Y is float16, mean and var are float16
# in float16 type, Y is float16, mean and var are float32
# so check mean and var with float32 gpu threshold
if self.dtype == 'float16' and j > 0:
if self.dtype == "float16" and j > 0:
atol = 1e-5
rtol = 1e-5
elif self.dtype == "bfloat16" and j > 0:
atol = 5e-3
rtol = 5e-3
np.testing.assert_allclose(
self.static_fwd_desire[i][j],
fwd_actual[i][j],
......@@ -1091,13 +1207,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
max_abs_diff = np.max(
np.abs(self.static_fwd_desire[i][j] - fwd_actual[i][j])
)
print(
self.shape,
self.dtype,
self.places[i],
vars_name[j],
max_abs_diff,
)
# compare with eager_desire
np.testing.assert_allclose(
self.fwd_desire[i],
......@@ -1121,14 +1230,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
np.abs(self.static_rev_desire[i][j] - rev_actual[i][j])
)
print(
self.shape,
self.dtype,
self.places[i],
vars_name[j + 3],
max_abs_diff,
)
np.testing.assert_allclose(
self.static_rev_desire[i][j],
rev_actual[i][j],
......@@ -1183,8 +1284,16 @@ class TestCompositeGroupNorm(unittest.TestCase):
net = apply_to_static(net, False)
output = net(input_)
grad = paddle.grad(output, input_)
fwd_actual.append(output.numpy())
rev_actual.append(grad[0].numpy())
fwd_actual.append(
convert_uint16_to_float(output.numpy())
if self.dtype == "bfloat16"
else output.numpy()
)
rev_actual.append(
convert_uint16_to_float(grad[0].numpy())
if self.dtype == "bfloat16"
else grad[0].numpy()
)
for i in range(len(self.places)):
atol = self.threshold_list[i][1]
......@@ -1244,8 +1353,16 @@ class TestCompositeGroupNorm(unittest.TestCase):
net = apply_to_static(net, True)
output = net(input_)
grad = paddle.grad(output, input_)
fwd_actual.append(output.numpy())
rev_actual.append(grad[0].numpy())
fwd_actual.append(
convert_uint16_to_float(output.numpy())
if self.dtype == "bfloat16"
else output.numpy()
)
rev_actual.append(
convert_uint16_to_float(grad[0].numpy())
if self.dtype == "bfloat16"
else grad[0].numpy()
)
i = 0
for place in self.places:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册