提交 dbad6f4c 编写于 作者: C chenguowei01

add resume model

上级 aa160894
......@@ -79,7 +79,13 @@ def parse_args():
parser.add_argument(
'--pretrained_model',
dest='pretrained_model',
help='The path of pretrained weight',
help='The path of pretrained model',
type=str,
default=None)
parser.add_argument(
'--resume_model',
dest='resume_model',
help='The path of resume model',
type=str,
default=None)
parser.add_argument(
......@@ -127,7 +133,7 @@ def train(model,
start_epoch = 0
if resume_model is not None:
start_epoch = resume(optimizer, resume_model)
start_epoch = resume(model, optimizer, resume_model)
elif pretrained_model is not None:
load_pretrained_model(model, pretrained_model)
......@@ -233,7 +239,6 @@ def main(args):
# todo, may less one than len(loader)
num_steps_each_epoch = len(train_dataset) // (
args.batch_size * ParallelEnv().nranks)
print(num_steps_each_epoch, 'num_steps_each_epoch')
decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
......@@ -253,6 +258,7 @@ def main(args):
num_epochs=args.num_epochs,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
resume_model=args.resume_model,
save_interval_epochs=args.save_interval_epochs,
num_classes=train_dataset.num_classes,
num_workers=args.num_workers)
......
......@@ -78,12 +78,13 @@ def load_pretrained_model(model, pretrained_model):
pretrained_model))
def resume(optimizer, resume_model):
def resume(model, optimizer, resume_model):
if resume_model is not None:
logging.info('Resume model from {}'.format(resume_model))
if os.path.exists(resume_model):
ckpt_path = os.path.join(resume_model, 'model')
_, opti_state_dict = fluid.load_dygraph(ckpt_path)
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
optimizer.set_dict(opti_state_dict)
epoch = resume_model.split('_')[-1]
if epoch.isdigit():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册