From 46371515fb1c3a4085484852783ddde46b629ef6 Mon Sep 17 00:00:00 2001 From: JYChen Date: Fri, 20 Aug 2021 11:41:37 +0800 Subject: [PATCH] add (N,C,*) input support for GroupNorm (#34773) * add (N,C,*) input support for GroupNorm * --amend --- paddle/fluid/operators/group_norm_op.cc | 7 ++ paddle/fluid/operators/group_norm_op.cu | 25 ++++-- paddle/fluid/operators/group_norm_op.h | 25 ++++-- .../tests/unittests/test_group_norm_op_v2.py | 78 +++++++++++++++---- python/paddle/nn/layer/norm.py | 4 +- 5 files changed, 112 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 978170b296b..e076444626e 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -37,6 +37,13 @@ class GroupNormOp : public framework::OperatorWithKernel { "GroupNorm"); auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE( + x_dim.size(), 2, + platform::errors::InvalidArgument( + "The Input(X)'s dimension of Op(group_norm) must be " + "greater than 1. But received: %u-D Tensor, which shape is [%s].", + x_dim.size(), x_dim)); + const std::string data_layout_str = ctx->Attrs().Get("data_layout"); const framework::DataLayout data_layout = diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index 18a248f5531..f199bfeb944 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -171,9 +171,16 @@ class GroupNormKernel const T* bias_data = nullptr; if (bias) bias_data = bias->data(); - int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] - : x_dims[1] * x_dims[2]); - + int imsize = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < x_dims.size(); ++i) { + imsize *= x_dims[i]; + } + } else { + for (int i = 1; i < x_dims.size() - 1; ++i) { + imsize *= x_dims[i]; + } + } #ifdef __HIPCC__ int block_size = std::max(std::min(256, imsize), 64); #else @@ -349,8 +356,16 @@ class GroupNormGradKernel const T* bias_data = nullptr; if (bias) bias_data = bias->data(); - int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] - : x_dims[1] * x_dims[2]); + int imsize = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < x_dims.size(); ++i) { + imsize *= x_dims[i]; + } + } else { + for (int i = 1; i < x_dims.size() - 1; ++i) { + imsize *= x_dims[i]; + } + } #ifdef __HIPCC__ int block_size = std::max(std::min(256, imsize), 64); diff --git a/paddle/fluid/operators/group_norm_op.h b/paddle/fluid/operators/group_norm_op.h index 2f0edd0451a..f2388699e26 100644 --- a/paddle/fluid/operators/group_norm_op.h +++ b/paddle/fluid/operators/group_norm_op.h @@ -68,9 +68,16 @@ class GroupNormKernel : public framework::OpKernel { const T* bias_data = nullptr; if (bias) bias_data = bias->data(); - int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] - : x_dims[1] * x_dims[2]); - + int imsize = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < x_dims.size(); ++i) { + imsize *= x_dims[i]; + } + } else { + for (int i = 1; i < x_dims.size() - 1; ++i) { + imsize *= x_dims[i]; + } + } auto* iter_x_data = x_data; auto* iter_y_data = y_data; for (int bid = 0; bid < x_dims[0]; bid++) { @@ -257,8 +264,16 @@ class GroupNormGradKernel : public framework::OpKernel { const T* bias_data = nullptr; if (bias) bias_data = bias->data(); - int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] - : x_dims[1] * x_dims[2]); + int imsize = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < x_dims.size(); ++i) { + imsize *= x_dims[i]; + } + } else { + for (int i = 1; i < x_dims.size() - 1; ++i) { + imsize *= x_dims[i]; + } + } auto* iter_x_data = x_data; auto* iter_d_x_data = d_x_data; auto* iter_y_data = y_data; diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py index 0e13ca17562..fbdf4a1cfd1 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py @@ -25,13 +25,29 @@ from paddle.fluid import Program, program_guard import paddle +def group_norm_naive_for_general_dimension(x, scale, bias, epsilon, groups): + # original version group norm only support 4-D tensor + # this function generalizes to support differnt dimensions tensor (>= 2-D) + input_shape = x.shape + N, C = x.shape[0], x.shape[1] + G = groups + x = x.reshape((N * G, -1)) + mean = np.mean(x, axis=1, keepdims=True) + var = np.var(x, axis=1, keepdims=True) + output = (x - mean) / np.sqrt(var + epsilon) + output = output.reshape(input_shape) * scale.reshape( + (-1, 1, 1)) + bias.reshape((-1, 1, 1)) + return output + + class TestDygraphGroupNormv2(unittest.TestCase): def test_dygraph(self): places = [fluid.CPUPlace()] if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): places.append(fluid.CUDAPlace(0)) + shapes = [[2, 2, 2, 2], [2, 2, 4], [4, 2], [4, 2, 6, 6, 2], + [2, 2, 2, 2, 2, 2]] for p in places: - shape = [2, 2, 2, 2] def compute_v1(x): with fluid.dygraph.guard(p): @@ -62,23 +78,26 @@ class TestDygraphGroupNormv2(unittest.TestCase): self.assertRaises(ValueError, attr_data_format) - x = np.random.randn(*shape).astype("float32") - y1 = compute_v1(x) - y2 = compute_v2(x) - result = np.allclose(y1, y2, atol=1e-5) - if not result: - print("y1:", y1, "\ty2:", y2) - self.assertTrue(result) - test_weight_bias_false() - test_nn_exception() + for shape in shapes: + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + result = np.allclose(y1, y2, atol=1e-5) + if not result: + print("y1:", y1, "\ty2:", y2) + self.assertTrue(result) + test_weight_bias_false() + test_nn_exception() def test_static(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): places.append(fluid.CUDAPlace(0)) + shapes = [[2, 6, 2, 2], [2, 6, 4], [4, 6], [4, 6, 6, 6, 2], + [4, 6, 2, 2, 2, 2]] for p in places: exe = fluid.Executor(p) - shape = [2, 6, 2, 2] def compute_v1(x_np): with program_guard(Program(), Program()): @@ -98,10 +117,39 @@ class TestDygraphGroupNormv2(unittest.TestCase): r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] return r - x = np.random.randn(*shape).astype("float32") - y1 = compute_v1(x) - y2 = compute_v2(x) - self.assertTrue(np.allclose(y1, y2, atol=1e-5)) + for shape in shapes: + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2, atol=1e-5)) + + +class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase): + def test_numerical_accuracy(self): + paddle.disable_static() + shapes = [(2, 6), (2, 6, 4), (2, 6, 4, 4), (2, 6, 6, 6, 2), (2, 6, 6, 6, + 2, 3)] + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(fluid.CUDAPlace(0)) + + for place in places: + for shape in shapes: + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + data = np.random.random(shape).astype("float32") + expect_res1 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=6) + expect_res2 = group_norm_naive_for_general_dimension( + data, scale, bias, epsilon=1e-5, groups=2) + + gn1 = paddle.nn.GroupNorm(num_channels=6, num_groups=6) + gn2 = paddle.nn.GroupNorm(num_channels=6, num_groups=2) + data_pd = paddle.to_tensor(data) + result1 = gn1(data_pd).numpy() + result2 = gn2(data_pd).numpy() + self.assertTrue(np.allclose(result1, expect_res1, atol=1e-5)) + self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5)) if __name__ == '__main__': diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 41599809810..147e7fca3ff 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -338,8 +338,8 @@ class GroupNorm(Layer): name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: 4-D tensor with shape: (batch, num_features, height, weight). - - output: 4-D tensor with same shape as input x. + - x: Tensor with shape: (batch, num_features, *). + - output: The same shape as input x. Returns: None -- GitLab