test_lstm_op.py 9.3 KB
Newer Older
D
dzhwinter 已提交
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18
import unittest
import numpy as np
from op_test import OpTest

19 20 21 22
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0

23 24 25 26 27 28

def identity(x):
    return x


def sigmoid(x):
29 30 31 32
    y = np.copy(x)
    y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
    y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
    return 1. / (1. + np.exp(-y))
33 34 35


def tanh(x):
36 37 38
    y = -2. * x
    y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
    return (2. / (1. + np.exp(y))) - 1.
39 40 41 42 43 44


def relu(x):
    return np.maximum(x, 0)


D
dangqingqing 已提交
45 46 47 48 49 50 51 52
ACTVATION = {
    'identity': identity,
    'sigmoid': sigmoid,
    'tanh': tanh,
    'relu': relu
}


53 54 55 56 57 58 59 60 61
def lstm(
        input,  # T x 4D
        lod,  # 1 x N
        h0=None,  # N x D
        c0=None,  # N x D
        w_h=None,  # D x 4D
        w_b=None,  # 1 x 4D
        w_c=None,  # 1 x 3D
        is_reverse=False,
D
dangqingqing 已提交
62 63 64 65
        act_gate=None,
        act_cell=None,
        act_cand=None):
    def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
66 67 68
        g = np.dot(h_pre, w_h)  # 1 x 4D
        g = g + x
        g = np.reshape(g, (1, g.size))
D
dangqingqing 已提交
69
        c, g_i, g_f, g_o = np.split(g, 4, axis=1)
70
        if w_c is None:
D
dangqingqing 已提交
71 72
            g_i = act_gate(g_i)  # 1 x D
            g_f = act_gate(g_f)  # 1 x D
73 74
        else:
            w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
75 76
            g_i = act_gate(g_i + w_ic * c_pre)  # 1 x D
            g_f = act_gate(g_f + w_fc * c_pre)  # 1 x D
D
dangqingqing 已提交
77
        c = g_f * c_pre + g_i * act_cand(c)  # 1 x D
78 79

        if w_c is None:
D
dangqingqing 已提交
80
            g_o = act_gate(g_o)  # 1 x D
81 82
        else:
            _, _, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
83 84
            g_o = act_gate(g_o + w_oc * c)  # 1 x D
        h = g_o * act_cell(c)
D
dangqingqing 已提交
85
        return h, c
86

D
dangqingqing 已提交
87 88 89 90 91 92 93
    def _reverse(x, lod):
        y = np.zeros_like(x)
        for i in range(len(lod) - 1):
            b, e = lod[i], lod[i + 1]
            y[b:e, :] = np.flip(x[b:e, :], 0)
        return y

94 95 96 97
    offset = lod[0]
    batch_size = len(offset) - 1
    hidden = []
    cell = []
D
dangqingqing 已提交
98
    input = _reverse(input, offset) if is_reverse else input
99 100 101 102 103 104 105
    if w_b is not None:
        input = input + np.tile(w_b, (offset[-1], 1))
    for i in range(batch_size):
        # compute one sequence
        seq_len = offset[i + 1] - offset[i]
        x = input[offset[i]:offset[i + 1], :]
        h_pre = h0[i]  # 1 x D
106
        c_pre = c0[i]  # 1 x D
107 108
        for j in range(seq_len):
            # compute one step
D
dangqingqing 已提交
109 110
            h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
                                 act_cell, act_cand)
111 112 113
            hidden.append(h_pre.flatten())
            cell.append(c_pre.flatten())

114 115
    hidden = np.array(hidden).astype('float64')
    cell = np.array(cell).astype('float64')
D
dangqingqing 已提交
116 117 118 119

    hidden = _reverse(hidden, offset) if is_reverse else hidden
    cell = _reverse(cell, offset) if is_reverse else cell

120 121
    assert hidden.shape == (input.shape[0], input.shape[1] / 4)
    assert cell.shape == (input.shape[0], input.shape[1] / 4)
D
dangqingqing 已提交
122
    return hidden, cell
123 124


D
dangqingqing 已提交
125
class TestLstmOp(OpTest):
126
    def set_argument(self):
127
        self.lod = [[0, 2, 5, 7]]
128 129
        self.D = 16

130 131 132
        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'
D
dangqingqing 已提交
133

D
dangqingqing 已提交
134
        self.has_initial_state = False
D
dangqingqing 已提交
135
        self.is_reverse = False
D
dangqingqing 已提交
136
        self.use_peepholes = True
D
dangqingqing 已提交
137 138

    def setUp(self):
139
        self.set_argument()
140
        self.op_type = 'lstm'
D
dangqingqing 已提交
141 142 143 144

        T = self.lod[0][-1]
        N = len(self.lod[0]) - 1

145
        x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
D
dangqingqing 已提交
146 147 148 149 150 151
        if self.has_initial_state:
            h0 = np.random.normal(size=(N, self.D)).astype('float64')
            c0 = np.random.normal(size=(N, self.D)).astype('float64')
        else:
            h0 = np.zeros((N, self.D)).astype('float64')
            c0 = np.zeros((N, self.D)).astype('float64')
152
        w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
D
dangqingqing 已提交
153 154 155 156
        if self.use_peepholes:
            b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
        else:
            b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
D
dangqingqing 已提交
157

D
dangqingqing 已提交
158 159
        w_b = b[:, 0:4 * self.D]
        w_c = b[:, 4 * self.D:] if self.use_peepholes else None
D
dangqingqing 已提交
160 161 162
        h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
                    ACTVATION[self.act_gate], ACTVATION[self.act_cell],
                    ACTVATION[self.act_cand])
163

164 165
        self.inputs = {'Input': (x, self.lod), 'Weight': w}

D
dangqingqing 已提交
166
        self.inputs['Bias'] = b
167

D
dangqingqing 已提交
168 169 170
        if self.has_initial_state:
            self.inputs['H0'] = h0
            self.inputs['C0'] = c0
171

172 173 174 175
        self.outputs = {
            'Hidden': (h, self.lod),
            'Cell': (c, self.lod),
        }
176
        self.attrs = {
D
dangqingqing 已提交
177
            'use_peepholes': self.use_peepholes,
178 179 180 181
            'is_reverse': self.is_reverse,
            'gate_activation': self.act_gate,
            'cell_activation': self.act_cell,
            'candidate_activation': self.act_cand
182 183
        }

D
dangqingqing 已提交
184
    def test_check_output(self):
D
dangqingqing 已提交
185
        self.check_output(atol=1e-8)
186

D
dangqingqing 已提交
187
    def test_check_grad(self):
D
dangqingqing 已提交
188 189 190 191 192
        # TODO(qingqing) remove folowing lines after the check_grad is refined.
        N = len(self.lod[0]) - 1
        self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
        self.outputs['BatchCellPreAct'] = np.zeros(
            (N, self.D)).astype('float64')
193
        self.check_grad(
D
dangqingqing 已提交
194
            ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
195 196


D
dangqingqing 已提交
197
class TestLstmOpHasInitial(TestLstmOp):
198
    def set_argument(self):
199
        self.lod = [[0, 2, 5, 7]]
D
dangqingqing 已提交
200
        self.D = 16
201 202 203 204 205

        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'

D
dangqingqing 已提交
206
        self.has_initial_state = True
207
        self.is_reverse = True
D
dangqingqing 已提交
208
        self.use_peepholes = True
209

D
dangqingqing 已提交
210 211 212 213 214 215 216 217 218
    def test_check_grad(self):
        # TODO(qingqing) remove folowing lines after the check_grad is refined.
        N = len(self.lod[0]) - 1
        self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
        self.outputs['BatchCellPreAct'] = np.zeros(
            (N, self.D)).astype('float64')
        self.check_grad(
            ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
            max_relative_error=5e-4)
219

D
dangqingqing 已提交
220
    def test_check_grad_ingore_bias(self):
D
dangqingqing 已提交
221 222 223 224 225 226 227 228
        N = len(self.lod[0]) - 1
        self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
        self.outputs['BatchCellPreAct'] = np.zeros(
            (N, self.D)).astype('float64')
        self.check_grad(
            ['Input', 'Weight'], ['Hidden'],
            max_relative_error=5e-4,
            no_grad_set=set('Bias'))
D
dangqingqing 已提交
229 230

    def test_check_grad_ingore_weight(self):
D
dangqingqing 已提交
231 232 233 234 235 236 237 238
        N = len(self.lod[0]) - 1
        self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
        self.outputs['BatchCellPreAct'] = np.zeros(
            (N, self.D)).astype('float64')
        self.check_grad(
            ['Input', 'Bias'], ['Hidden'],
            max_relative_error=5e-4,
            no_grad_set=set('Weight'))
D
dangqingqing 已提交
239 240

    def test_check_grad_ingore_input(self):
D
dangqingqing 已提交
241 242 243 244 245 246 247 248
        N = len(self.lod[0]) - 1
        self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
        self.outputs['BatchCellPreAct'] = np.zeros(
            (N, self.D)).astype('float64')
        self.check_grad(
            ['Weight', 'Bias'], ['Hidden'],
            max_relative_error=5e-4,
            no_grad_set=set('Input'))
D
dangqingqing 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271

    def test_check_grad_ingore_h0(self):
        N = len(self.lod[0]) - 1
        self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
        self.outputs['BatchCellPreAct'] = np.zeros(
            (N, self.D)).astype('float64')
        self.check_grad(
            ['Input', 'Weight', 'Bias', 'C0'], ['Hidden'],
            max_relative_error=5e-4,
            no_grad_set=set('H0'))

    def test_check_grad_ingore_c0(self):
        N = len(self.lod[0]) - 1
        self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
        self.outputs['BatchCellPreAct'] = np.zeros(
            (N, self.D)).astype('float64')
        self.check_grad(
            ['Input', 'Weight', 'Bias', 'H0'], ['Hidden'],
            max_relative_error=5e-4,
            no_grad_set=set('C0'))


class TestLstmOpRerverse(TestLstmOp):
272 273 274 275 276 277 278 279
    def set_argument(self):
        self.lod = [[0, 2, 5, 7]]
        self.D = 16

        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'

D
dangqingqing 已提交
280 281 282
        self.has_initial_state = False
        self.is_reverse = True
        self.use_peepholes = True
283

D
dangqingqing 已提交
284 285

class TestLstmOpNotUsePeepholes(TestLstmOp):
286
    def set_argument(self):
287
        self.lod = [[0, 2, 5, 7]]
D
dangqingqing 已提交
288
        self.D = 16
289 290 291 292 293

        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'

D
dangqingqing 已提交
294
        self.has_initial_state = False
295
        self.is_reverse = True
D
dangqingqing 已提交
296 297
        self.use_peepholes = False

298 299

if __name__ == '__main__':
300
    unittest.main()