提交 71a70f20 编写于 作者: C chengduoZH

refine gradient

上级 7e695ce8
...@@ -291,32 +291,28 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -291,32 +291,28 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right); auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product_func = [](T ele) { return ele * ele * ele; }; auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
auto inv_std_map = var_map.unaryExpr(inv_std_func).eval();
// TODO(zcd): these code can be refined // TODO(zcd): these code can be refined
if (d_scale) { if (d_scale) {
auto scale_map = auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right); ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
// dy_dx // dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func) auto dx_end =
.replicate(1, right) inv_std_map.replicate(1, right).cwiseProduct(d_y_map).cwiseProduct(
.cwiseProduct(d_y_map) scale_map.replicate(left, 1));
.cwiseProduct(scale_map.replicate(left, 1));
// dy_dmean_dx // dy_dmean_dx
auto dx_mean = (T(-1.0) / right) * auto dx_mean =
var_map.unaryExpr(inv_std_func) (T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
.replicate(1, right)
.cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1))
.rowwise()
.sum()
.replicate(1, right);
// dy_var_dx // dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right)) auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1)) .cwiseProduct(scale_map.replicate(left, 1))
.cwiseProduct(d_y_map) .cwiseProduct(d_y_map)
.rowwise() .rowwise()
.sum(); .sum();
auto dvar_end = var_map.unaryExpr(inv_std_func) auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part) .cwiseProduct(dvar_end_part)
.replicate(1, right); .replicate(1, right);
auto dx_var = auto dx_var =
...@@ -326,24 +322,18 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T> ...@@ -326,24 +322,18 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
d_x_map = dx_end + dx_mean + dx_var; d_x_map = dx_end + dx_mean + dx_var;
} else { } else {
// dy_dx // dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func) auto dx_end = inv_std_map.replicate(1, right).cwiseProduct(d_y_map);
.replicate(1, right)
.cwiseProduct(d_y_map);
// dy_dmean_dx // dy_dmean_dx
auto dx_mean = (T(-1.0) / right) * auto dx_mean =
var_map.unaryExpr(inv_std_func) (T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
.replicate(1, right)
.cwiseProduct(d_y_map)
.rowwise()
.sum()
.replicate(1, right);
// dy_var_dx // dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right)) auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map) .cwiseProduct(d_y_map)
.rowwise() .rowwise()
.sum(); .sum();
auto dvar_end = var_map.unaryExpr(inv_std_func) auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part) .cwiseProduct(dvar_end_part)
.replicate(1, right); .replicate(1, right);
auto dx_var = auto dx_var =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册