From 1e4925704d7f41ac233dd0e8e87f1ebe9365c954 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Thu, 3 Dec 2020 15:48:35 +0800 Subject: [PATCH] fix dali train (#446) --- tools/static/program.py | 18 +++++++++++------- tools/{ => static}/run_dali.sh | 8 ++++---- 2 files changed, 15 insertions(+), 11 deletions(-) rename tools/{ => static}/run_dali.sh (53%) diff --git a/tools/static/program.py b/tools/static/program.py index 147aa4b7..8741cb09 100644 --- a/tools/static/program.py +++ b/tools/static/program.py @@ -416,11 +416,15 @@ def run(dataloader, # ignore the warmup iters if idx == 5: batch_time.reset() - batch_size = batch[0].shape()[0] - feed_dict = { - key.name: batch[idx] - for idx, key in enumerate(feeds.values()) - } + if use_dali: + batch_size = batch[0]["feed_image"].shape()[0] + feed_dict = batch[0] + else: + batch_size = batch[0].shape()[0] + feed_dict = { + key.name: batch[idx] + for idx, key in enumerate(feeds.values()) + } metrics = exe.run(program=program, feed=feed_dict, fetch_list=fetch_list) @@ -452,7 +456,7 @@ def run(dataloader, global total_step logger.scaler('loss', metrics[0][0], total_step, vdl_writer) total_step += 1 - if mode == 'eval': + if mode == 'valid': if idx % config.get('print_interval', 10) == 0: logger.info("{:s} step:{:<4d} {:s}".format(mode, idx, fetchs_str)) @@ -471,7 +475,7 @@ def run(dataloader, for m in metric_list] + [batch_time.total]) + 's' ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count / batch_time.sum) - if mode == 'eval': + if mode == 'valid': logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info)) else: end_epoch_str = "END epoch:{:<3d}".format(epoch) diff --git a/tools/run_dali.sh b/tools/static/run_dali.sh similarity index 53% rename from tools/run_dali.sh rename to tools/static/run_dali.sh index 7146dbd8..63f0344d 100644 --- a/tools/run_dali.sh +++ b/tools/static/run_dali.sh @@ -1,11 +1,11 @@ #!/usr/bin/env bash -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +export CUDA_VISIBLE_DEVICES="0,1,2,3" export FLAGS_fraction_of_gpu_memory_to_use=0.80 -python -m paddle.distributed.launch \ - --selected_gpus="0,1,2,3,4,5,6,7" \ - tools/train.py \ +python3.7 -m paddle.distributed.launch \ + --selected_gpus="0,1,2,3" \ + tools/static/train.py \ -c ./configs/ResNet/ResNet50.yaml \ -o print_interval=10 \ -o use_dali=true -- GitLab