未验证 提交 184b684f 编写于 作者: H HydrogenSulfate 提交者: GitHub

Merge pull request #2401 from HydrogenSulfate/fix_dali_static

Fix tensor conversion in static mode with dali loader
...@@ -142,13 +142,18 @@ class HybridValPipe(Pipeline): ...@@ -142,13 +142,18 @@ class HybridValPipe(Pipeline):
class DALIImageNetIterator(DALIGenericIterator): class DALIImageNetIterator(DALIGenericIterator):
def __init__(self, *kargs, **kwargs):
super(DALIImageNetIterator, self).__init__(*kargs, **kwargs)
self.in_dynamic_mode = paddle.in_dynamic_mode()
def __next__(self) -> List[paddle.Tensor]: def __next__(self) -> List[paddle.Tensor]:
data_batch = super(DALIImageNetIterator, data_batch = super(DALIImageNetIterator,
self).__next__() # List[Dict[str, Tensor], ...] self).__next__() # List[Dict[str, Tensor], ...]
# reformat to List[Tensor1, Tensor2, ...] # reformat to List[Tensor1, Tensor2, ...]
data_batch = [ data_batch = [
paddle.to_tensor(data_batch[0][key]) for key in self.output_map paddle.to_tensor(data_batch[0][key])
if self.in_dynamic_mode else data_batch[0][key]
for key in self.output_map
] ]
return data_batch return data_batch
......
...@@ -386,15 +386,11 @@ def run(dataloader, ...@@ -386,15 +386,11 @@ def run(dataloader,
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
if use_dali: batch_size = batch[0].shape()[0]
batch_size = batch[0]["data"].shape()[0] feed_dict = {
feed_dict = batch[0] key.name: batch[idx]
else: for idx, key in enumerate(feeds.values())
batch_size = batch[0].shape()[0] }
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
metrics = exe.run(program=program, metrics = exe.run(program=program,
feed=feed_dict, feed=feed_dict,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册