From 3f8f9814236eaafe7a04e2a689fcd6bc46460fff Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 21 Apr 2022 08:06:15 +0000 Subject: [PATCH] fix: support specify DALI threads num --- ppcls/data/__init__.py | 7 ++++++- ppcls/data/dataloader/dali.py | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 80cf3bc9..5f73e7d8 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -72,7 +72,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): # build dataset if use_dali: from ppcls.data.dataloader.dali import dali_dataloader - return dali_dataloader(config, mode, paddle.device.get_device(), seed) + return dali_dataloader( + config, + mode, + paddle.device.get_device(), + num_threads=config[mode]['loader']["num_workers"], + seed=seed) class_num = config.get("class_num", None) config_dataset = config[mode]['dataset'] diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index faef45e2..a0b91a9a 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -143,7 +143,7 @@ class HybridValPipe(Pipeline): return self.epoch_size("Reader") -def dali_dataloader(config, mode, device, seed=None): +def dali_dataloader(config, mode, device, num_threads=4, seed=None): assert "gpu" in device, "gpu training is required for DALI" device_id = int(device.split(':')[1]) config_dataloader = config[mode] @@ -248,6 +248,7 @@ def dali_dataloader(config, mode, device, seed=None): device_id, shard_id, num_shards, + num_threads=num_threads, seed=seed + shard_id, pad_output=pad_output, output_dtype=output_dtype) @@ -270,6 +271,7 @@ def dali_dataloader(config, mode, device, seed=None): device_id=device_id, shard_id=0, num_shards=1, + num_threads=num_threads, seed=seed, pad_output=pad_output, output_dtype=output_dtype) @@ -298,6 +300,7 @@ def dali_dataloader(config, mode, device, seed=None): device_id=device_id, shard_id=shard_id, num_shards=num_shards, + num_threads=num_threads, pad_output=pad_output, output_dtype=output_dtype) else: @@ -311,6 +314,7 @@ def dali_dataloader(config, mode, device, seed=None): mean, std, device_id=device_id, + num_threads=num_threads, pad_output=pad_output, output_dtype=output_dtype) pipe.build() -- GitLab