提交 1c4bb5c8 编写于 作者: Q Qiao Longfei 提交者: daminglu

user need to set feed order for Trainer.train and Trainer.test (#10679)

上级 6c320526
......@@ -172,9 +172,9 @@ class Trainer(object):
def train(self,
num_epochs,
event_handler,
reader=None,
parallel=False,
feed_order=None):
reader,
feed_order,
parallel=False):
"""
Train the model.
......@@ -202,7 +202,7 @@ class Trainer(object):
self._train_by_executor(num_epochs, event_handler, reader, feed_order)
def test(self, reader, feed_order=None):
def test(self, reader, feed_order):
"""
Test the model on given test data
......@@ -276,12 +276,7 @@ def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program):
raise TypeError("The 'program' should be an object of Program")
if feed_order is None:
feed_var_list = [
var for var in program.global_block().vars.itervalues()
if var.is_data
]
elif isinstance(feed_order, list):
if isinstance(feed_order, list):
feed_var_list = [
program.global_block().var(var_name) for var_name in feed_order
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册