test_eager_deletion_recurrent_op.py 1.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16 17
import unittest

18
import numpy as np
19 20

import paddle
21
from paddle import fluid
22

23
paddle.enable_static()
24

25 26
np.random.seed(123)
os.environ["CPU_NUM"] = "1"
27 28 29
fluid.core._set_eager_deletion_mode(0.0, 1.0, True)


30 31
class RecurrentNet(paddle.nn.Layer):
    def __init__(self):
32
        super().__init__()
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
        self.cell = paddle.nn.SimpleRNNCell(16, 32)
        self.rnn = paddle.nn.RNN(self.cell)

    def forward(self, inputs, prev_h):
        outputs, final_states = self.rnn(inputs, prev_h)
        return outputs, final_states


class TestDy2StRecurrentOpBackward(unittest.TestCase):
    def setUp(self):
        paddle.disable_static()
        paddle.seed(100)

    def tearDown(self):
        paddle.enable_static()

    def test_recurrent_backward(self):
        net = RecurrentNet()
        inputs = paddle.rand((4, 23, 16))
        inputs.stop_gradient = False
        prev_h = paddle.randn((4, 32))
        prev_h.stop_gradient = False

        outputs, final_states = net(inputs, prev_h)
        outputs.backward()
        dy_grad = inputs.gradient()
        inputs.clear_gradient()

        net = paddle.jit.to_static(net)
        outputs, final_states = net(inputs, prev_h)
        outputs.backward()
        st_grad = inputs.gradient()
        np.testing.assert_allclose(dy_grad, st_grad)


68 69
if __name__ == '__main__':
    unittest.main()