base_model.py 5.2 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19 20 21
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import os
import paddle
import numpy as np
from collections import OrderedDict
from abc import ABC, abstractmethod

L
LielinJiang 已提交
22
from ..solver.lr_scheduler import build_lr_scheduler
L
LielinJiang 已提交
23 24 25 26 27 28 29 30 31 32 33


class BaseModel(ABC):
    """This class is an abstract base class (ABC) for models.
    To create a subclass, you need to implement the following five functions:
        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
        -- <set_input>:                     unpack data from dataset and apply preprocessing.
        -- <forward>:                       produce intermediate results.
        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
    """
L
LielinJiang 已提交
34
    def __init__(self, cfg):
L
LielinJiang 已提交
35 36
        """Initialize the BaseModel class.

L
LielinJiang 已提交
37 38
        Args:
            cfg (Dict)-- configs of Model.
L
LielinJiang 已提交
39 40

        When creating your custom class, you need to implement your own initialization.
L
LielinJiang 已提交
41
        In this function, you should first call <super(YourClass, self).__init__(self, cfg)>
L
LielinJiang 已提交
42
        Then, you need to define four lists:
L
LielinJiang 已提交
43 44
            -- self.losses (dict):          specify the training losses that you want to plot and save.
            -- self.nets (dict):         define networks used in our training.
L
LielinJiang 已提交
45
            -- self.visual_names (str list):        specify the images that you want to display and save.
L
LielinJiang 已提交
46 47 48
            -- self.optimizers (dict):    define and initialize optimizers. You can define one optimizer for each network.
                                          If two networks are updated at the same time, you can use itertools.chain to group them.
                                          See cycle_gan_model.py for an example.
L
LielinJiang 已提交
49
        """
L
LielinJiang 已提交
50 51
        self.cfg = cfg
        self.is_train = cfg.is_train
L
LielinJiang 已提交
52
        self.save_dir = os.path.join(
L
LielinJiang 已提交
53 54
            cfg.output_dir,
            cfg.model.name)  # save all the checkpoints to save_dir
L
LielinJiang 已提交
55

L
lijianshe02 已提交
56
        self.losses = OrderedDict()
L
LielinJiang 已提交
57 58 59
        self.nets = OrderedDict()
        self.visual_items = OrderedDict()
        self.optimizers = OrderedDict()
L
LielinJiang 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'

    @abstractmethod
    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): includes the data itself and its metadata information.
        """
        pass

    @abstractmethod
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        pass

    @abstractmethod
    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        pass

L
LielinJiang 已提交
82
    def build_lr_scheduler(self):
L
LielinJiang 已提交
83
        self.lr_scheduler = build_lr_scheduler(self.cfg.lr_scheduler)
L
LielinJiang 已提交
84 85 86 87 88 89 90 91

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

L
LielinJiang 已提交
92
    def test(self):
L
LielinJiang 已提交
93 94 95 96 97
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
L
LielinJiang 已提交
98
        with paddle.no_grad():
L
LielinJiang 已提交
99
            self.forward()
L
LielinJiang 已提交
100 101 102 103 104 105 106 107 108 109 110
            self.compute_visuals()

    def compute_visuals(self):
        """Calculate additional output images for visdom and HTML visualization"""
        pass

    def get_image_paths(self):
        """ Return image paths that are used to load current data"""
        return self.image_paths

    def get_current_visuals(self):
L
LielinJiang 已提交
111 112
        """Return visualization images."""
        return self.visual_items
L
LielinJiang 已提交
113 114 115

    def get_current_losses(self):
        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
L
lijianshe02 已提交
116
        return self.losses
L
LielinJiang 已提交
117 118 119

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
L
LielinJiang 已提交
120
        Args:
L
LielinJiang 已提交
121 122 123 124 125 126 127 128 129
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.trainable = requires_grad