提交 eb7d8754 编写于 作者: Q Qiao Longfei 提交者: daminglu

add trainer.stop and fix a bug for train_by_parallel_executor (#10762)

上级 54ae8e45
...@@ -57,22 +57,20 @@ def train(use_cuda, train_program, save_dirname): ...@@ -57,22 +57,20 @@ def train(use_cuda, train_program, save_dirname):
optimizer=fluid.optimizer.SGD(learning_rate=0.001)) optimizer=fluid.optimizer.SGD(learning_rate=0.001))
def event_handler(event): def event_handler(event):
if isinstance(event, fluid.EndEpochEvent): if isinstance(event, fluid.EndStepEvent):
test_metrics = trainer.test( if event.step == 10:
reader=test_reader, feed_order=['x', 'y']) test_metrics = trainer.test(
print test_metrics reader=test_reader, feed_order=['x', 'y'])
''' print test_metrics
'''
... ...
['25.768919467926025'] ['25.768919467926025']
['15.343549569447836'] ['15.343549569447836']
... ...
'''
'''
if float(test_metrics[0]) < 20.0:
if save_dirname is not None: if save_dirname is not None:
trainer.save_params(save_dirname) trainer.save_params(save_dirname)
return trainer.stop()
trainer.train( trainer.train(
reader=train_reader, reader=train_reader,
......
...@@ -100,6 +100,7 @@ class Trainer(object): ...@@ -100,6 +100,7 @@ class Trainer(object):
param_path=None, param_path=None,
place=None, place=None,
parallel=False): parallel=False):
self.__stop = False
self.parallel = parallel self.parallel = parallel
# 1. we need to generate a framework.Program by calling # 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in # program_func. Reference: fluid.program_guard in
...@@ -210,6 +211,12 @@ class Trainer(object): ...@@ -210,6 +211,12 @@ class Trainer(object):
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER' 'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
) )
def stop(self):
"""
stop training
"""
self.__stop = True
def train(self, num_epochs, event_handler, reader=None, feed_order=None): def train(self, num_epochs, event_handler, reader=None, feed_order=None):
""" """
Train the model. Train the model.
...@@ -289,6 +296,8 @@ class Trainer(object): ...@@ -289,6 +296,8 @@ class Trainer(object):
for epoch_id in range(num_epochs): for epoch_id in range(num_epochs):
event_handler(BeginEpochEvent(epoch_id)) event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()): for step_id, data in enumerate(reader()):
if self.__stop:
return
begin_event = BeginStepEvent(epoch_id, step_id) begin_event = BeginStepEvent(epoch_id, step_id)
event_handler(begin_event) event_handler(begin_event)
if begin_event.fetch_metrics: if begin_event.fetch_metrics:
...@@ -327,9 +336,7 @@ class Trainer(object): ...@@ -327,9 +336,7 @@ class Trainer(object):
feeder = data_feeder.DataFeeder( feeder = data_feeder.DataFeeder(
feed_list=feed_var_list, place=self.place) feed_list=feed_var_list, place=self.place)
reader = feeder.decorate_reader(reader, multi_devices=True) reader = feeder.decorate_reader(reader, multi_devices=True)
for epoch_id in range(num_epochs): self._train_by_any_executor(event_handler, pe, num_epochs, reader)
self._train_by_any_executor(event_handler, pe, num_epochs,
reader)
def _get_parallel_executor(self): def _get_parallel_executor(self):
return getattr(self, 'parallel_executor', None) return getattr(self, 'parallel_executor', None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册