未验证 提交 f218330e 编写于 作者: Z zhiboniu 提交者: GitHub

fix group_size = floor(C/groups) from ceil(C/groups); add groupnorm group divisible check (#35644)

上级 67a094b5
...@@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel { ...@@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel {
"The Attr(groups) of Op(group_norm) must be " "The Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s].", "greater than or equal to 1. But received: groups is [%s].",
groups)); groups));
PADDLE_ENFORCE_EQ(
channel_num % groups, 0,
platform::errors::InvalidArgument(
"Expected number of channels in input to be divisible by "
"num_groups, but got input channel is %d and num_groups is %d",
channel_num, groups));
if (ctx->HasInput("Scale")) { if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -144,7 +144,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T> ...@@ -144,7 +144,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
const int C = const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1; const int group_size = C / groups;
const int W = const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]); : x_dims[x_dims.size() - 2]);
...@@ -314,7 +315,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -314,7 +315,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const int C = const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1; const int group_size = C / groups;
const int W = const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]); : x_dims[x_dims.size() - 2]);
......
...@@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel<T> { ...@@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
const int C = const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1; const int group_size = C / groups;
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace()); mean->mutable_data<T>(ctx.GetPlace());
...@@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel<T> { ...@@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int imid; int imid;
for (imid = 0; imid < imsize - (imsize % M); for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M) { imid += M, iter_x_data += M) {
// TODO(gaoxiang)Because AVX/AVX2/AVX512 can not directly used // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high // in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling // performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance // loop to get high precision and performance
...@@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel<T> { ...@@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int imid; int imid;
for (imid = 0; imid < imsize - (imsize % M); for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M * C) { imid += M, iter_x_data += M * C) {
// TODO(gaoxiang)Because AVX/AVX2/AVX512 can not directly used // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high // in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling // performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance // loop to get high precision and performance
...@@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel<T> { ...@@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const int C = const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1; const int group_size = C / groups;
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册