From 5c6fc3f92ff05edbc77284a1ec34666eac34646e Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 24 Jan 2018 19:03:12 -0800 Subject: [PATCH] Make TestLstmpOp inherit from TestLstmOp --- python/paddle/v2/fluid/tests/test_lstm_op.py | 6 +-- python/paddle/v2/fluid/tests/test_lstmp_op.py | 52 +++---------------- 2 files changed, 11 insertions(+), 47 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_lstm_op.py b/python/paddle/v2/fluid/tests/test_lstm_op.py index d9fa01e247a..3e79f9d8e15 100644 --- a/python/paddle/v2/fluid/tests/test_lstm_op.py +++ b/python/paddle/v2/fluid/tests/test_lstm_op.py @@ -42,7 +42,7 @@ def relu(x): return np.maximum(x, 0) -ACTVATION = { +ACTIVATION = { 'identity': identity, 'sigmoid': sigmoid, 'tanh': tanh, @@ -158,8 +158,8 @@ class TestLstmOp(OpTest): w_b = b[:, 0:4 * self.D] w_c = b[:, 4 * self.D:] if self.use_peepholes else None h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, - ACTVATION[self.act_gate], ACTVATION[self.act_cell], - ACTVATION[self.act_cand]) + ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand]) self.inputs = {'Input': (x, self.lod), 'Weight': w} diff --git a/python/paddle/v2/fluid/tests/test_lstmp_op.py b/python/paddle/v2/fluid/tests/test_lstmp_op.py index f137fc61b38..92a954a9aa5 100644 --- a/python/paddle/v2/fluid/tests/test_lstmp_op.py +++ b/python/paddle/v2/fluid/tests/test_lstmp_op.py @@ -13,39 +13,13 @@ #limitations under the License. import unittest import numpy as np -from op_test import OpTest - -SIGMOID_THRESHOLD_MIN = -40.0 -SIGMOID_THRESHOLD_MAX = 13.0 -EXP_MAX_INPUT = 40.0 - - -def identity(x): - return x - - -def sigmoid(x): - y = np.copy(x) - y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN - y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX - return 1. / (1. + np.exp(-y)) - - -def tanh(x): - y = -2. * x - y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT - return (2. / (1. + np.exp(y))) - 1. - - -def relu(x): - return np.maximum(x, 0) - +import test_lstm_op as LstmTest ACTIVATION = { - 'identity': identity, - 'sigmoid': sigmoid, - 'tanh': tanh, - 'relu': relu + 'identity': LstmTest.identity, + 'sigmoid': LstmTest.sigmoid, + 'tanh': LstmTest.tanh, + 'relu': LstmTest.relu } @@ -55,7 +29,7 @@ def lstmp( lod, # 1 x N h0=None, # N x D c0=None, # N x D - w_r=None, # P x 5D + w_r=None, # P x 4D w_rh=None, # D x P w_b=None, # 1 x 4D w_c=None, # 1 x 3D @@ -130,26 +104,16 @@ def lstmp( return projection, cell -class TestLstmpOp(OpTest): +class TestLstmpOp(LstmTest.TestLstmOp): def reset_argument(self): pass def setUp(self): - self.lod = [[0, 2, 5, 7]] - # hidden size - self.D = 16 + self.set_argument() # projection size self.P = 10 - - self.act_gate = 'sigmoid' - self.act_cell = 'tanh' - self.act_cand = 'tanh' self.act_proj = self.act_cell - self.has_initial_state = False - self.is_reverse = False - self.use_peepholes = True - self.reset_argument() self.op_type = 'lstmp' -- GitLab