提交 a14df4ac 编写于 作者: H HydrogenSulfate

fix tensor conversion in static mode with dali loader

上级 3647da6d
...@@ -142,13 +142,18 @@ class HybridValPipe(Pipeline): ...@@ -142,13 +142,18 @@ class HybridValPipe(Pipeline):
class DALIImageNetIterator(DALIGenericIterator): class DALIImageNetIterator(DALIGenericIterator):
def __init__(self, *kargs, **kwargs):
super(DALIImageNetIterator, self).__init__(*kargs, **kwargs)
self.in_dynamic_mode = paddle.in_dynamic_mode()
def __next__(self) -> List[paddle.Tensor]: def __next__(self) -> List[paddle.Tensor]:
data_batch = super(DALIImageNetIterator, data_batch = super(DALIImageNetIterator,
self).__next__() # List[Dict[str, Tensor], ...] self).__next__() # List[Dict[str, Tensor], ...]
# reformat to List[Tensor1, Tensor2, ...] # reformat to List[Tensor1, Tensor2, ...]
data_batch = [ data_batch = [
paddle.to_tensor(data_batch[0][key]) for key in self.output_map paddle.to_tensor(data_batch[0][key])
if self.in_dynamic_mode else data_batch[0][key]
for key in self.output_map
] ]
return data_batch return data_batch
......
...@@ -386,10 +386,6 @@ def run(dataloader, ...@@ -386,10 +386,6 @@ def run(dataloader,
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
if use_dali:
batch_size = batch[0]["data"].shape()[0]
feed_dict = batch[0]
else:
batch_size = batch[0].shape()[0] batch_size = batch[0].shape()[0]
feed_dict = { feed_dict = {
key.name: batch[idx] key.name: batch[idx]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册