test_recurrent_op.py 11.8 KB
Newer Older
Y
Yan Chunwei 已提交
1
import unittest
S
superjom 已提交
2

Y
Yu Yang 已提交
3
import logging
S
superjom 已提交
4

Y
Yu Yang 已提交
5 6 7 8 9 10 11
from op_test import get_numeric_gradient
from paddle.v2.framework.layers import *
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.backward import append_backward_ops
import numpy as np
import paddle.v2.framework.core as core
S
fix res  
superjom 已提交
12 13


Y
Yu Yang 已提交
14 15 16 17
class PyRNNBase(object):
    def __init__(self, input_shape, output_shape):
        self.x = np.ones(shape=input_shape).astype("float32")
        self.y = np.zeros(shape=output_shape).astype("float32")
S
superjom 已提交
18

Y
Yu Yang 已提交
19 20
    def step(self):
        pass
S
superjom 已提交
21 22 23

    def forward(self):
        for step_id in range(self.x.shape[0]):
Y
Yu Yang 已提交
24 25
            self.step(step_id, self.x[step_id])
        return np.array([np.mean(self.y)])
S
superjom 已提交
26 27 28 29

    def segment_inputs(self):
        return [self.x[i] for i in range(self.x.shape[0])]

Y
Yu Yang 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

class PySimpleRNN1(PyRNNBase):
    def __init__(self, input_shape, output_shape):
        super(PySimpleRNN1, self).__init__(input_shape, output_shape)

        seq_len, batch_size, input_dim = input_shape
        self.h_boot = np.random.normal(size=(batch_size,
                                             input_dim)).astype("float32")

        self.scale = 1.0 / 2.0
        men_dim = (seq_len, batch_size, input_dim)
        self.mems = np.zeros(shape=men_dim).astype("float32")

    def step(self, step_id, x):
        if step_id == 0:
            pre_mem = self.h_boot
        else:
            pre_mem = self.mems[step_id - 1]
        self.mems[step_id] = (pre_mem + x) * self.scale
        self.y[step_id] = self.mems[step_id]


class PySimpleRNN2(PyRNNBase):
    def __init__(self, input_shape, output_shape):
        super(PySimpleRNN2, self).__init__(input_shape, output_shape)

        seq_len, batch_size, input_dim = input_shape
        self.W = np.random.normal(size=(input_dim, input_dim)).astype("float32")
        self.U = np.random.normal(size=(input_dim, input_dim)).astype("float32")
        self.h_boot = np.ones(shape=(batch_size, input_dim)).astype("float32")

        men_dim = (seq_len, batch_size, input_dim)
        self.mems = np.zeros(shape=men_dim).astype("float32")
S
superjom 已提交
63 64 65

    def step(self, step_id, x):
        if step_id > 0:
S
fix res  
superjom 已提交
66
            pre_mem = self.mems[step_id - 1]
S
superjom 已提交
67 68
        else:
            pre_mem = self.h_boot
Q
qiaolongfei 已提交
69 70
        xW = np.matmul(x, self.W).astype("float32")
        hU = np.matmul(pre_mem, self.U).astype("float32")
S
superjom 已提交
71

Y
Yu Yang 已提交
72 73
        def py_sigmoid(x):
            return 1. / (1. + np.exp(-x))
S
fix res  
superjom 已提交
74

Y
Yu Yang 已提交
75 76
        self.mems[step_id] = py_sigmoid(xW + hU)
        self.y[step_id] = self.mems[step_id]
Y
Yan Chunwei 已提交
77 78


Y
Yu Yang 已提交
79 80 81
def create_tensor(np_data, place):
    tensor = core.LoDTensor()
    tensor.set(np_data, place)
Y
Yan Chunwei 已提交
82 83 84
    return tensor


Y
Yu Yang 已提交
85
class RecurrentOpTest1(unittest.TestCase):
Y
Yan Chunwei 已提交
86 87 88
    '''
    Test RNNOp
    equation:
Y
Yu Yang 已提交
89
        h_t = ( x_t + h_{t-1} ) / scale
Y
Yan Chunwei 已提交
90 91 92 93 94
    vars:
        - x
    memories:
        - h
    outputs:
Y
Yu Yang 已提交
95
        - h
Y
Yan Chunwei 已提交
96 97
    '''

Y
Yu Yang 已提交
98 99 100 101
    input_dim = 2
    batch_size = 1
    sent_len = 1

102 103 104
    def setup_program(self):
        self.main_program = Program()
        self.startup_program = Program()
Y
Yu Yang 已提交
105
        self.p_info = {
106 107
            "main_program": self.main_program,
            "startup_program": self.startup_program
Y
Yu Yang 已提交
108 109
        }
        self.place = core.CPUPlace()
Y
Yan Chunwei 已提交
110

S
superjom 已提交
111
    def setUp(self):
112
        self.setup_program()
Y
Yu Yang 已提交
113
        self.data_field = {"x", "h_boot"}
Y
Yan Chunwei 已提交
114

Y
Yu Yang 已提交
115 116 117 118 119
        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = PySimpleRNN1(self.input_shape, self.output_shape)

        self.output = mean(x=self.create_rnn_op(), **self.p_info)
Y
Yan Chunwei 已提交
120 121

    def create_rnn_op(self):
Y
Yu Yang 已提交
122 123 124 125 126 127
        x = data(
            shape=[self.sent_len, self.batch_size, self.input_dim],
            data_type='float32',
            name='x',
            append_batch_size=False,
            **self.p_info)
Y
Yu Yang 已提交
128
        x.stop_gradient = False
Y
Yu Yang 已提交
129 130 131 132 133
        h_boot = data(
            shape=[self.input_dim],
            data_type='float32',
            name='h_boot',
            **self.p_info)
Y
Yu Yang 已提交
134
        h_boot.stop_gradient = False
Y
Yu Yang 已提交
135

136
        rnn = StaticRNN(main_program=self.main_program)
Y
Yu Yang 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        with rnn.step():
            h_pre = rnn.memory(init=h_boot)
            x_t = rnn.step_input(x)

            h = scale(
                x=elementwise_add(
                    x=h_pre, y=x_t, **self.p_info),
                scale=self.py_rnn.scale,
                **self.p_info)

            rnn.update_memory(h_pre, h)
            rnn.output(h)

        return rnn()

    def forward(self):
        self.feed_map = {
            x: create_tensor(getattr(self.py_rnn, x), self.place)
            for x in self.data_field
        }
        exe = Executor(self.place)
158
        out = exe.run(self.main_program,
Y
Yu Yang 已提交
159 160 161 162 163 164 165 166 167 168 169
                      feed=self.feed_map,
                      fetch_list=[self.output])

        return np.array(out[0])

    def backward(self):
        self.feed_map = {
            x: create_tensor(getattr(self.py_rnn, x), self.place)
            for x in self.data_field
        }
        fetch_list = [
170
            self.main_program.global_block().var(x + "@GRAD")
Y
Yu Yang 已提交
171 172 173 174
            for x in self.data_field
        ]

        exe = Executor(self.place)
175 176 177
        return exe.run(self.main_program,
                       feed=self.feed_map,
                       fetch_list=fetch_list)
Y
Yu Yang 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193

    def test_backward(self):
        self.check_forward()

        append_backward_ops(self.output)

        ana_grad = [np.array(x) for x in self.backward()]

        num_grad = self.get_numerical_gradient()
        for idx, name in enumerate(self.data_field):
            self.assertEqual(num_grad[idx].shape, ana_grad[idx].shape)
            self.assertTrue(
                np.isclose(
                    num_grad[idx], ana_grad[idx], rtol=0.1).all())

    def check_forward(self):
S
superjom 已提交
194
        print 'test recurrent op forward'
S
superjom 已提交
195 196 197 198 199 200
        pd_output = self.forward()
        py_output = self.py_rnn.forward()
        print 'pd_output', pd_output
        print
        print 'py_output', py_output
        self.assertEqual(pd_output.shape, py_output.shape)
S
superjom 已提交
201
        self.assertTrue(np.isclose(pd_output, py_output, rtol=0.1).all())
Y
Yan Chunwei 已提交
202

Y
Yu Yang 已提交
203 204 205 206 207 208 209 210 211
    def get_numerical_gradient(self, delta=0.005):
        dloss_dout = 1.0
        feed_list = [getattr(self.py_rnn, x) for x in self.data_field]
        grad_list = [np.zeros_like(x) for x in feed_list]
        for feed, grad in zip(feed_list, grad_list):
            for f, g in np.nditer([feed, grad], op_flags=['readwrite']):
                o = float(f)
                f[...] = o + delta
                y_pos = self.forward()
S
fix res  
superjom 已提交
212

Y
Yu Yang 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
                f[...] = o - delta
                y_neg = self.forward()

                f[...] = o
                dout_dfeed = (y_pos - y_neg) / (delta * 2)
                g[...] = dout_dfeed[0]

        return grad_list


class RecurrentOpTest2(RecurrentOpTest1):
    '''
    Test RNNOp
    equation:
        h_t = \sigma (W x_t + U h_{t-1})
    weights:
        - W
        - U
    vars:
        - x
    memories:
        - h
    outputs:
       - h
    '''

    input_dim = 2
    batch_size = 10
    sent_len = 2

    def setUp(self):
244
        self.setup_program()
Y
Yu Yang 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

        self.data_field = {"x", "h_boot", "W", "U"}

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)

        self.output = mean(x=self.create_rnn_op(), **self.p_info)

    def create_rnn_op(self):
        x = data(
            shape=[self.sent_len, self.batch_size, self.input_dim],
            data_type='float32',
            name='x',
            append_batch_size=False,
            **self.p_info)
Y
Yu Yang 已提交
261
        x.stop_gradient = False
Y
Yu Yang 已提交
262 263 264 265 266
        h_boot = data(
            shape=[self.input_dim],
            data_type='float32',
            name='h_boot',
            **self.p_info)
Y
Yu Yang 已提交
267
        h_boot.stop_gradient = False
Y
Yu Yang 已提交
268

269
        rnn = StaticRNN(main_program=self.main_program)
Y
Yu Yang 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
        with rnn.step():
            h_pre = rnn.memory(init=h_boot)
            x_t = rnn.step_input(x)

            temp_l = fc(input=x_t,
                        size=self.input_dim,
                        param_attr={'name': 'W'},
                        bias_attr=False,
                        **self.p_info)
            temp_r = fc(input=h_pre,
                        size=self.input_dim,
                        param_attr={'name': 'U'},
                        bias_attr=False,
                        **self.p_info)

            h = sigmoid(
                x=elementwise_add(
                    x=temp_l, y=temp_r, **self.p_info),
                **self.p_info)

            rnn.update_memory(h_pre, h)
            rnn.output(h)

        return rnn()


class RecurrentOpTest3(RecurrentOpTest1):
    '''
    Test RNNOp with two memories
    equation:
        h_1 = h_pre_1
        h_2 = h_pre_2
        y = h_1 + h_2
    vars:
        - x
    memories:
        - h_1, h_2
    outputs:
       - y
    '''

    class PySimpleRNN3(PyRNNBase):
        def __init__(self, input_shape, output_shape):
            super(RecurrentOpTest3.PySimpleRNN3, self).__init__(input_shape,
                                                                output_shape)

            seq_len, batch_size, input_dim = input_shape
            self.h_boot1 = np.random.normal(size=(batch_size,
                                                  input_dim)).astype("float32")
            self.h_boot2 = np.random.normal(size=(batch_size,
                                                  input_dim)).astype("float32")

            men_dim = (seq_len, batch_size, input_dim)
            self.mems1 = np.zeros(shape=men_dim).astype("float32")
            self.mems2 = np.zeros(shape=men_dim).astype("float32")

        def step(self, step_id, x):
            if step_id == 0:
                pre_mem1 = self.h_boot1
                pre_mem2 = self.h_boot2
            else:
                pre_mem1 = self.mems1[step_id - 1]
                pre_mem2 = self.mems2[step_id - 1]
            self.mems1[step_id] = pre_mem1
            self.mems2[step_id] = pre_mem2
            self.y[step_id] = self.mems1[step_id] + self.mems2[step_id] + x

    input_dim = 1
    batch_size = 1
    sent_len = 2

    def setUp(self):
342
        self.setup_program()
Y
Yu Yang 已提交
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359

        self.data_field = {"x", "h_boot1", "h_boot2"}

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = RecurrentOpTest3.PySimpleRNN3(self.input_shape,
                                                    self.output_shape)

        self.output = mean(x=self.create_rnn_op(), **self.p_info)

    def create_rnn_op(self):
        x = data(
            shape=[self.sent_len, self.batch_size, self.input_dim],
            data_type='float32',
            name='x',
            append_batch_size=False,
            **self.p_info)
Y
Yu Yang 已提交
360
        x.stop_gradient = False
Y
Yu Yang 已提交
361 362 363 364 365 366
        h_boot1 = data(
            shape=[self.batch_size, self.input_dim],
            data_type='float32',
            name='h_boot1',
            append_batch_size=False,
            **self.p_info)
Y
Yu Yang 已提交
367
        h_boot1.stop_gradient = False
Y
Yu Yang 已提交
368 369 370 371 372 373
        h_boot2 = data(
            shape=[self.batch_size, self.input_dim],
            data_type='float32',
            name='h_boot2',
            append_batch_size=False,
            **self.p_info)
Y
Yu Yang 已提交
374
        h_boot2.stop_gradient = False
Y
Yu Yang 已提交
375

376
        rnn = StaticRNN(main_program=self.main_program)
Y
Yu Yang 已提交
377 378 379 380 381 382 383 384 385 386 387 388 389 390
        with rnn.step():
            h_pre1 = rnn.memory(init=h_boot1)
            h_pre2 = rnn.memory(init=h_boot2)
            x_t = rnn.step_input(x)

            mem1 = scale(x=h_pre1, scale=1.0, **self.p_info)
            mem2 = scale(x=h_pre2, scale=1.0, **self.p_info)
            out = sums(input=[mem1, x_t, mem2], **self.p_info)

            rnn.update_memory(h_pre1, mem1)
            rnn.update_memory(h_pre2, mem2)
            rnn.output(out)

        return rnn()
S
init  
superjom 已提交
391 392


Y
Yan Chunwei 已提交
393 394
if __name__ == '__main__':
    unittest.main()