test_paddle_save_load.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as opt

BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
SEED = 10

IMAGE_SIZE = 784
CLASS_NUM = 10

31 32
LARGE_PARAM = 2**26

33

34 35
def random_batch_reader():
    def _get_random_inputs_and_labels():
36
        np.random.seed(SEED)
37 38 39 40
        image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype('float32')
        label = np.random.randint(0, CLASS_NUM - 1, (
            BATCH_SIZE,
            1, )).astype('int64')
41 42
        return image, label

43 44 45 46 47 48 49 50
    def __reader__():
        for _ in range(BATCH_NUM):
            batch_image, batch_label = _get_random_inputs_and_labels()
            batch_image = paddle.to_tensor(batch_image)
            batch_label = paddle.to_tensor(batch_label)
            yield batch_image, batch_label

    return __reader__
51 52 53 54 55 56 57 58 59 60 61


class LinearNet(nn.Layer):
    def __init__(self):
        super(LinearNet, self).__init__()
        self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)

    def forward(self, x):
        return self._linear(x)


62 63 64 65 66 67 68 69 70 71
class LayerWithLargeParameters(paddle.nn.Layer):
    def __init__(self):
        super(LayerWithLargeParameters, self).__init__()
        self._l = paddle.nn.Linear(10, LARGE_PARAM)

    def forward(self, x):
        y = self._l(x)
        return y


72 73 74 75 76 77 78 79 80 81
def train(layer, loader, loss_fn, opt):
    for epoch_id in range(EPOCH_NUM):
        for batch_id, (image, label) in enumerate(loader()):
            out = layer(image)
            loss = loss_fn(out, label)
            loss.backward()
            opt.step()
            opt.clear_grad()


82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
class TestSaveLoadLargeParameters(unittest.TestCase):
    def setUp(self):
        pass

    def test_large_parameters_paddle_save(self):
        # enable dygraph mode
        paddle.disable_static()
        # create network
        layer = LayerWithLargeParameters()
        save_dict = layer.state_dict()

        path = "test_paddle_save_load_large_param_save/layer" + ".pdparams"
        paddle.save(layer.state_dict(), path)
        dict_load = paddle.load(path)
        # compare results before and after saving
        for key, value in save_dict.items():
            self.assertTrue(
                np.sum(np.abs(dict_load[key] - value.numpy())) < 1e-15)


102 103 104
class TestSaveLoad(unittest.TestCase):
    def setUp(self):
        # enable dygraph mode
105
        paddle.disable_static()
106 107

        # config seed
C
cnn 已提交
108
        paddle.seed(SEED)
109 110 111 112 113 114 115 116 117 118
        paddle.framework.random._manual_program_seed(SEED)

    def build_and_train_model(self):
        # create network
        layer = LinearNet()
        loss_fn = nn.CrossEntropyLoss()

        adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

        # create data loader
119 120
        # TODO: using new DataLoader cause unknown Timeout on windows, replace it
        loader = random_batch_reader()
121 122 123 124 125 126 127 128 129 130 131 132 133 134

        # train
        train(layer, loader, loss_fn, adam)

        return layer, adam

    def check_load_state_dict(self, orig_dict, load_dict):
        for var_name, value in orig_dict.items():
            self.assertTrue(np.array_equal(value.numpy(), load_dict[var_name]))

    def test_save_load(self):
        layer, opt = self.build_and_train_model()

        # save
135 136
        layer_save_path = "test_paddle_save_load.linear.pdparams"
        opt_save_path = "test_paddle_save_load.linear.pdopt"
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
        layer_state_dict = layer.state_dict()
        opt_state_dict = opt.state_dict()

        paddle.save(layer_state_dict, layer_save_path)
        paddle.save(opt_state_dict, opt_save_path)

        # load
        load_layer_state_dict = paddle.load(layer_save_path)
        load_opt_state_dict = paddle.load(opt_save_path)

        self.check_load_state_dict(layer_state_dict, load_layer_state_dict)
        self.check_load_state_dict(opt_state_dict, load_opt_state_dict)

        # test save load in static mode
        paddle.enable_static()
152
        static_save_path = "static_mode_test/test_paddle_save_load.linear.pdparams"
153 154 155 156 157 158 159 160 161 162 163 164
        paddle.save(layer_state_dict, static_save_path)
        load_static_state_dict = paddle.load(static_save_path)
        self.check_load_state_dict(layer_state_dict, load_static_state_dict)

        # error test cases, some tests relay base test above
        # 1. test save obj not dict error
        test_list = [1, 2, 3]
        with self.assertRaises(NotImplementedError):
            paddle.save(test_list, "not_dict_error_path")

        # 2. test save path format error
        with self.assertRaises(ValueError):
165
            paddle.save(layer_state_dict, "test_paddle_save_load.linear.model/")
166 167 168

        # 3. test load path not exist error
        with self.assertRaises(ValueError):
169
            paddle.load("test_paddle_save_load.linear.params")
170 171 172

        # 4. test load old save path error
        with self.assertRaises(ValueError):
173
            paddle.load("test_paddle_save_load.linear")
174 175 176 177


if __name__ == '__main__':
    unittest.main()