diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc index 080819256598788f74857cbc9d88d080d1aef181..0b0c760e57e16b90d257e10590a27634fcab7399 100644 --- a/paddle/operators/layer_norm_op.cc +++ b/paddle/operators/layer_norm_op.cc @@ -233,39 +233,37 @@ class LayerNormGradKernel if (d_x) { d_x->mutable_data(ctx.GetPlace()); auto d_x_map = EigenMatrixMapRowMajor(d_x->data(), left, right); - auto triple_product = [](T ele) { return ele * ele; }; - auto neg_inv_std = [](T ele) { return -std::sqrt(1 / ele); }; + auto triple_product_func = [](T ele) { return ele * ele * ele; }; + auto scale_func = [scale_data](T ele) { return ele * scale_data; }; + auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; auto inv_std_scale_func = [scale_data](T ele) { return std::sqrt(1 / ele) * scale_data; }; - auto neg_inv_std_scale_func = [scale_data](T ele) { - return -std::sqrt(1 / ele) * scale_data; - }; // dy_dx auto dx_end = var_map.unaryExpr(inv_std_scale_func) .replicate(1, right) .cwiseProduct(d_y_map); // dy_dmean_dx - auto dmean_end = var_map.unaryExpr(neg_inv_std_scale_func) - .replicate(1, right) - .cwiseProduct(d_y_map) - .rowwise() - .sum(); - auto dx_mean = (T(1.0) / right) * dmean_end.replicate(1, right); + auto dx_mean = (T(-1.0) / right) * + var_map.unaryExpr(inv_std_scale_func) + .replicate(1, right) + .cwiseProduct(d_y_map) + .rowwise() + .sum() + .replicate(1, right); // dy_var_dx - auto dvar_end_0 = (x_map - mean_map.replicate(1, right)) - .cwiseProduct(d_y_map) - .rowwise() - .sum(); - auto dvar_end = var_map.unaryExpr(neg_inv_std) - .unaryExpr(triple_product) - .cwiseProduct(dvar_end_0); - auto dx_var = (T(1.0) / right) * + auto dvar_end_part = (x_map - mean_map.replicate(1, right)) + .cwiseProduct(d_y_map) + .rowwise() + .sum(); + auto dvar_end = var_map.unaryExpr(inv_std_func) + .unaryExpr(triple_product_func) + .cwiseProduct(dvar_end_part) + .replicate(1, right); + auto dx_var = (T(-1.0) / right) * (x_map - mean_map.replicate(1, right)) - .cwiseProduct(dvar_end.replicate(1, right)); - - // d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0) - // - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0)) + .cwiseProduct(dvar_end) + .unaryExpr(scale_func); d_x_map = dx_end + dx_mean + dx_var; } diff --git a/python/paddle/v2/fluid/tests/test_layer_norm_op.py b/python/paddle/v2/fluid/tests/test_layer_norm_op.py index 4ca9754f32683dd37c1df3ce8aa42d4cfbeaa44d..caa3b944ebfcc13251168de23549918bb95642bf 100644 --- a/python/paddle/v2/fluid/tests/test_layer_norm_op.py +++ b/python/paddle/v2/fluid/tests/test_layer_norm_op.py @@ -52,18 +52,19 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, epsilon): D = reduce(mul, x_shape, 1) / N grad_y.shape = [N, D] x.shape = [N, D] - grad_offset = np.sum(grad_y) mean.shape = [N, 1] var.shape = [N, 1] - grad_scale = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y) + + d_scale = np.sum(grad_y).reshape([1, ]) + d_bias = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y).reshape([1, ]) dx_end = np.sqrt(1.0 / var) * grad_y d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y, axis=1).reshape([N, 1]) - d_mean_1 = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape( - [N, 1]) * (-1.0 / D * np.sqrt(1.0 / var) * - np.sum(x - mean, axis=1).reshape([N, 1])).reshape([N, 1]) - d_mean = 1.0 / D * (d_mean_0 + d_mean_1) + # d_mean_1 = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape( + # [N, 1]) * (-1.0 / D * np.sqrt(1.0 / var) * + # np.sum(x - mean, axis=1).reshape([N, 1])).reshape([N, 1]) + d_mean = 1.0 / D * (d_mean_0) d_std = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape([N, 1]) * ( 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)) @@ -73,7 +74,7 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, epsilon): grad_y.shape = x_shape x.shape = x_shape - return grad_x, grad_scale, grad_offset + return grad_x, d_bias, d_scale def create_or_get_tensor(scope, var_name, var, place): @@ -144,7 +145,7 @@ class TestLayerNormdOp(OpTest): epsilon = 0.00001 x_shape = shape scale_shape = [1] - + np.random.random(123) x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -154,7 +155,6 @@ class TestLayerNormdOp(OpTest): x_val, scale_val, bias_val, epsilon) # for gradient test - # y_grad = np.ones(x_shape).astype(np.float32) * 0.00277778 y_grad = np.random.random_sample(x_shape).astype(np.float32) x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_layer_norm_grad( @@ -229,7 +229,6 @@ class TestLayerNormdOp(OpTest): for place in places: test_with_place(place, [2, 3, 4, 5]) - test_with_place(place, [2, 3]) if __name__ == '__main__':