提交 55fa094e 编写于 作者: L liuyuhui

fix one card eval in multicards training

上级 dd70cb1b
......@@ -119,7 +119,8 @@ def create_metric(out,
classes_num=1000,
use_distillation=False,
multilabel=False,
mode="train"):
mode="train",
use_xpu=False):
"""
Create measures of model accuracy, such as top1 and top5
......@@ -175,11 +176,12 @@ def create_metric(out,
fetch_list.append(ham_dist)
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
for idx, fetch in enumerate(fetch_list):
fetch_list[idx] = paddle.distributed.all_reduce(
fetch, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
if not use_xpu:
if mode != "train" and paddle.distributed.get_world_size() > 1:
for idx, fetch in enumerate(fetch_list):
fetch_list[idx] = paddle.distributed.all_reduce(
fetch, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs = OrderedDict()
for idx, name in enumerate(metric_names):
......@@ -213,6 +215,7 @@ def create_fetchs(feeds, net, config, mode="train"):
use_mix = config.get('use_mix') and mode == 'train'
use_distillation = config.get('use_distillation')
multilabel = config.get('multilabel', False)
use_xpu = config.get("use_xpu", False)
out = net(feeds["image"])
......@@ -229,7 +232,8 @@ def create_fetchs(feeds, net, config, mode="train"):
classes_num,
use_distillation,
multilabel=multilabel,
mode=mode)
mode=mode,
use_xpu=use_xpu)
fetchs.update(metric)
return fetchs
......
......@@ -109,21 +109,47 @@ def main(args):
program.run(train_dataloader, config, dp_net, optimizer,
lr_scheduler, epoch_id, 'train', vdl_writer)
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
with paddle.no_grad():
top1_acc = program.run(valid_dataloader, config, net, None,
None, epoch_id, 'valid', vdl_writer)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, "best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info(message)
if use_xpu:
if paddle.distributed.get_rank() == 0:
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
top1_acc = program.run(valid_dataloader, config, net,
None, None, epoch_id, 'valid')
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
if epoch_id % config.save_interval == 0:
model_path = os.path.join(
config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path,
"best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info("{:s}".format(
logger.coloring(message, "RED")))
else:
# 2. validate with validate dataset
if paddle.distributed.get_rank() == 0:
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
with paddle.no_grad():
top1_acc = program.run(valid_dataloader, config,
net, None, None, epoch_id,
'valid', vdl_writer)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
model_path = os.path.join(
config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path,
"best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info(message)
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册