提交 a14df4ac 编写于 作者: H HydrogenSulfate

fix tensor conversion in static mode with dali loader

上级 3647da6d
......@@ -142,13 +142,18 @@ class HybridValPipe(Pipeline):
class DALIImageNetIterator(DALIGenericIterator):
def __init__(self, *kargs, **kwargs):
super(DALIImageNetIterator, self).__init__(*kargs, **kwargs)
self.in_dynamic_mode = paddle.in_dynamic_mode()
def __next__(self) -> List[paddle.Tensor]:
data_batch = super(DALIImageNetIterator,
self).__next__() # List[Dict[str, Tensor], ...]
# reformat to List[Tensor1, Tensor2, ...]
data_batch = [
paddle.to_tensor(data_batch[0][key]) for key in self.output_map
paddle.to_tensor(data_batch[0][key])
if self.in_dynamic_mode else data_batch[0][key]
for key in self.output_map
]
return data_batch
......
......@@ -386,15 +386,11 @@ def run(dataloader,
profiler.add_profiler_step(profiler_options)
if use_dali:
batch_size = batch[0]["data"].shape()[0]
feed_dict = batch[0]
else:
batch_size = batch[0].shape()[0]
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
batch_size = batch[0].shape()[0]
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
metrics = exe.run(program=program,
feed=feed_dict,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册