提交 5c6fc3f9 编写于 作者: Y Yibing Liu

Make TestLstmpOp inherit from TestLstmOp

上级 9ecc54a1
......@@ -42,7 +42,7 @@ def relu(x):
return np.maximum(x, 0)
ACTVATION = {
ACTIVATION = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
......@@ -158,8 +158,8 @@ class TestLstmOp(OpTest):
w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:] if self.use_peepholes else None
h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])
ACTIVATION[self.act_gate], ACTIVATION[self.act_cell],
ACTIVATION[self.act_cand])
self.inputs = {'Input': (x, self.lod), 'Weight': w}
......
......@@ -13,39 +13,13 @@
#limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def identity(x):
return x
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
def relu(x):
return np.maximum(x, 0)
import test_lstm_op as LstmTest
ACTIVATION = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
'relu': relu
'identity': LstmTest.identity,
'sigmoid': LstmTest.sigmoid,
'tanh': LstmTest.tanh,
'relu': LstmTest.relu
}
......@@ -55,7 +29,7 @@ def lstmp(
lod, # 1 x N
h0=None, # N x D
c0=None, # N x D
w_r=None, # P x 5D
w_r=None, # P x 4D
w_rh=None, # D x P
w_b=None, # 1 x 4D
w_c=None, # 1 x 3D
......@@ -130,26 +104,16 @@ def lstmp(
return projection, cell
class TestLstmpOp(OpTest):
class TestLstmpOp(LstmTest.TestLstmOp):
def reset_argument(self):
pass
def setUp(self):
self.lod = [[0, 2, 5, 7]]
# hidden size
self.D = 16
self.set_argument()
# projection size
self.P = 10
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.act_proj = self.act_cell
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = True
self.reset_argument()
self.op_type = 'lstmp'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册