提交 3f8f9814 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: support specify DALI threads num

上级 6cd62518
...@@ -72,7 +72,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -72,7 +72,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
# build dataset # build dataset
if use_dali: if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(config, mode, paddle.device.get_device(), seed) return dali_dataloader(
config,
mode,
paddle.device.get_device(),
num_threads=config[mode]['loader']["num_workers"],
seed=seed)
class_num = config.get("class_num", None) class_num = config.get("class_num", None)
config_dataset = config[mode]['dataset'] config_dataset = config[mode]['dataset']
......
...@@ -143,7 +143,7 @@ class HybridValPipe(Pipeline): ...@@ -143,7 +143,7 @@ class HybridValPipe(Pipeline):
return self.epoch_size("Reader") return self.epoch_size("Reader")
def dali_dataloader(config, mode, device, seed=None): def dali_dataloader(config, mode, device, num_threads=4, seed=None):
assert "gpu" in device, "gpu training is required for DALI" assert "gpu" in device, "gpu training is required for DALI"
device_id = int(device.split(':')[1]) device_id = int(device.split(':')[1])
config_dataloader = config[mode] config_dataloader = config[mode]
...@@ -248,6 +248,7 @@ def dali_dataloader(config, mode, device, seed=None): ...@@ -248,6 +248,7 @@ def dali_dataloader(config, mode, device, seed=None):
device_id, device_id,
shard_id, shard_id,
num_shards, num_shards,
num_threads=num_threads,
seed=seed + shard_id, seed=seed + shard_id,
pad_output=pad_output, pad_output=pad_output,
output_dtype=output_dtype) output_dtype=output_dtype)
...@@ -270,6 +271,7 @@ def dali_dataloader(config, mode, device, seed=None): ...@@ -270,6 +271,7 @@ def dali_dataloader(config, mode, device, seed=None):
device_id=device_id, device_id=device_id,
shard_id=0, shard_id=0,
num_shards=1, num_shards=1,
num_threads=num_threads,
seed=seed, seed=seed,
pad_output=pad_output, pad_output=pad_output,
output_dtype=output_dtype) output_dtype=output_dtype)
...@@ -298,6 +300,7 @@ def dali_dataloader(config, mode, device, seed=None): ...@@ -298,6 +300,7 @@ def dali_dataloader(config, mode, device, seed=None):
device_id=device_id, device_id=device_id,
shard_id=shard_id, shard_id=shard_id,
num_shards=num_shards, num_shards=num_shards,
num_threads=num_threads,
pad_output=pad_output, pad_output=pad_output,
output_dtype=output_dtype) output_dtype=output_dtype)
else: else:
...@@ -311,6 +314,7 @@ def dali_dataloader(config, mode, device, seed=None): ...@@ -311,6 +314,7 @@ def dali_dataloader(config, mode, device, seed=None):
mean, mean,
std, std,
device_id=device_id, device_id=device_id,
num_threads=num_threads,
pad_output=pad_output, pad_output=pad_output,
output_dtype=output_dtype) output_dtype=output_dtype)
pipe.build() pipe.build()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册