未验证 提交 1e492570 编写于 作者: L littletomatodonkey 提交者: GitHub

fix dali train (#446)

上级 46dfc57e
......@@ -416,11 +416,15 @@ def run(dataloader,
# ignore the warmup iters
if idx == 5:
batch_time.reset()
batch_size = batch[0].shape()[0]
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
if use_dali:
batch_size = batch[0]["feed_image"].shape()[0]
feed_dict = batch[0]
else:
batch_size = batch[0].shape()[0]
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
metrics = exe.run(program=program,
feed=feed_dict,
fetch_list=fetch_list)
......@@ -452,7 +456,7 @@ def run(dataloader,
global total_step
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1
if mode == 'eval':
if mode == 'valid':
if idx % config.get('print_interval', 10) == 0:
logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
fetchs_str))
......@@ -471,7 +475,7 @@ def run(dataloader,
for m in metric_list] + [batch_time.total]) + 's'
ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count /
batch_time.sum)
if mode == 'eval':
if mode == 'valid':
logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info))
else:
end_epoch_str = "END epoch:{:<3d}".format(epoch)
......
#!/usr/bin/env bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export FLAGS_fraction_of_gpu_memory_to_use=0.80
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3,4,5,6,7" \
tools/train.py \
python3.7 -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
tools/static/train.py \
-c ./configs/ResNet/ResNet50.yaml \
-o print_interval=10 \
-o use_dali=true
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册