提交 eb7d8754 编写于 作者: Q Qiao Longfei 提交者: daminglu

add trainer.stop and fix a bug for train_by_parallel_executor (#10762)

上级 54ae8e45
......@@ -57,22 +57,20 @@ def train(use_cuda, train_program, save_dirname):
optimizer=fluid.optimizer.SGD(learning_rate=0.001))
def event_handler(event):
if isinstance(event, fluid.EndEpochEvent):
test_metrics = trainer.test(
reader=test_reader, feed_order=['x', 'y'])
print test_metrics
'''
...
['25.768919467926025']
['15.343549569447836']
...
'''
if float(test_metrics[0]) < 20.0:
if isinstance(event, fluid.EndStepEvent):
if event.step == 10:
test_metrics = trainer.test(
reader=test_reader, feed_order=['x', 'y'])
print test_metrics
'''
...
['25.768919467926025']
['15.343549569447836']
...
'''
if save_dirname is not None:
trainer.save_params(save_dirname)
return
trainer.stop()
trainer.train(
reader=train_reader,
......
......@@ -100,6 +100,7 @@ class Trainer(object):
param_path=None,
place=None,
parallel=False):
self.__stop = False
self.parallel = parallel
# 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in
......@@ -210,6 +211,12 @@ class Trainer(object):
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def stop(self):
"""
stop training
"""
self.__stop = True
def train(self, num_epochs, event_handler, reader=None, feed_order=None):
"""
Train the model.
......@@ -289,6 +296,8 @@ class Trainer(object):
for epoch_id in range(num_epochs):
event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()):
if self.__stop:
return
begin_event = BeginStepEvent(epoch_id, step_id)
event_handler(begin_event)
if begin_event.fetch_metrics:
......@@ -327,9 +336,7 @@ class Trainer(object):
feeder = data_feeder.DataFeeder(
feed_list=feed_var_list, place=self.place)
reader = feeder.decorate_reader(reader, multi_devices=True)
for epoch_id in range(num_epochs):
self._train_by_any_executor(event_handler, pe, num_epochs,
reader)
self._train_by_any_executor(event_handler, pe, num_epochs, reader)
def _get_parallel_executor(self):
return getattr(self, 'parallel_executor', None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册