提交 8adb5f4a 编写于 作者: T Tingquan Gao

Revert "debug for dali"

This reverts commit c641fb3c.
上级 578054dd
...@@ -107,8 +107,8 @@ def build_dataloader(config, *mode, seed=None): ...@@ -107,8 +107,8 @@ def build_dataloader(config, *mode, seed=None):
if use_dali: if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader( return dali_dataloader(
dataloader_config, config["DataLoader"],
mode[-1], mode,
paddle.device.get_device(), paddle.device.get_device(),
num_threads=num_workers, num_threads=num_workers,
seed=seed, seed=seed,
......
...@@ -668,7 +668,7 @@ class DALIImageNetIterator(DALIGenericIterator): ...@@ -668,7 +668,7 @@ class DALIImageNetIterator(DALIGenericIterator):
return data_batch return data_batch
def dali_dataloader(config_dataloader: Dict[str, Any], def dali_dataloader(config: Dict[str, Any],
mode: str, mode: str,
device: str, device: str,
py_num_workers: int=1, py_num_workers: int=1,
...@@ -690,6 +690,7 @@ def dali_dataloader(config_dataloader: Dict[str, Any], ...@@ -690,6 +690,7 @@ def dali_dataloader(config_dataloader: Dict[str, Any],
DALIImageNetIterator: Iterable DALI dataloader DALIImageNetIterator: Iterable DALI dataloader
""" """
assert "gpu" in device, f"device must be \"gpu\" when running with DALI, but got {device}" assert "gpu" in device, f"device must be \"gpu\" when running with DALI, but got {device}"
config_dataloader = config[mode]
device_id = int(device.split(":")[1]) device_id = int(device.split(":")[1])
device = "gpu" device = "gpu"
seed = 42 if seed is None else seed seed = 42 if seed is None else seed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册