未验证 提交 46371515 编写于 作者: J JYChen 提交者: GitHub

add (N,C,*) input support for GroupNorm (#34773)

* add (N,C,*) input support for GroupNorm

* --amend
上级 1aa2bde0
...@@ -37,6 +37,13 @@ class GroupNormOp : public framework::OperatorWithKernel { ...@@ -37,6 +37,13 @@ class GroupNormOp : public framework::OperatorWithKernel {
"GroupNorm"); "GroupNorm");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(
x_dim.size(), 2,
platform::errors::InvalidArgument(
"The Input(X)'s dimension of Op(group_norm) must be "
"greater than 1. But received: %u-D Tensor, which shape is [%s].",
x_dim.size(), x_dim));
const std::string data_layout_str = const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_layout"); ctx->Attrs().Get<std::string>("data_layout");
const framework::DataLayout data_layout = const framework::DataLayout data_layout =
......
...@@ -171,9 +171,16 @@ class GroupNormKernel<platform::CUDADeviceContext, T> ...@@ -171,9 +171,16 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
const T* bias_data = nullptr; const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>(); if (bias) bias_data = bias->data<T>();
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] int imsize = 1;
: x_dims[1] * x_dims[2]); if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__ #ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64); int block_size = std::max(std::min(256, imsize), 64);
#else #else
...@@ -349,8 +356,16 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -349,8 +356,16 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const T* bias_data = nullptr; const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>(); if (bias) bias_data = bias->data<T>();
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] int imsize = 1;
: x_dims[1] * x_dims[2]); if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__ #ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64); int block_size = std::max(std::min(256, imsize), 64);
......
...@@ -68,9 +68,16 @@ class GroupNormKernel : public framework::OpKernel<T> { ...@@ -68,9 +68,16 @@ class GroupNormKernel : public framework::OpKernel<T> {
const T* bias_data = nullptr; const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>(); if (bias) bias_data = bias->data<T>();
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] int imsize = 1;
: x_dims[1] * x_dims[2]); if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data; auto* iter_x_data = x_data;
auto* iter_y_data = y_data; auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) { for (int bid = 0; bid < x_dims[0]; bid++) {
...@@ -257,8 +264,16 @@ class GroupNormGradKernel : public framework::OpKernel<T> { ...@@ -257,8 +264,16 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const T* bias_data = nullptr; const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>(); if (bias) bias_data = bias->data<T>();
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] int imsize = 1;
: x_dims[1] * x_dims[2]); if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data; auto* iter_x_data = x_data;
auto* iter_d_x_data = d_x_data; auto* iter_d_x_data = d_x_data;
auto* iter_y_data = y_data; auto* iter_y_data = y_data;
......
...@@ -25,13 +25,29 @@ from paddle.fluid import Program, program_guard ...@@ -25,13 +25,29 @@ from paddle.fluid import Program, program_guard
import paddle import paddle
def group_norm_naive_for_general_dimension(x, scale, bias, epsilon, groups):
# original version group norm only support 4-D tensor
# this function generalizes to support differnt dimensions tensor (>= 2-D)
input_shape = x.shape
N, C = x.shape[0], x.shape[1]
G = groups
x = x.reshape((N * G, -1))
mean = np.mean(x, axis=1, keepdims=True)
var = np.var(x, axis=1, keepdims=True)
output = (x - mean) / np.sqrt(var + epsilon)
output = output.reshape(input_shape) * scale.reshape(
(-1, 1, 1)) + bias.reshape((-1, 1, 1))
return output
class TestDygraphGroupNormv2(unittest.TestCase): class TestDygraphGroupNormv2(unittest.TestCase):
def test_dygraph(self): def test_dygraph(self):
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0)) places.append(fluid.CUDAPlace(0))
shapes = [[2, 2, 2, 2], [2, 2, 4], [4, 2], [4, 2, 6, 6, 2],
[2, 2, 2, 2, 2, 2]]
for p in places: for p in places:
shape = [2, 2, 2, 2]
def compute_v1(x): def compute_v1(x):
with fluid.dygraph.guard(p): with fluid.dygraph.guard(p):
...@@ -62,6 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -62,6 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase):
self.assertRaises(ValueError, attr_data_format) self.assertRaises(ValueError, attr_data_format)
for shape in shapes:
x = np.random.randn(*shape).astype("float32") x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x) y1 = compute_v1(x)
y2 = compute_v2(x) y2 = compute_v2(x)
...@@ -73,12 +90,14 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -73,12 +90,14 @@ class TestDygraphGroupNormv2(unittest.TestCase):
test_nn_exception() test_nn_exception()
def test_static(self): def test_static(self):
paddle.enable_static()
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0)) places.append(fluid.CUDAPlace(0))
shapes = [[2, 6, 2, 2], [2, 6, 4], [4, 6], [4, 6, 6, 6, 2],
[4, 6, 2, 2, 2, 2]]
for p in places: for p in places:
exe = fluid.Executor(p) exe = fluid.Executor(p)
shape = [2, 6, 2, 2]
def compute_v1(x_np): def compute_v1(x_np):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -98,11 +117,40 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -98,11 +117,40 @@ class TestDygraphGroupNormv2(unittest.TestCase):
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r return r
for shape in shapes:
x = np.random.randn(*shape).astype("float32") x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x) y1 = compute_v1(x)
y2 = compute_v2(x) y2 = compute_v2(x)
self.assertTrue(np.allclose(y1, y2, atol=1e-5)) self.assertTrue(np.allclose(y1, y2, atol=1e-5))
class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase):
def test_numerical_accuracy(self):
paddle.disable_static()
shapes = [(2, 6), (2, 6, 4), (2, 6, 4, 4), (2, 6, 6, 6, 2), (2, 6, 6, 6,
2, 3)]
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
for place in places:
for shape in shapes:
scale = np.array([1]).astype("float32")
bias = np.array([0]).astype("float32")
data = np.random.random(shape).astype("float32")
expect_res1 = group_norm_naive_for_general_dimension(
data, scale, bias, epsilon=1e-5, groups=6)
expect_res2 = group_norm_naive_for_general_dimension(
data, scale, bias, epsilon=1e-5, groups=2)
gn1 = paddle.nn.GroupNorm(num_channels=6, num_groups=6)
gn2 = paddle.nn.GroupNorm(num_channels=6, num_groups=2)
data_pd = paddle.to_tensor(data)
result1 = gn1(data_pd).numpy()
result2 = gn2(data_pd).numpy()
self.assertTrue(np.allclose(result1, expect_res1, atol=1e-5))
self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -338,8 +338,8 @@ class GroupNorm(Layer): ...@@ -338,8 +338,8 @@ class GroupNorm(Layer):
name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
Shape: Shape:
- x: 4-D tensor with shape: (batch, num_features, height, weight). - x: Tensor with shape: (batch, num_features, *).
- output: 4-D tensor with same shape as input x. - output: The same shape as input x.
Returns: Returns:
None None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册