test_lstm_op.py 5.7 KB
Newer Older
1 2 3 4
import unittest
import numpy as np
from op_test import OpTest

5 6 7 8
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0

9 10 11 12 13 14

def identity(x):
    return x


def sigmoid(x):
15 16 17 18
    y = np.copy(x)
    y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
    y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
    return 1. / (1. + np.exp(-y))
19 20 21


def tanh(x):
22 23 24
    y = -2. * x
    y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
    return (2. / (1. + np.exp(y))) - 1.
25 26 27 28 29 30


def relu(x):
    return np.maximum(x, 0)


D
dangqingqing 已提交
31 32 33 34 35 36 37 38
ACTVATION = {
    'identity': identity,
    'sigmoid': sigmoid,
    'tanh': tanh,
    'relu': relu
}


39 40 41 42 43 44 45 46 47
def lstm(
        input,  # T x 4D
        lod,  # 1 x N
        h0=None,  # N x D
        c0=None,  # N x D
        w_h=None,  # D x 4D
        w_b=None,  # 1 x 4D
        w_c=None,  # 1 x 3D
        is_reverse=False,
D
dangqingqing 已提交
48 49 50 51
        act_gate=None,
        act_cell=None,
        act_cand=None):
    def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
52 53 54
        g = np.dot(h_pre, w_h)  # 1 x 4D
        g = g + x
        g = np.reshape(g, (1, g.size))
55
        c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
56
        if w_c is None:
D
dangqingqing 已提交
57 58
            g_i = act_gate(g_i)  # 1 x D
            g_f = act_gate(g_f)  # 1 x D
59 60
        else:
            w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
61 62 63
            g_i = act_gate(g_i + w_ic * c_pre)  # 1 x D
            g_f = act_gate(g_f + w_fc * c_pre)  # 1 x D
        c = g_f * c_pre + g_i * act_cand(c_tmp)  # 1 x D
64 65

        if w_c is None:
D
dangqingqing 已提交
66
            g_o = act_gate(g_o)  # 1 x D
67 68
        else:
            _, _, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
69 70 71
            g_o = act_gate(g_o + w_oc * c)  # 1 x D
        h = g_o * act_cell(c)
        bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
72
        return h, c, bg
73

D
dangqingqing 已提交
74 75 76 77 78 79 80
    def _reverse(x, lod):
        y = np.zeros_like(x)
        for i in range(len(lod) - 1):
            b, e = lod[i], lod[i + 1]
            y[b:e, :] = np.flip(x[b:e, :], 0)
        return y

81 82 83 84
    offset = lod[0]
    batch_size = len(offset) - 1
    hidden = []
    cell = []
85
    gate = []
D
dangqingqing 已提交
86
    input = _reverse(input, offset) if is_reverse else input
87 88 89 90 91 92 93
    if w_b is not None:
        input = input + np.tile(w_b, (offset[-1], 1))
    for i in range(batch_size):
        # compute one sequence
        seq_len = offset[i + 1] - offset[i]
        x = input[offset[i]:offset[i + 1], :]
        h_pre = h0[i]  # 1 x D
94
        c_pre = c0[i]  # 1 x D
95 96
        for j in range(seq_len):
            # compute one step
D
dangqingqing 已提交
97 98
            h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
                                        act_cell, act_cand)
99 100
            hidden.append(h_pre.flatten())
            cell.append(c_pre.flatten())
101
            gate.append(g_pre.flatten())
102

103 104 105
    hidden = np.array(hidden).astype('float64')
    cell = np.array(cell).astype('float64')
    gate = np.array(gate).astype('float64')
D
dangqingqing 已提交
106 107 108 109

    hidden = _reverse(hidden, offset) if is_reverse else hidden
    cell = _reverse(cell, offset) if is_reverse else cell

110
    assert gate.shape == input.shape
111 112
    assert hidden.shape == (input.shape[0], input.shape[1] / 4)
    assert cell.shape == (input.shape[0], input.shape[1] / 4)
113
    return hidden, cell, gate
114 115


D
dangqingqing 已提交
116
class TestLstmOp(OpTest):
117
    def set_data(self):
118 119 120
        # self.lod = [[0, 2, 6, 9]]
        # self.D = 64
        # self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
D
dangqingqing 已提交
121

122 123 124 125 126 127 128 129 130 131
        self.lod = [[0, 1]]
        self.D = 4
        self.sort_idx = [0]

        # self.act_gate = 'identity'
        # self.act_cell = 'identity'
        # self.act_cand = 'identity'
        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'
D
dangqingqing 已提交
132 133 134 135 136

        self.is_reverse = False

    def setUp(self):
        self.set_data()
137
        self.op_type = 'lstm'
D
dangqingqing 已提交
138 139 140 141

        T = self.lod[0][-1]
        N = len(self.lod[0]) - 1

142 143 144 145 146
        x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
        h0 = np.zeros((N, self.D)).astype('float64')
        c0 = np.zeros((N, self.D)).astype('float64')
        w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
        b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
D
dangqingqing 已提交
147 148 149 150 151 152

        w_b = b[:, 0:4 * self.D]
        w_c = b[:, 4 * self.D:]
        h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
                       ACTVATION[self.act_gate], ACTVATION[self.act_cell],
                       ACTVATION[self.act_cand])
153 154

        g_sort = np.zeros_like(x)
D
dangqingqing 已提交
155 156
        for i, j in enumerate(self.sort_idx):
            g_sort[i, :] = g[j, :]
157 158

        self.inputs = {
D
dangqingqing 已提交
159
            'Input': (x, self.lod),
160 161 162 163 164
            'H0': h0,
            'C0': c0,
            'Weight': w,
            'Bias': b
        }
165 166 167
        self.outputs = {
            'Hidden': (h, self.lod),
            'Cell': (c, self.lod),
168
            #'BatchGate': g_sort,
169
        }
170 171
        self.attrs = {
            'usePeepholes': True,
D
dangqingqing 已提交
172
            'isReverse': self.is_reverse,
173 174 175
            'gateActivation': self.act_gate,
            'cellActivation': self.act_cell,
            'candidateActivation': self.act_cand
176 177
        }

178
    def not_test_check_output(self):
179 180
        self.check_output()

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
    def test_check_grad(self):
        self.outputs['BatchGate'] = None
        self.outputs['BatchCellPreAct'] = None
        self.check_grad(['Input', 'Weight'], ['Hidden', 'Cell'])
        #['Input', 'Weight', 'Bias'], ['Hidden', 'Cell'])

    #class TestLstmOpRerverse(TestLstmOp):
    #    def set_data(self):
    #        self.lod = [[0, 2, 6, 9]]
    #        self.D = 64
    #        self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
    #
    #        self.act_gate = 'sigmoid'
    #        self.act_cell = 'tanh'
    #        self.act_cand = 'tanh'
    #
    #        self.is_reverse = True


if __name__ == '__main__':
201
    unittest.main()