# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function from paddle.fluid import core from paddle.fluid.executor import global_scope from paddle.fluid.framework import default_main_program, \ default_startup_program, Variable from paddle.fluid.unique_name import generate as unique_name __all__ = ['ctr_reader'] def monkey_patch_reader_methods(reader): def __get_reader__(): scope = global_scope() var = scope.find_var(reader.name) return var.get_reader() def reset(): return __get_reader__().reset() def start(): return __get_reader__().start() reader.reset = reset reader.start = start reader.stop_gradient = True reader.persistable = True return reader def _copy_reader_var_(block, var): new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_dtypes(var.desc.dtypes()) new_var.persistable = True return new_var def ctr_reader( feed_dict, file_type, # gzip or plain file_format, # csv or svm dense_slot_indexs, sparse_slot_indexs, capacity, thread_num, batch_size, file_list, slots, name=None): """ Create a CTR reader for data feeding in Python This layer returns a Reader Variable. The Reader provides :code:`decorate_paddle_reader()` and :code:`decorate_tensor_provider()` to set a Python generator as the data source in Python side. When :code:`Executor::Run()` is invoked in C++ side, the data from the generator would be read automatically. Unlike :code:`DataFeeder.feed()`, the data reading process and :code:`Executor::Run()` process can run in parallel using :code:`py_reader`. The :code:`start()` method of the Reader should be called when each pass begins, while the :code:`reset()` method should be called when the pass ends and :code:`fluid.core.EOFException` raises. Note that :code:`Program.clone()` method cannot clone :code:`py_reader`. Args: capacity(int): The buffer capacity maintained by :code:`py_reader`. thread_num(list|tuple): List of tuples which declaring data shapes. batch_size(list|tuple): List of strs which declaring data type. file_list(list|tuple): List of ints which declaring data lod_level. slots(bool): Whether use double buffer or not. name(basestring): The prefix Python queue name and Reader name. None will be generated automatically. Returns: Variable: A Reader from which we can get feeding data. Examples: 1. The basic usage of :code:`py_reader` is as follows: """ if name is None: queue_name = unique_name('lod_tensor_blocking_queue') reader_name = unique_name('create_ctr_reader') else: queue_name = "_".join([name, "queue"]) reader_name = "_".join([name, "reader"]) var = global_scope().var(queue_name) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity) startup_blk = default_startup_program().current_block() reader_var = startup_blk.create_var(name=reader_name) startup_blk.append_op( type='create_ctr_reader', inputs={'blocking_queue': [queue_name]}, outputs={'Out': [reader_var]}, attrs={ 'use_data_config': False, 'thread_num': thread_num, 'batch_size': batch_size, 'file_list': file_list, 'file_type': file_type, 'file_format': file_format, 'dense_slot_index': dense_slot_indexs, 'sparse_slot_index': sparse_slot_indexs, 'sparse_slots': slots, 'ranks': [], 'lod_levels': [], 'shape_concat': [] }) dtypes = [data.dtype for data in feed_dict] reader_var.desc.set_dtypes(dtypes) reader_var.persistable = True main_prog_reader_var = _copy_reader_var_( default_main_program().current_block(), reader_var) reader = monkey_patch_reader_methods(main_prog_reader_var) # monkey patch py_reader special methods reader.queue = feed_queue reader.exited = False main_blk = default_main_program().current_block() main_blk.append_op( type='read', inputs={'Reader': [reader]}, attrs={'infer_out': False}, outputs={'Out': feed_dict}) return reader