提交 dc02bfdf 编写于 作者: Y Yu Yang

Add list type of feeding

上级 4a99c441
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from py_paddle import DataProviderConverter from py_paddle import DataProviderConverter
import collections
import paddle.trainer.PyDataProvider2 as pydp2 import paddle.trainer.PyDataProvider2 as pydp2
__all__ = ['DataFeeder'] __all__ = ['DataFeeder']
...@@ -75,6 +75,13 @@ class DataFeeder(DataProviderConverter): ...@@ -75,6 +75,13 @@ class DataFeeder(DataProviderConverter):
input_types = [] input_types = []
if feeding is None: if feeding is None:
feeding = default_feeding_map(data_types) feeding = default_feeding_map(data_types)
elif isinstance(feeding, collections.Sequence):
feed_list = feeding
feeding = dict()
for i, name in enumerate(feed_list):
feeding[name] = i
elif not isinstance(feeding, dict):
raise TypeError("Feeding should be dict or sequence or None.")
self.feeding = feeding self.feeding = feeding
for each in data_types: for each in data_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册