未验证 提交 f218330e 编写于 作者: Z zhiboniu 提交者: GitHub

fix group_size = floor(C/groups) from ceil(C/groups); add groupnorm group divisible check (#35644)

上级 67a094b5
......@@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel {
"The Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s].",
groups));
PADDLE_ENFORCE_EQ(
channel_num % groups, 0,
platform::errors::InvalidArgument(
"Expected number of channels in input to be divisible by "
"num_groups, but got input channel is %d and num_groups is %d",
channel_num, groups));
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(
......
......@@ -144,7 +144,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
......@@ -314,7 +315,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
......
......@@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
......@@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M) {
// TODO(gaoxiang)Because AVX/AVX2/AVX512 can not directly used
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
......@@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M * C) {
// TODO(gaoxiang)Because AVX/AVX2/AVX512 can not directly used
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
......@@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
d_x->mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册