未验证 提交 45742397 编写于 作者: H HydrogenSulfate 提交者: GitHub

Merge pull request #2374 from HydrogenSulfate/fix_dali

add batch Tensor collate to simplify dali code in train/eval/retrival…
......@@ -14,14 +14,12 @@
from __future__ import division
import copy
import os
import numpy as np
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import paddle
from nvidia.dali import fn
from typing import List
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.paddle import DALIGenericIterator
......@@ -143,6 +141,18 @@ class HybridValPipe(Pipeline):
return self.epoch_size("Reader")
class DALIImageNetIterator(DALIGenericIterator):
def __next__(self) -> List[paddle.Tensor]:
data_batch = super(DALIImageNetIterator,
self).__next__() # List[Dict[str, Tensor], ...]
# reformat to List[Tensor1, Tensor2, ...]
data_batch = [
paddle.to_tensor(data_batch[0][key]) for key in self.output_map
]
return data_batch
def dali_dataloader(config, mode, device, num_threads=4, seed=None):
assert "gpu" in device, "gpu training is required for DALI"
device_id = int(device.split(':')[1])
......@@ -278,7 +288,7 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None):
pipe.build()
pipelines = [pipe]
# sample_per_shard = len(pipelines[0])
return DALIGenericIterator(
return DALIImageNetIterator(
pipelines, ['data', 'label'], reader_name='Reader')
else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
......@@ -318,5 +328,5 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None):
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
return DALIGenericIterator(
return DALIImageNetIterator(
[pipe], ['data', 'label'], reader_name="Reader")
......@@ -47,11 +47,7 @@ def classification_eval(engine, epoch_id=0):
if iter_id == 5:
for key in time_info:
time_info[key].reset()
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0])
......
......@@ -155,11 +155,7 @@ def cal_feature(engine, name='gallery'):
logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
)
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3:
......
......@@ -29,11 +29,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
for key in engine.time_info:
engine.time_info[key].reset()
engine.time_info["reader_cost"].update(time.time() - tic)
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册