提交 5c6fc3f9 编写于 作者: Y Yibing Liu

Make TestLstmpOp inherit from TestLstmOp

上级 9ecc54a1
...@@ -42,7 +42,7 @@ def relu(x): ...@@ -42,7 +42,7 @@ def relu(x):
return np.maximum(x, 0) return np.maximum(x, 0)
ACTVATION = { ACTIVATION = {
'identity': identity, 'identity': identity,
'sigmoid': sigmoid, 'sigmoid': sigmoid,
'tanh': tanh, 'tanh': tanh,
...@@ -158,8 +158,8 @@ class TestLstmOp(OpTest): ...@@ -158,8 +158,8 @@ class TestLstmOp(OpTest):
w_b = b[:, 0:4 * self.D] w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_c = b[:, 4 * self.D:] if self.use_peepholes else None
h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell], ACTIVATION[self.act_gate], ACTIVATION[self.act_cell],
ACTVATION[self.act_cand]) ACTIVATION[self.act_cand])
self.inputs = {'Input': (x, self.lod), 'Weight': w} self.inputs = {'Input': (x, self.lod), 'Weight': w}
......
...@@ -13,39 +13,13 @@ ...@@ -13,39 +13,13 @@
#limitations under the License. #limitations under the License.
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest import test_lstm_op as LstmTest
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def identity(x):
return x
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
def relu(x):
return np.maximum(x, 0)
ACTIVATION = { ACTIVATION = {
'identity': identity, 'identity': LstmTest.identity,
'sigmoid': sigmoid, 'sigmoid': LstmTest.sigmoid,
'tanh': tanh, 'tanh': LstmTest.tanh,
'relu': relu 'relu': LstmTest.relu
} }
...@@ -55,7 +29,7 @@ def lstmp( ...@@ -55,7 +29,7 @@ def lstmp(
lod, # 1 x N lod, # 1 x N
h0=None, # N x D h0=None, # N x D
c0=None, # N x D c0=None, # N x D
w_r=None, # P x 5D w_r=None, # P x 4D
w_rh=None, # D x P w_rh=None, # D x P
w_b=None, # 1 x 4D w_b=None, # 1 x 4D
w_c=None, # 1 x 3D w_c=None, # 1 x 3D
...@@ -130,26 +104,16 @@ def lstmp( ...@@ -130,26 +104,16 @@ def lstmp(
return projection, cell return projection, cell
class TestLstmpOp(OpTest): class TestLstmpOp(LstmTest.TestLstmOp):
def reset_argument(self): def reset_argument(self):
pass pass
def setUp(self): def setUp(self):
self.lod = [[0, 2, 5, 7]] self.set_argument()
# hidden size
self.D = 16
# projection size # projection size
self.P = 10 self.P = 10
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell self.act_proj = self.act_cell
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = True
self.reset_argument() self.reset_argument()
self.op_type = 'lstmp' self.op_type = 'lstmp'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册