diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc index 1c6d2ae4d05becaeed34d66cad398cc90f9d3ece..8fcac00e08fc27a92a20e829013e0f66af73f613 100644 --- a/paddle/operators/layer_norm_op.cc +++ b/paddle/operators/layer_norm_op.cc @@ -291,32 +291,28 @@ class LayerNormGradKernel auto d_x_map = EigenMatrixMapRowMajor(d_x->data(), left, right); auto triple_product_func = [](T ele) { return ele * ele * ele; }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; + + auto inv_std_map = var_map.unaryExpr(inv_std_func).eval(); // TODO(zcd): these code can be refined if (d_scale) { auto scale_map = ConstEigenMatrixMapRowMajor(scale->data(), 1, right); // dy_dx - auto dx_end = var_map.unaryExpr(inv_std_func) - .replicate(1, right) - .cwiseProduct(d_y_map) - .cwiseProduct(scale_map.replicate(left, 1)); + auto dx_end = + inv_std_map.replicate(1, right).cwiseProduct(d_y_map).cwiseProduct( + scale_map.replicate(left, 1)); + // dy_dmean_dx - auto dx_mean = (T(-1.0) / right) * - var_map.unaryExpr(inv_std_func) - .replicate(1, right) - .cwiseProduct(d_y_map) - .cwiseProduct(scale_map.replicate(left, 1)) - .rowwise() - .sum() - .replicate(1, right); + auto dx_mean = + (T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right); + // dy_var_dx auto dvar_end_part = (x_map - mean_map.replicate(1, right)) .cwiseProduct(scale_map.replicate(left, 1)) .cwiseProduct(d_y_map) .rowwise() .sum(); - auto dvar_end = var_map.unaryExpr(inv_std_func) - .unaryExpr(triple_product_func) + auto dvar_end = inv_std_map.unaryExpr(triple_product_func) .cwiseProduct(dvar_end_part) .replicate(1, right); auto dx_var = @@ -326,24 +322,18 @@ class LayerNormGradKernel d_x_map = dx_end + dx_mean + dx_var; } else { // dy_dx - auto dx_end = var_map.unaryExpr(inv_std_func) - .replicate(1, right) - .cwiseProduct(d_y_map); + auto dx_end = inv_std_map.replicate(1, right).cwiseProduct(d_y_map); + // dy_dmean_dx - auto dx_mean = (T(-1.0) / right) * - var_map.unaryExpr(inv_std_func) - .replicate(1, right) - .cwiseProduct(d_y_map) - .rowwise() - .sum() - .replicate(1, right); + auto dx_mean = + (T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right); + // dy_var_dx auto dvar_end_part = (x_map - mean_map.replicate(1, right)) .cwiseProduct(d_y_map) .rowwise() .sum(); - auto dvar_end = var_map.unaryExpr(inv_std_func) - .unaryExpr(triple_product_func) + auto dvar_end = inv_std_map.unaryExpr(triple_product_func) .cwiseProduct(dvar_end_part) .replicate(1, right); auto dx_var =