提交 3f8f9814 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: support specify DALI threads num

上级 6cd62518
......@@ -72,7 +72,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
# build dataset
if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(config, mode, paddle.device.get_device(), seed)
return dali_dataloader(
config,
mode,
paddle.device.get_device(),
num_threads=config[mode]['loader']["num_workers"],
seed=seed)
class_num = config.get("class_num", None)
config_dataset = config[mode]['dataset']
......
......@@ -143,7 +143,7 @@ class HybridValPipe(Pipeline):
return self.epoch_size("Reader")
def dali_dataloader(config, mode, device, seed=None):
def dali_dataloader(config, mode, device, num_threads=4, seed=None):
assert "gpu" in device, "gpu training is required for DALI"
device_id = int(device.split(':')[1])
config_dataloader = config[mode]
......@@ -248,6 +248,7 @@ def dali_dataloader(config, mode, device, seed=None):
device_id,
shard_id,
num_shards,
num_threads=num_threads,
seed=seed + shard_id,
pad_output=pad_output,
output_dtype=output_dtype)
......@@ -270,6 +271,7 @@ def dali_dataloader(config, mode, device, seed=None):
device_id=device_id,
shard_id=0,
num_shards=1,
num_threads=num_threads,
seed=seed,
pad_output=pad_output,
output_dtype=output_dtype)
......@@ -298,6 +300,7 @@ def dali_dataloader(config, mode, device, seed=None):
device_id=device_id,
shard_id=shard_id,
num_shards=num_shards,
num_threads=num_threads,
pad_output=pad_output,
output_dtype=output_dtype)
else:
......@@ -311,6 +314,7 @@ def dali_dataloader(config, mode, device, seed=None):
mean,
std,
device_id=device_id,
num_threads=num_threads,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册