未验证 提交 ff2fba39 编写于 作者: C crystal 提交者: GitHub

modify the block size of the group_norm backward (#41570)

上级 9872da00
...@@ -605,8 +605,16 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -605,8 +605,16 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
int flags = int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
const int max_num_threads = 1024;
int max_block_size = std::min(imsize, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 blocks(block_size_nchw);
ScalarGetDsDbCUDAKernel< ScalarGetDsDbCUDAKernel<
T><<<x_dims[0] * C, block_size, 0, dev_ctx.stream()>>>( T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data); imsize, x_data, dy_data, ds_data, db_data);
if (d_scale || d_bias) { if (d_scale || d_bias) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册