data_feeder.py 4.8 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from py_paddle import DataProviderConverter
Y
Yu Yang 已提交
16
import collections
Y
Yu Yang 已提交
17
import paddle.trainer.PyDataProvider2 as pydp2
D
dangqingqing 已提交
18 19

__all__ = ['DataFeeder']
20 21


Y
Yu Yang 已提交
22 23 24 25 26 27 28
def default_feeding_map(data_types):
    reader_dict = dict()
    for i, tp in enumerate(data_types):
        reader_dict[tp[0]] = i
    return reader_dict


29 30 31 32
class DataFeeder(DataProviderConverter):
    """
    DataFeeder converts the data returned by paddle.reader into a data structure
    of Arguments which is defined in the API. The paddle.reader usually returns
Q
qiaolongfei 已提交
33
    a list of mini-batch data entries. Each data entry in the list is one sample.
D
dangqingqing 已提交
34 35 36
    Each sample is a list or a tuple with one feature or multiple features.
    DataFeeder converts this mini-batch data entries into Arguments in order
    to feed it to C++ interface.
37
    
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    The simple usage shows below

    ..  code-block:: python

        data_types = [('image', paddle.data_type.dense_vector(784)),
                      ('label', paddle.data_type.integer_value(10))]
        feeding = ['image', 'label']
        feeder = DataFeeder(data_types=data_types, feeding=feeding)

        minibatch_data = [([1.0, 2.0, 3.0, ...], 5)]

        arg = feeder(minibatch_data)


    If mini-batch data and data layers are not one to one mapping, we
    could pass a dictionary to feeding parameter to represent the mapping
    relationship.
Y
Yu Yang 已提交
55 56 57 58


    ..  code-block:: python

59 60
        data_types = [('image', paddle.data_type.dense_vector(784)),
                      ('label', paddle.data_type.integer_value(10))]
61 62
        feeding = {'image':0, 'label':1}
        feeder = DataFeeder(data_types=data_types, feeding=feeding)
63 64 65 66
        minibatch_data = [
                           ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ),  # first sample
                           ( [1.0,2.0,3.0,4.0], 5, [6,7,8] )   # second sample
                         ]
D
dangqingqing 已提交
67 68 69 70
        # or minibatch_data = [
        #                       [ [1.0,2.0,3.0,4.0], 5, [6,7,8] ],  # first sample
        #                       [ [1.0,2.0,3.0,4.0], 5, [6,7,8] ]   # second sample
        #                     ]
71
        arg = feeder(minibatch_data)
Y
Yu Yang 已提交
72 73 74 75 76 77 78 79 80 81 82 83

    ..  note::

        This module is for internal use only. Users should use the `reader`
        interface.



    :param data_types: A list to specify data name and type. Each item is
                       a tuple of (data_name, data_type).

    :type data_types: list
84 85 86
    :param feeding: A dictionary or a sequence to specify the position of each
                    data in the input data.
    :type feeding: dict|collections.Sequence|None
87 88
    """

Y
Yu Yang 已提交
89
    def __init__(self, data_types, feeding=None):
90
        self.input_names = []
D
dangqingqing 已提交
91
        input_types = []
Y
Yu Yang 已提交
92 93
        if feeding is None:
            feeding = default_feeding_map(data_types)
Y
Yu Yang 已提交
94 95 96 97 98 99 100
        elif isinstance(feeding, collections.Sequence):
            feed_list = feeding
            feeding = dict()
            for i, name in enumerate(feed_list):
                feeding[name] = i
        elif not isinstance(feeding, dict):
            raise TypeError("Feeding should be dict or sequence or None.")
Y
Yu Yang 已提交
101 102

        self.feeding = feeding
103 104
        for each in data_types:
            self.input_names.append(each[0])
Y
Yu Yang 已提交
105 106 107
            if not isinstance(each[1], pydp2.InputType):
                raise TypeError("second item in each data_type should be an "
                                "InputType")
D
dangqingqing 已提交
108 109
            input_types.append(each[1])
        DataProviderConverter.__init__(self, input_types)
110

111 112 113
    def __len__(self):
        return len(self.input_names)

114 115
    def convert(self, dat, argument=None):
        """
D
dangqingqing 已提交
116 117
        :param dat: A list of mini-batch data. Each sample is a list or tuple
                    one feature or multiple features.
Y
Yu Yang 已提交
118 119

        :type dat: list
120 121 122
        :param argument: An Arguments object contains this mini-batch data with
                         one or multiple features. The Arguments definition is
                         in the API.
Y
Yu Yang 已提交
123
        :type argument: py_paddle.swig_paddle.Arguments
124 125
        """

D
dangqingqing 已提交
126 127 128 129 130
        def reorder_data(data):
            retv = []
            for each in data:
                reorder = []
                for name in self.input_names:
Y
Yu Yang 已提交
131
                    reorder.append(each[self.feeding[name]])
D
dangqingqing 已提交
132 133
                retv.append(reorder)
            return retv
D
dangqingqing 已提交
134

D
dangqingqing 已提交
135
        return DataProviderConverter.convert(self, reorder_data(dat), argument)