提交 8adb5f4a 编写于 作者: T Tingquan Gao

Revert "debug for dali"

This reverts commit c641fb3c.
上级 578054dd
......@@ -107,8 +107,8 @@ def build_dataloader(config, *mode, seed=None):
if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(
dataloader_config,
mode[-1],
config["DataLoader"],
mode,
paddle.device.get_device(),
num_threads=num_workers,
seed=seed,
......
......@@ -668,7 +668,7 @@ class DALIImageNetIterator(DALIGenericIterator):
return data_batch
def dali_dataloader(config_dataloader: Dict[str, Any],
def dali_dataloader(config: Dict[str, Any],
mode: str,
device: str,
py_num_workers: int=1,
......@@ -690,6 +690,7 @@ def dali_dataloader(config_dataloader: Dict[str, Any],
DALIImageNetIterator: Iterable DALI dataloader
"""
assert "gpu" in device, f"device must be \"gpu\" when running with DALI, but got {device}"
config_dataloader = config[mode]
device_id = int(device.split(":")[1])
device = "gpu"
seed = 42 if seed is None else seed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册