未验证 提交 9f53f3d0 编写于 作者: L LielinJiang 提交者: GitHub

Enhance logger callback for benchmark (#29106)

* enhance logger callback for benchmark
上级 e668cb07
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import time
import numbers
import warnings
......@@ -96,8 +97,8 @@ class CallbackList(object):
func(*args)
def _check_mode(self, mode):
assert mode in ['train', 'eval', 'test'], \
'mode should be train, eval or test'
assert mode in ['train', 'eval', 'predict'], \
'mode should be train, eval or predict'
def on_begin(self, mode, logs=None):
self._check_mode(mode)
......@@ -207,14 +208,14 @@ class Callback(object):
of last batch of validation dataset.
"""
def on_test_begin(self, logs=None):
def on_predict_begin(self, logs=None):
"""Called at the beginning of predict.
Args:
logs (dict): The logs is a dict or None.
"""
def on_test_end(self, logs=None):
def on_predict_end(self, logs=None):
"""Called at the end of predict.
Args:
......@@ -278,7 +279,7 @@ class Callback(object):
of current batch.
"""
def on_test_batch_begin(self, step, logs=None):
def on_predict_batch_begin(self, step, logs=None):
"""Called at the beginning of each batch in predict.
Args:
......@@ -286,7 +287,7 @@ class Callback(object):
logs (dict): The logs is a dict or None.
"""
def on_test_batch_end(self, step, logs=None):
def on_predict_batch_end(self, step, logs=None):
"""Called at the end of each batch in predict.
Args:
......@@ -303,7 +304,9 @@ class ProgBarLogger(Callback):
log_freq (int): The frequency, in number of steps,
the logs such as loss, metrics are printed. Default: 1.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2.
0 = silent, 1 = progress bar, 2 = one line per epoch, 3 = 2 +
time counter, such as average reader cost, samples per second.
Default: 2.
Examples:
.. code-block:: python
......@@ -351,6 +354,17 @@ class ProgBarLogger(Callback):
self.train_metrics = self.params['metrics']
assert self.train_metrics
self._train_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
if self._is_print():
print(
"The loss value printed in the log is the current batch, and the metric is the average value of previous step."
)
def on_epoch_begin(self, epoch=None, logs=None):
self.steps = self.params['steps']
self.epoch = epoch
......@@ -359,6 +373,8 @@ class ProgBarLogger(Callback):
print('Epoch %d/%d' % (epoch + 1, self.epochs))
self.train_progbar = ProgressBar(num=self.steps, verbose=self.verbose)
self._train_timer['batch_start_time'] = time.time()
def _updates(self, logs, mode):
values = []
metrics = getattr(self, '%s_metrics' % (mode))
......@@ -369,15 +385,39 @@ class ProgBarLogger(Callback):
if k in logs:
values.append((k, logs[k]))
if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
timer = getattr(self, '_%s_timer' % (mode))
cnt = timer['count'] if timer['count'] > 0 else 1.0
samples = timer['samples'] if timer['samples'] > 0 else 1.0
values.append(
('avg_reader_cost', "%.5f sec" % (timer['data_time'] / cnt)))
values.append(
('avg_batch_cost', "%.5f sec" % (timer['batch_time'] / cnt)))
values.append(
('ips', "%.5f samples/sec" %
(samples / (timer['batch_time'] + timer['batch_time']))))
progbar.update(steps, values)
def on_train_batch_begin(self, step, logs=None):
self._train_timer['batch_data_end_time'] = time.time()
self._train_timer['data_time'] += (
self._train_timer['batch_data_end_time'] -
self._train_timer['batch_start_time'])
def on_train_batch_end(self, step, logs=None):
logs = logs or {}
self.train_step += 1
self._train_timer['batch_time'] += (
time.time() - self._train_timer['batch_data_end_time'])
self._train_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._train_timer['samples'] += samples
if self._is_print() and self.train_step % self.log_freq == 0:
if self.steps is None or self.train_step < self.steps:
self._updates(logs, 'train')
self._train_timer['batch_start_time'] = time.time()
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
......@@ -390,10 +430,28 @@ class ProgBarLogger(Callback):
self.eval_step = 0
self.evaled_samples = 0
self._eval_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
self.eval_progbar = ProgressBar(
num=self.eval_steps, verbose=self.verbose)
if self._is_print():
print('Eval begin...')
print(
"The loss value printed in the log is the current batch, and the metric is the average value of previous step."
)
self._eval_timer['batch_start_time'] = time.time()
def on_eval_batch_begin(self, step, logs=None):
self._eval_timer['batch_data_end_time'] = time.time()
self._eval_timer['data_time'] += (
self._eval_timer['batch_data_end_time'] -
self._eval_timer['batch_start_time'])
def on_eval_batch_end(self, step, logs=None):
logs = logs or {}
......@@ -401,37 +459,69 @@ class ProgBarLogger(Callback):
samples = logs.get('batch_size', 1)
self.evaled_samples += samples
self._eval_timer['batch_time'] += (
time.time() - self._eval_timer['batch_data_end_time'])
self._eval_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._eval_timer['samples'] += samples
if self._is_print() and self.eval_step % self.log_freq == 0:
if self.eval_steps is None or self.eval_step < self.eval_steps:
self._updates(logs, 'eval')
def on_test_begin(self, logs=None):
self._eval_timer['batch_start_time'] = time.time()
def on_predict_begin(self, logs=None):
self.test_steps = logs.get('steps', None)
self.test_metrics = logs.get('metrics', [])
self.test_step = 0
self.tested_samples = 0
self._test_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
self.test_progbar = ProgressBar(
num=self.test_steps, verbose=self.verbose)
if self._is_print():
print('Predict begin...')
def on_test_batch_end(self, step, logs=None):
self._test_timer['batch_start_time'] = time.time()
def on_predict_batch_begin(self, step, logs=None):
self._test_timer['batch_data_end_time'] = time.time()
self._test_timer['data_time'] += (
self._test_timer['batch_data_end_time'] -
self._test_timer['batch_start_time'])
def on_predict_batch_end(self, step, logs=None):
logs = logs or {}
self.test_step += 1
samples = logs.get('batch_size', 1)
self.tested_samples += samples
self._test_timer['batch_time'] += (
time.time() - self._test_timer['batch_data_end_time'])
self._test_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._test_timer['samples'] += samples
if self.test_step % self.log_freq == 0 and self._is_print():
if self.test_steps is None or self.test_step < self.test_steps:
self._updates(logs, 'test')
self._test_timer['batch_start_time'] = time.time()
def on_eval_end(self, logs=None):
logs = logs or {}
if self._is_print() and (self.eval_steps is not None):
self._updates(logs, 'eval')
print('Eval samples: %d' % (self.evaled_samples))
def on_test_end(self, logs=None):
def on_predict_end(self, logs=None):
logs = logs or {}
if self._is_print():
if self.test_step % self.log_freq != 0 or self.verbose == 1:
......
......@@ -1692,11 +1692,11 @@ class Model(object):
test_steps = self._len_data_loader(test_loader)
logs = {'steps': test_steps}
cbks.on_begin('test', logs)
cbks.on_begin('predict', logs)
outputs = []
logs, outputs = self._run_one_epoch(test_loader, cbks, 'test')
logs, outputs = self._run_one_epoch(test_loader, cbks, 'predict')
outputs = list(zip(*outputs))
......@@ -1707,7 +1707,7 @@ class Model(object):
self._test_dataloader = None
cbks.on_end('test', logs)
cbks.on_end('predict', logs)
return outputs
def _save_inference_model(self, path):
......@@ -1793,7 +1793,7 @@ class Model(object):
callbacks.on_batch_begin(mode, step, logs)
if mode != 'test':
if mode != 'predict':
outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
data[len(self._inputs):])
if self._metrics and self._loss:
......@@ -1829,7 +1829,7 @@ class Model(object):
callbacks.on_batch_end(mode, step, logs)
self._reset_metrics()
if mode == 'test':
if mode == 'predict':
return logs, outputs
return logs
......
......@@ -159,7 +159,7 @@ class ProgressBar(object):
sys.stdout.write(info)
sys.stdout.flush()
self._last_update = now
elif self._verbose == 2:
elif self._verbose == 2 or self._verbose == 3:
if self._num:
numdigits = int(np.log10(self._num)) + 1
count = ('step %' + str(numdigits) + 'd/%d') % (current_num,
......
......@@ -106,13 +106,13 @@ class TestCallbacks(unittest.TestCase):
test_logs = {}
params = {'steps': eval_steps}
cbks.on_begin('test', params)
cbks.on_begin('predict', params)
for step in range(eval_steps):
cbks.on_batch_begin('test', step, test_logs)
cbks.on_batch_begin('predict', step, test_logs)
test_logs['batch_size'] = 2
time.sleep(0.005)
cbks.on_batch_end('test', step, test_logs)
cbks.on_end('test', test_logs)
cbks.on_batch_end('predict', step, test_logs)
cbks.on_end('predict', test_logs)
cbks.on_end('train')
......@@ -128,6 +128,10 @@ class TestCallbacks(unittest.TestCase):
self.verbose = 2
self.run_callback()
def test_callback_verbose_3(self):
self.verbose = 3
self.run_callback()
def test_visualdl_callback(self):
# visualdl not support python2
if sys.version_info < (3, ):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册