-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def__init__(self,opt):
def__init__(self,cfg):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
Args:
cfg (Dict)-- configs of Model.
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
In this function, you should first call <super(YourClass, self).__init__(self, cfg)>
Then, you need to define four lists:
-- self.losses (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.losses (dict): specify the training losses that you want to plot and save.
-- self.nets (dict): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
-- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example.
"""
self.opt=opt
self.isTrain=opt.isTrain
self.cfg=cfg
self.is_train=cfg.is_train
self.save_dir=os.path.join(
opt.output_dir,
opt.model.name)# save all the checkpoints to save_dir
cfg.output_dir,
cfg.model.name)# save all the checkpoints to save_dir
self.losses=OrderedDict()
self.model_names=[]
self.visual_names=[]
self.optimizers=[]
self.optimizer_names=[]
self.nets=OrderedDict()
self.visual_items=OrderedDict()
self.optimizers=OrderedDict()
self.image_paths=[]
self.metric=0# used for learning rate policy 'plateau'
@staticmethod
defmodify_commandline_options(parser,is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
returnparser
@abstractmethod
defset_input(self,input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.