提交 1777cd09 编写于 作者: T tensor-tang

refine fusion lstm op test

上级 4b28fab8
...@@ -43,13 +43,13 @@ def fusion_lstm( ...@@ -43,13 +43,13 @@ def fusion_lstm(
act_cell, act_cand) act_cell, act_cand)
class TestLstmOp(OpTest): class TestFusionLSTMOp(OpTest):
def set_argument(self): def set_conf(self):
pass pass
def setUp(self): def setUp(self):
self.op_type = 'fusion_lstm' self.op_type = 'fusion_lstm'
self.lod = [[2, 3, 2]] self.lod = [[2, 3, 5, 4]]
self.M = 8 self.M = 8
self.D = 16 self.D = 16
self.has_initial_state = False self.has_initial_state = False
...@@ -58,33 +58,33 @@ class TestLstmOp(OpTest): ...@@ -58,33 +58,33 @@ class TestLstmOp(OpTest):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.use_peepholes = False self.use_peepholes = False
self.set_argument() self.set_conf()
T = sum(self.lod[0]) T = sum(self.lod[0])
bs = len(self.lod[0]) bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float64') x = np.random.normal(size=(T, self.M)).astype('float32')
if self.has_initial_state: if self.has_initial_state:
h0 = np.random.normal(size=(bs, self.D)).astype('float64') h0 = np.random.normal(size=(bs, self.D)).astype('float32')
c0 = np.random.normal(size=(bs, self.D)).astype('float64') c0 = np.random.normal(size=(bs, self.D)).astype('float32')
else: else:
h0 = np.zeros((bs, self.D)).astype('float64') h0 = np.zeros((bs, self.D)).astype('float32')
c0 = np.zeros((bs, self.D)).astype('float64') c0 = np.zeros((bs, self.D)).astype('float32')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32')
if self.use_peepholes: if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float64') b = np.random.normal(size=(1, 7 * self.D)).astype('float32')
else: else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float64') b = np.random.normal(size=(1, 4 * self.D)).astype('float32')
w_b = np.copy(b[:, 0:4 * self.D]) w_b = np.copy(b[:, 0:4 * self.D])
w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_c = b[:, 4 * self.D:] if self.use_peepholes else None
# this is the weight of fc # this is the weight of fc
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64') wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32')
# this is the bias of fc # this is the bias of fc
# and it should be manually added into the bias of this fusion LSTM # and it should be manually added into the bias of this fusion LSTM
bx = np.random.normal(size=(1, 4 * self.D)).astype('float64') bx = np.random.normal(size=(1, 4 * self.D)).astype('float32')
b[0, 0:4 * self.D] += bx[0, :] b[0, 0:4 * self.D] += bx[0, :]
h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c, h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
self.is_reverse, ACTIVATION[self.act_gate], self.is_reverse, ACTIVATION[self.act_gate],
...@@ -114,35 +114,44 @@ class TestLstmOp(OpTest): ...@@ -114,35 +114,44 @@ class TestLstmOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-8) self.check_output()
class TestLstmOpInitReverse(TestLstmOp): class TestFusionLSTMOpInit(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.has_initial_state = True self.has_initial_state = True
self.is_reverse = True
class TestLstmOpMD1(TestLstmOp): # class TestFusionLSTMOpReverse(TestFusionLSTMOp):
def set_argument(self): # def set_conf(self):
# self.is_reverse = True
# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
# def set_conf(self):
# self.has_initial_state = True
# self.is_reverse = True
class TestFusionLSTMOpMD1(TestFusionLSTMOp):
def set_conf(self):
self.M = 36 self.M = 36
self.D = 8 self.D = 8
class TestLstmOpMD2(TestLstmOp): class TestFusionLSTMOpMD2(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.M = 8 self.M = 8
self.D = 8 self.D = 8
class TestLstmOpMD3(TestLstmOp): class TestFusionLSTMOpMD3(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.M = 15 self.M = 15
self.D = 3 self.D = 3
class TestLstmOpBS1(TestLstmOp): class TestFusionLSTMOpBS1(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.lod = [[3]] self.lod = [[3]]
self.D = 16 self.D = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册