提交 dbad6f4c 编写于 作者: C chenguowei01

add resume model

上级 aa160894
...@@ -79,7 +79,13 @@ def parse_args(): ...@@ -79,7 +79,13 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--pretrained_model', '--pretrained_model',
dest='pretrained_model', dest='pretrained_model',
help='The path of pretrained weight', help='The path of pretrained model',
type=str,
default=None)
parser.add_argument(
'--resume_model',
dest='resume_model',
help='The path of resume model',
type=str, type=str,
default=None) default=None)
parser.add_argument( parser.add_argument(
...@@ -127,7 +133,7 @@ def train(model, ...@@ -127,7 +133,7 @@ def train(model,
start_epoch = 0 start_epoch = 0
if resume_model is not None: if resume_model is not None:
start_epoch = resume(optimizer, resume_model) start_epoch = resume(model, optimizer, resume_model)
elif pretrained_model is not None: elif pretrained_model is not None:
load_pretrained_model(model, pretrained_model) load_pretrained_model(model, pretrained_model)
...@@ -233,7 +239,6 @@ def main(args): ...@@ -233,7 +239,6 @@ def main(args):
# todo, may less one than len(loader) # todo, may less one than len(loader)
num_steps_each_epoch = len(train_dataset) // ( num_steps_each_epoch = len(train_dataset) // (
args.batch_size * ParallelEnv().nranks) args.batch_size * ParallelEnv().nranks)
print(num_steps_each_epoch, 'num_steps_each_epoch')
decay_step = args.num_epochs * num_steps_each_epoch decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay( lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, decay_step, end_learning_rate=0, power=0.9) args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
...@@ -253,6 +258,7 @@ def main(args): ...@@ -253,6 +258,7 @@ def main(args):
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
batch_size=args.batch_size, batch_size=args.batch_size,
pretrained_model=args.pretrained_model, pretrained_model=args.pretrained_model,
resume_model=args.resume_model,
save_interval_epochs=args.save_interval_epochs, save_interval_epochs=args.save_interval_epochs,
num_classes=train_dataset.num_classes, num_classes=train_dataset.num_classes,
num_workers=args.num_workers) num_workers=args.num_workers)
......
...@@ -78,12 +78,13 @@ def load_pretrained_model(model, pretrained_model): ...@@ -78,12 +78,13 @@ def load_pretrained_model(model, pretrained_model):
pretrained_model)) pretrained_model))
def resume(optimizer, resume_model): def resume(model, optimizer, resume_model):
if resume_model is not None: if resume_model is not None:
logging.info('Resume model from {}'.format(resume_model)) logging.info('Resume model from {}'.format(resume_model))
if os.path.exists(resume_model): if os.path.exists(resume_model):
ckpt_path = os.path.join(resume_model, 'model') ckpt_path = os.path.join(resume_model, 'model')
_, opti_state_dict = fluid.load_dygraph(ckpt_path) para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
optimizer.set_dict(opti_state_dict) optimizer.set_dict(opti_state_dict)
epoch = resume_model.split('_')[-1] epoch = resume_model.split('_')[-1]
if epoch.isdigit(): if epoch.isdigit():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册