You need to sign in or sign up before continuing.
未验证 提交 78c37c48 编写于 作者: W wangguanzhong 提交者: GitHub

fix yolo eval (#613)

上级 76f6c939
......@@ -68,7 +68,8 @@ class YOLOv3Head(object):
background_label=-1).__dict__,
weight_prefix_name='',
downsample=[32, 16, 8],
scale_x_y=1.0):
scale_x_y=1.0,
clip_bbox=True):
self.norm_decay = norm_decay
self.num_classes = num_classes
self.anchor_masks = anchor_masks
......@@ -86,6 +87,7 @@ class YOLOv3Head(object):
self.downsample = downsample
# TODO(guanzhong) activate scale_x_y in Paddle 2.0
#self.scale_x_y = scale_x_y
self.clip_bbox = clip_bbox
def _conv_bn(self,
input,
......@@ -325,7 +327,7 @@ class YOLOv3Head(object):
conf_thresh=self.nms.score_threshold,
downsample_ratio=self.downsample[i],
name=self.prefix_name + "yolo_box" + str(i),
clip_bbox=False)
clip_bbox=self.clip_bbox)
boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
......@@ -352,25 +354,25 @@ class YOLOv4Head(YOLOv3Head):
__inject__ = ['nms', 'yolo_loss']
__shared__ = ['num_classes', 'weight_prefix_name']
def __init__(
self,
anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55],
[72, 146], [142, 110], [192, 243], [459, 401]],
anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
nms=MultiClassNMS(
score_threshold=0.01,
nms_top_k=-1,
keep_top_k=-1,
nms_threshold=0.45,
background_label=-1).__dict__,
spp_stage=5,
num_classes=80,
weight_prefix_name='',
downsample=[8, 16, 32],
scale_x_y=[1.2, 1.1, 1.05],
yolo_loss="YOLOv3Loss",
iou_aware=False,
iou_aware_factor=0.4, ):
def __init__(self,
anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55],
[72, 146], [142, 110], [192, 243], [459, 401]],
anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
nms=MultiClassNMS(
score_threshold=0.01,
nms_top_k=-1,
keep_top_k=-1,
nms_threshold=0.45,
background_label=-1).__dict__,
spp_stage=5,
num_classes=80,
weight_prefix_name='',
downsample=[8, 16, 32],
scale_x_y=[1.2, 1.1, 1.05],
yolo_loss="YOLOv3Loss",
iou_aware=False,
iou_aware_factor=0.4,
clip_bbox=False):
super(YOLOv4Head, self).__init__(
anchors=anchors,
anchor_masks=anchor_masks,
......@@ -381,7 +383,8 @@ class YOLOv4Head(YOLOv3Head):
scale_x_y=scale_x_y,
yolo_loss=yolo_loss,
iou_aware=iou_aware,
iou_aware_factor=iou_aware_factor)
iou_aware_factor=iou_aware_factor,
clip_box=clip_bbox)
self.spp_stage = spp_stage
def _upsample(self, input, scale=2, name=None):
......
......@@ -42,7 +42,7 @@ class YOLOv3(object):
def __init__(self,
backbone,
yolo_head='YOLOv4Head',
yolo_head='YOLOv3Head',
use_fine_grained_loss=False):
super(YOLOv3, self).__init__()
self.backbone = backbone
......
......@@ -275,11 +275,11 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
w *= im_width
h *= im_height
else:
im_size = t['im_size'][0][i].tolist()
xmin, ymin, xmax, ymax = \
clip_bbox([xmin, ymin, xmax, ymax], im_size)
w = xmax - xmin
h = ymax - ymin
# for yolov4
# w = xmax - xmin
# h = ymax - ymin
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
......
......@@ -111,7 +111,7 @@ def main():
extra_keys = []
if cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape', 'im_size']
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg.metric == 'VOC':
extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册