提交 263e0197 编写于 作者: C chengduoZH

follow comments

上级 09570b48
...@@ -33,29 +33,35 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -33,29 +33,35 @@ class LayerNormOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), ""); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); "Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), ""); PADDLE_ENFORCE(ctx->HasOutput("Y"),
PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); "Output(Y) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mean"),
"Output(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
"Output(Variance) of LayerNormOp should not be null.");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis"); auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis");
PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(), PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(),
"'begin_norm_axis' must be less than the rank of X"); "'begin_norm_axis' must be less than the rank of X.");
auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]); int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]); int right = static_cast<int>(matrix_dim[1]);
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); }
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right); if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right);
}
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->SetOutputDim("Mean", {left}); ctx->SetOutputDim("Mean", {left});
ctx->SetOutputDim("Variance", {left}); ctx->SetOutputDim("Variance", {left});
ctx->ShareLoD("X", "Y"); ctx->ShareLoD("X", "Y");
} }
}; };
...@@ -64,18 +70,26 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,18 +70,26 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker) LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor"); AddInput("X", "(LoDTensor) The input tensor.");
AddInput("Scale", AddInput("Scale",
"Scale is a 1-dimensional tensor of size H " "(Tensor, optional) Scale is a 1-dimensional tensor of size "
"that is applied to the output"); "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
"It is applied to the output.")
.AsDispensable();
AddInput("Bias", AddInput("Bias",
"Bias is a 1-dimensional tensor of size H " "(Tensor, optional) Bias is a 1-dimensional tensor of size "
"that is applied to the output"); "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
AddOutput("Y", "result after normalization"); "It is applied to the output.")
AddOutput("Mean", "Mean of the current mini batch."); .AsDispensable();
AddOutput("Variance", "Variance of the current mini batch."); AddOutput("Y", "(LoDTensor) Result after normalization.");
AddOutput("Mean", "(Tensor) Mean of the current mini batch.")
AddAttr<float>("epsilon", "") .AsIntermediate();
AddOutput("Variance", "(Tensor) Variance of the current mini batch.")
.AsIntermediate();
AddAttr<float>("epsilon",
"(float, default 1e-5) Constant for "
"numerical stability")
.SetDefault(1e-5) .SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) { .AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
...@@ -83,7 +97,9 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,7 +97,9 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddAttr<int>("begin_norm_axis", AddAttr<int>("begin_norm_axis",
"(int default:1), the " "(int default:1), the "
"axis of `begin_norm_axis ... Rank(X) - 1` will be normalized") "axis of `begin_norm_axis ... Rank(X) - 1` will be "
"normalized. `begin_norm_axis` splits the tensor(`X`) to a "
"matrix [N,H].")
.SetDefault(1) .SetDefault(1)
.AddCustomChecker([](const int &begin_norm_axis) { .AddCustomChecker([](const int &begin_norm_axis) {
PADDLE_ENFORCE_GT(begin_norm_axis, 0, PADDLE_ENFORCE_GT(begin_norm_axis, 0,
...@@ -124,8 +140,7 @@ class LayerNormKernel<platform::CPUDeviceContext, T> ...@@ -124,8 +140,7 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
int right = static_cast<int>(matrix_dim[1]); int right = static_cast<int>(matrix_dim[1]);
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right); auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1); auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1); auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right); auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right);
...@@ -141,14 +156,32 @@ class LayerNormKernel<platform::CPUDeviceContext, T> ...@@ -141,14 +156,32 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
.unaryExpr(add_epslion); .unaryExpr(add_epslion);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// TODO(zcd): Some thinking about output_map, is it appropriate that // TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory. // `output_map` and `input_map` point to the same memory.
auto inv_std = var_map.unaryExpr(inv_std_func); auto inv_std = var_map.unaryExpr(inv_std_func);
output_map = (input_map - mean_map.replicate(1, right)) if (scale && bias) {
.cwiseProduct(inv_std.replicate(1, right)) auto scale_map =
.cwiseProduct(scale_map.replicate(left, 1)) + ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
bias_map.replicate(left, 1); auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1)) +
bias_map.replicate(left, 1);
} else if (scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1));
} else if (bias) {
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right)) +
bias_map.replicate(left, 1);
} else {
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right));
}
} }
}; };
...@@ -158,11 +191,16 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -158,11 +191,16 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// check input // check input
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); "Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"), ""); PADDLE_ENFORCE(ctx->HasInput("Scale"),
PADDLE_ENFORCE(ctx->HasInput("Variance"), ""); "Input(Scale) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) of LayerNormOp should not be null.");
// check output // check output
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
...@@ -222,7 +260,6 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -222,7 +260,6 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right); auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right); auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right);
auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1); auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
...@@ -254,35 +291,67 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -254,35 +291,67 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right); auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product_func = [](T ele) { return ele * ele * ele; }; auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// dy_dx // TODO(zcd): these code can be refined
auto dx_end = var_map.unaryExpr(inv_std_func) if (d_scale) {
.replicate(1, right) auto scale_map =
.cwiseProduct(d_y_map) ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
.cwiseProduct(scale_map.replicate(left, 1)); // dy_dx
// dy_dmean_dx auto dx_end = var_map.unaryExpr(inv_std_func)
auto dx_mean = (T(-1.0) / right) * .replicate(1, right)
var_map.unaryExpr(inv_std_func) .cwiseProduct(d_y_map)
.replicate(1, right) .cwiseProduct(scale_map.replicate(left, 1));
.cwiseProduct(d_y_map) // dy_dmean_dx
.cwiseProduct(scale_map.replicate(left, 1)) auto dx_mean = (T(-1.0) / right) *
.rowwise() var_map.unaryExpr(inv_std_func)
.sum() .replicate(1, right)
.replicate(1, right); .cwiseProduct(d_y_map)
// dy_var_dx .cwiseProduct(scale_map.replicate(left, 1))
auto dvar_end_part = (x_map - mean_map.replicate(1, right)) .rowwise()
.cwiseProduct(scale_map.replicate(left, 1)) .sum()
.cwiseProduct(d_y_map) .replicate(1, right);
.rowwise() // dy_var_dx
.sum(); auto dvar_end_part = (x_map - mean_map.replicate(1, right))
auto dvar_end = var_map.unaryExpr(inv_std_func) .cwiseProduct(scale_map.replicate(left, 1))
.unaryExpr(triple_product_func) .cwiseProduct(d_y_map)
.cwiseProduct(dvar_end_part) .rowwise()
.replicate(1, right); .sum();
auto dx_var = auto dvar_end = var_map.unaryExpr(inv_std_func)
(T(-1.0) / right) * .unaryExpr(triple_product_func)
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); .cwiseProduct(dvar_end_part)
.replicate(1, right);
d_x_map = dx_end + dx_mean + dx_var; auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
} else {
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map);
// dy_dmean_dx
auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.rowwise()
.sum()
.replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
}
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册