base_model.py 5.5 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import os
import paddle
import numpy as np
from collections import OrderedDict
from abc import ABC, abstractmethod

L
LielinJiang 已提交
8
from ..solver.lr_scheduler import build_lr_scheduler
L
LielinJiang 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28


class BaseModel(ABC):
    """This class is an abstract base class (ABC) for models.
    To create a subclass, you need to implement the following five functions:
        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
        -- <set_input>:                     unpack data from dataset and apply preprocessing.
        -- <forward>:                       produce intermediate results.
        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
    """
    def __init__(self, opt):
        """Initialize the BaseModel class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions

        When creating your custom class, you need to implement your own initialization.
        In this function, you should first call <BaseModel.__init__(self, opt)>
        Then, you need to define four lists:
29
            -- self.loss (str list):          specify the training losses that you want to plot and save.
L
LielinJiang 已提交
30 31 32 33 34 35
            -- self.model_names (str list):         define networks used in our training.
            -- self.visual_names (str list):        specify the images that you want to display and save.
            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
        """
        self.opt = opt
        self.isTrain = opt.isTrain
L
LielinJiang 已提交
36 37 38 39
        self.save_dir = os.path.join(
            opt.output_dir,
            opt.model.name)  # save all the checkpoints to save_dir

40
        self.loss = OrderedDict()
L
LielinJiang 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.optimizer_names = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new model-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        return parser

    @abstractmethod
    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): includes the data itself and its metadata information.
        """
        pass

    @abstractmethod
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        pass

    @abstractmethod
    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        pass

L
LielinJiang 已提交
80 81
    def build_lr_scheduler(self):
        self.lr_scheduler = build_lr_scheduler(self.opt.lr_scheduler)
L
LielinJiang 已提交
82 83 84 85 86 87 88 89

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

L
LielinJiang 已提交
90
    def test(self):
L
LielinJiang 已提交
91 92 93 94 95
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
L
LielinJiang 已提交
96
        with paddle.no_grad():
L
LielinJiang 已提交
97
            self.forward()
L
LielinJiang 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111
            self.compute_visuals()

    def compute_visuals(self):
        """Calculate additional output images for visdom and HTML visualization"""
        pass

    def get_image_paths(self):
        """ Return image paths that are used to load current data"""
        return self.image_paths

    def get_current_visuals(self):
        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
        visual_ret = OrderedDict()
        for name in self.visual_names:
L
LielinJiang 已提交
112
            if isinstance(name, str) and hasattr(self, name):
L
LielinJiang 已提交
113 114 115 116 117
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
118
        return self.loss
L
LielinJiang 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    # print('trainable:', param.trainable)
                    param.trainable = requires_grad
                    # param.stop_gradient = not requires_grad