test_cache_program.py 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from collections import Counter
17 18 19 20

import numpy as np
from test_fetch_feed import Linear, Pool2D

21
import paddle
22
from paddle import fluid
H
hjyp 已提交
23
from paddle.jit.api import to_static
24
from paddle.jit.dy2static import convert_to_static
25 26 27 28 29 30 31 32 33 34 35


class TestCacheProgram(unittest.TestCase):
    def setUp(self):
        self.batch_num = 5
        self.dygraph_class = Pool2D
        self.data = np.random.random((1, 2, 4, 4)).astype('float32')

    def test_cache(self):
        prev_ops, cur_ops = Counter(), Counter()
        prev_out, cur_out = None, None
36
        with fluid.dygraph.guard(fluid.CPUPlace()):
37 38 39 40 41 42 43 44
            static_net = self.dygraph_class()
            for batch_id in range(self.batch_num):
                out = static_net(self.data)
                # Check outputs
                prev_out = cur_out
                cur_out = out
                # Check forward ops
                prev_ops = cur_ops
45 46 47 48 49 50
                cur_ops = Counter(
                    [
                        op.type
                        for op in fluid.default_main_program().block(0).ops
                    ]
                )
51
                if batch_id > 0:
52 53 54 55 56 57 58 59 60 61
                    prev_out_numpy = (
                        prev_out[0].numpy()
                        if isinstance(prev_out, (tuple, list))
                        else prev_out.numpy()
                    )
                    cur_out_numpy = (
                        cur_out[0].numpy()
                        if isinstance(cur_out, (tuple, list))
                        else cur_out.numpy()
                    )
62 63 64 65
                    np.testing.assert_allclose(
                        prev_out_numpy,
                        cur_out_numpy,
                        rtol=1e-05,
66 67 68 69
                        err_msg='Output in previous batch is {}\n Output in current batch is \n{}'.format(
                            prev_out_numpy, cur_out_numpy
                        ),
                    )
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
                    self.assertEqual(prev_ops, cur_ops)


class TestCacheProgram2(TestCacheProgram):
    def setUp(self):
        self.batch_num = 5
        self.dygraph_class = Linear
        self.data = np.random.random((4, 10)).astype('float32')


class TestCacheProgramWithOptimizer(unittest.TestCase):
    def setUp(self):
        self.dygraph_class = Linear
        self.data = np.random.random((4, 10)).astype('float32')
        self.batch_num = 5

    def train_static(self):
87
        return self.train(to_static=True)
88

89 90
    def train_dygraph(self):
        return self.train(to_static=False)
91

92
    def train(self, to_static=False):
R
Ryan 已提交
93
        paddle.jit.enable_to_static(to_static)
94 95 96 97

        with fluid.dygraph.guard(fluid.CPUPlace()):
            dygraph_net = self.dygraph_class()
            adam = fluid.optimizer.AdamOptimizer(
98 99
                learning_rate=0.001, parameter_list=dygraph_net.parameters()
            )
100 101
            loss_data = []
            for batch_id in range(self.batch_num):
102 103
                input = fluid.dygraph.to_variable(self.data)
                pred, avg_loss = dygraph_net(input)
104 105 106 107 108 109 110 111 112 113 114

                loss_data.append(avg_loss.numpy())
                avg_loss.backward()
                adam.minimize(avg_loss)
                dygraph_net.clear_gradients()

        return loss_data

    def test_with_optimizer(self):
        dygraph_loss = self.train_dygraph()
        static_loss = self.train_static()
115 116 117 118 119
        np.testing.assert_allclose(
            dygraph_loss,
            static_loss,
            rtol=1e-05,
            err_msg='dygraph is {}\n static_res is \n{}'.format(
120 121 122
                dygraph_loss, static_loss
            ),
        )
123 124


125 126
def simple_func(x):
    inputs = fluid.dygraph.to_variable(x)
127
    mean = paddle.mean(inputs)
128 129 130 131 132
    return mean


class TestConvertWithCache(unittest.TestCase):
    def test_cache(self):
133
        static_func = convert_to_static(simple_func)
134
        # Get transformed function from cache.
135
        cached_func = convert_to_static(simple_func)
136 137 138
        self.assertTrue(id(static_func), id(cached_func))


H
hjyp 已提交
139
@to_static
H
Huihuang Zheng 已提交
140
def sum_even_until_limit(max_len, limit):
141
    ret_sum = fluid.dygraph.to_variable(np.zeros(1).astype('int32'))
142 143 144 145 146 147 148 149 150 151 152
    for i in range(max_len):
        if i % 2 > 0:
            continue
        elif i > limit:
            break

        ret_sum += i
    return ret_sum


def sum_under_while(limit):
153 154
    i = fluid.dygraph.to_variable(np.zeros(1).astype('int32'))
    ret_sum = fluid.dygraph.to_variable(np.zeros(1).astype('int32'))
155 156 157 158 159 160 161 162
    while i <= limit:
        ret_sum += i
        i += 1
    return ret_sum


class TestToOutputWithCache(unittest.TestCase):
    def test_output(self):
163
        with fluid.dygraph.guard():
H
Huihuang Zheng 已提交
164
            ret = sum_even_until_limit(80, 10)
165
            self.assertEqual(ret.numpy(), 30)
166

H
hjyp 已提交
167
            ret = to_static(sum_under_while)(100)
168
            self.assertEqual(ret.numpy(), 5050)
169 170


171 172
if __name__ == '__main__':
    unittest.main()