未验证 提交 56882ce4 编写于 作者: L lijianshe02 提交者: GitHub

change input data type and decrease max_relative_error value in...

change input data type and decrease max_relative_error value in test_check_grad for grop_nom_op test test=develop (#21608)
上级 84b72671
...@@ -44,7 +44,7 @@ class TestGroupNormOp(OpTest): ...@@ -44,7 +44,7 @@ class TestGroupNormOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "group_norm" self.op_type = "group_norm"
self.data_format = "NCHW" self.data_format = "NCHW"
self.dtype = np.float32 self.dtype = np.float64
self.shape = (2, 4, 3, 3) self.shape = (2, 4, 3, 3)
self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"} self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"}
self.compare_between_place = False self.compare_between_place = False
...@@ -113,7 +113,7 @@ class TestGroupNormOp(OpTest): ...@@ -113,7 +113,7 @@ class TestGroupNormOp(OpTest):
place, place,
set(['X', 'Scale', 'Bias']), set(['X', 'Scale', 'Bias']),
'Y', 'Y',
max_relative_error=0.01) max_relative_error=0.005)
def init_test_case(self): def init_test_case(self):
pass pass
...@@ -193,19 +193,19 @@ class TestGroupNormOpLargeData_With_NHWC(TestGroupNormOp): ...@@ -193,19 +193,19 @@ class TestGroupNormOpLargeData_With_NHWC(TestGroupNormOp):
self.compare_between_place = True self.compare_between_place = True
class TestGroupNormAPI_With_NHWC(OpTest): class TestGroupNormAPI_With_NHWC(unittest.TestCase):
def test_case1(self): def test_case1(self):
data1 = fluid.data(name='data1', shape=[None, 3, 3, 4], dtype='float32') data1 = fluid.data(name='data1', shape=[None, 3, 3, 4], dtype='float64')
out1 = fluid.layers.group_norm( out1 = fluid.layers.group_norm(
input=data1, groups=2, data_layout="NHWC") input=data1, groups=2, data_layout="NHWC")
data2 = fluid.data(name='data2', shape=[None, 4, 3, 3], dtype='float32') data2 = fluid.data(name='data2', shape=[None, 4, 3, 3], dtype='float64')
out2 = fluid.layers.group_norm( out2 = fluid.layers.group_norm(
input=data2, groups=2, data_layout="NCHW") input=data2, groups=2, data_layout="NCHW")
data1_np = np.random.random((2, 3, 3, 4)).astype("float32") data1_np = np.random.random((2, 3, 3, 4)).astype("float64")
data2_np = np.random.random((2, 4, 3, 3)).astype("float32") data2_np = np.random.random((2, 4, 3, 3)).astype("float64")
scale = np.array([1]).astype("float32") scale = np.array([1]).astype("float64")
bias = np.array([0]).astype("float32") bias = np.array([0]).astype("float64")
place = core.CPUPlace() place = core.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -225,7 +225,7 @@ class TestGroupNormAPI_With_NHWC(OpTest): ...@@ -225,7 +225,7 @@ class TestGroupNormAPI_With_NHWC(OpTest):
class TestGroupNormException(unittest.TestCase): class TestGroupNormException(unittest.TestCase):
# data_layout is not NHWC or NCHW # data_layout is not NHWC or NCHW
def test_exception(self): def test_exception(self):
data = fluid.data(name='data', shape=[None, 3, 3, 4], dtype="float32") data = fluid.data(name='data', shape=[None, 3, 3, 4], dtype="float64")
def attr_data_format(): def attr_data_format():
out = fluid.layers.group_norm( out = fluid.layers.group_norm(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册