diff --git a/python/paddle/v2/data_feeder.py b/python/paddle/v2/data_feeder.py index bda8e22fd282f8ff4a820e4ecb6b3bb421d57890..ca3e44e5a0187da33654f4955197196b150da196 100644 --- a/python/paddle/v2/data_feeder.py +++ b/python/paddle/v2/data_feeder.py @@ -13,7 +13,7 @@ # limitations under the License. from py_paddle import DataProviderConverter - +import collections import paddle.trainer.PyDataProvider2 as pydp2 __all__ = ['DataFeeder'] @@ -35,15 +35,30 @@ class DataFeeder(DataProviderConverter): DataFeeder converts this mini-batch data entries into Arguments in order to feed it to C++ interface. - The example usage: + The simple usage shows below + + .. code-block:: python + + feeding = ['image', 'label'] + data_types = enumerate_data_types_of_data_layers(topology) + feeder = DataFeeder(data_types=data_types, feeding=feeding) + + minibatch_data = [([1.0, 2.0, 3.0, ...], 5)] + + arg = feeder(minibatch_data) + + + If mini-batch data and data layers are not one to one mapping, we + could pass a dictionary to feeding parameter to represent the mapping + relationship. .. code-block:: python data_types = [('image', paddle.data_type.dense_vector(784)), ('label', paddle.data_type.integer_value(10))] - reader_dict = {'image':0, 'label':1} - feeder = DataFeeder(data_types=data_types, reader_dict=reader_dict) + feeding = {'image':0, 'label':1} + feeder = DataFeeder(data_types=data_types, feeding=feeding) minibatch_data = [ ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample @@ -65,9 +80,9 @@ class DataFeeder(DataProviderConverter): a tuple of (data_name, data_type). :type data_types: list - :param reader_dict: A dictionary to specify the position of each data - in the input data. - :type feeding: dict + :param feeding: A dictionary or a sequence to specify the position of each + data in the input data. + :type feeding: dict|collections.Sequence|None """ def __init__(self, data_types, feeding=None): @@ -75,6 +90,13 @@ class DataFeeder(DataProviderConverter): input_types = [] if feeding is None: feeding = default_feeding_map(data_types) + elif isinstance(feeding, collections.Sequence): + feed_list = feeding + feeding = dict() + for i, name in enumerate(feed_list): + feeding[name] = i + elif not isinstance(feeding, dict): + raise TypeError("Feeding should be dict or sequence or None.") self.feeding = feeding for each in data_types: diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index a207beb548f0fda6d7aacf5c55e53d937e18a924..f5797a86c2b71502a7791453ff86c6a486c9f185 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -81,7 +81,7 @@ class SGD(object): :type event_handler: (BaseEvent) => None :param feeding: Feeding is a map of neural network input name and array index that reader returns. - :type feeding: dict + :type feeding: dict|list :return: """ if event_handler is None: