test_recurrent_op.py 20.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yan Chunwei 已提交
15
import unittest
16 17 18

import numpy as np

19
import paddle
20 21
from paddle import fluid
from paddle.fluid import ParamAttr, core, layers
22
from paddle.fluid.backward import append_backward
23 24
from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, grad_var_name
S
fix res  
superjom 已提交
25

26 27
np.random.seed(123)

S
fix res  
superjom 已提交
28

29
class PyRNNBase:
Y
Yu Yang 已提交
30 31 32
    def __init__(self, input_shape, output_shape):
        self.x = np.ones(shape=input_shape).astype("float32")
        self.y = np.zeros(shape=output_shape).astype("float32")
S
superjom 已提交
33

34 35
    def step(self, step_id, x):
        raise NotImplementedError
S
superjom 已提交
36 37 38

    def forward(self):
        for step_id in range(self.x.shape[0]):
Y
Yu Yang 已提交
39
            self.step(step_id, self.x[step_id])
40
        return np.mean(self.y)
S
superjom 已提交
41 42 43 44

    def segment_inputs(self):
        return [self.x[i] for i in range(self.x.shape[0])]

Y
Yu Yang 已提交
45 46 47

class PySimpleRNN1(PyRNNBase):
    def __init__(self, input_shape, output_shape):
48
        super().__init__(input_shape, output_shape)
Y
Yu Yang 已提交
49 50

        seq_len, batch_size, input_dim = input_shape
51 52 53
        self.h_boot = np.random.normal(size=(batch_size, input_dim)).astype(
            "float32"
        )
Y
Yu Yang 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

        self.scale = 1.0 / 2.0
        men_dim = (seq_len, batch_size, input_dim)
        self.mems = np.zeros(shape=men_dim).astype("float32")

    def step(self, step_id, x):
        if step_id == 0:
            pre_mem = self.h_boot
        else:
            pre_mem = self.mems[step_id - 1]
        self.mems[step_id] = (pre_mem + x) * self.scale
        self.y[step_id] = self.mems[step_id]


class PySimpleRNN2(PyRNNBase):
    def __init__(self, input_shape, output_shape):
70
        super().__init__(input_shape, output_shape)
Y
Yu Yang 已提交
71 72

        seq_len, batch_size, input_dim = input_shape
73 74
        self.W = np.ones(shape=(input_dim, input_dim)).astype("float32")
        self.U = np.zeros(shape=(input_dim, input_dim)).astype("float32")
Y
Yu Yang 已提交
75 76 77 78
        self.h_boot = np.ones(shape=(batch_size, input_dim)).astype("float32")

        men_dim = (seq_len, batch_size, input_dim)
        self.mems = np.zeros(shape=men_dim).astype("float32")
S
superjom 已提交
79 80 81

    def step(self, step_id, x):
        if step_id > 0:
S
fix res  
superjom 已提交
82
            pre_mem = self.mems[step_id - 1]
S
superjom 已提交
83 84
        else:
            pre_mem = self.h_boot
Q
qiaolongfei 已提交
85 86
        xW = np.matmul(x, self.W).astype("float32")
        hU = np.matmul(pre_mem, self.U).astype("float32")
S
superjom 已提交
87

Y
Yu Yang 已提交
88
        def py_sigmoid(x):
89
            return 1.0 / (1.0 + np.exp(-x))
S
fix res  
superjom 已提交
90

Y
Yu Yang 已提交
91 92
        self.mems[step_id] = py_sigmoid(xW + hU)
        self.y[step_id] = self.mems[step_id]
Y
Yan Chunwei 已提交
93 94


Y
Yu Yang 已提交
95 96 97
def create_tensor(np_data, place):
    tensor = core.LoDTensor()
    tensor.set(np_data, place)
Y
Yan Chunwei 已提交
98 99 100
    return tensor


Y
Yu Yang 已提交
101
class RecurrentOpTest1(unittest.TestCase):
Y
Yan Chunwei 已提交
102 103 104
    '''
    Test RNNOp
    equation:
Y
Yu Yang 已提交
105
        h_t = ( x_t + h_{t-1} ) / scale
Y
Yan Chunwei 已提交
106 107 108 109 110
    vars:
        - x
    memories:
        - h
    outputs:
Y
Yu Yang 已提交
111
        - h
Y
Yan Chunwei 已提交
112 113
    '''

Y
Yu Yang 已提交
114 115 116 117
    input_dim = 2
    batch_size = 1
    sent_len = 1

118 119 120
    def setup_program(self):
        self.main_program = Program()
        self.startup_program = Program()
Y
Yu Yang 已提交
121
        self.place = core.CPUPlace()
Y
Yan Chunwei 已提交
122

S
superjom 已提交
123
    def setUp(self):
124
        self.setup_program()
125 126
        self.feed_data_field = {"x", "h_boot"}
        self.grad_data_field = self.feed_data_field
Y
Yan Chunwei 已提交
127

Y
Yu Yang 已提交
128 129 130 131
        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = PySimpleRNN1(self.input_shape, self.output_shape)

C
chengduo 已提交
132
        with fluid.program_guard(self.main_program, self.startup_program):
133
            self.output = paddle.mean(self.create_rnn_op())
Y
Yan Chunwei 已提交
134 135

    def create_rnn_op(self):
G
GGBond8488 已提交
136
        x = paddle.static.data(
137 138 139 140
            shape=[self.sent_len, self.batch_size, self.input_dim],
            dtype='float32',
            name='x',
        )
Y
Yu Yang 已提交
141
        x.stop_gradient = False
G
GGBond8488 已提交
142 143
        h_boot = paddle.static.data(
            shape=[-1, self.input_dim], dtype='float32', name='h_boot'
144
        )
Y
Yu Yang 已提交
145
        h_boot.stop_gradient = False
Y
Yu Yang 已提交
146

C
chengduo 已提交
147
        rnn = layers.StaticRNN()
Y
Yu Yang 已提交
148 149 150 151
        with rnn.step():
            h_pre = rnn.memory(init=h_boot)
            x_t = rnn.step_input(x)

2
201716010711 已提交
152
            h = paddle.scale(
153
                x=paddle.add(x=h_pre, y=x_t),
154 155
                scale=self.py_rnn.scale,
            )
Y
Yu Yang 已提交
156 157 158 159 160 161 162 163 164

            rnn.update_memory(h_pre, h)
            rnn.output(h)

        return rnn()

    def forward(self):
        self.feed_map = {
            x: create_tensor(getattr(self.py_rnn, x), self.place)
165
            for x in self.feed_data_field
Y
Yu Yang 已提交
166 167
        }
        exe = Executor(self.place)
168 169 170
        out = exe.run(
            self.main_program, feed=self.feed_map, fetch_list=[self.output]
        )
Y
Yu Yang 已提交
171

D
dzhwinter 已提交
172
        return out[0]
Y
Yu Yang 已提交
173 174 175 176

    def backward(self):
        self.feed_map = {
            x: create_tensor(getattr(self.py_rnn, x), self.place)
177
            for x in self.feed_data_field
Y
Yu Yang 已提交
178 179
        }
        fetch_list = [
Q
qiaolongfei 已提交
180
            self.main_program.global_block().var(grad_var_name(x))
181
            for x in self.grad_data_field
Y
Yu Yang 已提交
182 183 184
        ]

        exe = Executor(self.place)
185 186 187 188 189 190
        return exe.run(
            self.main_program,
            feed=self.feed_map,
            fetch_list=fetch_list,
            return_numpy=False,
        )
Y
Yu Yang 已提交
191

192
    def test_backward(self, rtol=0.01):
Y
Yu Yang 已提交
193 194
        self.check_forward()

C
chengduo 已提交
195 196
        with fluid.program_guard(self.main_program, self.startup_program):
            append_backward(self.output)
Y
Yu Yang 已提交
197 198 199 200

        ana_grad = [np.array(x) for x in self.backward()]

        num_grad = self.get_numerical_gradient()
201
        for idx, name in enumerate(self.grad_data_field):
Y
Yu Yang 已提交
202
            self.assertEqual(num_grad[idx].shape, ana_grad[idx].shape)
203 204 205 206 207
            np.testing.assert_allclose(
                num_grad[idx],
                ana_grad[idx],
                rtol=rtol,
                atol=1e-8,
208 209 210 211 212 213 214 215 216 217 218 219
                err_msg='num_grad ('
                + name
                + ') has diff at '
                + str(self.place)
                + '\nExpect '
                + str(num_grad[idx])
                + '\n'
                + 'But Got'
                + str(ana_grad[idx])
                + ' in class '
                + self.__class__.__name__,
            )
Y
Yu Yang 已提交
220 221

    def check_forward(self):
S
superjom 已提交
222 223 224
        pd_output = self.forward()
        py_output = self.py_rnn.forward()
        self.assertEqual(pd_output.shape, py_output.shape)
225
        np.testing.assert_allclose(pd_output, py_output, rtol=0.01)
Y
Yan Chunwei 已提交
226

Y
Yu Yang 已提交
227 228
    def get_numerical_gradient(self, delta=0.005):
        dloss_dout = 1.0
229
        feed_list = [getattr(self.py_rnn, x) for x in self.grad_data_field]
Y
Yu Yang 已提交
230 231 232 233 234 235
        grad_list = [np.zeros_like(x) for x in feed_list]
        for feed, grad in zip(feed_list, grad_list):
            for f, g in np.nditer([feed, grad], op_flags=['readwrite']):
                o = float(f)
                f[...] = o + delta
                y_pos = self.forward()
S
fix res  
superjom 已提交
236

Y
Yu Yang 已提交
237 238 239 240 241
                f[...] = o - delta
                y_neg = self.forward()

                f[...] = o
                dout_dfeed = (y_pos - y_neg) / (delta * 2)
242
                g[...] = dout_dfeed
Y
Yu Yang 已提交
243 244 245 246 247

        return grad_list


class RecurrentOpTest2(RecurrentOpTest1):
248
    r'''
Y
Yu Yang 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
    Test RNNOp
    equation:
        h_t = \sigma (W x_t + U h_{t-1})
    weights:
        - W
        - U
    vars:
        - x
    memories:
        - h
    outputs:
       - h
    '''

    input_dim = 2
    batch_size = 10
    sent_len = 2

    def setUp(self):
268
        self.setup_program()
Y
Yu Yang 已提交
269

270 271
        self.feed_data_field = {"x", "h_boot", "W", "U"}
        self.grad_data_field = self.feed_data_field
Y
Yu Yang 已提交
272 273 274 275 276

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)

C
chengduo 已提交
277
        with fluid.program_guard(self.main_program, self.startup_program):
278
            self.output = paddle.mean(self.create_rnn_op())
Y
Yu Yang 已提交
279 280

    def create_rnn_op(self):
G
GGBond8488 已提交
281
        x = paddle.static.data(
282 283 284 285
            shape=[self.sent_len, self.batch_size, self.input_dim],
            dtype='float32',
            name='x',
        )
Y
Yu Yang 已提交
286
        x.stop_gradient = False
G
GGBond8488 已提交
287 288
        h_boot = paddle.static.data(
            shape=[-1, self.input_dim], dtype='float32', name='h_boot'
289
        )
Y
Yu Yang 已提交
290
        h_boot.stop_gradient = False
Y
Yu Yang 已提交
291

C
chengduo 已提交
292
        rnn = layers.StaticRNN()
Y
Yu Yang 已提交
293 294 295 296
        with rnn.step():
            h_pre = rnn.memory(init=h_boot)
            x_t = rnn.step_input(x)

C
Charles-hit 已提交
297 298
            temp_l = paddle.static.nn.fc(
                x=x_t,
299
                size=self.input_dim,
C
Charles-hit 已提交
300
                weight_attr=ParamAttr(
301
                    name='W',
302
                    initializer=paddle.nn.initializer.Constant(1.0),
303 304 305
                ),
                bias_attr=False,
            )
C
Charles-hit 已提交
306 307
            temp_r = paddle.static.nn.fc(
                x=h_pre,
308
                size=self.input_dim,
C
Charles-hit 已提交
309
                weight_attr=ParamAttr(
310
                    name='U',
311
                    initializer=paddle.nn.initializer.Constant(0.0),
312 313 314
                ),
                bias_attr=False,
            )
315

316
            h = paddle.nn.functional.sigmoid(x=paddle.add(x=temp_l, y=temp_r))
Y
Yu Yang 已提交
317 318 319 320 321 322

            rnn.update_memory(h_pre, h)
            rnn.output(h)

        return rnn()

C
chengduo 已提交
323
    def test_backward(self):
324
        super().test_backward(rtol=0.01)
C
chengduo 已提交
325

Y
Yu Yang 已提交
326

327
class RecurrentOpMultipleMemoryTest(RecurrentOpTest1):
Y
Yu Yang 已提交
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
    '''
    Test RNNOp with two memories
    equation:
        h_1 = h_pre_1
        h_2 = h_pre_2
        y = h_1 + h_2
    vars:
        - x
    memories:
        - h_1, h_2
    outputs:
       - y
    '''

    class PySimpleRNN3(PyRNNBase):
        def __init__(self, input_shape, output_shape):
344
            super().__init__(input_shape, output_shape)
Y
Yu Yang 已提交
345 346

            seq_len, batch_size, input_dim = input_shape
347 348 349 350 351 352
            self.h_boot1 = np.random.normal(
                size=(batch_size, input_dim)
            ).astype("float32")
            self.h_boot2 = np.random.normal(
                size=(batch_size, input_dim)
            ).astype("float32")
Y
Yu Yang 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373

            men_dim = (seq_len, batch_size, input_dim)
            self.mems1 = np.zeros(shape=men_dim).astype("float32")
            self.mems2 = np.zeros(shape=men_dim).astype("float32")

        def step(self, step_id, x):
            if step_id == 0:
                pre_mem1 = self.h_boot1
                pre_mem2 = self.h_boot2
            else:
                pre_mem1 = self.mems1[step_id - 1]
                pre_mem2 = self.mems2[step_id - 1]
            self.mems1[step_id] = pre_mem1
            self.mems2[step_id] = pre_mem2
            self.y[step_id] = self.mems1[step_id] + self.mems2[step_id] + x

    input_dim = 1
    batch_size = 1
    sent_len = 2

    def setUp(self):
374
        self.setup_program()
Y
Yu Yang 已提交
375

376 377
        self.feed_data_field = {"x", "h_boot1", "h_boot2"}
        self.grad_data_field = self.feed_data_field
Y
Yu Yang 已提交
378 379 380

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
381
        self.py_rnn = RecurrentOpMultipleMemoryTest.PySimpleRNN3(
382 383
            self.input_shape, self.output_shape
        )
Y
Yu Yang 已提交
384

C
chengduo 已提交
385
        with fluid.program_guard(self.main_program, self.startup_program):
386
            self.output = paddle.mean(self.create_rnn_op())
Y
Yu Yang 已提交
387 388

    def create_rnn_op(self):
G
GGBond8488 已提交
389
        x = paddle.static.data(
390 391 392 393
            shape=[self.sent_len, self.batch_size, self.input_dim],
            dtype='float32',
            name='x',
        )
Y
Yu Yang 已提交
394
        x.stop_gradient = False
G
GGBond8488 已提交
395
        h_boot1 = paddle.static.data(
396 397 398 399
            shape=[self.batch_size, self.input_dim],
            dtype='float32',
            name='h_boot1',
        )
Y
Yu Yang 已提交
400
        h_boot1.stop_gradient = False
G
GGBond8488 已提交
401
        h_boot2 = paddle.static.data(
402 403 404 405
            shape=[self.batch_size, self.input_dim],
            dtype='float32',
            name='h_boot2',
        )
Y
Yu Yang 已提交
406
        h_boot2.stop_gradient = False
Y
Yu Yang 已提交
407

C
chengduo 已提交
408
        rnn = layers.StaticRNN()
Y
Yu Yang 已提交
409 410 411 412 413
        with rnn.step():
            h_pre1 = rnn.memory(init=h_boot1)
            h_pre2 = rnn.memory(init=h_boot2)
            x_t = rnn.step_input(x)

2
201716010711 已提交
414 415
            mem1 = paddle.scale(x=h_pre1, scale=1.0)
            mem2 = paddle.scale(x=h_pre2, scale=1.0)
416
            out = paddle.add_n([mem1, x_t, mem2])
Y
Yu Yang 已提交
417 418 419 420 421 422

            rnn.update_memory(h_pre1, mem1)
            rnn.update_memory(h_pre2, mem2)
            rnn.output(out)

        return rnn()
S
init  
superjom 已提交
423 424


425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
class RecurrentOpNoMemBootTest(RecurrentOpTest1):
    '''
    Test RNNOp with two memories
    equation:
        mem = x + mem_pre
        y = mem
    vars:
        - x
    memories:
        - mem
    outputs:
       - y
    '''

    class PySimpleRNN4(PyRNNBase):
        def __init__(self, input_shape, output_shape):
441
            super().__init__(input_shape, output_shape)
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
            men_dim = input_shape
            self.mems = np.zeros(shape=men_dim).astype("float32")

        def step(self, step_id, x):
            if step_id == 0:
                pre_mem = np.zeros_like(x)
            else:
                pre_mem = self.mems[step_id - 1]
            self.mems[step_id] = pre_mem + x
            self.y[step_id] = self.mems[step_id]

    input_dim = 1
    batch_size = 1
    sent_len = 2

    def setUp(self):
        self.setup_program()

460 461
        self.feed_data_field = {"x"}
        self.grad_data_field = self.feed_data_field
462 463 464

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
465
        self.py_rnn = RecurrentOpNoMemBootTest.PySimpleRNN4(
466 467
            self.input_shape, self.output_shape
        )
C
chengduo 已提交
468 469

        with fluid.program_guard(self.main_program, self.startup_program):
470
            self.output = paddle.mean(self.create_rnn_op())
471 472

    def create_rnn_op(self):
G
GGBond8488 已提交
473
        x = paddle.static.data(
474 475 476 477
            shape=[self.sent_len, self.batch_size, self.input_dim],
            dtype='float32',
            name='x',
        )
478 479
        x.stop_gradient = False

C
chengduo 已提交
480
        rnn = layers.StaticRNN()
481 482 483
        with rnn.step():
            mem_pre = rnn.memory(shape=[-1, self.input_dim], batch_ref=x)
            x_t = rnn.step_input(x)
484
            mem = paddle.add(x=mem_pre, y=x_t)
485 486 487 488 489 490
            rnn.update_memory(mem_pre, mem)
            rnn.output(mem)

        return rnn()


491
class RecurrentOpSubBlockTest(RecurrentOpTest1):
492
    r'''
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
    Test RNNOp with subblock variable
    equation:
        y_ = emb * w1
        h_t = \concat([x, h_{t-1}])
        h_t = h_t * w2
        h_t = \\unsqueeze(h_t, 1)
        h_t = \dot_attention(h_t, y_)
        h_t = \squeeze(h_t, 1)
        y = h_t
    vars:
        - x
        - w1
        - w2
    memories:
        - h
    outputs:
       - y
    '''

    class PySimpleRNN5(PyRNNBase):
        def __init__(self, input_shape, output_shape):
514
            super().__init__(input_shape, output_shape)
515 516

            seq_len, batch_size, input_dim = input_shape
517 518 519 520 521 522 523 524 525 526
            self.w1 = np.random.uniform(
                -0.1, 0.1, size=(input_dim, input_dim)
            ).astype("float32")
            self.w2 = np.random.uniform(
                -0.1, 0.1, size=(input_dim * 2, input_dim)
            ).astype("float32")

            self.emb = np.random.uniform(
                -0.1, 0.1, size=(seq_len, batch_size, input_dim)
            ).astype("float32")
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562

            men_dim = (seq_len, batch_size, input_dim)
            self.mems = np.zeros(shape=men_dim).astype("float32")
            self.oy = np.matmul(self.emb, self.w1)

        def step(self, step_id, x):
            def dot_attention(query, memory):
                attn = np.matmul(query, memory.transpose((0, 2, 1)))
                weight = softmax(attn)
                weight_memory = np.matmul(weight, memory)
                return weight_memory, weight

            def softmax(x):
                return np.exp(x) / sum(np.exp(x))

            if step_id == 0:
                pre_mem = np.zeros_like(x)
            else:
                pre_mem = self.mems[step_id - 1]
            concat_in = np.concatenate([x, pre_mem], 1)
            new_mem = np.matmul(concat_in, self.w2)

            new_mem = np.expand_dims(new_mem, 1)
            new_mem, _ = dot_attention(new_mem, self.oy)
            new_mem = np.squeeze(new_mem, 1)

            self.mems[step_id] = new_mem
            self.y[step_id] = self.mems[step_id]

    input_dim = 2
    batch_size = 3
    sent_len = 3

    def setUp(self):
        self.setup_program()

563 564
        self.feed_data_field = {"x", "emb", "w1", "w2"}
        self.grad_data_field = self.feed_data_field
565 566 567

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
568
        self.py_rnn = RecurrentOpSubBlockTest.PySimpleRNN5(
569 570
            self.input_shape, self.output_shape
        )
571 572 573

        with fluid.program_guard(self.main_program, self.startup_program):
            rnn_out = self.create_rnn_op()
574
            self.output = paddle.mean(rnn_out)
575 576

    def create_rnn_op(self):
G
GGBond8488 已提交
577
        x = paddle.static.data(
578 579 580 581
            shape=[self.sent_len, self.batch_size, self.input_dim],
            dtype='float32',
            name='x',
        )
582 583
        x.stop_gradient = False

G
GGBond8488 已提交
584
        emb = paddle.static.data(
585 586 587
            name='emb',
            shape=[self.sent_len, self.batch_size, self.input_dim],
            dtype='float32',
588
        )
589 590
        emb.stop_gradient = False

G
GGBond8488 已提交
591
        w1 = paddle.static.data(
592 593 594 595
            shape=[self.input_dim, self.input_dim],
            dtype='float32',
            name='w1',
        )
596
        w1.stop_gradient = False
G
GGBond8488 已提交
597
        w2 = paddle.static.data(
598 599 600 601
            shape=[self.input_dim * 2, self.input_dim],
            dtype='float32',
            name='w2',
        )
602 603 604 605 606
        w2.stop_gradient = False

        rnn = layers.StaticRNN()

        def dot_attention(query, memory):
K
kangguangli 已提交
607
            attn = paddle.matmul(query, memory, transpose_y=True)
608
            weight = paddle.nn.functional.softmax(attn)
K
kangguangli 已提交
609
            weight_memory = paddle.matmul(weight, memory)
610 611 612

            return weight_memory, weight

K
kangguangli 已提交
613
        y = paddle.matmul(emb, w1)
614
        with rnn.step():
615 616 617 618 619
            pre_h = rnn.memory(
                shape=(self.sent_len, self.input_dim),
                batch_ref=x,
                init_value=0.0,
            )
620
            step_in = rnn.step_input(x)
621
            concat_in = paddle.concat([step_in, pre_h], 1)
K
kangguangli 已提交
622
            new_h = paddle.matmul(concat_in, w2)
623
            new_h = paddle.unsqueeze(new_h, [1])
624
            new_h, _ = dot_attention(new_h, y)
625
            new_h = paddle.squeeze(new_h, [1])
626 627 628 629 630 631 632

            rnn.update_memory(pre_h, new_h)
            rnn.step_output(new_h)

        return rnn()


633
class RecurrentOpStopGradientTest(RecurrentOpTest1):
634
    r"""
635 636 637 638 639
    Test RNNOp with stop_gradient = True
    equation:
        h_t = \sigma (W x_t + U h_{t-1})
    weights:
        - W
640
        - U
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662
    vars:
        - x
    memories:
        - h
    output:
        - h
    """

    input_dim = 2
    batch_size = 10
    sent_len = 2

    def setUp(self):
        self.setup_program()
        self.feed_data_field = {"x", "h_boot", "W", "U"}
        self.grad_data_field = {"x", "W", "U"}

        self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
        self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)

        with fluid.program_guard(self.main_program, self.startup_program):
663
            self.output = paddle.mean(self.create_rnn_op())
664 665

    def create_rnn_op(self):
G
GGBond8488 已提交
666
        x = paddle.static.data(
667 668 669 670
            shape=[self.sent_len, self.batch_size, self.input_dim],
            dtype="float32",
            name="x",
        )
671
        x.stop_gradient = False
G
GGBond8488 已提交
672 673
        h_boot = paddle.static.data(
            shape=[-1, self.input_dim], dtype="float32", name="h_boot"
674
        )
675 676 677 678 679 680 681
        h_boot.stop_gradient = True

        rnn = layers.StaticRNN()
        with rnn.step():
            h_pre = rnn.memory(init=h_boot)  # init doesn't have gradient
            x_t = rnn.step_input(x)

C
Charles-hit 已提交
682 683
            temp_l = paddle.static.nn.fc(
                x=x_t,
684
                size=self.input_dim,
C
Charles-hit 已提交
685
                weight_attr=ParamAttr(
686
                    name="W",
687
                    initializer=paddle.nn.initializer.Constant(1.0),
688 689 690
                ),
                bias_attr=False,
            )
C
Charles-hit 已提交
691 692
            temp_r = paddle.static.nn.fc(
                x=h_pre,
693
                size=self.input_dim,
C
Charles-hit 已提交
694
                weight_attr=ParamAttr(
695
                    name="U",
696
                    initializer=paddle.nn.initializer.Constant(0.0),
697 698 699
                ),
                bias_attr=False,
            )
700

701
            h = paddle.nn.functional.sigmoid(x=paddle.add(temp_l, temp_r))
702 703 704 705 706 707 708

            rnn.update_memory(h_pre, h)
            rnn.output(h)

        return rnn()


Y
Yan Chunwei 已提交
709 710
if __name__ == '__main__':
    unittest.main()