未验证 提交 c213c9fc 编写于 作者: T Tingquan Gao 提交者: GitHub

Fix the training log in static graph (#525)

* Adapt to PaddleHub2.0 to eliminate warning
* Fix the training log format
上级 e7dbecd2
...@@ -22,7 +22,7 @@ from paddlehub.utils.log import logger ...@@ -22,7 +22,7 @@ from paddlehub.utils.log import logger
from paddlehub.module.module import moduleinfo, serving from paddlehub.module.module import moduleinfo, serving
import cv2 import cv2
import numpy as np import numpy as np
import paddlehub as hub import paddle.nn as nn
import tools.infer.predict as paddle_predict import tools.infer.predict as paddle_predict
from tools.infer.utils import Base64ToCV2, create_paddle_predictor from tools.infer.utils import Base64ToCV2, create_paddle_predictor
...@@ -36,8 +36,8 @@ from deploy.hubserving.clas.params import read_params ...@@ -36,8 +36,8 @@ from deploy.hubserving.clas.params import read_params
author="paddle-dev", author="paddle-dev",
author_email="paddle-dev@baidu.com", author_email="paddle-dev@baidu.com",
type="cv/class") type="cv/class")
class ClasSystem(hub.Module): class ClasSystem(nn.Layer):
def _initialize(self, use_gpu=None, enable_mkldnn=None): def __init__(self, use_gpu=None, enable_mkldnn=None):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
......
...@@ -149,7 +149,7 @@ For example, if you need to replace the model used by the deployed service, you ...@@ -149,7 +149,7 @@ For example, if you need to replace the model used by the deployed service, you
After modifying and installing (`hub install deploy/hubserving/clas/`) and before deploying, you can use `python deploy/hubserving/clas/test.py` to test the installed service module. After modifying and installing (`hub install deploy/hubserving/clas/`) and before deploying, you can use `python deploy/hubserving/clas/test.py` to test the installed service module.
1. Uninstall old service module 3. Uninstall old service module
```shell ```shell
hub uninstall clas_system hub uninstall clas_system
``` ```
......
...@@ -288,10 +288,10 @@ def run(dataloader, ...@@ -288,10 +288,10 @@ def run(dataloader,
if not use_mix: if not use_mix:
topk_name = 'top{}'.format(config.topk) topk_name = 'top{}'.format(config.topk)
metric_list.insert( metric_list.insert(
1, (topk_name, AverageMeter( 0, (topk_name, AverageMeter(
topk_name, '.5f', postfix=","))) topk_name, '.5f', postfix=",")))
metric_list.insert( metric_list.insert(
1, ("top1", AverageMeter( 0, ("top1", AverageMeter(
"top1", '.5f', postfix=","))) "top1", '.5f', postfix=",")))
metric_list = OrderedDict(metric_list) metric_list = OrderedDict(metric_list)
......
...@@ -288,6 +288,7 @@ def create_optimizer(config): ...@@ -288,6 +288,7 @@ def create_optimizer(config):
opt = OptimizerBuilder(config, **opt_config) opt = OptimizerBuilder(config, **opt_config)
return opt(lr), lr return opt(lr), lr
def create_strategy(config): def create_strategy(config):
""" """
Create build strategy and exec strategy. Create build strategy and exec strategy.
...@@ -342,7 +343,6 @@ def create_strategy(config): ...@@ -342,7 +343,6 @@ def create_strategy(config):
return build_strategy, exec_strategy return build_strategy, exec_strategy
def dist_optimizer(config, optimizer): def dist_optimizer(config, optimizer):
""" """
Create a distributed optimizer based on a normal optimizer Create a distributed optimizer based on a normal optimizer
...@@ -493,19 +493,35 @@ def run(dataloader, ...@@ -493,19 +493,35 @@ def run(dataloader,
Returns: Returns:
""" """
fetch_list = [f[0] for f in fetchs.values()] fetch_list = [f[0] for f in fetchs.values()]
metric_list = [f[1] for f in fetchs.values()] metric_list = [
if mode == "train": ("lr", AverageMeter(
metric_list.append(AverageMeter('lr', 'f', need_avg=False)) 'lr', 'f', postfix=",", need_avg=False)),
for m in metric_list: ("batch_time", AverageMeter(
'batch_cost', '.5f', postfix=" s,")),
("reader_time", AverageMeter(
'reader_cost', '.5f', postfix=" s,")),
]
topk_name = 'top{}'.format(config.topk)
metric_list.insert(0, ("loss", fetchs["loss"][1]))
metric_list.insert(0, (topk_name, fetchs[topk_name][1]))
metric_list.insert(0, ("top1", fetchs["top1"][1]))
metric_list = OrderedDict(metric_list)
for m in metric_list.values():
m.reset() m.reset()
batch_time = AverageMeter('elapse', '.3f')
use_dali = config.get('use_dali', False) use_dali = config.get('use_dali', False)
dataloader = dataloader if use_dali else dataloader() dataloader = dataloader if use_dali else dataloader()
tic = time.time() tic = time.time()
for idx, batch in enumerate(dataloader): for idx, batch in enumerate(dataloader):
# ignore the warmup iters # ignore the warmup iters
if idx == 5: if idx == 5:
batch_time.reset() metric_list["batch_time"].reset()
metric_list["reader_time"].reset()
metric_list['reader_time'].update(time.time() - tic)
if use_dali: if use_dali:
batch_size = batch[0]["feed_image"].shape()[0] batch_size = batch[0]["feed_image"].shape()[0]
feed_dict = batch[0] feed_dict = batch[0]
...@@ -518,17 +534,16 @@ def run(dataloader, ...@@ -518,17 +534,16 @@ def run(dataloader,
metrics = exe.run(program=program, metrics = exe.run(program=program,
feed=feed_dict, feed=feed_dict,
fetch_list=fetch_list) fetch_list=fetch_list)
batch_time.update(time.time() - tic)
for i, m in enumerate(metrics):
metric_list[i].update(np.mean(m), batch_size)
for name, m in zip(fetchs.keys(), metrics):
metric_list[name].update(np.mean(m), batch_size)
metric_list["batch_time"].update(time.time() - tic)
if mode == "train": if mode == "train":
metric_list[-1].update(lr_scheduler.get_lr()) metric_list['lr'].update(lr_scheduler.get_lr())
fetchs_str = ''.join([str(m.value) + ' ' fetchs_str = ' '.join([str(m.value) for m in metric_list.values()])
for m in metric_list] + [batch_time.mean]) + 's' ips_info = " ips: {:.5f} images/sec.".format(
ips_info = " ips: {:.5f} images/sec.".format(batch_size / batch_size / metric_list["batch_time"].avg)
batch_time.avg)
fetchs_str += ips_info fetchs_str += ips_info
if lr_scheduler is not None: if lr_scheduler is not None:
...@@ -563,12 +578,13 @@ def run(dataloader, ...@@ -563,12 +578,13 @@ def run(dataloader,
tic = time.time() tic = time.time()
end_str = ''.join([str(m.mean) + ' ' end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
for m in metric_list] + [batch_time.total]) + 's' [metric_list["batch_time"].total])
ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count / ips_info = "ips: {:.5f} images/sec.".format(
batch_time.sum) batch_size * metric_list["batch_time"].count /
metric_list["batch_time"].sum)
if mode == 'valid': if mode == 'valid':
logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info)) logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
else: else:
end_epoch_str = "END epoch:{:<3d}".format(epoch) end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str, logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册