ctr_reader.py 5.4 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.fluid import core
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import default_main_program, \
    default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name

Q
Qiao Longfei 已提交
23 24
__all__ = ['ctr_reader']

Q
Qiao Longfei 已提交
25 26 27 28 29 30 31 32 33 34

def monkey_patch_reader_methods(reader):
    def __get_reader__():
        scope = global_scope()
        var = scope.find_var(reader.name)
        return var.get_reader()

    def reset():
        return __get_reader__().reset()

Q
Qiao Longfei 已提交
35 36 37
    def start():
        return __get_reader__().start()

Q
Qiao Longfei 已提交
38
    reader.reset = reset
Q
Qiao Longfei 已提交
39
    reader.start = start
Q
Qiao Longfei 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
    reader.stop_gradient = True
    reader.persistable = True
    return reader


def _copy_reader_var_(block, var):
    new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
    new_var.desc.set_shapes(var.desc.shapes())
    new_var.desc.set_dtypes(var.desc.dtypes())
    new_var.persistable = True
    return new_var


Q
Qiao Longfei 已提交
53 54 55 56
def ctr_reader(
        feed_dict,
        file_type,  # gzip or plain
        file_format,  # csv or svm
Q
Qiao Longfei 已提交
57 58
        dense_slot_index,
        sparse_slot_index,
Q
Qiao Longfei 已提交
59 60 61 62 63 64
        capacity,
        thread_num,
        batch_size,
        file_list,
        slots,
        name=None):
Q
Qiao Longfei 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    """
    Create a CTR reader for data feeding in Python

    This layer returns a Reader Variable.
    The Reader provides :code:`decorate_paddle_reader()` and
    :code:`decorate_tensor_provider()` to set a Python generator as the data
    source in Python side. When :code:`Executor::Run()` is invoked in C++
    side, the data from the generator would be read automatically. Unlike
    :code:`DataFeeder.feed()`, the data reading process and
    :code:`Executor::Run()` process can run in parallel using
    :code:`py_reader`. The :code:`start()` method of the Reader should be
    called when each pass begins, while the :code:`reset()` method should be
    called when the pass ends and :code:`fluid.core.EOFException` raises.
    Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.

    Args:
Q
Qiao Longfei 已提交
81 82 83 84 85 86 87 88 89
       feed_dict(list(variable)): a list of data variable.
       file_type('gzip'|'plain'): the type of the data file
       file_format('csv'|'svm'): csv data or svm data format.
        cvs data format is :
            label dense_fea,dense_fea sparse_fea,sparse_fea
        the svm data format is :
            label slot1:fea_sign slot2:fea_sign slot1:fea_sign
       dense_slot_index(list(int)): the index of dense slots
       sparse_slot_index(list(int)): the index of sparse slots
Q
Qiao Longfei 已提交
90 91 92 93
       capacity(int): The buffer capacity maintained by :code:`py_reader`.
       thread_num(list|tuple): List of tuples which declaring data shapes.
       batch_size(list|tuple): List of strs which declaring data type.
       file_list(list|tuple): List of ints which declaring data lod_level.
Q
Qiao Longfei 已提交
94
       slots(bool): slot id of all sparse feature
Q
Qiao Longfei 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
       name(basestring): The prefix Python queue name and Reader name. None will
            be generated automatically.

    Returns:
       Variable: A Reader from which we can get feeding data.

    Examples:

        1. The basic usage of :code:`py_reader` is as follows:
    """
    if name is None:
        queue_name = unique_name('lod_tensor_blocking_queue')
        reader_name = unique_name('create_ctr_reader')
    else:
        queue_name = "_".join([name, "queue"])
        reader_name = "_".join([name, "reader"])

    var = global_scope().var(queue_name)
Q
Qiao Longfei 已提交
113
    feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
Q
Qiao Longfei 已提交
114 115 116 117 118 119 120 121

    startup_blk = default_startup_program().current_block()
    reader_var = startup_blk.create_var(name=reader_name)
    startup_blk.append_op(
        type='create_ctr_reader',
        inputs={'blocking_queue': [queue_name]},
        outputs={'Out': [reader_var]},
        attrs={
Q
Qiao Longfei 已提交
122
            'use_data_config': False,
Q
Qiao Longfei 已提交
123 124 125
            'thread_num': thread_num,
            'batch_size': batch_size,
            'file_list': file_list,
Q
Qiao Longfei 已提交
126 127
            'file_type': file_type,
            'file_format': file_format,
Q
Qiao Longfei 已提交
128 129
            'dense_slot_index': dense_slot_index,
            'sparse_slot_index': sparse_slot_index,
Q
Qiao Longfei 已提交
130 131 132 133
            'sparse_slots': slots,
            'ranks': [],
            'lod_levels': [],
            'shape_concat': []
Q
Qiao Longfei 已提交
134 135
        })

Q
Qiao Longfei 已提交
136 137
    dtypes = [data.dtype for data in feed_dict]
    reader_var.desc.set_dtypes(dtypes)
Q
Qiao Longfei 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150
    reader_var.persistable = True

    main_prog_reader_var = _copy_reader_var_(
        default_main_program().current_block(), reader_var)

    reader = monkey_patch_reader_methods(main_prog_reader_var)

    # monkey patch py_reader special methods
    reader.queue = feed_queue
    reader.exited = False

    main_blk = default_main_program().current_block()
    main_blk.append_op(
Q
Qiao Longfei 已提交
151 152 153 154
        type='read',
        inputs={'Reader': [reader]},
        attrs={'infer_out': False},
        outputs={'Out': feed_dict})
Q
Qiao Longfei 已提交
155 156

    return reader