未验证 提交 d9c3123a 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13181 from tensor-tang/refine/fusion/ut

refine fusion lstm and gru op test
...@@ -37,7 +37,7 @@ def fusion_gru( ...@@ -37,7 +37,7 @@ def fusion_gru(
h0, h0,
wh, wh,
np.zeros( np.zeros(
(1, wh.shape[1]), dtype='float64'), (1, wh.shape[1]), dtype='float32'),
is_reverse, is_reverse,
act_state, act_state,
act_gate) act_gate)
...@@ -62,15 +62,15 @@ class TestFusionGRUOp(OpTest): ...@@ -62,15 +62,15 @@ class TestFusionGRUOp(OpTest):
T = sum(self.lod[0]) T = sum(self.lod[0])
N = len(self.lod[0]) N = len(self.lod[0])
x = np.random.rand(T, self.M).astype('float64') x = np.random.rand(T, self.M).astype('float32')
wx = np.random.rand(self.M, 3 * self.D).astype('float64') wx = np.random.rand(self.M, 3 * self.D).astype('float32')
wh = np.random.rand(self.D, 3 * self.D).astype('float64') wh = np.random.rand(self.D, 3 * self.D).astype('float32')
bias = np.random.rand( bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros( 1, 3 * self.D).astype('float32') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64') (1, 3 * self.D), dtype='float32')
h0 = np.random.rand( h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros( N, self.D).astype('float32') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64') (N, self.D), dtype='float32')
_, _, _, hidden = fusion_gru( _, _, _, hidden = fusion_gru(
x, self.lod, h0, wx, wh, bias, self.is_reverse, x, self.lod, h0, wx, wh, bias, self.is_reverse,
...@@ -93,7 +93,9 @@ class TestFusionGRUOp(OpTest): ...@@ -93,7 +93,9 @@ class TestFusionGRUOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-8) for use_seq in {True, False}:
self.attrs['use_seq'] = use_seq
self.check_output()
class TestFusionGRUOpNoInitial(TestFusionGRUOp): class TestFusionGRUOpNoInitial(TestFusionGRUOp):
......
...@@ -114,6 +114,8 @@ class TestFusionLSTMOp(OpTest): ...@@ -114,6 +114,8 @@ class TestFusionLSTMOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
for use_seq in {True, False}:
self.attrs['use_seq'] = use_seq
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册