ctr_reader.py 4.3 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.fluid import core
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import default_main_program, \
    default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name


def monkey_patch_reader_methods(reader):
    def __get_reader__():
        scope = global_scope()
        var = scope.find_var(reader.name)
        return var.get_reader()

    def reset():
        return __get_reader__().reset()

    reader.reset = reset
    reader.stop_gradient = True
    reader.persistable = True
    return reader


def _copy_reader_var_(block, var):
    new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
    new_var.desc.set_shapes(var.desc.shapes())
    new_var.desc.set_dtypes(var.desc.dtypes())
    new_var.persistable = True
    return new_var


def ctr_reader(feed_data,
               capacity,
               thread_num,
               batch_size,
               file_list,
               slots,
               name=None):
    """
    Create a CTR reader for data feeding in Python

    This layer returns a Reader Variable.
    The Reader provides :code:`decorate_paddle_reader()` and
    :code:`decorate_tensor_provider()` to set a Python generator as the data
    source in Python side. When :code:`Executor::Run()` is invoked in C++
    side, the data from the generator would be read automatically. Unlike
    :code:`DataFeeder.feed()`, the data reading process and
    :code:`Executor::Run()` process can run in parallel using
    :code:`py_reader`. The :code:`start()` method of the Reader should be
    called when each pass begins, while the :code:`reset()` method should be
    called when the pass ends and :code:`fluid.core.EOFException` raises.
    Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.

    Args:
       capacity(int): The buffer capacity maintained by :code:`py_reader`.
       thread_num(list|tuple): List of tuples which declaring data shapes.
       batch_size(list|tuple): List of strs which declaring data type.
       file_list(list|tuple): List of ints which declaring data lod_level.
       slots(bool): Whether use double buffer or not.
       name(basestring): The prefix Python queue name and Reader name. None will
            be generated automatically.

    Returns:
       Variable: A Reader from which we can get feeding data.

    Examples:

        1. The basic usage of :code:`py_reader` is as follows:
    """
    if name is None:
        queue_name = unique_name('lod_tensor_blocking_queue')
        reader_name = unique_name('create_ctr_reader')
    else:
        queue_name = "_".join([name, "queue"])
        reader_name = "_".join([name, "reader"])

    var = global_scope().var(queue_name)
    feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)

    startup_blk = default_startup_program().current_block()
    reader_var = startup_blk.create_var(name=reader_name)
    startup_blk.append_op(
        type='create_ctr_reader',
        inputs={'blocking_queue': [queue_name]},
        outputs={'Out': [reader_var]},
        attrs={
            'thread_num': thread_num,
            'batch_size': batch_size,
            'file_list': file_list,
            'slots': slots,
        })

    reader_var.persistable = True

    main_prog_reader_var = _copy_reader_var_(
        default_main_program().current_block(), reader_var)

    reader = monkey_patch_reader_methods(main_prog_reader_var)

    # monkey patch py_reader special methods
    reader.queue = feed_queue
    reader.exited = False

    main_blk = default_main_program().current_block()
    main_blk.append_op(
        type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data})

    return reader