test_lstm_op.py 5.9 KB
Newer Older
1 2 3 4
import unittest
import numpy as np
from op_test import OpTest

5 6 7 8
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0

9 10 11 12 13 14

def identity(x):
    return x


def sigmoid(x):
15 16 17 18
    y = np.copy(x)
    y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
    y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
    return 1. / (1. + np.exp(-y))
19 20 21


def tanh(x):
22 23 24
    y = -2. * x
    y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
    return (2. / (1. + np.exp(y))) - 1.
25 26 27 28 29 30


def relu(x):
    return np.maximum(x, 0)


D
dangqingqing 已提交
31 32 33 34 35 36 37 38
ACTVATION = {
    'identity': identity,
    'sigmoid': sigmoid,
    'tanh': tanh,
    'relu': relu
}


39 40 41 42 43 44 45 46 47
def lstm(
        input,  # T x 4D
        lod,  # 1 x N
        h0=None,  # N x D
        c0=None,  # N x D
        w_h=None,  # D x 4D
        w_b=None,  # 1 x 4D
        w_c=None,  # 1 x 3D
        is_reverse=False,
D
dangqingqing 已提交
48 49 50 51
        act_gate=None,
        act_cell=None,
        act_cand=None):
    def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
52 53 54
        g = np.dot(h_pre, w_h)  # 1 x 4D
        g = g + x
        g = np.reshape(g, (1, g.size))
55
        c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
56
        if w_c is None:
D
dangqingqing 已提交
57 58
            g_i = act_gate(g_i)  # 1 x D
            g_f = act_gate(g_f)  # 1 x D
59 60
        else:
            w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
61 62 63
            g_i = act_gate(g_i + w_ic * c_pre)  # 1 x D
            g_f = act_gate(g_f + w_fc * c_pre)  # 1 x D
        c = g_f * c_pre + g_i * act_cand(c_tmp)  # 1 x D
64 65

        if w_c is None:
D
dangqingqing 已提交
66
            g_o = act_gate(g_o)  # 1 x D
67 68
        else:
            _, _, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
69 70 71
            g_o = act_gate(g_o + w_oc * c)  # 1 x D
        h = g_o * act_cell(c)
        bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
72
        return h, c, bg
73

D
dangqingqing 已提交
74 75 76 77 78 79 80
    def _reverse(x, lod):
        y = np.zeros_like(x)
        for i in range(len(lod) - 1):
            b, e = lod[i], lod[i + 1]
            y[b:e, :] = np.flip(x[b:e, :], 0)
        return y

81 82 83 84
    offset = lod[0]
    batch_size = len(offset) - 1
    hidden = []
    cell = []
85
    gate = []
D
dangqingqing 已提交
86
    input = _reverse(input, offset) if is_reverse else input
87 88 89 90 91 92 93
    if w_b is not None:
        input = input + np.tile(w_b, (offset[-1], 1))
    for i in range(batch_size):
        # compute one sequence
        seq_len = offset[i + 1] - offset[i]
        x = input[offset[i]:offset[i + 1], :]
        h_pre = h0[i]  # 1 x D
94
        c_pre = c0[i]  # 1 x D
95 96
        for j in range(seq_len):
            # compute one step
D
dangqingqing 已提交
97 98
            h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
                                        act_cell, act_cand)
99 100
            hidden.append(h_pre.flatten())
            cell.append(c_pre.flatten())
101
            gate.append(g_pre.flatten())
102

103 104 105
    hidden = np.array(hidden).astype('float64')
    cell = np.array(cell).astype('float64')
    gate = np.array(gate).astype('float64')
D
dangqingqing 已提交
106 107 108 109

    hidden = _reverse(hidden, offset) if is_reverse else hidden
    cell = _reverse(cell, offset) if is_reverse else cell

110
    assert gate.shape == input.shape
111 112
    assert hidden.shape == (input.shape[0], input.shape[1] / 4)
    assert cell.shape == (input.shape[0], input.shape[1] / 4)
113
    return hidden, cell, gate
114 115


D
dangqingqing 已提交
116
class TestLstmOp(OpTest):
117 118 119 120 121
    def set_argument(self):
        self.lod = [[0, 2, 6, 9]]
        self.D = 16
        self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]

122 123 124
        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'
D
dangqingqing 已提交
125

126
        self.has_initial_state = True
D
dangqingqing 已提交
127 128 129
        self.is_reverse = False

    def setUp(self):
130
        self.set_argument()
131
        self.op_type = 'lstm'
D
dangqingqing 已提交
132 133 134 135

        T = self.lod[0][-1]
        N = len(self.lod[0]) - 1

136 137 138 139 140
        x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
        h0 = np.zeros((N, self.D)).astype('float64')
        c0 = np.zeros((N, self.D)).astype('float64')
        w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
        b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
D
dangqingqing 已提交
141 142 143 144 145 146

        w_b = b[:, 0:4 * self.D]
        w_c = b[:, 4 * self.D:]
        h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
                       ACTVATION[self.act_gate], ACTVATION[self.act_cell],
                       ACTVATION[self.act_cand])
147 148

        g_sort = np.zeros_like(x)
D
dangqingqing 已提交
149 150
        for i, j in enumerate(self.sort_idx):
            g_sort[i, :] = g[j, :]
151

152 153 154 155
        self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b}
        self.inputs['H0'] = h0
        self.inputs['C0'] = c0

156 157 158
        self.outputs = {
            'Hidden': (h, self.lod),
            'Cell': (c, self.lod),
159
            'BatchGate': g_sort,
160
        }
161 162
        self.attrs = {
            'usePeepholes': True,
D
dangqingqing 已提交
163
            'isReverse': self.is_reverse,
164 165 166
            'gateActivation': self.act_gate,
            'cellActivation': self.act_cell,
            'candidateActivation': self.act_cand
167 168
        }

169
    def test_check_output(self):
170 171
        self.check_output()

172
    #TODO(qingqing) add more unit testing case
173
    def test_check_grad(self):
174
        # TODO(qingqing) remove folowing two lines after the check_grad is refined.
175 176
        self.outputs['BatchGate'] = None
        self.outputs['BatchCellPreAct'] = None
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
        self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])


class TestLstmOpHasNoInitial(TestLstmOp):
    def set_argument(self):
        self.lod = [[0, 2, 6, 9]]
        self.D = 64
        self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]

        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'

        self.has_initial_state = False
        self.is_reverse = True


class TestLstmOpRerverse(TestLstmOp):
    def set_argument(self):
        self.lod = [[0, 2, 6, 9]]
        self.D = 64
        self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]

        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'

        self.has_initial_state = True
        self.is_reverse = True
206 207 208


if __name__ == '__main__':
209
    unittest.main()