未验证 提交 1da67779 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】 optimize layer_norm_grad rules (#52308)

* add to sub & delete full scale

* decrease 1_div_shape_2 compute

* x_sub_mean_mul_sqrt_var_1

* delete log

* add mean var test

* nothing
上级 53f5edbd
...@@ -942,43 +942,60 @@ void layer_norm_grad(const Tensor& x, ...@@ -942,43 +942,60 @@ void layer_norm_grad(const Tensor& x,
auto scale_ptr = scale.get_ptr(); auto scale_ptr = scale.get_ptr();
auto bias_ptr = bias.get_ptr(); auto bias_ptr = bias.get_ptr();
// cast dtype to float32 if dtype =float16 auto x_cast = reshape<T>(x, std::vector<int64_t>({shape_1, shape_2}));
Tensor x_cast = x; auto out_grad_cast =
Tensor out_grad_cast = out_grad; reshape<T>(out_grad, std::vector<int64_t>({shape_1, shape_2}));
auto mean_ = reshape<T>(mean, std::vector<int64_t>({shape_1, 1}));
auto variance_ = reshape<T>(variance, std::vector<int64_t>({shape_1, 1}));
Tensor scale_cast; Tensor scale_cast;
if (scale_ptr) { if (scale_ptr) {
scale_cast = reshape<T>(*scale_ptr, std::vector<int64_t>({1, shape_2})); scale_cast = reshape<T>(*scale_ptr, std::vector<int64_t>({1, shape_2}));
} }
// cast dtype to float32 if dtype =float16
if (x.dtype() == phi::DataType::FLOAT16) { if (x.dtype() == phi::DataType::FLOAT16) {
x_cast = cast<T>(x, phi::DataType::FLOAT32); x_cast = cast<T>(x_cast, phi::DataType::FLOAT32);
out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32); out_grad_cast = cast<T>(out_grad_cast, phi::DataType::FLOAT32);
if (scale_ptr) { if (scale_ptr) {
scale_cast = cast<T>(scale_cast, phi::DataType::FLOAT32); scale_cast = cast<T>(scale_cast, phi::DataType::FLOAT32);
} }
} }
x_cast = reshape<T>(x_cast, std::vector<int64_t>({shape_1, shape_2})); auto x_sub_mean = x_cast - mean_; // M,N
out_grad_cast = auto tmp = (1.0 / (variance_ + epsilon)); // M,1
reshape<T>(out_grad_cast, std::vector<int64_t>({shape_1, shape_2})); auto sqrt_var_1 = sqrt<T>(tmp); // M,1
auto mean_ = reshape<T>(mean, std::vector<int64_t>({shape_1, 1})); auto x_sub_mean_mul_sqrt_var_1 = x_sub_mean * sqrt_var_1;
auto variance_ = reshape<T>(variance, std::vector<int64_t>({shape_1, 1}));
if (bias_grad) { if (x_grad) {
if (bias_ptr) { auto out_grad_scale = out_grad_cast; // M,N
auto bias_grad_tmp = if (scale_ptr) {
out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true); out_grad_scale = out_grad_cast * scale_cast; // M,N * 1,N = M,N
bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape());
set_output<T>(bias_grad_tmp, bias_grad);
} else {
bias_grad = nullptr;
} }
auto dx_end = sqrt_var_1 * out_grad_scale;
auto d_mean =
dx_end.sum(std::vector<int64_t>({1}), x_cast.dtype(), true); // M,1
auto d_std_1 =
(tmp * x_sub_mean * out_grad_scale)
.sum(std::vector<int64_t>({1}), x_cast.dtype(), true); // M,1
auto d_std = d_std_1 * x_sub_mean_mul_sqrt_var_1; // M,1 * M,N = M,N
auto d_mean_d_std = (1.0 / shape_2) * (d_mean + d_std);
auto x_grad_tmp = dx_end - d_mean_d_std;
x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims()));
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_tmp = cast<T>(x_grad_tmp, x.dtype());
}
set_output<T>(x_grad_tmp, x_grad);
} }
auto x_sub_mean = x_cast - mean_;
auto tmp = (1.0 / (variance_ + epsilon));
auto sqrt_var_1 = sqrt<T>(tmp);
if (scale_grad) { if (scale_grad) {
if (scale_ptr) { if (scale_ptr) {
auto scale_grad_tmp = auto scale_grad_tmp =
(x_sub_mean * sqrt_var_1 * out_grad_cast) (x_sub_mean_mul_sqrt_var_1 * out_grad_cast)
.sum(std::vector<int64_t>({0}), x_cast.dtype(), true); .sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
scale_grad_tmp = reshape<T>(scale_grad_tmp, scale_ptr->shape()); scale_grad_tmp = reshape<T>(scale_grad_tmp, scale_ptr->shape());
set_output<T>(scale_grad_tmp, scale_grad); set_output<T>(scale_grad_tmp, scale_grad);
...@@ -987,29 +1004,15 @@ void layer_norm_grad(const Tensor& x, ...@@ -987,29 +1004,15 @@ void layer_norm_grad(const Tensor& x,
} }
} }
if (x_grad) { if (bias_grad) {
if (!scale_ptr) { if (bias_ptr) {
scale_cast = auto bias_grad_tmp =
full<T>(std::vector<int64_t>({1, shape_2}), 1.0, x_cast.dtype()); out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
} bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape());
auto out_grad_scale = out_grad_cast * scale_cast; set_output<T>(bias_grad_tmp, bias_grad);
auto dx_end = (sqrt_var_1 * out_grad_scale); } else {
auto d_mean_0 = bias_grad = nullptr;
(-dx_end).sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
auto d_mean = (1.0 / shape_2) * d_mean_0;
auto d_std_1 = (-tmp * x_sub_mean * out_grad_scale)
.sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
auto d_std_2 = (1.0 / shape_2) * sqrt_var_1;
d_std_2 = reshape<T>(d_std_2, std::vector<int64_t>({shape_1, 1}));
d_std_2 = d_std_2 * x_sub_mean;
auto d_std = d_std_1 * d_std_2;
auto x_grad_tmp = dx_end + d_mean + d_std;
x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims()));
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_tmp = cast<T>(x_grad_tmp, x.dtype());
} }
set_output<T>(x_grad_tmp, x_grad);
} }
} }
......
...@@ -18,8 +18,10 @@ import numpy as np ...@@ -18,8 +18,10 @@ import numpy as np
from utils import SUB_TOLERANCE from utils import SUB_TOLERANCE
import paddle import paddle
import paddle.nn.functional as F from paddle import _C_ops
from paddle.fluid import core, framework from paddle.fluid import core, framework
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle.incubate.autograd import primapi from paddle.incubate.autograd import primapi
from paddle.nn import LayerNorm from paddle.nn import LayerNorm
...@@ -32,6 +34,70 @@ def generate_data(shape1, shape2, shape3, dtype="float32"): ...@@ -32,6 +34,70 @@ def generate_data(shape1, shape2, shape3, dtype="float32"):
return np_data1, np_data2, np_data3 return np_data1, np_data2, np_data3
def layer_norm_wrapper(
x, normalized_shape, weight=None, bias=None, epsilon=1e-05, name=None
):
input_shape = list(x.shape)
input_ndim = len(input_shape)
normalized_ndim = len(normalized_shape)
begin_norm_axis = input_ndim - normalized_ndim
if (
input_ndim < normalized_ndim
or input_shape[begin_norm_axis:] != normalized_shape
):
str_normalized_shape = str(normalized_shape)
raise ValueError(
'Given normalized_shape is '
+ str_normalized_shape
+ ', expected input with shape [*, '
+ str_normalized_shape[1:]
+ ', but got input shape '
+ str(input_shape)
)
if in_dygraph_mode():
return _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis)
else:
inputs = {}
inputs['X'] = [x]
if weight:
inputs['Scale'] = [weight]
if bias:
inputs['Bias'] = [bias]
attrs = {"epsilon": epsilon, "begin_norm_axis": begin_norm_axis}
# create output
helper = LayerHelper('layer_norm', **locals())
from paddle.fluid.data_feeder import convert_dtype
param_dtype = (
x.dtype if convert_dtype(x.dtype) != 'float16' else 'float32'
)
mean_out = helper.create_variable_for_type_inference(
dtype=param_dtype, stop_gradient=True
)
variance_out = helper.create_variable_for_type_inference(
dtype=param_dtype, stop_gradient=True
)
layer_norm_out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": layer_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis},
)
return layer_norm_out, mean_out, variance_out
class Attr: class Attr:
def __init__(self) -> None: def __init__(self) -> None:
self.dtype = None self.dtype = None
...@@ -64,7 +130,7 @@ attrs = Attr() ...@@ -64,7 +130,7 @@ attrs = Attr()
def fn(x, norm_shape, w, b): def fn(x, norm_shape, w, b):
return F.layer_norm(x, norm_shape, w, b) return layer_norm_wrapper(x, norm_shape, w, b)
def expect_forward(x, norm_shape, w, b): def expect_forward(x, norm_shape, w, b):
...@@ -92,7 +158,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -92,7 +158,7 @@ class TestCompositelayer_norm(unittest.TestCase):
'w', shape=weight.shape, dtype=str(weight.dtype) 'w', shape=weight.shape, dtype=str(weight.dtype)
) )
b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype)) b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
y = fn(x, norm_shape, w, b) out, mean, var = fn(x, norm_shape, w, b)
blocks = main_program.blocks blocks = main_program.blocks
...@@ -115,7 +181,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -115,7 +181,7 @@ class TestCompositelayer_norm(unittest.TestCase):
'w': weight, 'w': weight,
'b': bias, 'b': bias,
}, },
fetch_list=[y], fetch_list=[out, mean, var],
) )
paddle.disable_static() paddle.disable_static()
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
...@@ -131,7 +197,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -131,7 +197,7 @@ class TestCompositelayer_norm(unittest.TestCase):
'x', shape=inputs.shape, dtype=str(inputs.dtype) 'x', shape=inputs.shape, dtype=str(inputs.dtype)
) )
y = fn(x, norm_shape, weight, bias) out, mean, var = fn(x, norm_shape, weight, bias)
blocks = main_program.blocks blocks = main_program.blocks
...@@ -152,7 +218,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -152,7 +218,7 @@ class TestCompositelayer_norm(unittest.TestCase):
feed={ feed={
'x': inputs, 'x': inputs,
}, },
fetch_list=[y], fetch_list=[out, mean, var],
) )
paddle.disable_static() paddle.disable_static()
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
...@@ -167,26 +233,28 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -167,26 +233,28 @@ class TestCompositelayer_norm(unittest.TestCase):
w_p = paddle.to_tensor(w) w_p = paddle.to_tensor(w)
b_p = paddle.to_tensor(b) b_p = paddle.to_tensor(b)
expect = expect_forward(x_p, n_shape, w_p, b_p).numpy() expect = expect_forward(x_p, n_shape, w_p, b_p)
actual = self.cal_composite(x, n_shape, w, b)[0] actual = self.cal_composite(x, n_shape, w, b)
assert expect.dtype == actual.dtype assert expect[0].numpy().dtype == actual[0].dtype
np.testing.assert_allclose( for i in range(3):
expect, np.testing.assert_allclose(
actual, expect[i].numpy(),
rtol=attrs.get_rtol("forward"), actual[i],
atol=attrs.get_atol("forward"), rtol=attrs.get_rtol("forward"),
) atol=attrs.get_atol("forward"),
)
expect_2 = expect_forward(x_p, n_shape, None, None).numpy() expect_2 = expect_forward(x_p, n_shape, None, None)
actual_2 = self.cal2_composite(x, n_shape, None, None)[0] actual_2 = self.cal2_composite(x, n_shape, None, None)
assert expect_2.dtype == actual_2.dtype assert expect_2[0].numpy().dtype == actual_2[0].dtype
np.testing.assert_allclose( for i in range(3):
expect_2, np.testing.assert_allclose(
actual_2, expect_2[i].numpy(),
rtol=attrs.get_rtol("forward"), actual_2[i],
atol=attrs.get_atol("forward"), rtol=attrs.get_rtol("forward"),
) atol=attrs.get_atol("forward"),
)
def test_forward(self): def test_forward(self):
for j in self.dtypes: for j in self.dtypes:
......
...@@ -492,7 +492,7 @@ class TestCompositelayer_norm(unittest.TestCase): ...@@ -492,7 +492,7 @@ class TestCompositelayer_norm(unittest.TestCase):
class TestCompositelayer_normPrimBackward(unittest.TestCase): class TestCompositelayer_normPrimBackward(unittest.TestCase):
def setUp(self): def setUp(self):
core._set_prim_backward_enabled(True) core._set_prim_backward_enabled(True)
self.dtypes = ["float16", "float32"] self.dtypes = ["float32"]
self.n_shape = [[4], [64, 128], [64]] self.n_shape = [[4], [64, 128], [64]]
self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]] self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]]
self.shape2s = [[4], [64 * 128], [64]] self.shape2s = [[4], [64 * 128], [64]]
...@@ -576,7 +576,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): ...@@ -576,7 +576,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
b_p = paddle.to_tensor(b) b_p = paddle.to_tensor(b)
y_g_p = paddle.to_tensor(y_g) y_g_p = paddle.to_tensor(y_g)
expect = dygraph_fused_backward_withNone(x_p, n_shape, w_p, b_p, y_g_p)[ expect = dygraph_fused_backward(x_p, n_shape, w_p, b_p, y_g_p)[
0 0
].numpy() ].numpy()
actual = self.static_comp_forward_and_backward(x, n_shape, w, b)[0] actual = self.static_comp_forward_and_backward(x, n_shape, w, b)[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册