未验证 提交 45742397 编写于 作者: H HydrogenSulfate 提交者: GitHub

Merge pull request #2374 from HydrogenSulfate/fix_dali

add batch Tensor collate to simplify dali code in train/eval/retrival…
...@@ -14,14 +14,12 @@ ...@@ -14,14 +14,12 @@
from __future__ import division from __future__ import division
import copy
import os import os
import numpy as np
import nvidia.dali.ops as ops import nvidia.dali.ops as ops
import nvidia.dali.types as types import nvidia.dali.types as types
import paddle import paddle
from nvidia.dali import fn from typing import List
from nvidia.dali.pipeline import Pipeline from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.paddle import DALIGenericIterator from nvidia.dali.plugin.paddle import DALIGenericIterator
...@@ -143,6 +141,18 @@ class HybridValPipe(Pipeline): ...@@ -143,6 +141,18 @@ class HybridValPipe(Pipeline):
return self.epoch_size("Reader") return self.epoch_size("Reader")
class DALIImageNetIterator(DALIGenericIterator):
def __next__(self) -> List[paddle.Tensor]:
data_batch = super(DALIImageNetIterator,
self).__next__() # List[Dict[str, Tensor], ...]
# reformat to List[Tensor1, Tensor2, ...]
data_batch = [
paddle.to_tensor(data_batch[0][key]) for key in self.output_map
]
return data_batch
def dali_dataloader(config, mode, device, num_threads=4, seed=None): def dali_dataloader(config, mode, device, num_threads=4, seed=None):
assert "gpu" in device, "gpu training is required for DALI" assert "gpu" in device, "gpu training is required for DALI"
device_id = int(device.split(':')[1]) device_id = int(device.split(':')[1])
...@@ -278,7 +288,7 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None): ...@@ -278,7 +288,7 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None):
pipe.build() pipe.build()
pipelines = [pipe] pipelines = [pipe]
# sample_per_shard = len(pipelines[0]) # sample_per_shard = len(pipelines[0])
return DALIGenericIterator( return DALIImageNetIterator(
pipelines, ['data', 'label'], reader_name='Reader') pipelines, ['data', 'label'], reader_name='Reader')
else: else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256) resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
...@@ -318,5 +328,5 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None): ...@@ -318,5 +328,5 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None):
pad_output=pad_output, pad_output=pad_output,
output_dtype=output_dtype) output_dtype=output_dtype)
pipe.build() pipe.build()
return DALIGenericIterator( return DALIImageNetIterator(
[pipe], ['data', 'label'], reader_name="Reader") [pipe], ['data', 'label'], reader_name="Reader")
...@@ -47,11 +47,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -47,11 +47,7 @@ def classification_eval(engine, epoch_id=0):
if iter_id == 5: if iter_id == 5:
for key in time_info: for key in time_info:
time_info[key].reset() time_info[key].reset()
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
time_info["reader_cost"].update(time.time() - tic) time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]) batch[0] = paddle.to_tensor(batch[0])
......
...@@ -155,11 +155,7 @@ def cal_feature(engine, name='gallery'): ...@@ -155,11 +155,7 @@ def cal_feature(engine, name='gallery'):
logger.info( logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]" f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
) )
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch = [paddle.to_tensor(x) for x in batch] batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3: if len(batch) == 3:
......
...@@ -29,11 +29,7 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -29,11 +29,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
for key in engine.time_info: for key in engine.time_info:
engine.time_info[key].reset() engine.time_info[key].reset()
engine.time_info["reader_cost"].update(time.time() - tic) engine.time_info["reader_cost"].update(time.time() - tic)
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False): if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1]) batch[1] = batch[1].reshape([batch_size, -1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册