diff --git a/examples/ocr/eval.py b/examples/ocr/eval.py index a432c74128519dcfeec91294dd30786d46898919..1adffa5401679ab0d49cc586c0238ce1c01fa1b8 100644 --- a/examples/ocr/eval.py +++ b/examples/ocr/eval.py @@ -23,7 +23,7 @@ from hapi.model import Input, set_device from hapi.vision.transforms import BatchCompose from utility import add_arguments, print_arguments -from utility import SeqAccuracy, MyProgBarLogger, SeqBeamAccuracy +from utility import SeqAccuracy, LoggerCallBack, SeqBeamAccuracy from utility import postprocess from seq2seq_attn import Seq2SeqAttModel, Seq2SeqAttInferModel, WeightCrossEntropy import data @@ -91,7 +91,7 @@ def main(FLAGS): model.evaluate( eval_data=test_loader, - callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) + callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)]) def beam_search(FLAGS): @@ -140,7 +140,7 @@ def beam_search(FLAGS): model.evaluate( eval_data=test_loader, - callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) + callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)]) if __name__ == '__main__': diff --git a/examples/ocr/train.py b/examples/ocr/train.py index d36f2962d60072eedb809fe7f9a39c5c2e4cffb6..d72173dfde7791b53af80f04697f8e3defd01445 100644 --- a/examples/ocr/train.py +++ b/examples/ocr/train.py @@ -28,7 +28,7 @@ from hapi.model import Input, set_device from hapi.vision.transforms import BatchCompose from utility import add_arguments, print_arguments -from utility import SeqAccuracy, MyProgBarLogger +from utility import SeqAccuracy, LoggerCallBack from seq2seq_attn import Seq2SeqAttModel, WeightCrossEntropy import data @@ -129,7 +129,7 @@ def main(FLAGS): eval_data=test_loader, epochs=FLAGS.epoch, save_dir=FLAGS.checkpoint_path, - callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) + callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)]) if __name__ == '__main__': diff --git a/examples/ocr/utility.py b/examples/ocr/utility.py index 49d60e0567adfef335759c41602b749a0bb21dba..d47b3f17d16452c1292402abc15b534eec4b3459 100644 --- a/examples/ocr/utility.py +++ b/examples/ocr/utility.py @@ -102,31 +102,31 @@ class SeqAccuracy(Metric): return self._name -class MyProgBarLogger(ProgBarLogger): +class LoggerCallBack(ProgBarLogger): def __init__(self, log_freq=1, verbose=2, train_bs=None, eval_bs=None): - super(MyProgBarLogger, self).__init__(log_freq, verbose) + super(LoggerCallBack, self).__init__(log_freq, verbose) self.train_bs = train_bs self.eval_bs = eval_bs if eval_bs else train_bs def on_train_batch_end(self, step, logs=None): logs = logs or {} logs['loss'] = [l / self.train_bs for l in logs['loss']] - super(MyProgBarLogger, self).on_train_batch_end(step, logs) + super(LoggerCallBack, self).on_train_batch_end(step, logs) def on_epoch_end(self, epoch, logs=None): logs = logs or {} logs['loss'] = [l / self.train_bs for l in logs['loss']] - super(MyProgBarLogger, self).on_epoch_end(epoch, logs) + super(LoggerCallBack, self).on_epoch_end(epoch, logs) def on_eval_batch_end(self, step, logs=None): logs = logs or {} logs['loss'] = [l / self.eval_bs for l in logs['loss']] - super(MyProgBarLogger, self).on_eval_batch_end(step, logs) + super(LoggerCallBack, self).on_eval_batch_end(step, logs) def on_eval_end(self, logs=None): logs = logs or {} logs['loss'] = [l / self.eval_bs for l in logs['loss']] - super(MyProgBarLogger, self).on_eval_end(logs) + super(LoggerCallBack, self).on_eval_end(logs) def index2word(ids): diff --git a/hapi/model.py b/hapi/model.py index effae5c91e32c1eef8f00902df89c045a1f03950..6a00a36838d553b20603acb4f0aadf6ff12f0da8 100644 --- a/hapi/model.py +++ b/hapi/model.py @@ -1267,7 +1267,7 @@ class Model(fluid.dygraph.Layer): if mode == 'train': assert epoch is not None, 'when mode is train, epoch must be given' - callbacks.on_epoch_end(epoch) + callbacks.on_epoch_end(epoch, logs) return logs