提交 fe362526 编写于 作者: Q qingqing01

Change MyProgBarLogger to LoggerCallBack

上级 5e496435
...@@ -23,7 +23,7 @@ from hapi.model import Input, set_device ...@@ -23,7 +23,7 @@ from hapi.model import Input, set_device
from hapi.vision.transforms import BatchCompose from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from utility import SeqAccuracy, MyProgBarLogger, SeqBeamAccuracy from utility import SeqAccuracy, LoggerCallBack, SeqBeamAccuracy
from utility import postprocess from utility import postprocess
from seq2seq_attn import Seq2SeqAttModel, Seq2SeqAttInferModel, WeightCrossEntropy from seq2seq_attn import Seq2SeqAttModel, Seq2SeqAttInferModel, WeightCrossEntropy
import data import data
...@@ -91,7 +91,7 @@ def main(FLAGS): ...@@ -91,7 +91,7 @@ def main(FLAGS):
model.evaluate( model.evaluate(
eval_data=test_loader, eval_data=test_loader,
callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
def beam_search(FLAGS): def beam_search(FLAGS):
...@@ -140,7 +140,7 @@ def beam_search(FLAGS): ...@@ -140,7 +140,7 @@ def beam_search(FLAGS):
model.evaluate( model.evaluate(
eval_data=test_loader, eval_data=test_loader,
callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -28,7 +28,7 @@ from hapi.model import Input, set_device ...@@ -28,7 +28,7 @@ from hapi.model import Input, set_device
from hapi.vision.transforms import BatchCompose from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from utility import SeqAccuracy, MyProgBarLogger from utility import SeqAccuracy, LoggerCallBack
from seq2seq_attn import Seq2SeqAttModel, WeightCrossEntropy from seq2seq_attn import Seq2SeqAttModel, WeightCrossEntropy
import data import data
...@@ -129,7 +129,7 @@ def main(FLAGS): ...@@ -129,7 +129,7 @@ def main(FLAGS):
eval_data=test_loader, eval_data=test_loader,
epochs=FLAGS.epoch, epochs=FLAGS.epoch,
save_dir=FLAGS.checkpoint_path, save_dir=FLAGS.checkpoint_path,
callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)]) callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -102,31 +102,31 @@ class SeqAccuracy(Metric): ...@@ -102,31 +102,31 @@ class SeqAccuracy(Metric):
return self._name return self._name
class MyProgBarLogger(ProgBarLogger): class LoggerCallBack(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, train_bs=None, eval_bs=None): def __init__(self, log_freq=1, verbose=2, train_bs=None, eval_bs=None):
super(MyProgBarLogger, self).__init__(log_freq, verbose) super(LoggerCallBack, self).__init__(log_freq, verbose)
self.train_bs = train_bs self.train_bs = train_bs
self.eval_bs = eval_bs if eval_bs else train_bs self.eval_bs = eval_bs if eval_bs else train_bs
def on_train_batch_end(self, step, logs=None): def on_train_batch_end(self, step, logs=None):
logs = logs or {} logs = logs or {}
logs['loss'] = [l / self.train_bs for l in logs['loss']] logs['loss'] = [l / self.train_bs for l in logs['loss']]
super(MyProgBarLogger, self).on_train_batch_end(step, logs) super(LoggerCallBack, self).on_train_batch_end(step, logs)
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
logs = logs or {} logs = logs or {}
logs['loss'] = [l / self.train_bs for l in logs['loss']] logs['loss'] = [l / self.train_bs for l in logs['loss']]
super(MyProgBarLogger, self).on_epoch_end(epoch, logs) super(LoggerCallBack, self).on_epoch_end(epoch, logs)
def on_eval_batch_end(self, step, logs=None): def on_eval_batch_end(self, step, logs=None):
logs = logs or {} logs = logs or {}
logs['loss'] = [l / self.eval_bs for l in logs['loss']] logs['loss'] = [l / self.eval_bs for l in logs['loss']]
super(MyProgBarLogger, self).on_eval_batch_end(step, logs) super(LoggerCallBack, self).on_eval_batch_end(step, logs)
def on_eval_end(self, logs=None): def on_eval_end(self, logs=None):
logs = logs or {} logs = logs or {}
logs['loss'] = [l / self.eval_bs for l in logs['loss']] logs['loss'] = [l / self.eval_bs for l in logs['loss']]
super(MyProgBarLogger, self).on_eval_end(logs) super(LoggerCallBack, self).on_eval_end(logs)
def index2word(ids): def index2word(ids):
......
...@@ -1267,7 +1267,7 @@ class Model(fluid.dygraph.Layer): ...@@ -1267,7 +1267,7 @@ class Model(fluid.dygraph.Layer):
if mode == 'train': if mode == 'train':
assert epoch is not None, 'when mode is train, epoch must be given' assert epoch is not None, 'when mode is train, epoch must be given'
callbacks.on_epoch_end(epoch) callbacks.on_epoch_end(epoch, logs)
return logs return logs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册