ctr_reader.py 5.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.fluid import core
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import default_main_program, \
    default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name

Q
Qiao Longfei 已提交
23 24
__all__ = ['ctr_reader']

Q
Qiao Longfei 已提交
25 26 27 28 29 30 31 32 33 34

def monkey_patch_reader_methods(reader):
    def __get_reader__():
        scope = global_scope()
        var = scope.find_var(reader.name)
        return var.get_reader()

    def reset():
        return __get_reader__().reset()

Q
Qiao Longfei 已提交
35 36 37
    def start():
        return __get_reader__().start()

Q
Qiao Longfei 已提交
38
    reader.reset = reset
Q
Qiao Longfei 已提交
39
    reader.start = start
Q
Qiao Longfei 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
    reader.stop_gradient = True
    reader.persistable = True
    return reader


def _copy_reader_var_(block, var):
    new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
    new_var.desc.set_shapes(var.desc.shapes())
    new_var.desc.set_dtypes(var.desc.dtypes())
    new_var.persistable = True
    return new_var


Q
Qiao Longfei 已提交
53 54 55 56
def ctr_reader(
        feed_dict,
        file_type,  # gzip or plain
        file_format,  # csv or svm
Q
Qiao Longfei 已提交
57 58
        dense_slot_index,
        sparse_slot_index,
Q
Qiao Longfei 已提交
59 60 61 62 63 64
        capacity,
        thread_num,
        batch_size,
        file_list,
        slots,
        name=None):
Q
Qiao Longfei 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    """
    Create a CTR reader for data feeding in Python

    This layer returns a Reader Variable.
    The Reader provides :code:`decorate_paddle_reader()` and
    :code:`decorate_tensor_provider()` to set a Python generator as the data
    source in Python side. When :code:`Executor::Run()` is invoked in C++
    side, the data from the generator would be read automatically. Unlike
    :code:`DataFeeder.feed()`, the data reading process and
    :code:`Executor::Run()` process can run in parallel using
    :code:`py_reader`. The :code:`start()` method of the Reader should be
    called when each pass begins, while the :code:`reset()` method should be
    called when the pass ends and :code:`fluid.core.EOFException` raises.
    Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.

    Args:
Q
Qiao Longfei 已提交
81 82 83 84 85 86 87 88 89
       feed_dict(list(variable)): a list of data variable.
       file_type('gzip'|'plain'): the type of the data file
       file_format('csv'|'svm'): csv data or svm data format.
        cvs data format is :
            label dense_fea,dense_fea sparse_fea,sparse_fea
        the svm data format is :
            label slot1:fea_sign slot2:fea_sign slot1:fea_sign
       dense_slot_index(list(int)): the index of dense slots
       sparse_slot_index(list(int)): the index of sparse slots
Q
Qiao Longfei 已提交
90
       capacity(int): The buffer capacity maintained by :code:`py_reader`.
Q
Qiao Longfei 已提交
91 92 93 94 95
       thread_num(int): the thread num to read files by cpp reader.
       batch_size(int): batch size of data.
       file_list(list(str)): List of file names that need to read.
       slots(list(int64)): list of slot id.
       name(string): The prefix Python queue name and Reader name. None will
Q
Qiao Longfei 已提交
96 97 98 99 100 101 102
            be generated automatically.

    Returns:
       Variable: A Reader from which we can get feeding data.

    Examples:

Q
Qiao Longfei 已提交
103 104 105 106 107 108 109 110 111
        1. The basic usage of :code:`ctr_reader` is as follows:

     .. code-block:: python

        py_reader = fluid.contrib.ctr_reader.ctr_reader(
          feed_dict=datas, file_type='plain', file_format='csv',
          file_list=file_list, dense_slot_indexs=[1, 2, 3, 4], sparse_slot_indexs=[],
          capacity=64, thread_num=20, batch_size=1000, slots=[], name='ctr_reader')

Q
Qiao Longfei 已提交
112 113 114 115 116 117 118 119 120
    """
    if name is None:
        queue_name = unique_name('lod_tensor_blocking_queue')
        reader_name = unique_name('create_ctr_reader')
    else:
        queue_name = "_".join([name, "queue"])
        reader_name = "_".join([name, "reader"])

    var = global_scope().var(queue_name)
Q
Qiao Longfei 已提交
121
    feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
Q
Qiao Longfei 已提交
122 123 124 125 126 127 128 129

    startup_blk = default_startup_program().current_block()
    reader_var = startup_blk.create_var(name=reader_name)
    startup_blk.append_op(
        type='create_ctr_reader',
        inputs={'blocking_queue': [queue_name]},
        outputs={'Out': [reader_var]},
        attrs={
Q
Qiao Longfei 已提交
130
            'use_data_config': False,
Q
Qiao Longfei 已提交
131 132 133
            'thread_num': thread_num,
            'batch_size': batch_size,
            'file_list': file_list,
Q
Qiao Longfei 已提交
134 135
            'file_type': file_type,
            'file_format': file_format,
Q
Qiao Longfei 已提交
136 137
            'dense_slot_index': dense_slot_index,
            'sparse_slot_index': sparse_slot_index,
Q
Qiao Longfei 已提交
138 139 140 141
            'sparse_slots': slots,
            'ranks': [],
            'lod_levels': [],
            'shape_concat': []
Q
Qiao Longfei 已提交
142 143
        })

Q
Qiao Longfei 已提交
144 145
    dtypes = [data.dtype for data in feed_dict]
    reader_var.desc.set_dtypes(dtypes)
Q
Qiao Longfei 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158
    reader_var.persistable = True

    main_prog_reader_var = _copy_reader_var_(
        default_main_program().current_block(), reader_var)

    reader = monkey_patch_reader_methods(main_prog_reader_var)

    # monkey patch py_reader special methods
    reader.queue = feed_queue
    reader.exited = False

    main_blk = default_main_program().current_block()
    main_blk.append_op(
Q
Qiao Longfei 已提交
159 160 161 162
        type='read',
        inputs={'Reader': [reader]},
        attrs={'infer_out': False},
        outputs={'Out': feed_dict})
Q
Qiao Longfei 已提交
163 164

    return reader