From 1a32448feb2d6f50a06f4dfb90961ffe89879cc8 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Wed, 14 Dec 2022 20:21:01 +0800 Subject: [PATCH] Keep double-buffer reader for static mode (#49068) --- python/paddle/fluid/reader.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 5e678fc676..7de46e5083 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -51,6 +51,7 @@ from .dataloader.batch_sampler import _InfiniteIterableSampler from .layers.io import ( monkey_patch_reader_methods, _copy_reader_var_, + __create_unshared_decorated_reader__, ) from .unique_name import UniqueNameGenerator from .framework import _get_paddle_place, _get_paddle_place_list @@ -1351,11 +1352,6 @@ class GeneratorLoader(DataLoaderBase): self._use_double_buffer = use_double_buffer self._capacity = capacity if not self._iterable: - # Because layers.io.double_buffer is not supported anymore and that iterable is False and use_double_buffer - # is True is not spported, here if itrable is False, use_double_buffer will be - # forcely set False to avoid unexpected error. - # TODO: keep use_double_buffer - self._use_double_buffer = False self._init_non_iterable() def _wait_thread_ends(self): @@ -1410,6 +1406,7 @@ class GeneratorLoader(DataLoaderBase): 'lod_tensor_blocking_queue' ) reader_name = data_loader_unique_name_generator('create_py_reader') + double_buffer_name = data_loader_unique_name_generator('double_buffer') var = global_scope().var(queue_name) self._queue = core.init_lod_tensor_blocking_queue( @@ -1455,6 +1452,18 @@ class GeneratorLoader(DataLoaderBase): reader = monkey_patch_reader_methods(main_prog_var) + if self._use_double_buffer: + double_buffer_reader = __create_unshared_decorated_reader__( + 'create_double_buffer_reader', + reader, + {}, + name=double_buffer_name, + ) + # we return a double buffer reader. However, the reset method comes from + # py_reader. + double_buffer_reader.reset = reader.reset + reader = double_buffer_reader + self._reader = reader default_main_program().current_block().append_op( -- GitLab