diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py index 159eec94874f2c84174608fb79d5c1e30263c3e8..b86f5eee7bd61cb543ef709eb21371d57766dc06 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py @@ -97,7 +97,8 @@ def train(use_cuda, save_dirname): # if math.isnan(float(avg_cost)): # sys.exit("got NaN loss, training failed.") elif isinstance(event, fluid.EndStepEvent): - print("Step {0}, Epoch {1}".format(event.step, event.epoch)) + print("Step {0}, Epoch {1} Metrics {2}".format( + event.step, event.epoch, map(numpy.array, event.metrics))) train_reader = paddle.batch( paddle.reader.shuffle( diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 8dd140a92c9bd6da3fc4e15715f21ede1270f102..544856794b52ead06e62fd733da678a99a1fac35 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -49,12 +49,14 @@ class BeginStepEvent(object): def __init__(self, epoch_id, step_id): self.epoch = epoch_id self.step = step_id + self.fetch_metrics = True class EndStepEvent(object): - def __init__(self, epoch_id, step_id): + def __init__(self, epoch_id, step_id, metrics): self.epoch = epoch_id self.step = step_id + self.metrics = metrics def check_and_get_place(place): @@ -259,13 +261,25 @@ class Trainer(object): feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) exe = executor.Executor(self.place) - for epoch_id in range(num_epochs): - event_handler(BeginEpochEvent(epoch_id)) - for step_id, data in enumerate(reader()): - event_handler(BeginStepEvent(epoch_id, step_id)) - exe.run(feed=feeder.feed(data), fetch_list=[]) - event_handler(EndStepEvent(epoch_id, step_id)) - event_handler(EndEpochEvent(epoch_id)) + reader = feeder.decorate_reader(reader, multi_devices=False) + self._train_by_any_executor(event_handler, exe, num_epochs, reader) + + def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): + for epoch_id in range(num_epochs): + event_handler(BeginEpochEvent(epoch_id)) + for step_id, data in enumerate(reader()): + begin_event = BeginStepEvent(epoch_id, step_id) + event_handler(begin_event) + if begin_event.fetch_metrics: + metrics = exe.run(feed=data, + fetch_list=[ + var.name + for var in self.train_func_outputs + ]) + else: + metrics = exe.run(feed=data, fetch_list=[]) + event_handler(EndStepEvent(epoch_id, step_id, metrics)) + event_handler(EndEpochEvent(epoch_id)) def _test_by_executor(self, reader, feed_order, fetch_list): with executor.scope_guard(self.scope): @@ -293,17 +307,8 @@ class Trainer(object): feed_list=feed_var_list, place=self.place) reader = feeder.decorate_reader(reader, multi_devices=True) for epoch_id in range(num_epochs): - event_handler(BeginEpochEvent(epoch_id=epoch_id)) - for step_id, data in enumerate(reader()): - event_handler( - BeginStepEvent( - epoch_id=epoch_id, step_id=step_id)) - pe.run(feed=data, fetch_list=[]) - event_handler( - EndStepEvent( - epoch_id=epoch_id, step_id=step_id)) - - event_handler(EndEpochEvent(epoch_id=epoch_id)) + self._train_by_any_executor(event_handler, pe, num_epochs, + reader) def _get_parallel_executor(self): return getattr(self, 'parallel_executor', None)