test_build_strategy.py 3.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

17
import numpy as np
18
from test_resnet import ResNetHelper
19

20
import paddle
21 22 23 24 25 26 27 28 29


class TestResnetWithPass(unittest.TestCase):
    def setUp(self):
        self.build_strategy = paddle.static.BuildStrategy()
        self.build_strategy.fuse_elewise_add_act_ops = True
        self.build_strategy.fuse_bn_act_ops = True
        self.build_strategy.fuse_bn_add_act_ops = True
        self.build_strategy.enable_addto = True
30
        self.resnet_helper = ResNetHelper()
31 32 33 34
        # NOTE: for enable_addto
        paddle.fluid.set_flags({"FLAGS_max_inplace_grad_add": 8})

    def train(self, to_static):
R
Ryan 已提交
35
        paddle.jit.enable_to_static(to_static)
36
        return self.resnet_helper.train(to_static, self.build_strategy)
37 38 39

    def verify_predict(self):
        image = np.random.random([1, 3, 224, 224]).astype('float32')
40 41 42 43
        dy_pre = self.resnet_helper.predict_dygraph(image)
        st_pre = self.resnet_helper.predict_static(image)
        dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
        predictor_pre = self.resnet_helper.predict_analysis_inference(image)
44 45 46 47
        np.testing.assert_allclose(
            dy_pre,
            st_pre,
            rtol=1e-05,
48
            err_msg=f'dy_pre:\n {dy_pre}\n, st_pre: \n{st_pre}.',
49
        )
50 51 52 53 54
        np.testing.assert_allclose(
            dy_jit_pre,
            st_pre,
            rtol=1e-05,
            err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format(
55 56 57
                dy_jit_pre, st_pre
            ),
        )
58 59 60 61 62
        np.testing.assert_allclose(
            predictor_pre,
            st_pre,
            rtol=1e-05,
            err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format(
63 64 65
                predictor_pre, st_pre
            ),
        )
66 67 68 69

    def test_resnet(self):
        static_loss = self.train(to_static=True)
        dygraph_loss = self.train(to_static=False)
70 71 72 73 74
        np.testing.assert_allclose(
            static_loss,
            dygraph_loss,
            rtol=1e-05,
            err_msg='static_loss: {} \n dygraph_loss: {}'.format(
75 76 77
                static_loss, dygraph_loss
            ),
        )
78 79 80 81 82 83
        self.verify_predict()

    def test_in_static_mode_mkldnn(self):
        paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
        try:
            if paddle.fluid.core.is_compiled_with_mkldnn():
84
                self.resnet_helper.train(True, self.build_strategy)
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        finally:
            paddle.fluid.set_flags({'FLAGS_use_mkldnn': False})


class TestError(unittest.TestCase):
    def test_type_error(self):
        def foo(x):
            out = x + 1
            return out

        with self.assertRaises(TypeError):
            static_foo = paddle.jit.to_static(foo, build_strategy="x")


if __name__ == '__main__':
    unittest.main()