未验证 提交 78c37c48 编写于 作者: W wangguanzhong 提交者: GitHub

fix yolo eval (#613)

上级 76f6c939
...@@ -68,7 +68,8 @@ class YOLOv3Head(object): ...@@ -68,7 +68,8 @@ class YOLOv3Head(object):
background_label=-1).__dict__, background_label=-1).__dict__,
weight_prefix_name='', weight_prefix_name='',
downsample=[32, 16, 8], downsample=[32, 16, 8],
scale_x_y=1.0): scale_x_y=1.0,
clip_bbox=True):
self.norm_decay = norm_decay self.norm_decay = norm_decay
self.num_classes = num_classes self.num_classes = num_classes
self.anchor_masks = anchor_masks self.anchor_masks = anchor_masks
...@@ -86,6 +87,7 @@ class YOLOv3Head(object): ...@@ -86,6 +87,7 @@ class YOLOv3Head(object):
self.downsample = downsample self.downsample = downsample
# TODO(guanzhong) activate scale_x_y in Paddle 2.0 # TODO(guanzhong) activate scale_x_y in Paddle 2.0
#self.scale_x_y = scale_x_y #self.scale_x_y = scale_x_y
self.clip_bbox = clip_bbox
def _conv_bn(self, def _conv_bn(self,
input, input,
...@@ -325,7 +327,7 @@ class YOLOv3Head(object): ...@@ -325,7 +327,7 @@ class YOLOv3Head(object):
conf_thresh=self.nms.score_threshold, conf_thresh=self.nms.score_threshold,
downsample_ratio=self.downsample[i], downsample_ratio=self.downsample[i],
name=self.prefix_name + "yolo_box" + str(i), name=self.prefix_name + "yolo_box" + str(i),
clip_bbox=False) clip_bbox=self.clip_bbox)
boxes.append(box) boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1])) scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
...@@ -352,8 +354,7 @@ class YOLOv4Head(YOLOv3Head): ...@@ -352,8 +354,7 @@ class YOLOv4Head(YOLOv3Head):
__inject__ = ['nms', 'yolo_loss'] __inject__ = ['nms', 'yolo_loss']
__shared__ = ['num_classes', 'weight_prefix_name'] __shared__ = ['num_classes', 'weight_prefix_name']
def __init__( def __init__(self,
self,
anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55], anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55],
[72, 146], [142, 110], [192, 243], [459, 401]], [72, 146], [142, 110], [192, 243], [459, 401]],
anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]], anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
...@@ -370,7 +371,8 @@ class YOLOv4Head(YOLOv3Head): ...@@ -370,7 +371,8 @@ class YOLOv4Head(YOLOv3Head):
scale_x_y=[1.2, 1.1, 1.05], scale_x_y=[1.2, 1.1, 1.05],
yolo_loss="YOLOv3Loss", yolo_loss="YOLOv3Loss",
iou_aware=False, iou_aware=False,
iou_aware_factor=0.4, ): iou_aware_factor=0.4,
clip_bbox=False):
super(YOLOv4Head, self).__init__( super(YOLOv4Head, self).__init__(
anchors=anchors, anchors=anchors,
anchor_masks=anchor_masks, anchor_masks=anchor_masks,
...@@ -381,7 +383,8 @@ class YOLOv4Head(YOLOv3Head): ...@@ -381,7 +383,8 @@ class YOLOv4Head(YOLOv3Head):
scale_x_y=scale_x_y, scale_x_y=scale_x_y,
yolo_loss=yolo_loss, yolo_loss=yolo_loss,
iou_aware=iou_aware, iou_aware=iou_aware,
iou_aware_factor=iou_aware_factor) iou_aware_factor=iou_aware_factor,
clip_box=clip_bbox)
self.spp_stage = spp_stage self.spp_stage = spp_stage
def _upsample(self, input, scale=2, name=None): def _upsample(self, input, scale=2, name=None):
......
...@@ -42,7 +42,7 @@ class YOLOv3(object): ...@@ -42,7 +42,7 @@ class YOLOv3(object):
def __init__(self, def __init__(self,
backbone, backbone,
yolo_head='YOLOv4Head', yolo_head='YOLOv3Head',
use_fine_grained_loss=False): use_fine_grained_loss=False):
super(YOLOv3, self).__init__() super(YOLOv3, self).__init__()
self.backbone = backbone self.backbone = backbone
......
...@@ -275,11 +275,11 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False): ...@@ -275,11 +275,11 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
w *= im_width w *= im_width
h *= im_height h *= im_height
else: else:
im_size = t['im_size'][0][i].tolist() # for yolov4
xmin, ymin, xmax, ymax = \ # w = xmax - xmin
clip_bbox([xmin, ymin, xmax, ymax], im_size) # h = ymax - ymin
w = xmax - xmin w = xmax - xmin + 1
h = ymax - ymin h = ymax - ymin + 1
bbox = [xmin, ymin, w, h] bbox = [xmin, ymin, w, h]
coco_res = { coco_res = {
......
...@@ -111,7 +111,7 @@ def main(): ...@@ -111,7 +111,7 @@ def main():
extra_keys = [] extra_keys = []
if cfg.metric == 'COCO': if cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape', 'im_size'] extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg.metric == 'VOC': if cfg.metric == 'VOC':
extra_keys = ['gt_bbox', 'gt_class', 'is_difficult'] extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册