未验证 提交 7f4a1dea 编写于 作者: G Guanghua Yu 提交者: GitHub

add VisualDL and save best_model (#2361)

* add VisualDL and save best_model

* add VisualDLWriter callback
上级 87cb019d
...@@ -19,6 +19,8 @@ from __future__ import print_function ...@@ -19,6 +19,8 @@ from __future__ import print_function
import os import os
import sys import sys
import datetime import datetime
import six
import numpy as np
import paddle import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
...@@ -140,7 +142,10 @@ class Checkpointer(Callback): ...@@ -140,7 +142,10 @@ class Checkpointer(Callback):
def __init__(self, model): def __init__(self, model):
super(Checkpointer, self).__init__(model) super(Checkpointer, self).__init__(model)
cfg = self.model.cfg cfg = self.model.cfg
self.best_ap = 0.
self.use_ema = ('use_ema' in cfg and cfg['use_ema']) self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
self.save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename)
if self.use_ema: if self.use_ema:
self.ema = ModelEMA( self.ema = ModelEMA(
cfg['ema_decay'], self.model.model, use_thres_step=True) cfg['ema_decay'], self.model.model, use_thres_step=True)
...@@ -152,24 +157,36 @@ class Checkpointer(Callback): ...@@ -152,24 +157,36 @@ class Checkpointer(Callback):
def on_epoch_end(self, status): def on_epoch_end(self, status):
# Checkpointer only performed during training # Checkpointer only performed during training
mode = status['mode'] mode = status['mode']
if mode != 'train': epoch_id = status['epoch_id']
return weight = None
save_name = None
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
epoch_id = status['epoch_id'] if mode == 'train':
end_epoch = self.model.cfg.epoch end_epoch = self.model.cfg.epoch
if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1: if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_dir = os.path.join(self.model.cfg.save_dir, save_name = str(
self.model.cfg.filename) epoch_id) if epoch_id != end_epoch - 1 else "model_final"
save_name = str( if self.use_ema:
epoch_id) if epoch_id != end_epoch - 1 else "model_final" weight = self.ema.apply()
if self.use_ema: else:
state_dict = self.ema.apply() weight = self.model.model
save_model(state_dict, self.model.optimizer, save_dir, elif mode == 'eval':
save_name, epoch_id + 1) if 'save_best_model' in status and status['save_best_model']:
else: for metric in self.model._metrics:
save_model(self.model.model, self.model.optimizer, save_dir, map_res = metric.get_results()
save_name, epoch_id + 1) key = 'bbox' if 'bbox' in map_res else 'mask'
if map_res[key][0] > self.best_ap:
self.best_ap = map_res[key][0]
save_name = 'best_model'
if self.use_ema:
weight = self.ema.apply()
else:
weight = self.model.model
logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap))
if weight:
save_model(weight, self.model.optimizer, self.save_dir,
save_name, epoch_id + 1)
class WiferFaceEval(Callback): class WiferFaceEval(Callback):
...@@ -182,3 +199,60 @@ class WiferFaceEval(Callback): ...@@ -182,3 +199,60 @@ class WiferFaceEval(Callback):
for metric in self.model._metrics: for metric in self.model._metrics:
metric.update(self.model.model) metric.update(self.model.model)
sys.exit() sys.exit()
class VisualDLWriter(Callback):
"""
Use VisualDL to log data or image
"""
def __init__(self, model):
super(VisualDLWriter, self).__init__(model)
assert six.PY3, "VisualDL requires Python >= 3.5"
try:
from visualdl import LogWriter
except Exception as e:
logger.error('visualdl not found, plaese install visualdl. '
'for example: `pip install visualdl`.')
raise e
self.vdl_writer = LogWriter(model.cfg.vdl_log_dir)
self.vdl_loss_step = 0
self.vdl_mAP_step = 0
self.vdl_image_step = 0
self.vdl_image_frame = 0
def on_step_end(self, status):
mode = status['mode']
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if mode == 'train':
training_staus = status['training_staus']
for loss_name, loss_value in training_staus.get().items():
self.vdl_writer.add_scalar(loss_name, loss_value,
self.vdl_loss_step)
self.vdl_loss_step += 1
elif mode == 'test':
ori_image = status['original_image']
result_image = status['result_image']
self.vdl_writer.add_image(
"original/frame_{}".format(self.vdl_image_frame), ori_image,
self.vdl_image_step)
self.vdl_writer.add_image(
"result/frame_{}".format(self.vdl_image_frame),
result_image, self.vdl_image_step)
self.vdl_image_step += 1
# each frame can display ten pictures at most.
if self.vdl_image_step % 10 == 0:
self.vdl_image_step = 0
self.vdl_image_frame += 1
def on_epoch_end(self, status):
mode = status['mode']
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if mode == 'eval':
for metric in self.model._metrics:
for key, map_value in metric.get_results().items():
self.vdl_writer.add_scalar("{}-mAP".format(key),
map_value[0],
self.vdl_mAP_step)
self.vdl_mAP_step += 1
...@@ -34,7 +34,7 @@ from ppdet.utils.visualizer import visualize_results ...@@ -34,7 +34,7 @@ from ppdet.utils.visualizer import visualize_results
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_categories, get_infer_results from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_categories, get_infer_results
import ppdet.utils.stats as stats import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
from .export_utils import _dump_infer_config from .export_utils import _dump_infer_config
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
...@@ -101,12 +101,17 @@ class Trainer(object): ...@@ -101,12 +101,17 @@ class Trainer(object):
def _init_callbacks(self): def _init_callbacks(self):
if self.mode == 'train': if self.mode == 'train':
self._callbacks = [LogPrinter(self), Checkpointer(self)] self._callbacks = [LogPrinter(self), Checkpointer(self)]
if self.cfg.use_vdl:
self._callbacks.append(VisualDLWriter(self))
self._compose_callback = ComposeCallback(self._callbacks) self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'eval': elif self.mode == 'eval':
self._callbacks = [LogPrinter(self)] self._callbacks = [LogPrinter(self)]
if self.cfg.metric == 'WiderFace': if self.cfg.metric == 'WiderFace':
self._callbacks.append(WiferFaceEval(self)) self._callbacks.append(WiferFaceEval(self))
self._compose_callback = ComposeCallback(self._callbacks) self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'test' and self.cfg.use_vdl:
self._callbacks = [VisualDLWriter(self)]
self._compose_callback = ComposeCallback(self._callbacks)
else: else:
self._callbacks = [] self._callbacks = []
self._compose_callback = None self._compose_callback = None
...@@ -268,6 +273,7 @@ class Trainer(object): ...@@ -268,6 +273,7 @@ class Trainer(object):
self.cfg.worker_num, self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler) batch_sampler=self._eval_batch_sampler)
with paddle.no_grad(): with paddle.no_grad():
self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader) self._eval_with_loader(self._eval_loader)
def _eval_with_loader(self, loader): def _eval_with_loader(self, loader):
...@@ -291,12 +297,12 @@ class Trainer(object): ...@@ -291,12 +297,12 @@ class Trainer(object):
self.status['sample_num'] = sample_num self.status['sample_num'] = sample_num
self.status['cost_time'] = time.time() - tic self.status['cost_time'] = time.time() - tic
self._compose_callback.on_epoch_end(self.status)
# accumulate metric to log out # accumulate metric to log out
for metric in self._metrics: for metric in self._metrics:
metric.accumulate() metric.accumulate()
metric.log() metric.log()
self._compose_callback.on_epoch_end(self.status)
# reset metric states for metric may performed multiple times # reset metric states for metric may performed multiple times
self._reset_metrics() self._reset_metrics()
...@@ -330,8 +336,9 @@ class Trainer(object): ...@@ -330,8 +336,9 @@ class Trainer(object):
for i, im_id in enumerate(outs['im_id']): for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)] image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert('RGB')
end = start + bbox_num[i] self.status['original_image'] = np.array(image.copy())
end = start + bbox_num[i]
bbox_res = batch_res['bbox'][start:end] \ bbox_res = batch_res['bbox'][start:end] \
if 'bbox' in batch_res else None if 'bbox' in batch_res else None
mask_res = batch_res['mask'][start:end] \ mask_res = batch_res['mask'][start:end] \
...@@ -341,7 +348,8 @@ class Trainer(object): ...@@ -341,7 +348,8 @@ class Trainer(object):
image = visualize_results(image, bbox_res, mask_res, segm_res, image = visualize_results(image, bbox_res, mask_res, segm_res,
int(outs['im_id']), catid2name, int(outs['im_id']), catid2name,
draw_threshold) draw_threshold)
self.status['result_image'] = np.array(image.copy())
self._compose_callback.on_step_end(self.status)
# save image with detection # save image with detection
save_name = self._get_save_image_name(output_dir, image_path) save_name = self._get_save_image_name(output_dir, image_path)
logger.info("Detection bbox results save in {}".format( logger.info("Detection bbox results save in {}".format(
......
...@@ -130,7 +130,7 @@ def cocoapi_eval(jsonfile, ...@@ -130,7 +130,7 @@ def cocoapi_eval(jsonfile,
results_flatten = list(itertools.chain(*results_per_category)) results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2) headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest( results_2d = itertools.zip_longest(
* [results_flatten[i::num_columns] for i in range(num_columns)]) *[results_flatten[i::num_columns] for i in range(num_columns)])
table_data = [headers] table_data = [headers]
table_data += [result for result in results_2d] table_data += [result for result in results_2d]
table = AsciiTable(table_data) table = AsciiTable(table_data)
......
...@@ -277,9 +277,8 @@ class DetectionMAP(object): ...@@ -277,9 +277,8 @@ class DetectionMAP(object):
num_columns = min(6, len(results_per_category) * 2) num_columns = min(6, len(results_per_category) * 2)
results_flatten = list(itertools.chain(*results_per_category)) results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2) headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(* [ results_2d = itertools.zip_longest(
results_flatten[i::num_columns] for i in range(num_columns) *[results_flatten[i::num_columns] for i in range(num_columns)])
])
table_data = [headers] table_data = [headers]
table_data += [result for result in results_2d] table_data += [result for result in results_2d]
table = AsciiTable(table_data) table = AsciiTable(table_data)
......
...@@ -214,7 +214,7 @@ class VOCMetric(Metric): ...@@ -214,7 +214,7 @@ class VOCMetric(Metric):
self.map_type, map_stat)) self.map_type, map_stat))
def get_results(self): def get_results(self):
self.detection_map.get_map() return {'bbox': [self.detection_map.get_map()]}
class WiderFaceMetric(Metric): class WiderFaceMetric(Metric):
......
tqdm tqdm
typeguard ; python_version >= '3.4' typeguard ; python_version >= '3.4'
visualdl>=2.0.0b visualdl>=2.1.0
opencv-python opencv-python
PyYAML PyYAML
shapely shapely
......
...@@ -130,6 +130,8 @@ def main(): ...@@ -130,6 +130,8 @@ def main():
FLAGS = parse_args() FLAGS = parse_args()
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
cfg['use_vdl'] = FLAGS.use_vdl
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
if FLAGS.slim_config: if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config) slim_cfg = load_config(FLAGS.slim_config)
......
...@@ -72,6 +72,16 @@ def parse_args(): ...@@ -72,6 +72,16 @@ def parse_args():
help="Enable mixed precision training.") help="Enable mixed precision training.")
parser.add_argument( parser.add_argument(
"--fleet", action='store_true', default=False, help="Use fleet or not") "--fleet", action='store_true', default=False, help="Use fleet or not")
parser.add_argument(
"--use_vdl",
type=bool,
default=False,
help="whether to record the data to VisualDL.")
parser.add_argument(
'--vdl_log_dir',
type=str,
default="vdl_log_dir/scalar",
help='VisualDL logging directory for scalar.')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -104,6 +114,8 @@ def main(): ...@@ -104,6 +114,8 @@ def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
cfg['fp16'] = FLAGS.fp16 cfg['fp16'] = FLAGS.fp16
cfg['fleet'] = FLAGS.fleet cfg['fleet'] = FLAGS.fleet
cfg['use_vdl'] = FLAGS.use_vdl
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
if FLAGS.slim_config: if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config) slim_cfg = load_config(FLAGS.slim_config)
......
...@@ -96,6 +96,7 @@ def main(): ...@@ -96,6 +96,7 @@ def main():
dump_infer_config(FLAGS, cfg) dump_infer_config(FLAGS, cfg)
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
if __name__ == '__main__': if __name__ == '__main__':
enable_static_mode() enable_static_mode()
parser = ArgsParser() parser = ArgsParser()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册