ctr_reader.py 4.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.fluid import core
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import default_main_program, \
    default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name

Q
Qiao Longfei 已提交
23 24
__all__ = ['ctr_reader']

Q
Qiao Longfei 已提交
25 26 27 28 29 30 31 32 33 34

def monkey_patch_reader_methods(reader):
    def __get_reader__():
        scope = global_scope()
        var = scope.find_var(reader.name)
        return var.get_reader()

    def reset():
        return __get_reader__().reset()

Q
Qiao Longfei 已提交
35 36 37
    def start():
        return __get_reader__().start()

Q
Qiao Longfei 已提交
38
    reader.reset = reset
Q
Qiao Longfei 已提交
39
    reader.start = start
Q
Qiao Longfei 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
    reader.stop_gradient = True
    reader.persistable = True
    return reader


def _copy_reader_var_(block, var):
    new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
    new_var.desc.set_shapes(var.desc.shapes())
    new_var.desc.set_dtypes(var.desc.dtypes())
    new_var.persistable = True
    return new_var


Q
Qiao Longfei 已提交
53 54 55 56 57 58 59 60 61 62 63 64
def ctr_reader(
        feed_dict,
        file_type,  # gzip or plain
        file_format,  # csv or svm
        dense_slot_indexs,
        sparse_slot_indexs,
        capacity,
        thread_num,
        batch_size,
        file_list,
        slots,
        name=None):
Q
Qiao Longfei 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
    """
    Create a CTR reader for data feeding in Python

    This layer returns a Reader Variable.
    The Reader provides :code:`decorate_paddle_reader()` and
    :code:`decorate_tensor_provider()` to set a Python generator as the data
    source in Python side. When :code:`Executor::Run()` is invoked in C++
    side, the data from the generator would be read automatically. Unlike
    :code:`DataFeeder.feed()`, the data reading process and
    :code:`Executor::Run()` process can run in parallel using
    :code:`py_reader`. The :code:`start()` method of the Reader should be
    called when each pass begins, while the :code:`reset()` method should be
    called when the pass ends and :code:`fluid.core.EOFException` raises.
    Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.

    Args:
       capacity(int): The buffer capacity maintained by :code:`py_reader`.
       thread_num(list|tuple): List of tuples which declaring data shapes.
       batch_size(list|tuple): List of strs which declaring data type.
       file_list(list|tuple): List of ints which declaring data lod_level.
       slots(bool): Whether use double buffer or not.
       name(basestring): The prefix Python queue name and Reader name. None will
            be generated automatically.

    Returns:
       Variable: A Reader from which we can get feeding data.

    Examples:

        1. The basic usage of :code:`py_reader` is as follows:
    """
    if name is None:
        queue_name = unique_name('lod_tensor_blocking_queue')
        reader_name = unique_name('create_ctr_reader')
    else:
        queue_name = "_".join([name, "queue"])
        reader_name = "_".join([name, "reader"])

    var = global_scope().var(queue_name)
Q
Qiao Longfei 已提交
104
    feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
Q
Qiao Longfei 已提交
105 106 107 108 109 110 111 112

    startup_blk = default_startup_program().current_block()
    reader_var = startup_blk.create_var(name=reader_name)
    startup_blk.append_op(
        type='create_ctr_reader',
        inputs={'blocking_queue': [queue_name]},
        outputs={'Out': [reader_var]},
        attrs={
Q
Qiao Longfei 已提交
113
            'use_data_config': False,
Q
Qiao Longfei 已提交
114 115 116
            'thread_num': thread_num,
            'batch_size': batch_size,
            'file_list': file_list,
Q
Qiao Longfei 已提交
117 118 119 120 121 122 123 124
            'file_type': file_type,
            'file_format': file_format,
            'dense_slot_index': dense_slot_indexs,
            'sparse_slot_index': sparse_slot_indexs,
            'sparse_slots': slots,
            'ranks': [],
            'lod_levels': [],
            'shape_concat': []
Q
Qiao Longfei 已提交
125 126
        })

Q
Qiao Longfei 已提交
127 128
    dtypes = [data.dtype for data in feed_dict]
    reader_var.desc.set_dtypes(dtypes)
Q
Qiao Longfei 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141
    reader_var.persistable = True

    main_prog_reader_var = _copy_reader_var_(
        default_main_program().current_block(), reader_var)

    reader = monkey_patch_reader_methods(main_prog_reader_var)

    # monkey patch py_reader special methods
    reader.queue = feed_queue
    reader.exited = False

    main_blk = default_main_program().current_block()
    main_blk.append_op(
Q
Qiao Longfei 已提交
142 143 144 145
        type='read',
        inputs={'Reader': [reader]},
        attrs={'infer_out': False},
        outputs={'Out': feed_dict})
Q
Qiao Longfei 已提交
146 147

    return reader