提交 71a70f20 编写于 作者: C chengduoZH

refine gradient

上级 7e695ce8
......@@ -291,32 +291,28 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
auto inv_std_map = var_map.unaryExpr(inv_std_func).eval();
// TODO(zcd): these code can be refined
if (d_scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1));
auto dx_end =
inv_std_map.replicate(1, right).cwiseProduct(d_y_map).cwiseProduct(
scale_map.replicate(left, 1));
// dy_dmean_dx
auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1))
.rowwise()
.sum()
.replicate(1, right);
auto dx_mean =
(T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
......@@ -326,24 +322,18 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
d_x_map = dx_end + dx_mean + dx_var;
} else {
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map);
auto dx_end = inv_std_map.replicate(1, right).cwiseProduct(d_y_map);
// dy_dmean_dx
auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.rowwise()
.sum()
.replicate(1, right);
auto dx_mean =
(T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册