From eb7d87545e14112fea1ef82d4d0eb60d3faa4a10 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 18 May 2018 13:06:34 +0800 Subject: [PATCH] add trainer.stop and fix a bug for train_by_parallel_executor (#10762) --- .../fit_a_line/test_fit_a_line.py | 26 +++++++++---------- python/paddle/fluid/trainer.py | 13 +++++++--- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py index 973142ccdf..4c8505acf3 100644 --- a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py @@ -57,22 +57,20 @@ def train(use_cuda, train_program, save_dirname): optimizer=fluid.optimizer.SGD(learning_rate=0.001)) def event_handler(event): - if isinstance(event, fluid.EndEpochEvent): - test_metrics = trainer.test( - reader=test_reader, feed_order=['x', 'y']) - print test_metrics - ''' - - ... - ['25.768919467926025'] - ['15.343549569447836'] - ... - - ''' - if float(test_metrics[0]) < 20.0: + if isinstance(event, fluid.EndStepEvent): + if event.step == 10: + test_metrics = trainer.test( + reader=test_reader, feed_order=['x', 'y']) + print test_metrics + ''' + ... + ['25.768919467926025'] + ['15.343549569447836'] + ... + ''' if save_dirname is not None: trainer.save_params(save_dirname) - return + trainer.stop() trainer.train( reader=train_reader, diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index f4292208c9..7da123dd92 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -100,6 +100,7 @@ class Trainer(object): param_path=None, place=None, parallel=False): + self.__stop = False self.parallel = parallel # 1. we need to generate a framework.Program by calling # program_func. Reference: fluid.program_guard in @@ -210,6 +211,12 @@ class Trainer(object): 'TRAINING_ROLE environment variable must be either TRAINER or PSERVER' ) + def stop(self): + """ + stop training + """ + self.__stop = True + def train(self, num_epochs, event_handler, reader=None, feed_order=None): """ Train the model. @@ -289,6 +296,8 @@ class Trainer(object): for epoch_id in range(num_epochs): event_handler(BeginEpochEvent(epoch_id)) for step_id, data in enumerate(reader()): + if self.__stop: + return begin_event = BeginStepEvent(epoch_id, step_id) event_handler(begin_event) if begin_event.fetch_metrics: @@ -327,9 +336,7 @@ class Trainer(object): feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) reader = feeder.decorate_reader(reader, multi_devices=True) - for epoch_id in range(num_epochs): - self._train_by_any_executor(event_handler, pe, num_epochs, - reader) + self._train_by_any_executor(event_handler, pe, num_epochs, reader) def _get_parallel_executor(self): return getattr(self, 'parallel_executor', None) -- GitLab