未验证 提交 f7eb03c6 编写于 作者: C Charles-hit 提交者: GitHub

support group_norm and cumsum prim ops bf16 dtype (#54580)

上级 3d4d995f
...@@ -693,11 +693,13 @@ void group_norm_grad(const Tensor& x, ...@@ -693,11 +693,13 @@ void group_norm_grad(const Tensor& x,
Tensor x_data = x; Tensor x_data = x;
Tensor out_grad_data = out_grad; Tensor out_grad_data = out_grad;
if (x.dtype() == phi::DataType::FLOAT16) { if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_data = cast<T>(x, phi::DataType::FLOAT32); x_data = cast<T>(x, phi::DataType::FLOAT32);
} }
if (out_grad.dtype() == phi::DataType::FLOAT16) { if (out_grad.dtype() == phi::DataType::FLOAT16 ||
out_grad.dtype() == phi::DataType::BFLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32); out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
} }
...@@ -728,7 +730,8 @@ void group_norm_grad(const Tensor& x, ...@@ -728,7 +730,8 @@ void group_norm_grad(const Tensor& x,
Tensor p1; Tensor p1;
if (scale_ptr) { if (scale_ptr) {
auto scale_data = scale.get(); auto scale_data = scale.get();
if (scale_data.dtype() == phi::DataType::FLOAT16) { if (scale_data.dtype() == phi::DataType::FLOAT16 ||
scale_data.dtype() == phi::DataType::BFLOAT16) {
scale_data = cast<T>(scale_data, phi::DataType::FLOAT32); scale_data = cast<T>(scale_data, phi::DataType::FLOAT32);
} }
d1 = (reshape<T>(sum_y_grad_mul_x * scale_data, shape_group)) d1 = (reshape<T>(sum_y_grad_mul_x * scale_data, shape_group))
...@@ -757,7 +760,8 @@ void group_norm_grad(const Tensor& x, ...@@ -757,7 +760,8 @@ void group_norm_grad(const Tensor& x,
auto tmp_2 = reshape<T>(x_data, whole_group_shape) * p2 + p3; auto tmp_2 = reshape<T>(x_data, whole_group_shape) * p2 + p3;
auto x_grad_data = tmp_1 + tmp_2; auto x_grad_data = tmp_1 + tmp_2;
x_grad_data = reshape<T>(x_grad_data, x.shape()); x_grad_data = reshape<T>(x_grad_data, x.shape());
if (x.dtype() == phi::DataType::FLOAT16) { if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype()); x_grad_data = cast<T>(x_grad_data, x.dtype());
} }
...@@ -770,9 +774,9 @@ void group_norm_grad(const Tensor& x, ...@@ -770,9 +774,9 @@ void group_norm_grad(const Tensor& x,
reshape<T>(sum_y_grad, shape_group) * reshape<T>(sum_y_grad, shape_group) *
reshape<T>(mean, third_shape)) * reshape<T>(mean, third_shape)) *
reshape<T>(inv_std, third_shape); reshape<T>(inv_std, third_shape);
auto scale_grad_tmp = auto scale_grad_tmp = reshape<T>(
reshape<T>(tmp1.sum(std::vector<int64_t>({0}), dtype, false), tmp1.sum(std::vector<int64_t>({0}), scale_ptr->dtype(), false),
IntArray(std::vector<int64_t>({C}))); IntArray(std::vector<int64_t>({C})));
set_output<T>(scale_grad_tmp, scale_grad); set_output<T>(scale_grad_tmp, scale_grad);
} else { } else {
scale_grad = nullptr; scale_grad = nullptr;
...@@ -782,7 +786,7 @@ void group_norm_grad(const Tensor& x, ...@@ -782,7 +786,7 @@ void group_norm_grad(const Tensor& x,
if (bias_grad) { if (bias_grad) {
if (bias_ptr) { if (bias_ptr) {
auto bias_grad_tmp = auto bias_grad_tmp =
sum_y_grad.sum(std::vector<int64_t>({0}), dtype, false); sum_y_grad.sum(std::vector<int64_t>({0}), bias_ptr->dtype(), false);
set_output<T>(bias_grad_tmp, bias_grad); set_output<T>(bias_grad_tmp, bias_grad);
} else { } else {
bias_grad = nullptr; bias_grad = nullptr;
......
...@@ -298,8 +298,13 @@ namespace phi { ...@@ -298,8 +298,13 @@ namespace phi {
} \ } \
}() }()
#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \ #define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(SPECIFIED_TYPE1, \
SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \ SPECIFIED_TYPE2, \
SPECIFIED_TYPE3, \
SPECIFIED_TYPE4, \
TYPE, \
NAME, \
...) \
[&] { \ [&] { \
const auto& __dtype__ = TYPE; \ const auto& __dtype__ = TYPE; \
switch (__dtype__) { \ switch (__dtype__) { \
...@@ -328,6 +333,10 @@ namespace phi { ...@@ -328,6 +333,10 @@ namespace phi {
SPECIFIED_TYPE3, \ SPECIFIED_TYPE3, \
::phi::DataTypeToCppType<SPECIFIED_TYPE3>::type, \ ::phi::DataTypeToCppType<SPECIFIED_TYPE3>::type, \
__VA_ARGS__) \ __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
SPECIFIED_TYPE4, \
::phi::DataTypeToCppType<SPECIFIED_TYPE4>::type, \
__VA_ARGS__) \
default: \ default: \
PD_THROW("function " #NAME " is not implemented for data type `", \ PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \ __dtype__, \
......
...@@ -47,10 +47,11 @@ void Reduce(const KPDevice& dev_ctx, ...@@ -47,10 +47,11 @@ void Reduce(const KPDevice& dev_ctx,
#ifndef PADDLE_WITH_XPU_KP #ifndef PADDLE_WITH_XPU_KP
if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) { if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) {
auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype); auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
out_dtype, out_dtype,
"ReduceKernel", "ReduceKernel",
([&] { ([&] {
......
...@@ -656,8 +656,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout): ...@@ -656,8 +656,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
is_amp = False is_amp = False
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
# when inputs are float16, convert to float32 in computing dtype = convert_dtype(x.dtype)
if convert_dtype(x.dtype) == "float16": # when inputs are float16 or bfloat16, convert to float32 in computing
if dtype in ["float16", "uint16"]:
is_amp = True is_amp = True
x = cast(x, "float32") x = cast(x, "float32")
scale = cast(scale, "float32") scale = cast(scale, "float32")
...@@ -676,9 +677,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout): ...@@ -676,9 +677,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
out = out + reshape(bias, (-1, 1, 1)) out = out + reshape(bias, (-1, 1, 1))
ret_mean_ = reshape(mean_, (N, groups)) ret_mean_ = reshape(mean_, (N, groups))
ret_var_ = reshape(var_, (N, groups)) ret_var_ = reshape(var_, (N, groups))
# return output in float16, mean and var in float32 # return output in float16 or bfloat16, mean and var in float32
if is_amp: if is_amp:
out = cast(out, "float16") out = cast(out, dtype)
return out, ret_mean_, ret_var_ return out, ret_mean_, ret_var_
......
...@@ -1033,7 +1033,7 @@ set_tests_properties( ...@@ -1033,7 +1033,7 @@ set_tests_properties(
PROPERTIES TIMEOUT 120) PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 120)
set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120) set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120)
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250) set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250)
set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250) set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120)
......
...@@ -122,7 +122,7 @@ class TestSumOp1(OpTest): ...@@ -122,7 +122,7 @@ class TestSumOp1(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = cumsum_wrapper self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum self.public_python_api = paddle.cumsum
self.set_enable_cinn() self.if_enable_cinn()
self.init_dtype() self.init_dtype()
self.set_attrs_input_output() self.set_attrs_input_output()
if self.dtype == np.uint16: if self.dtype == np.uint16:
...@@ -141,7 +141,7 @@ class TestSumOp1(OpTest): ...@@ -141,7 +141,7 @@ class TestSumOp1(OpTest):
def init_dtype(self): def init_dtype(self):
self.dtype = self.dtype_ = np.float64 self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self): def if_enable_cinn(self):
pass pass
def set_attrs_input_output(self): def set_attrs_input_output(self):
...@@ -221,7 +221,7 @@ class TestSumOpExclusive1(OpTest): ...@@ -221,7 +221,7 @@ class TestSumOpExclusive1(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = cumsum_wrapper self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum self.public_python_api = paddle.cumsum
self.set_enable_cinn() self.if_enable_cinn()
self.init_dtype() self.init_dtype()
self.set_attrs_input_output() self.set_attrs_input_output()
if self.dtype == np.uint16: if self.dtype == np.uint16:
...@@ -240,7 +240,7 @@ class TestSumOpExclusive1(OpTest): ...@@ -240,7 +240,7 @@ class TestSumOpExclusive1(OpTest):
def init_dtype(self): def init_dtype(self):
self.dtype = self.dtype_ = np.float64 self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self): def if_enable_cinn(self):
pass pass
def set_attrs_input_output(self): def set_attrs_input_output(self):
...@@ -346,7 +346,7 @@ class TestSumOpReverseExclusive(OpTest): ...@@ -346,7 +346,7 @@ class TestSumOpReverseExclusive(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = cumsum_wrapper self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum self.public_python_api = paddle.cumsum
self.set_enable_cinn() self.if_enable_cinn()
self.init_dtype() self.init_dtype()
self.attrs = { self.attrs = {
'axis': 2, 'axis': 2,
...@@ -378,7 +378,7 @@ class TestSumOpReverseExclusive(OpTest): ...@@ -378,7 +378,7 @@ class TestSumOpReverseExclusive(OpTest):
def init_dtype(self): def init_dtype(self):
self.dtype = self.dtype_ = np.float64 self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self): def if_enable_cinn(self):
pass pass
...@@ -387,7 +387,7 @@ def create_test_fp16_class(parent, max_relative_error=1e-2): ...@@ -387,7 +387,7 @@ def create_test_fp16_class(parent, max_relative_error=1e-2):
def init_dtype(self): def init_dtype(self):
self.dtype = self.dtype_ = np.float16 self.dtype = self.dtype_ = np.float16
def set_enable_cinn(self): def if_enable_cinn(self):
pass pass
def test_check_output(self): def test_check_output(self):
...@@ -430,7 +430,7 @@ def create_test_bf16_class(parent): ...@@ -430,7 +430,7 @@ def create_test_bf16_class(parent):
self.dtype = np.uint16 self.dtype = np.uint16
self.dtype_ = np.float32 self.dtype_ = np.float32
def set_enable_cinn(self): def if_enable_cinn(self):
self.enable_cinn = False self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
...@@ -439,7 +439,9 @@ def create_test_bf16_class(parent): ...@@ -439,7 +439,9 @@ def create_test_bf16_class(parent):
def test_check_grad(self): def test_check_grad(self):
place = paddle.CUDAPlace(0) place = paddle.CUDAPlace(0)
self.check_grad_with_place(place, ["X"], "Out", check_prim=True) self.check_grad_with_place(
place, ["X"], "Out", check_prim=True, numeric_grad_delta=0.05
)
cls_name = "{}_{}".format(parent.__name__, "BF16") cls_name = "{}_{}".format(parent.__name__, "BF16")
TestCumsumBF16Op.__name__ = cls_name TestCumsumBF16Op.__name__ = cls_name
......
...@@ -19,6 +19,7 @@ import parameterized as param ...@@ -19,6 +19,7 @@ import parameterized as param
from eager_op_test import ( from eager_op_test import (
OpTest, OpTest,
convert_float_to_uint16, convert_float_to_uint16,
convert_uint16_to_float,
paddle_static_guard, paddle_static_guard,
skip_check_grad_ci, skip_check_grad_ci,
) )
...@@ -810,6 +811,96 @@ def apply_to_static(net, use_cinn): ...@@ -810,6 +811,96 @@ def apply_to_static(net, use_cinn):
[[1e-3, 1e-3, 1e-3]], # gpu thresholds for static, jit, jit_cinn [[1e-3, 1e-3, 1e-3]], # gpu thresholds for static, jit, jit_cinn
None, None,
), ),
(
'test0_bfp16',
(2, 100, 3, 5),
1e-5,
2,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'test1_bfp16',
(2, 100, 3, 5),
1e-5,
1,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'test2_bfp16',
(2, 100, 3, 5),
1e-5,
4,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'bigeps3_bfp16',
(2, 100, 3, 5),
0.5,
2,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
(
'largedata_bfp16',
(2, 32, 64, 64),
1e-5,
4,
'NCHW',
places,
'bfloat16',
[
[
1e-2,
1e-2,
1e-2,
], # cpu thresholds for static, jit, jit_cinn
[1e-2, 1e-2, 1e-2],
], # gpu thresholds for static, jit, jit_cinn
None,
),
), ),
) )
class TestCompositeGroupNorm(unittest.TestCase): class TestCompositeGroupNorm(unittest.TestCase):
...@@ -825,12 +916,23 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -825,12 +916,23 @@ class TestCompositeGroupNorm(unittest.TestCase):
np.random.seed(1234) np.random.seed(1234)
self.fwd_desire = [] self.fwd_desire = []
self.rev_desire = [] self.rev_desire = []
self.x = np.random.random(self.shape).astype(self.dtype) if self.dtype != "bfloat16":
self.scale = np.random.random([self.shape[1]]).astype(self.dtype) self.x = np.random.random(self.shape).astype(self.dtype)
self.bias = np.random.random([self.shape[1]]).astype(self.dtype) self.scale = np.random.random([self.shape[1]]).astype(self.dtype)
self.bias = np.random.random([self.shape[1]]).astype(self.dtype)
else:
self.x = convert_float_to_uint16(
np.random.random(self.shape).astype("float32")
)
self.scale = convert_float_to_uint16(
np.random.random([self.shape[1]]).astype("float32")
)
self.bias = convert_float_to_uint16(
np.random.random([self.shape[1]]).astype("float32")
)
self.num_channels = self.shape[1] self.num_channels = self.shape[1]
if self.dtype == 'float16': if self.dtype in ['float16', 'bfloat16']:
self.places = [] self.places = []
if paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0)) self.places.append(paddle.CUDAPlace(0))
...@@ -879,7 +981,11 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -879,7 +981,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
paddle.assign(bias_, group_norm.bias) paddle.assign(bias_, group_norm.bias)
output = group_norm(input_) output = group_norm(input_)
grad = paddle.grad(output, input_) grad = paddle.grad(output, input_)
if self.dtype == "bfloat16":
output = paddle.cast(output, "float32")
grad = paddle.utils.map_structure(
lambda x: paddle.cast(x, "float32"), grad
)
return output, grad[0] return output, grad[0]
def get_static_desire(self, place): def get_static_desire(self, place):
...@@ -923,7 +1029,6 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -923,7 +1029,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
output = group_norm(input_) output = group_norm(input_)
blocks = mp.blocks blocks = mp.blocks
names = dict( names = dict(
zip( zip(
blocks[0].ops[2].output_names, blocks[0].ops[2].output_names,
...@@ -964,7 +1069,11 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -964,7 +1069,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
) )
paddle.disable_static() paddle.disable_static()
core._set_prim_all_enabled(True) core._set_prim_all_enabled(True)
if self.dtype == "bfloat16":
out_list[0] = convert_uint16_to_float(out_list[0])
i = 3
for i in range(3, len(out_list)):
out_list[i] = convert_uint16_to_float(out_list[i])
return out_list[:3], out_list[3:] return out_list[:3], out_list[3:]
def test_static_comp(self): def test_static_comp(self):
...@@ -1051,6 +1160,11 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -1051,6 +1160,11 @@ class TestCompositeGroupNorm(unittest.TestCase):
}, },
fetch_list=vars_list + [grads], fetch_list=vars_list + [grads],
) )
if self.dtype == "bfloat16":
out_list[0] = convert_uint16_to_float(out_list[0])
i = 3
for i in range(3, len(out_list)):
out_list[i] = convert_uint16_to_float(out_list[i])
fwd_actual[-1].append(out_list[0]) fwd_actual[-1].append(out_list[0])
fwd_actual[-1].append(out_list[1]) fwd_actual[-1].append(out_list[1])
fwd_actual[-1].append(out_list[2]) fwd_actual[-1].append(out_list[2])
...@@ -1075,12 +1189,14 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -1075,12 +1189,14 @@ class TestCompositeGroupNorm(unittest.TestCase):
atol = self.threshold_list[i][0] atol = self.threshold_list[i][0]
rtol = self.threshold_list[i][0] rtol = self.threshold_list[i][0]
for j in range(len(self.static_fwd_desire[i])): for j in range(len(self.static_fwd_desire[i])):
# in float16 type, Y is float16, mean and var are float16 # in float16 type, Y is float16, mean and var are float32
# so check mean and var with float32 gpu threshold # so check mean and var with float32 gpu threshold
if self.dtype == 'float16' and j > 0: if self.dtype == "float16" and j > 0:
atol = 1e-5 atol = 1e-5
rtol = 1e-5 rtol = 1e-5
elif self.dtype == "bfloat16" and j > 0:
atol = 5e-3
rtol = 5e-3
np.testing.assert_allclose( np.testing.assert_allclose(
self.static_fwd_desire[i][j], self.static_fwd_desire[i][j],
fwd_actual[i][j], fwd_actual[i][j],
...@@ -1091,13 +1207,6 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -1091,13 +1207,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
max_abs_diff = np.max( max_abs_diff = np.max(
np.abs(self.static_fwd_desire[i][j] - fwd_actual[i][j]) np.abs(self.static_fwd_desire[i][j] - fwd_actual[i][j])
) )
print(
self.shape,
self.dtype,
self.places[i],
vars_name[j],
max_abs_diff,
)
# compare with eager_desire # compare with eager_desire
np.testing.assert_allclose( np.testing.assert_allclose(
self.fwd_desire[i], self.fwd_desire[i],
...@@ -1121,14 +1230,6 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -1121,14 +1230,6 @@ class TestCompositeGroupNorm(unittest.TestCase):
np.abs(self.static_rev_desire[i][j] - rev_actual[i][j]) np.abs(self.static_rev_desire[i][j] - rev_actual[i][j])
) )
print(
self.shape,
self.dtype,
self.places[i],
vars_name[j + 3],
max_abs_diff,
)
np.testing.assert_allclose( np.testing.assert_allclose(
self.static_rev_desire[i][j], self.static_rev_desire[i][j],
rev_actual[i][j], rev_actual[i][j],
...@@ -1183,8 +1284,16 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -1183,8 +1284,16 @@ class TestCompositeGroupNorm(unittest.TestCase):
net = apply_to_static(net, False) net = apply_to_static(net, False)
output = net(input_) output = net(input_)
grad = paddle.grad(output, input_) grad = paddle.grad(output, input_)
fwd_actual.append(output.numpy()) fwd_actual.append(
rev_actual.append(grad[0].numpy()) convert_uint16_to_float(output.numpy())
if self.dtype == "bfloat16"
else output.numpy()
)
rev_actual.append(
convert_uint16_to_float(grad[0].numpy())
if self.dtype == "bfloat16"
else grad[0].numpy()
)
for i in range(len(self.places)): for i in range(len(self.places)):
atol = self.threshold_list[i][1] atol = self.threshold_list[i][1]
...@@ -1244,8 +1353,16 @@ class TestCompositeGroupNorm(unittest.TestCase): ...@@ -1244,8 +1353,16 @@ class TestCompositeGroupNorm(unittest.TestCase):
net = apply_to_static(net, True) net = apply_to_static(net, True)
output = net(input_) output = net(input_)
grad = paddle.grad(output, input_) grad = paddle.grad(output, input_)
fwd_actual.append(output.numpy()) fwd_actual.append(
rev_actual.append(grad[0].numpy()) convert_uint16_to_float(output.numpy())
if self.dtype == "bfloat16"
else output.numpy()
)
rev_actual.append(
convert_uint16_to_float(grad[0].numpy())
if self.dtype == "bfloat16"
else grad[0].numpy()
)
i = 0 i = 0
for place in self.places: for place in self.places:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册