未验证 提交 97d3d6ee 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Integrate rmsnorm kernel (#54998)

* add rmsnorm kernel
* add static graph test
* fix round type
* use alignas to avoid msvc compile error
* remove redundant headerfile to avoid rocm compile error
* fix rocm compile not found cub
* Add document
上级 852d7a12
......@@ -1994,6 +1994,16 @@
data_type : x
backward : reverse_grad
- op : rms_norm
args : (Tensor x, Tensor weight, Tensor bias, float epsilon, int begin_norm_axis)
output : Tensor(out)
infer_meta :
func : RmsNormInferMeta
kernel :
func : rms_norm
data_type : x
optional : bias
- op : rmsprop_
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)
output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out), Tensor(master_param_outs)
......
......@@ -3137,6 +3137,38 @@ void Unpool3dInferMeta(const MetaTensor& x,
}
}
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out) {
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
auto x_dims_size = x_dims_vec.size();
size_t normalized_dims = 1;
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
normalized_dims *= x_dims_vec[i];
}
PADDLE_ENFORCE_EQ(normalized_dims,
weight.dims()[0],
phi::errors::InvalidArgument(
"The normalized size of Input(X) must equal to be"
"the size of Weight, but received"
"normalized size of Input(X) is [%d], received size"
"of Weight is [%d]",
normalized_dims,
weight.dims()[0]));
auto out_dims = phi::make_ddim(x_dims_vec);
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
......@@ -479,4 +479,11 @@ void Unpool3dInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out);
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Original OneFlow copyright notice:
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh
// The following code modified from OneFlow's implementation, and change to use
// single Pass algorithm. Support Int8 quant, dequant Load/Store implementation.
#include "paddle/phi/kernels/rms_norm_kernel.h"
#include <assert.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include <cub/cub.cuh>
#endif
namespace phi {
namespace {
#ifndef PADDLE_WITH_HIP
constexpr int kWarpSize = 32;
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
return a + b;
}
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
return max(a, b);
}
};
template <template <typename> class ReductionOp,
typename T,
int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(
val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width));
}
return val;
}
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template <typename T>
__inline__ __device__ T Div(T a, T b);
template <>
__inline__ __device__ float Div<float>(float a, float b) {
return a / b;
}
template <>
__inline__ __device__ double Div<double>(double a, double b) {
return a / b;
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}
template <class Func>
inline cudaError_t GetNumBlocks(Func func,
int32_t block_size,
size_t dynamic_smem_size,
int32_t max_blocks,
int32_t waves,
int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err =
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int max_active_blocks;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, func, block_size, dynamic_smem_size);
}
*num_blocks = std::max<int>(
1, std::min<int32_t>(max_blocks, sm_count * max_active_blocks * waves));
return cudaSuccess;
}
template <typename T>
class HasCanPackAs {
typedef char one;
struct two {
char x[2];
};
template <typename C>
static one test(decltype(&C::CanPackAs));
template <typename C>
static two test(...);
public:
enum { value = sizeof(test<T>(0)) == sizeof(char) };
};
template <typename T>
typename std::enable_if<HasCanPackAs<T>::value == true, bool>::type CanPackAs(
T t, size_t pack_size) {
return t.CanPackAs(pack_size);
}
template <typename T>
typename std::enable_if<HasCanPackAs<T>::value == false, bool>::type CanPackAs(
T t, size_t pack_size) {
return true;
}
template <typename T, int N>
struct alignas(sizeof(T) * N) Pack {
__device__ Pack() {
// do nothing
}
T elem[N];
};
template <typename SRC, typename DST>
struct DirectLoad {
using LoadType = DST;
DirectLoad(const SRC* src, int32_t row_size) : src(src), row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int32_t row, int32_t col) const {
Pack<SRC, N> pack;
const int32_t offset = (row * row_size + col) / N;
pack = *(reinterpret_cast<const Pack<SRC, N>*>(src) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(pack.elem[i]);
}
}
const SRC* src;
int32_t row_size;
};
template <typename SRC, typename DST>
struct ResidualAddBiasLoad {
using LoadType = DST;
ResidualAddBiasLoad(const SRC* src,
const SRC* residual,
const SRC* bias,
SRC* residual_out,
int32_t row_size)
: src(src),
residual(residual),
bias(bias),
residual_out(residual_out),
row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int32_t row, int32_t col) const {
Pack<SRC, N> src_pack;
Pack<SRC, N> residual_pack;
Pack<SRC, N> bias_pack;
const int32_t offset = (row * row_size + col) / N;
src_pack = *(reinterpret_cast<const Pack<SRC, N>*>(src) + offset);
residual_pack = *(reinterpret_cast<const Pack<SRC, N>*>(residual) + offset);
if (bias) {
bias_pack = *(reinterpret_cast<const Pack<SRC, N>*>(bias) + col / N);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
bias_pack.elem[i] = static_cast<SRC>(0.0f);
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
src_pack.elem[i] =
src_pack.elem[i] + residual_pack.elem[i] + bias_pack.elem[i];
}
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(src_pack.elem[i]);
}
*(reinterpret_cast<Pack<SRC, N>*>(residual_out) + offset) = src_pack;
}
const SRC* src;
const SRC* residual;
const SRC* bias;
SRC* residual_out;
int32_t row_size;
};
template <typename SRC, typename DST>
struct DirectStore {
DirectStore(DST* dst, int32_t row_size) : dst(dst), row_size(row_size) {}
template <int N>
__device__ void store(const SRC* src, int32_t row, int32_t col) {
Pack<DST, N> pack;
const int32_t offset = (row * row_size + col) / N;
#pragma unroll
for (int i = 0; i < N; ++i) {
pack.elem[i] = static_cast<DST>(src[i]);
}
*(reinterpret_cast<Pack<DST, N>*>(dst) + offset) = pack;
}
DST* dst;
int32_t row_size;
};
template <typename T>
inline __device__ void WelfordCombine(T val, T* mean, T* m2, T* count) {
// Use Welford Online algorithem to compute mean and variance
// For more details you can refer to:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
*count += 1;
T delta1 = val - *mean;
*mean += Div(delta1, *count);
T delta2 = val - *mean;
*m2 += delta1 * delta2;
}
template <typename T>
inline __device__ void WelfordCombine(
T b_mean, T b_m2, T b_count, T* mean, T* m2, T* count) {
if (b_count == 0) {
return;
}
T new_count = *count + b_count;
T nb_over_n = Div(b_count, new_count);
T delta = b_mean - *mean;
*mean += delta * nb_over_n;
*m2 += b_m2 + delta * delta * (*count) * nb_over_n;
*count = new_count;
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(
T thread_mean, T thread_m2, T thread_count, T* mean, T* m2, T* count) {
*mean = thread_mean;
*m2 = thread_m2;
*count = thread_count;
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width);
T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width);
T b_count = __shfl_down_sync(0xffffffff, *count, mask, thread_group_width);
WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);
}
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(
T thread_mean, T thread_m2, T thread_count, T* mean, T* m2, T* count) {
WelfordWarpReduce<T, thread_group_width>(
thread_mean, thread_m2, thread_count, mean, m2, count);
*mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width);
*m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width);
*count = __shfl_sync(0xffffffff, *count, 0, thread_group_width);
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpReduceSum(T x) {
T result = 0.0f;
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
result += __shfl_xor_sync(0xffffffff, x, mask, thread_group_width);
}
return result;
}
template <typename T>
__inline__ __device__ void WelfordBlockAllReduce(T thread_mean,
T thread_m2,
T thread_count,
T* result_mean,
T* result_m2,
T* result_count) {
__shared__ T mean_shared[kWarpSize];
__shared__ T m2_shared[kWarpSize];
__shared__ T count_shared[kWarpSize];
__shared__ T mean_result_broadcast;
__shared__ T m2_result_broadcast;
__shared__ T count_result_broadcast;
const int lid = threadIdx.x % kWarpSize;
const int wid = threadIdx.x / kWarpSize;
T warp_mean = 0;
T warp_m2 = 0;
T warp_count = 0;
WelfordWarpReduce(
thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
__syncthreads();
if (lid == 0) {
mean_shared[wid] = warp_mean;
m2_shared[wid] = warp_m2;
count_shared[wid] = warp_count;
}
__syncthreads();
if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncwarp();
T block_mean = 0;
T block_m2 = 0;
T block_count = 0;
WelfordWarpReduce(
warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count);
if (lid == 0) {
mean_result_broadcast = block_mean;
m2_result_broadcast = block_m2;
count_result_broadcast = block_count;
}
}
__syncthreads();
*result_mean = mean_result_broadcast;
*result_m2 = m2_result_broadcast;
*result_count = count_result_broadcast;
}
template <typename LOAD,
typename STORE,
typename ComputeType,
int kPackSize,
int block_size>
__global__ void __launch_bounds__(block_size)
RmsNormBlockSMemImpl(LOAD load,
STORE store,
const int32_t rows,
const int32_t cols,
const float epsilon,
ComputeType col_divisor) {
using LoadType = typename LOAD::LoadType;
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast<LoadType*>(shared_buf);
const int tid = threadIdx.x;
assert(cols % kPackSize == 0);
const int num_packs = static_cast<int>(cols) / kPackSize;
for (int32_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_sum_square = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
LoadType pack[kPackSize];
load.template load<kPackSize>(pack, row, pack_id * kPackSize);
#pragma unroll
for (int i = 0; i < kPackSize; ++i) {
buf[i * num_packs + pack_id] = pack[i];
ComputeType pack_val = static_cast<ComputeType>(pack[i]);
thread_sum_square += pack_val * pack_val;
}
}
const ComputeType row_sum_square =
BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum_square);
// use multiply instead of divide. Author(zhengzekang).
ComputeType row_rms = row_sum_square * col_divisor;
ComputeType row_inv_rms =
Rsqrt(row_rms + static_cast<ComputeType>(epsilon));
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[kPackSize];
#pragma unroll
for (int i = 0; i < kPackSize; ++i) {
pack[i] = static_cast<ComputeType>(buf[i * num_packs + pack_id]) *
row_inv_rms;
}
store.template store<kPackSize>(pack, row, pack_id * kPackSize);
}
}
}
template <typename LOAD,
typename STORE,
typename ComputeType,
int kPackSize,
int block_size>
inline cudaError_t LaunchRmsNormBlockSMemImpl(cudaStream_t stream,
LOAD load,
STORE store,
int smem,
const int32_t rows,
const int32_t cols,
const float epsilon,
ComputeType col_divisor) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(
RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, kPackSize, block_size>,
block_size,
smem,
rows,
waves,
&grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, kPackSize, block_size>
<<<grid_dim_x, block_size, smem, stream>>>(
load, store, rows, cols, epsilon, col_divisor);
return cudaPeekAtLastError();
}
template <typename Func>
cudaError_t MaximizeDynamicSharedMemorySize(Func func,
const int max_smem_size) {
cudaFuncAttributes attr{};
cudaError_t err = cudaFuncGetAttributes(&attr, func);
if (err != cudaSuccess) {
return err;
}
constexpr int reserved_smem = 1024; // 1K
return cudaFuncSetAttribute(
func,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_smem_size - attr.sharedSizeBytes - reserved_smem);
}
template <typename LOAD, typename STORE, typename ComputeType, int kPackSize>
inline cudaError_t TryDispatchRmsNormBlockSMemImplBlockSize(
cudaStream_t stream,
LOAD load,
STORE store,
const int32_t rows,
const int32_t cols,
const float epsilon,
ComputeType col_divisor,
bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
int dev = 0;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count = 0;
{
cudaError_t err =
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
static const bool max_smem_configed = [=]() {
int max_smem_size = 0;
cudaError_t err = cudaDeviceGetAttribute(
&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
if (err != cudaSuccess) {
return false;
}
err =
MaximizeDynamicSharedMemorySize(RmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_1>,
max_smem_size);
if (err != cudaSuccess) {
return false;
}
err =
MaximizeDynamicSharedMemorySize(RmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_2>,
max_smem_size);
if (err != cudaSuccess) {
return false;
}
err =
MaximizeDynamicSharedMemorySize(RmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_3>,
max_smem_size);
if (err != cudaSuccess) {
return false;
}
err =
MaximizeDynamicSharedMemorySize(RmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_4>,
max_smem_size);
if (err != cudaSuccess) {
return false;
}
return true;
}();
const size_t smem = cols * sizeof(typename LOAD::LoadType);
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
RmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_1>,
block_size_conf_1,
smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
RmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_4>,
block_size_conf_4,
smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1 ||
(max_active_blocks_conf_4 > 0 && rows <= sm_count)) {
*success = true;
return LaunchRmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_4>(
stream, load, store, smem, rows, cols, epsilon, col_divisor);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
RmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_3>,
block_size_conf_3,
smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1 ||
(max_active_blocks_conf_3 > 0 && rows <= sm_count)) {
*success = true;
return LaunchRmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_3>(
stream, load, store, smem, rows, cols, epsilon, col_divisor);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
RmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_2>,
block_size_conf_2,
smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1 ||
(max_active_blocks_conf_2 > 0 && rows <= sm_count)) {
*success = true;
return LaunchRmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_2>(
stream, load, store, smem, rows, cols, epsilon, col_divisor);
}
*success = true;
return LaunchRmsNormBlockSMemImpl<LOAD,
STORE,
ComputeType,
kPackSize,
block_size_conf_1>(
stream, load, store, smem, rows, cols, epsilon, col_divisor);
}
template <typename LOAD, typename STORE, typename ComputeType>
struct TryDispatchRmsNormBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream,
LOAD load,
STORE store,
const int32_t rows,
const int32_t cols,
const float epsilon,
ComputeType col_divisor,
bool* success) {
if (cols % 4 == 0 && CanPackAs<LOAD>(load, 4) &&
CanPackAs<STORE>(store, 4)) {
return TryDispatchRmsNormBlockSMemImplBlockSize<LOAD,
STORE,
ComputeType,
4>(
stream, load, store, rows, cols, epsilon, col_divisor, success);
} else if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) &&
CanPackAs<STORE>(store, 2)) {
return TryDispatchRmsNormBlockSMemImplBlockSize<LOAD,
STORE,
ComputeType,
2>(
stream, load, store, rows, cols, epsilon, col_divisor, success);
} else {
return TryDispatchRmsNormBlockSMemImplBlockSize<LOAD,
STORE,
ComputeType,
1>(
stream, load, store, rows, cols, epsilon, col_divisor, success);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t TryDispatchRmsNormBlockSMemImpl(cudaStream_t stream,
LOAD load,
STORE store,
const int32_t rows,
const int32_t cols,
const float epsilon,
ComputeType col_divisor,
bool* success) {
return TryDispatchRmsNormBlockSMemImplPackSize<LOAD, STORE, ComputeType>()(
stream, load, store, rows, cols, epsilon, col_divisor, success);
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value,
cudaError_t>::type
DispatchRmsNorm(cudaStream_t stream,
LOAD load,
STORE store,
const int32_t rows,
const int32_t cols,
const float epsilon) {
const ComputeType col_divisor = 1.0f / cols;
bool dispatch_smem_impl_success;
{
cudaError_t err = TryDispatchRmsNormBlockSMemImpl<LOAD, STORE, ComputeType>(
stream,
load,
store,
rows,
cols,
epsilon,
col_divisor,
&dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
return cudaSuccess;
}
template <typename SRC, typename DST>
struct SkipLoadAndStoreResidual {
using LoadType = DST;
// need to aseert SRC equals to DST.
SkipLoadAndStoreResidual(SRC* src,
const SRC* bias,
const SRC* skip,
SRC* residual_bias_out,
float alpha,
int32_t row_size)
: src(src),
bias(bias),
skip(skip),
residual_bias_out(residual_bias_out),
alpha(alpha),
row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int32_t row, int32_t col) const {
Pack<SRC, N> src_pack;
Pack<SRC, N> bias_pack;
Pack<SRC, N> skip_pack;
Pack<DST, N> residual_out_pack;
const int32_t offset = (row * row_size + col) / N;
const int32_t bias_offset = col / N;
src_pack = *(reinterpret_cast<const Pack<SRC, N>*>(src) + offset);
bias_pack = *(reinterpret_cast<const Pack<SRC, N>*>(bias) + bias_offset);
skip_pack = *(reinterpret_cast<const Pack<SRC, N>*>(skip) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
residual_out_pack.elem[i] =
src_pack.elem[i] + bias_pack.elem[i] + skip_pack.elem[i];
}
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = residual_out_pack.elem[i];
}
*(reinterpret_cast<Pack<SRC, N>*>(residual_bias_out) + offset) =
residual_out_pack;
}
SRC* src;
const SRC* bias;
const SRC* skip;
SRC* residual_bias_out;
float alpha;
int32_t row_size;
};
template <typename SRC, typename DST>
struct AffineStore {
AffineStore(DST* y, int32_t row_size, const DST* gamma, const DST* beta)
: y(y), row_size(row_size), gamma(gamma), beta(beta) {}
template <int N>
__device__ void store(const SRC* src, int32_t row, int32_t col) {
Pack<DST, N> y_pack;
Pack<DST, N> gamma_pack;
Pack<DST, N> beta_pack;
const int32_t offset = (row * row_size + col) / N;
const int32_t gamma_offset = col / N;
gamma_pack = *(reinterpret_cast<const Pack<DST, N>*>(gamma) + gamma_offset);
// Author(Zhengzekang): Bias maybe optional.
if (beta) {
beta_pack = *(reinterpret_cast<const Pack<DST, N>*>(beta) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; i++) {
beta_pack.elem[i] = static_cast<DST>(0.0f);
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
float normalized_i = static_cast<float>(src[i]);
float normalized_val =
normalized_i * static_cast<float>(gamma_pack.elem[i]) +
static_cast<float>(beta_pack.elem[i]);
y_pack.elem[i] = static_cast<DST>(normalized_val);
}
*(reinterpret_cast<Pack<DST, N>*>(y) + offset) = y_pack;
}
DST* y;
int32_t row_size;
const DST* gamma;
const DST* beta;
};
// ======== For Int8 Output ========
template <typename T>
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
if (v > max) return max;
if (v < min) return min;
return v;
}
template <typename InType, typename OutType>
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * input;
if (round_type == 0) {
quant_value = static_cast<float>(rint(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
return static_cast<OutType>(
ClipFunc<float>(quant_value, min_bound, max_bound));
}
template <typename OutType,
typename SRC,
typename DST,
bool do_scale,
bool do_center>
struct AffineQuantStore {
AffineQuantStore(OutType* y,
const int64_t row_size,
const DST* gamma,
const DST* beta,
const float quant_out_scale,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0)
: y(y),
row_size(row_size),
gamma(gamma),
beta(beta),
quant_round_type(quant_round_type),
quant_out_scale(quant_out_scale),
quant_max_bound(quant_max_bound),
quant_min_bound(quant_min_bound) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<OutType, N> y_pack;
Pack<DST, N> gamma_pack;
Pack<DST, N> beta_pack;
Pack<OutType, N> out_pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t gamma_offset = col / N;
gamma_pack = *(reinterpret_cast<const Pack<DST, N>*>(gamma) + gamma_offset);
// Author(Zhengzekang): Bias maybe optional.
if (beta) {
beta_pack = *(reinterpret_cast<const Pack<DST, N>*>(beta) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; i++) {
beta_pack.elem[i] = static_cast<DST>(0.0f);
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
float normalized_i = static_cast<float>(src[i]);
float normalized_val =
normalized_i * static_cast<float>(gamma_pack.elem[i]) +
static_cast<float>(beta_pack.elem[i]);
y_pack.elem[i] = QuantHelperFunc<float, OutType>(normalized_val,
quant_out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
*(reinterpret_cast<Pack<OutType, N>*>(y) + offset) = y_pack;
}
OutType* y;
int64_t row_size;
const DST* gamma;
const DST* beta;
const int quant_round_type;
const float quant_out_scale;
const float quant_max_bound;
const float quant_min_bound;
};
#endif
} // namespace
template <typename T, typename Context>
void RmsNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int begin_norm_axis,
DenseTensor* out) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
const T* x_data = x.data<T>();
const T* weight_data = weight.data<T>();
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
T* out_data = dev_ctx.template Alloc<T>(out);
int32_t rows = 1;
int32_t cols = 1;
for (int i = 0; i < begin_norm_axis; i++) {
rows *= x.dims()[i];
}
for (int i = begin_norm_axis; i < x.dims().size(); i++) {
cols *= x.dims()[i];
}
DirectLoad<T, ComputeType> load(x_data, cols);
AffineStore<ComputeType, T> store(out_data, cols, weight_data, bias_data);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
dev_ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template <typename T, typename Context>
void RmsNormWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
T* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
DirectLoad<T, ComputeType> load(x, cols);
AffineStore<ComputeType, T> store(output, cols, weight, bias);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void RmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* weight,
const phi::dtype::float16* bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::float16* output);
template void RmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* weight,
const phi::dtype::bfloat16* bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::bfloat16* output);
template void RmsNormWrapper(const phi::GPUContext& ctx,
const float* x,
const float* weight,
const float* bias,
const float epsilon,
const int rows,
const int cols,
float* output);
// ========== ResidualAdd + RMSNorm ==========
template <typename T, typename Context>
void ResidualAddRmsNormWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
T* residual_output,
T* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
ResidualAddBiasLoad<T, ComputeType> load(
x, residual, bias, residual_output, cols);
AffineStore<ComputeType, T> store(output, cols, norm_weight, norm_bias);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* residual,
const phi::dtype::float16* bias,
const phi::dtype::float16* norm_weight,
const phi::dtype::float16* norm_bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::float16* residual_output,
phi::dtype::float16* output);
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* residual,
const phi::dtype::bfloat16* bias,
const phi::dtype::bfloat16* norm_weight,
const phi::dtype::bfloat16* norm_bias,
const float epsilon,
const int rows,
const int cols,
phi::dtype::bfloat16* residual_output,
phi::dtype::bfloat16* output);
template void ResidualAddRmsNormWrapper(const phi::GPUContext& ctx,
const float* x,
const float* residual,
const float* bias,
const float* norm_weight,
const float* norm_bias,
const float epsilon,
const int rows,
const int cols,
float* residual_output,
float* output);
// ===== FP16 in, Int8out RMSNorm =====
template <typename T, typename Context>
void RmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
DirectLoad<T, ComputeType> load(x, cols);
AffineQuantStore<int8_t, ComputeType, T, true, true> store(output,
cols,
weight,
bias,
in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const float* x,
const float* weight,
const float* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* weight,
const phi::dtype::float16* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template void RmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* weight,
const phi::dtype::bfloat16* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
// ===== FP16 in, Int8out ResidualAdd + RMSNorm =====
template <typename T, typename Context>
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
T* residual_output,
int8_t* output) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with CUDA, ROCM platform isn't support it";
#else
using ComputeType = typename phi::dtype::MPTypeTrait<T>::Type;
ResidualAddBiasLoad<T, ComputeType> load(
x, residual, bias, residual_output, cols);
AffineQuantStore<int8_t, ComputeType, T, true, true> store(output,
cols,
norm_weight,
norm_bias,
in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
DispatchRmsNorm<decltype(load), decltype(store), ComputeType>(
ctx.stream(), load, store, rows, cols, epsilon);
#endif
}
template void ResidualAddRmsNormInt8OutWrapper(const phi::GPUContext& ctx,
const float* x,
const float* residual,
const float* bias,
const float* norm_weight,
const float* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
float* residual_output,
int8_t* output);
template void ResidualAddRmsNormInt8OutWrapper(
const phi::GPUContext& ctx,
const phi::dtype::float16* x,
const phi::dtype::float16* residual,
const phi::dtype::float16* bias,
const phi::dtype::float16* norm_weight,
const phi::dtype::float16* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
phi::dtype::float16* residual_output,
int8_t* output);
template void ResidualAddRmsNormInt8OutWrapper(
const phi::GPUContext& ctx,
const phi::dtype::bfloat16* x,
const phi::dtype::bfloat16* residual,
const phi::dtype::bfloat16* bias,
const phi::dtype::bfloat16* norm_weight,
const phi::dtype::bfloat16* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
phi::dtype::bfloat16* residual_output,
int8_t* output);
} // namespace phi
PD_REGISTER_KERNEL(rms_norm,
GPU,
ALL_LAYOUT,
phi::RmsNormKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void RmsNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int begin_norm_axis,
DenseTensor* out);
template <typename T, typename Context>
void RmsNormWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
T* output);
template <typename T, typename Context>
void ResidualAddRmsNormWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
T* residual_output,
T* output);
template <typename T, typename Context>
void RmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);
template <typename T, typename Context>
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
T* residual_output,
int8_t* output);
} // namespace phi
......@@ -21,6 +21,7 @@ from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add
from .fused_gate_attention import fused_gate_attention
from .fused_rotary_position_embedding import fused_rotary_position_embedding
from .rms_norm import rms_norm
__all__ = [
......@@ -33,4 +34,5 @@ __all__ = [
'fused_ec_moe',
'fused_dropout_add',
'fused_rotary_position_embedding',
"rms_norm",
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
def rms_norm(x, weight, bias, epsilon, begin_norm_axis):
r"""
Apply RMSNorm kernel.
Args:
x (Tensor): the input Tensor..
weight (Tensor): the weight Tensor to affine output.
bias (Tensor): the bias Tensor to affine output.
epsilon (float): a small float number to avoid divide 0.
begin_norm_axis (int): the begin axis to normalize.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16)
paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16)
paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16)
epsilon = 1e-6
paddle_rmsnorm = paddle.incubate.nn.functional.rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""
if in_dygraph_mode():
return _C_ops.rms_norm(x, weight, bias, epsilon, begin_norm_axis)
helper = LayerHelper('rms_norm', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='rms_norm',
inputs={'x': x, 'weight': weight, 'bias': bias},
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis},
outputs={'out': out},
)
return out
......@@ -75,6 +75,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass)
list(REMOVE_ITEM TEST_OPS test_fused_comm_buffer)
......@@ -154,6 +155,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress)
endif()
list(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import core
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestRMSNormOp(unittest.TestCase):
def setUp(self):
np.random.seed(20)
batch = 32
cols = 256
self.x_np = np.random.random([batch, 256])
self.gamma_np = np.random.random([256])
self.beta_np = np.random.random([256])
self.epsilon = 1e-6
def naive_rms_norm(self, x, gamma, beta):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + self.epsilon) * x
out = out * gamma + beta
return out
def check_main(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_rmsnorm_out = paddle.incubate.nn.functional.rms_norm(
x, gamma, beta, self.epsilon, begin_norm_axis=1
)
paddle_naive_rmsnorm_out = self.naive_rms_norm(x, gamma, beta)
paddle.enable_static()
return paddle_rmsnorm_out, paddle_naive_rmsnorm_out
def test_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float16'
)
np.testing.assert_allclose(
paddle_rmsnorm.numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=1e-03,
atol=1e-3,
)
def test_rmsnorm_fp32(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float32'
)
np.testing.assert_allclose(
paddle_rmsnorm.numpy(),
paddle_naive_rmsnorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestRMSNormStaticOp(unittest.TestCase):
def setUp(self):
np.random.seed(20)
self.batch = 32
self.cols = 256
self.x_np = np.random.random([self.batch, 256])
self.gamma_np = np.random.random([256])
self.beta_np = np.random.random([256])
self.epsilon = 1e-6
self.place = paddle.CUDAPlace(0)
def naive_rms_norm(self, x, gamma, beta):
variance = x.pow(2).mean(-1, keepdim=True)
out = paddle.rsqrt(variance + self.epsilon) * x
out = out * gamma + beta
return out
def check_main(self, x_np, gamma_np, beta_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
gamma = paddle.to_tensor(gamma_np.astype(dtype))
beta = paddle.to_tensor(beta_np.astype(dtype))
paddle_naive_rmsnorm_out = self.naive_rms_norm(x, gamma, beta)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static", shape=[self.batch, self.cols], dtype=dtype
)
gamma_static = paddle.static.data(
name="gamma_static", shape=[self.cols], dtype=dtype
)
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype=dtype
)
outs = paddle.incubate.nn.functional.rms_norm(
x_static,
gamma_static,
beta_static,
self.epsilon,
begin_norm_axis=1,
)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x_np.astype(dtype),
"gamma_static": gamma_np.astype(dtype),
"beta_static": beta_np.astype(dtype),
},
fetch_list=[outs],
)
return out_s[0], paddle_naive_rmsnorm_out
def test_rmsnorm_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float16'
)
np.testing.assert_allclose(
paddle_rmsnorm,
paddle_naive_rmsnorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_rmsnorm_fp32(self):
if not paddle.is_compiled_with_cuda():
return
paddle_rmsnorm, paddle_naive_rmsnorm = self.check_main(
self.x_np, self.gamma_np, self.beta_np, 'float32'
)
np.testing.assert_allclose(
paddle_rmsnorm,
paddle_naive_rmsnorm.numpy(),
rtol=1e-3,
atol=1e-3,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册