未验证 提交 2ac6a7e4 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Add rmsnorm residual bias add and quant (#55965)

* add rmsnorm residual bias add and quant

* refine python interface

* add rmsnorm unittest

* Add layernorm

* fix layernorm unittest

* refine unittest

* fix example code

* fix review comment
上级 1ad502df
......@@ -1017,6 +1017,16 @@
data_type : dtype
backend : place
- op : fused_bias_residual_layernorm
args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, float residual_alpha, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound)
output : Tensor(out), Tensor(residual_out), Tensor(mean), Tensor(variance)
infer_meta :
func : FusedLayerNormInferMeta
kernel :
func : fused_bias_residual_layernorm
data_type : x
optional : bias, residual, norm_weight, norm_bias, residual_out
- op : gather
args : (Tensor x, Tensor index, Scalar axis=0)
output : Tensor(out)
......@@ -2071,14 +2081,14 @@
backward : reverse_grad
- op : rms_norm
args : (Tensor x, Tensor weight, Tensor bias, float epsilon, int begin_norm_axis)
output : Tensor(out)
args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound)
output : Tensor(out), Tensor(residual_out)
infer_meta :
func : RmsNormInferMeta
kernel :
func : rms_norm
data_type : x
optional : bias
optional : bias, residual, norm_bias, residual_out
- op : rmsprop_
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)
......
......@@ -3239,38 +3239,6 @@ void Unpool3dInferMeta(const MetaTensor& x,
}
}
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();
size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}
PADDLE_ENFORCE_EQ(normalized_dims,
weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
weight.dims()[0]));
auto out_dims = phi::make_ddim(x_dims_vec);
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
......@@ -490,11 +490,4 @@ void Unpool3dInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out);
} // namespace phi
......@@ -1506,6 +1506,68 @@ void FusedBiasActInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
void FusedLayerNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
const MetaTensor& norm_weight,
const MetaTensor& norm_bias,
const float epsilon,
const float residual_alpha,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* residual_out,
MetaTensor* mean,
MetaTensor* variance) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();
size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}
int32_t rows = 1;
for (int i = 0; i < begin_norm_axis; i++) {
rows *= x.dims()[i];
}
PADDLE_ENFORCE_EQ(normalized_dims,
norm_weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
norm_weight.dims()[0]));
auto out_dims = phi::make_ddim(x_dims_vec);
out->set_dims(out_dims);
if (quant_scale <= 0.0f) {
out->set_dtype(x.dtype());
} else {
out->set_dtype(phi::DataType::INT8);
}
out->set_layout(x.layout());
residual_out->set_dims(out_dims);
residual_out->set_dtype(x.dtype());
residual_out->set_layout(x.layout());
mean->set_dims(phi::make_ddim({rows}));
mean->set_dtype(DataType::FLOAT32);
mean->set_layout(x.layout());
variance->set_dims(phi::make_ddim({rows}));
variance->set_dtype(DataType::FLOAT32);
variance->set_layout(x.layout());
}
void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
......@@ -2918,6 +2980,54 @@ void PsroiPoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
const MetaTensor& norm_weight,
const MetaTensor& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* residual_out) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();
size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}
PADDLE_ENFORCE_EQ(normalized_dims,
norm_weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
norm_weight.dims()[0]));
auto out_dims = phi::make_ddim(x_dims_vec);
out->set_dims(out_dims);
if (quant_scale <= 0.0f) {
out->set_dtype(x.dtype());
} else {
out->set_dtype(phi::DataType::INT8);
}
out->set_layout(x.layout());
out->share_lod(x);
residual_out->set_dims(out_dims);
residual_out->set_dtype(x.dtype());
residual_out->set_layout(x.layout());
residual_out->share_lod(x);
}
void RmspropInferMeta(const MetaTensor& param,
const MetaTensor& mean_square,
const MetaTensor& grad,
......
......@@ -301,6 +301,23 @@ void FusedBiasActInferMeta(const MetaTensor& x,
float quant_min_bound,
MetaTensor* out);
void FusedLayerNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
const MetaTensor& norm_weight,
const MetaTensor& norm_bias,
const float epsilon,
const float residual_alpha,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* residual_out,
MetaTensor* mean,
MetaTensor* variance);
void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
......@@ -516,6 +533,20 @@ void PsroiPoolInferMeta(const MetaTensor& x,
float spatial_scale,
MetaTensor* out);
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
const MetaTensor& norm_weight,
const MetaTensor& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* residual_out);
void RmspropInferMeta(const MetaTensor& param,
const MetaTensor& mean_square,
const MetaTensor& grad,
......
......@@ -124,6 +124,7 @@ __global__ void FusedDropoutActBias(
nullptr,
nullptr,
act,
1.0, /*Since Dropout Act bias do not use residual alpha, we set 1.0*/
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
......
......@@ -123,10 +123,12 @@ class FusedDropoutHelper {
FusedDropoutHelper(const phi::GPUContext& ctx,
const int rows,
const int cols,
const DropoutParam& dropout_param) {
const DropoutParam& dropout_param,
const float residual_alpha = 1.0) {
rows_ = rows;
cols_ = cols;
dropout_param_ = dropout_param;
residual_alpha_ = residual_alpha;
}
// out = residual + dropout( src + bias )
......@@ -156,7 +158,8 @@ class FusedDropoutHelper {
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
quant_next_in_scale,
residual_alpha_);
}
void ResidualDropoutBiasGrad(const phi::GPUContext& ctx,
......@@ -336,6 +339,7 @@ class FusedDropoutHelper {
int rows_;
int cols_;
DropoutParam dropout_param_;
float residual_alpha_;
};
template <typename T,
......@@ -348,20 +352,23 @@ class FusedDropoutLayerNormHelper
FusedDropoutLayerNormHelper() {}
FusedDropoutLayerNormHelper(const int rows,
const int cols,
const float epsilon) {
const float epsilon,
const float residual_alpha = 1.0) {
using U = phi::funcs::LayerNormParamType<T>;
this->rows_ = rows;
this->cols_ = cols;
epsilon_ = epsilon;
this->residual_alpha_ = residual_alpha;
}
FusedDropoutLayerNormHelper(const phi::GPUContext& ctx,
const int rows,
const int cols,
const DropoutParam& dropout_param,
const float epsilon)
const float epsilon,
const float residual_alpha = 1.0)
: FusedDropoutHelper<T, MaskType, InType, OutType>(
ctx, rows, cols, dropout_param) {
ctx, rows, cols, dropout_param, residual_alpha) {
using U = phi::funcs::LayerNormParamType<T>;
epsilon_ = epsilon;
}
......@@ -476,7 +483,8 @@ class FusedDropoutLayerNormHelper
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
quant_min_bound,
this->residual_alpha_);
}
template <typename P = phi::funcs::LayerNormParamType<T>,
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Original OneFlow copyright notice:
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh
// The following code modified from OneFlow's implementation, and change to use
// single Pass algorithm. Support Int8 quant, dequant Load/Store implementation.
#include "paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.h"
#include <assert.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include <cub/cub.cuh>
#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
#endif
namespace phi {
namespace fusion {
namespace {
#ifndef PADDLE_WITH_HIP
constexpr int kWarpSize = 32;
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
return a + b;
}
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
return max(a, b);
}
};
template <template <typename> class ReductionOp,
typename T,
int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(
val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width));
}
return val;
}
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template <typename T>
__inline__ __device__ T Div(T a, T b);
template <>
__inline__ __device__ float Div<float>(float a, float b) {
#ifdef OF_LAYER_NORM_USE_FAST_MATH
return __fdividef(a, b);
#else
return a / b;
#endif
}
template <>
__inline__ __device__ double Div<double>(double a, double b) {
return a / b;
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
#ifdef OF_LAYER_NORM_USE_FAST_MATH
return __frsqrt_rn(x);
#else
return rsqrt(x);
#endif
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}
template <class Func>
inline cudaError_t GetNumBlocks(Func func,
int64_t block_size,
size_t dynamic_smem_size,
int64_t max_blocks,
int64_t waves,
int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err =
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int max_active_blocks;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, func, block_size, dynamic_smem_size);
}
*num_blocks = std::max<int>(
1, std::min<int64_t>(max_blocks, sm_count * max_active_blocks * waves));
return cudaSuccess;
}
template <typename T>
struct DefaultComputeType {
using type = T;
};
template <>
struct DefaultComputeType<half> {
using type = float;
};
#if CUDA_VERSION >= 11000
template <>
struct DefaultComputeType<nv_bfloat16> {
using type = float;
};
#endif // CUDA_VERSION >= 11000
template <typename T>
class HasCanPackAs {
typedef char one;
struct two {
char x[2];
};
template <typename C>
static one test(decltype(&C::CanPackAs));
template <typename C>
static two test(...);
public:
enum { value = sizeof(test<T>(0)) == sizeof(char) };
};
template <typename T>
typename std::enable_if<HasCanPackAs<T>::value == true, bool>::type CanPackAs(
T t, size_t pack_size) {
return t.CanPackAs(pack_size);
}
template <typename T>
typename std::enable_if<HasCanPackAs<T>::value == false, bool>::type CanPackAs(
T t, size_t pack_size) {
return true;
}
template <typename T, int N>
struct alignas(sizeof(T) * N) Pack {
__device__ Pack() {
// do nothing
}
T elem[N];
};
template <typename SRC, typename DST>
struct DirectLoad {
using LoadType = DST;
DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) const {
Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
pack = *(reinterpret_cast<const Pack<SRC, N>*>(src) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(pack.elem[i]);
}
}
const SRC* src;
int64_t row_size;
};
template <typename SRC, typename DST>
struct DirectStore {
DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> pack;
const int64_t offset = (row * row_size + col) / N;
#pragma unroll
for (int i = 0; i < N; ++i) {
pack.elem[i] = static_cast<DST>(src[i]);
}
*(reinterpret_cast<Pack<DST, N>*>(dst) + offset) = pack;
}
DST* dst;
int64_t row_size;
};
template <typename T>
inline __device__ void WelfordCombine(T val, T* mean, T* m2, T* count) {
// Use Welford Online algorithem to compute mean and variance
// For more details you can refer to:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
*count += 1;
T delta1 = val - *mean;
*mean += Div(delta1, *count);
T delta2 = val - *mean;
*m2 += delta1 * delta2;
}
template <typename T>
inline __device__ void WelfordCombine(
T b_mean, T b_m2, T b_count, T* mean, T* m2, T* count) {
if (b_count == 0) {
return;
}
T new_count = *count + b_count;
T nb_over_n = Div(b_count, new_count);
T delta = b_mean - *mean;
*mean += delta * nb_over_n;
*m2 += b_m2 + delta * delta * (*count) * nb_over_n;
*count = new_count;
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(
T thread_mean, T thread_m2, T thread_count, T* mean, T* m2, T* count) {
*mean = thread_mean;
*m2 = thread_m2;
*count = thread_count;
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width);
T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width);
T b_count = __shfl_down_sync(0xffffffff, *count, mask, thread_group_width);
WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);
}
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(
T thread_mean, T thread_m2, T thread_count, T* mean, T* m2, T* count) {
WelfordWarpReduce<T, thread_group_width>(
thread_mean, thread_m2, thread_count, mean, m2, count);
*mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width);
*m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width);
*count = __shfl_sync(0xffffffff, *count, 0, thread_group_width);
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpReduceSum(T x) {
T result = 0.0f;
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
result += __shfl_xor_sync(0xffffffff, x, mask, thread_group_width);
}
return result;
}
template <typename T>
__inline__ __device__ void WelfordBlockAllReduce(T thread_mean,
T thread_m2,
T thread_count,
T* result_mean,
T* result_m2,
T* result_count) {
__shared__ T mean_shared[kWarpSize];
__shared__ T m2_shared[kWarpSize];
__shared__ T count_shared[kWarpSize];
__shared__ T mean_result_broadcast;
__shared__ T m2_result_broadcast;
__shared__ T count_result_broadcast;
const int lid = threadIdx.x % kWarpSize;
const int wid = threadIdx.x / kWarpSize;
T warp_mean = 0;
T warp_m2 = 0;
T warp_count = 0;
WelfordWarpReduce(
thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
__syncthreads();
if (lid == 0) {
mean_shared[wid] = warp_mean;
m2_shared[wid] = warp_m2;
count_shared[wid] = warp_count;
}
__syncthreads();
if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncwarp();
T block_mean = 0;
T block_m2 = 0;
T block_count = 0;
WelfordWarpReduce(
warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count);
if (lid == 0) {
mean_result_broadcast = block_mean;
m2_result_broadcast = block_m2;
count_result_broadcast = block_count;
}
}
__syncthreads();
*result_mean = mean_result_broadcast;
*result_m2 = m2_result_broadcast;
*result_count = count_result_broadcast;
}
template <typename LOAD,
typename STORE,
typename ComputeType,
int kPackSize,
int block_size>
__global__ void LayerNormBlockSMemImpl(LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance,
ComputeType col_divisor) {
using LoadType = typename LOAD::LoadType;
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast<LoadType*>(shared_buf);
const int tid = threadIdx.x;
assert(cols % kPackSize == 0);
const int num_packs = static_cast<int>(cols) / kPackSize;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_sum = 0;
ComputeType thread_sum_square = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
LoadType pack[kPackSize];
load.template load<kPackSize>(pack, row, pack_id * kPackSize);
#pragma unroll
for (int i = 0; i < kPackSize; ++i) {
buf[i * num_packs + pack_id] = pack[i];
ComputeType pack_val = static_cast<ComputeType>(pack[i]);
thread_sum += pack_val;
thread_sum_square += pack_val * pack_val;
}
}
const ComputeType row_sum =
BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
const ComputeType row_sum_square =
BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum_square);
// use multiply instead of divide.
ComputeType row_mean = row_sum * col_divisor;
ComputeType row_sum_square_mean = row_sum_square * col_divisor;
ComputeType row_variance = max(row_sum_square_mean - row_mean * row_mean,
static_cast<ComputeType>(0.0));
ComputeType row_inv_var =
Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
if (threadIdx.x == 0 && mean && inv_variance) {
mean[row] = row_mean;
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[kPackSize];
#pragma unroll
for (int i = 0; i < kPackSize; ++i) {
pack[i] = (static_cast<ComputeType>(buf[i * num_packs + pack_id]) -
row_mean) *
row_inv_var;
}
store.template store<kPackSize>(pack, row, pack_id * kPackSize);
}
}
}
template <typename LOAD,
typename STORE,
typename ComputeType,
int pack_size,
int block_size>
inline cudaError_t LaunchLayerNormBlockSMemImpl(cudaStream_t stream,
LOAD load,
STORE store,
int smem,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance,
ComputeType col_divisor) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>,
block_size,
smem,
rows,
waves,
&grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem, stream>>>(
load, store, rows, cols, epsilon, mean, inv_variance, col_divisor);
return cudaPeekAtLastError();
}
template <typename Func>
cudaError_t MaximizeDynamicSharedMemorySize(Func func,
const int max_smem_size) {
cudaFuncAttributes attr{};
cudaError_t err = cudaFuncGetAttributes(&attr, func);
if (err != cudaSuccess) {
return err;
}
constexpr int reserved_smem = 1024; // 1K
return cudaFuncSetAttribute(
func,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_smem_size - attr.sharedSizeBytes - reserved_smem);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize(
cudaStream_t stream,
LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance,
ComputeType col_divisor,
bool* success) {
// Note(Zhengzekang): We choose a fixed blocksize to avoid layernorm diff, by
// RichardWooSJTU.
constexpr int block_size_conf_1 = 128;
int dev = 0;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
const size_t smem = cols * sizeof(typename LOAD::LoadType);
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
pack_size,
block_size_conf_1>(stream,
load,
store,
smem,
rows,
cols,
epsilon,
mean,
inv_variance,
col_divisor);
}
template <typename LOAD, typename STORE, typename ComputeType>
struct TryDispatchLayerNormBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream,
LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance,
ComputeType col_divisor,
bool* success) {
if (cols % 4 == 0 && CanPackAs<LOAD>(load, 4) &&
CanPackAs<STORE>(store, 4)) {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD,
STORE,
ComputeType,
4>(stream,
load,
store,
rows,
cols,
epsilon,
mean,
inv_variance,
col_divisor,
success);
} else if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) &&
CanPackAs<STORE>(store, 2)) {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD,
STORE,
ComputeType,
2>(stream,
load,
store,
rows,
cols,
epsilon,
mean,
inv_variance,
col_divisor,
success);
} else {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD,
STORE,
ComputeType,
1>(stream,
load,
store,
rows,
cols,
epsilon,
mean,
inv_variance,
col_divisor,
success);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream,
LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance,
ComputeType col_divisor,
bool* success) {
return TryDispatchLayerNormBlockSMemImplPackSize<LOAD, STORE, ComputeType>()(
stream,
load,
store,
rows,
cols,
epsilon,
mean,
inv_variance,
col_divisor,
success);
}
template <typename LOAD,
typename STORE,
typename ComputeType,
int kPackSize,
int block_size>
__global__ void __launch_bounds__(1024)
LayerNormBlockUncachedImpl(LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance) {
using LoadType = typename LOAD::LoadType;
const int tid = threadIdx.x;
assert(cols % kPackSize == 0);
const int num_packs = static_cast<int>(cols) / kPackSize;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
LoadType pack[kPackSize];
load.template load<kPackSize>(pack, row, pack_id * kPackSize);
#pragma unroll
for (int i = 0; i < kPackSize; ++i) {
WelfordCombine(static_cast<ComputeType>(pack[i]),
&thread_mean,
&thread_m2,
&thread_count);
}
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
WelfordBlockAllReduce<ComputeType>(
thread_mean, thread_m2, thread_count, &row_mean, &row_m2, &row_count);
ComputeType row_variance =
max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));
ComputeType row_inv_var =
Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
if (threadIdx.x == 0 && mean && inv_variance) {
mean[row] = row_mean;
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
LoadType pack[kPackSize];
ComputeType dst_pack[kPackSize];
const int pack_offset = pack_id * kPackSize;
load.template load<kPackSize>(pack, row, pack_offset);
#pragma unroll
for (int i = 0; i < kPackSize; ++i) {
dst_pack[i] =
(static_cast<ComputeType>(pack[i]) - row_mean) * row_inv_var;
}
store.template store<kPackSize>(dst_pack, row, pack_offset);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream,
LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(LayerNormBlockUncachedImpl<LOAD,
STORE,
ComputeType,
pack_size,
block_size>,
block_size,
0,
rows,
waves,
&grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, 0, stream>>>(
load, store, rows, cols, epsilon, mean, inv_variance);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType>
struct DispatchLayerNormBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream,
LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance) {
if (cols % 4 == 0 && CanPackAs<LOAD>(load, 4) &&
CanPackAs<STORE>(store, 4)) {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) &&
CanPackAs<STORE>(store, 2)) {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 1>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormBlockUncachedImpl(
cudaStream_t stream,
LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormBlockUncachedImplPackSize<LOAD, STORE, ComputeType>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value,
cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream,
LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance) {
const ComputeType col_divisor = 1.0f / cols;
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType>(
stream,
load,
store,
rows,
cols,
epsilon,
mean,
inv_variance,
col_divisor,
&dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
return cudaSuccess;
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value,
cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream,
LOAD load,
STORE store,
const int64_t rows,
const int64_t cols,
const double epsilon,
ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
template <typename T>
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
if (v > max) return max;
if (v < min) return min;
return v;
}
template <typename InType, typename OutType>
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * input;
if (round_type == 0) {
quant_value = static_cast<float>(rint(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
return static_cast<OutType>(
ClipFunc<float>(quant_value, min_bound, max_bound));
}
template <typename OutType,
typename SRC,
typename DST,
bool do_scale,
bool do_center>
struct AffineQuantStore {
AffineQuantStore(OutType* y,
const int64_t row_size,
const float* gamma,
const float* beta,
const float quant_out_scale,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0)
: y(y),
row_size(row_size),
gamma(gamma),
beta(beta),
quant_round_type(quant_round_type),
quant_out_scale(quant_out_scale),
quant_max_bound(quant_max_bound),
quant_min_bound(quant_min_bound) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<OutType, N> y_pack;
Pack<float, N> gamma_pack;
Pack<float, N> beta_pack;
Pack<OutType, N> out_pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t gamma_offset = col / N;
gamma_pack =
*(reinterpret_cast<const Pack<float, N>*>(gamma) + gamma_offset);
beta_pack = *(reinterpret_cast<const Pack<float, N>*>(beta) + gamma_offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
float normalized_i = static_cast<float>(src[i]);
float normalized_val =
normalized_i * gamma_pack.elem[i] + beta_pack.elem[i];
y_pack.elem[i] = QuantHelperFunc<float, OutType>(normalized_val,
quant_out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
*(reinterpret_cast<Pack<OutType, N>*>(y) + offset) = y_pack;
}
OutType* y;
int64_t row_size;
const float* gamma;
const float* beta;
const int quant_round_type;
const float quant_out_scale;
const float quant_max_bound;
const float quant_min_bound;
};
template <typename T>
struct SkipLoadAndStoreResidual {
using LoadType = T;
SkipLoadAndStoreResidual(const T* src,
const T* bias,
const T* skip,
T* residual_bias_out,
float alpha,
int64_t row_size)
: src(src),
bias(bias),
skip(skip),
residual_bias_out(residual_bias_out),
alpha(alpha),
row_size(row_size) {}
template <int N>
__device__ void load(T* dst, int64_t row, int64_t col) const {
Pack<T, N> src_pack;
Pack<T, N> bias_pack;
Pack<T, N> skip_pack;
Pack<T, N> residual_out_pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t bias_offset = col / N;
src_pack = *(reinterpret_cast<const Pack<T, N>*>(src) + offset);
bias_pack = *(reinterpret_cast<const Pack<T, N>*>(bias) + bias_offset);
skip_pack = *(reinterpret_cast<const Pack<T, N>*>(skip) + offset);
T alpha_val = static_cast<T>(alpha);
#pragma unroll
for (int i = 0; i < N; ++i) {
// First we need to cast src and dequant.
residual_out_pack.elem[i] =
static_cast<T>(static_cast<T>(static_cast<float>(src_pack.elem[i])) +
bias_pack.elem[i] + skip_pack.elem[i] * alpha_val);
}
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = residual_out_pack.elem[i];
}
*(reinterpret_cast<Pack<T, N>*>(residual_bias_out) + offset) =
residual_out_pack;
}
const T* src;
const T* bias;
const T* skip;
T* residual_bias_out;
float alpha;
int64_t row_size;
};
#endif
} // namespace
template <typename T, typename Context>
void FusedLayerNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& residual,
const paddle::optional<DenseTensor>& norm_weight,
const paddle::optional<DenseTensor>& norm_bias,
const float epsilon,
const float residual_alpha,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* residual_out,
DenseTensor* mean,
DenseTensor* variance) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using U = phi::funcs::LayerNormParamType<T>;
const T* x_data = x.data<T>();
const U* norm_weight_data =
norm_weight ? norm_weight.get().data<U>() : nullptr;
const U* norm_bias_data = norm_bias ? norm_bias.get().data<U>() : nullptr;
int32_t rows = 1;
int32_t cols = 1;
for (int i = 0; i < begin_norm_axis; i++) {
rows *= x.dims()[i];
}
for (int i = begin_norm_axis; i < x.dims().size(); i++) {
cols *= x.dims()[i];
}
phi::fusion::DropoutParam dropout_param(true, 0, true, true, 0.0, nullptr, 0);
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
residual_bias_add_layernorm_helper(
dev_ctx, rows, cols, dropout_param, epsilon, residual_alpha);
phi::fusion::AttnLayerNorm<T> layernorm_helper(dev_ctx, epsilon, rows, cols);
// Do residual + bias + x
if (residual && norm_weight_data == nullptr && norm_bias_data == nullptr) {
const T* residual_data = residual.get().data<T>();
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
T* out_data = dev_ctx.template Alloc<T>(out);
residual_bias_add_layernorm_helper.ResidualDropoutBias(
dev_ctx,
x_data,
residual_data,
bias_data,
out_data,
nullptr /*dropout_mask_out_data*/);
return;
}
U* mean_data = dev_ctx.template Alloc<U>(mean);
U* variance_data = dev_ctx.template Alloc<U>(variance);
if (residual) {
// Do Layernorm(residual + bias + x)
T* residual_out_data = dev_ctx.template Alloc<T>(residual_out);
const T* residual_data = residual.get().data<T>();
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
if (quant_scale <= 0.0f) {
T* out_data = dev_ctx.template Alloc<T>(out);
residual_bias_add_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
x_data,
residual_data,
bias_data,
norm_weight_data,
norm_bias_data,
residual_out_data,
nullptr,
out_data,
mean_data,
variance_data);
} else {
// Quantize and output int8.
int8_t* out_data = dev_ctx.template Alloc<int8_t>(out);
SkipLoadAndStoreResidual<T> load(x_data,
bias_data,
residual_data,
residual_out_data,
residual_alpha,
cols);
AffineQuantStore<int8_t, U, T, true, true> store(out_data,
cols,
norm_weight_data,
norm_bias_data,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
DispatchLayerNorm<decltype(load), decltype(store), U>(
dev_ctx.stream(),
load,
store,
rows,
cols,
epsilon,
mean_data /*ln_mean_data*/,
variance_data /*ln_var_data*/);
}
} else {
if (quant_scale <= 0.0f) {
T* out_data = dev_ctx.template Alloc<T>(out);
layernorm_helper.ComputeForward(x_data,
norm_weight_data,
norm_bias_data,
out_data,
mean_data,
variance_data);
} else {
// Quantize and output int8.
int8_t* out_data = dev_ctx.template Alloc<int8_t>(out);
DirectLoad<T, U> load(x_data, cols);
AffineQuantStore<int8_t, U, T, true, true> store(out_data,
cols,
norm_weight_data,
norm_bias_data,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
DispatchLayerNorm<decltype(load), decltype(store), U>(dev_ctx.stream(),
load,
store,
rows,
cols,
epsilon,
mean_data,
variance_data);
}
}
#endif
}
} // namespace fusion
} // namespace phi
#ifndef PADDLE_WITH_HIP
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
GPU,
ALL_LAYOUT,
phi::fusion::FusedLayerNormKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
GPU,
ALL_LAYOUT,
phi::fusion::FusedLayerNormKernel,
float,
phi::dtype::float16) {}
#endif // CUDNN_VERSION_MIN
#else
PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
GPU,
ALL_LAYOUT,
phi::fusion::FusedLayerNormKernel,
float,
phi::dtype::float16) {}
#endif // PADDLE_WITH_HIP
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedLayerNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& residual,
const paddle::optional<DenseTensor>& norm_weight,
const paddle::optional<DenseTensor>& norm_bias,
const float epsilon,
const float residual_alpha,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* residual_out,
DenseTensor* mean,
DenseTensor* variance);
} // namespace fusion
} // namespace phi
......@@ -132,7 +132,8 @@ __global__ void FusedLayernormResidualDropoutBias(
T *dst,
T *layernorm_dst,
LayerNormParamType<T> *mean,
LayerNormParamType<T> *var) {
LayerNormParamType<T> *var,
const float residual_alpha = 1.0) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
int idx = row_id * cols + col_id;
......@@ -175,7 +176,8 @@ __global__ void FusedLayernormResidualDropoutBias(
is_test,
&mean_val,
&var_val,
relu);
relu,
residual_alpha);
}
mean_val = phi::funcs::BlockReduceSum<U>(mean_val, shared_mean);
......@@ -233,7 +235,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel(
T *dst,
T *layernorm_dst,
LayerNormParamType<T> *mean,
LayerNormParamType<T> *var) {
LayerNormParamType<T> *var,
const float residual_alpha = 1.0) {
if (dropout_prob != 0.0f) {
FusedLayernormResidualDropoutBias<T,
MaskType,
......@@ -258,7 +261,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel(
dst,
layernorm_dst,
mean,
var);
var,
residual_alpha);
} else {
FusedLayernormResidualDropoutBias<T,
MaskType,
......@@ -283,7 +287,8 @@ void LaunchFusedLayernormResidualDropoutBiasCUDAKernel(
dst,
layernorm_dst,
mean,
var);
var,
residual_alpha);
}
}
......@@ -539,7 +544,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
const float quant_min_bound = -127.0,
const float residual_alpha = 1.0) {
__shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
......@@ -641,13 +647,13 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
dequant_out_scale[it][jt]) +
bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
residual[it][jt] * static_cast<T>(residual_alpha);
x[it][jt] = tmp;
xf[it * VecSize + jt] = U(tmp);
} else {
x[it][jt] = (static_cast<T>(x_input[it][jt]) + bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
residual[it][jt] * static_cast<T>(residual_alpha);
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
......@@ -663,12 +669,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
T tmp = static_cast<T>(static_cast<float>(x_input[it][jt]) *
dequant_out_scale[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
residual[it][jt] * static_cast<T>(residual_alpha);
x[it][jt] = tmp;
} else {
x[it][jt] = static_cast<T>(x_input[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
residual[it][jt] * static_cast<T>(residual_alpha);
}
xf[it * VecSize + jt] = U(x[it][jt]);
}
......@@ -848,7 +854,8 @@ void LaunchLayernormResidualDropoutBias(
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
const float quant_min_bound = -127.0,
const float residual_alpha = 1.0) {
// dropout_prob == 1.0f
// NOTE(minghaoBD): OutType should be T if drop_out_rate == 1.0
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
......@@ -942,7 +949,8 @@ void LaunchLayernormResidualDropoutBias(
quant_next_in_scale, \
quant_round_type, \
quant_max_bound, \
quant_min_bound); \
quant_min_bound, \
residual_alpha); \
} else { \
fused_fast_ln_fwd_kernel< \
false, \
......@@ -986,7 +994,8 @@ void LaunchLayernormResidualDropoutBias(
quant_next_in_scale, \
quant_round_type, \
quant_max_bound, \
quant_min_bound); \
quant_min_bound, \
residual_alpha); \
} \
} break
......@@ -1036,7 +1045,8 @@ void LaunchLayernormResidualDropoutBias(
dst,
reinterpret_cast<T *>(layernorm_dst),
mean,
var);
var,
residual_alpha);
} else {
if (can_call_fast_ln_kernel) {
switch (cols) {
......@@ -1074,7 +1084,8 @@ void LaunchLayernormResidualDropoutBias(
dst,
reinterpret_cast<T *>(layernorm_dst),
mean,
var);
var,
residual_alpha);
}
}
}
......
......@@ -53,6 +53,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
typename phi::dtype::MPTypeTrait<T>::Type *mean_val,
typename phi::dtype::MPTypeTrait<T>::Type *var_val,
Functor act_func,
const float residual_alpha = 1.0,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
......@@ -121,10 +122,11 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
tmp = act_func(tmp);
}
if (HasDropout) {
dest_vec[ii] =
tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
dest_vec[ii] = tmp * static_cast<T>(mask_vec[ii]) * factor +
residual_vec[ii] * static_cast<T>(residual_alpha);
} else {
dest_vec[ii] = tmp * factor + residual_vec[ii];
dest_vec[ii] =
tmp * factor + residual_vec[ii] * static_cast<T>(residual_alpha);
}
if (ComputeLayerNorm) {
U tmp = static_cast<U>(dest_vec[ii]);
......@@ -274,7 +276,8 @@ __global__ void FusedResidualDropoutBias(
const bool is_test,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0) {
const float quant_next_in_scale = 1.0,
const float residual_alpha = 1.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
......@@ -316,6 +319,7 @@ __global__ void FusedResidualDropoutBias(
nullptr,
nullptr,
relu,
residual_alpha,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
......@@ -345,7 +349,8 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0) {
const float quant_next_in_scale = 1.0,
const float residual_alpha = 1.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
// NOTE(minghaoBD): OutType should be T if dropout_prob == 1.0
......@@ -396,7 +401,8 @@ void LaunchResidualDropoutBias(const uint32_t rows,
is_test, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale); \
quant_next_in_scale, \
residual_alpha); \
} else { \
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType, __has_dropout> \
<<<config.block_per_grid, \
......@@ -416,7 +422,8 @@ void LaunchResidualDropoutBias(const uint32_t rows,
is_test, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale); \
quant_next_in_scale, \
residual_alpha); \
} \
} while (0)
......
......@@ -937,20 +937,26 @@ struct AffineQuantStore {
template <typename T, typename Context>
void RmsNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int begin_norm_axis,
DenseTensor* out) {
const paddle::optional<DenseTensor>& residual,
const DenseTensor& norm_weight,
const paddle::optional<DenseTensor>& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* residual_out) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
const T* x_data = x.data<T>();
const T* weight_data = weight.data<T>();
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
T* out_data = dev_ctx.template Alloc<T>(out);
const T* norm_weight_data = norm_weight.data<T>();
const T* norm_bias_data = norm_bias ? norm_bias.get().data<T>() : nullptr;
int32_t rows = 1;
int32_t cols = 1;
......@@ -962,283 +968,64 @@ void RmsNormKernel(const Context& dev_ctx,
cols *= x.dims()[i];
}
DirectLoad<T, ComputeType> load(x_data, cols);
AffineStore<ComputeType, T> store(out_data, cols, weight_data, bias_data);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
dev_ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template <typename T, typename Context>
void RmsNormWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
T* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
DirectLoad<T, ComputeType> load(x, cols);
AffineStore<ComputeType, T> store(output, cols, weight, bias);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void RmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* weight,
const phi::dtype::float16* bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::float16* output);
template void RmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* weight,
const phi::dtype::bfloat16* bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::bfloat16* output);
template void RmsNormWrapper(const phi::GPUContext& ctx,
const float* x,
const float* weight,
const float* bias,
const float epsilon,
const int rows,
const int cols,
float* output);
// ========== ResidualAdd + RMSNorm ==========
template <typename T, typename Context>
void ResidualAddRmsNormWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
T* residual_output,
T* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
if (residual) {
// Do RMSNorm(bias_add + residual + x)
T* residual_out_data = dev_ctx.template Alloc<T>(residual_out);
const T* residual_data = residual.get().data<T>();
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
ResidualAddBiasLoad<T, ComputeType> load(
x, residual, bias, residual_output, cols);
AffineStore<ComputeType, T> store(output, cols, norm_weight, norm_bias);
x_data, residual_data, bias_data, residual_out_data, cols);
if (quant_scale <= 0.0f) {
// No Quantize.
T* out_data = dev_ctx.template Alloc<T>(out);
AffineStore<ComputeType, T> store(
out_data, cols, norm_weight_data, norm_bias_data);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* residual,
const phi::dtype::float16* bias,
const phi::dtype::float16* norm_weight,
const phi::dtype::float16* norm_bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::float16* residual_output,
phi::dtype::float16* output);
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* residual,
const phi::dtype::bfloat16* bias,
const phi::dtype::bfloat16* norm_weight,
const phi::dtype::bfloat16* norm_bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::bfloat16* residual_output,
phi::dtype::bfloat16* output);
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx,
const float* x,
const float* residual,
const float* bias,
const float* norm_weight,
const float* norm_bias,
const float epsilon,
const int rows,
const int cols,
float* residual_output,
float* output);
// ===== FP16 in, Int8out RMSNorm =====
template <typename T, typename Context>
void RmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
DirectLoad<T, ComputeType> load(x, cols);
AffineQuantStore<int8_t, ComputeType, T, true, true> store(output,
dev_ctx.stream(), load, store, rows, cols, epsilon);
} else {
// Quantize and output int8.
int8_t* out_data = dev_ctx.template Alloc<int8_t>(out);
AffineQuantStore<int8_t, ComputeType, T, true, true> store(
out_data,
cols,
weight,
bias,
in_scale,
norm_weight_data,
norm_bias_data,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const float* x,
const float* weight,
const float* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* weight,
const phi::dtype::float16* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* weight,
const phi::dtype::bfloat16* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
// ===== FP16 in, Int8out ResidualAdd + RMSNorm =====
template <typename T, typename Context>
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
T* residual_output,
int8_t* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
ResidualAddBiasLoad<T, ComputeType> load(
x, residual, bias, residual_output, cols);
AffineQuantStore<int8_t, ComputeType, T, true, true> store(output,
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
dev_ctx.stream(), load, store, rows, cols, epsilon);
}
} else {
DirectLoad<T, ComputeType> load(x_data, cols);
if (quant_scale <= 0.0f) {
// No Quantize.
T* out_data = dev_ctx.template Alloc<T>(out);
AffineStore<ComputeType, T> store(
out_data, cols, norm_weight_data, norm_bias_data);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
dev_ctx.stream(), load, store, rows, cols, epsilon);
} else {
// Quantize and output int8.
int8_t* out_data = dev_ctx.template Alloc<int8_t>(out);
AffineQuantStore<int8_t, ComputeType, T, true, true> store(
out_data,
cols,
norm_weight,
norm_bias,
in_scale,
norm_weight_data,
norm_bias_data,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
dev_ctx.stream(), load, store, rows, cols, epsilon);
}
}
#endif
}
template void ResidualAddRmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const float* x,
const float* residual,
const float* bias,
const float* norm_weight,
const float* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
float* residual_output,
int8_t* output);
template void ResidualAddRmsNormInt8OutWrapper(
const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* residual,
const phi::dtype::float16* bias,
const phi::dtype::float16* norm_weight,
const phi::dtype::float16* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
phi::dtype::float16* residual_output,
int8_t* output);
template void ResidualAddRmsNormInt8OutWrapper(
const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* residual,
const phi::dtype::bfloat16* bias,
const phi::dtype::bfloat16* norm_weight,
const phi::dtype::bfloat16* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
phi::dtype::bfloat16* residual_output,
int8_t* output);
} // namespace phi
PD_REGISTER_KERNEL(rms_norm,
......
......@@ -22,64 +22,17 @@ namespace phi {
template <typename T, typename Context>
void RmsNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int begin_norm_axis,
DenseTensor* out);
template <typename T, typename Context>
void RmsNormWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
T* output);
template <typename T, typename Context>
void ResidualAddRmsNormWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
T* residual_output,
T* output);
template <typename T, typename Context>
void RmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template <typename T, typename Context>
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const paddle::optional<DenseTensor>& residual,
const DenseTensor& norm_weight,
const paddle::optional<DenseTensor>& norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
T* residual_output,
int8_t* output);
DenseTensor* out,
DenseTensor* residual_out);
} // namespace phi
......@@ -28,7 +28,8 @@ from .fused_rotary_position_embedding import fused_rotary_position_embedding
from .variable_length_memory_efficient_attention import (
variable_length_memory_efficient_attention,
)
from .rms_norm import rms_norm
from .fused_rms_norm import fused_rms_norm
from .fused_layer_norm import fused_layer_norm
__all__ = [
'fused_multi_head_attention',
......@@ -42,5 +43,6 @@ __all__ = [
'fused_dropout_add',
'fused_rotary_position_embedding',
'variable_length_memory_efficient_attention',
"rms_norm",
"fused_rms_norm",
"fused_layer_norm",
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode
def fused_layer_norm(
x,
norm_weight,
norm_bias,
epsilon,
residual_alpha=1.0,
begin_norm_axis=1,
bias=None,
residual=None,
quant_scale=-1,
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0,
):
r"""
Apply Fused LayerNorm kernel. Also support LayerNorm(bias + residual_alpha * residual + x) fused pattern.
when norm_weight and norm_bias is None, it return fused (bias + residual_alpha * residual + x)
Args:
x (Tensor): the input Tensor..
norm_weight (Tensor): the weight Tensor to affine output.
norm_bias (Tensor): the bias Tensor to affine output.
epsilon (float): a small float number to avoid divide 0.
residual_alpha (float): a scale factor for residual. default is 1.
begin_norm_axis (int): the begin axis to normalize. default is 1.
bias (optional|Tensor): the previous layers's bias to fused.
residual (optional|Tensor): the residual input to fused.
quant_scale (float): the quant scale.
quant_round_type (float): the quant round type.
quant_max_bound (float): the quant max bound to clip.
quant_min_bound (float): the quant min bound to clip.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16)
paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32)
paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32)
epsilon = 1e-6
paddle_layernorm = paddle.incubate.nn.functional.fused_layer_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""
if in_dynamic_mode():
return _C_ops.fused_bias_residual_layernorm(
x,
bias,
residual,
norm_weight,
norm_bias,
epsilon,
residual_alpha,
begin_norm_axis,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
helper = LayerHelper('fused_layernorm', **locals())
out = None
if quant_scale <= 0:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable_for_type_inference(dtype=paddle.int8)
outputs_dict = {}
outputs_dict['out'] = out
outputs_dict['mean'] = helper.create_variable_for_type_inference(
dtype=paddle.float32
)
outputs_dict['variance'] = helper.create_variable_for_type_inference(
dtype=paddle.float32
)
residual_out = helper.create_variable_for_type_inference(dtype=x.dtype)
outputs_dict['residual_out'] = residual_out
inputs = {'x': x, 'norm_weight': norm_weight, 'norm_bias': norm_bias}
if residual is not None:
inputs['residual'] = residual
if bias is not None:
inputs['bias'] = bias
helper.append_op(
type='fused_bias_residual_layernorm',
inputs=inputs,
attrs={
"epsilon": epsilon,
"residual_alpha": residual_alpha,
"begin_norm_axis": begin_norm_axis,
"quant_scale": quant_scale,
"quant_round_type": quant_round_type,
"quant_max_bound": quant_max_bound,
"quant_min_bound": quant_min_bound,
},
outputs=outputs_dict,
)
return out
......@@ -13,21 +13,40 @@
# limitations under the License.
import paddle
from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle.framework import LayerHelper, in_dynamic_mode
def rms_norm(x, weight, bias, epsilon, begin_norm_axis):
def fused_rms_norm(
x,
norm_weight,
norm_bias,
epsilon,
begin_norm_axis,
bias=None,
residual=None,
quant_scale=-1,
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0,
):
r"""
Apply RMSNorm kernel.
Apply Fused RMSNorm kernel. Also support RMSNorm(bias + residual + x) fused pattern.
Args:
x (Tensor): the input Tensor..
weight (Tensor): the weight Tensor to affine output.
bias (Tensor): the bias Tensor to affine output.
norm_weight (Tensor): the weight Tensor to affine output.
norm_bias (Tensor): the bias Tensor to affine output.
epsilon (float): a small float number to avoid divide 0.
begin_norm_axis (int): the begin axis to normalize.
bias (optional|Tensor): the previous layers's bias to fused.
residual (optional|Tensor): the residual input to fused.
quant_scale (float): the quant scale.
quant_round_type (float): the quant round type.
quant_max_bound (float): the quant max bound to clip.
quant_min_bound (float): the quant min bound to clip.
Returns:
Tensor: the output Tensor.
......@@ -42,18 +61,54 @@ def rms_norm(x, weight, bias, epsilon, begin_norm_axis):
paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16)
paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16)
epsilon = 1e-6
paddle_rmsnorm = paddle.incubate.nn.functional.rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""
if in_dygraph_mode():
return _C_ops.rms_norm(x, weight, bias, epsilon, begin_norm_axis)
if in_dynamic_mode():
return _C_ops.rms_norm(
x,
bias,
residual,
norm_weight,
norm_bias,
epsilon,
begin_norm_axis,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
helper = LayerHelper('rms_norm', **locals())
out = None
if quant_scale <= 0:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable_for_type_inference(dtype=paddle.int8)
outputs_dict = {}
outputs_dict['out'] = out
residual_out = helper.create_variable_for_type_inference(dtype=x.dtype)
outputs_dict['residual_out'] = residual_out
inputs = {'x': x, 'norm_weight': norm_weight}
if norm_bias:
inputs['norm_bias'] = norm_bias
if residual is not None:
inputs['residual'] = residual
if bias is not None:
inputs['bias'] = bias
helper.append_op(
type='rms_norm',
inputs={'x': x, 'weight': weight, 'bias': bias},
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis},
outputs={'out': out},
inputs=inputs,
attrs={
"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis,
"quant_scale": quant_scale,
"quant_round_type": quant_round_type,
"quant_max_bound": quant_max_bound,
"quant_min_bound": quant_min_bound,
},
outputs=outputs_dict,
)
return out
return (out, residual_out) if residual is not None else out
......@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass)
list(REMOVE_ITEM TEST_OPS test_fused_comm_buffer)
......@@ -156,6 +157,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress)
list(REMOVE_ITEM TEST_OPS test_matmul_int8_op)
list(REMOVE_ITEM TEST_OPS test_variable_length_memory_efficient_attention)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import core
def quant_helper(
x, quant_scale, quant_round_type, quant_max_bound, quant_min_bound
):
quant_value = quant_max_bound * quant_scale * x
if quant_round_type == 0:
quant_value = paddle.to_tensor(np.rint(quant_value.numpy()))
else:
quant_value = paddle.round(quant_value)
return paddle.cast(
paddle.clip(quant_value, quant_min_bound, quant_max_bound),
paddle.int8,
)
def naive_layer_norm(x, gamma, beta, epsilon):
x_float = paddle.cast(x, dtype=paddle.float32)
mean = paddle.mean(x_float, axis=-1, keepdim=True)
var = paddle.var(x_float, axis=-1, keepdim=True)
sqrt_var = paddle.rsqrt(var + epsilon)
normalized_output = (x_float - mean) * sqrt_var
out = normalized_output * gamma + beta
out = paddle.cast(out, x.dtype)
return out
def naive_layer_norm_int8(
x,
gamma,
beta,
epsilon,
in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
):
out = naive_layer_norm(x, gamma, beta, epsilon)
out = quant_helper(
out, in_scale, quant_round_type, quant_max_bound, quant_min_bound
)
return out
def naive_residual_biasadd_layer_norm(
x, residual, bias, gamma, beta, epsilon, residual_alpha
):
x = x + residual * residual_alpha + bias
mean = paddle.mean(x, axis=-1, keepdim=True)
var = paddle.var(x, axis=-1, keepdim=True)
sqrt_var = paddle.rsqrt(var + epsilon)
out = ((x - mean) * sqrt_var) * paddle.cast(gamma, x.dtype) + paddle.cast(
beta, x.dtype
)
return out
def naive_residual_biasadd_layer_norm_int8(
x,
residual,
bias,
gamma,
beta,
epsilon,
residual_alpha,
in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
):
out = naive_residual_biasadd_layer_norm(
x, residual, bias, gamma, beta, epsilon, residual_alpha
)
out = quant_helper(
out, in_scale, quant_round_type, quant_max_bound, quant_min_bound
)
return out
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestlayernormOp(unittest.TestCase):
def setUp(self):
np.random.seed(20)
batch = 16
cols = 256
self.x_np = np.random.uniform(-0.05, 0.05, [batch, cols])
self.residual_np = np.random.uniform(-0.05, 0.05, [batch, cols])
self.bias_np = np.random.uniform(-0.05, 0.05, [cols])
self.norm_weight_np = np.random.uniform(-0.05, 0.05, [cols])
self.norm_bias_np = np.random.uniform(-0.05, 0.05, [cols])
self.epsilon = 1e-5
self.residual_alpha = np.random.uniform(low=0.1, high=1.1, size=[1])
self.quant_scale = 0.15
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
def check_layernorm(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(np.float32))
beta = paddle.to_tensor(beta_np.astype(np.float32))
paddle_layernorm_out = paddle.incubate.nn.functional.fused_layer_norm(
x, gamma, beta, self.epsilon, begin_norm_axis=1
)
paddle_naive_layernorm_out = naive_layer_norm(
x, gamma, beta, self.epsilon
)
paddle.enable_static()
return paddle_layernorm_out, paddle_naive_layernorm_out
def check_layernorm_int8(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(np.float32))
beta = paddle.to_tensor(beta_np.astype(np.float32))
paddle_layernorm_out = paddle.incubate.nn.functional.fused_layer_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
paddle_naive_layernorm_out = naive_layer_norm_int8(
x,
gamma,
beta,
self.epsilon,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
return paddle_layernorm_out, paddle_naive_layernorm_out
def check_residual_bias_layernorm(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(np.float32))
beta = paddle.to_tensor(beta_np.astype(np.float32))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_layernorm_out = paddle.incubate.nn.functional.fused_layer_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
bias=bias,
residual=residual,
residual_alpha=self.residual_alpha,
)
paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm(
x, residual, bias, gamma, beta, self.epsilon, self.residual_alpha
)
paddle.enable_static()
return paddle_layernorm_out, paddle_naive_layernorm_out
def check_residual_bias_layernorm_int8(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(np.float32))
beta = paddle.to_tensor(beta_np.astype(np.float32))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_layernorm_out = paddle.incubate.nn.functional.fused_layer_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
bias=bias,
residual=residual,
residual_alpha=self.residual_alpha,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm_int8(
x,
residual,
bias,
gamma,
beta,
self.epsilon,
self.residual_alpha,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
return paddle_layernorm_out, paddle_naive_layernorm_out
def test_layernorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_layernorm, paddle_naive_layernorm = self.check_layernorm(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
np.testing.assert_allclose(
paddle_layernorm[0].numpy(),
paddle_naive_layernorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_layernorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
paddle_layernorm, paddle_naive_layernorm = self.check_layernorm_int8(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
np.testing.assert_allclose(
paddle_layernorm[0].numpy(),
paddle_naive_layernorm.numpy(),
rtol=2,
atol=2,
)
def test_residual_bias_add_layernorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
(
paddle_layernorm,
paddle_naive_layernorm,
) = self.check_residual_bias_layernorm(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
paddle_layernorm[0].numpy(),
paddle_naive_layernorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_residual_bias_add_layernorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
(
paddle_layernorm,
paddle_naive_layernorm,
) = self.check_residual_bias_layernorm_int8(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
paddle_layernorm[0].numpy(),
paddle_naive_layernorm.numpy(),
rtol=2,
atol=2,
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestlayernormStaticOp(unittest.TestCase):
def setUp(self):
np.random.seed(20)
self.batch = 16
self.cols = 256
self.x_np = np.random.uniform(-0.05, 0.05, [self.batch, self.cols])
self.residual_np = np.random.uniform(
-0.05, 0.05, [self.batch, self.cols]
)
self.bias_np = np.random.uniform(-0.05, 0.05, [self.cols])
self.norm_weight_np = np.random.uniform(-0.05, 0.05, [self.cols])
self.norm_bias_np = np.random.uniform(-0.05, 0.05, [self.cols])
self.epsilon = 1e-5
self.residual_alpha = np.random.uniform(low=0.1, high=1.1, size=[1])
self.quant_scale = 0.15
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
self.place = paddle.CUDAPlace(0)
def check_layernorm(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(np.float32))
beta = paddle.to_tensor(beta_np.astype(np.float32))
paddle_naive_layernorm_out = naive_layer_norm(
x, gamma, beta, self.epsilon
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=paddle.float32
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=paddle.float32
)
outs = paddle.incubate.nn.functional.fused_layer_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(np.float32),
"beta_static": beta_np.astype(np.float32),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_layernorm_out
def check_layernorm_int8(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(np.float32))
beta = paddle.to_tensor(beta_np.astype(np.float32))
paddle_naive_layernorm_out = naive_layer_norm_int8(
x,
gamma,
beta,
self.epsilon,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=paddle.float32
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=paddle.float32
)
outs = paddle.incubate.nn.functional.fused_layer_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(np.float32),
"beta_static": beta_np.astype(np.float32),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_layernorm_out
def check_residual_bias_layernorm(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(np.float32))
beta = paddle.to_tensor(beta_np.astype(np.float32))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm(
x, residual, bias, gamma, beta, self.epsilon, self.residual_alpha
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
residual_static = paddle.static.data(
name="residual_static",
shape=[self.batch, self.cols],
dtype=dtype,
)
bias_static = paddle.static.data(
name="bias_static", shape=[self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=paddle.float32
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=paddle.float32
)
outs = paddle.incubate.nn.functional.fused_layer_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
residual_alpha=self.residual_alpha,
bias=bias_static,
residual=residual_static,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(np.float32),
"beta_static": beta_np.astype(np.float32),
"residual_static": residual_np.astype(dtype),
"bias_static": bias_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_layernorm_out
def check_residual_bias_layernorm_int8(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(np.float32))
beta = paddle.to_tensor(beta_np.astype(np.float32))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_naive_layernorm_out = naive_residual_biasadd_layer_norm_int8(
x,
residual,
bias,
gamma,
beta,
self.epsilon,
self.residual_alpha,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
residual_static = paddle.static.data(
name="residual_static",
shape=[self.batch, self.cols],
dtype=dtype,
)
bias_static = paddle.static.data(
name="bias_static", shape=[self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=paddle.float32
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=paddle.float32
)
outs = paddle.incubate.nn.functional.fused_layer_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
bias=bias_static,
residual=residual_static,
residual_alpha=self.residual_alpha,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(np.float32),
"beta_static": beta_np.astype(np.float32),
"residual_static": residual_np.astype(dtype),
"bias_static": bias_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_layernorm_out
def test_layernorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_layernorm, paddle_naive_layernorm = self.check_layernorm(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
np.testing.assert_allclose(
paddle_layernorm,
paddle_naive_layernorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_layernorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
paddle_layernorm, paddle_naive_layernorm = self.check_layernorm_int8(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
np.testing.assert_allclose(
paddle_layernorm,
paddle_naive_layernorm.numpy(),
rtol=2,
atol=2,
)
def test_residual_bias_add_layernorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
(
paddle_layernorm,
paddle_naive_layernorm,
) = self.check_residual_bias_layernorm(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
paddle_layernorm,
paddle_naive_layernorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_residual_bias_add_layernorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
(
paddle_layernorm,
paddle_naive_layernorm,
) = self.check_residual_bias_layernorm_int8(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
paddle_layernorm,
paddle_naive_layernorm.numpy(),
rtol=2,
atol=2,
)
if __name__ == "__main__":
unittest.main()
......@@ -20,6 +20,73 @@ from paddle import fluid
from paddle.fluid import core
def quant_helper(
x, quant_scale, quant_round_type, quant_max_bound, quant_min_bound
):
quant_value = quant_max_bound * quant_scale * x
if quant_round_type == 0:
quant_value = paddle.to_tensor(np.rint(quant_value.numpy()))
else:
quant_value = paddle.round(quant_value)
return paddle.cast(
paddle.clip(quant_value, quant_min_bound, quant_max_bound),
paddle.int8,
)
def naive_rms_norm(x, gamma, beta, epsilon):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + epsilon) * x
out = out * gamma + beta
return out
def naive_rms_norm_int8(
x,
gamma,
beta,
epsilon,
in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
):
out = naive_rms_norm(x, gamma, beta, epsilon)
out = quant_helper(
out, in_scale, quant_round_type, quant_max_bound, quant_min_bound
)
return out
def naive_residual_biasadd_rms_norm(x, residual, bias, gamma, beta, epsilon):
x = x + residual + bias
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + epsilon) * x
out = out * gamma + beta
return out
def naive_residual_biasadd_rms_norm_int8(
x,
residual,
bias,
gamma,
beta,
epsilon,
in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
):
out = naive_residual_biasadd_rms_norm(
x, residual, bias, gamma, beta, epsilon
)
out = quant_helper(
out, in_scale, quant_round_type, quant_max_bound, quant_min_bound
)
return out
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
......@@ -28,58 +95,195 @@ class TestRMSNormOp(unittest.TestCase):
np.random.seed(20)
batch = 32
cols = 256
self.x_np = np.random.random([batch, 256])
self.gamma_np = np.random.random([256])
self.beta_np = np.random.random([256])
self.epsilon = 1e-6
self.x_np = np.random.random([batch, cols])
self.residual_np = np.random.random([batch, cols])
self.bias_np = np.random.random([cols])
def naive_rms_norm(self, x, gamma, beta):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + self.epsilon) * x
out = out * gamma + beta
return out
self.norm_weight_np = np.random.random([cols])
self.norm_bias_np = np.random.random([cols])
self.epsilon = 1e-6
self.quant_scale = 0.15
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
def check_main(self, x_np, gamma_np, beta_np, dtype):
def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.rms_norm(
paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm(
x, gamma, beta, self.epsilon, begin_norm_axis=1
)
paddle_naive_rmsnorm_out = self.naive_rms_norm(x, gamma, beta)
paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
paddle_naive_rmsnorm_out = naive_rms_norm_int8(
x,
gamma,
beta,
self.epsilon,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def check_residual_bias_rmsnorm(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
bias=bias,
residual=residual,
)
paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm(
x, residual, bias, gamma, beta, self.epsilon
)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def check_residual_bias_rmsnorm_int8(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm(
x,
gamma,
beta,
self.epsilon,
begin_norm_axis=1,
bias=bias,
residual=residual,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm_int8(
x,
residual,
bias,
gamma,
beta,
self.epsilon,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def test_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float16'
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
np.testing.assert_allclose(
paddle_rmsnorm.numpy(),
paddle_rmsnorm[0].numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=1e-03,
rtol=1e-3,
atol=1e-3,
)
def test_rmsnorm_fp32(self):
def test_rmsnorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float32'
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
np.testing.assert_allclose(
paddle_rmsnorm[0].numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=2,
atol=2,
)
def test_residual_bias_add_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
paddle_rmsnorm.numpy(),
paddle_rmsnorm[0].numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_residual_bias_add_rmsnorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
(
paddle_rmsnorm,
paddle_naive_rmsnorm,
) = self.check_residual_bias_rmsnorm_int8(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
paddle_rmsnorm[0].numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=2,
atol=2,
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
......@@ -90,45 +294,145 @@ class TestRMSNormStaticOp(unittest.TestCase):
self.batch = 32
self.cols = 256
self.x_np = np.random.random([self.batch, 256])
self.gamma_np = np.random.random([256])
self.beta_np = np.random.random([256])
self.norm_weight_np = np.random.random([256])
self.norm_bias_np = np.random.random([256])
self.residual_np = np.random.random([self.batch, 256])
self.bias_np = np.random.random([256])
self.epsilon = 1e-6
self.quant_scale = 0.15
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
self.place = paddle.CUDAPlace(0)
def naive_rms_norm(self, x, gamma, beta):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + self.epsilon) * x
out = out * gamma + beta
return out
def check_main(self, x_np, gamma_np, beta_np, dtype):
def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_naive_rmsnorm_out = self.naive_rms_norm(x, gamma, beta)
paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=dtype
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=dtype
)
outs = paddle.incubate.nn.functional.fused_rms_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(dtype),
"beta_static": beta_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_rmsnorm_out
def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_naive_rmsnorm_out = naive_rms_norm_int8(
x,
gamma,
beta,
self.epsilon,
self.quant_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=dtype
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=dtype
)
outs = paddle.incubate.nn.functional.fused_rms_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
quant_scale=self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(dtype),
"beta_static": beta_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_rmsnorm_out
def check_residual_bias_rmsnorm(
self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype
):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
residual = paddle.to_tensor(residual_np.astype(dtype))
bias = paddle.to_tensor(bias_np.astype(dtype))
paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm(
x, residual, bias, gamma, beta, self.epsilon
)
paddle.enable_static()
outs = paddle.incubate.nn.functional.rms_norm(
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
residual_static = paddle.static.data(
name="residual_static",
shape=[self.batch, self.cols],
dtype=dtype,
)
bias_static = paddle.static.data(
name="bias_static", shape=[self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=dtype
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=dtype
)
outs = paddle.incubate.nn.functional.fused_rms_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
bias=bias_static,
residual=residual_static,
)
exe = fluid.Executor(self.place)
......@@ -137,17 +441,18 @@ class TestRMSNormStaticOp(unittest.TestCase):
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(dtype),
"beta_static": beta_np.astype(dtype),
"residual_static": residual_np.astype(dtype),
"bias_static": bias_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_rmsnorm_out
def test_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float16'
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
np.testing.assert_allclose(
......@@ -157,11 +462,16 @@ class TestRMSNormStaticOp(unittest.TestCase):
atol=1e-3,
)
def test_rmsnorm_fp32(self):
def test_residual_bias_add_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float32'
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm(
self.x_np,
self.norm_weight_np,
self.norm_bias_np,
self.residual_np,
self.bias_np,
'float16',
)
np.testing.assert_allclose(
......@@ -171,6 +481,20 @@ class TestRMSNormStaticOp(unittest.TestCase):
atol=1e-3,
)
def test_rmsnorm_int8(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8(
self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16'
)
print("1111")
np.testing.assert_allclose(
paddle_rmsnorm,
paddle_naive_rmsnorm.numpy(),
rtol=2,
atol=2,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册