From ca51b6f7a24cff866247909c8986e303b6a8415c Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Sat, 8 May 2021 05:29:12 +0000 Subject: [PATCH] fix one card eval in multicards training --- tools/program.py | 18 ++++++++++------ tools/train.py | 56 +++++++++++++++++++++++++++++++++++------------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/tools/program.py b/tools/program.py index fd155ab8..4666d962 100644 --- a/tools/program.py +++ b/tools/program.py @@ -119,7 +119,8 @@ def create_metric(out, classes_num=1000, use_distillation=False, multilabel=False, - mode="train"): + mode="train", + use_xpu=False): """ Create measures of model accuracy, such as top1 and top5 @@ -175,11 +176,12 @@ def create_metric(out, fetch_list.append(ham_dist) # multi cards' eval - if mode != "train" and paddle.distributed.get_world_size() > 1: - for idx, fetch in enumerate(fetch_list): - fetch_list[idx] = paddle.distributed.all_reduce( - fetch, op=paddle.distributed.ReduceOp. - SUM) / paddle.distributed.get_world_size() + if not use_xpu: + if mode != "train" and paddle.distributed.get_world_size() > 1: + for idx, fetch in enumerate(fetch_list): + fetch_list[idx] = paddle.distributed.all_reduce( + fetch, op=paddle.distributed.ReduceOp. + SUM) / paddle.distributed.get_world_size() fetchs = OrderedDict() for idx, name in enumerate(metric_names): @@ -213,6 +215,7 @@ def create_fetchs(feeds, net, config, mode="train"): use_mix = config.get('use_mix') and mode == 'train' use_distillation = config.get('use_distillation') multilabel = config.get('multilabel', False) + use_xpu = config.get("use_xpu", False) out = net(feeds["image"]) @@ -229,7 +232,8 @@ def create_fetchs(feeds, net, config, mode="train"): classes_num, use_distillation, multilabel=multilabel, - mode=mode) + mode=mode, + use_xpu=use_xpu) fetchs.update(metric) return fetchs diff --git a/tools/train.py b/tools/train.py index 48e15676..38113fb6 100644 --- a/tools/train.py +++ b/tools/train.py @@ -109,21 +109,47 @@ def main(args): program.run(train_dataloader, config, dp_net, optimizer, lr_scheduler, epoch_id, 'train', vdl_writer) - # 2. validate with validate dataset - if config.validate and epoch_id % config.valid_interval == 0: - net.eval() - with paddle.no_grad(): - top1_acc = program.run(valid_dataloader, config, net, None, - None, epoch_id, 'valid', vdl_writer) - if top1_acc > best_top1_acc: - best_top1_acc = top1_acc - best_top1_epoch = epoch_id - model_path = os.path.join(config.model_save_dir, - config.ARCHITECTURE["name"]) - save_model(net, optimizer, model_path, "best_model") - message = "The best top1 acc {:.5f}, in epoch: {:d}".format( - best_top1_acc, best_top1_epoch) - logger.info(message) + if use_xpu: + if paddle.distributed.get_rank() == 0: + # 2. validate with validate dataset + if config.validate and epoch_id % config.valid_interval == 0: + net.eval() + top1_acc = program.run(valid_dataloader, config, net, + None, None, epoch_id, 'valid') + if top1_acc > best_top1_acc: + best_top1_acc = top1_acc + best_top1_epoch = epoch_id + if epoch_id % config.save_interval == 0: + model_path = os.path.join( + config.model_save_dir, + config.ARCHITECTURE["name"]) + save_model(net, optimizer, model_path, + "best_model") + message = "The best top1 acc {:.5f}, in epoch: {:d}".format( + best_top1_acc, best_top1_epoch) + logger.info("{:s}".format( + logger.coloring(message, "RED"))) + + else: + # 2. validate with validate dataset + if paddle.distributed.get_rank() == 0: + if config.validate and epoch_id % config.valid_interval == 0: + net.eval() + with paddle.no_grad(): + top1_acc = program.run(valid_dataloader, config, + net, None, None, epoch_id, + 'valid', vdl_writer) + if top1_acc > best_top1_acc: + best_top1_acc = top1_acc + best_top1_epoch = epoch_id + model_path = os.path.join( + config.model_save_dir, + config.ARCHITECTURE["name"]) + save_model(net, optimizer, model_path, + "best_model") + message = "The best top1 acc {:.5f}, in epoch: {:d}".format( + best_top1_acc, best_top1_epoch) + logger.info(message) # 3. save the persistable model if epoch_id % config.save_interval == 0: -- GitLab