test_translated_layer.py 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
19 20
import tempfile
import os
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
import paddle
import paddle.nn as nn
import paddle.optimizer as opt

BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
SEED = 10

IMAGE_SIZE = 784
CLASS_NUM = 10


# define a random dataset
class RandomDataset(paddle.io.Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        np.random.seed(SEED)
        image = np.random.random([IMAGE_SIZE]).astype('float32')
        label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples


class LinearNet(nn.Layer):
    def __init__(self):
        super(LinearNet, self).__init__()
        self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
53
        self._dropout = paddle.nn.Dropout(p=0.5)
54

55 56 57 58
    @paddle.jit.to_static(input_spec=[
        paddle.static.InputSpec(
            shape=[None, IMAGE_SIZE], dtype='float32', name='x')
    ])
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    def forward(self, x):
        return self._linear(x)


def train(layer, loader, loss_fn, opt):
    for epoch_id in range(EPOCH_NUM):
        for batch_id, (image, label) in enumerate(loader()):
            out = layer(image)
            loss = loss_fn(out, label)
            loss.backward()
            opt.step()
            opt.clear_grad()
            print("Epoch {} batch {}: loss = {}".format(epoch_id, batch_id,
                                                        np.mean(loss.numpy())))
    return loss


class TestTranslatedLayer(unittest.TestCase):
77 78 79
    def tearDown(self):
        self.temp_dir.cleanup()

80 81 82 83 84 85
    def setUp(self):
        # enable dygraph mode
        place = paddle.CPUPlace()
        paddle.disable_static(place)

        # config seed
C
cnn 已提交
86
        paddle.seed(SEED)
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
        paddle.framework.random._manual_program_seed(SEED)

        # create network
        self.layer = LinearNet()
        self.loss_fn = nn.CrossEntropyLoss()
        self.sgd = opt.SGD(learning_rate=0.001,
                           parameters=self.layer.parameters())

        # create data loader
        dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
        self.loader = paddle.io.DataLoader(
            dataset,
            places=place,
            batch_size=BATCH_SIZE,
            shuffle=True,
            drop_last=True,
103
            num_workers=0)
104
        self.temp_dir = tempfile.TemporaryDirectory()
105 106 107 108 109

        # train
        train(self.layer, self.loader, self.loss_fn, self.sgd)

        # save
110 111
        self.model_path = os.path.join(self.temp_dir.name,
                                       './linear.example.model')
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
        paddle.jit.save(self.layer, self.model_path)

    def test_inference_and_fine_tuning(self):
        self.load_and_inference()
        self.load_and_fine_tuning()

    def load_and_inference(self):
        # load
        translated_layer = paddle.jit.load(self.model_path)

        # inference
        x = paddle.randn([1, IMAGE_SIZE], 'float32')

        self.layer.eval()
        orig_pred = self.layer(x)

        translated_layer.eval()
        pred = translated_layer(x)

        self.assertTrue(np.array_equal(orig_pred.numpy(), pred.numpy()))

    def load_and_fine_tuning(self):
        # load
        translated_layer = paddle.jit.load(self.model_path)

        # train original layer continue
        self.layer.train()
        orig_loss = train(self.layer, self.loader, self.loss_fn, self.sgd)

        # fine-tuning
        translated_layer.train()
        sgd = opt.SGD(learning_rate=0.001,
                      parameters=translated_layer.parameters())
        loss = train(translated_layer, self.loader, self.loss_fn, sgd)

        self.assertTrue(
            np.array_equal(orig_loss.numpy(), loss.numpy()),
            msg="original loss:\n{}\nnew loss:\n{}\n".format(orig_loss.numpy(),
                                                             loss.numpy()))

    def test_get_program(self):
        # load
        translated_layer = paddle.jit.load(self.model_path)

        program = translated_layer.program()
        self.assertTrue(isinstance(program, paddle.static.Program))

    def test_get_program_method_not_exists(self):
        # load
        translated_layer = paddle.jit.load(self.model_path)

        with self.assertRaises(ValueError):
            program = translated_layer.program('not_exists')

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    def test_get_input_spec(self):
        # load
        translated_layer = paddle.jit.load(self.model_path)

        expect_spec = [
            paddle.static.InputSpec(
                shape=[None, IMAGE_SIZE], dtype='float32', name='x')
        ]
        actual_spec = translated_layer._input_spec()

        for spec_x, spec_y in zip(expect_spec, actual_spec):
            self.assertEqual(spec_x, spec_y)

    def test_get_output_spec(self):
        # load
        translated_layer = paddle.jit.load(self.model_path)

        expect_spec = [
            paddle.static.InputSpec(
                shape=[None, CLASS_NUM],
                dtype='float32',
                name='translated_layer/scale_0.tmp_1')
        ]
        actual_spec = translated_layer._output_spec()

        for spec_x, spec_y in zip(expect_spec, actual_spec):
            self.assertEqual(spec_x, spec_y)

194 195 196 197 198 199 200 201 202 203 204 205 206 207
    def test_layer_state(self):
        # load
        translated_layer = paddle.jit.load(self.model_path)
        translated_layer.eval()
        self.assertEqual(translated_layer.training, False)
        for layer in translated_layer.sublayers():
            print("123")
            self.assertEqual(layer.training, False)

        translated_layer.train()
        self.assertEqual(translated_layer.training, True)
        for layer in translated_layer.sublayers():
            self.assertEqual(layer.training, True)

208 209 210

if __name__ == '__main__':
    unittest.main()