提交 64fad361 编写于 作者: G gaotingquan 提交者: Tingquan Gao

Fix the exception of DALI

上级 6404d7c5
......@@ -289,8 +289,9 @@ def create_strategy(config):
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = (10000 if 'AMP' in config and
config.AMP.get("use_pure_fp16", False) else 10)
exec_strategy.num_iteration_per_drop_scope = (
10000
if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10)
fuse_op = True if 'AMP' in config else False
......@@ -357,7 +358,8 @@ def mixed_precision_optimizer(config, optimizer):
if 'AMP' in config:
amp_cfg = config.AMP if config.AMP else dict()
scale_loss = amp_cfg.get('scale_loss', 1.0)
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling', False)
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling',
False)
use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
optimizer = paddle.static.amp.decorate(
optimizer,
......@@ -501,7 +503,21 @@ def run(dataloader,
use_dali = config.get('use_dali', False)
dataloader = dataloader if use_dali else dataloader()
tic = time.time()
for idx, batch in enumerate(dataloader):
idx = 0
batch_size = None
while True:
# The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG
try:
batch = next(dataloader)
except StopIteration:
break
except RuntimeError:
logger.warning(
"Except RuntimeError when reading data from dataloader, try to read once again..."
)
continue
idx += 1
# ignore the warmup iters
if idx == 5:
metric_list["batch_time"].reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册