提交 fe362526 编写于 作者: Q qingqing01

Change MyProgBarLogger to LoggerCallBack

上级 5e496435
......@@ -23,7 +23,7 @@ from hapi.model import Input, set_device
from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments
from utility import SeqAccuracy, MyProgBarLogger, SeqBeamAccuracy
from utility import SeqAccuracy, LoggerCallBack, SeqBeamAccuracy
from utility import postprocess
from seq2seq_attn import Seq2SeqAttModel, Seq2SeqAttInferModel, WeightCrossEntropy
import data
......@@ -91,7 +91,7 @@ def main(FLAGS):
model.evaluate(
eval_data=test_loader,
callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)])
callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
def beam_search(FLAGS):
......@@ -140,7 +140,7 @@ def beam_search(FLAGS):
model.evaluate(
eval_data=test_loader,
callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)])
callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
if __name__ == '__main__':
......
......@@ -28,7 +28,7 @@ from hapi.model import Input, set_device
from hapi.vision.transforms import BatchCompose
from utility import add_arguments, print_arguments
from utility import SeqAccuracy, MyProgBarLogger
from utility import SeqAccuracy, LoggerCallBack
from seq2seq_attn import Seq2SeqAttModel, WeightCrossEntropy
import data
......@@ -129,7 +129,7 @@ def main(FLAGS):
eval_data=test_loader,
epochs=FLAGS.epoch,
save_dir=FLAGS.checkpoint_path,
callbacks=[MyProgBarLogger(10, 2, FLAGS.batch_size)])
callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
if __name__ == '__main__':
......
......@@ -102,31 +102,31 @@ class SeqAccuracy(Metric):
return self._name
class MyProgBarLogger(ProgBarLogger):
class LoggerCallBack(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, train_bs=None, eval_bs=None):
super(MyProgBarLogger, self).__init__(log_freq, verbose)
super(LoggerCallBack, self).__init__(log_freq, verbose)
self.train_bs = train_bs
self.eval_bs = eval_bs if eval_bs else train_bs
def on_train_batch_end(self, step, logs=None):
logs = logs or {}
logs['loss'] = [l / self.train_bs for l in logs['loss']]
super(MyProgBarLogger, self).on_train_batch_end(step, logs)
super(LoggerCallBack, self).on_train_batch_end(step, logs)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['loss'] = [l / self.train_bs for l in logs['loss']]
super(MyProgBarLogger, self).on_epoch_end(epoch, logs)
super(LoggerCallBack, self).on_epoch_end(epoch, logs)
def on_eval_batch_end(self, step, logs=None):
logs = logs or {}
logs['loss'] = [l / self.eval_bs for l in logs['loss']]
super(MyProgBarLogger, self).on_eval_batch_end(step, logs)
super(LoggerCallBack, self).on_eval_batch_end(step, logs)
def on_eval_end(self, logs=None):
logs = logs or {}
logs['loss'] = [l / self.eval_bs for l in logs['loss']]
super(MyProgBarLogger, self).on_eval_end(logs)
super(LoggerCallBack, self).on_eval_end(logs)
def index2word(ids):
......
......@@ -1267,7 +1267,7 @@ class Model(fluid.dygraph.Layer):
if mode == 'train':
assert epoch is not None, 'when mode is train, epoch must be given'
callbacks.on_epoch_end(epoch)
callbacks.on_epoch_end(epoch, logs)
return logs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册