未验证 提交 84445499 编写于 作者: Y yangjianfengo1 提交者: GitHub

Increase relative error of test_group_norm_op unittest (#55943)

* fix fp 16

* bf16 rtol

* fixed input

* code style
上级 dc4b48f6
......@@ -357,6 +357,37 @@ class TestGroupNormFP16Op_With_NHWC(TestGroupNormFP16OP):
self.attrs['epsilon'] = 0.5
self.shape = (1, 100, 4, 4)
self.dtype = np.float16
input = np.sin(
np.arange(
self.shape[0] * self.shape[1] * self.shape[2] * self.shape[3]
)
)
input = np.transpose(input.reshape(self.shape), (0, 2, 3, 1)).astype(
self.dtype
)
scale = np.sin(np.arange(self.shape[1])).astype(self.dtype)
bias = np.sin(np.arange(self.shape[1])).astype(self.dtype)
output, mean, var = group_norm_naive(
input,
scale,
bias,
self.attrs['epsilon'],
self.attrs['groups'],
self.data_format,
)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(input),
'Scale': OpTest.np_dtype_to_fluid_dtype(scale),
'Bias': OpTest.np_dtype_to_fluid_dtype(bias),
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
self.attrs['data_layout'] = self.data_format
def test_check_output(self):
rtol = 2e-3
place = core.CUDAPlace(0)
self.check_output_with_place(place, rtol=rtol)
class TestGroupNormBF16Op_With_NHWC(TestGroupNormBF16Op):
......@@ -374,10 +405,20 @@ class TestGroupNormBF16Op_With_NHWC(TestGroupNormBF16Op):
}
self.compare_between_place = False
self.init_test_case()
input = np.random.random(self.shape).astype(np.float32)
scale = np.random.random([self.shape[3]]).astype(np.float32)
bias = np.random.random([self.shape[3]]).astype(np.float32)
input = (
np.sin(
np.arange(
self.shape[0]
* self.shape[1]
* self.shape[2]
* self.shape[3]
)
)
.reshape(self.shape)
.astype(np.float32)
)
scale = np.sin(np.arange(self.shape[3])).astype(np.float32)
bias = np.sin(np.arange(self.shape[3])).astype(np.float32)
output, mean, var = group_norm_naive(
input,
scale,
......@@ -394,6 +435,11 @@ class TestGroupNormBF16Op_With_NHWC(TestGroupNormBF16Op):
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
def test_check_output(self):
rtol = 2e-2
place = core.CUDAPlace(0)
self.check_output_with_place(place, rtol=rtol)
class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp):
def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册