提交 892cc82d 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1766 from reyoung/feature/add_list_type_of_feeding

Add list type of feeding
......@@ -13,7 +13,7 @@
# limitations under the License.
from py_paddle import DataProviderConverter
import collections
import paddle.trainer.PyDataProvider2 as pydp2
__all__ = ['DataFeeder']
......@@ -35,15 +35,30 @@ class DataFeeder(DataProviderConverter):
DataFeeder converts this mini-batch data entries into Arguments in order
to feed it to C++ interface.
The example usage:
The simple usage shows below
.. code-block:: python
feeding = ['image', 'label']
data_types = enumerate_data_types_of_data_layers(topology)
feeder = DataFeeder(data_types=data_types, feeding=feeding)
minibatch_data = [([1.0, 2.0, 3.0, ...], 5)]
arg = feeder(minibatch_data)
If mini-batch data and data layers are not one to one mapping, we
could pass a dictionary to feeding parameter to represent the mapping
relationship.
.. code-block:: python
data_types = [('image', paddle.data_type.dense_vector(784)),
('label', paddle.data_type.integer_value(10))]
reader_dict = {'image':0, 'label':1}
feeder = DataFeeder(data_types=data_types, reader_dict=reader_dict)
feeding = {'image':0, 'label':1}
feeder = DataFeeder(data_types=data_types, feeding=feeding)
minibatch_data = [
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample
......@@ -65,9 +80,9 @@ class DataFeeder(DataProviderConverter):
a tuple of (data_name, data_type).
:type data_types: list
:param reader_dict: A dictionary to specify the position of each data
in the input data.
:type feeding: dict
:param feeding: A dictionary or a sequence to specify the position of each
data in the input data.
:type feeding: dict|collections.Sequence|None
"""
def __init__(self, data_types, feeding=None):
......@@ -75,6 +90,13 @@ class DataFeeder(DataProviderConverter):
input_types = []
if feeding is None:
feeding = default_feeding_map(data_types)
elif isinstance(feeding, collections.Sequence):
feed_list = feeding
feeding = dict()
for i, name in enumerate(feed_list):
feeding[name] = i
elif not isinstance(feeding, dict):
raise TypeError("Feeding should be dict or sequence or None.")
self.feeding = feeding
for each in data_types:
......
......@@ -81,7 +81,7 @@ class SGD(object):
:type event_handler: (BaseEvent) => None
:param feeding: Feeding is a map of neural network input name and array
index that reader returns.
:type feeding: dict
:type feeding: dict|list
:return:
"""
if event_handler is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册