diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py index 973142ccdfa3c8a6f56abeb2a5a5081784aa8db1..4c8505acf322a8ee33799c009b523cd70bd01db3 100644 --- a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py @@ -57,22 +57,20 @@ def train(use_cuda, train_program, save_dirname): optimizer=fluid.optimizer.SGD(learning_rate=0.001)) def event_handler(event): - if isinstance(event, fluid.EndEpochEvent): - test_metrics = trainer.test( - reader=test_reader, feed_order=['x', 'y']) - print test_metrics - ''' - - ... - ['25.768919467926025'] - ['15.343549569447836'] - ... - - ''' - if float(test_metrics[0]) < 20.0: + if isinstance(event, fluid.EndStepEvent): + if event.step == 10: + test_metrics = trainer.test( + reader=test_reader, feed_order=['x', 'y']) + print test_metrics + ''' + ... + ['25.768919467926025'] + ['15.343549569447836'] + ... + ''' if save_dirname is not None: trainer.save_params(save_dirname) - return + trainer.stop() trainer.train( reader=train_reader, diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index f4292208c949ea271a7a65267d0a71208e74e75f..7da123dd92ed9d111d68cd70efb8ce1493452609 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -100,6 +100,7 @@ class Trainer(object): param_path=None, place=None, parallel=False): + self.__stop = False self.parallel = parallel # 1. we need to generate a framework.Program by calling # program_func. Reference: fluid.program_guard in @@ -210,6 +211,12 @@ class Trainer(object): 'TRAINING_ROLE environment variable must be either TRAINER or PSERVER' ) + def stop(self): + """ + stop training + """ + self.__stop = True + def train(self, num_epochs, event_handler, reader=None, feed_order=None): """ Train the model. @@ -289,6 +296,8 @@ class Trainer(object): for epoch_id in range(num_epochs): event_handler(BeginEpochEvent(epoch_id)) for step_id, data in enumerate(reader()): + if self.__stop: + return begin_event = BeginStepEvent(epoch_id, step_id) event_handler(begin_event) if begin_event.fetch_metrics: @@ -327,9 +336,7 @@ class Trainer(object): feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) reader = feeder.decorate_reader(reader, multi_devices=True) - for epoch_id in range(num_epochs): - self._train_by_any_executor(event_handler, pe, num_epochs, - reader) + self._train_by_any_executor(event_handler, pe, num_epochs, reader) def _get_parallel_executor(self): return getattr(self, 'parallel_executor', None)