test_layer_hook.py 2.8 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import os
import tempfile
17 18 19 20 21
import unittest

import numpy as np

import paddle
22 23 24 25 26 27 28


def forward_post_hook1(layer, input, output):
    return output + output


def forward_pre_hook1(layer, input):
29
    input_return = (input[0] * 2,)
30 31 32 33
    return input_return


class SimpleNet(paddle.nn.Layer):
34 35 36
    def __init__(
        self,
    ):
37
        super().__init__()
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
        self.fc1 = paddle.nn.Linear(10, 10)
        # sublayer1 register post hook
        self.fc1.register_forward_post_hook(forward_post_hook1)

        self.fc2 = paddle.nn.Linear(10, 10)
        # sublayer2 register pre hook
        self.fc2.register_forward_pre_hook(forward_pre_hook1)

        # register pre/post hook
        self.register_forward_pre_hook(forward_pre_hook1)
        self.register_forward_post_hook(forward_post_hook1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        out = paddle.mean(x)

        return out


class TestNestLayerHook(unittest.TestCase):
    def setUp(self):
        paddle.seed(2022)
        self.x = paddle.randn([4, 10])
62 63 64 65 66
        self.temp_dir = tempfile.TemporaryDirectory()
        self.path = os.path.join(self.temp_dir.name, 'net_hook')

    def tearDown(self):
        self.temp_dir.cleanup()
67 68 69 70 71 72 73 74 75 76 77

    def train_net(self, to_static=False):
        paddle.seed(2022)
        net = SimpleNet()
        if to_static:
            net = paddle.jit.to_static(net)
        out = net(self.x)

        if to_static:
            paddle.jit.save(net, self.path)

78
        return float(out)
79 80 81 82

    def load_train(self):
        net = paddle.jit.load(self.path)
        out = net(self.x)
83
        return float(out)
84 85 86 87 88 89

    def test_hook(self):
        dy_out = self.train_net(to_static=False)
        st_out = self.train_net(to_static=True)
        load_out = self.load_train()
        print(st_out, dy_out, load_out)
90 91 92 93 94
        np.testing.assert_allclose(
            st_out,
            dy_out,
            rtol=1e-05,
            err_msg='dygraph_res is {}\nstatic_res is {}'.format(
95 96 97
                dy_out, st_out
            ),
        )
98 99 100 101
        np.testing.assert_allclose(
            st_out,
            load_out,
            rtol=1e-05,
102
            err_msg=f'load_out is {load_out}\nstatic_res is {st_out}',
103
        )
104 105 106 107


if __name__ == "__main__":
    unittest.main()