diff --git a/python/paddle/v2/data_feeder.py b/python/paddle/v2/data_feeder.py index bda8e22fd282f8ff4a820e4ecb6b3bb421d57890..63a8917c55a83dfdf59dece51b5025831501a7c1 100644 --- a/python/paddle/v2/data_feeder.py +++ b/python/paddle/v2/data_feeder.py @@ -13,7 +13,7 @@ # limitations under the License. from py_paddle import DataProviderConverter - +import collections import paddle.trainer.PyDataProvider2 as pydp2 __all__ = ['DataFeeder'] @@ -75,6 +75,13 @@ class DataFeeder(DataProviderConverter): input_types = [] if feeding is None: feeding = default_feeding_map(data_types) + elif isinstance(feeding, collections.Sequence): + feed_list = feeding + feeding = dict() + for i, name in enumerate(feed_list): + feeding[name] = i + elif not isinstance(feeding, dict): + raise TypeError("Feeding should be dict or sequence or None.") self.feeding = feeding for each in data_types: