未验证 提交 dd9d7206 编写于 作者: C crystal 提交者: GitHub

fix group_norm address misalignment (#40657)

* fix group_norm address misalignment

* fix vectorize

* fix code

* fix vectorize length

* optimize code
上级 c63e03b1
......@@ -152,6 +152,21 @@ __device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
}
}
template <typename T>
__device__ __forceinline__ void ReduceMeanAndVar(T* mean, T* var, T x_mean,
T x_var, int size) {
const int nc = blockIdx.x;
x_mean = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
x_mean, kps::AddFunctor<T>());
x_var = kps::details::BlockXReduce<T, kps::AddFunctor<T>>(
x_var, kps::AddFunctor<T>());
__syncthreads();
if (threadIdx.x == 0) {
mean[nc] = static_cast<T>(x_mean / size);
var[nc] = static_cast<T>(x_var / size);
}
}
template <typename T>
__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
int i = blockIdx.x;
......@@ -162,10 +177,7 @@ __global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) {
x_mean += val;
x_var += val * val;
}
x_mean /= size;
x_var /= size;
CudaAtomicAddWithWarp(&mean[i], x_mean);
CudaAtomicAddWithWarp(&var[i], x_var);
ReduceMeanAndVar<T>(mean, var, x_mean, x_var, size);
}
template <typename T, typename AccT, int VecSize>
......@@ -174,21 +186,12 @@ __global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var,
int i = blockIdx.x;
AccT x_mean = static_cast<AccT>(0);
AccT x_var = static_cast<AccT>(0);
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
x += i * size;
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
phi::Array<const T*, 1> ins;
ins[0] = x;
ThreadReduce<T, AccT, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var);
x_mean = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
x_mean, kps::AddFunctor<AccT>());
x_var = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
x_var, kps::AddFunctor<AccT>());
__syncthreads();
if (threadIdx.x == 0) {
mean[i] = static_cast<T>(x_mean / size);
var[i] = static_cast<T>(x_var / size);
}
ReduceMeanAndVar<AccT>(mean, var, x_mean, x_var, size);
}
template <typename T, int flags>
......@@ -272,10 +275,6 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Tensor temp_var;
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, mean, static_cast<T>(0));
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
auto* x_data = x->data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
......@@ -319,7 +318,7 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 grids(x_dims[0] * groups);
dim3 blocks(block_size_nchw);
if (size < vec_size) {
if (size < vec_size * block_size_nchw) {
ScalarGetMeanAndVarNCHW<T><<<grids, blocks, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, size);
} else {
......@@ -328,6 +327,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
x_data, mean_data, temp_var_data, size);
}
} else {
set_zero(dev_ctx, mean, static_cast<T>(0));
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
temp_var_data);
......@@ -424,24 +425,15 @@ __global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
int i = blockIdx.x;
AccT ds_sum = static_cast<AccT>(0);
AccT db_sum = static_cast<AccT>(0);
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
x += i * imsize;
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
phi::Array<const T*, 2> ins;
ins[0] = x;
ins[1] = dy;
ThreadReduce<T, AccT, VecSize, 2>(ins, imsize, input_offset, &db_sum,
&ds_sum);
ds_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
ds_sum, kps::AddFunctor<AccT>());
db_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
db_sum, kps::AddFunctor<AccT>());
__syncthreads();
if (threadIdx.x == 0) {
ds[i] = ds_sum;
db[i] = db_sum;
}
ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1);
}
template <typename T>
......@@ -455,8 +447,7 @@ __global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
CudaAtomicAddWithWarp(&ds[nc], ds_sum);
CudaAtomicAddWithWarp(&db[nc], db_sum);
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1);
}
template <typename T>
......@@ -641,13 +632,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 blocks(block_size_nchw);
if (imsize < vec_size) {
if (d_scale) {
set_zero(dev_ctx, d_scale, static_cast<T>(0));
}
if (d_bias) {
set_zero(dev_ctx, d_bias, static_cast<T>(0));
}
if (imsize < vec_size * block_size_nchw) {
ScalarGetDsDbCUDAKernel<
T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
......@@ -687,7 +672,6 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
imsize, C, group_size, groups, p1_data, p2_data, p3_data, x_data,
dy_data, d_x_data);
}
} else {
if (d_scale) {
set_zero(dev_ctx, d_scale, static_cast<T>(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册