From 0f47703dd5db01a7510031e810f963e09a8c9c13 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 28 Jan 2018 16:41:18 +0800 Subject: [PATCH] add begin_norm_axis --- paddle/operators/layer_norm_op.cc | 47 ++++++++++++------- .../v2/fluid/tests/test_layer_norm_op.py | 28 ++++++----- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc index 0b0c760e57..9e618d10d2 100644 --- a/paddle/operators/layer_norm_op.cc +++ b/paddle/operators/layer_norm_op.cc @@ -42,10 +42,17 @@ class LayerNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], 1); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], 1); + auto x_dim = ctx->GetInputDim("X"); + auto begin_norm_axis = ctx->Attrs().Get("begin_norm_axis"); + PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(), + "'begin_norm_axis' must be less than the rank of X"); + + auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis); + int left = static_cast(matrix_dim[0]); ctx->SetOutputDim("Y", ctx->GetInputDim("X")); - ctx->SetOutputDim("Mean", {ctx->GetInputDim("X")[0]}); - ctx->SetOutputDim("Variance", {ctx->GetInputDim("X")[0]}); + ctx->SetOutputDim("Mean", {left}); + ctx->SetOutputDim("Variance", {left}); ctx->ShareLoD("X", "Y"); } @@ -72,10 +79,14 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, "'epsilon' should be between 0.0 and 0.001."); }); - AddAttr>("axis", - "(vector default:{1, 1, 1}), the " - "axis to normalize.") - .SetDefault({1, 2, 3}); // todo(zcd) : who to set axis + AddAttr("begin_norm_axis", + "(int default:1), the " + "axis of `begin_norm_axis ... Rank(X) - 1` will be normalized") + .SetDefault(1) + .AddCustomChecker([](const int &begin_norm_axis) { + PADDLE_ENFORCE_GT(begin_norm_axis, 0, + "'begin_norm_axis' should be greater than zero."); + }); AddComment(R"DOC( Layer Normalization. @@ -97,9 +108,7 @@ class LayerNormKernel const auto *bias = ctx.Input("Bias"); const auto *x = ctx.Input("X"); const auto &x_dims = x->dims(); - - const int N = x_dims[0]; - const int sample_size = x->numel() / N; + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); auto scale_data = scale->data()[0]; auto bias_data = bias->data()[0]; @@ -111,7 +120,9 @@ class LayerNormKernel mean->mutable_data(ctx.GetPlace()); var->mutable_data(ctx.GetPlace()); - int left = N, right = sample_size; + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); auto input_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); auto mean_map = EigenMatrixMapRowMajor(mean->data(), left, 1); auto var_map = EigenMatrixMapRowMajor(var->data(), left, 1); @@ -131,7 +142,8 @@ class LayerNormKernel return std::sqrt(1 / ele) * scale_data; }; auto sub_bias = [bias_data](T ele) { return bias_data - ele; }; - + // TODO(zcd): Some thinking about output_map, is it appropriate that + // `output_map` and `input_map` point to the same memory. output_map = (var_map.unaryExpr(scale_inv_std).replicate(1, right)) .cwiseProduct(input_map) + var_map.unaryExpr(scale_inv_std) @@ -198,13 +210,14 @@ class LayerNormGradKernel const auto *var = ctx.Input("Variance"); const auto *scale = ctx.Input("Scale"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); + auto scale_data = scale->data()[0]; const auto &x_dims = x->dims(); - const int N = x_dims[0]; - const int sample_size = x->numel() / N; - int left = N, right = sample_size; - auto scale_data = scale->data()[0]; + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]), + right = static_cast(matrix_dim[1]); // init output auto *d_x = ctx.Output(framework::GradVarName("X")); @@ -223,11 +236,13 @@ class LayerNormGradKernel if (d_scale) { d_scale->mutable_data(ctx.GetPlace()); auto inv_std = [](T ele) { return std::sqrt(1 / ele); }; + // There are two equation to compute d_scale. One uses "Y" and the other + // does not use "Y" d_scale->data()[0] = ((x_map - mean_map.replicate(1, right)) .cwiseProduct(var_map.unaryExpr(inv_std).replicate(1, right)) .cwiseProduct(d_y_map)) - .sum(); // also can use `y` to get d_scale_map + .sum(); } if (d_x) { diff --git a/python/paddle/v2/fluid/tests/test_layer_norm_op.py b/python/paddle/v2/fluid/tests/test_layer_norm_op.py index caa3b944eb..8ce327436f 100644 --- a/python/paddle/v2/fluid/tests/test_layer_norm_op.py +++ b/python/paddle/v2/fluid/tests/test_layer_norm_op.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import unittest import numpy as np @@ -33,23 +32,24 @@ def get_backward_op(scope, op, no_grad_set): return backward_op -def _reference_layer_norm_naive(x, scale, beta, epsilon): +def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): old_shape = x.shape - N = x.shape[0] - D = reduce(mul, old_shape, 1) / N + N = reduce(mul, old_shape[0:begin_norm_axis], 1) + D = reduce(mul, old_shape[begin_norm_axis:len(old_shape)], 1) x.shape = [N, D] mean = np.mean(x, axis=1) var = np.var(x, axis=1) + epsilon output = scale * np.divide((x - mean.reshape([N, 1])), (np.sqrt(var)).reshape([N, 1])) + beta output.shape = old_shape + x.shape = old_shape return output, mean, var -def _reference_layer_norm_grad(x, grad_y, scale, mean, var, epsilon): +def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): x_shape = x.shape - N = x_shape[0] - D = reduce(mul, x_shape, 1) / N + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) grad_y.shape = [N, D] x.shape = [N, D] mean.shape = [N, 1] @@ -140,7 +140,9 @@ class TestLayerNormdOp(OpTest): self.assertLessEqual(max_diff, max_relative_error, err_msg()) def test_forward_backward(self): - def test_with_place(place, shape): + def test_with_place(place, shape, begin_norm_axis=1): + assert begin_norm_axis > 0 and begin_norm_axis < len( + shape), 'begin_norm_axis must be between 0 and len(shape)-1.' # attr epsilon = 0.00001 x_shape = shape @@ -152,13 +154,13 @@ class TestLayerNormdOp(OpTest): # run forward y_out, saved_mean, var_ref = _reference_layer_norm_naive( - x_val, scale_val, bias_val, epsilon) + x_val, scale_val, bias_val, epsilon, begin_norm_axis) # for gradient test y_grad = np.random.random_sample(x_shape).astype(np.float32) x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_layer_norm_grad( - x_val, y_grad, scale_val, saved_mean, var_ref, epsilon) + x_val, y_grad, scale_val, saved_mean, var_ref, begin_norm_axis) scope = core.Scope() @@ -185,7 +187,8 @@ class TestLayerNormdOp(OpTest): Mean="Mean", Variance="Variance", # attrs - epsilon=epsilon) + epsilon=epsilon, + begin_norm_axis=begin_norm_axis) layer_norm_op.run(scope, place) @@ -228,7 +231,8 @@ class TestLayerNormdOp(OpTest): places.append(core.CUDAPlace(0)) for place in places: - test_with_place(place, [2, 3, 4, 5]) + test_with_place(place, [2, 3, 4, 5], begin_norm_axis=1) + test_with_place(place, [2, 3, 4, 5], begin_norm_axis=3) if __name__ == '__main__': -- GitLab