diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index e076444626e6afa7893b539481804964750e48bd..2c7fd8f4173ea72f0b40fe0d7620f168554fd33f 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel { "The Attr(groups) of Op(group_norm) must be " "greater than or equal to 1. But received: groups is [%s].", groups)); + PADDLE_ENFORCE_EQ( + channel_num % groups, 0, + platform::errors::InvalidArgument( + "Expected number of channels in input to be divisible by " + "num_groups, but got input channel is %d and num_groups is %d", + channel_num, groups)); if (ctx->HasInput("Scale")) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index f199bfeb9443b668b5833d9a011e36577b66bb4a..e029c84090af19e186e30be63f28b01270ef94c5 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -144,7 +144,8 @@ class GroupNormKernel const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; + const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2]); @@ -314,7 +315,7 @@ class GroupNormGradKernel const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2]); diff --git a/paddle/fluid/operators/group_norm_op.h b/paddle/fluid/operators/group_norm_op.h index f2388699e266f52fd1b06612ee4f78fb4ec88b21..9cb451235f152cc855e4b47388b9ce13e7ff8911 100644 --- a/paddle/fluid/operators/group_norm_op.h +++ b/paddle/fluid/operators/group_norm_op.h @@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel { const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; y->mutable_data(ctx.GetPlace()); mean->mutable_data(ctx.GetPlace()); @@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel { int imid; for (imid = 0; imid < imsize - (imsize % M); imid += M, iter_x_data += M) { - // TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used + // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used // in template class/function, before we complete high // performance cpu vector extension, temporarily unrolling // loop to get high precision and performance @@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel { int imid; for (imid = 0; imid < imsize - (imsize % M); imid += M, iter_x_data += M * C) { - // TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used + // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used // in template class/function, before we complete high // performance cpu vector extension, temporarily unrolling // loop to get high precision and performance @@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel { const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; d_x->mutable_data(ctx.GetPlace()); math::SetConstant set_zero;