diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 3656bd1a181671bcdc853267135b623df3238f20..0c1f58a2f30f68c184906a0cebd78da98a83d952 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -169,10 +169,16 @@ __inline__ __device__ half rsqrt_(const half val) { } #endif -template -__global__ void LayerNormForward(const T *x, const U *scale, const U *bias, - T *y, U *mean, U *var, float epsilon, - int64_t feature_size) { +template +using LayerNormScaleBiasT = + typename std::conditional::type; + +template +__global__ void LayerNormForward( + const T *x, const LayerNormScaleBiasT *scale, + const LayerNormScaleBiasT *bias, T *y, + U *mean, U *var, float epsilon, int64_t feature_size) { __shared__ U mean_share; __shared__ U var_share; __shared__ U shared_mean[32]; // threadIdx.x / warpSize <= kMaxBlockDim / @@ -212,14 +218,15 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, if (bias != nullptr) { for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast( - scale[j] * (static_cast(x[i]) - mean_val) * invvar + bias[j]); + y[i] = static_cast(static_cast(scale[j]) * + (static_cast(x[i]) - mean_val) * invvar + + static_cast(bias[j])); } } else { for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast(scale[j] * (static_cast(x[i]) - mean_val) * - invvar); + y[i] = static_cast(static_cast(scale[j]) * + (static_cast(x[i]) - mean_val) * invvar); } } } else { // scale == nullptr @@ -227,7 +234,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar + - bias[j]); + static_cast(bias[j])); } } else { for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; @@ -336,12 +343,15 @@ __global__ void LayerNormBackwardPartGradGammaBeta( } } -template +template __global__ void LayerNormBackwardSumGradGammaBeta( const U *part_grad_gamma, const U *part_grad_beta, const int part_size, // const int n1, const int n2, T* grad_gamma, T* grad_beta) { - const int n1, const int n2, U *grad_gamma, U *grad_beta) { + const int n1, const int n2, + LayerNormScaleBiasT *grad_gamma, + LayerNormScaleBiasT *grad_beta) { // sum partial gradients for gamma and beta + using ScaleBiasT = LayerNormScaleBiasT; __shared__ U buf[BDIMX * BDIMY]; int64_t i2 = blockIdx.x * BDIMX + threadIdx.x; if (i2 < n2) { @@ -378,20 +388,18 @@ __global__ void LayerNormBackwardSumGradGammaBeta( } // write out fully summed gradients if (threadIdx.y == 0) { - grad_gamma[i2] = sum_gamma; - grad_beta[i2] = sum_beta; + grad_gamma[i2] = static_cast(sum_gamma); + grad_beta[i2] = static_cast(sum_beta); } } } -template +template __global__ void LayerNormBackwardComputeGradInput( const T *__restrict__ dout, const T *__restrict__ input, const int n1, - const int n2, - // const U* __restrict__ mean, const U* __restrict__ var, const float - // epsilon, const T* gamma, - const U *__restrict__ mean, const U *__restrict__ var, const float epsilon, - const U *gamma, T *grad_input) { + const int n2, const U *__restrict__ mean, const U *__restrict__ var, + const float epsilon, + const LayerNormScaleBiasT *gamma, T *grad_input) { #ifdef __HIPCC__ for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) { #else @@ -411,15 +419,17 @@ __global__ void LayerNormBackwardComputeGradInput( for (int k = 0; k < 4; ++k) { const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); - sum_loss1 += c_loss * gamma[l + k]; - sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; + sum_loss1 += c_loss * static_cast(gamma[l + k]); + sum_loss2 += + c_loss * static_cast(gamma[l + k]) * (c_h - c_mean) * c_invvar; } } for (; l < n2; ++l) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); - sum_loss1 += c_loss * gamma[l]; - sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + sum_loss1 += c_loss * static_cast(gamma[l]); + sum_loss2 += + c_loss * static_cast(gamma[l]) * (c_h - c_mean) * c_invvar; } } else { int l = 4 * thrx; @@ -491,7 +501,7 @@ __global__ void LayerNormBackwardComputeGradInput( for (int l = thrx; l < n2; l += numx) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * gamma[l]; + U f_grad_input = fH * c_loss * static_cast(gamma[l]); f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; f_grad_input *= term1; @@ -513,11 +523,17 @@ __global__ void LayerNormBackwardComputeGradInput( // Make sure that d_scale != nullptr && d_bias != nullptr // Since d_scale != nullptr, scale would not be nullptr -template +template __global__ void LayerNormBackwardGradientAll( - const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, - const U *var, const U *scale, float epsilon, int64_t batch_size, - int64_t feature_size, int64_t col_offset) { + const T *x, const T *d_y, + LayerNormScaleBiasT *d_scale, + LayerNormScaleBiasT *d_bias, T *d_x, + const U *mean, const U *var, + const LayerNormScaleBiasT *scale, + float epsilon, int64_t batch_size, int64_t feature_size, + int64_t col_offset) { + using ScaleBiasT = LayerNormScaleBiasT; int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset); int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset); int64_t stride = BlockDim * feature_size; @@ -532,7 +548,8 @@ __global__ void LayerNormBackwardGradientAll( d_bias_partial += static_cast(d_y[i]); if (HasDx) { d_x[i] = static_cast(static_cast(d_y[i]) * - scale[blockIdx.x + col_offset] / var_val); + static_cast(scale[blockIdx.x + col_offset]) / + var_val); } } @@ -543,19 +560,24 @@ __global__ void LayerNormBackwardGradientAll( d_bias_partial = BlockReduceSum(d_bias_partial, shared_bias); if (threadIdx.x == 0) { - d_scale[blockIdx.x + col_offset] = d_scale_partial; - d_bias[blockIdx.x + col_offset] = d_bias_partial; + d_scale[blockIdx.x + col_offset] = static_cast(d_scale_partial); + d_bias[blockIdx.x + col_offset] = static_cast(d_bias_partial); } } // Make sure that there is only one true expression: d_scale != nullptr // or d_bias != nullptr // Notice: scale may be nullptr -template +template __global__ void LayerNormBackwardGradientScaleOrBias( - const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, - const U *var, const U *scale, float epsilon, int64_t batch_size, - int64_t feature_size, int col_offset) { + const T *x, const T *d_y, + LayerNormScaleBiasT *d_scale, + LayerNormScaleBiasT *d_bias, T *d_x, + const U *mean, const U *var, + const LayerNormScaleBiasT *scale, + float epsilon, int64_t batch_size, int64_t feature_size, int col_offset) { + using ScaleBiasT = LayerNormScaleBiasT; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; @@ -578,7 +600,8 @@ __global__ void LayerNormBackwardGradientScaleOrBias( if (HasDx) { if (scale != nullptr) { d_x[i] = static_cast(static_cast(d_y[i]) * - scale[blockIdx.x + col_offset] / var_val); + static_cast(scale[blockIdx.x + col_offset]) / + var_val); } else { d_x[i] = static_cast(static_cast(d_y[i]) / var_val); } @@ -590,9 +613,11 @@ __global__ void LayerNormBackwardGradientScaleOrBias( if (threadIdx.x == 0) { if (HasDScale) { - d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial; + d_scale[blockIdx.x + col_offset] = + static_cast(d_scale_or_d_bias_partial); } else { - d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial; + d_bias[blockIdx.x + col_offset] = + static_cast(d_scale_or_d_bias_partial); } } } @@ -640,12 +665,12 @@ __global__ void LayerNormBackwardPostProcessToCalculateDX( } // Here, we only calculate d_x -template -__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, - T *d_x, const U *mean, - const U *var, const U *scale, - float epsilon, - int64_t feature_size) { +template +__global__ void LayerNormBackwardGradientOnlyDX( + const T *x, const T *d_y, T *d_x, const U *mean, const U *var, + const LayerNormScaleBiasT *scale, + float epsilon, int64_t feature_size) { + using ScaleBiasT = LayerNormScaleBiasT; using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ U d_x_reduce_tmp[2]; @@ -660,8 +685,8 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, static_cast(real_sqrt(static_cast(block_var) + epsilon)); if (scale != nullptr) { int col_idx = i % feature_size; - d_x[i] = - static_cast(static_cast(d_y[i]) * scale[col_idx] / var_val); + d_x[i] = static_cast(static_cast(d_y[i]) * + static_cast(scale[col_idx]) / var_val); } else { d_x[i] = static_cast(static_cast(d_y[i]) / var_val); } @@ -692,11 +717,16 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, } } -template +template __global__ void LayerNormBackwardWhenBatchSizeIsOne( - const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean, - const U *var, const U *scale, float epsilon, int64_t feature_size) { + const T *x, const T *d_y, T *d_x, + LayerNormScaleBiasT *d_scale, + LayerNormScaleBiasT *d_bias, const U *mean, + const U *var, + const LayerNormScaleBiasT *scale, + float epsilon, int64_t feature_size) { int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + using ScaleBiasT = LayerNormScaleBiasT; if (idx < feature_size) { auto var_val = static_cast(real_sqrt(static_cast(var[0]) + epsilon)); @@ -704,26 +734,32 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( if (d_scale == nullptr) { d_x[idx] = static_cast(static_cast(d_y[idx]) / var_val); } else { - d_x[idx] = - static_cast(static_cast(d_y[idx]) * scale[idx] / var_val); + d_x[idx] = static_cast(static_cast(d_y[idx]) * + static_cast(scale[idx]) / var_val); } } if (d_scale != nullptr) { - d_scale[idx] = static_cast(d_y[idx]) * - (static_cast(x[idx]) - mean[0]) / var_val; + d_scale[idx] = + static_cast(static_cast(d_y[idx]) * + (static_cast(x[idx]) - mean[0]) / var_val); } - if (d_bias != nullptr) d_bias[idx] = static_cast(d_y[idx]); + if (d_bias != nullptr) { + d_bias[idx] = static_cast(d_y[idx]); + } } } -template -static void LayerNormBackward(const T *x, const T *d_y, const U *scale, - const U *mean, const U *var, T *d_x, U *d_scale, - U *d_bias, float epsilon, int64_t batch_size, - int64_t feature_size, - const platform::CUDADeviceContext &dev_ctx) { +template +static void LayerNormBackward( + const T *x, const T *d_y, + const LayerNormScaleBiasT *scale, + const U *mean, const U *var, T *d_x, + LayerNormScaleBiasT *d_scale, + LayerNormScaleBiasT *d_bias, float epsilon, + int64_t batch_size, int64_t feature_size, + const platform::CUDADeviceContext &dev_ctx) { auto stream = dev_ctx.stream(); #ifdef __HIPCC__ const int kMaxBlockDim = 256; @@ -737,10 +773,10 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, if (gradient_flag == 0) return; if (batch_size == 1) { - LayerNormBackwardWhenBatchSizeIsOne< - T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, - 0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, - epsilon, feature_size); + LayerNormBackwardWhenBatchSizeIsOne<<< + (feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, + stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon, + feature_size); if (d_x != nullptr) { switch (GetDesiredBlockDim(feature_size)) { @@ -759,8 +795,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< - T, U, kBlockDim, false, - false><<>>( + T, U, kBlockDim, false, false, + ScaleBiasWithSameTypeX><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } @@ -770,8 +806,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< - T, U, kBlockDim, false, - true><<>>( + T, U, kBlockDim, false, true, + ScaleBiasWithSameTypeX><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } @@ -781,7 +817,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientAll< - T, U, kBlockDim, false><<>>( + T, U, kBlockDim, false, + ScaleBiasWithSameTypeX><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } @@ -790,7 +827,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormBackwardGradientOnlyDX< - T, U, kBlockDim><<>>( + T, U, kBlockDim, + ScaleBiasWithSameTypeX><<>>( x, d_y, d_x, mean, var, scale, epsilon, feature_size)); } break; @@ -799,8 +837,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< - T, U, kBlockDim, true, - false><<>>( + T, U, kBlockDim, true, false, + ScaleBiasWithSameTypeX><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } @@ -816,8 +854,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< - T, U, kBlockDim, true, - true><<>>( + T, U, kBlockDim, true, true, + ScaleBiasWithSameTypeX><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } @@ -854,7 +892,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, dim3 threads3(BDIMX3, BDIMY3, 1); const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1); LayerNormBackwardSumGradGammaBeta< - T, U, BDIMX3, BDIMY3><<>>( + T, U, BDIMX3, BDIMY3, + ScaleBiasWithSameTypeX><<>>( part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size, d_scale, d_bias); @@ -862,7 +901,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, constexpr int BDIMY1 = 4; dim3 threads1(BDIMX1, BDIMY1, 1); LayerNormBackwardComputeGradInput< - T, U, BDIMX1, BDIMY1><<>>( + T, U, BDIMX1, BDIMY1, + ScaleBiasWithSameTypeX><<>>( d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); break; } diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 444478c2eadabaa9cc831d16bba9aeecfcdfe3b9..192ad3db8bebde7f72aefebef8e3009eefead47e 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -102,24 +102,6 @@ class LayerNormOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - // By default, the type of the scale, bias, mean, - // and var tensors should both be float. (For float or float16 input tensor) - // or double (For double input tensor). - auto ln_param_type = framework::proto::VarType::FP32; - if (input_data_type == framework::proto::VarType::FP64) { - ln_param_type = framework::proto::VarType::FP64; - } - if (ctx.HasInput("Scale")) { - PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input("Scale")->type(), - platform::errors::InvalidArgument( - "Scale input should be of float type")); - } - if (ctx.HasInput("Bias")) { - PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input("Bias")->type(), - platform::errors::InvalidArgument( - "Bias input should be of float type")); - } - framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index 4be0c03f2de0aeaa87ef8c50f2973260eaea7b01..3fe453bda2d9ec8d6d25fdccfa657be7cb5c6aaf 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -63,8 +63,32 @@ class LayerNormKernel auto *y_data = y->mutable_data(ctx.GetPlace()); auto *mean_data = mean->mutable_data(ctx.GetPlace()); auto *var_data = var->mutable_data(ctx.GetPlace()); - auto *scale_data = (scale == nullptr ? nullptr : scale->data()); - auto *bias_data = (bias == nullptr ? nullptr : bias->data()); + + auto *void_scale_data = (scale == nullptr ? nullptr : scale->data()); + auto *void_bias_data = (bias == nullptr ? nullptr : bias->data()); + + framework::proto::VarType::Type x_dtype = x->type(); + framework::proto::VarType::Type scale_bias_dtype; + if (void_scale_data != nullptr) { + scale_bias_dtype = scale->type(); + if (void_bias_data != nullptr) { + PADDLE_ENFORCE_EQ(scale_bias_dtype, bias->type(), + platform::errors::InvalidArgument( + "Thie Scale and Bias of layer_norm op " + "should have the same data type.")); + } + } else { + scale_bias_dtype = (void_bias_data != nullptr ? bias->type() : x_dtype); + } + + bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype; + if (!is_scale_bias_same_dtype_with_x) { + PADDLE_ENFORCE_EQ(scale_bias_dtype, + framework::DataTypeTrait::DataType(), + platform::errors::InvalidArgument( + "Unsupported data type of Scale and Bias: %s", + framework::DataTypeToString(scale_bias_dtype))); + } auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int64_t batch_size = static_cast(matrix_dim[0]); @@ -72,17 +96,28 @@ class LayerNormKernel auto stream = ctx.cuda_device_context().stream(); - switch (GetDesiredBlockDim(feature_size)) { - FIXED_BLOCK_DIM_CASE( - LayerNormForward<<>>( - x_data, scale_data, bias_data, y_data, mean_data, var_data, - epsilon, feature_size)); - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Product from begin_norm_axis to end must be larger than 1")); - break; +#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ + do { \ + switch (GetDesiredBlockDim(feature_size)) { \ + FIXED_BLOCK_DIM_CASE( \ + LayerNormForward<<< \ + batch_size, kBlockDim, 0, stream>>>( \ + x_data, static_cast(void_scale_data), \ + static_cast(void_bias_data), y_data, \ + mean_data, var_data, epsilon, feature_size)); \ + default: \ + PADDLE_THROW(platform::errors::InvalidArgument( \ + "Product from begin_norm_axis to end must be larger than 1")); \ + break; \ + } \ + } while (0) + + if (is_scale_bias_same_dtype_with_x) { + PADDLE_LAUNCH_LAYERNORM_FWD(T, true); + } else { + PADDLE_LAUNCH_LAYERNORM_FWD(U, false); } +#undef PADDLE_LAUNCH_LAYERNORM_FWD } }; @@ -102,32 +137,64 @@ class LayerNormGradKernel auto *mean = ctx.Input("Mean"); auto *var = ctx.Input("Variance"); auto *scale = ctx.Input("Scale"); + auto *bias = ctx.Input("Bias"); auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto &x_dims = x->dims(); + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int64_t batch_size = static_cast(matrix_dim[0]); + int64_t feature_size = static_cast(matrix_dim[1]); + auto *x_data = x->data(); auto *d_y_data = d_y->data(); + auto *mean_data = mean->data(); auto *var_data = var->data(); - auto *scale_data = (scale == nullptr ? nullptr : scale->data()); - auto *d_scale_data = - (d_scale == nullptr ? nullptr - : d_scale->mutable_data(ctx.GetPlace())); - auto *d_bias_data = - (d_bias == nullptr ? nullptr : d_bias->mutable_data(ctx.GetPlace())); auto *d_x_data = (d_x == nullptr ? nullptr : d_x->mutable_data(ctx.GetPlace())); - const auto &x_dims = x->dims(); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int64_t batch_size = static_cast(matrix_dim[0]); - int64_t feature_size = static_cast(matrix_dim[1]); + framework::proto::VarType::Type x_dtype = x->type(); + framework::proto::VarType::Type scale_bias_dtype; + if (scale != nullptr) { + scale_bias_dtype = scale->type(); + } else { + // FIXME(zengjinle): do not find a better way to get the right + // data type of the d_scale and d_bias if scale == nullptr. + auto *bias = ctx.Input("Bias"); + if (bias != nullptr) { + scale_bias_dtype = bias->saved_type(); + } else { + scale_bias_dtype = x_dtype; + } + } + +#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ + do { \ + auto *scale_data = \ + (scale == nullptr ? nullptr : scale->data()); \ + auto *d_scale_data = \ + (d_scale == nullptr ? nullptr : d_scale->mutable_data( \ + ctx.GetPlace())); \ + auto *d_bias_data = \ + (d_bias == nullptr ? nullptr : d_bias->mutable_data( \ + ctx.GetPlace())); \ + auto *d_x_data = \ + (d_x == nullptr ? nullptr : d_x->mutable_data(ctx.GetPlace())); \ + LayerNormBackward( \ + x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, \ + d_scale_data, d_bias_data, epsilon, batch_size, feature_size, \ + ctx.cuda_device_context()); \ + } while (0) + + if (scale_bias_dtype == x_dtype) { + PADDLE_LAUNCH_LAYERNORM_BWD(T, true); + } else { + PADDLE_LAUNCH_LAYERNORM_BWD(U, false); + } - LayerNormBackward(x_data, d_y_data, scale_data, mean_data, var_data, - d_x_data, d_scale_data, d_bias_data, epsilon, - batch_size, feature_size, - ctx.cuda_device_context()); +#undef PADDLE_LAUNCH_LAYERNORM_BWD } }; diff --git a/paddle/fluid/operators/optimizers/lamb_op.cc b/paddle/fluid/operators/optimizers/lamb_op.cc index 8adf0dea7eb34dcd92c3b207859ba51d04845f62..f9c3f9c3582b333f28cad21359df563167c108c5 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.cc +++ b/paddle/fluid/operators/optimizers/lamb_op.cc @@ -152,6 +152,14 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Moment2", "(Tensor) Input second moment."); AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator."); AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator."); + AddInput("MasterParam", + "(LoDTensor, default LoDTensor) " + "Input master parameter that has to be updated.") + .AsDispensable(); + AddInput( + "SkipUpdate", + "(Tensor) Input tensor to determine whether to update the parameter.") + .AsDispensable(); AddOutput("ParamOut", "(Tensor) Output parameter."); AddOutput("Moment1Out", "(Tensor) Output first moment."); @@ -160,6 +168,8 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator") .AsDispensable(); + AddOutput("MasterParamOut", "(Tensor) Output master parameter.") + .AsDispensable(); AddAttr("weight_decay", "(float) Weight decay rate."); AddAttr("beta1", "(float, default 0.9) The exponential decay rate for the " @@ -173,6 +183,10 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker { "(float, default 1.0e-6) " "Constant for numerical stability.") .SetDefault(1.0e-6f); + AddAttr( + "multi_precision", + "(bool, default false) Whether to enable multi-precision mode.") + .SetDefault(false); AddComment(R"DOC( LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer. diff --git a/paddle/fluid/operators/optimizers/lamb_op.cu b/paddle/fluid/operators/optimizers/lamb_op.cu index 9ffb62926a4fffd95ca014947282a7a32e92e4b8..b46fa19ea135207ec889db20d9d4a03593f01b62 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.cu +++ b/paddle/fluid/operators/optimizers/lamb_op.cu @@ -16,5 +16,7 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - lamb, ops::LambOpKernel, + lamb, ops::LambOpKernel, + ops::LambOpKernel, ops::LambOpKernel); diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index 749b9e795560c12af460049f1ab3f24f26998822..9eba8df9992fc02efa217c63d5e779c72b7c83a6 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -17,8 +17,10 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/eigen_ext.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { @@ -26,30 +28,35 @@ namespace operators { namespace scatter = paddle::operators::math::scatter; -template +template struct LambMomentREGUpdateFunctor { - T weight_decay_; - T beta1_; - T beta2_; - T epsilon_; - - T beta1_pow_; - T* beta1_pow_out_; - T beta2_pow_; - T* beta2_pow_out_; - const T* moment1_; - T* moment1_out_; - const T* moment2_; - T* moment2_out_; + using MT = typename std::conditional< + IsMultiPrecision, typename details::MPTypeTrait::Type, T>::type; + + MT weight_decay_; + MT beta1_; + MT beta2_; + MT epsilon_; + + MT beta1_pow_; + MT* beta1_pow_out_; + MT beta2_pow_; + MT* beta2_pow_out_; + const MT* moment1_; + MT* moment1_out_; + const MT* moment2_; + MT* moment2_out_; const T* grad_; - const T* param_; - T* trust_ratio_div_; - - LambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, - T beta1_pow, T* beta1_pow_out, T beta2_pow, - T* beta2_pow_out, const T* mom1, T* mom1_out, - const T* mom2, T* mom2_out, const T* grad, - const T* param, T* trust_ratio_div) + const MT* param_; + MT* trust_ratio_div_; + const bool* skip_update_; + + LambMomentREGUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon, + MT beta1_pow, MT* beta1_pow_out, MT beta2_pow, + MT* beta2_pow_out, const MT* mom1, MT* mom1_out, + const MT* mom2, MT* mom2_out, const T* grad, + const MT* param, MT* trust_ratio_div, + const bool* skip_update) : weight_decay_(weight_decay), beta1_(beta1), beta2_(beta2), @@ -64,26 +71,30 @@ struct LambMomentREGUpdateFunctor { moment2_out_(mom2_out), grad_(grad), param_(param), - trust_ratio_div_(trust_ratio_div) {} + trust_ratio_div_(trust_ratio_div), + skip_update_(skip_update) {} inline HOSTDEVICE void operator()(size_t i) const { - T g = grad_[i]; - T mom1 = moment1_[i]; - T mom2 = moment2_[i]; - T beta1_pow = beta1_pow_; - T beta2_pow = beta2_pow_; - T p = param_[i]; + if (skip_update_ && *skip_update_) return; - mom1 = beta1_ * mom1 + (1 - beta1_) * g; - mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + MT g = static_cast(grad_[i]); + MT mom1 = moment1_[i]; + MT mom2 = moment2_[i]; + MT beta1_pow = beta1_pow_; + MT beta2_pow = beta2_pow_; + MT p = param_[i]; + + mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; moment1_out_[i] = mom1; moment2_out_[i] = mom2; - T mom1_unbiased = mom1 / (1 - beta1_pow); - T mom2_unbiased = mom2 / (1 - beta2_pow); + MT mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); + MT mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); trust_ratio_div_[i] = - mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p; + mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + + weight_decay_ * p; if (beta1_pow_out_ && beta2_pow_out_) { beta1_pow_out_[0] = beta1_pow * beta1_; beta2_pow_out_[0] = beta2_pow * beta2_; @@ -91,31 +102,35 @@ struct LambMomentREGUpdateFunctor { } }; -template +template struct LambMomentMENUpdateFunctor { - T weight_decay_; - T beta1_; - T beta2_; - T epsilon_; - - const T* beta1_pow_; - T* beta1_pow_out_; - const T* beta2_pow_; - T* beta2_pow_out_; - const T* moment1_; - T* moment1_out_; - const T* moment2_; - T* moment2_out_; + using MT = typename std::conditional< + IsMultiPrecision, typename details::MPTypeTrait::Type, T>::type; + + MT weight_decay_; + MT beta1_; + MT beta2_; + MT epsilon_; + + const MT* beta1_pow_; + MT* beta1_pow_out_; + const MT* beta2_pow_; + MT* beta2_pow_out_; + const MT* moment1_; + MT* moment1_out_; + const MT* moment2_; + MT* moment2_out_; const T* grad_; - const T* param_; - T* trust_ratio_div_; - - LambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, - const T* beta1_pow, T* beta1_pow_out, - const T* beta2_pow, T* beta2_pow_out, - const T* mom1, T* mom1_out, const T* mom2, - T* mom2_out, const T* grad, const T* param, - T* trust_ratio_div) + const MT* param_; + MT* trust_ratio_div_; + const bool* skip_update_; + + LambMomentMENUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon, + const MT* beta1_pow, MT* beta1_pow_out, + const MT* beta2_pow, MT* beta2_pow_out, + const MT* mom1, MT* mom1_out, const MT* mom2, + MT* mom2_out, const T* grad, const MT* param, + MT* trust_ratio_div, const bool* skip_update) : weight_decay_(weight_decay), beta1_(beta1), beta2_(beta2), @@ -130,26 +145,29 @@ struct LambMomentMENUpdateFunctor { moment2_out_(mom2_out), grad_(grad), param_(param), - trust_ratio_div_(trust_ratio_div) {} + trust_ratio_div_(trust_ratio_div), + skip_update_(skip_update) {} inline HOSTDEVICE void operator()(size_t i) const { - T g = grad_[i]; - T mom1 = moment1_[i]; - T mom2 = moment2_[i]; - T beta1_pow = *beta1_pow_; - T beta2_pow = *beta2_pow_; - T p = param_[i]; + if (skip_update_ && *skip_update_) return; + MT g = static_cast(grad_[i]); + MT mom1 = moment1_[i]; + MT mom2 = moment2_[i]; + MT beta1_pow = *beta1_pow_; + MT beta2_pow = *beta2_pow_; + MT p = param_[i]; - mom1 = beta1_ * mom1 + (1 - beta1_) * g; - mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; moment1_out_[i] = mom1; moment2_out_[i] = mom2; - T mom1_unbiased = mom1 / (1 - beta1_pow); - T mom2_unbiased = mom2 / (1 - beta2_pow); + MT mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); + MT mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); trust_ratio_div_[i] = - mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p; + mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + + weight_decay_ * p; if (beta1_pow_out_ && beta2_pow_out_) { beta1_pow_out_[0] = beta1_pow * beta1_; beta2_pow_out_[0] = beta2_pow * beta2_; @@ -180,13 +198,15 @@ struct SparseLambMomentREGUpdateFunctor { int64_t row_numel_; int64_t row_count_; + const bool* skip_update_; + SparseLambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, T beta1_pow, T* beta1_pow_out, T beta2_pow, T* beta2_pow_out, const T* mom1, T* mom1_out, const T* mom2, T* mom2_out, const T* grad, const T* param, T* trust_ratio_div, const int64_t* rows, int64_t row_numel, - int64_t row_count) + int64_t row_count, const bool* skip_update) : weight_decay_(weight_decay), beta1_(beta1), beta2_(beta2), @@ -204,7 +224,8 @@ struct SparseLambMomentREGUpdateFunctor { trust_ratio_div_(trust_ratio_div), rows_(rows), row_numel_(row_numel), - row_count_(row_count) {} + row_count_(row_count), + skip_update_(skip_update) {} inline HOSTDEVICE void update(size_t i, T g) const { // The following code is same as dense @@ -214,16 +235,17 @@ struct SparseLambMomentREGUpdateFunctor { T beta2_pow = beta2_pow_; T p = param_[i]; - mom1 = beta1_ * mom1 + (1 - beta1_) * g; - mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; moment1_out_[i] = mom1; moment2_out_[i] = mom2; - T mom1_unbiased = mom1 / (1 - beta1_pow); - T mom2_unbiased = mom2 / (1 - beta2_pow); + T mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); + T mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); trust_ratio_div_[i] = - mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p; + mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + + weight_decay_ * p; if (beta1_pow_out_ && beta1_pow_out_) { beta1_pow_out_[0] = beta1_pow * beta1_; beta2_pow_out_[0] = beta2_pow * beta2_; @@ -231,9 +253,11 @@ struct SparseLambMomentREGUpdateFunctor { } inline HOSTDEVICE void operator()(size_t i) const { + if (skip_update_ && *skip_update_) return; auto row_idx = math::BinarySearch(rows_, row_count_, i / row_numel_); - T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] + : static_cast(0); update(i, g); } }; @@ -261,13 +285,16 @@ struct SparseLambMomentMENUpdateFunctor { int64_t row_numel_; int64_t row_count_; + const bool* skip_update_; + SparseLambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, const T* beta1_pow, T* beta1_pow_out, const T* beta2_pow, T* beta2_pow_out, const T* mom1, T* mom1_out, const T* mom2, T* mom2_out, const T* grad, const T* param, T* trust_ratio_div, const int64_t* rows, - int64_t row_numel, int64_t row_count) + int64_t row_numel, int64_t row_count, + const bool* skip_update) : weight_decay_(weight_decay), beta1_(beta1), beta2_(beta2), @@ -285,7 +312,8 @@ struct SparseLambMomentMENUpdateFunctor { trust_ratio_div_(trust_ratio_div), rows_(rows), row_numel_(row_numel), - row_count_(row_count) {} + row_count_(row_count), + skip_update_(skip_update) {} inline HOSTDEVICE void update(size_t i, T g) const { // The following code is same as dense @@ -295,16 +323,17 @@ struct SparseLambMomentMENUpdateFunctor { T beta2_pow = *beta2_pow_; T p = param_[i]; - mom1 = beta1_ * mom1 + (1 - beta1_) * g; - mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; moment1_out_[i] = mom1; moment2_out_[i] = mom2; - T mom1_unbiased = mom1 / (1 - beta1_pow); - T mom2_unbiased = mom2 / (1 - beta2_pow); + T mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); + T mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); trust_ratio_div_[i] = - mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p; + mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + + weight_decay_ * p; if (beta1_pow_out_ && beta1_pow_out_) { beta1_pow_out_[0] = beta1_pow * beta1_; beta2_pow_out_[0] = beta2_pow * beta2_; @@ -312,40 +341,61 @@ struct SparseLambMomentMENUpdateFunctor { } inline HOSTDEVICE void operator()(size_t i) const { + if (skip_update_ && *skip_update_) return; auto row_idx = math::BinarySearch(rows_, row_count_, i / row_numel_); - T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] + : static_cast(0); update(i, g); } }; -template +template struct LambParamUpateFunctor { - const T* lr_; + using MT = typename std::conditional< + IsMultiPrecision, typename details::MPTypeTrait::Type, T>::type; + + const MT* lr_; const T* param_; - const T* param_norm_; - const T* trust_ratio_div_; - const T* trust_ratio_div_norm_; + const MT* master_param_; + const MT* param_norm_; + const MT* trust_ratio_div_; + const MT* trust_ratio_div_norm_; T* param_out_; + MT* master_param_out_; + + const bool* skip_update_; - LambParamUpateFunctor(const T* lr, const T* param, const T* param_norm, - const T* trust_ratio_div, const T* trust_ratio_div_norm, - T* param_out) + LambParamUpateFunctor(const MT* lr, const T* param, const MT* master_param, + const MT* param_norm, const MT* trust_ratio_div, + const MT* trust_ratio_div_norm, T* param_out, + MT* master_param_out, const bool* skip_update) : lr_(lr), param_(param), + master_param_(master_param), param_norm_(param_norm), trust_ratio_div_(trust_ratio_div), trust_ratio_div_norm_(trust_ratio_div_norm), - param_out_(param_out) {} + param_out_(param_out), + master_param_out_(master_param_out), + skip_update_(skip_update) {} inline HOSTDEVICE void operator()(size_t i) const { - T lr = *lr_; - T p = *param_norm_; - T t = *trust_ratio_div_norm_; - - T r = (p > 0 && t > 0) ? p / t : 1.0; + if (skip_update_ && *skip_update_) return; + MT lr = *lr_; + MT pn = *param_norm_; + MT tn = *trust_ratio_div_norm_; + + MT r = (pn > static_cast(0) && tn > static_cast(0)) + ? pn / tn + : static_cast(1); lr *= r; - param_out_[i] = param_[i] - lr * trust_ratio_div_[i]; + MT p = IsMultiPrecision ? master_param_[i] : static_cast(param_[i]); + MT param_out = p - lr * trust_ratio_div_[i]; + param_out_[i] = static_cast(param_out); + if (IsMultiPrecision) { + master_param_out_[i] = param_out; + } } }; @@ -353,86 +403,146 @@ template class LambOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type()))); - - using paddle::framework::LoDTensor; - - T weight_decay = static_cast(ctx.Attr("weight_decay")); - T beta1 = static_cast(ctx.Attr("beta1")); - T beta2 = static_cast(ctx.Attr("beta2")); - T epsilon = static_cast(ctx.Attr("epsilon")); - auto& param = GET_DATA_SAFELY(ctx.Input("Param"), "Input", - "Param", "Lamb"); - auto* grad_var = ctx.InputVar("Grad"); - auto& mom1 = GET_DATA_SAFELY(ctx.Input("Moment1"), "Input", - "Moment1", "Lamb"); - auto& mom2 = GET_DATA_SAFELY(ctx.Input("Moment2"), "Input", - "Moment2", "Lamb"); - auto& lr = GET_DATA_SAFELY(ctx.Input("LearningRate"), "Input", - "LearningRate", "Lamb"); - - auto& beta1_pow = GET_DATA_SAFELY(ctx.Input("Beta1Pow"), "Input", - "Beta1Pow", "Lamb"); - auto& beta2_pow = GET_DATA_SAFELY(ctx.Input("Beta2Pow"), "Input", - "Beta2Pow", "Lamb"); - - auto& param_out = GET_DATA_SAFELY(ctx.Output("ParamOut"), - "Output", "ParamOut", "Lamb"); - auto& mom1_out = GET_DATA_SAFELY(ctx.Output("Moment1Out"), - "Output", "Moment1Out", "Lamb"); - auto& mom2_out = GET_DATA_SAFELY(ctx.Output("Moment2Out"), - "Output", "Moment2Out", "Lamb"); - auto& beta1_pow_out = GET_DATA_SAFELY(ctx.Output("Beta1PowOut"), - "Output", "Beta1PowOut", "Lamb"); - auto& beta2_pow_out = GET_DATA_SAFELY(ctx.Output("Beta2PowOut"), - "Output", "Beta2PowOut", "Lamb"); + using MT = typename details::MPTypeTrait::Type; + bool multi_precision = ctx.Attr("multi_precision"); + if (multi_precision) { + ComputeImpl(ctx); + } else { + ComputeImpl(ctx); + } + } + + private: + template + void ComputeImpl(const framework::ExecutionContext& ctx) const { + if (!IsMultiPrecision) { + constexpr auto kIsSameType = std::is_same::value; + PADDLE_ENFORCE_EQ( + kIsSameType, true, + platform::errors::InvalidArgument( + "When multi_precision=False, T and MT must be the same type.")); + } + const auto* skip_update = ctx.Input("SkipUpdate"); + const bool* skip_update_flag = skip_update && skip_update->IsInitialized() + ? skip_update->data() + : nullptr; + if (skip_update_flag && platform::is_cpu_place(skip_update->place()) && + (*skip_update_flag)) { + return; + } + + auto weight_decay = static_cast(ctx.Attr("weight_decay")); + auto beta1 = static_cast(ctx.Attr("beta1")); + auto beta2 = static_cast(ctx.Attr("beta2")); + auto epsilon = static_cast(ctx.Attr("epsilon")); + const auto& param = GET_DATA_SAFELY( + ctx.Input("Param"), "Input", "Param", "Lamb"); + const auto* grad_var = ctx.InputVar("Grad"); + const auto& mom1 = GET_DATA_SAFELY( + ctx.Input("Moment1"), "Input", "Moment1", "Lamb"); + const auto& mom2 = GET_DATA_SAFELY( + ctx.Input("Moment2"), "Input", "Moment2", "Lamb"); + const auto& lr = + GET_DATA_SAFELY(ctx.Input("LearningRate"), + "Input", "LearningRate", "Lamb"); + + const auto& beta1_pow = + GET_DATA_SAFELY(ctx.Input("Beta1Pow"), "Input", + "Beta1Pow", "Lamb"); + const auto& beta2_pow = + GET_DATA_SAFELY(ctx.Input("Beta2Pow"), "Input", + "Beta2Pow", "Lamb"); + + auto& param_out = + GET_DATA_SAFELY(ctx.Output("ParamOut"), "Output", + "ParamOut", "Lamb"); + auto& mom1_out = + GET_DATA_SAFELY(ctx.Output("Moment1Out"), + "Output", "Moment1Out", "Lamb"); + auto& mom2_out = + GET_DATA_SAFELY(ctx.Output("Moment2Out"), + "Output", "Moment2Out", "Lamb"); + auto& beta1_pow_out = + GET_DATA_SAFELY(ctx.Output("Beta1PowOut"), + "Output", "Beta1PowOut", "Lamb"); + auto& beta2_pow_out = + GET_DATA_SAFELY(ctx.Output("Beta2PowOut"), + "Output", "Beta2PowOut", "Lamb"); + const auto* master_param = + IsMultiPrecision ? ctx.Input("MasterParam") + : nullptr; + auto* master_param_out = + IsMultiPrecision ? ctx.Output("MasterParamOut") + : nullptr; + + if (IsMultiPrecision) { + PADDLE_ENFORCE_NOT_NULL(master_param, + platform::errors::InvalidArgument( + "Input(MasterParam) must be provided when " + "multi_precision=True.")); + PADDLE_ENFORCE_NOT_NULL(master_param_out, + platform::errors::InvalidArgument( + "Output(MasterParamOut) must be provided " + "when multi_precision=True.")); + } auto& dev_ctx = ctx.template device_context(); platform::ForRange for_range(dev_ctx, param.numel()); - framework::Tensor trust_ratio_div = - ctx.AllocateTmpTensor(param.dims(), dev_ctx); + auto trust_ratio_div = + ctx.AllocateTmpTensor(param.dims(), dev_ctx); + + const void* param_ptr = param.template data(); + const void* master_param_ptr = + master_param ? master_param->template data() : nullptr; + void* param_out_ptr = param_out.template mutable_data(ctx.GetPlace()); + void* master_param_out_ptr = + master_param_out + ? master_param_out->template mutable_data(ctx.GetPlace()) + : nullptr; // Update moments if (grad_var->IsType()) { - auto& grad = *ctx.Input("Grad"); + auto& grad = grad_var->Get(); if (platform::is_gpu_place(ctx.GetPlace()) && beta1_pow.place() == platform::CPUPlace() && beta2_pow.place() == platform::CPUPlace()) { - LambMomentREGUpdateFunctor moment_update_functor( - weight_decay, beta1, beta2, epsilon, *beta1_pow.template data(), - nullptr, *beta2_pow.template data(), nullptr, - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - grad.template data(), param.template data(), - trust_ratio_div.template data()); + LambMomentREGUpdateFunctor moment_update_functor( + weight_decay, beta1, beta2, epsilon, *beta1_pow.template data(), + nullptr, *beta2_pow.template data(), nullptr, + mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + grad.template data(), + static_cast(IsMultiPrecision ? master_param_ptr + : param_ptr), + trust_ratio_div.template data(), skip_update_flag); for_range(moment_update_functor); - beta1_pow_out.template mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow.template data()[0]; - beta2_pow_out.template mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow.template data()[0]; + beta1_pow_out.template mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow.template data()[0]; + beta2_pow_out.template mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow.template data()[0]; } else { - LambMomentMENUpdateFunctor moment_update_functor( - weight_decay, beta1, beta2, epsilon, beta1_pow.template data(), - beta1_pow_out.template mutable_data(ctx.GetPlace()), - beta2_pow.template data(), - beta2_pow_out.template mutable_data(ctx.GetPlace()), - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - grad.template data(), param.template data(), - trust_ratio_div.template data()); + LambMomentMENUpdateFunctor moment_update_functor( + weight_decay, beta1, beta2, epsilon, beta1_pow.template data(), + beta1_pow_out.template mutable_data(ctx.GetPlace()), + beta2_pow.template data(), + beta2_pow_out.template mutable_data(ctx.GetPlace()), + mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + grad.template data(), + static_cast(IsMultiPrecision ? master_param_ptr + : param_ptr), + trust_ratio_div.template data(), skip_update_flag); for_range(moment_update_functor); } } else if (grad_var->IsType()) { + PADDLE_ENFORCE_EQ(IsMultiPrecision, false, + platform::errors::Unimplemented( + "SelectedRows gradient is not supported when " + "multi_precision=True")); auto& grad = GET_DATA_SAFELY(ctx.Input("Grad"), "Input", "Grad", "Lamb"); if (grad.rows().size() == 0) { @@ -470,22 +580,25 @@ class LambOpKernel : public framework::OpKernel { beta1_pow.place() == platform::CPUPlace() && beta2_pow.place() == platform::CPUPlace()) { SparseLambMomentREGUpdateFunctor moment_update_functor( - weight_decay, beta1, beta2, epsilon, *beta1_pow.template data(), - nullptr, *beta2_pow.template data(), nullptr, - mom1.template data(), + static_cast(weight_decay), static_cast(beta1), + static_cast(beta2), static_cast(epsilon), + *beta1_pow.template data(), nullptr, + *beta2_pow.template data(), nullptr, mom1.template data(), mom1_out.template mutable_data(ctx.GetPlace()), mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), grad_data, param.template data(), trust_ratio_div.template data(), rows, - row_numel, grad_merge.rows().size()); + row_numel, grad_merge.rows().size(), skip_update_flag); for_range(moment_update_functor); beta1_pow_out.template mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow.template data()[0]; + static_cast(beta1) * beta1_pow.template data()[0]; beta2_pow_out.template mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow.template data()[0]; + static_cast(beta2) * beta2_pow.template data()[0]; } else { SparseLambMomentMENUpdateFunctor moment_update_functor( - weight_decay, beta1, beta2, epsilon, beta1_pow.template data(), + static_cast(weight_decay), static_cast(beta1), + static_cast(beta2), static_cast(epsilon), + beta1_pow.template data(), beta1_pow_out.template mutable_data(ctx.GetPlace()), beta2_pow.template data(), beta2_pow_out.template mutable_data(ctx.GetPlace()), @@ -494,36 +607,45 @@ class LambOpKernel : public framework::OpKernel { mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), grad_data, param.template data(), trust_ratio_div.template data(), rows, - row_numel, grad_merge.rows().size()); + row_numel, grad_merge.rows().size(), skip_update_flag); for_range(moment_update_functor); } } else { PADDLE_THROW(platform::errors::InvalidArgument( "Variable type not supported by lamb_op. Expect LoDTensor or " "SelectedRows, but got %s", - framework::ToTypeName(param_var->Type()))); + framework::ToTypeName(grad_var->Type()))); } // Update parameter - framework::Tensor p_norm_t = - ctx.AllocateTmpTensor({1}, dev_ctx); - framework::Tensor trust_ratio_div_norm_t = - ctx.AllocateTmpTensor({1}, dev_ctx); - auto p_norm = framework::EigenScalar::From(p_norm_t); - auto trust_ratio_div_norm = - framework::EigenScalar::From(trust_ratio_div_norm_t); + auto p_norm_t = ctx.AllocateTmpTensor({1}, dev_ctx); + auto trust_ratio_div_norm_t = + ctx.AllocateTmpTensor({1}, dev_ctx); - auto p = framework::EigenVector::Flatten(param); - auto t = framework::EigenVector::Flatten(trust_ratio_div); + auto p_norm = framework::EigenScalar::From(p_norm_t); + auto trust_ratio_div_norm = + framework::EigenScalar::From(trust_ratio_div_norm_t); + auto t = framework::EigenVector::Flatten(trust_ratio_div); + // TODO(zengjinle): remove the following Eigen operations when + // *skip_update == true. auto* place = dev_ctx.eigen_device(); - p_norm.device(*place) = p.square().sum().sqrt(); + if (IsMultiPrecision) { + auto mp = framework::EigenVector::Flatten(*master_param); + p_norm.device(*place) = mp.square().sum().sqrt(); + } else { + auto p = framework::EigenVector::Flatten(param); + p_norm.device(*place) = p.square().sum().sqrt(); + } trust_ratio_div_norm.device(*place) = t.square().sum().sqrt(); - LambParamUpateFunctor param_update_functor( - lr.template data(), param.template data(), - p_norm_t.template data(), trust_ratio_div.template data(), - trust_ratio_div_norm_t.template data(), - param_out.template mutable_data(ctx.GetPlace())); + + LambParamUpateFunctor param_update_functor( + lr.template data(), static_cast(param_ptr), + static_cast(master_param_ptr), p_norm_t.template data(), + trust_ratio_div.template data(), + trust_ratio_div_norm_t.template data(), + static_cast(param_out_ptr), static_cast(master_param_out_ptr), + skip_update_flag); for_range(param_update_functor); } }; diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index f148c971432c00fda18a70e26fe16f64a25a0c82..6fd5f659a9955f69fb82687cd20ee45065553532 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -72,6 +72,9 @@ std::map> op_ins_map = { {"adamw", {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow", "MasterParam"}}, + {"lamb", + {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", + "Beta2Pow", "MasterParam"}}, {"sparse_attention", {"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}}, }; @@ -112,8 +115,6 @@ std::map> op_outs_map = { {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}}, - {"lamb", - {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"run_program", {"DOut"}}, {"adam", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", @@ -121,6 +122,9 @@ std::map> op_outs_map = { {"adamw", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, + {"lamb", + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", + "MasterParamOut"}}, }; // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are @@ -142,6 +146,9 @@ std::map> op_passing_outs_map = { {"adamw", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, + {"lamb", + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", + "MasterParamOut"}}, {"average_accumulates", {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", "out_old_num_accumulates", "out_num_updates"}}, @@ -173,8 +180,6 @@ std::map> op_passing_outs_map = { {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}}, {"moving_average_abs_max_scale", {"Out", "OutScale", "OutAccum", "OutState"}}, - {"lamb", - {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"rnn", {"DropoutState"}}, {"run_program", {"Out", "DOut", "OutScope"}}, {"clear_float_status", {"FloatStatusOut"}}, diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 293d6119e75046b4269e2c713239c776bd40fc39..a4187d4a143d335cdefa5da33edd277f547be951 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -371,6 +371,20 @@ class ClipGradByNorm(ClipGradBase): return param, new_grad +_allow_pure_fp16_global_norm_clip_flag = False + + +def _allow_pure_fp16_global_norm_clip(*args): + global _allow_pure_fp16_global_norm_clip_flag + if len(args) == 0: + return _allow_pure_fp16_global_norm_clip_flag + else: + assert len(args) == 1 and isinstance(args[0], bool) + old_value = _allow_pure_fp16_global_norm_clip_flag + _allow_pure_fp16_global_norm_clip_flag = args[0] + return old_value + + class ClipGradByGlobalNorm(ClipGradBase): r""" Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in @@ -537,8 +551,12 @@ class ClipGradByGlobalNorm(ClipGradBase): global_norm_var = [] if len(sum_square_list_fp16) > 0: global_norm_var_fp16 = layers.sums(sum_square_list_fp16) - global_norm_var.append( - global_norm_var_fp16.astype(sum_dtype)) + if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip( + ): + global_norm_var.append( + global_norm_var_fp16.astype(sum_dtype)) + else: + global_norm_var.append(global_norm_var_fp16) if len(sum_square_list_fp32) > 0: global_norm_var_fp32 = layers.sums(sum_square_list_fp32) if sum_dtype == 'float32': @@ -573,8 +591,9 @@ class ClipGradByGlobalNorm(ClipGradBase): with p.block.program._optimized_guard([p, g]): # inplace scale_input = (scale_var.astype('float16') - if g.dtype == core.VarDesc.VarType.FP16 else - scale_var) + if g.dtype == core.VarDesc.VarType.FP16 and + scale_var.dtype != core.VarDesc.VarType.FP16 + else scale_var) # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g # will be in different blocks with the gradient clip related ops. # We need to handle the correct block, otherwise will encounter diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 563c394c9fbfe82e491170a65fbf273ac6f5104f..b737b14aa6d15292b0a8c8d6c7ccfd116a11db24 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -411,6 +411,8 @@ class OptimizerWithMixedPrecision(object): found_inf = paddle.tensor.creation._memcpy(found_inf, paddle.CPUPlace()) real_optimizer._set_auxiliary_var('found_inf', found_inf) + elif hasattr(real_optimizer, "_set_auxiliary_var"): + real_optimizer._set_auxiliary_var('found_inf', found_inf) optimize_ops = self._optimizer.apply_gradients(params_grads) return optimize_ops diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 36546c1de12048d0327e859b83016fc73cffd4f7..e3e5bc4f3270343902e11766c42e34dc96e18352 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -80,11 +80,27 @@ def _dtype_to_str(dtype): return 'fp32' +_keep_layer_norm_scale_bias_to_fp32_flag = True + + +def _keep_layer_norm_scale_bias_to_fp32(*args): + global _keep_layer_norm_scale_bias_to_fp32_flag + if len(args) == 0: + return _keep_layer_norm_scale_bias_to_fp32_flag + else: + assert len(args) == 1 and isinstance(args[0], bool) + old_value = _keep_layer_norm_scale_bias_to_fp32_flag + _keep_layer_norm_scale_bias_to_fp32_flag = args[0] + return old_value + + def _keep_fp32_input(op, in_name): op_type = op.type - if op_type in ['batch_norm', 'layer_norm']: + if op_type == 'batch_norm': # Scale, Bias, Mean, Variance should be float32. return in_name != 'X' + if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32(): + return in_name != 'X' if op_type == 'fused_bn_add_activation': return in_name not in {'X', 'Z'} if op_type == 'resnet_unit': @@ -98,7 +114,9 @@ def _keep_fp32_input(op, in_name): def _keep_fp32_output(op, out_name): op_type = op.type - if op_type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']: + if op_type in ['batch_norm', 'fused_bn_add_activation']: + return out_name != 'Y' + if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32(): return out_name != 'Y' if op_type == 'resnet_unit': return out_name not in {'Y', 'ConvX', 'ConvZ'} diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index f849d61c5d70081f9977c764aed055269df6e708..ae2c87938c682c0807484c9d4aa3aad77bd20a7e 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3554,14 +3554,14 @@ class LambOptimizer(AdamOptimizer): else: weight_decay = self._weight_decay lr = self._create_param_lr(param_and_grad) - + master_weight = None if framework.in_dygraph_mode(): - _, _, _, _, _ = _C_ops.lamb( - param_and_grad[0], param_and_grad[1], lr, moment1, moment2, - beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, - moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1, - 'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay', - weight_decay) + _C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1, + moment2, beta1_pow_acc, beta2_pow_acc, master_weight, + param_and_grad[0], moment1, moment2, beta1_pow_acc, + beta2_pow_acc, master_weight, 'beta1', self._beta1, + 'beta2', self._beta2, 'epsilon', self._epsilon, + 'weight_decay', weight_decay) return None # create the lamb optimize op diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index 29735f1c89c857f67874bbc8c442b5659682c6ba..7984ca5571658c96b24acaf49fe6de853bc345b3 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -21,6 +21,7 @@ import paddle.fluid.core as core import paddle.fluid as fluid import six from fake_reader import fake_imdb_reader +from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip paddle.enable_static() @@ -566,5 +567,35 @@ class TestDygraphGradientClipFP64(unittest.TestCase): % (a, b)) +class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase): + def check_main(self, expected_has_cast_op): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + names = ["p0", "p1"] + shapes = [[2, 3], [4, 5]] + + param_and_grads = [] + main_block = main_prog.global_block() + for name, shape in zip(names, shapes): + p = main_block.create_parameter( + name=name, shape=shape, dtype='float16') + g = main_block.create_parameter( + name=p.name + '@GRAD', shape=p.shape, dtype=p.dtype) + param_and_grads.append((p, g)) + + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + clip(param_and_grads) + actual_has_cast = any(op.type == 'cast' for op in main_block.ops) + self.assertEqual(actual_has_cast, expected_has_cast_op) + + def test_main(self): + self.check_main(True) + _allow_pure_fp16_global_norm_clip(True) + self.check_main(False) + _allow_pure_fp16_global_norm_clip(False) + self.check_main(True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lambv2_op.py b/python/paddle/fluid/tests/unittests/test_lambv2_op.py index 861418679a36620d2a31bf375de50c65cc10b5ea..24a22f802ce92f2efffe15169ef36496f82664b4 100644 --- a/python/paddle/fluid/tests/unittests/test_lambv2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lambv2_op.py @@ -19,6 +19,7 @@ import numpy as np from op_test import OpTest from paddle.fluid import core from paddle.fluid.op import Operator +from paddle.fluid.dygraph.base import switch_to_static_graph import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers @@ -181,5 +182,58 @@ class TestLambOpV2Group(TestLambOpV2): adam.clear_gradients() +class TestLambOpMultiPrecision(unittest.TestCase): + def check_main(self, x_np, place, multi_precision=False, seed=10, n=10): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + paddle.seed(seed) + with paddle.static.amp.fp16_guard(): + x = paddle.static.data( + name='x', shape=[None, 10], dtype='float32') + linear = paddle.nn.Linear(10, 2) + hidden = linear(x) + loss = paddle.mean(hidden) + + optimizer = paddle.optimizer.Lamb(learning_rate=1e-3) + optimizer._multi_precision = multi_precision + if multi_precision: + optimizer = paddle.static.amp.decorate( + optimizer, use_pure_fp16=True, use_fp16_guard=True) + optimizer.minimize(loss) + + weight, bias = linear.weight, linear.bias + scope = paddle.static.Scope() + exe = paddle.static.Executor(place) + scope = paddle.static.Scope() + x = main_prog.global_block().var(x.name) + if x.dtype == core.VarDesc.VarType.FP16: + x_np = x_np.astype(np.float16) + + with paddle.static.scope_guard(scope): + exe.run(startup_prog) + if multi_precision: + optimizer.amp_init(place) + weight_np, bias_np = None, None + for i in range(n): + feed_dict = {x.name: x_np} + weight_np, bias_np = exe.run(main_prog, + feed=feed_dict, + fetch_list=[weight, bias]) + return weight_np.astype('float32'), bias_np.astype('float32') + + @switch_to_static_graph + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + + place = paddle.CUDAPlace(0) + x_np = np.random.random(size=[5, 10]).astype('float32') + weight_1, bias_1 = self.check_main(x_np, place, multi_precision=False) + weight_2, bias_2 = self.check_main(x_np, place, multi_precision=True) + self.assertTrue(np.all(np.abs(weight_1 - weight_2) < 1e-3)) + self.assertTrue(np.all(np.abs(bias_1 - bias_2) < 1e-7)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 98a503eb1ea6f6dd6ecd49579a231cf0e52b7b73..d2d931f148078d124a25ddbb888b3e9cb5911211 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -20,9 +20,13 @@ import paddle from operator import mul import paddle.fluid.core as core import paddle.fluid as fluid +import paddle.nn.functional as F from functools import reduce from op_test import _set_use_system_allocator from paddle.fluid import Program, program_guard +from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32 + +paddle.enable_static() np.random.random(123) @@ -325,5 +329,58 @@ class TestDygraphLayerNormAPIError(unittest.TestCase): self.assertRaises(TypeError, layer_norm, x2) +class TestFP16ScaleBiasLayerNorm(unittest.TestCase): + def check_main(self, x_np, weight_np, bias_np, dtype): + paddle.disable_static() + + weight_np = weight_np.astype(dtype) + bias_np = bias_np.astype(dtype) + + x = paddle.to_tensor(x_np) + weight = paddle.to_tensor(weight_np) + bias = paddle.to_tensor(bias_np) + x.stop_gradient = False + weight.stop_gradient = False + bias.stop_gradient = False + y = F.layer_norm(x, x.shape[1:], weight, bias) + x_g, w_g, b_g = paddle.grad(y, [x, weight, bias]) + y_np = y.numpy().astype('float32') + x_g_np = x_g.numpy().astype('float32') + w_g_np = w_g.numpy().astype('float16') + b_g_np = b_g.numpy().astype('float32') + + paddle.enable_static() + return y_np, x_g_np, w_g_np, b_g_np + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + x_np = np.random.random([10, 20]).astype('float16') + weight_np = np.random.random([20]).astype('float16') + bias_np = np.random.random([20]).astype('float16') + + y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main( + x_np, weight_np, bias_np, 'float16') + y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main( + x_np, weight_np, bias_np, 'float32') + + def assert_equal(x, y): + self.assertTrue(np.array_equal(x, y)) + + assert_equal(y_np_1, y_np_2) + assert_equal(x_g_np_1, x_g_np_2) + assert_equal(w_g_np_1, w_g_np_2) + assert_equal(b_g_np_1, b_g_np_2) + + +class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase): + def test_main(self): + self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) + _keep_layer_norm_scale_bias_to_fp32(False) + self.assertFalse(_keep_layer_norm_scale_bias_to_fp32()) + _keep_layer_norm_scale_bias_to_fp32(True) + self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/optimizer/lamb.py b/python/paddle/optimizer/lamb.py index 43d4d326bd7e9a7a23d4fcc5c44a166a1d6b9d8e..894c829f58830540d7e8b74a9ce1da6e287dded5 100644 --- a/python/paddle/optimizer/lamb.py +++ b/python/paddle/optimizer/lamb.py @@ -16,6 +16,9 @@ from .optimizer import Optimizer from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable +from ..fluid import layers +from ..fluid import unique_name +from ..fluid.layer_helper import LayerHelper from paddle import _C_ops __all__ = [] @@ -127,6 +130,36 @@ class Lamb(Optimizer): 'lamb_weight_decay': lamb_weight_decay, 'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn, } + self._master_weights = {} + # TODO(zengjinle): expose API as soon as possible + self._multi_precision = False + + def _create_master_weight(self, param): + assert self._multi_precision + if param.name in self._master_weights: + var = self._master_weights[param.name] + else: + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32 + }) + self._master_weights[param.name] = var + return var def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) @@ -135,18 +168,51 @@ class Lamb(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: - self._add_accumulator(self._moment1_acc_str, p) - self._add_accumulator(self._moment2_acc_str, p) - self._add_accumulator( + if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + master_p = self._create_master_weight(p) + self._add_moments_pows(master_p) + else: + self._add_moments_pows(p) + + def _get_accumulator(self, name, param): + """Utility function to fetch an accumulator for a parameter + Args: + name: name of the accumulator + param: parameter variable for which accumulator is to be fetched + Returns: + accumulator variable for the parameter + """ + if self._name is not None: + name = self._name + "_" + name + find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 + target_param = self._master_weights[ + param.name] if find_master else param + target_name = target_param.name + if (name not in self._accumulators or + target_name not in self._accumulators[name]): + raise Exception("Accumulator {} does not exist for parameter {}". + format(name, target_name)) + return self._accumulators[name][target_name] + + def _add_moments_pows(self, p): + acc_dtype = p.dtype + if acc_dtype == core.VarDesc.VarType.FP16: + acc_dtype = core.VarDesc.VarType.FP32 + + self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) + self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) + self._add_accumulator( name=self._beta1_pow_acc_str, param=p, + dtype=acc_dtype, fill_value=0.9 if isinstance(self._beta1, Variable) \ else self._beta1, shape=[1], type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') - self._add_accumulator( + self._add_accumulator( name=self._beta2_pow_acc_str, param=p, + dtype=acc_dtype, fill_value=0.999 if isinstance(self._beta2, Variable) \ else self._beta2, shape=[1], @@ -175,13 +241,20 @@ class Lamb(Optimizer): weight_decay = self._lamb_weight_decay lr = self._create_param_lr(param_and_grad) + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = self._master_weights[param_and_grad[0] + .name] if find_master else None + found_inf = self._get_auxiliary_var('found_inf') + if framework.in_dygraph_mode(): - _, _, _, _, _ = _C_ops.lamb( - param_and_grad[0], param_and_grad[1], lr, moment1, moment2, - beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, - moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1, - 'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay', - weight_decay) + _C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1, + moment2, beta1_pow_acc, beta2_pow_acc, master_weight, + param_and_grad[0], moment1, moment2, beta1_pow_acc, + beta2_pow_acc, master_weight, 'beta1', self._beta1, + 'beta2', self._beta2, 'epsilon', self._epsilon, + 'weight_decay', weight_decay, 'multi_precision', + find_master) return None # create the lamb optimize op @@ -205,9 +278,17 @@ class Lamb(Optimizer): "beta1": self._beta1, "beta2": self._beta2, "epsilon": self._epsilon, - "weight_decay": weight_decay + "weight_decay": weight_decay, + "multi_precision": find_master, } + if find_master: + inputs["MasterParam"] = master_weight + outputs["MasterParamOut"] = master_weight + + if found_inf: + inputs["SkipUpdate"] = found_inf + lamb_op = block.append_op( type=self.type, inputs=inputs, diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index e76ea622148fb91a0411682b69f6e1e21fe4921e..abfaf489822fcde785e69205e073046ffca83776 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -217,6 +217,14 @@ class Optimizer(object): else: self._param_groups = self._parameter_list + self._auxiliary_vars = {} + + def _set_auxiliary_var(self, key, val): + self._auxiliary_vars[key] = val + + def _get_auxiliary_var(self, key): + return self._auxiliary_vars.get(key, None) + @framework.dygraph_only def state_dict(self): '''