From 2ac6a7e4d0fed274dfe650cb1a3454ef42154b3e Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Mon, 14 Aug 2023 10:42:56 +0800 Subject: [PATCH] Add rmsnorm residual bias add and quant (#55965) * add rmsnorm residual bias add and quant * refine python interface * add rmsnorm unittest * Add layernorm * fix layernorm unittest * refine unittest * fix example code * fix review comment --- paddle/phi/api/yaml/ops.yaml | 16 +- paddle/phi/infermeta/binary.cc | 32 - paddle/phi/infermeta/binary.h | 7 - paddle/phi/infermeta/multiary.cc | 110 ++ paddle/phi/infermeta/multiary.h | 31 + .../fusion/gpu/fused_dropout_act_bias.h | 1 + .../kernels/fusion/gpu/fused_dropout_helper.h | 20 +- .../fusion/gpu/fused_layernorm_kernel.cu | 1094 +++++++++++++++++ .../fusion/gpu/fused_layernorm_kernel.h | 43 + .../fused_layernorm_residual_dropout_bias.h | 41 +- .../fusion/gpu/fused_residual_dropout_bias.h | 21 +- paddle/phi/kernels/gpu/rms_norm_kernel.cu | 349 +----- paddle/phi/kernels/rms_norm_kernel.h | 69 +- .../paddle/incubate/nn/functional/__init__.py | 6 +- .../nn/functional/fused_layer_norm.py | 125 ++ .../incubate/nn/functional/fused_rms_norm.py | 114 ++ .../paddle/incubate/nn/functional/rms_norm.py | 59 - test/legacy_test/CMakeLists.txt | 2 + test/legacy_test/test_fused_layernorm_op.py | 623 ++++++++++ test/legacy_test/test_rms_norm_op.py | 400 +++++- 20 files changed, 2655 insertions(+), 508 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu create mode 100644 paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h create mode 100644 python/paddle/incubate/nn/functional/fused_layer_norm.py create mode 100644 python/paddle/incubate/nn/functional/fused_rms_norm.py delete mode 100644 python/paddle/incubate/nn/functional/rms_norm.py create mode 100644 test/legacy_test/test_fused_layernorm_op.py diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index ecc29de613d..37a5368f8ee 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1017,6 +1017,16 @@ data_type : dtype backend : place +- op : fused_bias_residual_layernorm + args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, float residual_alpha, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound) + output : Tensor(out), Tensor(residual_out), Tensor(mean), Tensor(variance) + infer_meta : + func : FusedLayerNormInferMeta + kernel : + func : fused_bias_residual_layernorm + data_type : x + optional : bias, residual, norm_weight, norm_bias, residual_out + - op : gather args : (Tensor x, Tensor index, Scalar axis=0) output : Tensor(out) @@ -2071,14 +2081,14 @@ backward : reverse_grad - op : rms_norm - args : (Tensor x, Tensor weight, Tensor bias, float epsilon, int begin_norm_axis) - output : Tensor(out) + args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound) + output : Tensor(out), Tensor(residual_out) infer_meta : func : RmsNormInferMeta kernel : func : rms_norm data_type : x - optional : bias + optional : bias, residual, norm_bias, residual_out - op : rmsprop_ args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index cfc88c5c2d5..fee5882787e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -3239,38 +3239,6 @@ void Unpool3dInferMeta(const MetaTensor& x, } } -void RmsNormInferMeta(const MetaTensor& x, - const MetaTensor& weight, - const MetaTensor& bias, - const float epsilon, - const int begin_norm_axis, - MetaTensor* out) { - std::vector x_dims_vec = phi::vectorize(x.dims()); - auto x_dims_size = x_dims_vec.size(); - - size_t normalized_dims = 1; - for (size_t i = begin_norm_axis; i < x_dims_size; ++i) { - normalized_dims *= x_dims_vec[i]; - } - - PADDLE_ENFORCE_EQ(normalized_dims, - weight.dims()[0], - phi::errors::InvalidArgument( - "The normalized size of Input(X) must equal to be" - "the size of Weight, but received" - "normalized size of Input(X) is [%d], received size" - "of Weight is [%d]", - normalized_dims, - weight.dims()[0])); - - auto out_dims = phi::make_ddim(x_dims_vec); - - out->set_dims(out_dims); - out->set_dtype(x.dtype()); - out->set_layout(x.layout()); - out->share_lod(x); -} - } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 48615cc22c5..8aa4114e740 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -490,11 +490,4 @@ void Unpool3dInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); -void RmsNormInferMeta(const MetaTensor& x, - const MetaTensor& weight, - const MetaTensor& bias, - const float epsilon, - const int begin_norm_axis, - MetaTensor* out); - } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index ee84f6d169d..9b3ad135cf7 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1506,6 +1506,68 @@ void FusedBiasActInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void FusedLayerNormInferMeta(const MetaTensor& x, + const MetaTensor& bias, + const MetaTensor& residual, + const MetaTensor& norm_weight, + const MetaTensor& norm_bias, + const float epsilon, + const float residual_alpha, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + MetaTensor* out, + MetaTensor* residual_out, + MetaTensor* mean, + MetaTensor* variance) { + std::vector x_dims_vec = phi::vectorize(x.dims()); + auto x_dims_size = x_dims_vec.size(); + + size_t normalized_dims = 1; + for (size_t i = begin_norm_axis; i < x_dims_size; ++i) { + normalized_dims *= x_dims_vec[i]; + } + + int32_t rows = 1; + for (int i = 0; i < begin_norm_axis; i++) { + rows *= x.dims()[i]; + } + + PADDLE_ENFORCE_EQ(normalized_dims, + norm_weight.dims()[0], + phi::errors::InvalidArgument( + "The normalized size of Input(X) must equal to be" + "the size of Weight, but received" + "normalized size of Input(X) is [%d], received size" + "of Weight is [%d]", + normalized_dims, + norm_weight.dims()[0])); + + auto out_dims = phi::make_ddim(x_dims_vec); + + out->set_dims(out_dims); + if (quant_scale <= 0.0f) { + out->set_dtype(x.dtype()); + } else { + out->set_dtype(phi::DataType::INT8); + } + out->set_layout(x.layout()); + + residual_out->set_dims(out_dims); + residual_out->set_dtype(x.dtype()); + residual_out->set_layout(x.layout()); + + mean->set_dims(phi::make_ddim({rows})); + mean->set_dtype(DataType::FLOAT32); + mean->set_layout(x.layout()); + + variance->set_dims(phi::make_ddim({rows})); + variance->set_dtype(DataType::FLOAT32); + variance->set_layout(x.layout()); +} + void FusedLinearParamGradAddInferMeta(const MetaTensor& x, const MetaTensor& dout, const MetaTensor& dweight, @@ -2918,6 +2980,54 @@ void PsroiPoolInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RmsNormInferMeta(const MetaTensor& x, + const MetaTensor& bias, + const MetaTensor& residual, + const MetaTensor& norm_weight, + const MetaTensor& norm_bias, + const float epsilon, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + MetaTensor* out, + MetaTensor* residual_out) { + std::vector x_dims_vec = phi::vectorize(x.dims()); + auto x_dims_size = x_dims_vec.size(); + + size_t normalized_dims = 1; + for (size_t i = begin_norm_axis; i < x_dims_size; ++i) { + normalized_dims *= x_dims_vec[i]; + } + + PADDLE_ENFORCE_EQ(normalized_dims, + norm_weight.dims()[0], + phi::errors::InvalidArgument( + "The normalized size of Input(X) must equal to be" + "the size of Weight, but received" + "normalized size of Input(X) is [%d], received size" + "of Weight is [%d]", + normalized_dims, + norm_weight.dims()[0])); + + auto out_dims = phi::make_ddim(x_dims_vec); + + out->set_dims(out_dims); + if (quant_scale <= 0.0f) { + out->set_dtype(x.dtype()); + } else { + out->set_dtype(phi::DataType::INT8); + } + out->set_layout(x.layout()); + out->share_lod(x); + + residual_out->set_dims(out_dims); + residual_out->set_dtype(x.dtype()); + residual_out->set_layout(x.layout()); + residual_out->share_lod(x); +} + void RmspropInferMeta(const MetaTensor& param, const MetaTensor& mean_square, const MetaTensor& grad, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 2d24b2252a5..f1ade56c309 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -301,6 +301,23 @@ void FusedBiasActInferMeta(const MetaTensor& x, float quant_min_bound, MetaTensor* out); +void FusedLayerNormInferMeta(const MetaTensor& x, + const MetaTensor& bias, + const MetaTensor& residual, + const MetaTensor& norm_weight, + const MetaTensor& norm_bias, + const float epsilon, + const float residual_alpha, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + MetaTensor* out, + MetaTensor* residual_out, + MetaTensor* mean, + MetaTensor* variance); + void FusedLinearParamGradAddInferMeta(const MetaTensor& x, const MetaTensor& dout, const MetaTensor& dweight, @@ -516,6 +533,20 @@ void PsroiPoolInferMeta(const MetaTensor& x, float spatial_scale, MetaTensor* out); +void RmsNormInferMeta(const MetaTensor& x, + const MetaTensor& bias, + const MetaTensor& residual, + const MetaTensor& norm_weight, + const MetaTensor& norm_bias, + const float epsilon, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + MetaTensor* out, + MetaTensor* residual_out); + void RmspropInferMeta(const MetaTensor& param, const MetaTensor& mean_square, const MetaTensor& grad, diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h b/paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h index 8868a4435b4..e5f5c9ba50b 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h @@ -124,6 +124,7 @@ __global__ void FusedDropoutActBias( nullptr, nullptr, act, + 1.0, /*Since Dropout Act bias do not use residual alpha, we set 1.0*/ quant_last_in_scale, dequant_out_scale_data, quant_next_in_scale, diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h index c73a35d2265..681e6cdac57 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h @@ -123,10 +123,12 @@ class FusedDropoutHelper { FusedDropoutHelper(const phi::GPUContext& ctx, const int rows, const int cols, - const DropoutParam& dropout_param) { + const DropoutParam& dropout_param, + const float residual_alpha = 1.0) { rows_ = rows; cols_ = cols; dropout_param_ = dropout_param; + residual_alpha_ = residual_alpha; } // out = residual + dropout( src + bias ) @@ -156,7 +158,8 @@ class FusedDropoutHelper { ctx, quant_last_in_scale, dequant_out_scale_data, - quant_next_in_scale); + quant_next_in_scale, + residual_alpha_); } void ResidualDropoutBiasGrad(const phi::GPUContext& ctx, @@ -336,6 +339,7 @@ class FusedDropoutHelper { int rows_; int cols_; DropoutParam dropout_param_; + float residual_alpha_; }; template ; this->rows_ = rows; this->cols_ = cols; epsilon_ = epsilon; + this->residual_alpha_ = residual_alpha; } FusedDropoutLayerNormHelper(const phi::GPUContext& ctx, const int rows, const int cols, const DropoutParam& dropout_param, - const float epsilon) + const float epsilon, + const float residual_alpha = 1.0) : FusedDropoutHelper( - ctx, rows, cols, dropout_param) { + ctx, rows, cols, dropout_param, residual_alpha) { using U = phi::funcs::LayerNormParamType; epsilon_ = epsilon; } @@ -476,7 +483,8 @@ class FusedDropoutLayerNormHelper quant_next_in_scale, quant_round_type, quant_max_bound, - quant_min_bound); + quant_min_bound, + this->residual_alpha_); } template , diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu new file mode 100644 index 00000000000..138e5583a3a --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu @@ -0,0 +1,1094 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// Original OneFlow copyright notice: + +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh +// The following code modified from OneFlow's implementation, and change to use +// single Pass algorithm. Support Int8 quant, dequant Load/Store implementation. + +#include "paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#ifndef PADDLE_WITH_HIP +#include +#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h" +#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h" +#endif + +namespace phi { + +namespace fusion { + +namespace { + +#ifndef PADDLE_WITH_HIP + +constexpr int kWarpSize = 32; + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return max(a, b); + } +}; + +template