提交 a9bce040 编写于 作者: L liuyuhui

fix tools/train.py

上级 70f17823
...@@ -109,47 +109,21 @@ def main(args): ...@@ -109,47 +109,21 @@ def main(args):
program.run(train_dataloader, config, dp_net, optimizer, program.run(train_dataloader, config, dp_net, optimizer,
lr_scheduler, epoch_id, 'train', vdl_writer) lr_scheduler, epoch_id, 'train', vdl_writer)
if use_xpu: # 2. validate with validate dataset
if paddle.distributed.get_rank() == 0: if config.validate and epoch_id % config.valid_interval == 0:
# 2. validate with validate dataset net.eval()
if config.validate and epoch_id % config.valid_interval == 0: with paddle.no_grad():
net.eval() top1_acc = program.run(valid_dataloader, config, net, None,
top1_acc = program.run(valid_dataloader, config, net, None, epoch_id, 'valid', vdl_writer)
None, None, epoch_id, 'valid') if top1_acc > best_top1_acc:
if top1_acc > best_top1_acc: best_top1_acc = top1_acc
best_top1_acc = top1_acc best_top1_epoch = epoch_id
best_top1_epoch = epoch_id model_path = os.path.join(config.model_save_dir,
if epoch_id % config.save_interval == 0: config.ARCHITECTURE["name"])
model_path = os.path.join( save_model(net, optimizer, model_path, "best_model")
config.model_save_dir, message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
config.ARCHITECTURE["name"]) best_top1_acc, best_top1_epoch)
save_model(net, optimizer, model_path, logger.info(message)
"best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info("{:s}".format(
logger.coloring(message, "RED")))
else:
# 2. validate with validate dataset
if paddle.distributed.get_rank() == 0:
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
with paddle.no_grad():
top1_acc = program.run(valid_dataloader, config,
net, None, None, epoch_id,
'valid', vdl_writer)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
model_path = os.path.join(
config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path,
"best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info(message)
# 3. save the persistable model # 3. save the persistable model
if epoch_id % config.save_interval == 0: if epoch_id % config.save_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册