提交 64fad361 编写于 作者: G gaotingquan 提交者: Tingquan Gao

Fix the exception of DALI

上级 6404d7c5
...@@ -289,8 +289,9 @@ def create_strategy(config): ...@@ -289,8 +289,9 @@ def create_strategy(config):
exec_strategy = paddle.static.ExecutionStrategy() exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1 exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = (10000 if 'AMP' in config and exec_strategy.num_iteration_per_drop_scope = (
config.AMP.get("use_pure_fp16", False) else 10) 10000
if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10)
fuse_op = True if 'AMP' in config else False fuse_op = True if 'AMP' in config else False
...@@ -357,7 +358,8 @@ def mixed_precision_optimizer(config, optimizer): ...@@ -357,7 +358,8 @@ def mixed_precision_optimizer(config, optimizer):
if 'AMP' in config: if 'AMP' in config:
amp_cfg = config.AMP if config.AMP else dict() amp_cfg = config.AMP if config.AMP else dict()
scale_loss = amp_cfg.get('scale_loss', 1.0) scale_loss = amp_cfg.get('scale_loss', 1.0)
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling', False) use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling',
False)
use_pure_fp16 = amp_cfg.get('use_pure_fp16', False) use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
optimizer = paddle.static.amp.decorate( optimizer = paddle.static.amp.decorate(
optimizer, optimizer,
...@@ -501,7 +503,21 @@ def run(dataloader, ...@@ -501,7 +503,21 @@ def run(dataloader,
use_dali = config.get('use_dali', False) use_dali = config.get('use_dali', False)
dataloader = dataloader if use_dali else dataloader() dataloader = dataloader if use_dali else dataloader()
tic = time.time() tic = time.time()
for idx, batch in enumerate(dataloader):
idx = 0
batch_size = None
while True:
# The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG
try:
batch = next(dataloader)
except StopIteration:
break
except RuntimeError:
logger.warning(
"Except RuntimeError when reading data from dataloader, try to read once again..."
)
continue
idx += 1
# ignore the warmup iters # ignore the warmup iters
if idx == 5: if idx == 5:
metric_list["batch_time"].reset() metric_list["batch_time"].reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册