提交 76711d60 编写于 作者: G Guanghua Yu 提交者: wangguanzhong

[PaddleDetection] add save best model (#3052)

* add save best model & fix dcn configs
上级 7aac2766
......@@ -132,7 +132,7 @@ FasterRCNNEvalFeed:
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
......
......@@ -130,7 +130,7 @@ FasterRCNNEvalFeed:
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
......
......@@ -132,7 +132,7 @@ FasterRCNNEvalFeed:
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
......
......@@ -135,7 +135,7 @@ FasterRCNNEvalFeed:
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
......
......@@ -139,7 +139,7 @@ MaskRCNNEvalFeed:
MaskRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
......
......@@ -138,7 +138,7 @@ MaskRCNNEvalFeed:
MaskRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
......
......@@ -140,7 +140,7 @@ MaskRCNNEvalFeed:
MaskRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
......
......@@ -141,7 +141,7 @@ MaskRCNNEvalFeed:
MaskRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
......
......@@ -51,7 +51,8 @@ python -u tools/train.py -c configs/faster_rcnn_r50_1x.yml --eval
Alternating between training epoch and evaluation run is possible, simply pass
in `--eval` to do so and evaluate at each snapshot_iter. It can be modified at `snapshot_iter` of the configuration file. If evaluation dataset is large and
causes time-consuming in training, we suggest decreasing evaluation times or evaluating after training.
causes time-consuming in training, we suggest decreasing evaluation times or evaluating after training. When perform evaluation in training,
the best model with highest MAP is saved at each `snapshot_iter`. `best_model` has the same path as `model_final`.
- configuration options and assign Dataset path
......
......@@ -49,7 +49,8 @@ export PYTHONPATH=$PYTHONPATH:.
python -u tools/train.py -c configs/faster_rcnn_r50_1x.yml --eval
```
可通过设置`--eval`在训练epoch中交替执行评估, 评估在每个snapshot_iter时开始。可在配置文件的`snapshot_iter`处修改。
如果验证集很大,测试将会比较耗时,影响训练速度,建议减少评估次数,或训练完再进行评估。
如果验证集很大,测试将会比较耗时,影响训练速度,建议减少评估次数,或训练完再进行评估。当边训练边测试时,在每次snapshot_iter会评测出最佳mAP模型保存到
`best_model`文件夹下,`best_model`的路径和`model_final`的路径相同。
- 设置配置文件参数 && 指定数据集路径
```bash
......
......@@ -80,16 +80,18 @@ def bbox_eval(results, anno_file, outfile, with_background=True):
for i, catid in enumerate(cat_ids)})
xywh_results = bbox2out(results, clsid2catid)
assert len(
xywh_results) > 0, "The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data."
if len(xywh_results) == 0:
logger.warning("The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data.\n \
stop eval!")
return [0.0]
with open(outfile, 'w') as f:
json.dump(xywh_results, f)
cocoapi_eval(outfile, 'bbox', coco_gt=coco_gt)
map_stats = cocoapi_eval(outfile, 'bbox', coco_gt=coco_gt)
# flush coco evaluation result
sys.stdout.flush()
return map_stats
def mask_eval(results, anno_file, outfile, resolution, thresh_binarize=0.5):
......@@ -137,7 +139,7 @@ def cocoapi_eval(jsonfile,
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval.stats
def proposal2out(results, is_bbox_normalized=False):
xywh_res = []
......
......@@ -99,6 +99,7 @@ def eval_results(results,
is_bbox_normalized=False,
output_directory=None):
"""Evaluation for evaluation program results"""
box_ap_stats = []
if metric == 'COCO':
from ppdet.utils.coco_eval import proposal_eval, bbox_eval, mask_eval
anno_file = getattr(feed.dataset, 'annotation', None)
......@@ -112,7 +113,7 @@ def eval_results(results,
output = 'bbox.json'
if output_directory:
output = os.path.join(output_directory, 'bbox.json')
bbox_eval(results, anno_file, output, with_background)
box_ap_stats = bbox_eval(results, anno_file, output, with_background)
if 'mask' in results[0]:
output = 'mask.json'
if output_directory:
......@@ -122,9 +123,12 @@ def eval_results(results,
if 'accum_map' in results[-1]:
res = np.mean(results[-1]['accum_map'][0])
logger.info('mAP: {:.2f}'.format(res * 100.))
box_ap_stats.append(res * 100.)
elif 'bbox' in results[0]:
voc_bbox_eval(
box_ap = voc_bbox_eval(
results, num_classes, is_bbox_normalized=is_bbox_normalized)
box_ap_stats.append(box_ap)
return box_ap_stats
def json_eval_results(feed, metric, json_directory=None):
"""
......
......@@ -33,8 +33,8 @@ __all__ = [
]
def bbox_eval(results,
class_num,
def bbox_eval(results,
class_num,
overlap_thresh=0.5,
map_type='11point',
is_bbox_normalized=False,
......@@ -45,13 +45,13 @@ def bbox_eval(results,
Args:
results (list): prediction bounding box results.
class_num (int): evaluation class number.
overlap_thresh (float): the postive threshold of
overlap_thresh (float): the postive threshold of
bbox overlap
map_type (string): method for mAP calcualtion,
can only be '11point' or 'integral'
is_bbox_normalized (bool): whether bbox is normalized
to range [0, 1].
evaluate_difficult (bool): whether to evaluate
evaluate_difficult (bool): whether to evaluate
difficult gt bbox.
"""
assert 'bbox' in results[0]
......@@ -107,8 +107,10 @@ def bbox_eval(results,
logger.info("Accumulating evaluatation results...")
detection_map.accumulate()
map_stat = 100. * detection_map.get_map()
logger.info("mAP({:.2f}, {}) = {:.2f}".format(overlap_thresh,
map_type, 100. * detection_map.get_map()))
map_type, map_stat))
return map_stat
def prune_zero_padding(gt_box, gt_label, difficult=None):
......
......@@ -171,6 +171,7 @@ def main():
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(cfg.save_dir, cfg_name)
time_stat = deque(maxlen=cfg.log_iter)
best_box_ap_list = [0.0, 0] #[map, iter]
for it in range(start_iter, cfg.max_iters):
start_time = end_time
end_time = time.time()
......@@ -198,8 +199,14 @@ def main():
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, cfg.num_classes,
box_ap_stats = eval_results(results, eval_feed, cfg.metric, cfg.num_classes,
resolution, is_bbox_normalized, FLAGS.output_eval)
if box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0]
best_box_ap_list[1] = it
checkpoint.save(exe, train_prog, os.path.join(save_dir,"best_model"))
logger.info("Best test box ap: {}, in iter: {}".format(
best_box_ap_list[0],best_box_ap_list[1]))
train_pyreader.reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册