未验证 提交 7f4a1dea 编写于 作者: G Guanghua Yu 提交者: GitHub

add VisualDL and save best_model (#2361)

* add VisualDL and save best_model

* add VisualDLWriter callback
上级 87cb019d
......@@ -19,6 +19,8 @@ from __future__ import print_function
import os
import sys
import datetime
import six
import numpy as np
import paddle
from paddle.distributed import ParallelEnv
......@@ -140,7 +142,10 @@ class Checkpointer(Callback):
def __init__(self, model):
super(Checkpointer, self).__init__(model)
cfg = self.model.cfg
self.best_ap = 0.
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
self.save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename)
if self.use_ema:
self.ema = ModelEMA(
cfg['ema_decay'], self.model.model, use_thres_step=True)
......@@ -152,24 +157,36 @@ class Checkpointer(Callback):
def on_epoch_end(self, status):
# Checkpointer only performed during training
mode = status['mode']
if mode != 'train':
return
epoch_id = status['epoch_id']
weight = None
save_name = None
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
epoch_id = status['epoch_id']
end_epoch = self.model.cfg.epoch
if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename)
save_name = str(
epoch_id) if epoch_id != end_epoch - 1 else "model_final"
if self.use_ema:
state_dict = self.ema.apply()
save_model(state_dict, self.model.optimizer, save_dir,
save_name, epoch_id + 1)
else:
save_model(self.model.model, self.model.optimizer, save_dir,
save_name, epoch_id + 1)
if mode == 'train':
end_epoch = self.model.cfg.epoch
if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_name = str(
epoch_id) if epoch_id != end_epoch - 1 else "model_final"
if self.use_ema:
weight = self.ema.apply()
else:
weight = self.model.model
elif mode == 'eval':
if 'save_best_model' in status and status['save_best_model']:
for metric in self.model._metrics:
map_res = metric.get_results()
key = 'bbox' if 'bbox' in map_res else 'mask'
if map_res[key][0] > self.best_ap:
self.best_ap = map_res[key][0]
save_name = 'best_model'
if self.use_ema:
weight = self.ema.apply()
else:
weight = self.model.model
logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap))
if weight:
save_model(weight, self.model.optimizer, self.save_dir,
save_name, epoch_id + 1)
class WiferFaceEval(Callback):
......@@ -182,3 +199,60 @@ class WiferFaceEval(Callback):
for metric in self.model._metrics:
metric.update(self.model.model)
sys.exit()
class VisualDLWriter(Callback):
"""
Use VisualDL to log data or image
"""
def __init__(self, model):
super(VisualDLWriter, self).__init__(model)
assert six.PY3, "VisualDL requires Python >= 3.5"
try:
from visualdl import LogWriter
except Exception as e:
logger.error('visualdl not found, plaese install visualdl. '
'for example: `pip install visualdl`.')
raise e
self.vdl_writer = LogWriter(model.cfg.vdl_log_dir)
self.vdl_loss_step = 0
self.vdl_mAP_step = 0
self.vdl_image_step = 0
self.vdl_image_frame = 0
def on_step_end(self, status):
mode = status['mode']
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if mode == 'train':
training_staus = status['training_staus']
for loss_name, loss_value in training_staus.get().items():
self.vdl_writer.add_scalar(loss_name, loss_value,
self.vdl_loss_step)
self.vdl_loss_step += 1
elif mode == 'test':
ori_image = status['original_image']
result_image = status['result_image']
self.vdl_writer.add_image(
"original/frame_{}".format(self.vdl_image_frame), ori_image,
self.vdl_image_step)
self.vdl_writer.add_image(
"result/frame_{}".format(self.vdl_image_frame),
result_image, self.vdl_image_step)
self.vdl_image_step += 1
# each frame can display ten pictures at most.
if self.vdl_image_step % 10 == 0:
self.vdl_image_step = 0
self.vdl_image_frame += 1
def on_epoch_end(self, status):
mode = status['mode']
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if mode == 'eval':
for metric in self.model._metrics:
for key, map_value in metric.get_results().items():
self.vdl_writer.add_scalar("{}-mAP".format(key),
map_value[0],
self.vdl_mAP_step)
self.vdl_mAP_step += 1
......@@ -34,7 +34,7 @@ from ppdet.utils.visualizer import visualize_results
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_categories, get_infer_results
import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
from .export_utils import _dump_infer_config
from ppdet.utils.logger import setup_logger
......@@ -101,12 +101,17 @@ class Trainer(object):
def _init_callbacks(self):
if self.mode == 'train':
self._callbacks = [LogPrinter(self), Checkpointer(self)]
if self.cfg.use_vdl:
self._callbacks.append(VisualDLWriter(self))
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'eval':
self._callbacks = [LogPrinter(self)]
if self.cfg.metric == 'WiderFace':
self._callbacks.append(WiferFaceEval(self))
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'test' and self.cfg.use_vdl:
self._callbacks = [VisualDLWriter(self)]
self._compose_callback = ComposeCallback(self._callbacks)
else:
self._callbacks = []
self._compose_callback = None
......@@ -268,6 +273,7 @@ class Trainer(object):
self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler)
with paddle.no_grad():
self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader)
def _eval_with_loader(self, loader):
......@@ -291,12 +297,12 @@ class Trainer(object):
self.status['sample_num'] = sample_num
self.status['cost_time'] = time.time() - tic
self._compose_callback.on_epoch_end(self.status)
# accumulate metric to log out
for metric in self._metrics:
metric.accumulate()
metric.log()
self._compose_callback.on_epoch_end(self.status)
# reset metric states for metric may performed multiple times
self._reset_metrics()
......@@ -330,8 +336,9 @@ class Trainer(object):
for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
end = start + bbox_num[i]
self.status['original_image'] = np.array(image.copy())
end = start + bbox_num[i]
bbox_res = batch_res['bbox'][start:end] \
if 'bbox' in batch_res else None
mask_res = batch_res['mask'][start:end] \
......@@ -341,7 +348,8 @@ class Trainer(object):
image = visualize_results(image, bbox_res, mask_res, segm_res,
int(outs['im_id']), catid2name,
draw_threshold)
self.status['result_image'] = np.array(image.copy())
self._compose_callback.on_step_end(self.status)
# save image with detection
save_name = self._get_save_image_name(output_dir, image_path)
logger.info("Detection bbox results save in {}".format(
......
......@@ -130,7 +130,7 @@ def cocoapi_eval(jsonfile,
results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(
* [results_flatten[i::num_columns] for i in range(num_columns)])
*[results_flatten[i::num_columns] for i in range(num_columns)])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
......
......@@ -277,9 +277,8 @@ class DetectionMAP(object):
num_columns = min(6, len(results_per_category) * 2)
results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(* [
results_flatten[i::num_columns] for i in range(num_columns)
])
results_2d = itertools.zip_longest(
*[results_flatten[i::num_columns] for i in range(num_columns)])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
......
......@@ -214,7 +214,7 @@ class VOCMetric(Metric):
self.map_type, map_stat))
def get_results(self):
self.detection_map.get_map()
return {'bbox': [self.detection_map.get_map()]}
class WiderFaceMetric(Metric):
......
tqdm
typeguard ; python_version >= '3.4'
visualdl>=2.0.0b
visualdl>=2.1.0
opencv-python
PyYAML
shapely
......
......@@ -130,6 +130,8 @@ def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
cfg['use_vdl'] = FLAGS.use_vdl
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
merge_config(FLAGS.opt)
if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config)
......
......@@ -72,6 +72,16 @@ def parse_args():
help="Enable mixed precision training.")
parser.add_argument(
"--fleet", action='store_true', default=False, help="Use fleet or not")
parser.add_argument(
"--use_vdl",
type=bool,
default=False,
help="whether to record the data to VisualDL.")
parser.add_argument(
'--vdl_log_dir',
type=str,
default="vdl_log_dir/scalar",
help='VisualDL logging directory for scalar.')
args = parser.parse_args()
return args
......@@ -104,6 +114,8 @@ def main():
cfg = load_config(FLAGS.config)
cfg['fp16'] = FLAGS.fp16
cfg['fleet'] = FLAGS.fleet
cfg['use_vdl'] = FLAGS.use_vdl
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
merge_config(FLAGS.opt)
if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config)
......
......@@ -96,6 +96,7 @@ def main():
dump_infer_config(FLAGS, cfg)
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
if __name__ == '__main__':
enable_static_mode()
parser = ArgsParser()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册