From fe36252699133e68ef8b1cb07a8a372361f88158 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 21 Apr 2020 05:41:02 +0000 Subject: [PATCH] Change MyProgBarLogger to LoggerCallBack --- examples/ocr/eval.py | 6 +++--- examples/ocr/train.py | 4 ++-- examples/ocr/utility.py | 12 ++++++------ hapi/model.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/ocr/eval.py b/examples/ocr/eval.py index a432c74..1adffa5 100644 --- a/examples/ocr/eval.py +++ b/examples/ocr/eval.py @@ -23,7 +23,7 @@ from hapi.model import Input, set_device from hapi.vision.transforms import BatchCompose from utility import add_arguments, print_arguments -from utility import SeqAccuracy, MyProgBarLogger, SeqBeamAccuracy +from utility import SeqAccuracy, LoggerCallBack, SeqBeamAccuracy from utility import postprocess from seq2seq_attn import Seq2SeqAttModel, Seq2SeqAttInferModel, WeightCrossEntropy import data @@ -91,7 +91,7 @@ def main(FLAGS): model.evaluate( eval_data=test_loader, - callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) + callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)]) def beam_search(FLAGS): @@ -140,7 +140,7 @@ def beam_search(FLAGS): model.evaluate( eval_data=test_loader, - callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) + callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)]) if __name__ == '__main__': diff --git a/examples/ocr/train.py b/examples/ocr/train.py index d36f296..d72173d 100644 --- a/examples/ocr/train.py +++ b/examples/ocr/train.py @@ -28,7 +28,7 @@ from hapi.model import Input, set_device from hapi.vision.transforms import BatchCompose from utility import add_arguments, print_arguments -from utility import SeqAccuracy, MyProgBarLogger +from utility import SeqAccuracy, LoggerCallBack from seq2seq_attn import Seq2SeqAttModel, WeightCrossEntropy import data @@ -129,7 +129,7 @@ def main(FLAGS): eval_data=test_loader, epochs=FLAGS.epoch, save_dir=FLAGS.checkpoint_path, - callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) + callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)]) if __name__ == '__main__': diff --git a/examples/ocr/utility.py b/examples/ocr/utility.py index 49d60e0..d47b3f1 100644 --- a/examples/ocr/utility.py +++ b/examples/ocr/utility.py @@ -102,31 +102,31 @@ class SeqAccuracy(Metric): return self._name -class MyProgBarLogger(ProgBarLogger): +class LoggerCallBack(ProgBarLogger): def __init__(self, log_freq=1, verbose=2, train_bs=None, eval_bs=None): - super(MyProgBarLogger, self).__init__(log_freq, verbose) + super(LoggerCallBack, self).__init__(log_freq, verbose) self.train_bs = train_bs self.eval_bs = eval_bs if eval_bs else train_bs def on_train_batch_end(self, step, logs=None): logs = logs or {} logs['loss'] = [l / self.train_bs for l in logs['loss']] - super(MyProgBarLogger, self).on_train_batch_end(step, logs) + super(LoggerCallBack, self).on_train_batch_end(step, logs) def on_epoch_end(self, epoch, logs=None): logs = logs or {} logs['loss'] = [l / self.train_bs for l in logs['loss']] - super(MyProgBarLogger, self).on_epoch_end(epoch, logs) + super(LoggerCallBack, self).on_epoch_end(epoch, logs) def on_eval_batch_end(self, step, logs=None): logs = logs or {} logs['loss'] = [l / self.eval_bs for l in logs['loss']] - super(MyProgBarLogger, self).on_eval_batch_end(step, logs) + super(LoggerCallBack, self).on_eval_batch_end(step, logs) def on_eval_end(self, logs=None): logs = logs or {} logs['loss'] = [l / self.eval_bs for l in logs['loss']] - super(MyProgBarLogger, self).on_eval_end(logs) + super(LoggerCallBack, self).on_eval_end(logs) def index2word(ids): diff --git a/hapi/model.py b/hapi/model.py index effae5c..6a00a36 100644 --- a/hapi/model.py +++ b/hapi/model.py @@ -1267,7 +1267,7 @@ class Model(fluid.dygraph.Layer): if mode == 'train': assert epoch is not None, 'when mode is train, epoch must be given' - callbacks.on_epoch_end(epoch) + callbacks.on_epoch_end(epoch, logs) return logs -- GitLab