提交 c4ad0dd0 编写于 作者: Y yuyang18

Add fetch metrics

上级 2a0205a5
...@@ -97,7 +97,8 @@ def train(use_cuda, save_dirname): ...@@ -97,7 +97,8 @@ def train(use_cuda, save_dirname):
# if math.isnan(float(avg_cost)): # if math.isnan(float(avg_cost)):
# sys.exit("got NaN loss, training failed.") # sys.exit("got NaN loss, training failed.")
elif isinstance(event, fluid.EndStepEvent): elif isinstance(event, fluid.EndStepEvent):
print("Step {0}, Epoch {1}".format(event.step, event.epoch)) print("Step {0}, Epoch {1} Metrics {2}".format(
event.step, event.epoch, map(numpy.array, event.metrics)))
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
......
...@@ -49,12 +49,14 @@ class BeginStepEvent(object): ...@@ -49,12 +49,14 @@ class BeginStepEvent(object):
def __init__(self, epoch_id, step_id): def __init__(self, epoch_id, step_id):
self.epoch = epoch_id self.epoch = epoch_id
self.step = step_id self.step = step_id
self.fetch_metrics = True
class EndStepEvent(object): class EndStepEvent(object):
def __init__(self, epoch_id, step_id): def __init__(self, epoch_id, step_id, metrics):
self.epoch = epoch_id self.epoch = epoch_id
self.step = step_id self.step = step_id
self.metrics = metrics
def check_and_get_place(place): def check_and_get_place(place):
...@@ -259,12 +261,24 @@ class Trainer(object): ...@@ -259,12 +261,24 @@ class Trainer(object):
feeder = data_feeder.DataFeeder( feeder = data_feeder.DataFeeder(
feed_list=feed_var_list, place=self.place) feed_list=feed_var_list, place=self.place)
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
reader = feeder.decorate_reader(reader, multi_devices=False)
self._train_by_any_executor(event_handler, exe, num_epochs, reader)
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
for epoch_id in range(num_epochs): for epoch_id in range(num_epochs):
event_handler(BeginEpochEvent(epoch_id)) event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()): for step_id, data in enumerate(reader()):
event_handler(BeginStepEvent(epoch_id, step_id)) begin_event = BeginStepEvent(epoch_id, step_id)
exe.run(feed=feeder.feed(data), fetch_list=[]) event_handler(begin_event)
event_handler(EndStepEvent(epoch_id, step_id)) if begin_event.fetch_metrics:
metrics = exe.run(feed=data,
fetch_list=[
var.name
for var in self.train_func_outputs
])
else:
metrics = exe.run(feed=data, fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id, metrics))
event_handler(EndEpochEvent(epoch_id)) event_handler(EndEpochEvent(epoch_id))
def _test_by_executor(self, reader, feed_order, fetch_list): def _test_by_executor(self, reader, feed_order, fetch_list):
...@@ -293,17 +307,8 @@ class Trainer(object): ...@@ -293,17 +307,8 @@ class Trainer(object):
feed_list=feed_var_list, place=self.place) feed_list=feed_var_list, place=self.place)
reader = feeder.decorate_reader(reader, multi_devices=True) reader = feeder.decorate_reader(reader, multi_devices=True)
for epoch_id in range(num_epochs): for epoch_id in range(num_epochs):
event_handler(BeginEpochEvent(epoch_id=epoch_id)) self._train_by_any_executor(event_handler, pe, num_epochs,
for step_id, data in enumerate(reader()): reader)
event_handler(
BeginStepEvent(
epoch_id=epoch_id, step_id=step_id))
pe.run(feed=data, fetch_list=[])
event_handler(
EndStepEvent(
epoch_id=epoch_id, step_id=step_id))
event_handler(EndEpochEvent(epoch_id=epoch_id))
def _get_parallel_executor(self): def _get_parallel_executor(self):
return getattr(self, 'parallel_executor', None) return getattr(self, 'parallel_executor', None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册