When creating your custom class, you need to implement your own initialization.
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
Then, you need to define four lists:
-- self.loss (str list): specify the training losses that you want to plot and save.
-- self.losses (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
...
@@ -37,7 +37,7 @@ class BaseModel(ABC):
...
@@ -37,7 +37,7 @@ class BaseModel(ABC):
opt.output_dir,
opt.output_dir,
opt.model.name)# save all the checkpoints to save_dir
opt.model.name)# save all the checkpoints to save_dir
self.loss=OrderedDict()
self.losses=OrderedDict()
self.model_names=[]
self.model_names=[]
self.visual_names=[]
self.visual_names=[]
self.optimizers=[]
self.optimizers=[]
...
@@ -115,7 +115,7 @@ class BaseModel(ABC):
...
@@ -115,7 +115,7 @@ class BaseModel(ABC):
defget_current_losses(self):
defget_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""