提交 1777cd09 编写于 作者: T tensor-tang

refine fusion lstm op test

上级 4b28fab8
......@@ -43,13 +43,13 @@ def fusion_lstm(
act_cell, act_cand)
class TestLstmOp(OpTest):
def set_argument(self):
class TestFusionLSTMOp(OpTest):
def set_conf(self):
pass
def setUp(self):
self.op_type = 'fusion_lstm'
self.lod = [[2, 3, 2]]
self.lod = [[2, 3, 5, 4]]
self.M = 8
self.D = 16
self.has_initial_state = False
......@@ -58,33 +58,33 @@ class TestLstmOp(OpTest):
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.use_peepholes = False
self.set_argument()
self.set_conf()
T = sum(self.lod[0])
bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float64')
x = np.random.normal(size=(T, self.M)).astype('float32')
if self.has_initial_state:
h0 = np.random.normal(size=(bs, self.D)).astype('float64')
c0 = np.random.normal(size=(bs, self.D)).astype('float64')
h0 = np.random.normal(size=(bs, self.D)).astype('float32')
c0 = np.random.normal(size=(bs, self.D)).astype('float32')
else:
h0 = np.zeros((bs, self.D)).astype('float64')
c0 = np.zeros((bs, self.D)).astype('float64')
h0 = np.zeros((bs, self.D)).astype('float32')
c0 = np.zeros((bs, self.D)).astype('float32')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32')
if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
b = np.random.normal(size=(1, 7 * self.D)).astype('float32')
else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
b = np.random.normal(size=(1, 4 * self.D)).astype('float32')
w_b = np.copy(b[:, 0:4 * self.D])
w_c = b[:, 4 * self.D:] if self.use_peepholes else None
# this is the weight of fc
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64')
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32')
# this is the bias of fc
# and it should be manually added into the bias of this fusion LSTM
bx = np.random.normal(size=(1, 4 * self.D)).astype('float64')
bx = np.random.normal(size=(1, 4 * self.D)).astype('float32')
b[0, 0:4 * self.D] += bx[0, :]
h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
self.is_reverse, ACTIVATION[self.act_gate],
......@@ -114,35 +114,44 @@ class TestLstmOp(OpTest):
}
def test_check_output(self):
self.check_output(atol=1e-8)
self.check_output()
class TestLstmOpInitReverse(TestLstmOp):
def set_argument(self):
class TestFusionLSTMOpInit(TestFusionLSTMOp):
def set_conf(self):
self.has_initial_state = True
self.is_reverse = True
class TestLstmOpMD1(TestLstmOp):
def set_argument(self):
# class TestFusionLSTMOpReverse(TestFusionLSTMOp):
# def set_conf(self):
# self.is_reverse = True
# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
# def set_conf(self):
# self.has_initial_state = True
# self.is_reverse = True
class TestFusionLSTMOpMD1(TestFusionLSTMOp):
def set_conf(self):
self.M = 36
self.D = 8
class TestLstmOpMD2(TestLstmOp):
def set_argument(self):
class TestFusionLSTMOpMD2(TestFusionLSTMOp):
def set_conf(self):
self.M = 8
self.D = 8
class TestLstmOpMD3(TestLstmOp):
def set_argument(self):
class TestFusionLSTMOpMD3(TestFusionLSTMOp):
def set_conf(self):
self.M = 15
self.D = 3
class TestLstmOpBS1(TestLstmOp):
def set_argument(self):
class TestFusionLSTMOpBS1(TestFusionLSTMOp):
def set_conf(self):
self.lod = [[3]]
self.D = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册