提交 e677a23d 编写于 作者: L LielinJiang

refine code

上级 fd163f53
...@@ -296,7 +296,7 @@ class ProgBarLogger(Callback): ...@@ -296,7 +296,7 @@ class ProgBarLogger(Callback):
self.tested_samples = 0 self.tested_samples = 0
self.test_progbar = ProgressBar( self.test_progbar = ProgressBar(
num=self.test_steps, verbose=self.verbose) num=self.test_steps, verbose=self.verbose)
if ParallelEnv().local_rank == 0: if self._is_print():
print('Predict begin...') print('Predict begin...')
def on_test_batch_end(self, step, logs=None): def on_test_batch_end(self, step, logs=None):
......
...@@ -1303,22 +1303,19 @@ class Model(fluid.dygraph.Layer): ...@@ -1303,22 +1303,19 @@ class Model(fluid.dygraph.Layer):
cbks.on_begin('train') cbks.on_begin('train')
for epoch in range(epochs): for epoch in range(epochs):
loader = train_loader
cbks.on_epoch_begin(epoch) cbks.on_epoch_begin(epoch)
logs = self._run_one_epoch(loader, cbks, 'train') logs = self._run_one_epoch(train_loader, cbks, 'train')
cbks.on_epoch_end(epoch, logs) cbks.on_epoch_end(epoch, logs)
if do_eval and epoch % eval_freq == 0: if do_eval and epoch % eval_freq == 0:
loader = eval_loader
eval_steps = self._len_data_loader(loader) eval_steps = self._len_data_loader(eval_loader)
cbks.on_begin('eval', { cbks.on_begin('eval', {
'steps': eval_steps, 'steps': eval_steps,
'metrics_name': self._metrics_name() 'metrics_name': self._metrics_name()
}) })
logs = self._run_one_epoch(loader, cbks, 'eval') logs = self._run_one_epoch(eval_loader, cbks, 'eval')
cbks.on_end('eval', logs) cbks.on_end('eval', logs)
...@@ -1414,15 +1411,13 @@ class Model(fluid.dygraph.Layer): ...@@ -1414,15 +1411,13 @@ class Model(fluid.dygraph.Layer):
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), ) metrics=self._metrics_name(), )
loader = eval_loader eval_steps = self._len_data_loader(eval_loader)
eval_steps = self._len_data_loader(loader)
cbks.on_begin('eval', { cbks.on_begin('eval', {
'steps': eval_steps, 'steps': eval_steps,
'metrics_name': self._metrics_name() 'metrics_name': self._metrics_name()
}) })
logs = self._run_one_epoch(loader, cbks, 'eval') logs = self._run_one_epoch(eval_loader, cbks, 'eval')
cbks.on_end('eval', logs) cbks.on_end('eval', logs)
...@@ -1522,18 +1517,16 @@ class Model(fluid.dygraph.Layer): ...@@ -1522,18 +1517,16 @@ class Model(fluid.dygraph.Layer):
self._test_dataloader = test_loader self._test_dataloader = test_loader
loader = test_loader
cbks = config_callbacks(callbacks, model=self, verbose=1) cbks = config_callbacks(callbacks, model=self, verbose=1)
test_steps = self._len_data_loader(loader) test_steps = self._len_data_loader(test_loader)
logs = {'steps': test_steps} logs = {'steps': test_steps}
cbks.on_begin('test', logs) cbks.on_begin('test', logs)
outputs = [] outputs = []
logs, outputs = self._run_one_epoch(loader, cbks, 'test') logs, outputs = self._run_one_epoch(test_loader, cbks, 'test')
outputs = list(zip(*outputs)) outputs = list(zip(*outputs))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册