提交 87b5559c 编写于 作者: C chengduoZH

fix scale and bias dim

上级 0f47703d
...@@ -38,10 +38,6 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -38,10 +38,6 @@ class LayerNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Bias"), ""); PADDLE_ENFORCE(ctx->HasInput("Bias"), "");
PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); PADDLE_ENFORCE(ctx->HasOutput("Y"), "");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], 1);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], 1);
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis"); auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis");
PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(), PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(),
...@@ -50,6 +46,11 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -50,6 +46,11 @@ class LayerNormOp : public framework::OperatorWithKernel {
auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]); int left = static_cast<int>(matrix_dim[0]);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], left);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], left);
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->SetOutputDim("Mean", {left}); ctx->SetOutputDim("Mean", {left});
ctx->SetOutputDim("Variance", {left}); ctx->SetOutputDim("Variance", {left});
...@@ -64,10 +65,10 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,10 +65,10 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor"); AddInput("X", "The input tensor");
AddInput("Scale", AddInput("Scale",
"Scale is a 1-dimensional tensor of size 1 " "Scale is a 1-dimensional tensor of size H "
"that is applied to the output"); "that is applied to the output");
AddInput("Bias", AddInput("Bias",
"Bias is a 1-dimensional tensor of size 1 " "Bias is a 1-dimensional tensor of size H "
"that is applied to the output"); "that is applied to the output");
AddOutput("Y", "result after normalization"); AddOutput("Y", "result after normalization");
AddOutput("Mean", "Mean of the current mini batch."); AddOutput("Mean", "Mean of the current mini batch.");
...@@ -110,9 +111,6 @@ class LayerNormKernel<platform::CPUDeviceContext, T> ...@@ -110,9 +111,6 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto scale_data = scale->data<T>()[0];
auto bias_data = bias->data<T>()[0];
auto *output = ctx.Output<Tensor>("Y"); auto *output = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean"); auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance"); auto *var = ctx.Output<Tensor>("Variance");
...@@ -123,7 +121,10 @@ class LayerNormKernel<platform::CPUDeviceContext, T> ...@@ -123,7 +121,10 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]); int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]); int right = static_cast<int>(matrix_dim[1]);
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right); auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), left, 1);
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), left, 1);
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1); auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1); auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right); auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right);
...@@ -138,18 +139,15 @@ class LayerNormKernel<platform::CPUDeviceContext, T> ...@@ -138,18 +139,15 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
.mean() .mean()
.unaryExpr(add_epslion); .unaryExpr(add_epslion);
auto scale_inv_std = [scale_data](T ele) { auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
return std::sqrt(1 / ele) * scale_data;
};
auto sub_bias = [bias_data](T ele) { return bias_data - ele; };
// TODO(zcd): Some thinking about output_map, is it appropriate that // TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory. // `output_map` and `input_map` point to the same memory.
output_map = (var_map.unaryExpr(scale_inv_std).replicate(1, right)) auto inv_std_scale =
.cwiseProduct(input_map) + var_map.unaryExpr(inv_std_func).cwiseProduct(scale_map);
var_map.unaryExpr(scale_inv_std) output_map =
.cwiseProduct(mean_map) inv_std_scale.replicate(1, right).cwiseProduct(input_map) +
.unaryExpr(sub_bias) (bias_map - inv_std_scale.cwiseProduct(mean_map)).replicate(1, right);
.replicate(1, right);
} }
}; };
...@@ -165,17 +163,17 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -165,17 +163,17 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Variance"), ""); PADDLE_ENFORCE(ctx->HasInput("Variance"), "");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "");
const auto x_dims = ctx->GetInputDim("X");
// check output // check output
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
if (ctx->HasOutput(framework::GradVarName("Scale"))) { if (ctx->HasOutput(framework::GradVarName("Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {1}); ctx->SetOutputDim(framework::GradVarName("Scale"),
ctx->GetInputDim("Scale"));
} }
if (ctx->HasOutput(framework::GradVarName("Bias"))) { if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"), {1}); ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
} }
} }
...@@ -210,20 +208,20 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -210,20 +208,20 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
const auto *var = ctx.Input<Tensor>("Variance"); const auto *var = ctx.Input<Tensor>("Variance");
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto scale_data = scale->data<T>()[0];
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]), int left = static_cast<int>(matrix_dim[0]);
right = static_cast<int>(matrix_dim[1]); int right = static_cast<int>(matrix_dim[1]);
// init output // init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), left, 1);
auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right); auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right); auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right);
auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1); auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
...@@ -231,36 +229,38 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -231,36 +229,38 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
if (d_bias) { if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<T>(ctx.GetPlace());
d_bias->data<T>()[0] = d_y_map.sum(); auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), left, 1);
d_bias_map = d_y_map.colwise().mean();
} }
if (d_scale) { if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace()); d_scale->mutable_data<T>(ctx.GetPlace());
auto inv_std = [](T ele) { return std::sqrt(1 / ele); }; auto d_scale_map = EigenMatrixMapRowMajor<T>(d_scale->data<T>(), left, 1);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// There are two equation to compute d_scale. One uses "Y" and the other // There are two equation to compute d_scale. One uses "Y" and the other
// does not use "Y" // does not use "Y"
d_scale->data<T>()[0] = d_scale_map =
((x_map - mean_map.replicate(1, right)) ((x_map - mean_map.replicate(1, right))
.cwiseProduct(var_map.unaryExpr(inv_std).replicate(1, right)) .cwiseProduct(
var_map.unaryExpr(inv_std_func).replicate(1, right))
.cwiseProduct(d_y_map)) .cwiseProduct(d_y_map))
.sum(); .colwise()
.mean();
} }
if (d_x) { if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right); auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product_func = [](T ele) { return ele * ele * ele; }; auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto scale_func = [scale_data](T ele) { return ele * scale_data; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
auto inv_std_scale_func = [scale_data](T ele) {
return std::sqrt(1 / ele) * scale_data;
};
// dy_dx // dy_dx
auto dx_end = var_map.unaryExpr(inv_std_scale_func) auto dx_end = var_map.unaryExpr(inv_std_func)
.cwiseProduct(scale_map)
.replicate(1, right) .replicate(1, right)
.cwiseProduct(d_y_map); .cwiseProduct(d_y_map);
// dy_dmean_dx // dy_dmean_dx
auto dx_mean = (T(-1.0) / right) * auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_scale_func) var_map.unaryExpr(inv_std_func)
.cwiseProduct(scale_map)
.replicate(1, right) .replicate(1, right)
.cwiseProduct(d_y_map) .cwiseProduct(d_y_map)
.rowwise() .rowwise()
...@@ -274,11 +274,11 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -274,11 +274,11 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto dvar_end = var_map.unaryExpr(inv_std_func) auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func) .unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part) .cwiseProduct(dvar_end_part)
.cwiseProduct(scale_map)
.replicate(1, right); .replicate(1, right);
auto dx_var = (T(-1.0) / right) * auto dx_var =
(x_map - mean_map.replicate(1, right)) (T(-1.0) / right) *
.cwiseProduct(dvar_end) (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
.unaryExpr(scale_func);
d_x_map = dx_end + dx_mean + dx_var; d_x_map = dx_end + dx_mean + dx_var;
} }
......
...@@ -39,8 +39,9 @@ def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): ...@@ -39,8 +39,9 @@ def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
x.shape = [N, D] x.shape = [N, D]
mean = np.mean(x, axis=1) mean = np.mean(x, axis=1)
var = np.var(x, axis=1) + epsilon var = np.var(x, axis=1) + epsilon
output = scale * np.divide((x - mean.reshape([N, 1])), output = scale.reshape([1, D]) * np.divide(
(np.sqrt(var)).reshape([N, 1])) + beta (x - mean.reshape([N, 1])),
(np.sqrt(var)).reshape([N, 1])) + beta.reshape([1, D])
output.shape = old_shape output.shape = old_shape
x.shape = old_shape x.shape = old_shape
return output, mean, var return output, mean, var
...@@ -55,8 +56,10 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): ...@@ -55,8 +56,10 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1):
mean.shape = [N, 1] mean.shape = [N, 1]
var.shape = [N, 1] var.shape = [N, 1]
d_scale = np.sum(grad_y).reshape([1, ]) d_scale = np.sum(grad_y, axis=1).reshape([1, D])
d_bias = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y).reshape([1, ]) d_bias = scale.reshape([1, D]) * np.sum((
(x - mean) * np.sqrt(1 / var)) * grad_y,
axis=1).reshape([1, D])
dx_end = np.sqrt(1.0 / var) * grad_y dx_end = np.sqrt(1.0 / var) * grad_y
...@@ -69,7 +72,7 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): ...@@ -69,7 +72,7 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1):
d_std = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape([N, 1]) * ( d_std = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape([N, 1]) * (
1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)) 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean))
grad_x = scale * (dx_end + d_mean + d_std) grad_x = scale.reshape([1, D]) * (dx_end + d_mean + d_std)
grad_y.shape = x_shape grad_y.shape = x_shape
x.shape = x_shape x.shape = x_shape
...@@ -146,7 +149,8 @@ class TestLayerNormdOp(OpTest): ...@@ -146,7 +149,8 @@ class TestLayerNormdOp(OpTest):
# attr # attr
epsilon = 0.00001 epsilon = 0.00001
x_shape = shape x_shape = shape
scale_shape = [1] D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
scale_shape = [D]
np.random.random(123) np.random.random(123)
x_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册