From c641fb3c516891e97375f34df653f252c4e008ca Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Fri, 10 Mar 2023 03:50:45 +0000 Subject: [PATCH] debug for dali --- ppcls/data/__init__.py | 4 ++-- ppcls/data/dataloader/dali.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 5c9137b3..e4232294 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -107,8 +107,8 @@ def build_dataloader(config, *mode, seed=None): if use_dali: from ppcls.data.dataloader.dali import dali_dataloader return dali_dataloader( - config["DataLoader"], - mode, + dataloader_config, + mode[-1], paddle.device.get_device(), num_threads=num_workers, seed=seed, diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index bd654aaf..d4e0c91a 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -668,7 +668,7 @@ class DALIImageNetIterator(DALIGenericIterator): return data_batch -def dali_dataloader(config: Dict[str, Any], +def dali_dataloader(config_dataloader: Dict[str, Any], mode: str, device: str, py_num_workers: int=1, @@ -690,7 +690,6 @@ def dali_dataloader(config: Dict[str, Any], DALIImageNetIterator: Iterable DALI dataloader """ assert "gpu" in device, f"device must be \"gpu\" when running with DALI, but got {device}" - config_dataloader = config[mode] device_id = int(device.split(":")[1]) device = "gpu" seed = 42 if seed is None else seed -- GitLab