提交 355e5565 编写于 作者: L LielinJiang

clean code

上级 10ab8ea5
# from .transforms import RandomCrop, Resize, RandomHorizontalFlip, PairedRandomCrop, PairedRandomHorizontalFlip, Normalize, Permute
from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip
......@@ -270,22 +270,12 @@ class Trainer:
if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1
# for name in self.model.model_names:
# if isinstance(name, str):
# net = getattr(self.model, 'net' + name)
# net.set_dict(state_dicts['net' + name])
for net_name, net in self.model.nets.items():
net.set_dict(state_dicts[net_name])
for opt_name, opt in self.model.optimizers.items():
opt.set_dict(state_dicts[opt_name])
# for name in self.model.optimizer_names:
# if isinstance(name, str):
# opt = getattr(self.model, name)
# opt.set_dict(state_dicts[name])
def load(self, weight_path):
state_dicts = load(weight_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册