未验证 提交 6cfe3643 编写于 作者: G George Ni 提交者: GitHub

[MOT] fix for_mot in yolo_fpn (#3039)

* fix for_mot in PPYOLOTinyFPN and PPYOLOPAN

* fix yolo_fpn
上级 fa474195
...@@ -30,7 +30,8 @@ class YOLOv3(BaseArch): ...@@ -30,7 +30,8 @@ class YOLOv3(BaseArch):
yolo_head (nn.Layer): anchor_head instance yolo_head (nn.Layer): anchor_head instance
bbox_post_process (object): `BBoxPostProcess` instance bbox_post_process (object): `BBoxPostProcess` instance
data_format (str): data format, NCHW or NHWC data_format (str): data format, NCHW or NHWC
for_mot (bool): whether return other features used in tracking model for_mot (bool): whether return other features for multi-object tracking
models, default False in pure object detection models.
""" """
super(YOLOv3, self).__init__(data_format=data_format) super(YOLOv3, self).__init__(data_format=data_format)
self.backbone = backbone self.backbone = backbone
......
...@@ -18,11 +18,9 @@ import paddle.nn.functional as F ...@@ -18,11 +18,9 @@ import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ..backbones.darknet import ConvBNLayer from ..backbones.darknet import ConvBNLayer
import numpy as np
from ..shape_spec import ShapeSpec from ..shape_spec import ShapeSpec
__all__ = ['YOLOv3FPN', 'PPYOLOFPN'] __all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN']
def add_coord(x, data_format): def add_coord(x, data_format):
...@@ -492,8 +490,11 @@ class YOLOv3FPN(nn.Layer): ...@@ -492,8 +490,11 @@ class YOLOv3FPN(nn.Layer):
assert len(blocks) == self.num_blocks assert len(blocks) == self.num_blocks
blocks = blocks[::-1] blocks = blocks[::-1]
yolo_feats = [] yolo_feats = []
# add embedding features output for multi-object tracking model
if for_mot: if for_mot:
emb_feats = [] emb_feats = []
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
if i > 0: if i > 0:
if self.data_format == 'NCHW': if self.data_format == 'NCHW':
...@@ -504,7 +505,7 @@ class YOLOv3FPN(nn.Layer): ...@@ -504,7 +505,7 @@ class YOLOv3FPN(nn.Layer):
yolo_feats.append(tip) yolo_feats.append(tip)
if for_mot: if for_mot:
# add emb_feats output # add embedding features output
emb_feats.append(route) emb_feats.append(route)
if i < self.num_blocks - 1: if i < self.num_blocks - 1:
...@@ -668,8 +669,11 @@ class PPYOLOFPN(nn.Layer): ...@@ -668,8 +669,11 @@ class PPYOLOFPN(nn.Layer):
assert len(blocks) == self.num_blocks assert len(blocks) == self.num_blocks
blocks = blocks[::-1] blocks = blocks[::-1]
yolo_feats = [] yolo_feats = []
# add embedding features output for multi-object tracking model
if for_mot: if for_mot:
emb_feats = [] emb_feats = []
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
if i > 0: if i > 0:
if self.data_format == 'NCHW': if self.data_format == 'NCHW':
...@@ -680,7 +684,7 @@ class PPYOLOFPN(nn.Layer): ...@@ -680,7 +684,7 @@ class PPYOLOFPN(nn.Layer):
yolo_feats.append(tip) yolo_feats.append(tip)
if for_mot: if for_mot:
# add emb_feats output # add embedding features output
emb_feats.append(route) emb_feats.append(route)
if i < self.num_blocks - 1: if i < self.num_blocks - 1:
...@@ -780,11 +784,15 @@ class PPYOLOTinyFPN(nn.Layer): ...@@ -780,11 +784,15 @@ class PPYOLOTinyFPN(nn.Layer):
name=name)) name=name))
self.routes.append(route) self.routes.append(route)
def forward(self, blocks): def forward(self, blocks, for_mot=False):
assert len(blocks) == self.num_blocks assert len(blocks) == self.num_blocks
blocks = blocks[::-1] blocks = blocks[::-1]
yolo_feats = [] yolo_feats = []
# add embedding features output for multi-object tracking model
if for_mot:
emb_feats = []
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
if i == 0 and self.spp_: if i == 0 and self.spp_:
block = self.spp(block) block = self.spp(block)
...@@ -797,11 +805,18 @@ class PPYOLOTinyFPN(nn.Layer): ...@@ -797,11 +805,18 @@ class PPYOLOTinyFPN(nn.Layer):
route, tip = self.yolo_blocks[i](block) route, tip = self.yolo_blocks[i](block)
yolo_feats.append(tip) yolo_feats.append(tip)
if for_mot:
# add embedding features output
emb_feats.append(route)
if i < self.num_blocks - 1: if i < self.num_blocks - 1:
route = self.routes[i](route) route = self.routes[i](route)
route = F.interpolate( route = F.interpolate(
route, scale_factor=2., data_format=self.data_format) route, scale_factor=2., data_format=self.data_format)
if for_mot:
return {'yolo_feats': yolo_feats, 'emb_feats': emb_feats}
else:
return yolo_feats return yolo_feats
@classmethod @classmethod
...@@ -964,11 +979,15 @@ class PPYOLOPAN(nn.Layer): ...@@ -964,11 +979,15 @@ class PPYOLOPAN(nn.Layer):
self._out_channels = self._out_channels[::-1] self._out_channels = self._out_channels[::-1]
def forward(self, blocks): def forward(self, blocks, for_mot=False):
assert len(blocks) == self.num_blocks assert len(blocks) == self.num_blocks
blocks = blocks[::-1] blocks = blocks[::-1]
# fpn
fpn_feats = [] fpn_feats = []
# add embedding features output for multi-object tracking model
if for_mot:
emb_feats = []
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
if i > 0: if i > 0:
if self.data_format == 'NCHW': if self.data_format == 'NCHW':
...@@ -978,6 +997,10 @@ class PPYOLOPAN(nn.Layer): ...@@ -978,6 +997,10 @@ class PPYOLOPAN(nn.Layer):
route, tip = self.fpn_blocks[i](block) route, tip = self.fpn_blocks[i](block)
fpn_feats.append(tip) fpn_feats.append(tip)
if for_mot:
# add embedding features output
emb_feats.append(route)
if i < self.num_blocks - 1: if i < self.num_blocks - 1:
route = self.fpn_routes[i](route) route = self.fpn_routes[i](route)
route = F.interpolate( route = F.interpolate(
...@@ -996,6 +1019,9 @@ class PPYOLOPAN(nn.Layer): ...@@ -996,6 +1019,9 @@ class PPYOLOPAN(nn.Layer):
route, tip = self.pan_blocks[i](block) route, tip = self.pan_blocks[i](block)
pan_feats.append(tip) pan_feats.append(tip)
if for_mot:
return {'yolo_feats': pan_feats[::-1], 'emb_feats': emb_feats}
else:
return pan_feats[::-1] return pan_feats[::-1]
@classmethod @classmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册