未验证 提交 657dd5a9 编写于 作者: C crystal 提交者: GitHub

Optimize group_norm op forward (#39596)

* optimize group norm forward

* use vectorized optimization

* add scalar calculation code

* optimize code
上级 75280d36
......@@ -29,6 +29,7 @@ namespace operators {
using DataLayout = framework::DataLayout;
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
#define ALIGN_BYTES 16
#define CHECK_CASE(i, flags, kernel_name, ...) \
if (i == flags) { \
......@@ -56,8 +57,7 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
template <typename T>
__global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W,
int imsize, int groups,
int group_size, T* mean, T* var,
const DataLayout data_layout) {
int group_size, T* mean, T* var) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
......@@ -68,13 +68,10 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W,
T x_mean = 0, x_var = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val;
if (data_layout == DataLayout::kNCHW) {
val = x[(bid * C + ccid) * imsize + imid];
} else {
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];
}
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];
x_mean += val;
x_var += val * val;
}
......@@ -84,6 +81,85 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W,
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
}
template <typename T, typename AccT, int VecSize>
__device__ __forceinline__ void ThreadReduce(const T* input, int size,
const int offset, AccT* mean,
AccT* var) {
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
if (offset > 0) {
input -= offset;
size += offset;
if (tid >= offset) {
AccT temp = input[tid];
*mean += temp;
*var += temp * temp;
}
size -= blockDim.x;
input += blockDim.x;
}
int remain = size % (VecSize * blockDim.x);
T ins[VecSize];
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
*ins_vec = reinterpret_cast<const VecT*>(input)[tid];
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
AccT temp = ins[i];
*mean += temp;
*var += temp * temp;
}
}
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
AccT temp = input[tid];
*mean += temp;
*var += temp * temp;
}
}
template <typename T>
__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
int i = blockIdx.x;
T x_mean = 0, x_var = 0;
for (int j = threadIdx.x; j < size; j += blockDim.x) {
T val;
val = x[i * size + j];
x_mean += val;
x_var += val * val;
}
x_mean /= size;
x_var /= size;
CudaAtomicAddWithWarp(&mean[i], x_mean);
CudaAtomicAddWithWarp(&var[i], x_var);
}
template <typename T, typename AccT, int VecSize>
__global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var,
int size) {
int i = blockIdx.x;
AccT x_mean = static_cast<AccT>(0);
AccT x_var = static_cast<AccT>(0);
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
x += i * size;
ThreadReduce<T, AccT, VecSize>(x, size, input_offset, &x_mean, &x_var);
x_mean = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
x_mean, kps::AddFunctor<AccT>());
x_var = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
x_var, kps::AddFunctor<AccT>());
__syncthreads();
if (threadIdx.x == 0) {
mean[i] = static_cast<T>(x_mean / size);
var[i] = static_cast<T>(x_var / size);
}
}
template <typename T, int flags>
__global__ void GroupNormForward(const T* x, const T* mean, const T* var,
const T* scale, const T* bias, int N, int C,
......@@ -96,26 +172,34 @@ __global__ void GroupNormForward(const T* x, const T* mean, const T* var,
int H = imsize / W;
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = mean[bid * groups + gid];
T x_var = var[bid * groups + gid];
auto ng = bid * groups + gid;
T x_mean = mean[ng];
T x_var = var[ng];
x_var = x_var - x_mean * x_mean;
T var_inv = 1.0 / sqrt(x_var + epsilon);
if (cid == 0 && threadIdx.x == 0) real_var[bid * groups + gid] = x_var;
T var_inv = rsqrt(x_var + epsilon);
if (cid == 0 && threadIdx.x == 0) {
real_var[ng] = x_var;
}
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val;
int hid, wid;
int index = (bid * C + ccid) * imsize + imid;
if (data_layout == DataLayout::kNCHW) {
val = x[(bid * C + ccid) * imsize + imid];
val = x[index];
} else {
hid = imid / W;
wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];
}
val = (val - x_mean) * var_inv;
if (flags & kHasScale) val *= scale[gid * group_size + cid];
if (flags & kHasBias) val += bias[gid * group_size + cid];
if (flags & kHasScale) {
val *= scale[ccid];
}
if (flags & kHasBias) {
val += bias[ccid];
}
if (data_layout == DataLayout::kNCHW) {
y[(bid * C + ccid) * imsize + imid] = val;
y[index] = val;
} else {
y[(bid * H + hid) * W * C + wid * C + ccid] = val;
}
......@@ -182,16 +266,41 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
#else
int block_size = std::min(1024, imsize);
#endif
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
temp_var_data, data_layout);
if (data_layout == DataLayout::kNCHW) {
using AccT = typename details::MPTypeTrait<T>::Type;
constexpr int vec_size = sizeof(float4) / sizeof(T);
int size = group_size * imsize;
const int max_num_threads = 1024;
int max_block_size = std::min(size / vec_size, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 grids(x_dims[0] * groups);
dim3 blocks(block_size_nchw);
if (size < vec_size) {
ScalarGetMeanAndVarNCHW<T><<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size);
} else {
VectorizedGetMeanAndVarNCHW<
T, AccT, vec_size><<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size);
}
} else {
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
temp_var_data);
}
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags, GroupNormForward, x_data, mean_data, temp_var_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册