diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 5c9137b35cb516a79b435748fb64fd093f55b433..e4232294a1d5a347ea742148fd1f98a4cccac5cf 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -107,8 +107,8 @@ def build_dataloader(config, *mode, seed=None): if use_dali: from ppcls.data.dataloader.dali import dali_dataloader return dali_dataloader( - config["DataLoader"], - mode, + dataloader_config, + mode[-1], paddle.device.get_device(), num_threads=num_workers, seed=seed, diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index bd654aafaa2e9ee658c37eff7cbd860b97e61890..d4e0c91a3ef3f7e8beca33c96d77e8c1f04caf5a 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -668,7 +668,7 @@ class DALIImageNetIterator(DALIGenericIterator): return data_batch -def dali_dataloader(config: Dict[str, Any], +def dali_dataloader(config_dataloader: Dict[str, Any], mode: str, device: str, py_num_workers: int=1, @@ -690,7 +690,6 @@ def dali_dataloader(config: Dict[str, Any], DALIImageNetIterator: Iterable DALI dataloader """ assert "gpu" in device, f"device must be \"gpu\" when running with DALI, but got {device}" - config_dataloader = config[mode] device_id = int(device.split(":")[1]) device = "gpu" seed = 42 if seed is None else seed