base_model.py 9.0 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19 20 21
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import os
import paddle
import numpy as np
from collections import OrderedDict
from abc import ABC, abstractmethod

22 23 24
from .criterions.builder import build_criterion
from ..solver import build_lr_scheduler, build_optimizer
from ..utils.visual import tensor2img
L
LielinJiang 已提交
25 26 27


class BaseModel(ABC):
28
    r"""This class is an abstract base class (ABC) for models.
L
LielinJiang 已提交
29
    To create a subclass, you need to implement the following five functions:
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
        -- <__init__>:          initialize the class.
        -- <setup_input>:       unpack data from dataset and apply preprocessing.
        -- <forward>:           produce intermediate results.
        -- <train_iter>:        calculate losses, gradients, and update network weights.

    # trainer training logic:
    #
    #                build_model                               ||    model(BaseModel)
    #                     |                                    ||
    #               build_dataloader                           ||    dataloader
    #                     |                                    ||
    #               model.setup_lr_schedulers                  ||    lr_scheduler
    #                     |                                    ||
    #               model.setup_optimizers                     ||    optimizers
    #                     |                                    ||
    #     train loop (model.setup_input + model.train_iter)    ||    train loop
    #                     |                                    ||
    #         print log (model.get_current_losses)             ||
    #                     |                                    ||
    #         save checkpoint (model.nets)                     \/

L
LielinJiang 已提交
51
    """
L
LielinJiang 已提交
52
    def __init__(self, params=None):
L
LielinJiang 已提交
53 54 55
        """Initialize the BaseModel class.

        When creating your custom class, you need to implement your own initialization.
L
LielinJiang 已提交
56
        In this function, you should first call <super(YourClass, self).__init__(self, cfg)>
L
LielinJiang 已提交
57
        Then, you need to define four lists:
L
LielinJiang 已提交
58 59
            -- self.losses (dict):          specify the training losses that you want to plot and save.
            -- self.nets (dict):         define networks used in our training.
L
LielinJiang 已提交
60
            -- self.visual_names (str list):        specify the images that you want to display and save.
L
LielinJiang 已提交
61 62 63
            -- self.optimizers (dict):    define and initialize optimizers. You can define one optimizer for each network.
                                          If two networks are updated at the same time, you can use itertools.chain to group them.
                                          See cycle_gan_model.py for an example.
L
LielinJiang 已提交
64 65 66

        Args:
            params (dict): Hyper params for train or test. Default: None.
L
LielinJiang 已提交
67
        """
L
LielinJiang 已提交
68 69 70
        self.params = params
        self.is_train = True if self.params is None else self.params.get(
            'is_train', True)
L
LielinJiang 已提交
71

L
LielinJiang 已提交
72 73
        self.nets = OrderedDict()
        self.optimizers = OrderedDict()
74 75 76
        self.metrics = OrderedDict()
        self.losses = OrderedDict()
        self.visual_items = OrderedDict()
L
LielinJiang 已提交
77 78

    @abstractmethod
79
    def setup_input(self, input):
L
LielinJiang 已提交
80 81
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

82
        Args:
L
LielinJiang 已提交
83 84 85 86 87 88
            input (dict): includes the data itself and its metadata information.
        """
        pass

    @abstractmethod
    def forward(self):
89
        """Run forward pass; called by both functions <train_iter> and <test_iter>."""
L
LielinJiang 已提交
90 91 92
        pass

    @abstractmethod
93
    def train_iter(self, optims=None):
L
LielinJiang 已提交
94 95 96
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        pass

97 98 99
    def set_total_iter(self, total_iter):
        self.total_iter = total_iter

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    def test_iter(self, metrics=None):
        """Calculate metrics; called in every test iteration"""
        self.eval()
        with paddle.no_grad():
            self.forward()
        self.train()

    def setup_train_mode(self, is_train):
        self.is_train = is_train

    def setup_lr_schedulers(self, cfg):
        self.lr_scheduler = build_lr_scheduler(cfg)
        return self.lr_scheduler

    def setup_optimizers(self, lr, cfg):
        if cfg.get('name', None):
            cfg_ = cfg.copy()
            net_names = cfg_.pop('net_names')
            parameters = []
            for net_name in net_names:
                parameters += self.nets[net_name].parameters()
            self.optimizers['optim'] = build_optimizer(cfg_, lr, parameters)
        else:
            for opt_name, opt_cfg in cfg.items():
                cfg_ = opt_cfg.copy()
                net_names = cfg_.pop('net_names')
                parameters = []
                for net_name in net_names:
                    parameters += self.nets[net_name].parameters()
                self.optimizers[opt_name] = build_optimizer(
                    cfg_, lr, parameters)

        return self.optimizers

    def setup_metrics(self, cfg):
L
LielinJiang 已提交
135
        from ..metrics import build_metric
136 137 138 139 140 141 142 143
        if isinstance(list(cfg.values())[0], dict):
            for metric_name, cfg_ in cfg.items():
                self.metrics[metric_name] = build_metric(cfg_)
        else:
            metric = build_metric(cfg)
            self.metrics[metric.__class__.__name__] = metric

        return self.metrics
L
LielinJiang 已提交
144 145

    def eval(self):
146 147 148
        """Make nets eval mode during test time"""
        for net in self.nets.values():
            net.eval()
L
LielinJiang 已提交
149

150 151 152 153
    def train(self):
        """Make nets train mode during train time"""
        for net in self.nets.values():
            net.train()
L
LielinJiang 已提交
154 155 156 157 158 159 160

    def compute_visuals(self):
        """Calculate additional output images for visdom and HTML visualization"""
        pass

    def get_image_paths(self):
        """ Return image paths that are used to load current data"""
L
LielinJiang 已提交
161 162 163
        if hasattr(self, 'image_paths'):
            return self.image_paths
        return []
L
LielinJiang 已提交
164 165

    def get_current_visuals(self):
L
LielinJiang 已提交
166 167
        """Return visualization images."""
        return self.visual_items
L
LielinJiang 已提交
168 169 170

    def get_current_losses(self):
        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
L
lijianshe02 已提交
171
        return self.losses
L
LielinJiang 已提交
172 173 174

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
L
LielinJiang 已提交
175
        Args:
176 177
            nets (network list): a list of networks
            requires_grad (bool): whether the networks require gradients or not
L
LielinJiang 已提交
178 179 180 181 182 183 184
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.trainable = requires_grad
185

L
lzzyzlbb 已提交
186
    def export_model(self, export_model, output_dir=None, inputs_size=[], export_serving_model=False, model_name=None):
187 188
        inputs_num = 0
        for net in export_model:
L
LielinJiang 已提交
189 190 191 192 193
            input_spec = [
                paddle.static.InputSpec(shape=inputs_size[inputs_num + i],
                                        dtype="float32")
                for i in range(net["inputs_num"])
            ]
194 195 196 197
            inputs_num = inputs_num + net["inputs_num"]
            static_model = paddle.jit.to_static(self.nets[net["name"]],
                                                input_spec=input_spec)
            if output_dir is None:
L
LielinJiang 已提交
198
                output_dir = 'inference_model'
L
lzzyzlbb 已提交
199 200 201
            if model_name is None:
                model_name = '{}_{}'.format(self.__class__.__name__.lower(),
                                               net["name"])
L
LielinJiang 已提交
202 203 204
            paddle.jit.save(
                static_model,
                os.path.join(
L
lzzyzlbb 已提交
205
                    output_dir, model_name))
L
lzzyzlbb 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218
            if export_serving_model:
                from paddle_serving_client.io import inference_model_to_serving
                model_name = '{}_{}'.format(self.__class__.__name__.lower(),
                                                    net["name"])

                inference_model_to_serving(
                    dirname=output_dir,
                    serving_server="{}/{}/serving_server".format(output_dir,
                                                                model_name),
                    serving_client="{}/{}/serving_client".format(output_dir,
                                                                model_name),
                    model_filename="{}.pdmodel".format(model_name),
                    params_filename="{}.pdiparams".format(model_name))