未验证 提交 2ac6a7e4 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Add rmsnorm residual bias add and quant (#55965)

* add rmsnorm residual bias add and quant

* refine python interface

* add rmsnorm unittest

* Add layernorm

* fix layernorm unittest

* refine unittest

* fix example code

* fix review comment
上级 1ad502df
...@@ -1017,6 +1017,16 @@ ...@@ -1017,6 +1017,16 @@
data_type : dtype data_type : dtype
backend : place backend : place
- op : fused_bias_residual_layernorm
args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, float residual_alpha, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound)
output : Tensor(out), Tensor(residual_out), Tensor(mean), Tensor(variance)
infer_meta :
func : FusedLayerNormInferMeta
kernel :
func : fused_bias_residual_layernorm
data_type : x
optional : bias, residual, norm_weight, norm_bias, residual_out
- op : gather - op : gather
args : (Tensor x, Tensor index, Scalar axis=0) args : (Tensor x, Tensor index, Scalar axis=0)
output : Tensor(out) output : Tensor(out)
...@@ -2071,14 +2081,14 @@ ...@@ -2071,14 +2081,14 @@
backward : reverse_grad backward : reverse_grad
- op : rms_norm - op : rms_norm
args : (Tensor x, Tensor weight, Tensor bias, float epsilon, int begin_norm_axis) args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound)
output : Tensor(out) output : Tensor(out), Tensor(residual_out)
infer_meta : infer_meta :
func : RmsNormInferMeta func : RmsNormInferMeta
kernel : kernel :
func : rms_norm func : rms_norm
data_type : x data_type : x
optional : bias optional : bias, residual, norm_bias, residual_out
- op : rmsprop_ - op : rmsprop_
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false) args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)
......
...@@ -3239,38 +3239,6 @@ void Unpool3dInferMeta(const MetaTensor& x, ...@@ -3239,38 +3239,6 @@ void Unpool3dInferMeta(const MetaTensor& x,
} }
} }
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();
size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}
PADDLE_ENFORCE_EQ(normalized_dims,
weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
weight.dims()[0]));
auto out_dims = phi::make_ddim(x_dims_vec);
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
...@@ -490,11 +490,4 @@ void Unpool3dInferMeta(const MetaTensor& x, ...@@ -490,11 +490,4 @@ void Unpool3dInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out);
} // namespace phi } // namespace phi
...@@ -1506,6 +1506,68 @@ void FusedBiasActInferMeta(const MetaTensor& x, ...@@ -1506,6 +1506,68 @@ void FusedBiasActInferMeta(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void FusedLayerNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
const MetaTensor& norm_weight,
const MetaTensor& norm_bias,
const float epsilon,
const float residual_alpha,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* residual_out,
MetaTensor* mean,
MetaTensor* variance) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();
size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}
int32_t rows = 1;
for (int i = 0; i < begin_norm_axis; i++) {
rows *= x.dims()[i];
}
PADDLE_ENFORCE_EQ(normalized_dims,
norm_weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
norm_weight.dims()[0]));
auto out_dims = phi::make_ddim(x_dims_vec);
out->set_dims(out_dims);
if (quant_scale <= 0.0f) {
out->set_dtype(x.dtype());
} else {
out->set_dtype(phi::DataType::INT8);
}
out->set_layout(x.layout());
residual_out->set_dims(out_dims);
residual_out->set_dtype(x.dtype());
residual_out->set_layout(x.layout());
mean->set_dims(phi::make_ddim({rows}));
mean->set_dtype(DataType::FLOAT32);
mean->set_layout(x.layout());
variance->set_dims(phi::make_ddim({rows}));
variance->set_dtype(DataType::FLOAT32);
variance->set_layout(x.layout());
}
void FusedLinearParamGradAddInferMeta(const MetaTensor& x, void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout, const MetaTensor& dout,
const MetaTensor& dweight, const MetaTensor& dweight,
...@@ -2918,6 +2980,54 @@ void PsroiPoolInferMeta(const MetaTensor& x, ...@@ -2918,6 +2980,54 @@ void PsroiPoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
const MetaTensor& norm_weight,
const MetaTensor& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* residual_out) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();
size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}
PADDLE_ENFORCE_EQ(normalized_dims,
norm_weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
norm_weight.dims()[0]));
auto out_dims = phi::make_ddim(x_dims_vec);
out->set_dims(out_dims);
if (quant_scale <= 0.0f) {
out->set_dtype(x.dtype());
} else {
out->set_dtype(phi::DataType::INT8);
}
out->set_layout(x.layout());
out->share_lod(x);
residual_out->set_dims(out_dims);
residual_out->set_dtype(x.dtype());
residual_out->set_layout(x.layout());
residual_out->share_lod(x);
}
void RmspropInferMeta(const MetaTensor& param, void RmspropInferMeta(const MetaTensor& param,
const MetaTensor& mean_square, const MetaTensor& mean_square,
const MetaTensor& grad, const MetaTensor& grad,
......
...@@ -301,6 +301,23 @@ void FusedBiasActInferMeta(const MetaTensor& x, ...@@ -301,6 +301,23 @@ void FusedBiasActInferMeta(const MetaTensor& x,
float quant_min_bound, float quant_min_bound,
MetaTensor* out); MetaTensor* out);
void FusedLayerNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
const MetaTensor& norm_weight,
const MetaTensor& norm_bias,
const float epsilon,
const float residual_alpha,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* residual_out,
MetaTensor* mean,
MetaTensor* variance);
void FusedLinearParamGradAddInferMeta(const MetaTensor& x, void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout, const MetaTensor& dout,
const MetaTensor& dweight, const MetaTensor& dweight,
...@@ -516,6 +533,20 @@ void PsroiPoolInferMeta(const MetaTensor& x, ...@@ -516,6 +533,20 @@ void PsroiPoolInferMeta(const MetaTensor& x,
float spatial_scale, float spatial_scale,
MetaTensor* out); MetaTensor* out);
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
const MetaTensor& norm_weight,
const MetaTensor& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* residual_out);
void RmspropInferMeta(const MetaTensor& param, void RmspropInferMeta(const MetaTensor& param,
const MetaTensor& mean_square, const MetaTensor& mean_square,
const MetaTensor& grad, const MetaTensor& grad,
......
...@@ -124,6 +124,7 @@ __global__ void FusedDropoutActBias( ...@@ -124,6 +124,7 @@ __global__ void FusedDropoutActBias(
nullptr, nullptr,
nullptr, nullptr,
act, act,
1.0, /*Since Dropout Act bias do not use residual alpha, we set 1.0*/
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_next_in_scale, quant_next_in_scale,
......
...@@ -123,10 +123,12 @@ class FusedDropoutHelper { ...@@ -123,10 +123,12 @@ class FusedDropoutHelper {
FusedDropoutHelper(const phi::GPUContext& ctx, FusedDropoutHelper(const phi::GPUContext& ctx,
const int rows, const int rows,
const int cols, const int cols,
const DropoutParam& dropout_param) { const DropoutParam& dropout_param,
const float residual_alpha = 1.0) {
rows_ = rows; rows_ = rows;
cols_ = cols; cols_ = cols;
dropout_param_ = dropout_param; dropout_param_ = dropout_param;
residual_alpha_ = residual_alpha;
} }
// out = residual + dropout( src + bias ) // out = residual + dropout( src + bias )
...@@ -156,7 +158,8 @@ class FusedDropoutHelper { ...@@ -156,7 +158,8 @@ class FusedDropoutHelper {
ctx, ctx,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_next_in_scale); quant_next_in_scale,
residual_alpha_);
} }
void ResidualDropoutBiasGrad(const phi::GPUContext& ctx, void ResidualDropoutBiasGrad(const phi::GPUContext& ctx,
...@@ -336,6 +339,7 @@ class FusedDropoutHelper { ...@@ -336,6 +339,7 @@ class FusedDropoutHelper {
int rows_; int rows_;
int cols_; int cols_;
DropoutParam dropout_param_; DropoutParam dropout_param_;
float residual_alpha_;
}; };
template <typename T, template <typename T,
...@@ -348,20 +352,23 @@ class FusedDropoutLayerNormHelper ...@@ -348,20 +352,23 @@ class FusedDropoutLayerNormHelper
FusedDropoutLayerNormHelper() {} FusedDropoutLayerNormHelper() {}
FusedDropoutLayerNormHelper(const int rows, FusedDropoutLayerNormHelper(const int rows,
const int cols, const int cols,
const float epsilon) { const float epsilon,
const float residual_alpha = 1.0) {
using U = phi::funcs::LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
this->rows_ = rows; this->rows_ = rows;
this->cols_ = cols; this->cols_ = cols;
epsilon_ = epsilon; epsilon_ = epsilon;
this->residual_alpha_ = residual_alpha;
} }
FusedDropoutLayerNormHelper(const phi::GPUContext& ctx, FusedDropoutLayerNormHelper(const phi::GPUContext& ctx,
const int rows, const int rows,
const int cols, const int cols,
const DropoutParam& dropout_param, const DropoutParam& dropout_param,
const float epsilon) const float epsilon,
const float residual_alpha = 1.0)
: FusedDropoutHelper<T, MaskType, InType, OutType>( : FusedDropoutHelper<T, MaskType, InType, OutType>(
ctx, rows, cols, dropout_param) { ctx, rows, cols, dropout_param, residual_alpha) {
using U = phi::funcs::LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
epsilon_ = epsilon; epsilon_ = epsilon;
} }
...@@ -476,7 +483,8 @@ class FusedDropoutLayerNormHelper ...@@ -476,7 +483,8 @@ class FusedDropoutLayerNormHelper
quant_next_in_scale, quant_next_in_scale,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
quant_min_bound); quant_min_bound,
this->residual_alpha_);
} }
template <typename P = phi::funcs::LayerNormParamType<T>, template <typename P = phi::funcs::LayerNormParamType<T>,
......
此差异已折叠。
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedLayerNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& residual,
const paddle::optional<DenseTensor>& norm_weight,
const paddle::optional<DenseTensor>& norm_bias,
const float epsilon,
const float residual_alpha,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* residual_out,
DenseTensor* mean,
DenseTensor* variance);
} // namespace fusion
} // namespace phi
...@@ -132,7 +132,8 @@ __global__ void FusedLayernormResidualDropoutBias( ...@@ -132,7 +132,8 @@ __global__ void FusedLayernormResidualDropoutBias(
T *dst, T *dst,
T *layernorm_dst, T *layernorm_dst,
LayerNormParamType<T> *mean, LayerNormParamType<T> *mean,
LayerNormParamType<T> *var) { LayerNormParamType<T> *var,
const float residual_alpha = 1.0) {
int col_id = threadIdx.x; int col_id = threadIdx.x;
int row_id = blockIdx.x; int row_id = blockIdx.x;
int idx = row_id * cols + col_id; int idx = row_id * cols + col_id;
...@@ -175,7 +176,8 @@ __global__ void FusedLayernormResidualDropoutBias( ...@@ -175,7 +176,8 @@ __global__ void FusedLayernormResidualDropoutBias(
is_test, is_test,
&mean_val, &mean_val,
&var_val, &var_val,
relu); relu,
residual_alpha);
} }
mean_val = phi::funcs::BlockReduceSum<U>(mean_val, shared_mean); mean_val = phi::funcs::BlockReduceSum<U>(mean_val, shared_mean);
...@@ -233,7 +235,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel( ...@@ -233,7 +235,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel(
T *dst, T *dst,
T *layernorm_dst, T *layernorm_dst,
LayerNormParamType<T> *mean, LayerNormParamType<T> *mean,
LayerNormParamType<T> *var) { LayerNormParamType<T> *var,
const float residual_alpha = 1.0) {
if (dropout_prob != 0.0f) { if (dropout_prob != 0.0f) {
FusedLayernormResidualDropoutBias<T, FusedLayernormResidualDropoutBias<T,
MaskType, MaskType,
...@@ -258,7 +261,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel( ...@@ -258,7 +261,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel(
dst, dst,
layernorm_dst, layernorm_dst,
mean, mean,
var); var,
residual_alpha);
} else { } else {
FusedLayernormResidualDropoutBias<T, FusedLayernormResidualDropoutBias<T,
MaskType, MaskType,
...@@ -283,7 +287,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel( ...@@ -283,7 +287,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel(
dst, dst,
layernorm_dst, layernorm_dst,
mean, mean,
var); var,
residual_alpha);
} }
} }
...@@ -539,7 +544,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -539,7 +544,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) { const float quant_min_bound = -127.0,
const float residual_alpha = 1.0) {
__shared__ U smem[WARPS_M * WARPS_N]; __shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>; using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>; using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
...@@ -641,13 +647,13 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -641,13 +647,13 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
dequant_out_scale[it][jt]) + dequant_out_scale[it][jt]) +
bias[it][jt]) * bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor + static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt]; residual[it][jt] * static_cast<T>(residual_alpha);
x[it][jt] = tmp; x[it][jt] = tmp;
xf[it * VecSize + jt] = U(tmp); xf[it * VecSize + jt] = U(tmp);
} else { } else {
x[it][jt] = (static_cast<T>(x_input[it][jt]) + bias[it][jt]) * x[it][jt] = (static_cast<T>(x_input[it][jt]) + bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor + static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt]; residual[it][jt] * static_cast<T>(residual_alpha);
xf[it * VecSize + jt] = U(x[it][jt]); xf[it * VecSize + jt] = U(x[it][jt]);
} }
} }
...@@ -663,12 +669,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -663,12 +669,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
T tmp = static_cast<T>(static_cast<float>(x_input[it][jt]) * T tmp = static_cast<T>(static_cast<float>(x_input[it][jt]) *
dequant_out_scale[it][jt]) * dequant_out_scale[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor + static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt]; residual[it][jt] * static_cast<T>(residual_alpha);
x[it][jt] = tmp; x[it][jt] = tmp;
} else { } else {
x[it][jt] = static_cast<T>(x_input[it][jt]) * x[it][jt] = static_cast<T>(x_input[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor + static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt]; residual[it][jt] * static_cast<T>(residual_alpha);
} }
xf[it * VecSize + jt] = U(x[it][jt]); xf[it * VecSize + jt] = U(x[it][jt]);
} }
...@@ -848,7 +854,8 @@ void LaunchLayernormResidualDropoutBias( ...@@ -848,7 +854,8 @@ void LaunchLayernormResidualDropoutBias(
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) { const float quant_min_bound = -127.0,
const float residual_alpha = 1.0) {
// dropout_prob == 1.0f // dropout_prob == 1.0f
// NOTE(minghaoBD): OutType should be T if drop_out_rate == 1.0 // NOTE(minghaoBD): OutType should be T if drop_out_rate == 1.0
if (std::abs(dropout_prob - 1.0f) < 1e-5) { if (std::abs(dropout_prob - 1.0f) < 1e-5) {
...@@ -942,7 +949,8 @@ void LaunchLayernormResidualDropoutBias( ...@@ -942,7 +949,8 @@ void LaunchLayernormResidualDropoutBias(
quant_next_in_scale, \ quant_next_in_scale, \
quant_round_type, \ quant_round_type, \
quant_max_bound, \ quant_max_bound, \
quant_min_bound); \ quant_min_bound, \
residual_alpha); \
} else { \ } else { \
fused_fast_ln_fwd_kernel< \ fused_fast_ln_fwd_kernel< \
false, \ false, \
...@@ -986,7 +994,8 @@ void LaunchLayernormResidualDropoutBias( ...@@ -986,7 +994,8 @@ void LaunchLayernormResidualDropoutBias(
quant_next_in_scale, \ quant_next_in_scale, \
quant_round_type, \ quant_round_type, \
quant_max_bound, \ quant_max_bound, \
quant_min_bound); \ quant_min_bound, \
residual_alpha); \
} \ } \
} break } break
...@@ -1036,7 +1045,8 @@ void LaunchLayernormResidualDropoutBias( ...@@ -1036,7 +1045,8 @@ void LaunchLayernormResidualDropoutBias(
dst, dst,
reinterpret_cast<T *>(layernorm_dst), reinterpret_cast<T *>(layernorm_dst),
mean, mean,
var); var,
residual_alpha);
} else { } else {
if (can_call_fast_ln_kernel) { if (can_call_fast_ln_kernel) {
switch (cols) { switch (cols) {
...@@ -1074,7 +1084,8 @@ void LaunchLayernormResidualDropoutBias( ...@@ -1074,7 +1084,8 @@ void LaunchLayernormResidualDropoutBias(
dst, dst,
reinterpret_cast<T *>(layernorm_dst), reinterpret_cast<T *>(layernorm_dst),
mean, mean,
var); var,
residual_alpha);
} }
} }
} }
......
...@@ -53,6 +53,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -53,6 +53,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
typename phi::dtype::MPTypeTrait<T>::Type *mean_val, typename phi::dtype::MPTypeTrait<T>::Type *mean_val,
typename phi::dtype::MPTypeTrait<T>::Type *var_val, typename phi::dtype::MPTypeTrait<T>::Type *var_val,
Functor act_func, Functor act_func,
const float residual_alpha = 1.0,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
...@@ -121,10 +122,11 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -121,10 +122,11 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
tmp = act_func(tmp); tmp = act_func(tmp);
} }
if (HasDropout) { if (HasDropout) {
dest_vec[ii] = dest_vec[ii] = tmp * static_cast<T>(mask_vec[ii]) * factor +
tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii]; residual_vec[ii] * static_cast<T>(residual_alpha);
} else { } else {
dest_vec[ii] = tmp * factor + residual_vec[ii]; dest_vec[ii] =
tmp * factor + residual_vec[ii] * static_cast<T>(residual_alpha);
} }
if (ComputeLayerNorm) { if (ComputeLayerNorm) {
U tmp = static_cast<U>(dest_vec[ii]); U tmp = static_cast<U>(dest_vec[ii]);
...@@ -274,7 +276,8 @@ __global__ void FusedResidualDropoutBias( ...@@ -274,7 +276,8 @@ __global__ void FusedResidualDropoutBias(
const bool is_test, const bool is_test,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0) { const float quant_next_in_scale = 1.0,
const float residual_alpha = 1.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x; int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y; int row_id = blockIdx.y;
int idx = row_id * cols + col_id; int idx = row_id * cols + col_id;
...@@ -316,6 +319,7 @@ __global__ void FusedResidualDropoutBias( ...@@ -316,6 +319,7 @@ __global__ void FusedResidualDropoutBias(
nullptr, nullptr,
nullptr, nullptr,
relu, relu,
residual_alpha,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_next_in_scale); quant_next_in_scale);
...@@ -345,7 +349,8 @@ void LaunchResidualDropoutBias(const uint32_t rows, ...@@ -345,7 +349,8 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const phi::GPUContext &ctx, const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0) { const float quant_next_in_scale = 1.0,
const float residual_alpha = 1.0) {
// dropout_prob == 1.0f // dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) { if (std::abs(dropout_prob - 1.0f) < 1e-5) {
// NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0 // NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0
...@@ -396,7 +401,8 @@ void LaunchResidualDropoutBias(const uint32_t rows, ...@@ -396,7 +401,8 @@ void LaunchResidualDropoutBias(const uint32_t rows,
is_test, \ is_test, \
quant_last_in_scale, \ quant_last_in_scale, \
dequant_out_scale_data, \ dequant_out_scale_data, \
quant_next_in_scale); \ quant_next_in_scale, \
residual_alpha); \
} else { \ } else { \
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType, __has_dropout> \ FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType, __has_dropout> \
<<<config.block_per_grid, \ <<<config.block_per_grid, \
...@@ -416,7 +422,8 @@ void LaunchResidualDropoutBias(const uint32_t rows, ...@@ -416,7 +422,8 @@ void LaunchResidualDropoutBias(const uint32_t rows,
is_test, \ is_test, \
quant_last_in_scale, \ quant_last_in_scale, \
dequant_out_scale_data, \ dequant_out_scale_data, \
quant_next_in_scale); \ quant_next_in_scale, \
residual_alpha); \
} \ } \
} while (0) } while (0)
......
...@@ -937,20 +937,26 @@ struct AffineQuantStore { ...@@ -937,20 +937,26 @@ struct AffineQuantStore {
template <typename T, typename Context> template <typename T, typename Context>
void RmsNormKernel(const Context& dev_ctx, void RmsNormKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias, const paddle::optional<DenseTensor>& bias,
float epsilon, const paddle::optional<DenseTensor>& residual,
int begin_norm_axis, const DenseTensor& norm_weight,
DenseTensor* out) { const paddle::optional<DenseTensor>& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* residual_out) {
#if defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it"; LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else #else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type; using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
const T* weight_data = weight.data<T>(); const T* norm_weight_data = norm_weight.data<T>();
const T* bias_data = bias ? bias.get().data<T>() : nullptr; const T* norm_bias_data = norm_bias ? norm_bias.get().data<T>() : nullptr;
T* out_data = dev_ctx.template Alloc<T>(out);
int32_t rows = 1; int32_t rows = 1;
int32_t cols = 1; int32_t cols = 1;
...@@ -962,283 +968,64 @@ void RmsNormKernel(const Context& dev_ctx, ...@@ -962,283 +968,64 @@ void RmsNormKernel(const Context& dev_ctx,
cols *= x.dims()[i]; cols *= x.dims()[i];
} }
DirectLoad<T, ComputeType> load(x_data, cols); if (residual) {
AffineStore<ComputeType, T> store(out_data, cols, weight_data, bias_data); // Do RMSNorm(bias_add + residual + x)
T* residual_out_data = dev_ctx.template Alloc<T>(residual_out);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>( const T* residual_data = residual.get().data<T>();
dev_ctx.stream(), load, store, rows, cols, epsilon); const T* bias_data = bias ? bias.get().data<T>() : nullptr;
#endif
}
template <typename T, typename Context>
void RmsNormWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
T* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
DirectLoad<T, ComputeType> load(x, cols);
AffineStore<ComputeType, T> store(output, cols, weight, bias);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void RmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* weight,
const phi::dtype::float16* bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::float16* output);
template void RmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* weight,
const phi::dtype::bfloat16* bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::bfloat16* output);
template void RmsNormWrapper(const phi::GPUContext& ctx,
const float* x,
const float* weight,
const float* bias,
const float epsilon,
const int rows,
const int cols,
float* output);
// ========== ResidualAdd + RMSNorm ==========
template <typename T, typename Context>
void ResidualAddRmsNormWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
T* residual_output,
T* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
ResidualAddBiasLoad<T, ComputeType> load( ResidualAddBiasLoad<T, ComputeType> load(
x, residual, bias, residual_output, cols); x_data, residual_data, bias_data, residual_out_data, cols);
AffineStore<ComputeType, T> store(output, cols, norm_weight, norm_bias); if (quant_scale <= 0.0f) {
// No Quantize.
T* out_data = dev_ctx.template Alloc<T>(out);
AffineStore<ComputeType, T> store(
out_data, cols, norm_weight_data, norm_bias_data);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>( DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon); dev_ctx.stream(), load, store, rows, cols, epsilon);
#endif } else {
} // Quantize and output int8.
int8_t* out_data = dev_ctx.template Alloc<int8_t>(out);
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx, AffineQuantStore<int8_t, ComputeType, T, true, true> store(
const phi::dtype::float16* x, out_data,
const phi::dtype::float16* residual,
const phi::dtype::float16* bias,
const phi::dtype::float16* norm_weight,
const phi::dtype::float16* norm_bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::float16* residual_output,
phi::dtype::float16* output);
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* residual,
const phi::dtype::bfloat16* bias,
const phi::dtype::bfloat16* norm_weight,
const phi::dtype::bfloat16* norm_bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::bfloat16* residual_output,
phi::dtype::bfloat16* output);
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx,
const float* x,
const float* residual,
const float* bias,
const float* norm_weight,
const float* norm_bias,
const float epsilon,
const int rows,
const int cols,
float* residual_output,
float* output);
// ===== FP16 in, Int8out RMSNorm =====
template <typename T, typename Context>
void RmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
DirectLoad<T, ComputeType> load(x, cols);
AffineQuantStore<int8_t, ComputeType, T, true, true> store(output,
cols, cols,
weight, norm_weight_data,
bias, norm_bias_data,
in_scale, quant_scale,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
quant_min_bound); quant_min_bound);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const float* x,
const float* weight,
const float* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* weight,
const phi::dtype::float16* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx, DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
const phi::dtype::bfloat16* x, dev_ctx.stream(), load, store, rows, cols, epsilon);
const phi::dtype::bfloat16* weight, }
const phi::dtype::bfloat16* bias, } else {
const float epsilon, DirectLoad<T, ComputeType> load(x_data, cols);
const int rows, if (quant_scale <= 0.0f) {
const int cols, // No Quantize.
const float in_scale, T* out_data = dev_ctx.template Alloc<T>(out);
const int quant_round_type, AffineStore<ComputeType, T> store(
const float quant_max_bound, out_data, cols, norm_weight_data, norm_bias_data);
const float quant_min_bound, DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
int8_t* output); dev_ctx.stream(), load, store, rows, cols, epsilon);
} else {
// ===== FP16 in, Int8out ResidualAdd + RMSNorm ===== // Quantize and output int8.
template <typename T, typename Context> int8_t* out_data = dev_ctx.template Alloc<int8_t>(out);
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx, AffineQuantStore<int8_t, ComputeType, T, true, true> store(
const T* x, out_data,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
T* residual_output,
int8_t* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
ResidualAddBiasLoad<T, ComputeType> load(
x, residual, bias, residual_output, cols);
AffineQuantStore<int8_t, ComputeType, T, true, true> store(output,
cols, cols,
norm_weight, norm_weight_data,
norm_bias, norm_bias_data,
in_scale, quant_scale,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
quant_min_bound); quant_min_bound);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>( DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon); dev_ctx.stream(), load, store, rows, cols, epsilon);
}
}
#endif #endif
} }
template void ResidualAddRmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const float* x,
const float* residual,
const float* bias,
const float* norm_weight,
const float* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
float* residual_output,
int8_t* output);
template void ResidualAddRmsNormInt8OutWrapper(
const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* residual,
const phi::dtype::float16* bias,
const phi::dtype::float16* norm_weight,
const phi::dtype::float16* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
phi::dtype::float16* residual_output,
int8_t* output);
template void ResidualAddRmsNormInt8OutWrapper(
const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* residual,
const phi::dtype::bfloat16* bias,
const phi::dtype::bfloat16* norm_weight,
const phi::dtype::bfloat16* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
phi::dtype::bfloat16* residual_output,
int8_t* output);
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(rms_norm, PD_REGISTER_KERNEL(rms_norm,
......
...@@ -22,64 +22,17 @@ namespace phi { ...@@ -22,64 +22,17 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void RmsNormKernel(const Context& dev_ctx, void RmsNormKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias, const paddle::optional<DenseTensor>& bias,
float epsilon, const paddle::optional<DenseTensor>& residual,
int begin_norm_axis, const DenseTensor& norm_weight,
DenseTensor* out); const paddle::optional<DenseTensor>& norm_bias,
template <typename T, typename Context>
void RmsNormWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
T* output);
template <typename T, typename Context>
void ResidualAddRmsNormWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
T* residual_output,
T* output);
template <typename T, typename Context>
void RmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template <typename T, typename Context>
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon, const float epsilon,
const int rows, const int begin_norm_axis,
const int cols, const float quant_scale,
const float in_scale,
const int quant_round_type, const int quant_round_type,
const float quant_max_bound, const float quant_max_bound,
const float quant_min_bound, const float quant_min_bound,
T* residual_output, DenseTensor* out,
int8_t* output); DenseTensor* residual_out);
} // namespace phi } // namespace phi
...@@ -28,7 +28,8 @@ from .fused_rotary_position_embedding import fused_rotary_position_embedding ...@@ -28,7 +28,8 @@ from .fused_rotary_position_embedding import fused_rotary_position_embedding
from .variable_length_memory_efficient_attention import ( from .variable_length_memory_efficient_attention import (
variable_length_memory_efficient_attention, variable_length_memory_efficient_attention,
) )
from .rms_norm import rms_norm from .fused_rms_norm import fused_rms_norm
from .fused_layer_norm import fused_layer_norm
__all__ = [ __all__ = [
'fused_multi_head_attention', 'fused_multi_head_attention',
...@@ -42,5 +43,6 @@ __all__ = [ ...@@ -42,5 +43,6 @@ __all__ = [
'fused_dropout_add', 'fused_dropout_add',
'fused_rotary_position_embedding', 'fused_rotary_position_embedding',
'variable_length_memory_efficient_attention', 'variable_length_memory_efficient_attention',
"rms_norm", "fused_rms_norm",
"fused_layer_norm",
] ]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode
def fused_layer_norm(
x,
norm_weight,
norm_bias,
epsilon,
residual_alpha=1.0,
begin_norm_axis=1,
bias=None,
residual=None,
quant_scale=-1,
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0,
):
r"""
Apply Fused LayerNorm kernel. Also support LayerNorm(bias + residual_alpha * residual + x) fused pattern.
when norm_weight and norm_bias is None, it return fused (bias + residual_alpha * residual + x)
Args:
x (Tensor): the input Tensor..
norm_weight (Tensor): the weight Tensor to affine output.
norm_bias (Tensor): the bias Tensor to affine output.
epsilon (float): a small float number to avoid divide 0.
residual_alpha (float): a scale factor for residual. default is 1.
begin_norm_axis (int): the begin axis to normalize. default is 1.
bias (optional|Tensor): the previous layers's bias to fused.
residual (optional|Tensor): the residual input to fused.
quant_scale (float): the quant scale.
quant_round_type (float): the quant round type.
quant_max_bound (float): the quant max bound to clip.
quant_min_bound (float): the quant min bound to clip.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16)
paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32)
paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32)
epsilon = 1e-6
paddle_layernorm = paddle.incubate.nn.functional.fused_layer_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""
if in_dynamic_mode():
return _C_ops.fused_bias_residual_layernorm(
x,
bias,
residual,
norm_weight,
norm_bias,
epsilon,
residual_alpha,
begin_norm_axis,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
helper = LayerHelper('fused_layernorm', **locals())
out = None
if quant_scale <= 0:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable_for_type_inference(dtype=paddle.int8)
outputs_dict = {}
outputs_dict['out'] = out
outputs_dict['mean'] = helper.create_variable_for_type_inference(
dtype=paddle.float32
)
outputs_dict['variance'] = helper.create_variable_for_type_inference(
dtype=paddle.float32
)
residual_out = helper.create_variable_for_type_inference(dtype=x.dtype)
outputs_dict['residual_out'] = residual_out
inputs = {'x': x, 'norm_weight': norm_weight, 'norm_bias': norm_bias}
if residual is not None:
inputs['residual'] = residual
if bias is not None:
inputs['bias'] = bias
helper.append_op(
type='fused_bias_residual_layernorm',
inputs=inputs,
attrs={
"epsilon": epsilon,
"residual_alpha": residual_alpha,
"begin_norm_axis": begin_norm_axis,
"quant_scale": quant_scale,
"quant_round_type": quant_round_type,
"quant_max_bound": quant_max_bound,
"quant_min_bound": quant_min_bound,
},
outputs=outputs_dict,
)
return out
...@@ -13,21 +13,40 @@ ...@@ -13,21 +13,40 @@
# limitations under the License. # limitations under the License.
import paddle
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
def rms_norm(x, weight, bias, epsilon, begin_norm_axis): def fused_rms_norm(
x,
norm_weight,
norm_bias,
epsilon,
begin_norm_axis,
bias=None,
residual=None,
quant_scale=-1,
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0,
):
r""" r"""
Apply RMSNorm kernel. Apply Fused RMSNorm kernel. Also support RMSNorm(bias + residual + x) fused pattern.
Args: Args:
x (Tensor): the input Tensor.. x (Tensor): the input Tensor..
weight (Tensor): the weight Tensor to affine output. norm_weight (Tensor): the weight Tensor to affine output.
bias (Tensor): the bias Tensor to affine output. norm_bias (Tensor): the bias Tensor to affine output.
epsilon (float): a small float number to avoid divide 0. epsilon (float): a small float number to avoid divide 0.
begin_norm_axis (int): the begin axis to normalize. begin_norm_axis (int): the begin axis to normalize.
bias (optional|Tensor): the previous layers's bias to fused.
residual (optional|Tensor): the residual input to fused.
quant_scale (float): the quant scale.
quant_round_type (float): the quant round type.
quant_max_bound (float): the quant max bound to clip.
quant_min_bound (float): the quant min bound to clip.
Returns: Returns:
Tensor: the output Tensor. Tensor: the output Tensor.
...@@ -42,18 +61,54 @@ def rms_norm(x, weight, bias, epsilon, begin_norm_axis): ...@@ -42,18 +61,54 @@ def rms_norm(x, weight, bias, epsilon, begin_norm_axis):
paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16)
paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16)
epsilon = 1e-6 epsilon = 1e-6
paddle_rmsnorm = paddle.incubate.nn.functional.rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
""" """
if in_dynamic_mode():
if in_dygraph_mode(): return _C_ops.rms_norm(
return _C_ops.rms_norm(x, weight, bias, epsilon, begin_norm_axis) x,
bias,
residual,
norm_weight,
norm_bias,
epsilon,
begin_norm_axis,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
helper = LayerHelper('rms_norm', **locals()) helper = LayerHelper('rms_norm', **locals())
out = None
if quant_scale <= 0:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable_for_type_inference(dtype=paddle.int8)
outputs_dict = {}
outputs_dict['out'] = out
residual_out = helper.create_variable_for_type_inference(dtype=x.dtype)
outputs_dict['residual_out'] = residual_out
inputs = {'x': x, 'norm_weight': norm_weight}
if norm_bias:
inputs['norm_bias'] = norm_bias
if residual is not None:
inputs['residual'] = residual
if bias is not None:
inputs['bias'] = bias
helper.append_op( helper.append_op(
type='rms_norm', type='rms_norm',
inputs={'x': x, 'weight': weight, 'bias': bias}, inputs=inputs,
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis}, attrs={
outputs={'out': out}, "epsilon": epsilon,
"begin_norm_axis": begin_norm_axis,
"quant_scale": quant_scale,
"quant_round_type": quant_round_type,
"quant_max_bound": quant_max_bound,
"quant_min_bound": quant_min_bound,
},
outputs=outputs_dict,
) )
return out return (out, residual_out) if residual is not None else out
...@@ -76,6 +76,7 @@ if(NOT WITH_GPU) ...@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass) list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass) list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass)
list(REMOVE_ITEM TEST_OPS test_fused_comm_buffer) list(REMOVE_ITEM TEST_OPS test_fused_comm_buffer)
...@@ -156,6 +157,7 @@ if(WIN32) ...@@ -156,6 +157,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress) list(REMOVE_ITEM TEST_OPS test_linear_compress)
list(REMOVE_ITEM TEST_OPS test_matmul_int8_op) list(REMOVE_ITEM TEST_OPS test_matmul_int8_op)
list(REMOVE_ITEM TEST_OPS test_variable_length_memory_efficient_attention) list(REMOVE_ITEM TEST_OPS test_variable_length_memory_efficient_attention)
......
此差异已折叠。
...@@ -20,6 +20,73 @@ from paddle import fluid ...@@ -20,6 +20,73 @@ from paddle import fluid
from paddle.fluid import core from paddle.fluid import core
def quant_helper(
x, quant_scale, quant_round_type, quant_max_bound, quant_min_bound
):
quant_value = quant_max_bound * quant_scale * x
if quant_round_type == 0:
quant_value = paddle.to_tensor(np.rint(quant_value.numpy()))
else:
quant_value = paddle.round(quant_value)
return paddle.cast(
paddle.clip(quant_value, quant_min_bound, quant_max_bound),
paddle.int8,
)
def naive_rms_norm(x, gamma, beta, epsilon):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + epsilon) * x
out = out * gamma + beta
return out
def naive_rms_norm_int8(
x,
gamma,
beta,
epsilon,
in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
):
out = naive_rms_norm(x, gamma, beta, epsilon)
out = quant_helper(
out, in_scale, quant_round_type, quant_max_bound, quant_min_bound
)
return out
def naive_residual_biasadd_rms_norm(x, residual, bias, gamma, beta, epsilon):
x = x + residual + bias
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + epsilon) * x
out = out * gamma + beta
return out
def naive_residual_biasadd_rms_norm_int8(
x,
residual,
bias,
gamma,
beta,
epsilon,
in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
):
out = naive_residual_biasadd_rms_norm(
x, residual, bias, gamma, beta, epsilon
)
out = quant_helper(
out, in_scale, quant_round_type, quant_max_bound, quant_min_bound
)
return out
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA " not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
) )
...@@ -28,58 +95,195 @@ class TestRMSNormOp(unittest.TestCase): ...@@ -28,58 +95,195 @@ class TestRMSNormOp(unittest.TestCase):
np.random.seed(20) np.random.seed(20)
batch = 32 batch = 32
cols = 256 cols = 256
self.x_np = np.random.random([batch, 256]) self.x_np = np.random.random([batch, cols])
self.gamma_np = np.random.random([256]) self.residual_np = np.random.random([batch, cols])
self.beta_np = np.random.random([256]) self.bias_np = np.random.random([cols])
self.epsilon = 1e-6
def naive_rms_norm(self, x, gamma, beta): self.norm_weight_np = np.random.random([cols])
variance = x.pow(2).mean(-1, keepdim=True) self.norm_bias_np = np.random.random([cols])
out = paddle.rsqrt(variance + self.epsilon) * x self.epsilon = 1e-6
out = out * gamma + beta self.quant_scale = 0.15
return out self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
def check_main(self, x_np, gamma_np, beta_np, dtype): def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype)) x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype)) gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype)) beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.rms_norm( paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm(
x, gamma, beta, self.epsilon, begin_norm_axis=1 x, gamma, beta, self.epsilon, begin_norm_axis=1
) )
paddle_naive_rmsnorm_out = self.naive_rms_norm(x, gamma, beta) paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
paddle_naive_rmsnorm_out = naive_rms_norm_int8(
x,
gamma,
beta,
self.epsilon,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def check_residual_bias_rmsnorm(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
bias=bias,
residual=residual,
)
paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm(
x, residual, bias, gamma, beta, self.epsilon
)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def check_residual_bias_rmsnorm_int8(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
bias=bias,
residual=residual,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm_int8(
x,
residual,
bias,
gamma,
beta,
self.epsilon,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static() paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def test_rmsnorm_fp16(self): def test_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main( paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm(
self.x_np, self.gamma_np, self.beta_np, 'float16' self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
) )
np.testing.assert_allclose( np.testing.assert_allclose(
paddle_rmsnorm.numpy(), paddle_rmsnorm[0].numpy(),
paddle_naive_rmsnorm.numpy(), paddle_naive_rmsnorm.numpy(),
rtol=1e-03, rtol=1e-3,
atol=1e-3, atol=1e-3,
) )
def test_rmsnorm_fp32(self): def test_rmsnorm_int8(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main( paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8(
self.x_np, self.gamma_np, self.beta_np, 'float32' self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
np.testing.assert_allclose(
paddle_rmsnorm[0].numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=2,
atol=2,
)
def test_residual_bias_add_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
) )
np.testing.assert_allclose( np.testing.assert_allclose(
paddle_rmsnorm.numpy(), paddle_rmsnorm[0].numpy(),
paddle_naive_rmsnorm.numpy(), paddle_naive_rmsnorm.numpy(),
rtol=1e-3, rtol=1e-3,
atol=1e-3, atol=1e-3,
) )
def test_residual_bias_add_rmsnorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
(
paddle_rmsnorm,
paddle_naive_rmsnorm,
) = self.check_residual_bias_rmsnorm_int8(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
paddle_rmsnorm[0].numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=2,
atol=2,
)
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA " not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
...@@ -90,45 +294,145 @@ class TestRMSNormStaticOp(unittest.TestCase): ...@@ -90,45 +294,145 @@ class TestRMSNormStaticOp(unittest.TestCase):
self.batch = 32 self.batch = 32
self.cols = 256 self.cols = 256
self.x_np = np.random.random([self.batch, 256]) self.x_np = np.random.random([self.batch, 256])
self.gamma_np = np.random.random([256]) self.norm_weight_np = np.random.random([256])
self.beta_np = np.random.random([256]) self.norm_bias_np = np.random.random([256])
self.residual_np = np.random.random([self.batch, 256])
self.bias_np = np.random.random([256])
self.epsilon = 1e-6 self.epsilon = 1e-6
self.quant_scale = 0.15
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
def naive_rms_norm(self, x, gamma, beta): def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + self.epsilon) * x
out = out * gamma + beta
return out
def check_main(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype)) x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype)) gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype)) beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_naive_rmsnorm_out = self.naive_rms_norm(x, gamma, beta) paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon)
paddle.enable_static() paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data( x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype name="x_static", shape=[self.batch, self.cols], dtype=dtype
) )
gamma_static = paddle.static.data( gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=dtype name="gamma_static", shape=[self.cols], dtype=dtype
) )
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=dtype
)
outs = paddle.incubate.nn.functional.fused_rms_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(dtype),
"beta_static": beta_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_rmsnorm_out
def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_naive_rmsnorm_out = naive_rms_norm_int8(
x,
gamma,
beta,
self.epsilon,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=dtype
)
beta_static = paddle.static.data( beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=dtype name="beta_static", shape=[self.cols], dtype=dtype
) )
outs = paddle.incubate.nn.functional.fused_rms_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(dtype),
"beta_static": beta_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_rmsnorm_out
def check_residual_bias_rmsnorm(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm(
x, residual, bias, gamma, beta, self.epsilon
)
paddle.enable_static()
outs = paddle.incubate.nn.functional.rms_norm( with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
residual_static = paddle.static.data(
name="residual_static",
shape=[self.batch, self.cols],
dtype=dtype,
)
bias_static = paddle.static.data(
name="bias_static", shape=[self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=dtype
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=dtype
)
outs = paddle.incubate.nn.functional.fused_rms_norm(
x_static, x_static,
gamma_static, gamma_static,
beta_static, beta_static,
self.epsilon, self.epsilon,
begin_norm_axis=1, begin_norm_axis=1,
bias=bias_static,
residual=residual_static,
) )
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
...@@ -137,17 +441,18 @@ class TestRMSNormStaticOp(unittest.TestCase): ...@@ -137,17 +441,18 @@ class TestRMSNormStaticOp(unittest.TestCase):
"x_static": x_np.astype(dtype), "x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(dtype), "gamma_static": gamma_np.astype(dtype),
"beta_static": beta_np.astype(dtype), "beta_static": beta_np.astype(dtype),
"residual_static": residual_np.astype(dtype),
"bias_static": bias_np.astype(dtype),
}, },
fetch_list=[outs], fetch_list=[outs],
) )
return out_s[0], paddle_naive_rmsnorm_out return out_s[0], paddle_naive_rmsnorm_out
def test_rmsnorm_fp16(self): def test_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main( paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm(
self.x_np, self.gamma_np, self.beta_np, 'float16' self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
) )
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -157,11 +462,16 @@ class TestRMSNormStaticOp(unittest.TestCase): ...@@ -157,11 +462,16 @@ class TestRMSNormStaticOp(unittest.TestCase):
atol=1e-3, atol=1e-3,
) )
def test_rmsnorm_fp32(self): def test_residual_bias_add_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda(): if not paddle.is_compiled_with_cuda():
return return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main( paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm(
self.x_np, self.gamma_np, self.beta_np, 'float32' self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
) )
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -171,6 +481,20 @@ class TestRMSNormStaticOp(unittest.TestCase): ...@@ -171,6 +481,20 @@ class TestRMSNormStaticOp(unittest.TestCase):
atol=1e-3, atol=1e-3,
) )
def test_rmsnorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
print("1111")
np.testing.assert_allclose(
paddle_rmsnorm,
paddle_naive_rmsnorm.numpy(),
rtol=2,
atol=2,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册