提交 09570b48 编写于 作者: C chengduoZH

layer norm -> scale + bias

上级 7e0d21de
......@@ -45,11 +45,12 @@ class LayerNormOp : public framework::OperatorWithKernel {
auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], left);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], left);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right);
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->SetOutputDim("Mean", {left});
......@@ -143,10 +144,10 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
// TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory.
auto inv_std_scale = var_map.unaryExpr(inv_std_func);
auto inv_std = var_map.unaryExpr(inv_std_func);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std_scale.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1)) -
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1)) +
bias_map.replicate(left, 1);
}
};
......@@ -230,7 +231,7 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), 1, right);
d_bias_map = d_y_map.colwise().mean();
d_bias_map = d_y_map.colwise().sum();
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
......@@ -245,7 +246,7 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
var_map.unaryExpr(inv_std_func).replicate(1, right))
.cwiseProduct(d_y_map))
.colwise()
.mean();
.sum();
}
if (d_x) {
......@@ -269,14 +270,14 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
.replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right)
.cwiseProduct(scale_map.replicate(left, 1));
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
......
......@@ -49,35 +49,38 @@ def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1):
x_shape = x.shape
scale_shape = scale.shape
N = reduce(mul, x_shape[0:begin_norm_axis], 1)
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
grad_y.shape = [N, D]
x.shape = [N, D]
mean.shape = [N, 1]
var.shape = [N, 1]
scale.shape = [1, D]
d_scale = np.sum(grad_y, axis=1).reshape([1, D])
d_bias = scale.reshape([1, D]) * np.sum((
(x - mean) * np.sqrt(1 / var)) * grad_y,
axis=1).reshape([1, D])
d_bias = np.sum(grad_y, axis=0).reshape([1, D])
d_scale = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y,
axis=0).reshape([1, D])
dx_end = np.sqrt(1.0 / var) * grad_y
dx_end = scale * np.sqrt(1.0 / var) * grad_y
d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y, axis=1).reshape([N, 1])
d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape(
[N, 1])
# d_mean_1 = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape(
# [N, 1]) * (-1.0 / D * np.sqrt(1.0 / var) *
# np.sum(x - mean, axis=1).reshape([N, 1])).reshape([N, 1])
d_mean = 1.0 / D * (d_mean_0)
d_mean = 1.0 / D * d_mean_0
d_std = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape([N, 1]) * (
d_std = np.sum(
-1.0 / var * (x - mean) * grad_y * scale, axis=1).reshape([N, 1]) * (
1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean))
grad_x = scale.reshape([1, D]) * (dx_end + d_mean + d_std)
grad_x = dx_end + d_mean + d_std
grad_y.shape = x_shape
x.shape = x_shape
return grad_x, d_bias, d_scale
scale.shape = scale_shape
return grad_x, d_scale, d_bias
def create_or_get_tensor(scope, var_name, var, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册