未验证 提交 46371515 编写于 作者: J JYChen 提交者: GitHub

add (N,C,*) input support for GroupNorm (#34773)

* add (N,C,*) input support for GroupNorm

* --amend
上级 1aa2bde0
......@@ -37,6 +37,13 @@ class GroupNormOp : public framework::OperatorWithKernel {
"GroupNorm");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(
x_dim.size(), 2,
platform::errors::InvalidArgument(
"The Input(X)'s dimension of Op(group_norm) must be "
"greater than 1. But received: %u-D Tensor, which shape is [%s].",
x_dim.size(), x_dim));
const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_layout");
const framework::DataLayout data_layout =
......
......@@ -171,9 +171,16 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]);
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
#else
......@@ -349,8 +356,16 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]);
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
......
......@@ -68,9 +68,16 @@ class GroupNormKernel : public framework::OpKernel<T> {
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]);
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) {
......@@ -257,8 +264,16 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3]
: x_dims[1] * x_dims[2]);
int imsize = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < x_dims.size(); ++i) {
imsize *= x_dims[i];
}
} else {
for (int i = 1; i < x_dims.size() - 1; ++i) {
imsize *= x_dims[i];
}
}
auto* iter_x_data = x_data;
auto* iter_d_x_data = d_x_data;
auto* iter_y_data = y_data;
......
......@@ -25,13 +25,29 @@ from paddle.fluid import Program, program_guard
import paddle
def group_norm_naive_for_general_dimension(x, scale, bias, epsilon, groups):
# original version group norm only support 4-D tensor
# this function generalizes to support differnt dimensions tensor (>= 2-D)
input_shape = x.shape
N, C = x.shape[0], x.shape[1]
G = groups
x = x.reshape((N * G, -1))
mean = np.mean(x, axis=1, keepdims=True)
var = np.var(x, axis=1, keepdims=True)
output = (x - mean) / np.sqrt(var + epsilon)
output = output.reshape(input_shape) * scale.reshape(
(-1, 1, 1)) + bias.reshape((-1, 1, 1))
return output
class TestDygraphGroupNormv2(unittest.TestCase):
def test_dygraph(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
shapes = [[2, 2, 2, 2], [2, 2, 4], [4, 2], [4, 2, 6, 6, 2],
[2, 2, 2, 2, 2, 2]]
for p in places:
shape = [2, 2, 2, 2]
def compute_v1(x):
with fluid.dygraph.guard(p):
......@@ -62,23 +78,26 @@ class TestDygraphGroupNormv2(unittest.TestCase):
self.assertRaises(ValueError, attr_data_format)
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
result = np.allclose(y1, y2, atol=1e-5)
if not result:
print("y1:", y1, "\ty2:", y2)
self.assertTrue(result)
test_weight_bias_false()
test_nn_exception()
for shape in shapes:
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
result = np.allclose(y1, y2, atol=1e-5)
if not result:
print("y1:", y1, "\ty2:", y2)
self.assertTrue(result)
test_weight_bias_false()
test_nn_exception()
def test_static(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
shapes = [[2, 6, 2, 2], [2, 6, 4], [4, 6], [4, 6, 6, 6, 2],
[4, 6, 2, 2, 2, 2]]
for p in places:
exe = fluid.Executor(p)
shape = [2, 6, 2, 2]
def compute_v1(x_np):
with program_guard(Program(), Program()):
......@@ -98,10 +117,39 @@ class TestDygraphGroupNormv2(unittest.TestCase):
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
self.assertTrue(np.allclose(y1, y2, atol=1e-5))
for shape in shapes:
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
self.assertTrue(np.allclose(y1, y2, atol=1e-5))
class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase):
def test_numerical_accuracy(self):
paddle.disable_static()
shapes = [(2, 6), (2, 6, 4), (2, 6, 4, 4), (2, 6, 6, 6, 2), (2, 6, 6, 6,
2, 3)]
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
for place in places:
for shape in shapes:
scale = np.array([1]).astype("float32")
bias = np.array([0]).astype("float32")
data = np.random.random(shape).astype("float32")
expect_res1 = group_norm_naive_for_general_dimension(
data, scale, bias, epsilon=1e-5, groups=6)
expect_res2 = group_norm_naive_for_general_dimension(
data, scale, bias, epsilon=1e-5, groups=2)
gn1 = paddle.nn.GroupNorm(num_channels=6, num_groups=6)
gn2 = paddle.nn.GroupNorm(num_channels=6, num_groups=2)
data_pd = paddle.to_tensor(data)
result1 = gn1(data_pd).numpy()
result2 = gn2(data_pd).numpy()
self.assertTrue(np.allclose(result1, expect_res1, atol=1e-5))
self.assertTrue(np.allclose(result2, expect_res2, atol=1e-5))
if __name__ == '__main__':
......
......@@ -338,8 +338,8 @@ class GroupNorm(Layer):
name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
Shape:
- x: 4-D tensor with shape: (batch, num_features, height, weight).
- output: 4-D tensor with same shape as input x.
- x: Tensor with shape: (batch, num_features, *).
- output: The same shape as input x.
Returns:
None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册