提交 f6d07342 编写于 作者: T tensor-tang

refine fusion lstm op test

上级 92890ac2
...@@ -61,28 +61,28 @@ class TestLstmOp(OpTest): ...@@ -61,28 +61,28 @@ class TestLstmOp(OpTest):
T = sum(self.lod[0]) T = sum(self.lod[0])
bs = len(self.lod[0]) bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float32') x = np.random.normal(size=(T, self.M)).astype('float64')
if self.has_initial_state: if self.has_initial_state:
h0 = np.random.normal(size=(bs, self.D)).astype('float32') h0 = np.random.normal(size=(bs, self.D)).astype('float64')
c0 = np.random.normal(size=(bs, self.D)).astype('float32') c0 = np.random.normal(size=(bs, self.D)).astype('float64')
else: else:
h0 = np.zeros((bs, self.D)).astype('float32') h0 = np.zeros((bs, self.D)).astype('float64')
c0 = np.zeros((bs, self.D)).astype('float32') c0 = np.zeros((bs, self.D)).astype('float64')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32') wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
if self.use_peepholes: if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float32') b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
else: else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float32') b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
w_b = np.copy(b[:, 0:4 * self.D]) w_b = np.copy(b[:, 0:4 * self.D])
w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_c = b[:, 4 * self.D:] if self.use_peepholes else None
# this is the weight of fc # this is the weight of fc
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32') wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64')
# this is the bias of fc # this is the bias of fc
# and it should be manually added into the bias of this fusion LSTM # and it should be manually added into the bias of this fusion LSTM
bx = np.random.normal(size=(1, 4 * self.D)).astype('float32') bx = np.random.normal(size=(1, 4 * self.D)).astype('float64')
b[0, 0:4 * self.D] += bx[0, :] b[0, 0:4 * self.D] += bx[0, :]
h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c, h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
self.is_reverse, ACTIVATION[self.act_gate], self.is_reverse, ACTIVATION[self.act_gate],
...@@ -112,7 +112,7 @@ class TestLstmOp(OpTest): ...@@ -112,7 +112,7 @@ class TestLstmOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(atol=1e-8)
class TestLstmOpInitReverse(TestLstmOp): class TestLstmOpInitReverse(TestLstmOp):
...@@ -123,16 +123,22 @@ class TestLstmOpInitReverse(TestLstmOp): ...@@ -123,16 +123,22 @@ class TestLstmOpInitReverse(TestLstmOp):
class TestLstmOpMD1(TestLstmOp): class TestLstmOpMD1(TestLstmOp):
def set_argument(self): def set_argument(self):
self.M = 35 self.M = 36
self.D = 8 self.D = 8
class TestLstmOpMD2(TestLstmOp): class TestLstmOpMD2(TestLstmOp):
def set_argument(self): def set_argument(self):
self.M = 36 self.M = 8
self.D = 8 self.D = 8
class TestLstmOpMD3(TestLstmOp):
def set_argument(self):
self.M = 15
self.D = 3
class TestLstmOpBS1(TestLstmOp): class TestLstmOpBS1(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[3]] self.lod = [[3]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册