提交 b9044c27 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add standard logger

上级 e5df00a0
......@@ -20,6 +20,8 @@ import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
import time
import datetime
import argparse
import paddle
import paddle.nn as nn
......@@ -126,6 +128,12 @@ class Trainer(object):
# key:
# val: metrics list word
output_info = dict()
time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
# global iter counter
global_step = 0
......@@ -135,10 +143,15 @@ class Trainer(object):
if metric_info is not None:
best_metric.update(metric_info)
tic = time.time()
for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1):
acc = 0.0
for iter_id, batch in enumerate(self.train_dataloader()):
if iter_id == 5:
for key in time_info:
time_info[key].reset()
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
.reshape([-1, 1]))
......@@ -148,8 +161,10 @@ class Trainer(object):
out = self.model(batch[0])
else:
out = self.model(batch[0], batch[1])
# calc loss
loss_dict = self.train_loss_func(out, batch[1])
for key in loss_dict:
if not key in output_info:
output_info[key] = AverageMeter(key, '7.5f')
......@@ -164,21 +179,38 @@ class Trainer(object):
output_info[key].update(metric_dict[key].numpy()[0],
batch_size)
# step opt and lr
loss_dict["loss"].backward()
optimizer.step()
optimizer.clear_grad()
lr_sch.step()
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
logger.info("[Train][Epoch {}][Iter: {}/{}]{}, {}".format(
epoch_id, iter_id,
len(self.train_dataloader), lr_msg, metric_msg))
time_msg = "s, ".join([
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info
])
# step opt and lr
loss_dict["loss"].backward()
optimizer.step()
optimizer.clear_grad()
lr_sch.step()
ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg)
eta_sec = ((self.config["Global"]["epochs"] - epoch_id + 1
) * len(self.train_dataloader) - iter_id
) * time_info["batch_cost"].avg
eta_msg = "eta: {:s}".format(
str(datetime.timedelta(seconds=int(eta_sec))))
logger.info(
"[Train][Epoch {}][Iter: {}/{}]{}, {}, {}, {}, {}".
format(epoch_id, iter_id,
len(self.train_dataloader), lr_msg, metric_msg,
time_msg, ips_msg, eta_msg))
tic = time.time()
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
......@@ -220,6 +252,7 @@ class Trainer(object):
@paddle.no_grad()
def eval(self, epoch_id=0):
self.model.eval()
if self.eval_loss_func is None:
loss_config = self.config.get("Loss", None)
......@@ -264,12 +297,25 @@ class Trainer(object):
self.model.train()
return eval_result
@paddle.no_grad()
def eval_cls(self, epoch_id=0):
output_info = dict()
time_info = {
"batch_cost": AverageMeter(
"batch_cost", '.5f', postfix=" s,"),
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
print_batch_step = self.config["Global"]["print_batch_step"]
metric_key = None
tic = time.time()
for iter_id, batch in enumerate(self.eval_dataloader()):
if iter_id == 5:
for key in time_info:
time_info[key].reset()
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
......@@ -305,13 +351,26 @@ class Trainer(object):
output_info[key].update(metric_dict[key].numpy()[0],
batch_size)
time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
time_msg = "s, ".join([
"{}: {:.5f}".format(key, time_info[key].avg)
for key in time_info
])
ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg)
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].val)
for key in output_info
])
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}".format(
epoch_id, iter_id, len(self.eval_dataloader), metric_msg))
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id,
len(self.eval_dataloader), metric_msg, time_msg, ips_msg))
tic = time.time()
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册