提交 c641fb3c 编写于 作者: G gaotingquan 提交者: Wei Shengyu

debug for dali

上级 58daf805
......@@ -107,8 +107,8 @@ def build_dataloader(config, *mode, seed=None):
if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(
config["DataLoader"],
mode,
dataloader_config,
mode[-1],
paddle.device.get_device(),
num_threads=num_workers,
seed=seed,
......
......@@ -668,7 +668,7 @@ class DALIImageNetIterator(DALIGenericIterator):
return data_batch
def dali_dataloader(config: Dict[str, Any],
def dali_dataloader(config_dataloader: Dict[str, Any],
mode: str,
device: str,
py_num_workers: int=1,
......@@ -690,7 +690,6 @@ def dali_dataloader(config: Dict[str, Any],
DALIImageNetIterator: Iterable DALI dataloader
"""
assert "gpu" in device, f"device must be \"gpu\" when running with DALI, but got {device}"
config_dataloader = config[mode]
device_id = int(device.split(":")[1])
device = "gpu"
seed = 42 if seed is None else seed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册