未验证 提交 1da67779 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】 optimize layer_norm_grad rules (#52308)

* add to sub & delete full scale

* decrease 1_div_shape_2 compute

* x_sub_mean_mul_sqrt_var_1

* delete log

* add mean var test

* nothing
上级 53f5edbd
......@@ -942,43 +942,60 @@ void layer_norm_grad(const Tensor& x,
auto scale_ptr = scale.get_ptr();
auto bias_ptr = bias.get_ptr();
// cast dtype to float32 if dtype =float16
Tensor x_cast = x;
Tensor out_grad_cast = out_grad;
auto x_cast = reshape<T>(x, std::vector<int64_t>({shape_1, shape_2}));
auto out_grad_cast =
reshape<T>(out_grad, std::vector<int64_t>({shape_1, shape_2}));
auto mean_ = reshape<T>(mean, std::vector<int64_t>({shape_1, 1}));
auto variance_ = reshape<T>(variance, std::vector<int64_t>({shape_1, 1}));
Tensor scale_cast;
if (scale_ptr) {
scale_cast = reshape<T>(*scale_ptr, std::vector<int64_t>({1, shape_2}));
}
// cast dtype to float32 if dtype =float16
if (x.dtype() == phi::DataType::FLOAT16) {
x_cast = cast<T>(x, phi::DataType::FLOAT32);
out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
x_cast = cast<T>(x_cast, phi::DataType::FLOAT32);
out_grad_cast = cast<T>(out_grad_cast, phi::DataType::FLOAT32);
if (scale_ptr) {
scale_cast = cast<T>(scale_cast, phi::DataType::FLOAT32);
}
}
x_cast = reshape<T>(x_cast, std::vector<int64_t>({shape_1, shape_2}));
out_grad_cast =
reshape<T>(out_grad_cast, std::vector<int64_t>({shape_1, shape_2}));
auto mean_ = reshape<T>(mean, std::vector<int64_t>({shape_1, 1}));
auto variance_ = reshape<T>(variance, std::vector<int64_t>({shape_1, 1}));
if (bias_grad) {
if (bias_ptr) {
auto bias_grad_tmp =
out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape());
set_output<T>(bias_grad_tmp, bias_grad);
} else {
bias_grad = nullptr;
auto x_sub_mean = x_cast - mean_; // M,N
auto tmp = (1.0 / (variance_ + epsilon)); // M,1
auto sqrt_var_1 = sqrt<T>(tmp); // M,1
auto x_sub_mean_mul_sqrt_var_1 = x_sub_mean * sqrt_var_1;
if (x_grad) {
auto out_grad_scale = out_grad_cast; // M,N
if (scale_ptr) {
out_grad_scale = out_grad_cast * scale_cast; // M,N * 1,N = M,N
}
auto dx_end = sqrt_var_1 * out_grad_scale;
auto d_mean =
dx_end.sum(std::vector<int64_t>({1}), x_cast.dtype(), true); // M,1
auto d_std_1 =
(tmp * x_sub_mean * out_grad_scale)
.sum(std::vector<int64_t>({1}), x_cast.dtype(), true); // M,1
auto d_std = d_std_1 * x_sub_mean_mul_sqrt_var_1; // M,1 * M,N = M,N
auto d_mean_d_std = (1.0 / shape_2) * (d_mean + d_std);
auto x_grad_tmp = dx_end - d_mean_d_std;
x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims()));
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_tmp = cast<T>(x_grad_tmp, x.dtype());
}
set_output<T>(x_grad_tmp, x_grad);
}
auto x_sub_mean = x_cast - mean_;
auto tmp = (1.0 / (variance_ + epsilon));
auto sqrt_var_1 = sqrt<T>(tmp);
if (scale_grad) {
if (scale_ptr) {
auto scale_grad_tmp =
(x_sub_mean * sqrt_var_1 * out_grad_cast)
(x_sub_mean_mul_sqrt_var_1 * out_grad_cast)
.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
scale_grad_tmp = reshape<T>(scale_grad_tmp, scale_ptr->shape());
set_output<T>(scale_grad_tmp, scale_grad);
......@@ -987,29 +1004,15 @@ void layer_norm_grad(const Tensor& x,
}
}
if (x_grad) {
if (!scale_ptr) {
scale_cast =
full<T>(std::vector<int64_t>({1, shape_2}), 1.0, x_cast.dtype());
}
auto out_grad_scale = out_grad_cast * scale_cast;
auto dx_end = (sqrt_var_1 * out_grad_scale);
auto d_mean_0 =
(-dx_end).sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
auto d_mean = (1.0 / shape_2) * d_mean_0;
auto d_std_1 = (-tmp * x_sub_mean * out_grad_scale)
.sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
auto d_std_2 = (1.0 / shape_2) * sqrt_var_1;
d_std_2 = reshape<T>(d_std_2, std::vector<int64_t>({shape_1, 1}));
d_std_2 = d_std_2 * x_sub_mean;
auto d_std = d_std_1 * d_std_2;
auto x_grad_tmp = dx_end + d_mean + d_std;
x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims()));
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_tmp = cast<T>(x_grad_tmp, x.dtype());
if (bias_grad) {
if (bias_ptr) {
auto bias_grad_tmp =
out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape());
set_output<T>(bias_grad_tmp, bias_grad);
} else {
bias_grad = nullptr;
}
set_output<T>(x_grad_tmp, x_grad);
}
}
......
......@@ -18,8 +18,10 @@ import numpy as np
from utils import SUB_TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle import _C_ops
from paddle.fluid import core, framework
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle.incubate.autograd import primapi
from paddle.nn import LayerNorm
......@@ -32,6 +34,70 @@ def generate_data(shape1, shape2, shape3, dtype="float32"):
return np_data1, np_data2, np_data3
def layer_norm_wrapper(
x, normalized_shape, weight=None, bias=None, epsilon=1e-05, name=None
):
input_shape = list(x.shape)
input_ndim = len(input_shape)
normalized_ndim = len(normalized_shape)
begin_norm_axis = input_ndim - normalized_ndim
if (
input_ndim < normalized_ndim
or input_shape[begin_norm_axis:] != normalized_shape
):
str_normalized_shape = str(normalized_shape)
raise ValueError(
'Given normalized_shape is '
+ str_normalized_shape
+ ', expected input with shape [*, '
+ str_normalized_shape[1:]
+ ', but got input shape '
+ str(input_shape)
)
if in_dygraph_mode():
return _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis)
else:
inputs = {}
inputs['X'] = [x]
if weight:
inputs['Scale'] = [weight]
if bias:
inputs['Bias'] = [bias]
attrs = {"epsilon": epsilon, "begin_norm_axis": begin_norm_axis}
# create output
helper = LayerHelper('layer_norm', **locals())
from paddle.fluid.data_feeder import convert_dtype
param_dtype = (
x.dtype if convert_dtype(x.dtype) != 'float16' else 'float32'
)
mean_out = helper.create_variable_for_type_inference(
dtype=param_dtype, stop_gradient=True
)
variance_out = helper.create_variable_for_type_inference(
dtype=param_dtype, stop_gradient=True
)
layer_norm_out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": layer_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis},
)
return layer_norm_out, mean_out, variance_out
class Attr:
def __init__(self) -> None:
self.dtype = None
......@@ -64,7 +130,7 @@ attrs = Attr()
def fn(x, norm_shape, w, b):
return F.layer_norm(x, norm_shape, w, b)
return layer_norm_wrapper(x, norm_shape, w, b)
def expect_forward(x, norm_shape, w, b):
......@@ -92,7 +158,7 @@ class TestCompositelayer_norm(unittest.TestCase):
'w', shape=weight.shape, dtype=str(weight.dtype)
)
b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype))
y = fn(x, norm_shape, w, b)
out, mean, var = fn(x, norm_shape, w, b)
blocks = main_program.blocks
......@@ -115,7 +181,7 @@ class TestCompositelayer_norm(unittest.TestCase):
'w': weight,
'b': bias,
},
fetch_list=[y],
fetch_list=[out, mean, var],
)
paddle.disable_static()
core._set_prim_forward_enabled(False)
......@@ -131,7 +197,7 @@ class TestCompositelayer_norm(unittest.TestCase):
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
y = fn(x, norm_shape, weight, bias)
out, mean, var = fn(x, norm_shape, weight, bias)
blocks = main_program.blocks
......@@ -152,7 +218,7 @@ class TestCompositelayer_norm(unittest.TestCase):
feed={
'x': inputs,
},
fetch_list=[y],
fetch_list=[out, mean, var],
)
paddle.disable_static()
core._set_prim_forward_enabled(False)
......@@ -167,23 +233,25 @@ class TestCompositelayer_norm(unittest.TestCase):
w_p = paddle.to_tensor(w)
b_p = paddle.to_tensor(b)
expect = expect_forward(x_p, n_shape, w_p, b_p).numpy()
actual = self.cal_composite(x, n_shape, w, b)[0]
expect = expect_forward(x_p, n_shape, w_p, b_p)
actual = self.cal_composite(x, n_shape, w, b)
assert expect.dtype == actual.dtype
assert expect[0].numpy().dtype == actual[0].dtype
for i in range(3):
np.testing.assert_allclose(
expect,
actual,
expect[i].numpy(),
actual[i],
rtol=attrs.get_rtol("forward"),
atol=attrs.get_atol("forward"),
)
expect_2 = expect_forward(x_p, n_shape, None, None).numpy()
actual_2 = self.cal2_composite(x, n_shape, None, None)[0]
assert expect_2.dtype == actual_2.dtype
expect_2 = expect_forward(x_p, n_shape, None, None)
actual_2 = self.cal2_composite(x, n_shape, None, None)
assert expect_2[0].numpy().dtype == actual_2[0].dtype
for i in range(3):
np.testing.assert_allclose(
expect_2,
actual_2,
expect_2[i].numpy(),
actual_2[i],
rtol=attrs.get_rtol("forward"),
atol=attrs.get_atol("forward"),
)
......
......@@ -492,7 +492,7 @@ class TestCompositelayer_norm(unittest.TestCase):
class TestCompositelayer_normPrimBackward(unittest.TestCase):
def setUp(self):
core._set_prim_backward_enabled(True)
self.dtypes = ["float16", "float32"]
self.dtypes = ["float32"]
self.n_shape = [[4], [64, 128], [64]]
self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]]
self.shape2s = [[4], [64 * 128], [64]]
......@@ -576,7 +576,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
b_p = paddle.to_tensor(b)
y_g_p = paddle.to_tensor(y_g)
expect = dygraph_fused_backward_withNone(x_p, n_shape, w_p, b_p, y_g_p)[
expect = dygraph_fused_backward(x_p, n_shape, w_p, b_p, y_g_p)[
0
].numpy()
actual = self.static_comp_forward_and_backward(x, n_shape, w, b)[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册