diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 5e678fc67662d896cfe741c6088df1b316f2e5d2..7de46e5083aac0e2a231cca6a5f3e652ede7e432 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -51,6 +51,7 @@ from .dataloader.batch_sampler import _InfiniteIterableSampler from .layers.io import ( monkey_patch_reader_methods, _copy_reader_var_, + __create_unshared_decorated_reader__, ) from .unique_name import UniqueNameGenerator from .framework import _get_paddle_place, _get_paddle_place_list @@ -1351,11 +1352,6 @@ class GeneratorLoader(DataLoaderBase): self._use_double_buffer = use_double_buffer self._capacity = capacity if not self._iterable: - # Because layers.io.double_buffer is not supported anymore and that iterable is False and use_double_buffer - # is True is not spported, here if itrable is False, use_double_buffer will be - # forcely set False to avoid unexpected error. - # TODO: keep use_double_buffer - self._use_double_buffer = False self._init_non_iterable() def _wait_thread_ends(self): @@ -1410,6 +1406,7 @@ class GeneratorLoader(DataLoaderBase): 'lod_tensor_blocking_queue' ) reader_name = data_loader_unique_name_generator('create_py_reader') + double_buffer_name = data_loader_unique_name_generator('double_buffer') var = global_scope().var(queue_name) self._queue = core.init_lod_tensor_blocking_queue( @@ -1455,6 +1452,18 @@ class GeneratorLoader(DataLoaderBase): reader = monkey_patch_reader_methods(main_prog_var) + if self._use_double_buffer: + double_buffer_reader = __create_unshared_decorated_reader__( + 'create_double_buffer_reader', + reader, + {}, + name=double_buffer_name, + ) + # we return a double buffer reader. However, the reset method comes from + # py_reader. + double_buffer_reader.reset = reader.reset + reader = double_buffer_reader + self._reader = reader default_main_program().current_block().append_op(