提交 f6d07342 编写于 作者: T tensor-tang

refine fusion lstm op test

上级 92890ac2
......@@ -61,28 +61,28 @@ class TestLstmOp(OpTest):
T = sum(self.lod[0])
bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float32')
x = np.random.normal(size=(T, self.M)).astype('float64')
if self.has_initial_state:
h0 = np.random.normal(size=(bs, self.D)).astype('float32')
c0 = np.random.normal(size=(bs, self.D)).astype('float32')
h0 = np.random.normal(size=(bs, self.D)).astype('float64')
c0 = np.random.normal(size=(bs, self.D)).astype('float64')
else:
h0 = np.zeros((bs, self.D)).astype('float32')
c0 = np.zeros((bs, self.D)).astype('float32')
h0 = np.zeros((bs, self.D)).astype('float64')
c0 = np.zeros((bs, self.D)).astype('float64')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float32')
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float32')
b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
w_b = np.copy(b[:, 0:4 * self.D])
w_c = b[:, 4 * self.D:] if self.use_peepholes else None
# this is the weight of fc
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32')
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64')
# this is the bias of fc
# and it should be manually added into the bias of this fusion LSTM
bx = np.random.normal(size=(1, 4 * self.D)).astype('float32')
bx = np.random.normal(size=(1, 4 * self.D)).astype('float64')
b[0, 0:4 * self.D] += bx[0, :]
h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
self.is_reverse, ACTIVATION[self.act_gate],
......@@ -112,7 +112,7 @@ class TestLstmOp(OpTest):
}
def test_check_output(self):
self.check_output()
self.check_output(atol=1e-8)
class TestLstmOpInitReverse(TestLstmOp):
......@@ -123,16 +123,22 @@ class TestLstmOpInitReverse(TestLstmOp):
class TestLstmOpMD1(TestLstmOp):
def set_argument(self):
self.M = 35
self.M = 36
self.D = 8
class TestLstmOpMD2(TestLstmOp):
def set_argument(self):
self.M = 36
self.M = 8
self.D = 8
class TestLstmOpMD3(TestLstmOp):
def set_argument(self):
self.M = 15
self.D = 3
class TestLstmOpBS1(TestLstmOp):
def set_argument(self):
self.lod = [[3]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册