未验证 提交 e3fcbb8f 编写于 作者: C Charles-hit 提交者: GitHub

[AMP Prim OP]support bf16 dtype for layer_norm prim op (#54236)

* support layer_norm prim op bf16 dtype

* polish code

* resolve conflict
上级 effebd41
...@@ -934,8 +934,9 @@ void layer_norm_grad(const Tensor& x, ...@@ -934,8 +934,9 @@ void layer_norm_grad(const Tensor& x,
scale_cast = reshape<T>(*scale_ptr, std::vector<int64_t>({1, shape_2})); scale_cast = reshape<T>(*scale_ptr, std::vector<int64_t>({1, shape_2}));
} }
// cast dtype to float32 if dtype =float16 // cast dtype to float32 if dtype =float16 or bfloat16
if (x.dtype() == phi::DataType::FLOAT16) { if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_cast = cast<T>(x_cast, phi::DataType::FLOAT32); x_cast = cast<T>(x_cast, phi::DataType::FLOAT32);
out_grad_cast = cast<T>(out_grad_cast, phi::DataType::FLOAT32); out_grad_cast = cast<T>(out_grad_cast, phi::DataType::FLOAT32);
if (scale_ptr) { if (scale_ptr) {
...@@ -967,7 +968,8 @@ void layer_norm_grad(const Tensor& x, ...@@ -967,7 +968,8 @@ void layer_norm_grad(const Tensor& x,
auto x_grad_tmp = dx_end - d_mean_d_std; auto x_grad_tmp = dx_end - d_mean_d_std;
x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims())); x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims()));
if (x.dtype() == phi::DataType::FLOAT16) { if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_grad_tmp = cast<T>(x_grad_tmp, x.dtype()); x_grad_tmp = cast<T>(x_grad_tmp, x.dtype());
} }
set_output<T>(x_grad_tmp, x_grad); set_output<T>(x_grad_tmp, x_grad);
...@@ -979,6 +981,10 @@ void layer_norm_grad(const Tensor& x, ...@@ -979,6 +981,10 @@ void layer_norm_grad(const Tensor& x,
(x_sub_mean_mul_sqrt_var_1 * out_grad_cast) (x_sub_mean_mul_sqrt_var_1 * out_grad_cast)
.sum(std::vector<int64_t>({0}), x_cast.dtype(), true); .sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
scale_grad_tmp = reshape<T>(scale_grad_tmp, scale_ptr->shape()); scale_grad_tmp = reshape<T>(scale_grad_tmp, scale_ptr->shape());
if (scale_ptr->dtype() == phi::DataType::FLOAT16 ||
scale_ptr->dtype() == phi::DataType::BFLOAT16) {
scale_grad_tmp = cast<T>(scale_grad_tmp, scale_ptr->dtype());
}
set_output<T>(scale_grad_tmp, scale_grad); set_output<T>(scale_grad_tmp, scale_grad);
} else { } else {
scale_grad = nullptr; scale_grad = nullptr;
...@@ -990,6 +996,10 @@ void layer_norm_grad(const Tensor& x, ...@@ -990,6 +996,10 @@ void layer_norm_grad(const Tensor& x,
auto bias_grad_tmp = auto bias_grad_tmp =
out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true); out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape()); bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape());
if (bias_ptr->dtype() == phi::DataType::FLOAT16 ||
bias_ptr->dtype() == phi::DataType::BFLOAT16) {
bias_grad_tmp = cast<T>(bias_grad_tmp, bias_ptr->dtype());
}
set_output<T>(bias_grad_tmp, bias_grad); set_output<T>(bias_grad_tmp, bias_grad);
} else { } else {
bias_grad = nullptr; bias_grad = nullptr;
......
...@@ -150,9 +150,12 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): ...@@ -150,9 +150,12 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
is_amp = False is_amp = False
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16": dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True is_amp = True
x = cast(x, "float32") x = cast(x, "float32")
scale = cast(scale, "float32") if scale else scale
bias = cast(bias, "float32") if bias else bias
axis = tuple(range(begin_norm_axis, len(x.shape))) axis = tuple(range(begin_norm_axis, len(x.shape)))
mean_ = mean(x, axis=axis, keepdim=True) mean_ = mean(x, axis=axis, keepdim=True)
...@@ -175,8 +178,7 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): ...@@ -175,8 +178,7 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
mean_ = reshape(mean_, [-1]) mean_ = reshape(mean_, [-1])
variance = reshape(variance, [-1]) variance = reshape(variance, [-1])
if is_amp: if is_amp:
out = cast(out, "float16") out = cast(out, dtype)
return out, mean_, variance return out, mean_, variance
...@@ -632,7 +634,7 @@ def rsqrt_composite(x): ...@@ -632,7 +634,7 @@ def rsqrt_composite(x):
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
dtype = convert_dtype(x.dtype) dtype = convert_dtype(x.dtype)
if dtype == "float16" or dtype == "uint16": if dtype in ["float16", "uint16"]:
is_amp = True is_amp = True
x = cast(x, "float32") x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype) y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
......
...@@ -958,7 +958,7 @@ else() ...@@ -958,7 +958,7 @@ else()
set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 150) set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150) set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150)
endif() endif()
set_tests_properties(test_imperative_selected_rows_to_lod_tensor set_tests_properties(test_imperative_selected_rows_to_lod_tensor
......
...@@ -17,7 +17,11 @@ from functools import reduce ...@@ -17,7 +17,11 @@ from functools import reduce
from operator import mul from operator import mul
import numpy as np import numpy as np
from eager_op_test import OpTest, _set_use_system_allocator from eager_op_test import (
OpTest,
_set_use_system_allocator,
convert_float_to_uint16,
)
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -212,6 +216,96 @@ class TestLayerNormOpByOpTest(OpTest): ...@@ -212,6 +216,96 @@ class TestLayerNormOpByOpTest(OpTest):
} }
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestLayerNormBF16OpByOpTest(OpTest):
def setUp(self):
self.python_api = layer_norm_wrapper
self.public_python_api = layer_norm_wrapper
self.op_type = "layer_norm"
self.prim_op_type = "comp"
self.python_out_sig = ["Y"]
self.initConfig()
self.initTestCase()
def test_check_output(self):
self.check_output_with_place(
place=core.CUDAPlace(0),
no_check_set=["Mean", "Variance"],
atol=self.ori_atol,
rtol=self.ori_rtol,
check_prim=True,
)
def test_check_grad(self):
self.check_grad_with_place(
core.CUDAPlace(0),
self.check_grad_input_list,
['Y'],
max_relative_error=self.max_relative_error,
check_prim=True,
)
def initConfig(self):
self.ori_atol = 1e-2
self.ori_rtol = 1e-2
self.max_relative_error = 1e-5
self.dtype = np.uint16
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = True
self.has_bias = True
def initTestCase(self):
np.random.seed(123)
self.D = reduce(
mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1
)
self.scale_shape = [self.D]
x = np.random.random(self.x_shape).astype("float32")
scale = (
np.random.random(self.scale_shape).astype("float32")
if self.has_scale
else None
)
bias = (
np.random.random(self.scale_shape).astype("float32")
if self.has_bias
else None
)
self.inputs = {
"X": convert_float_to_uint16(x),
}
self.check_grad_input_list = ['X']
if self.has_scale:
self.inputs.update({"Scale": convert_float_to_uint16(scale)})
self.check_grad_input_list.append('Scale')
if self.has_bias:
self.inputs.update({"Bias": convert_float_to_uint16(bias)})
self.check_grad_input_list.append('Bias')
self.attrs = {
"epsilon": self.epsilon,
"begin_norm_axis": self.begin_norm_axis,
}
y, mean, variance = _reference_layer_norm_naive(
x, scale, bias, self.epsilon, self.begin_norm_axis
)
self.outputs = {
"Y": convert_float_to_uint16(y),
"Mean": convert_float_to_uint16(mean),
"Variance": convert_float_to_uint16(variance),
}
class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest): class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest):
def initConfig(self): def initConfig(self):
self.rev_comp_atol = 1e-6 self.rev_comp_atol = 1e-6
...@@ -234,6 +328,21 @@ class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest): ...@@ -234,6 +328,21 @@ class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest):
self.has_bias = False self.has_bias = False
class TestLayerNormBF16OpByOpTest_case2(TestLayerNormBF16OpByOpTest):
def initConfig(self):
self.ori_atol = 1e-2
self.ori_rtol = 1e-2
self.max_relative_error = 1e-5
self.dtype = np.uint16
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = False
self.has_bias = False
class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest): class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest):
def initConfig(self): def initConfig(self):
self.rev_comp_atol = 1e-7 self.rev_comp_atol = 1e-7
...@@ -256,6 +365,21 @@ class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest): ...@@ -256,6 +365,21 @@ class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest):
self.has_bias = False self.has_bias = False
class TestLayerNormBF16OpByOpTest_case3(TestLayerNormBF16OpByOpTest):
def initConfig(self):
self.ori_atol = 1e-2
self.ori_rtol = 1e-2
self.max_relative_error = 1e-5
self.dtype = np.uint16
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = True
self.has_bias = False
class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest): class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest):
def initConfig(self): def initConfig(self):
self.rev_comp_atol = 1e-6 self.rev_comp_atol = 1e-6
...@@ -278,6 +402,21 @@ class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest): ...@@ -278,6 +402,21 @@ class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest):
self.has_bias = True self.has_bias = True
class TestLayerNormBF16OpByOpTest_case4(TestLayerNormBF16OpByOpTest):
def initConfig(self):
self.ori_atol = 1e-2
self.ori_rtol = 1e-2
self.max_relative_error = 1e-5
self.dtype = np.uint16
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = False
self.has_bias = True
class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest): class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest):
def initConfig(self): def initConfig(self):
self.rev_comp_atol = 1e-5 self.rev_comp_atol = 1e-5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册