提交 0e096754 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix load ckp

上级 8408d081
......@@ -106,9 +106,9 @@ def init_model(config, net, optimizer=None):
"Given dir {}.pdparams not exist.".format(checkpoints)
assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict, opti_dict = paddle(checkpoints)
para_dict, opti_dict = paddle.load(checkpoints)
net.set_dict(para_dict)
optimizer.set_dict(opti_dict)
optimizer.set_state_dict(opti_dict)
logger.info(
logger.coloring("Finish initing model from {}".format(checkpoints),
"HEADER"))
......
......@@ -83,9 +83,10 @@ def main(args):
if config.validate and ParallelEnv().local_rank == 0:
valid_dataloader = Reader(config, 'valid', places=place)()
last_epoch_id = config.get("last_epoch", 0)
best_top1_acc = 0.0 # best top1 acc record
best_top1_epoch = 0
for epoch_id in range(config.epochs):
best_top1_epoch = last_epoch_id
for epoch_id in range(last_epoch_id + 1, config.epochs):
net.train()
# 1. train with train dataset
program.run(train_dataloader, config, net, optimizer, lr_scheduler,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册