From 09570b48dd40a52009b66e93af6108cb308e361d Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 30 Jan 2018 15:22:52 +0800 Subject: [PATCH] layer norm -> scale + bias --- paddle/operators/layer_norm_op.cc | 19 ++++++------- .../v2/fluid/tests/test_layer_norm_op.py | 27 ++++++++++--------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc index 125ac9f53ff..5821afe9f68 100644 --- a/paddle/operators/layer_norm_op.cc +++ b/paddle/operators/layer_norm_op.cc @@ -45,11 +45,12 @@ class LayerNormOp : public framework::OperatorWithKernel { auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis); int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], left); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], left); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right); ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Mean", {left}); @@ -143,10 +144,10 @@ class LayerNormKernel // TODO(zcd): Some thinking about output_map, is it appropriate that // `output_map` and `input_map` point to the same memory. - auto inv_std_scale = var_map.unaryExpr(inv_std_func); + auto inv_std = var_map.unaryExpr(inv_std_func); output_map = (input_map - mean_map.replicate(1, right)) - .cwiseProduct(inv_std_scale.replicate(1, right)) - .cwiseProduct(scale_map.replicate(left, 1)) - + .cwiseProduct(inv_std.replicate(1, right)) + .cwiseProduct(scale_map.replicate(left, 1)) + bias_map.replicate(left, 1); } }; @@ -230,7 +231,7 @@ class LayerNormGradKernel if (d_bias) { d_bias->mutable_data(ctx.GetPlace()); auto d_bias_map = EigenMatrixMapRowMajor(d_bias->data(), 1, right); - d_bias_map = d_y_map.colwise().mean(); + d_bias_map = d_y_map.colwise().sum(); } if (d_scale) { d_scale->mutable_data(ctx.GetPlace()); @@ -245,7 +246,7 @@ class LayerNormGradKernel var_map.unaryExpr(inv_std_func).replicate(1, right)) .cwiseProduct(d_y_map)) .colwise() - .mean(); + .sum(); } if (d_x) { @@ -269,14 +270,14 @@ class LayerNormGradKernel .replicate(1, right); // dy_var_dx auto dvar_end_part = (x_map - mean_map.replicate(1, right)) + .cwiseProduct(scale_map.replicate(left, 1)) .cwiseProduct(d_y_map) .rowwise() .sum(); auto dvar_end = var_map.unaryExpr(inv_std_func) .unaryExpr(triple_product_func) .cwiseProduct(dvar_end_part) - .replicate(1, right) - .cwiseProduct(scale_map.replicate(left, 1)); + .replicate(1, right); auto dx_var = (T(-1.0) / right) * (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); diff --git a/python/paddle/v2/fluid/tests/test_layer_norm_op.py b/python/paddle/v2/fluid/tests/test_layer_norm_op.py index 9264cf4b799..d27d1d8138c 100644 --- a/python/paddle/v2/fluid/tests/test_layer_norm_op.py +++ b/python/paddle/v2/fluid/tests/test_layer_norm_op.py @@ -49,35 +49,38 @@ def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): x_shape = x.shape + scale_shape = scale.shape N = reduce(mul, x_shape[0:begin_norm_axis], 1) D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) grad_y.shape = [N, D] x.shape = [N, D] mean.shape = [N, 1] var.shape = [N, 1] + scale.shape = [1, D] - d_scale = np.sum(grad_y, axis=1).reshape([1, D]) - d_bias = scale.reshape([1, D]) * np.sum(( - (x - mean) * np.sqrt(1 / var)) * grad_y, - axis=1).reshape([1, D]) + d_bias = np.sum(grad_y, axis=0).reshape([1, D]) + d_scale = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y, + axis=0).reshape([1, D]) - dx_end = np.sqrt(1.0 / var) * grad_y + dx_end = scale * np.sqrt(1.0 / var) * grad_y - d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y, axis=1).reshape([N, 1]) + d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape( + [N, 1]) # d_mean_1 = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape( # [N, 1]) * (-1.0 / D * np.sqrt(1.0 / var) * # np.sum(x - mean, axis=1).reshape([N, 1])).reshape([N, 1]) - d_mean = 1.0 / D * (d_mean_0) + d_mean = 1.0 / D * d_mean_0 - d_std = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape([N, 1]) * ( - 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)) + d_std = np.sum( + -1.0 / var * (x - mean) * grad_y * scale, axis=1).reshape([N, 1]) * ( + 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)) - grad_x = scale.reshape([1, D]) * (dx_end + d_mean + d_std) + grad_x = dx_end + d_mean + d_std grad_y.shape = x_shape x.shape = x_shape - - return grad_x, d_bias, d_scale + scale.shape = scale_shape + return grad_x, d_scale, d_bias def create_or_get_tensor(scope, var_name, var, place): -- GitLab