未验证 提交 04a4bdf8 编写于 作者: C crystal 提交者: GitHub

fix group_norm (#41531)

fix group_norm vectorized address misalignment
上级 c2e12949
...@@ -419,23 +419,6 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale, ...@@ -419,23 +419,6 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
} }
} }
template <typename T, typename AccT, int VecSize>
__global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
T* ds, T* db) {
int i = blockIdx.x;
AccT ds_sum = static_cast<AccT>(0);
AccT db_sum = static_cast<AccT>(0);
x += i * imsize;
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
phi::Array<const T*, 2> ins;
ins[0] = x;
ins[1] = dy;
ThreadReduce<T, AccT, VecSize, 2>(ins, imsize, input_offset, &db_sum,
&ds_sum);
ReduceMeanAndVar<AccT>(db, ds, db_sum, ds_sum, 1);
}
template <typename T> template <typename T>
__global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy, __global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
T* ds, T* db) { T* ds, T* db) {
...@@ -622,25 +605,9 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -622,25 +605,9 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
int flags = int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
using AccT = typename details::MPTypeTrait<T>::Type; ScalarGetDsDbCUDAKernel<
constexpr int vec_size = sizeof(float4) / sizeof(T); T><<<x_dims[0] * C, block_size, 0, dev_ctx.stream()>>>(
const int max_num_threads = 1024; imsize, x_data, dy_data, ds_data, db_data);
int max_block_size = std::min(imsize / vec_size, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 blocks(block_size_nchw);
if (imsize < vec_size * block_size_nchw) {
ScalarGetDsDbCUDAKernel<
T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
} else {
VectorizedGetDsDbCUDAKernel<
T, AccT, vec_size><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
}
if (d_scale || d_bias) { if (d_scale || d_bias) {
const int block = 256; const int block = 256;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册