diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc index 5821afe9f681002c20886a70333b872a16a9a8dd..1c6d2ae4d05becaeed34d66cad398cc90f9d3ece 100644 --- a/paddle/operators/layer_norm_op.cc +++ b/paddle/operators/layer_norm_op.cc @@ -33,29 +33,35 @@ class LayerNormOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), ""); - PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); - PADDLE_ENFORCE(ctx->HasInput("Bias"), ""); - PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Mean"), + "Output(Mean) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Variance"), + "Output(Variance) of LayerNormOp should not be null."); auto x_dim = ctx->GetInputDim("X"); auto begin_norm_axis = ctx->Attrs().Get("begin_norm_axis"); PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(), - "'begin_norm_axis' must be less than the rank of X"); + "'begin_norm_axis' must be less than the rank of X."); auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis); int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); - - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right); + if (ctx->HasInput("Scale")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right); + } + if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right); + } ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Mean", {left}); ctx->SetOutputDim("Variance", {left}); - ctx->ShareLoD("X", "Y"); } }; @@ -64,18 +70,26 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { public: LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensor"); + AddInput("X", "(LoDTensor) The input tensor."); AddInput("Scale", - "Scale is a 1-dimensional tensor of size H " - "that is applied to the output"); + "(Tensor, optional) Scale is a 1-dimensional tensor of size " + "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." + "It is applied to the output.") + .AsDispensable(); AddInput("Bias", - "Bias is a 1-dimensional tensor of size H " - "that is applied to the output"); - AddOutput("Y", "result after normalization"); - AddOutput("Mean", "Mean of the current mini batch."); - AddOutput("Variance", "Variance of the current mini batch."); - - AddAttr("epsilon", "") + "(Tensor, optional) Bias is a 1-dimensional tensor of size " + "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." + "It is applied to the output.") + .AsDispensable(); + AddOutput("Y", "(LoDTensor) Result after normalization."); + AddOutput("Mean", "(Tensor) Mean of the current mini batch.") + .AsIntermediate(); + AddOutput("Variance", "(Tensor) Variance of the current mini batch.") + .AsIntermediate(); + + AddAttr("epsilon", + "(float, default 1e-5) Constant for " + "numerical stability") .SetDefault(1e-5) .AddCustomChecker([](const float &epsilon) { PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, @@ -83,7 +97,9 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { }); AddAttr("begin_norm_axis", "(int default:1), the " - "axis of `begin_norm_axis ... Rank(X) - 1` will be normalized") + "axis of `begin_norm_axis ... Rank(X) - 1` will be " + "normalized. `begin_norm_axis` splits the tensor(`X`) to a " + "matrix [N,H].") .SetDefault(1) .AddCustomChecker([](const int &begin_norm_axis) { PADDLE_ENFORCE_GT(begin_norm_axis, 0, @@ -124,8 +140,7 @@ class LayerNormKernel int right = static_cast(matrix_dim[1]); auto input_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); - auto scale_map = ConstEigenMatrixMapRowMajor(scale->data(), 1, right); - auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right); + auto mean_map = EigenMatrixMapRowMajor(mean->data(), left, 1); auto var_map = EigenMatrixMapRowMajor(var->data(), left, 1); auto output_map = EigenMatrixMapRowMajor(output->data(), left, right); @@ -141,14 +156,32 @@ class LayerNormKernel .unaryExpr(add_epslion); auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; - // TODO(zcd): Some thinking about output_map, is it appropriate that // `output_map` and `input_map` point to the same memory. auto inv_std = var_map.unaryExpr(inv_std_func); - output_map = (input_map - mean_map.replicate(1, right)) - .cwiseProduct(inv_std.replicate(1, right)) - .cwiseProduct(scale_map.replicate(left, 1)) + - bias_map.replicate(left, 1); + if (scale && bias) { + auto scale_map = + ConstEigenMatrixMapRowMajor(scale->data(), 1, right); + auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right); + output_map = (input_map - mean_map.replicate(1, right)) + .cwiseProduct(inv_std.replicate(1, right)) + .cwiseProduct(scale_map.replicate(left, 1)) + + bias_map.replicate(left, 1); + } else if (scale) { + auto scale_map = + ConstEigenMatrixMapRowMajor(scale->data(), 1, right); + output_map = (input_map - mean_map.replicate(1, right)) + .cwiseProduct(inv_std.replicate(1, right)) + .cwiseProduct(scale_map.replicate(left, 1)); + } else if (bias) { + auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right); + output_map = (input_map - mean_map.replicate(1, right)) + .cwiseProduct(inv_std.replicate(1, right)) + + bias_map.replicate(left, 1); + } else { + output_map = (input_map - mean_map.replicate(1, right)) + .cwiseProduct(inv_std.replicate(1, right)); + } } }; @@ -158,11 +191,16 @@ class LayerNormGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { // check input - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); - PADDLE_ENFORCE(ctx->HasInput("Mean"), ""); - PADDLE_ENFORCE(ctx->HasInput("Variance"), ""); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Scale"), + "Input(Scale) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Mean"), + "Input(Mean) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Variance"), + "Input(Variance) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) of LayerNormOp should not be null."); // check output if (ctx->HasOutput(framework::GradVarName("X"))) { @@ -222,7 +260,6 @@ class LayerNormGradKernel auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - auto scale_map = ConstEigenMatrixMapRowMajor(scale->data(), 1, right); auto x_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); auto d_y_map = ConstEigenMatrixMapRowMajor(d_y->data(), left, right); auto mean_map = ConstEigenMatrixMapRowMajor(mean->data(), left, 1); @@ -254,35 +291,67 @@ class LayerNormGradKernel auto d_x_map = EigenMatrixMapRowMajor(d_x->data(), left, right); auto triple_product_func = [](T ele) { return ele * ele * ele; }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; - // dy_dx - auto dx_end = var_map.unaryExpr(inv_std_func) - .replicate(1, right) - .cwiseProduct(d_y_map) - .cwiseProduct(scale_map.replicate(left, 1)); - // dy_dmean_dx - auto dx_mean = (T(-1.0) / right) * - var_map.unaryExpr(inv_std_func) - .replicate(1, right) - .cwiseProduct(d_y_map) - .cwiseProduct(scale_map.replicate(left, 1)) - .rowwise() - .sum() - .replicate(1, right); - // dy_var_dx - auto dvar_end_part = (x_map - mean_map.replicate(1, right)) - .cwiseProduct(scale_map.replicate(left, 1)) - .cwiseProduct(d_y_map) - .rowwise() - .sum(); - auto dvar_end = var_map.unaryExpr(inv_std_func) - .unaryExpr(triple_product_func) - .cwiseProduct(dvar_end_part) - .replicate(1, right); - auto dx_var = - (T(-1.0) / right) * - (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); - - d_x_map = dx_end + dx_mean + dx_var; + // TODO(zcd): these code can be refined + if (d_scale) { + auto scale_map = + ConstEigenMatrixMapRowMajor(scale->data(), 1, right); + // dy_dx + auto dx_end = var_map.unaryExpr(inv_std_func) + .replicate(1, right) + .cwiseProduct(d_y_map) + .cwiseProduct(scale_map.replicate(left, 1)); + // dy_dmean_dx + auto dx_mean = (T(-1.0) / right) * + var_map.unaryExpr(inv_std_func) + .replicate(1, right) + .cwiseProduct(d_y_map) + .cwiseProduct(scale_map.replicate(left, 1)) + .rowwise() + .sum() + .replicate(1, right); + // dy_var_dx + auto dvar_end_part = (x_map - mean_map.replicate(1, right)) + .cwiseProduct(scale_map.replicate(left, 1)) + .cwiseProduct(d_y_map) + .rowwise() + .sum(); + auto dvar_end = var_map.unaryExpr(inv_std_func) + .unaryExpr(triple_product_func) + .cwiseProduct(dvar_end_part) + .replicate(1, right); + auto dx_var = + (T(-1.0) / right) * + (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); + + d_x_map = dx_end + dx_mean + dx_var; + } else { + // dy_dx + auto dx_end = var_map.unaryExpr(inv_std_func) + .replicate(1, right) + .cwiseProduct(d_y_map); + // dy_dmean_dx + auto dx_mean = (T(-1.0) / right) * + var_map.unaryExpr(inv_std_func) + .replicate(1, right) + .cwiseProduct(d_y_map) + .rowwise() + .sum() + .replicate(1, right); + // dy_var_dx + auto dvar_end_part = (x_map - mean_map.replicate(1, right)) + .cwiseProduct(d_y_map) + .rowwise() + .sum(); + auto dvar_end = var_map.unaryExpr(inv_std_func) + .unaryExpr(triple_product_func) + .cwiseProduct(dvar_end_part) + .replicate(1, right); + auto dx_var = + (T(-1.0) / right) * + (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); + + d_x_map = dx_end + dx_mean + dx_var; + } } } };