test_recurrent_op.py 21.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
Yan Chunwei 已提交
17
import unittest
18
import paddle
C
chengduo 已提交
19
import paddle.fluid as fluid
20
import paddle.fluid.layers as layers
21 22 23 24
import numpy as np
import paddle.fluid.core as core

from paddle.fluid import ParamAttr
25 26 27
from paddle.fluid.framework import Program, grad_var_name
from paddle.fluid.executor import Executor
from paddle.fluid.backward import append_backward
S
fix res  
superjom 已提交
28

29 30
np.random.seed(123)

S
fix res  
superjom 已提交
31

Y
Yu Yang 已提交
32
class PyRNNBase(object):
33

Y
Yu Yang 已提交
34 35 36
    def __init__(self, input_shape, output_shape):
        self.x = np.ones(shape=input_shape).astype("float32")
        self.y = np.zeros(shape=output_shape).astype("float32")
S
superjom 已提交
37

38 39
    def step(self, step_id, x):
        raise NotImplementedError
S
superjom 已提交
40 41 42

    def forward(self):
        for step_id in range(self.x.shape[0]):
Y
Yu Yang 已提交
43 44
            self.step(step_id, self.x[step_id])
        return np.array([np.mean(self.y)])
S
superjom 已提交
45 46 47 48

    def segment_inputs(self):
        return [self.x[i] for i in range(self.x.shape[0])]

Y
Yu Yang 已提交
49 50

class PySimpleRNN1(PyRNNBase):
51

Y
Yu Yang 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    def __init__(self, input_shape, output_shape):
        super(PySimpleRNN1, self).__init__(input_shape, output_shape)

        seq_len, batch_size, input_dim = input_shape
        self.h_boot = np.random.normal(size=(batch_size,
                                             input_dim)).astype("float32")

        self.scale = 1.0 / 2.0
        men_dim = (seq_len, batch_size, input_dim)
        self.mems = np.zeros(shape=men_dim).astype("float32")

    def step(self, step_id, x):
        if step_id == 0:
            pre_mem = self.h_boot
        else:
            pre_mem = self.mems[step_id - 1]
        self.mems[step_id] = (pre_mem + x) * self.scale
        self.y[step_id] = self.mems[step_id]


class PySimpleRNN2(PyRNNBase):
73

Y
Yu Yang 已提交
74 75 76 77
    def __init__(self, input_shape, output_shape):
        super(PySimpleRNN2, self).__init__(input_shape, output_shape)

        seq_len, batch_size, input_dim = input_shape
78 79
        self.W = np.ones(shape=(input_dim, input_dim)).astype("float32")
        self.U = np.zeros(shape=(input_dim, input_dim)).astype("float32")
Y
Yu Yang 已提交
80 81 82 83
        self.h_boot = np.ones(shape=(batch_size, input_dim)).astype("float32")

        men_dim = (seq_len, batch_size, input_dim)
        self.mems = np.zeros(shape=men_dim).astype("float32")
S
superjom 已提交
84 85 86

    def step(self, step_id, x):
        if step_id > 0:
S
fix res  
superjom 已提交
87
            pre_mem = self.mems[step_id - 1]
S
superjom 已提交
88 89
        else:
            pre_mem = self.h_boot
Q
qiaolongfei 已提交
90 91
        xW = np.matmul(x, self.W).astype("float32")
        hU = np.matmul(pre_mem, self.U).astype("float32")
S
superjom 已提交
92

Y
Yu Yang 已提交
93 94
        def py_sigmoid(x):
            return 1. / (1. + np.exp(-x))
S
fix res  
superjom 已提交
95

Y
Yu Yang 已提交
96 97
        self.mems[step_id] = py_sigmoid(xW + hU)
        self.y[step_id] = self.mems[step_id]
Y
Yan Chunwei 已提交
98 99


Y
Yu Yang 已提交
100 101 102
def create_tensor(np_data, place):
    tensor = core.LoDTensor()
    tensor.set(np_data, place)
Y
Yan Chunwei 已提交
103 104 105
    return tensor


Y
Yu Yang 已提交
106
class RecurrentOpTest1(unittest.TestCase):
Y
Yan Chunwei 已提交
107 108 109
    '''
    Test RNNOp
    equation:
Y
Yu Yang 已提交
110
        h_t = ( x_t + h_{t-1} ) / scale
Y
Yan Chunwei 已提交
111 112 113 114 115
    vars:
        - x
    memories:
        - h
    outputs:
Y
Yu Yang 已提交
116
        - h
Y
Yan Chunwei 已提交
117 118
    '''

Y
Yu Yang 已提交
119 120 121 122
    input_dim = 2
    batch_size = 1
    sent_len = 1

123 124 125
    def setup_program(self):
        self.main_program = Program()
        self.startup_program = Program()
Y
Yu Yang 已提交
126
        self.place = core.CPUPlace()
Y
Yan Chunwei 已提交
127

S
superjom 已提交
128
    def setUp(self):
129
        self.setup_program()
130 131
        self.feed_data_field = {"x", "h_boot"}
        self.grad_data_field = self.feed_data_field
Y
Yan Chunwei 已提交
132

Y
Yu Yang 已提交
133 134 135 136
        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = PySimpleRNN1(self.input_shape, self.output_shape)

C
chengduo 已提交
137
        with fluid.program_guard(self.main_program, self.startup_program):
138
            self.output = paddle.mean(self.create_rnn_op())
Y
Yan Chunwei 已提交
139 140

    def create_rnn_op(self):
141 142 143 144
        x = layers.data(shape=[self.sent_len, self.batch_size, self.input_dim],
                        dtype='float32',
                        name='x',
                        append_batch_size=False)
Y
Yu Yang 已提交
145
        x.stop_gradient = False
146 147 148
        h_boot = layers.data(shape=[self.input_dim],
                             dtype='float32',
                             name='h_boot')
Y
Yu Yang 已提交
149
        h_boot.stop_gradient = False
Y
Yu Yang 已提交
150

C
chengduo 已提交
151
        rnn = layers.StaticRNN()
Y
Yu Yang 已提交
152 153 154 155
        with rnn.step():
            h_pre = rnn.memory(init=h_boot)
            x_t = rnn.step_input(x)

156 157
            h = layers.scale(x=layers.elementwise_add(x=h_pre, y=x_t),
                             scale=self.py_rnn.scale)
Y
Yu Yang 已提交
158 159 160 161 162 163 164 165 166

            rnn.update_memory(h_pre, h)
            rnn.output(h)

        return rnn()

    def forward(self):
        self.feed_map = {
            x: create_tensor(getattr(self.py_rnn, x), self.place)
167
            for x in self.feed_data_field
Y
Yu Yang 已提交
168 169
        }
        exe = Executor(self.place)
170
        out = exe.run(self.main_program,
Y
Yu Yang 已提交
171 172 173
                      feed=self.feed_map,
                      fetch_list=[self.output])

D
dzhwinter 已提交
174
        return out[0]
Y
Yu Yang 已提交
175 176 177 178

    def backward(self):
        self.feed_map = {
            x: create_tensor(getattr(self.py_rnn, x), self.place)
179
            for x in self.feed_data_field
Y
Yu Yang 已提交
180 181
        }
        fetch_list = [
Q
qiaolongfei 已提交
182
            self.main_program.global_block().var(grad_var_name(x))
183
            for x in self.grad_data_field
Y
Yu Yang 已提交
184 185 186
        ]

        exe = Executor(self.place)
187 188
        return exe.run(self.main_program,
                       feed=self.feed_map,
D
dzhwinter 已提交
189 190
                       fetch_list=fetch_list,
                       return_numpy=False)
Y
Yu Yang 已提交
191

192
    def test_backward(self, rtol=0.01):
Y
Yu Yang 已提交
193 194
        self.check_forward()

C
chengduo 已提交
195 196
        with fluid.program_guard(self.main_program, self.startup_program):
            append_backward(self.output)
Y
Yu Yang 已提交
197 198 199 200

        ana_grad = [np.array(x) for x in self.backward()]

        num_grad = self.get_numerical_gradient()
201
        for idx, name in enumerate(self.grad_data_field):
Y
Yu Yang 已提交
202
            self.assertEqual(num_grad[idx].shape, ana_grad[idx].shape)
203 204 205 206 207 208 209 210 211
            np.testing.assert_allclose(
                num_grad[idx],
                ana_grad[idx],
                rtol=rtol,
                atol=1e-8,
                err_msg='num_grad (' + name + ') has diff at ' +
                str(self.place) + '\nExpect ' + str(num_grad[idx]) + '\n' +
                'But Got' + str(ana_grad[idx]) + ' in class ' +
                self.__class__.__name__)
Y
Yu Yang 已提交
212 213

    def check_forward(self):
S
superjom 已提交
214 215 216
        pd_output = self.forward()
        py_output = self.py_rnn.forward()
        self.assertEqual(pd_output.shape, py_output.shape)
217
        np.testing.assert_allclose(pd_output, py_output, rtol=0.01)
Y
Yan Chunwei 已提交
218

Y
Yu Yang 已提交
219 220
    def get_numerical_gradient(self, delta=0.005):
        dloss_dout = 1.0
221
        feed_list = [getattr(self.py_rnn, x) for x in self.grad_data_field]
Y
Yu Yang 已提交
222 223 224 225 226 227
        grad_list = [np.zeros_like(x) for x in feed_list]
        for feed, grad in zip(feed_list, grad_list):
            for f, g in np.nditer([feed, grad], op_flags=['readwrite']):
                o = float(f)
                f[...] = o + delta
                y_pos = self.forward()
S
fix res  
superjom 已提交
228

Y
Yu Yang 已提交
229 230 231 232 233 234 235 236 237 238 239
                f[...] = o - delta
                y_neg = self.forward()

                f[...] = o
                dout_dfeed = (y_pos - y_neg) / (delta * 2)
                g[...] = dout_dfeed[0]

        return grad_list


class RecurrentOpTest2(RecurrentOpTest1):
240
    r'''
Y
Yu Yang 已提交
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    Test RNNOp
    equation:
        h_t = \sigma (W x_t + U h_{t-1})
    weights:
        - W
        - U
    vars:
        - x
    memories:
        - h
    outputs:
       - h
    '''

    input_dim = 2
    batch_size = 10
    sent_len = 2

    def setUp(self):
260
        self.setup_program()
Y
Yu Yang 已提交
261

262 263
        self.feed_data_field = {"x", "h_boot", "W", "U"}
        self.grad_data_field = self.feed_data_field
Y
Yu Yang 已提交
264 265 266 267 268

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)

C
chengduo 已提交
269
        with fluid.program_guard(self.main_program, self.startup_program):
270
            self.output = paddle.mean(self.create_rnn_op())
Y
Yu Yang 已提交
271 272

    def create_rnn_op(self):
273 274 275 276
        x = layers.data(shape=[self.sent_len, self.batch_size, self.input_dim],
                        dtype='float32',
                        name='x',
                        append_batch_size=False)
Y
Yu Yang 已提交
277
        x.stop_gradient = False
278 279 280
        h_boot = layers.data(shape=[self.input_dim],
                             dtype='float32',
                             name='h_boot')
Y
Yu Yang 已提交
281
        h_boot.stop_gradient = False
Y
Yu Yang 已提交
282

C
chengduo 已提交
283
        rnn = layers.StaticRNN()
Y
Yu Yang 已提交
284 285 286 287
        with rnn.step():
            h_pre = rnn.memory(init=h_boot)
            x_t = rnn.step_input(x)

288 289 290 291 292 293 294 295 296 297 298 299 300 301
            temp_l = layers.fc(
                input=x_t,
                size=self.input_dim,
                param_attr=ParamAttr(
                    name='W',
                    initializer=fluid.initializer.ConstantInitializer(1.0)),
                bias_attr=False)
            temp_r = layers.fc(
                input=h_pre,
                size=self.input_dim,
                param_attr=ParamAttr(
                    name='U',
                    initializer=fluid.initializer.ConstantInitializer(0.0)),
                bias_attr=False)
302

C
chengduo 已提交
303
            h = layers.sigmoid(x=layers.elementwise_add(x=temp_l, y=temp_r))
Y
Yu Yang 已提交
304 305 306 307 308 309

            rnn.update_memory(h_pre, h)
            rnn.output(h)

        return rnn()

C
chengduo 已提交
310
    def test_backward(self):
311
        super(RecurrentOpTest2, self).test_backward(rtol=0.01)
C
chengduo 已提交
312

Y
Yu Yang 已提交
313

314
class RecurrentOpMultipleMemoryTest(RecurrentOpTest1):
Y
Yu Yang 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
    '''
    Test RNNOp with two memories
    equation:
        h_1 = h_pre_1
        h_2 = h_pre_2
        y = h_1 + h_2
    vars:
        - x
    memories:
        - h_1, h_2
    outputs:
       - y
    '''

    class PySimpleRNN3(PyRNNBase):
330

Y
Yu Yang 已提交
331
        def __init__(self, input_shape, output_shape):
332 333
            super(RecurrentOpMultipleMemoryTest.PySimpleRNN3,
                  self).__init__(input_shape, output_shape)
Y
Yu Yang 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360

            seq_len, batch_size, input_dim = input_shape
            self.h_boot1 = np.random.normal(size=(batch_size,
                                                  input_dim)).astype("float32")
            self.h_boot2 = np.random.normal(size=(batch_size,
                                                  input_dim)).astype("float32")

            men_dim = (seq_len, batch_size, input_dim)
            self.mems1 = np.zeros(shape=men_dim).astype("float32")
            self.mems2 = np.zeros(shape=men_dim).astype("float32")

        def step(self, step_id, x):
            if step_id == 0:
                pre_mem1 = self.h_boot1
                pre_mem2 = self.h_boot2
            else:
                pre_mem1 = self.mems1[step_id - 1]
                pre_mem2 = self.mems2[step_id - 1]
            self.mems1[step_id] = pre_mem1
            self.mems2[step_id] = pre_mem2
            self.y[step_id] = self.mems1[step_id] + self.mems2[step_id] + x

    input_dim = 1
    batch_size = 1
    sent_len = 2

    def setUp(self):
361
        self.setup_program()
Y
Yu Yang 已提交
362

363 364
        self.feed_data_field = {"x", "h_boot1", "h_boot2"}
        self.grad_data_field = self.feed_data_field
Y
Yu Yang 已提交
365 366 367

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
368 369
        self.py_rnn = RecurrentOpMultipleMemoryTest.PySimpleRNN3(
            self.input_shape, self.output_shape)
Y
Yu Yang 已提交
370

C
chengduo 已提交
371
        with fluid.program_guard(self.main_program, self.startup_program):
372
            self.output = paddle.mean(self.create_rnn_op())
Y
Yu Yang 已提交
373 374

    def create_rnn_op(self):
375 376 377 378
        x = layers.data(shape=[self.sent_len, self.batch_size, self.input_dim],
                        dtype='float32',
                        name='x',
                        append_batch_size=False)
Y
Yu Yang 已提交
379
        x.stop_gradient = False
380 381 382 383
        h_boot1 = layers.data(shape=[self.batch_size, self.input_dim],
                              dtype='float32',
                              name='h_boot1',
                              append_batch_size=False)
Y
Yu Yang 已提交
384
        h_boot1.stop_gradient = False
385 386 387 388
        h_boot2 = layers.data(shape=[self.batch_size, self.input_dim],
                              dtype='float32',
                              name='h_boot2',
                              append_batch_size=False)
Y
Yu Yang 已提交
389
        h_boot2.stop_gradient = False
Y
Yu Yang 已提交
390

C
chengduo 已提交
391
        rnn = layers.StaticRNN()
Y
Yu Yang 已提交
392 393 394 395 396
        with rnn.step():
            h_pre1 = rnn.memory(init=h_boot1)
            h_pre2 = rnn.memory(init=h_boot2)
            x_t = rnn.step_input(x)

C
chengduo 已提交
397 398 399
            mem1 = layers.scale(x=h_pre1, scale=1.0)
            mem2 = layers.scale(x=h_pre2, scale=1.0)
            out = layers.sums(input=[mem1, x_t, mem2])
Y
Yu Yang 已提交
400 401 402 403 404 405

            rnn.update_memory(h_pre1, mem1)
            rnn.update_memory(h_pre2, mem2)
            rnn.output(out)

        return rnn()
S
init  
superjom 已提交
406 407


408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
class RecurrentOpNoMemBootTest(RecurrentOpTest1):
    '''
    Test RNNOp with two memories
    equation:
        mem = x + mem_pre
        y = mem
    vars:
        - x
    memories:
        - mem
    outputs:
       - y
    '''

    class PySimpleRNN4(PyRNNBase):
423

424
        def __init__(self, input_shape, output_shape):
425 426
            super(RecurrentOpNoMemBootTest.PySimpleRNN4,
                  self).__init__(input_shape, output_shape)
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
            men_dim = input_shape
            self.mems = np.zeros(shape=men_dim).astype("float32")

        def step(self, step_id, x):
            if step_id == 0:
                pre_mem = np.zeros_like(x)
            else:
                pre_mem = self.mems[step_id - 1]
            self.mems[step_id] = pre_mem + x
            self.y[step_id] = self.mems[step_id]

    input_dim = 1
    batch_size = 1
    sent_len = 2

    def setUp(self):
        self.setup_program()

445 446
        self.feed_data_field = {"x"}
        self.grad_data_field = self.feed_data_field
447 448 449

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
450 451
        self.py_rnn = RecurrentOpNoMemBootTest.PySimpleRNN4(
            self.input_shape, self.output_shape)
C
chengduo 已提交
452 453

        with fluid.program_guard(self.main_program, self.startup_program):
454
            self.output = paddle.mean(self.create_rnn_op())
455 456

    def create_rnn_op(self):
457 458 459 460
        x = layers.data(shape=[self.sent_len, self.batch_size, self.input_dim],
                        dtype='float32',
                        name='x',
                        append_batch_size=False)
461 462
        x.stop_gradient = False

C
chengduo 已提交
463
        rnn = layers.StaticRNN()
464 465 466
        with rnn.step():
            mem_pre = rnn.memory(shape=[-1, self.input_dim], batch_ref=x)
            x_t = rnn.step_input(x)
C
chengduo 已提交
467
            mem = layers.elementwise_add(x=mem_pre, y=x_t)
468 469 470 471 472 473
            rnn.update_memory(mem_pre, mem)
            rnn.output(mem)

        return rnn()


474
class RecurrentOpSubBlockTest(RecurrentOpTest1):
475
    r'''
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
    Test RNNOp with subblock variable
    equation:
        y_ = emb * w1
        h_t = \concat([x, h_{t-1}])
        h_t = h_t * w2
        h_t = \\unsqueeze(h_t, 1)
        h_t = \dot_attention(h_t, y_)
        h_t = \squeeze(h_t, 1)
        y = h_t
    vars:
        - x
        - w1
        - w2
    memories:
        - h
    outputs:
       - y
    '''

    class PySimpleRNN5(PyRNNBase):
496

497
        def __init__(self, input_shape, output_shape):
498 499
            super(RecurrentOpSubBlockTest.PySimpleRNN5,
                  self).__init__(input_shape, output_shape)
500 501

            seq_len, batch_size, input_dim = input_shape
502 503 504 505 506 507 508 509 510 511 512 513
            self.w1 = np.random.uniform(-0.1, 0.1,
                                        size=(input_dim,
                                              input_dim)).astype("float32")
            self.w2 = np.random.uniform(-0.1,
                                        0.1,
                                        size=(input_dim * 2,
                                              input_dim)).astype("float32")

            self.emb = np.random.uniform(-0.1,
                                         0.1,
                                         size=(seq_len, batch_size,
                                               input_dim)).astype("float32")
514 515 516 517 518 519

            men_dim = (seq_len, batch_size, input_dim)
            self.mems = np.zeros(shape=men_dim).astype("float32")
            self.oy = np.matmul(self.emb, self.w1)

        def step(self, step_id, x):
520

521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
            def dot_attention(query, memory):
                attn = np.matmul(query, memory.transpose((0, 2, 1)))
                weight = softmax(attn)
                weight_memory = np.matmul(weight, memory)
                return weight_memory, weight

            def softmax(x):
                return np.exp(x) / sum(np.exp(x))

            if step_id == 0:
                pre_mem = np.zeros_like(x)
            else:
                pre_mem = self.mems[step_id - 1]
            concat_in = np.concatenate([x, pre_mem], 1)
            new_mem = np.matmul(concat_in, self.w2)

            new_mem = np.expand_dims(new_mem, 1)
            new_mem, _ = dot_attention(new_mem, self.oy)
            new_mem = np.squeeze(new_mem, 1)

            self.mems[step_id] = new_mem
            self.y[step_id] = self.mems[step_id]

    input_dim = 2
    batch_size = 3
    sent_len = 3

    def setUp(self):
        self.setup_program()

551 552
        self.feed_data_field = {"x", "emb", "w1", "w2"}
        self.grad_data_field = self.feed_data_field
553 554 555

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
556 557
        self.py_rnn = RecurrentOpSubBlockTest.PySimpleRNN5(
            self.input_shape, self.output_shape)
558 559 560

        with fluid.program_guard(self.main_program, self.startup_program):
            rnn_out = self.create_rnn_op()
561
            self.output = paddle.mean(rnn_out)
562 563

    def create_rnn_op(self):
564 565 566 567
        x = layers.data(shape=[self.sent_len, self.batch_size, self.input_dim],
                        dtype='float32',
                        name='x',
                        append_batch_size=False)
568 569 570 571 572 573 574 575 576
        x.stop_gradient = False

        emb = layers.data(
            name='emb',
            shape=[self.sent_len, self.batch_size, self.input_dim],
            dtype='float32',
            append_batch_size=False)
        emb.stop_gradient = False

577 578 579 580
        w1 = layers.data(shape=[self.input_dim, self.input_dim],
                         dtype='float32',
                         name='w1',
                         append_batch_size=False)
581
        w1.stop_gradient = False
582 583 584 585
        w2 = layers.data(shape=[self.input_dim * 2, self.input_dim],
                         dtype='float32',
                         name='w2',
                         append_batch_size=False)
586 587 588 589 590 591 592 593 594 595 596 597 598
        w2.stop_gradient = False

        rnn = layers.StaticRNN()

        def dot_attention(query, memory):
            attn = layers.matmul(query, memory, transpose_y=True)
            weight = layers.softmax(attn)
            weight_memory = layers.matmul(weight, memory)

            return weight_memory, weight

        y = layers.matmul(emb, w1)
        with rnn.step():
599 600 601
            pre_h = rnn.memory(shape=(self.sent_len, self.input_dim),
                               batch_ref=x,
                               init_value=0.0)
602 603 604 605 606 607 608 609 610 611 612 613 614
            step_in = rnn.step_input(x)
            concat_in = layers.concat([step_in, pre_h], 1)
            new_h = layers.matmul(concat_in, w2)
            new_h = layers.unsqueeze(new_h, [1])
            new_h, _ = dot_attention(new_h, y)
            new_h = layers.squeeze(new_h, [1])

            rnn.update_memory(pre_h, new_h)
            rnn.step_output(new_h)

        return rnn()


615
class RecurrentOpStopGradientTest(RecurrentOpTest1):
616
    r"""
617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
    Test RNNOp with stop_gradient = True
    equation:
        h_t = \sigma (W x_t + U h_{t-1})
    weights:
        - W
	- U
    vars:
        - x
    memories:
        - h
    output:
        - h
    """

    input_dim = 2
    batch_size = 10
    sent_len = 2

    def setUp(self):
        self.setup_program()
        self.feed_data_field = {"x", "h_boot", "W", "U"}
        self.grad_data_field = {"x", "W", "U"}

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)

        with fluid.program_guard(self.main_program, self.startup_program):
645
            self.output = paddle.mean(self.create_rnn_op())
646 647

    def create_rnn_op(self):
648 649 650 651
        x = layers.data(shape=[self.sent_len, self.batch_size, self.input_dim],
                        dtype="float32",
                        name="x",
                        append_batch_size=False)
652
        x.stop_gradient = False
653 654 655
        h_boot = layers.data(shape=[self.input_dim],
                             dtype="float32",
                             name="h_boot")
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
        h_boot.stop_gradient = True

        rnn = layers.StaticRNN()
        with rnn.step():
            h_pre = rnn.memory(init=h_boot)  # init doesn't have gradient
            x_t = rnn.step_input(x)

            temp_l = layers.fc(
                input=x_t,
                size=self.input_dim,
                param_attr=ParamAttr(
                    name="W",
                    initializer=fluid.initializer.ConstantInitializer(1.0)),
                bias_attr=False)
            temp_r = layers.fc(
                input=h_pre,
                size=self.input_dim,
                param_attr=ParamAttr(
                    name="U",
                    initializer=fluid.initializer.ConstantInitializer(0.0)),
                bias_attr=False)

            h = layers.sigmoid(x=layers.elementwise_add(temp_l, temp_r))

            rnn.update_memory(h_pre, h)
            rnn.output(h)

        return rnn()


Y
Yan Chunwei 已提交
686 687
if __name__ == '__main__':
    unittest.main()