未验证 提交 c213c9fc 编写于 作者: T Tingquan Gao 提交者: GitHub

Fix the training log in static graph (#525)

* Adapt to PaddleHub2.0 to eliminate warning
* Fix the training log format
上级 e7dbecd2
......@@ -22,7 +22,7 @@ from paddlehub.utils.log import logger
from paddlehub.module.module import moduleinfo, serving
import cv2
import numpy as np
import paddlehub as hub
import paddle.nn as nn
import tools.infer.predict as paddle_predict
from tools.infer.utils import Base64ToCV2, create_paddle_predictor
......@@ -36,8 +36,8 @@ from deploy.hubserving.clas.params import read_params
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/class")
class ClasSystem(hub.Module):
def _initialize(self, use_gpu=None, enable_mkldnn=None):
class ClasSystem(nn.Layer):
def __init__(self, use_gpu=None, enable_mkldnn=None):
"""
initialize with the necessary elements
"""
......
......@@ -149,7 +149,7 @@ For example, if you need to replace the model used by the deployed service, you
After modifying and installing (`hub install deploy/hubserving/clas/`) and before deploying, you can use `python deploy/hubserving/clas/test.py` to test the installed service module.
1. Uninstall old service module
3. Uninstall old service module
```shell
hub uninstall clas_system
```
......
......@@ -288,10 +288,10 @@ def run(dataloader,
if not use_mix:
topk_name = 'top{}'.format(config.topk)
metric_list.insert(
1, (topk_name, AverageMeter(
0, (topk_name, AverageMeter(
topk_name, '.5f', postfix=",")))
metric_list.insert(
1, ("top1", AverageMeter(
0, ("top1", AverageMeter(
"top1", '.5f', postfix=",")))
metric_list = OrderedDict(metric_list)
......
......@@ -288,6 +288,7 @@ def create_optimizer(config):
opt = OptimizerBuilder(config, **opt_config)
return opt(lr), lr
def create_strategy(config):
"""
Create build strategy and exec strategy.
......@@ -342,7 +343,6 @@ def create_strategy(config):
return build_strategy, exec_strategy
def dist_optimizer(config, optimizer):
"""
Create a distributed optimizer based on a normal optimizer
......@@ -493,19 +493,35 @@ def run(dataloader,
Returns:
"""
fetch_list = [f[0] for f in fetchs.values()]
metric_list = [f[1] for f in fetchs.values()]
if mode == "train":
metric_list.append(AverageMeter('lr', 'f', need_avg=False))
for m in metric_list:
metric_list = [
("lr", AverageMeter(
'lr', 'f', postfix=",", need_avg=False)),
("batch_time", AverageMeter(
'batch_cost', '.5f', postfix=" s,")),
("reader_time", AverageMeter(
'reader_cost', '.5f', postfix=" s,")),
]
topk_name = 'top{}'.format(config.topk)
metric_list.insert(0, ("loss", fetchs["loss"][1]))
metric_list.insert(0, (topk_name, fetchs[topk_name][1]))
metric_list.insert(0, ("top1", fetchs["top1"][1]))
metric_list = OrderedDict(metric_list)
for m in metric_list.values():
m.reset()
batch_time = AverageMeter('elapse', '.3f')
use_dali = config.get('use_dali', False)
dataloader = dataloader if use_dali else dataloader()
tic = time.time()
for idx, batch in enumerate(dataloader):
# ignore the warmup iters
if idx == 5:
batch_time.reset()
metric_list["batch_time"].reset()
metric_list["reader_time"].reset()
metric_list['reader_time'].update(time.time() - tic)
if use_dali:
batch_size = batch[0]["feed_image"].shape()[0]
feed_dict = batch[0]
......@@ -518,17 +534,16 @@ def run(dataloader,
metrics = exe.run(program=program,
feed=feed_dict,
fetch_list=fetch_list)
batch_time.update(time.time() - tic)
for i, m in enumerate(metrics):
metric_list[i].update(np.mean(m), batch_size)
for name, m in zip(fetchs.keys(), metrics):
metric_list[name].update(np.mean(m), batch_size)
metric_list["batch_time"].update(time.time() - tic)
if mode == "train":
metric_list[-1].update(lr_scheduler.get_lr())
metric_list['lr'].update(lr_scheduler.get_lr())
fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.mean]) + 's'
ips_info = " ips: {:.5f} images/sec.".format(batch_size /
batch_time.avg)
fetchs_str = ' '.join([str(m.value) for m in metric_list.values()])
ips_info = " ips: {:.5f} images/sec.".format(
batch_size / metric_list["batch_time"].avg)
fetchs_str += ips_info
if lr_scheduler is not None:
......@@ -563,12 +578,13 @@ def run(dataloader,
tic = time.time()
end_str = ''.join([str(m.mean) + ' '
for m in metric_list] + [batch_time.total]) + 's'
ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count /
batch_time.sum)
end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
[metric_list["batch_time"].total])
ips_info = "ips: {:.5f} images/sec.".format(
batch_size * metric_list["batch_time"].count /
metric_list["batch_time"].sum)
if mode == 'valid':
logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info))
logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
else:
end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册