diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 80cf3bc9af826e935fe0fe6ccf8cad8d6924d370..5f73e7d832a7da4f5733e7c20c1c481a8fb5d09b 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -72,7 +72,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): # build dataset if use_dali: from ppcls.data.dataloader.dali import dali_dataloader - return dali_dataloader(config, mode, paddle.device.get_device(), seed) + return dali_dataloader( + config, + mode, + paddle.device.get_device(), + num_threads=config[mode]['loader']["num_workers"], + seed=seed) class_num = config.get("class_num", None) config_dataset = config[mode]['dataset'] diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index faef45e26b3dee2e17464a502f42f9886eac6518..a0b91a9a9dd468b0bb4ba0dd314341f067f6e3f5 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -143,7 +143,7 @@ class HybridValPipe(Pipeline): return self.epoch_size("Reader") -def dali_dataloader(config, mode, device, seed=None): +def dali_dataloader(config, mode, device, num_threads=4, seed=None): assert "gpu" in device, "gpu training is required for DALI" device_id = int(device.split(':')[1]) config_dataloader = config[mode] @@ -248,6 +248,7 @@ def dali_dataloader(config, mode, device, seed=None): device_id, shard_id, num_shards, + num_threads=num_threads, seed=seed + shard_id, pad_output=pad_output, output_dtype=output_dtype) @@ -270,6 +271,7 @@ def dali_dataloader(config, mode, device, seed=None): device_id=device_id, shard_id=0, num_shards=1, + num_threads=num_threads, seed=seed, pad_output=pad_output, output_dtype=output_dtype) @@ -298,6 +300,7 @@ def dali_dataloader(config, mode, device, seed=None): device_id=device_id, shard_id=shard_id, num_shards=num_shards, + num_threads=num_threads, pad_output=pad_output, output_dtype=output_dtype) else: @@ -311,6 +314,7 @@ def dali_dataloader(config, mode, device, seed=None): mean, std, device_id=device_id, + num_threads=num_threads, pad_output=pad_output, output_dtype=output_dtype) pipe.build()