提交 0130cc96 编写于 作者: Z Zhang Ting 提交者: Aurelius84

fixed group_norm's bug and modified unittest (#20506)

* modified group_norm's unittest for pass statement, test=develop

* fix group_norm's bug: scale or bias is None which causes segmentation fault, test=develop
上级 c7ae6c62
......@@ -220,7 +220,8 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
if (bias_data) val -= bias_data[gid * group_size + cid];
T dval = iter_y_data[0];
dp_scale += val * dval;
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data)
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data && scale_data[gid * group_size + cid] != 0)
val /= scale_data[gid * group_size + cid];
......@@ -237,8 +238,9 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
T dly = tmp_y[0];
T dss = dp_scale;
T dbs = dp_bias;
T v_scale = scale_data[gid * group_size + cid];
T v_bias = bias_data[gid * group_size + cid];
T v_scale = 1., v_bias = 0.;
if (scale_data) v_scale = scale_data[gid * group_size + cid];
if (bias_data) v_bias = bias_data[gid * group_size + cid];
v_y -= v_bias;
if (v_scale != 0) v_y /= v_scale;
iter_d_x_data[0] =
......@@ -256,7 +258,8 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
if (bias_data) val -= bias_data[gid * group_size + cid];
T dval = iter_y_data[0];
dp_scale += val * dval;
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data)
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data && scale_data[gid * group_size + cid] != 0)
val /= scale_data[gid * group_size + cid];
......@@ -276,8 +279,9 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
T dly = tmp_y[0];
T dss = dp_scale;
T dbs = dp_bias;
T v_scale = scale_data[gid * group_size + cid];
T v_bias = bias_data[gid * group_size + cid];
T v_scale = 1.0, v_bias = 0.;
if (scale_data) v_scale = scale_data[gid * group_size + cid];
if (bias_data) v_bias = bias_data[gid * group_size + cid];
v_y -= v_bias;
if (v_scale != 0) v_y /= v_scale;
iter_d_x_data[0] =
......
......@@ -195,12 +195,10 @@ class TestGroupNormOpLargeData_With_NHWC(TestGroupNormOp):
class TestGroupNormAPI_With_NHWC(OpTest):
def test_case1(self):
data1 = fluid.layers.data(
name='data1', shape=[3, 3, 4], dtype='float32')
data1 = fluid.data(name='data1', shape=[None, 3, 3, 4], dtype='float32')
out1 = fluid.layers.group_norm(
input=data1, groups=2, data_layout="NHWC")
data2 = fluid.layers.data(
name='data2', shape=[4, 3, 3], dtype='float32')
data2 = fluid.data(name='data2', shape=[None, 4, 3, 3], dtype='float32')
out2 = fluid.layers.group_norm(
input=data2, groups=2, data_layout="NCHW")
......@@ -223,14 +221,17 @@ class TestGroupNormAPI_With_NHWC(OpTest):
self.assertTrue(np.allclose(results[0], expect_res1[0]))
self.assertTrue(np.allclose(results[1], expect_res2[0]))
class TestGroupNormException(OpTest):
# data_layout is not NHWC or NCHW
def test_case2(self):
data = fluid.layers.data(name='data', shape=[3, 3, 4], dtype="float32")
try:
def test_exception(self):
data = fluid.data(name='data', shape=[None, 3, 3, 4], dtype="float32")
def attr_data_format():
out = fluid.layers.group_norm(
input=data, groups=2, data_layout="NDHW")
except:
pass
self.assertRaises(ValueError, attr_data_format)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册