未验证 提交 9f53f3d0 编写于 作者: L LielinJiang 提交者: GitHub

Enhance logger callback for benchmark (#29106)

* enhance logger callback for benchmark
上级 e668cb07
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import time
import numbers import numbers
import warnings import warnings
...@@ -96,8 +97,8 @@ class CallbackList(object): ...@@ -96,8 +97,8 @@ class CallbackList(object):
func(*args) func(*args)
def _check_mode(self, mode): def _check_mode(self, mode):
assert mode in ['train', 'eval', 'test'], \ assert mode in ['train', 'eval', 'predict'], \
'mode should be train, eval or test' 'mode should be train, eval or predict'
def on_begin(self, mode, logs=None): def on_begin(self, mode, logs=None):
self._check_mode(mode) self._check_mode(mode)
...@@ -207,14 +208,14 @@ class Callback(object): ...@@ -207,14 +208,14 @@ class Callback(object):
of last batch of validation dataset. of last batch of validation dataset.
""" """
def on_test_begin(self, logs=None): def on_predict_begin(self, logs=None):
"""Called at the beginning of predict. """Called at the beginning of predict.
Args: Args:
logs (dict): The logs is a dict or None. logs (dict): The logs is a dict or None.
""" """
def on_test_end(self, logs=None): def on_predict_end(self, logs=None):
"""Called at the end of predict. """Called at the end of predict.
Args: Args:
...@@ -278,7 +279,7 @@ class Callback(object): ...@@ -278,7 +279,7 @@ class Callback(object):
of current batch. of current batch.
""" """
def on_test_batch_begin(self, step, logs=None): def on_predict_batch_begin(self, step, logs=None):
"""Called at the beginning of each batch in predict. """Called at the beginning of each batch in predict.
Args: Args:
...@@ -286,7 +287,7 @@ class Callback(object): ...@@ -286,7 +287,7 @@ class Callback(object):
logs (dict): The logs is a dict or None. logs (dict): The logs is a dict or None.
""" """
def on_test_batch_end(self, step, logs=None): def on_predict_batch_end(self, step, logs=None):
"""Called at the end of each batch in predict. """Called at the end of each batch in predict.
Args: Args:
...@@ -303,7 +304,9 @@ class ProgBarLogger(Callback): ...@@ -303,7 +304,9 @@ class ProgBarLogger(Callback):
log_freq (int): The frequency, in number of steps, log_freq (int): The frequency, in number of steps,
the logs such as loss, metrics are printed. Default: 1. the logs such as loss, metrics are printed. Default: 1.
verbose (int): The verbosity mode, should be 0, 1, or 2. verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2. 0 = silent, 1 = progress bar, 2 = one line per epoch, 3 = 2 +
time counter, such as average reader cost, samples per second.
Default: 2.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -351,6 +354,17 @@ class ProgBarLogger(Callback): ...@@ -351,6 +354,17 @@ class ProgBarLogger(Callback):
self.train_metrics = self.params['metrics'] self.train_metrics = self.params['metrics']
assert self.train_metrics assert self.train_metrics
self._train_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
if self._is_print():
print(
"The loss value printed in the log is the current batch, and the metric is the average value of previous step."
)
def on_epoch_begin(self, epoch=None, logs=None): def on_epoch_begin(self, epoch=None, logs=None):
self.steps = self.params['steps'] self.steps = self.params['steps']
self.epoch = epoch self.epoch = epoch
...@@ -359,6 +373,8 @@ class ProgBarLogger(Callback): ...@@ -359,6 +373,8 @@ class ProgBarLogger(Callback):
print('Epoch %d/%d' % (epoch + 1, self.epochs)) print('Epoch %d/%d' % (epoch + 1, self.epochs))
self.train_progbar = ProgressBar(num=self.steps, verbose=self.verbose) self.train_progbar = ProgressBar(num=self.steps, verbose=self.verbose)
self._train_timer['batch_start_time'] = time.time()
def _updates(self, logs, mode): def _updates(self, logs, mode):
values = [] values = []
metrics = getattr(self, '%s_metrics' % (mode)) metrics = getattr(self, '%s_metrics' % (mode))
...@@ -369,15 +385,39 @@ class ProgBarLogger(Callback): ...@@ -369,15 +385,39 @@ class ProgBarLogger(Callback):
if k in logs: if k in logs:
values.append((k, logs[k])) values.append((k, logs[k]))
if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
timer = getattr(self, '_%s_timer' % (mode))
cnt = timer['count'] if timer['count'] > 0 else 1.0
samples = timer['samples'] if timer['samples'] > 0 else 1.0
values.append(
('avg_reader_cost', "%.5f sec" % (timer['data_time'] / cnt)))
values.append(
('avg_batch_cost', "%.5f sec" % (timer['batch_time'] / cnt)))
values.append(
('ips', "%.5f samples/sec" %
(samples / (timer['batch_time'] + timer['batch_time']))))
progbar.update(steps, values) progbar.update(steps, values)
def on_train_batch_begin(self, step, logs=None):
self._train_timer['batch_data_end_time'] = time.time()
self._train_timer['data_time'] += (
self._train_timer['batch_data_end_time'] -
self._train_timer['batch_start_time'])
def on_train_batch_end(self, step, logs=None): def on_train_batch_end(self, step, logs=None):
logs = logs or {} logs = logs or {}
self.train_step += 1 self.train_step += 1
self._train_timer['batch_time'] += (
time.time() - self._train_timer['batch_data_end_time'])
self._train_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._train_timer['samples'] += samples
if self._is_print() and self.train_step % self.log_freq == 0: if self._is_print() and self.train_step % self.log_freq == 0:
if self.steps is None or self.train_step < self.steps: if self.steps is None or self.train_step < self.steps:
self._updates(logs, 'train') self._updates(logs, 'train')
self._train_timer['batch_start_time'] = time.time()
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
logs = logs or {} logs = logs or {}
...@@ -390,10 +430,28 @@ class ProgBarLogger(Callback): ...@@ -390,10 +430,28 @@ class ProgBarLogger(Callback):
self.eval_step = 0 self.eval_step = 0
self.evaled_samples = 0 self.evaled_samples = 0
self._eval_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
self.eval_progbar = ProgressBar( self.eval_progbar = ProgressBar(
num=self.eval_steps, verbose=self.verbose) num=self.eval_steps, verbose=self.verbose)
if self._is_print(): if self._is_print():
print('Eval begin...') print('Eval begin...')
print(
"The loss value printed in the log is the current batch, and the metric is the average value of previous step."
)
self._eval_timer['batch_start_time'] = time.time()
def on_eval_batch_begin(self, step, logs=None):
self._eval_timer['batch_data_end_time'] = time.time()
self._eval_timer['data_time'] += (
self._eval_timer['batch_data_end_time'] -
self._eval_timer['batch_start_time'])
def on_eval_batch_end(self, step, logs=None): def on_eval_batch_end(self, step, logs=None):
logs = logs or {} logs = logs or {}
...@@ -401,37 +459,69 @@ class ProgBarLogger(Callback): ...@@ -401,37 +459,69 @@ class ProgBarLogger(Callback):
samples = logs.get('batch_size', 1) samples = logs.get('batch_size', 1)
self.evaled_samples += samples self.evaled_samples += samples
self._eval_timer['batch_time'] += (
time.time() - self._eval_timer['batch_data_end_time'])
self._eval_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._eval_timer['samples'] += samples
if self._is_print() and self.eval_step % self.log_freq == 0: if self._is_print() and self.eval_step % self.log_freq == 0:
if self.eval_steps is None or self.eval_step < self.eval_steps: if self.eval_steps is None or self.eval_step < self.eval_steps:
self._updates(logs, 'eval') self._updates(logs, 'eval')
def on_test_begin(self, logs=None): self._eval_timer['batch_start_time'] = time.time()
def on_predict_begin(self, logs=None):
self.test_steps = logs.get('steps', None) self.test_steps = logs.get('steps', None)
self.test_metrics = logs.get('metrics', []) self.test_metrics = logs.get('metrics', [])
self.test_step = 0 self.test_step = 0
self.tested_samples = 0 self.tested_samples = 0
self._test_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
self.test_progbar = ProgressBar( self.test_progbar = ProgressBar(
num=self.test_steps, verbose=self.verbose) num=self.test_steps, verbose=self.verbose)
if self._is_print(): if self._is_print():
print('Predict begin...') print('Predict begin...')
def on_test_batch_end(self, step, logs=None): self._test_timer['batch_start_time'] = time.time()
def on_predict_batch_begin(self, step, logs=None):
self._test_timer['batch_data_end_time'] = time.time()
self._test_timer['data_time'] += (
self._test_timer['batch_data_end_time'] -
self._test_timer['batch_start_time'])
def on_predict_batch_end(self, step, logs=None):
logs = logs or {} logs = logs or {}
self.test_step += 1 self.test_step += 1
samples = logs.get('batch_size', 1) samples = logs.get('batch_size', 1)
self.tested_samples += samples self.tested_samples += samples
self._test_timer['batch_time'] += (
time.time() - self._test_timer['batch_data_end_time'])
self._test_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._test_timer['samples'] += samples
if self.test_step % self.log_freq == 0 and self._is_print(): if self.test_step % self.log_freq == 0 and self._is_print():
if self.test_steps is None or self.test_step < self.test_steps: if self.test_steps is None or self.test_step < self.test_steps:
self._updates(logs, 'test') self._updates(logs, 'test')
self._test_timer['batch_start_time'] = time.time()
def on_eval_end(self, logs=None): def on_eval_end(self, logs=None):
logs = logs or {} logs = logs or {}
if self._is_print() and (self.eval_steps is not None): if self._is_print() and (self.eval_steps is not None):
self._updates(logs, 'eval') self._updates(logs, 'eval')
print('Eval samples: %d' % (self.evaled_samples)) print('Eval samples: %d' % (self.evaled_samples))
def on_test_end(self, logs=None): def on_predict_end(self, logs=None):
logs = logs or {} logs = logs or {}
if self._is_print(): if self._is_print():
if self.test_step % self.log_freq != 0 or self.verbose == 1: if self.test_step % self.log_freq != 0 or self.verbose == 1:
......
...@@ -1692,11 +1692,11 @@ class Model(object): ...@@ -1692,11 +1692,11 @@ class Model(object):
test_steps = self._len_data_loader(test_loader) test_steps = self._len_data_loader(test_loader)
logs = {'steps': test_steps} logs = {'steps': test_steps}
cbks.on_begin('test', logs) cbks.on_begin('predict', logs)
outputs = [] outputs = []
logs, outputs = self._run_one_epoch(test_loader, cbks, 'test') logs, outputs = self._run_one_epoch(test_loader, cbks, 'predict')
outputs = list(zip(*outputs)) outputs = list(zip(*outputs))
...@@ -1707,7 +1707,7 @@ class Model(object): ...@@ -1707,7 +1707,7 @@ class Model(object):
self._test_dataloader = None self._test_dataloader = None
cbks.on_end('test', logs) cbks.on_end('predict', logs)
return outputs return outputs
def _save_inference_model(self, path): def _save_inference_model(self, path):
...@@ -1793,7 +1793,7 @@ class Model(object): ...@@ -1793,7 +1793,7 @@ class Model(object):
callbacks.on_batch_begin(mode, step, logs) callbacks.on_batch_begin(mode, step, logs)
if mode != 'test': if mode != 'predict':
outs = getattr(self, mode + '_batch')(data[:len(self._inputs)], outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
data[len(self._inputs):]) data[len(self._inputs):])
if self._metrics and self._loss: if self._metrics and self._loss:
...@@ -1829,7 +1829,7 @@ class Model(object): ...@@ -1829,7 +1829,7 @@ class Model(object):
callbacks.on_batch_end(mode, step, logs) callbacks.on_batch_end(mode, step, logs)
self._reset_metrics() self._reset_metrics()
if mode == 'test': if mode == 'predict':
return logs, outputs return logs, outputs
return logs return logs
......
...@@ -159,7 +159,7 @@ class ProgressBar(object): ...@@ -159,7 +159,7 @@ class ProgressBar(object):
sys.stdout.write(info) sys.stdout.write(info)
sys.stdout.flush() sys.stdout.flush()
self._last_update = now self._last_update = now
elif self._verbose == 2: elif self._verbose == 2 or self._verbose == 3:
if self._num: if self._num:
numdigits = int(np.log10(self._num)) + 1 numdigits = int(np.log10(self._num)) + 1
count = ('step %' + str(numdigits) + 'd/%d') % (current_num, count = ('step %' + str(numdigits) + 'd/%d') % (current_num,
......
...@@ -106,13 +106,13 @@ class TestCallbacks(unittest.TestCase): ...@@ -106,13 +106,13 @@ class TestCallbacks(unittest.TestCase):
test_logs = {} test_logs = {}
params = {'steps': eval_steps} params = {'steps': eval_steps}
cbks.on_begin('test', params) cbks.on_begin('predict', params)
for step in range(eval_steps): for step in range(eval_steps):
cbks.on_batch_begin('test', step, test_logs) cbks.on_batch_begin('predict', step, test_logs)
test_logs['batch_size'] = 2 test_logs['batch_size'] = 2
time.sleep(0.005) time.sleep(0.005)
cbks.on_batch_end('test', step, test_logs) cbks.on_batch_end('predict', step, test_logs)
cbks.on_end('test', test_logs) cbks.on_end('predict', test_logs)
cbks.on_end('train') cbks.on_end('train')
...@@ -128,6 +128,10 @@ class TestCallbacks(unittest.TestCase): ...@@ -128,6 +128,10 @@ class TestCallbacks(unittest.TestCase):
self.verbose = 2 self.verbose = 2
self.run_callback() self.run_callback()
def test_callback_verbose_3(self):
self.verbose = 3
self.run_callback()
def test_visualdl_callback(self): def test_visualdl_callback(self):
# visualdl not support python2 # visualdl not support python2
if sys.version_info < (3, ): if sys.version_info < (3, ):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册