提交 1c4bb5c8 编写于 作者: Q Qiao Longfei 提交者: daminglu

user need to set feed order for Trainer.train and Trainer.test (#10679)

上级 6c320526
...@@ -172,9 +172,9 @@ class Trainer(object): ...@@ -172,9 +172,9 @@ class Trainer(object):
def train(self, def train(self,
num_epochs, num_epochs,
event_handler, event_handler,
reader=None, reader,
parallel=False, feed_order,
feed_order=None): parallel=False):
""" """
Train the model. Train the model.
...@@ -202,7 +202,7 @@ class Trainer(object): ...@@ -202,7 +202,7 @@ class Trainer(object):
self._train_by_executor(num_epochs, event_handler, reader, feed_order) self._train_by_executor(num_epochs, event_handler, reader, feed_order)
def test(self, reader, feed_order=None): def test(self, reader, feed_order):
""" """
Test the model on given test data Test the model on given test data
...@@ -276,12 +276,7 @@ def build_feed_var_list(program, feed_order): ...@@ -276,12 +276,7 @@ def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program): if not isinstance(program, framework.Program):
raise TypeError("The 'program' should be an object of Program") raise TypeError("The 'program' should be an object of Program")
if feed_order is None: if isinstance(feed_order, list):
feed_var_list = [
var for var in program.global_block().vars.itervalues()
if var.is_data
]
elif isinstance(feed_order, list):
feed_var_list = [ feed_var_list = [
program.global_block().var(var_name) for var_name in feed_order program.global_block().var(var_name) for var_name in feed_order
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册