提交 7e0d21de 编写于 作者: C chengduoZH

fix scale and bias dim

上级 87b5559c
...@@ -123,8 +123,8 @@ class LayerNormKernel<platform::CPUDeviceContext, T> ...@@ -123,8 +123,8 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
int right = static_cast<int>(matrix_dim[1]); int right = static_cast<int>(matrix_dim[1]);
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right); auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), left, 1); auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), left, 1); auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1); auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1); auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right); auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right);
...@@ -143,11 +143,11 @@ class LayerNormKernel<platform::CPUDeviceContext, T> ...@@ -143,11 +143,11 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
// TODO(zcd): Some thinking about output_map, is it appropriate that // TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory. // `output_map` and `input_map` point to the same memory.
auto inv_std_scale = auto inv_std_scale = var_map.unaryExpr(inv_std_func);
var_map.unaryExpr(inv_std_func).cwiseProduct(scale_map); output_map = (input_map - mean_map.replicate(1, right))
output_map = .cwiseProduct(inv_std_scale.replicate(1, right))
inv_std_scale.replicate(1, right).cwiseProduct(input_map) + .cwiseProduct(scale_map.replicate(left, 1)) -
(bias_map - inv_std_scale.cwiseProduct(mean_map)).replicate(1, right); bias_map.replicate(left, 1);
} }
}; };
...@@ -221,7 +221,7 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -221,7 +221,7 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), left, 1); auto scale_map = ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right); auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right); auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right);
auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1); auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
...@@ -229,12 +229,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -229,12 +229,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
if (d_bias) { if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), left, 1); auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), 1, right);
d_bias_map = d_y_map.colwise().mean(); d_bias_map = d_y_map.colwise().mean();
} }
if (d_scale) { if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace()); d_scale->mutable_data<T>(ctx.GetPlace());
auto d_scale_map = EigenMatrixMapRowMajor<T>(d_scale->data<T>(), left, 1); auto d_scale_map =
EigenMatrixMapRowMajor<T>(d_scale->data<T>(), 1, right);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// There are two equation to compute d_scale. One uses "Y" and the other // There are two equation to compute d_scale. One uses "Y" and the other
// does not use "Y" // does not use "Y"
...@@ -254,15 +255,15 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -254,15 +255,15 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// dy_dx // dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func) auto dx_end = var_map.unaryExpr(inv_std_func)
.cwiseProduct(scale_map)
.replicate(1, right) .replicate(1, right)
.cwiseProduct(d_y_map); .cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1));
// dy_dmean_dx // dy_dmean_dx
auto dx_mean = (T(-1.0) / right) * auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_func) var_map.unaryExpr(inv_std_func)
.cwiseProduct(scale_map)
.replicate(1, right) .replicate(1, right)
.cwiseProduct(d_y_map) .cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1))
.rowwise() .rowwise()
.sum() .sum()
.replicate(1, right); .replicate(1, right);
...@@ -274,8 +275,8 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -274,8 +275,8 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto dvar_end = var_map.unaryExpr(inv_std_func) auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func) .unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part) .cwiseProduct(dvar_end_part)
.cwiseProduct(scale_map) .replicate(1, right)
.replicate(1, right); .cwiseProduct(scale_map.replicate(left, 1));
auto dx_var = auto dx_var =
(T(-1.0) / right) * (T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册