diff --git a/paddle/operators/layer_norm_op.h b/paddle/operators/layer_norm_op.h index 608447b1ff846d86039d4ce3d5bc652fe4d2088a..3c436b89263758bbc0abcd1bb71cef3e1370d2a5 100644 --- a/paddle/operators/layer_norm_op.h +++ b/paddle/operators/layer_norm_op.h @@ -97,15 +97,15 @@ class LayerNormKernel : public framework::OpKernel { auto &dev_ctx = ctx.template device_context(); math::RowwiseMean row_mean; - // functor-> get mean + // get mean row_mean(dev_ctx, x, mean); - // functor-> get variance + // get variance ElementwiseComputeEx, DeviceContext, T>( ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor(), &out); row_mean(dev_ctx, out, var); - // functor-> get norm_out + // get x_norm ElementwiseComputeEx, DeviceContext, T>( ctx, &x, mean, /*axis*/ 0, SubFunctor(), &out); ElementwiseComputeEx, DeviceContext, T>( @@ -129,9 +129,11 @@ class LayerNormGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { const float epsilon = ctx.Attr("epsilon"); auto x = *ctx.Input("X"); - auto mean = *ctx.Input("Mean"); - auto var = *ctx.Input("Variance"); - auto scale = *ctx.Input("Scale"); + auto *y = ctx.Input("Y"); + auto *mean = ctx.Input("Mean"); + auto *var = ctx.Input("Variance"); + auto *scale = ctx.Input("Scale"); + auto *bias = ctx.Input("Bias"); auto d_y = *ctx.Input(framework::GradVarName("Y")); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); @@ -155,14 +157,19 @@ class LayerNormGradKernel : public framework::OpKernel { if (d_scale || d_x) { x.Resize(matrix_shape); temp.mutable_data(matrix_shape, ctx.GetPlace()); - temp_norm.mutable_data(matrix_shape, ctx.GetPlace()); - // get x_norm - ElementwiseComputeEx, DeviceContext, T>( - ctx, &x, &mean, /*axis*/ 0, SubFunctor(), &temp_norm); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp_norm, &var, /*axis*/ 0, - DivAndSqrtFunctor(static_cast(epsilon)), &temp_norm); + if (!(bias && scale)) { + temp_norm.ShareDataWith(*y); + temp_norm.Resize(matrix_shape); + } else { + temp_norm.mutable_data(matrix_shape, ctx.GetPlace()); + // get x_norm + ElementwiseComputeEx, DeviceContext, T>( + ctx, &x, mean, /*axis*/ 0, SubFunctor(), &temp_norm); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp_norm, var, /*axis*/ 0, + DivAndSqrtFunctor(static_cast(epsilon)), &temp_norm); + } } if (d_bias) { @@ -188,7 +195,7 @@ class LayerNormGradKernel : public framework::OpKernel { if (d_scale) { // dy_dx ElementwiseComputeEx, DeviceContext, T>( - ctx, &d_y, &scale, /*axis*/ 1, MulFunctor(), &temp); + ctx, &d_y, scale, /*axis*/ 1, MulFunctor(), &temp); framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x); // dy_dmean_dx @@ -199,7 +206,6 @@ class LayerNormGradKernel : public framework::OpKernel { // dy_var_dx ElementwiseComputeEx, DeviceContext, T>( ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); - } else { // dy_dx framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); @@ -216,12 +222,12 @@ class LayerNormGradKernel : public framework::OpKernel { // dy_var_dx row_mean(dev_ctx, temp, &temp_vec); ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor(), &temp_norm); + ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor(), &temp); ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &temp_norm, /*axis*/ 0, SubFunctor(), d_x); + ctx, d_x, &temp, /*axis*/ 0, SubFunctor(), d_x); ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &var, /*axis*/ 0, + ctx, d_x, var, /*axis*/ 0, DivAndSqrtFunctor(static_cast(epsilon)), d_x); d_x->Resize(dx_dim); }