test_model.py 30.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# copyright (c) 2020 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import unittest

import os
import numpy as np
import shutil
import tempfile

L
Leo Chen 已提交
25
import paddle
26
from paddle import fluid
27
from paddle import to_tensor
C
cnn 已提交
28
from paddle.nn import Conv2D, Linear, ReLU, Sequential, Softmax
29

30 31
from paddle import Model
from paddle.static import InputSpec
32
from paddle.nn.layer.loss import CrossEntropyLoss
33
from paddle.metric import Accuracy
34 35
from paddle.vision.datasets import MNIST
from paddle.vision.models import LeNet
Y
yukavio 已提交
36 37
import paddle.vision.models as models
import paddle.fluid.dygraph.jit as jit
38
from paddle.io import DistributedBatchSampler, Dataset
39
from paddle.hapi.model import prepare_distributed_context
40 41
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
42 43


44
class LeNetDygraph(paddle.nn.Layer):
L
LielinJiang 已提交
45
    def __init__(self, num_classes=10):
46 47 48
        super(LeNetDygraph, self).__init__()
        self.num_classes = num_classes
        self.features = Sequential(
C
cnn 已提交
49
            Conv2D(
50
                1, 6, 3, stride=1, padding=1),
L
LielinJiang 已提交
51
            ReLU(),
52
            paddle.fluid.dygraph.Pool2D(2, 'max', 2),
C
cnn 已提交
53
            Conv2D(
54
                6, 16, 5, stride=1, padding=0),
L
LielinJiang 已提交
55
            ReLU(),
56
            paddle.fluid.dygraph.Pool2D(2, 'max', 2))
57 58 59

        if num_classes > 0:
            self.fc = Sequential(
L
LielinJiang 已提交
60
                Linear(400, 120), Linear(120, 84), Linear(84, 10))
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102

    def forward(self, inputs):
        x = self.features(inputs)

        if self.num_classes > 0:
            x = fluid.layers.flatten(x, 1)
            x = self.fc(x)
        return x


class MnistDataset(MNIST):
    def __init__(self, mode, return_label=True, sample_num=None):
        super(MnistDataset, self).__init__(mode=mode)
        self.return_label = return_label
        if sample_num:
            self.images = self.images[:sample_num]
            self.labels = self.labels[:sample_num]

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        img = np.reshape(img, [1, 28, 28])
        if self.return_label:
            return img, np.array(self.labels[idx]).astype('int64')
        return img,

    def __len__(self):
        return len(self.images)


def compute_acc(pred, label):
    pred = np.argmax(pred, -1)
    label = np.array(label)
    correct = pred[:, np.newaxis] == label
    return np.sum(correct) / correct.shape[0]


def dynamic_train(model, dataloader):
    optim = fluid.optimizer.Adam(
        learning_rate=0.001, parameter_list=model.parameters())
    model.train()
    for inputs, labels in dataloader:
        outputs = model(inputs)
103
        loss = CrossEntropyLoss(reduction="sum")(outputs, labels)
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
        avg_loss = fluid.layers.reduce_sum(loss)
        avg_loss.backward()
        optim.minimize(avg_loss)
        model.clear_gradients()


def dynamic_evaluate(model, dataloader):
    with fluid.dygraph.no_grad():
        model.eval()
        cnt = 0
        for inputs, labels in dataloader:
            outputs = model(inputs)

            cnt += (np.argmax(outputs.numpy(), -1)[:, np.newaxis] ==
                    labels.numpy()).astype('int').sum()

    return cnt / len(dataloader.dataset)


@unittest.skipIf(not fluid.is_compiled_with_cuda(),
                 'CPU testing is not supported')
class TestModel(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        if not fluid.is_compiled_with_cuda():
            self.skipTest('module not tested when ONLY_CPU compling')
130
        cls.device = paddle.set_device('gpu')
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        fluid.enable_dygraph(cls.device)

        sp_num = 1280
        cls.train_dataset = MnistDataset(mode='train', sample_num=sp_num)
        cls.val_dataset = MnistDataset(mode='test', sample_num=sp_num)
        cls.test_dataset = MnistDataset(
            mode='test', return_label=False, sample_num=sp_num)

        cls.train_loader = fluid.io.DataLoader(
            cls.train_dataset, places=cls.device, batch_size=64)
        cls.val_loader = fluid.io.DataLoader(
            cls.val_dataset, places=cls.device, batch_size=64)
        cls.test_loader = fluid.io.DataLoader(
            cls.test_dataset, places=cls.device, batch_size=64)

        seed = 333
C
cnn 已提交
147
        paddle.seed(seed)
L
Leo Chen 已提交
148
        paddle.framework.random._manual_program_seed(seed)
149 150 151 152 153 154 155

        dy_lenet = LeNetDygraph()
        cls.init_param = dy_lenet.state_dict()
        dynamic_train(dy_lenet, cls.train_loader)

        cls.acc1 = dynamic_evaluate(dy_lenet, cls.val_loader)

156 157
        cls.inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
        cls.labels = [InputSpec([None, 1], 'int64', 'label')]
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

        cls.save_dir = tempfile.mkdtemp()
        cls.weight_path = os.path.join(cls.save_dir, 'lenet')
        fluid.dygraph.save_dygraph(dy_lenet.state_dict(), cls.weight_path)

        fluid.disable_dygraph()

    @classmethod
    def tearDownClass(cls):
        shutil.rmtree(cls.save_dir)

    def test_fit_dygraph(self):
        self.fit(True)

    def test_fit_static(self):
        self.fit(False)

175 176 177 178 179 180
    def test_fit_dynamic_with_rank(self):
        self.fit(True, 2, 0)

    def test_fit_static_with_rank(self):
        self.fit(False, 2, 0)

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
    def test_evaluate_dygraph(self):
        self.evaluate(True)

    def test_evaluate_static(self):
        self.evaluate(False)

    def test_predict_dygraph(self):
        self.predict(True)

    def test_predict_static(self):
        self.predict(False)

    def test_prepare_context(self):
        prepare_distributed_context()

196
    def fit(self, dynamic, num_replicas=None, rank=None):
197 198
        fluid.enable_dygraph(self.device) if dynamic else None
        seed = 333
C
cnn 已提交
199
        paddle.seed(seed)
L
Leo Chen 已提交
200
        paddle.framework.random._manual_program_seed(seed)
201

L
LielinJiang 已提交
202
        net = LeNet()
203
        optim_new = fluid.optimizer.Adam(
204 205
            learning_rate=0.001, parameter_list=net.parameters())
        model = Model(net, inputs=self.inputs, labels=self.labels)
206 207
        model.prepare(
            optim_new,
208
            loss=CrossEntropyLoss(reduction="sum"),
209
            metrics=Accuracy())
210 211 212 213 214 215
        model.fit(self.train_dataset, batch_size=64, shuffle=False)

        result = model.evaluate(self.val_dataset, batch_size=64)
        np.testing.assert_allclose(result['acc'], self.acc1)

        train_sampler = DistributedBatchSampler(
216 217 218 219 220
            self.train_dataset,
            batch_size=64,
            shuffle=False,
            num_replicas=num_replicas,
            rank=rank)
221
        val_sampler = DistributedBatchSampler(
222 223 224 225 226
            self.val_dataset,
            batch_size=64,
            shuffle=False,
            num_replicas=num_replicas,
            rank=rank)
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244

        train_loader = fluid.io.DataLoader(
            self.train_dataset,
            batch_sampler=train_sampler,
            places=self.device,
            return_list=True)

        val_loader = fluid.io.DataLoader(
            self.val_dataset,
            batch_sampler=val_sampler,
            places=self.device,
            return_list=True)

        model.fit(train_loader, val_loader)
        fluid.disable_dygraph() if dynamic else None

    def evaluate(self, dynamic):
        fluid.enable_dygraph(self.device) if dynamic else None
245 246
        model = Model(LeNet(), self.inputs, self.labels)
        model.prepare(metrics=Accuracy())
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
        model.load(self.weight_path)
        result = model.evaluate(self.val_dataset, batch_size=64)
        np.testing.assert_allclose(result['acc'], self.acc1)

        sampler = DistributedBatchSampler(
            self.val_dataset, batch_size=64, shuffle=False)

        val_loader = fluid.io.DataLoader(
            self.val_dataset,
            batch_sampler=sampler,
            places=self.device,
            return_list=True)

        model.evaluate(val_loader)

        fluid.disable_dygraph() if dynamic else None

    def predict(self, dynamic):
        fluid.enable_dygraph(self.device) if dynamic else None
266 267
        model = Model(LeNet(), self.inputs)
        model.prepare()
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
        model.load(self.weight_path)
        output = model.predict(
            self.test_dataset, batch_size=64, stack_outputs=True)
        np.testing.assert_equal(output[0].shape[0], len(self.test_dataset))

        acc = compute_acc(output[0], self.val_dataset.labels)
        np.testing.assert_allclose(acc, self.acc1)

        sampler = DistributedBatchSampler(
            self.test_dataset, batch_size=64, shuffle=False)

        test_loader = fluid.io.DataLoader(
            self.test_dataset,
            batch_sampler=sampler,
            places=self.device,
            return_list=True)

        model.evaluate(test_loader)

        fluid.disable_dygraph() if dynamic else None

289 290 291 292 293 294 295 296 297 298 299
    def test_predict_without_inputs(self):
        fluid.enable_dygraph(self.device)
        model = Model(LeNet())
        model.prepare()
        model.load(self.weight_path)
        model._inputs = None
        output = model.predict(
            self.test_dataset, batch_size=64, stack_outputs=True)
        np.testing.assert_equal(output[0].shape[0], len(self.test_dataset))
        fluid.disable_dygraph()

300 301 302 303 304 305
    def test_summary_gpu(self):
        paddle.disable_static(self.device)
        rnn = paddle.nn.LSTM(16, 32, 2)
        params_info = paddle.summary(
            rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])

306

307
class MyModel(paddle.nn.Layer):
L
LielinJiang 已提交
308
    def __init__(self):
309
        super(MyModel, self).__init__()
310
        self._fc = Linear(20, 10)
311 312 313 314 315 316

    def forward(self, x):
        y = self._fc(x)
        return y


317 318 319 320 321 322 323 324 325
class MyDataset(Dataset):
    def __getitem__(self, idx):
        return np.random.random(size=(20,)).astype(np.float32), \
               np.random.randint(0, 10, size=(1,)).astype(np.int64)

    def __len__(self):
        return 40


326 327
class TestModelFunction(unittest.TestCase):
    def set_seed(self, seed=1024):
C
cnn 已提交
328
        paddle.seed(seed)
L
Leo Chen 已提交
329
        paddle.framework.random._manual_program_seed(seed)
330 331 332 333 334 335 336 337 338

    def test_train_batch(self, dynamic=True):
        dim = 20
        data = np.random.random(size=(4, dim)).astype(np.float32)
        label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64)

        def get_expect():
            fluid.enable_dygraph(fluid.CPUPlace())
            self.set_seed()
L
LielinJiang 已提交
339
            m = MyModel()
340 341 342
            optim = fluid.optimizer.SGD(learning_rate=0.001,
                                        parameter_list=m.parameters())
            m.train()
343 344
            output = m(to_tensor(data))
            loss = CrossEntropyLoss(reduction='sum')(output, to_tensor(label))
345 346 347 348 349 350 351 352 353
            avg_loss = fluid.layers.reduce_sum(loss)
            avg_loss.backward()
            optim.minimize(avg_loss)
            m.clear_gradients()
            fluid.disable_dygraph()
            return avg_loss.numpy()

        ref = get_expect()
        for dynamic in [True, False]:
354
            device = paddle.set_device('cpu')
355 356 357
            fluid.enable_dygraph(device) if dynamic else None
            self.set_seed()

L
LielinJiang 已提交
358
            net = MyModel()
359
            optim2 = fluid.optimizer.SGD(learning_rate=0.001,
360
                                         parameter_list=net.parameters())
361

362 363
            inputs = [InputSpec([None, dim], 'float32', 'x')]
            labels = [InputSpec([None, 1], 'int64', 'label')]
364
            model = Model(net, inputs, labels)
365
            model.prepare(optim2, loss=CrossEntropyLoss(reduction="sum"))
366 367 368 369
            loss, = model.train_batch([data], [label])
            np.testing.assert_allclose(loss.flatten(), ref.flatten())
            fluid.disable_dygraph() if dynamic else None

370
    def test_test_batch(self):
371 372 373 374 375 376 377 378
        dim = 20
        data = np.random.random(size=(4, dim)).astype(np.float32)

        def get_expect():
            fluid.enable_dygraph(fluid.CPUPlace())
            self.set_seed()
            m = MyModel()
            m.eval()
379
            output = m(to_tensor(data))
380 381 382 383 384
            fluid.disable_dygraph()
            return output.numpy()

        ref = get_expect()
        for dynamic in [True, False]:
385
            device = paddle.set_device('cpu')
386 387
            fluid.enable_dygraph(device) if dynamic else None
            self.set_seed()
388
            net = MyModel()
389
            inputs = [InputSpec([None, dim], 'float32', 'x')]
390 391
            model = Model(net, inputs)
            model.prepare()
392
            out, = model.predict_batch([data])
393

394
            np.testing.assert_allclose(out, ref, rtol=1e-6)
395 396 397 398 399
            fluid.disable_dygraph() if dynamic else None

    def test_save_load(self):
        path = tempfile.mkdtemp()
        for dynamic in [True, False]:
400
            device = paddle.set_device('cpu')
401
            fluid.enable_dygraph(device) if dynamic else None
L
LielinJiang 已提交
402
            net = MyModel()
403 404
            inputs = [InputSpec([None, 20], 'float32', 'x')]
            labels = [InputSpec([None, 1], 'int64', 'label')]
405
            optim = fluid.optimizer.SGD(learning_rate=0.001,
406 407
                                        parameter_list=net.parameters())
            model = Model(net, inputs, labels)
408
            model.prepare(
409
                optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
410 411 412 413 414
            model.save(path + '/test')
            model.load(path + '/test')
            shutil.rmtree(path)
            fluid.disable_dygraph() if dynamic else None

415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
    def test_dynamic_load(self):
        mnist_data = MnistDataset(mode='train')
        for new_optimizer in [True, False]:
            path = tempfile.mkdtemp()
            paddle.disable_static()
            net = LeNet()
            inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
            labels = [InputSpec([None, 1], 'int64', 'label')]
            if new_optimizer:
                optim = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=net.parameters())
            else:
                optim = fluid.optimizer.Adam(
                    learning_rate=0.001, parameter_list=net.parameters())
            model = Model(net, inputs, labels)
            model.prepare(
                optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
            model.fit(mnist_data, batch_size=64, verbose=0)
            model.save(path + '/test')
            model.load(path + '/test')
            shutil.rmtree(path)
            paddle.enable_static()

438 439
    def test_dynamic_save_static_load(self):
        path = tempfile.mkdtemp()
440
        # dynamic saving
441
        device = paddle.set_device('cpu')
442
        fluid.enable_dygraph(device)
443
        model = Model(MyModel())
444 445
        optim = fluid.optimizer.SGD(learning_rate=0.001,
                                    parameter_list=model.parameters())
446
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
447 448
        model.save(path + '/test')
        fluid.disable_dygraph()
449

450 451
        inputs = [InputSpec([None, 20], 'float32', 'x')]
        labels = [InputSpec([None, 1], 'int64', 'label')]
L
LielinJiang 已提交
452
        model = Model(MyModel(), inputs, labels)
453 454
        optim = fluid.optimizer.SGD(learning_rate=0.001,
                                    parameter_list=model.parameters())
455
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
456 457 458 459 460 461
        model.load(path + '/test')
        shutil.rmtree(path)

    def test_static_save_dynamic_load(self):
        path = tempfile.mkdtemp()

L
LielinJiang 已提交
462
        net = MyModel()
463 464
        inputs = [InputSpec([None, 20], 'float32', 'x')]
        labels = [InputSpec([None, 1], 'int64', 'label')]
465
        optim = fluid.optimizer.SGD(learning_rate=0.001,
466 467
                                    parameter_list=net.parameters())
        model = Model(net, inputs, labels)
468
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
469 470
        model.save(path + '/test')

471
        device = paddle.set_device('cpu')
472 473
        fluid.enable_dygraph(device)  #if dynamic else None

L
LielinJiang 已提交
474
        net = MyModel()
475 476
        inputs = [InputSpec([None, 20], 'float32', 'x')]
        labels = [InputSpec([None, 1], 'int64', 'label')]
477
        optim = fluid.optimizer.SGD(learning_rate=0.001,
478 479
                                    parameter_list=net.parameters())
        model = Model(net, inputs, labels)
480
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
481 482 483 484 485 486
        model.load(path + '/test')
        shutil.rmtree(path)
        fluid.disable_dygraph()

    def test_parameters(self):
        for dynamic in [True, False]:
487
            device = paddle.set_device('cpu')
488
            fluid.enable_dygraph(device) if dynamic else None
489
            net = MyModel()
490
            inputs = [InputSpec([None, 20], 'float32', 'x')]
491 492
            model = Model(net, inputs)
            model.prepare()
493 494 495 496 497
            params = model.parameters()
            self.assertTrue(params[0].shape[0] == 20)
            self.assertTrue(params[0].shape[1] == 10)
            fluid.disable_dygraph() if dynamic else None

L
LielinJiang 已提交
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
    def test_summary(self):
        def _get_param_from_state_dict(state_dict):
            params = 0
            for k, v in state_dict.items():
                params += np.prod(v.numpy().shape)
            return params

        for dynamic in [True, False]:
            device = paddle.set_device('cpu')
            fluid.enable_dygraph(device) if dynamic else None
            net = MyModel()
            inputs = [InputSpec([None, 20], 'float32', 'x')]
            model = Model(net, inputs)
            model.prepare()
            params_info = model.summary()
            gt_params = _get_param_from_state_dict(net.state_dict())

            np.testing.assert_allclose(params_info['total_params'], gt_params)
            print(params_info)

518 519
            model.summary(input_size=(20))
            model.summary(input_size=[(20)])
L
LielinJiang 已提交
520
            model.summary(input_size=(20), dtype='float32')
521

L
LielinJiang 已提交
522
    def test_summary_nlp(self):
523 524 525 526 527 528
        def _get_param_from_state_dict(state_dict):
            params = 0
            for k, v in state_dict.items():
                params += np.prod(v.numpy().shape)
            return params

L
LielinJiang 已提交
529 530 531 532 533
        nlp_net = paddle.nn.GRU(input_size=2,
                                hidden_size=3,
                                num_layers=3,
                                direction="bidirectional")
        paddle.summary(nlp_net, (1, 1, 2))
534

L
LielinJiang 已提交
535
        rnn = paddle.nn.LSTM(16, 32, 2)
536 537 538 539 540 541 542 543 544 545 546 547 548 549
        params_info = paddle.summary(
            rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
        gt_params = _get_param_from_state_dict(rnn.state_dict())
        np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)

        rnn = paddle.nn.GRU(16, 32, 2, direction='bidirectional')
        params_info = paddle.summary(rnn, (4, 23, 16))
        gt_params = _get_param_from_state_dict(rnn.state_dict())
        np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)

        rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional')
        params_info = paddle.summary(rnn, (4, 23, 16))
        gt_params = _get_param_from_state_dict(rnn.state_dict())
        np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
L
LielinJiang 已提交
550

L
LielinJiang 已提交
551 552 553 554 555
    def test_summary_dtype(self):
        input_shape = (3, 1)
        net = paddle.nn.Embedding(10, 3, sparse=True)
        paddle.summary(net, input_shape, dtypes='int64')

L
LielinJiang 已提交
556 557 558
    def test_summary_error(self):
        with self.assertRaises(TypeError):
            nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
L
LielinJiang 已提交
559
            paddle.summary(nlp_net, (1, 1, '2'))
L
LielinJiang 已提交
560 561 562 563 564 565 566

        with self.assertRaises(ValueError):
            nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
            paddle.summary(nlp_net, (-1, -1))

        paddle.disable_static()
        nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
L
LielinJiang 已提交
567
        paddle.summary(nlp_net, (1, 1, 2))
L
LielinJiang 已提交
568

Y
yukavio 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
    def test_static_flops(self):
        paddle.disable_static()
        net = models.__dict__['mobilenet_v2'](pretrained=False)
        inputs = paddle.randn([1, 3, 224, 224])
        static_program = jit._trace(net, inputs=[inputs])[1]
        paddle.flops(static_program, [1, 3, 224, 224], print_detail=True)

    def test_dynamic_flops(self):
        net = models.__dict__['mobilenet_v2'](pretrained=False)

        def customize_dropout(m, x, y):
            m.total_ops += 0

        paddle.flops(
            net, [1, 3, 224, 224],
            custom_ops={paddle.nn.Dropout: customize_dropout},
            print_detail=True)

587
    def test_export_deploy_model(self):
588
        self.set_seed()
589
        np.random.seed(201)
590
        for dynamic in [True, False]:
591
            paddle.disable_static() if dynamic else None
592 593
            prog_translator = ProgramTranslator()
            prog_translator.enable(False) if not dynamic else None
594
            net = LeNet()
595
            inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
596 597 598 599 600 601 602
            model = Model(net, inputs)
            model.prepare()
            save_dir = tempfile.mkdtemp()
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            tensor_img = np.array(
                np.random.random((1, 1, 28, 28)), dtype=np.float32)
603

604
            model.save(save_dir, training=False)
605
            ori_results = model.predict_batch(tensor_img)
606
            fluid.disable_dygraph() if dynamic else None
607

608 609 610 611 612 613
            place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda(
            ) else fluid.CUDAPlace(0)
            new_scope = fluid.Scope()
            with fluid.scope_guard(new_scope):
                exe = fluid.Executor(place)
                [inference_program, feed_target_names, fetch_targets] = (
614 615
                    paddle.static.io.load_inference_model(
                        path_prefix=save_dir, executor=exe))
616 617 618 619 620 621
                results = exe.run(inference_program,
                                  feed={feed_target_names[0]: tensor_img},
                                  fetch_list=fetch_targets)
                np.testing.assert_allclose(
                    results, ori_results, rtol=1e-5, atol=1e-7)
                shutil.rmtree(save_dir)
622
            paddle.enable_static()
623

L
LiuChiachi 已提交
624
    def test_dygraph_export_deploy_model_about_inputs(self):
J
Jiaqi Liu 已提交
625 626
        self.set_seed()
        np.random.seed(201)
627 628
        mnist_data = MnistDataset(mode='train')
        paddle.disable_static()
L
LiuChiachi 已提交
629
        # without inputs
630
        for initial in ["fit", "train_batch", "eval_batch", "predict_batch"]:
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
            save_dir = tempfile.mkdtemp()
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            net = LeNet()
            model = Model(net)
            optim = fluid.optimizer.Adam(
                learning_rate=0.001, parameter_list=model.parameters())
            model.prepare(
                optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
            if initial == "fit":
                model.fit(mnist_data, batch_size=64, verbose=0)
            else:
                img = np.array(
                    np.random.random((1, 1, 28, 28)), dtype=np.float32)
                label = np.array(np.random.rand(1, 1), dtype=np.int64)
                if initial == "train_batch":
                    model.train_batch([img], [label])
                elif initial == "eval_batch":
                    model.eval_batch([img], [label])
                else:
651
                    model.predict_batch([img])
652 653 654

            model.save(save_dir, training=False)
            shutil.rmtree(save_dir)
L
LiuChiachi 已提交
655 656 657 658 659 660 661 662 663 664 665 666
        # with inputs, and the type of inputs is InputSpec
        save_dir = tempfile.mkdtemp()
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        net = LeNet()
        inputs = InputSpec([None, 1, 28, 28], 'float32', 'x')
        model = Model(net, inputs)
        optim = fluid.optimizer.Adam(
            learning_rate=0.001, parameter_list=model.parameters())
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
        model.save(save_dir, training=False)
        shutil.rmtree(save_dir)
667

668

669
class TestModelWithLRScheduler(unittest.TestCase):
670 671 672 673
    def test_fit_by_step(self):
        base_lr = 1e-3
        boundaries = [5, 8]

674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692
        def make_optimizer(parameters=None):
            momentum = 0.9
            weight_decay = 5e-4
            values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
            learning_rate = paddle.optimizer.lr.PiecewiseDecay(
                boundaries=boundaries, values=values)
            learning_rate = paddle.optimizer.lr.LinearWarmup(
                learning_rate=learning_rate,
                warmup_steps=4,
                start_lr=base_lr / 5.,
                end_lr=base_lr,
                verbose=True)
            optimizer = paddle.optimizer.Momentum(
                learning_rate=learning_rate,
                weight_decay=weight_decay,
                momentum=momentum,
                parameters=parameters)
            return optimizer

693
        # dynamic test
694 695 696 697 698 699 700 701 702 703 704 705
        device = paddle.set_device('cpu')
        fluid.enable_dygraph(device)
        net = MyModel()
        inputs = [InputSpec([None, 20], 'float32', 'x')]
        labels = [InputSpec([None, 1], 'int64', 'label')]
        optim = make_optimizer(net.parameters())
        model = Model(net, inputs, labels)
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))

        dataset = MyDataset()
        model.fit(dataset, dataset, batch_size=4, epochs=10, num_workers=0)

706 707
        np.testing.assert_allclose(model._optimizer._learning_rate.last_lr,
                                   base_lr * (0.1**len(boundaries)))
708
        # static test
709 710
        paddle.enable_static()

711 712 713 714 715 716 717 718 719 720
        net = MyModel()
        inputs = [InputSpec([None, 20], 'float32', 'x')]
        labels = [InputSpec([None, 1], 'int64', 'label')]
        optim = make_optimizer(net.parameters())
        model = Model(net, inputs, labels)
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))

        dataset = MyDataset()
        model.fit(dataset, dataset, batch_size=4, epochs=10, num_workers=0)

721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807
        np.testing.assert_allclose(model._optimizer._learning_rate.last_lr,
                                   base_lr * (0.1**len(boundaries)))

    def test_fit_by_epoch(self):
        base_lr = 1e-3
        boundaries = [5, 8]
        epochs = 10
        wamup_epochs = 4

        def make_optimizer(parameters=None):
            momentum = 0.9
            weight_decay = 5e-4
            values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
            learning_rate = paddle.optimizer.lr.PiecewiseDecay(
                boundaries=boundaries, values=values)
            learning_rate = paddle.optimizer.lr.LinearWarmup(
                learning_rate=learning_rate,
                warmup_steps=wamup_epochs,
                start_lr=base_lr / 5.,
                end_lr=base_lr,
                verbose=True)
            optimizer = paddle.optimizer.Momentum(
                learning_rate=learning_rate,
                weight_decay=weight_decay,
                momentum=momentum,
                parameters=parameters)
            return optimizer

        # dynamic test
        device = paddle.set_device('cpu')
        fluid.enable_dygraph(device)
        net = MyModel()
        inputs = [InputSpec([None, 20], 'float32', 'x')]
        labels = [InputSpec([None, 1], 'int64', 'label')]
        optim = make_optimizer(net.parameters())
        model = Model(net, inputs, labels)
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))

        dataset = MyDataset()

        lr_scheduler_callback = paddle.callbacks.LRScheduler(
            by_step=False, by_epoch=True)

        model.fit(dataset,
                  dataset,
                  batch_size=4,
                  epochs=epochs,
                  num_workers=0,
                  callbacks=lr_scheduler_callback)

        cnt = 0
        for b in boundaries:
            if b + wamup_epochs <= epochs:
                cnt += 1

        np.testing.assert_allclose(model._optimizer._learning_rate.last_lr,
                                   base_lr * (0.1**cnt))
        # static test
        paddle.enable_static()

        net = MyModel()
        inputs = [InputSpec([None, 20], 'float32', 'x')]
        labels = [InputSpec([None, 1], 'int64', 'label')]
        optim = make_optimizer(net.parameters())
        model = Model(net, inputs, labels)
        model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))

        dataset = MyDataset()

        lr_scheduler_callback = paddle.callbacks.LRScheduler(
            by_step=False, by_epoch=True)

        model.fit(dataset,
                  dataset,
                  batch_size=4,
                  epochs=epochs,
                  num_workers=0,
                  callbacks=lr_scheduler_callback)

        cnt = 0
        for b in boundaries:
            if b + wamup_epochs <= epochs:
                cnt += 1

        np.testing.assert_allclose(model._optimizer._learning_rate.last_lr,
                                   base_lr * (0.1**cnt))

808

809 810
class TestRaiseError(unittest.TestCase):
    def test_input_without_name(self):
L
LielinJiang 已提交
811
        net = MyModel()
812 813
        inputs = [InputSpec([None, 10], 'float32')]
        labels = [InputSpec([None, 1], 'int64', 'label')]
814 815 816
        with self.assertRaises(ValueError):
            model = Model(net, inputs, labels)

817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
    def test_static_without_inputs(self):
        paddle.enable_static()
        net = MyModel()
        with self.assertRaises(TypeError):
            model = Model(net)

    def test_save_infer_model_without_inputs_and_run_in_dygraph(self):
        paddle.disable_static()
        net = MyModel()
        save_dir = tempfile.mkdtemp()
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        with self.assertRaises(RuntimeError):
            model = Model(net)
            model.save(save_dir, training=False)
        paddle.enable_static()
833

834 835 836 837 838 839 840 841 842 843 844 845
    def test_save_infer_model_without_file_prefix(self):
        paddle.enable_static()
        net = LeNet()
        inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
        model = Model(net, inputs)
        model.prepare()
        path = ""
        tensor_img = np.array(
            np.random.random((1, 1, 28, 28)), dtype=np.float32)
        with self.assertRaises(ValueError):
            model.save(path, training=False)

846

847 848
if __name__ == '__main__':
    unittest.main()