提交 892cc82d 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1766 from reyoung/feature/add_list_type_of_feeding

Add list type of feeding
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from py_paddle import DataProviderConverter from py_paddle import DataProviderConverter
import collections
import paddle.trainer.PyDataProvider2 as pydp2 import paddle.trainer.PyDataProvider2 as pydp2
__all__ = ['DataFeeder'] __all__ = ['DataFeeder']
...@@ -35,15 +35,30 @@ class DataFeeder(DataProviderConverter): ...@@ -35,15 +35,30 @@ class DataFeeder(DataProviderConverter):
DataFeeder converts this mini-batch data entries into Arguments in order DataFeeder converts this mini-batch data entries into Arguments in order
to feed it to C++ interface. to feed it to C++ interface.
The example usage: The simple usage shows below
.. code-block:: python
feeding = ['image', 'label']
data_types = enumerate_data_types_of_data_layers(topology)
feeder = DataFeeder(data_types=data_types, feeding=feeding)
minibatch_data = [([1.0, 2.0, 3.0, ...], 5)]
arg = feeder(minibatch_data)
If mini-batch data and data layers are not one to one mapping, we
could pass a dictionary to feeding parameter to represent the mapping
relationship.
.. code-block:: python .. code-block:: python
data_types = [('image', paddle.data_type.dense_vector(784)), data_types = [('image', paddle.data_type.dense_vector(784)),
('label', paddle.data_type.integer_value(10))] ('label', paddle.data_type.integer_value(10))]
reader_dict = {'image':0, 'label':1} feeding = {'image':0, 'label':1}
feeder = DataFeeder(data_types=data_types, reader_dict=reader_dict) feeder = DataFeeder(data_types=data_types, feeding=feeding)
minibatch_data = [ minibatch_data = [
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample
...@@ -65,9 +80,9 @@ class DataFeeder(DataProviderConverter): ...@@ -65,9 +80,9 @@ class DataFeeder(DataProviderConverter):
a tuple of (data_name, data_type). a tuple of (data_name, data_type).
:type data_types: list :type data_types: list
:param reader_dict: A dictionary to specify the position of each data :param feeding: A dictionary or a sequence to specify the position of each
in the input data. data in the input data.
:type feeding: dict :type feeding: dict|collections.Sequence|None
""" """
def __init__(self, data_types, feeding=None): def __init__(self, data_types, feeding=None):
...@@ -75,6 +90,13 @@ class DataFeeder(DataProviderConverter): ...@@ -75,6 +90,13 @@ class DataFeeder(DataProviderConverter):
input_types = [] input_types = []
if feeding is None: if feeding is None:
feeding = default_feeding_map(data_types) feeding = default_feeding_map(data_types)
elif isinstance(feeding, collections.Sequence):
feed_list = feeding
feeding = dict()
for i, name in enumerate(feed_list):
feeding[name] = i
elif not isinstance(feeding, dict):
raise TypeError("Feeding should be dict or sequence or None.")
self.feeding = feeding self.feeding = feeding
for each in data_types: for each in data_types:
......
...@@ -81,7 +81,7 @@ class SGD(object): ...@@ -81,7 +81,7 @@ class SGD(object):
:type event_handler: (BaseEvent) => None :type event_handler: (BaseEvent) => None
:param feeding: Feeding is a map of neural network input name and array :param feeding: Feeding is a map of neural network input name and array
index that reader returns. index that reader returns.
:type feeding: dict :type feeding: dict|list
:return: :return:
""" """
if event_handler is None: if event_handler is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册