未验证 提交 d70598ec 编写于 作者: X xiaoting 提交者: GitHub

add nms_cpu for layout_picodet (#8453)

上级 484510a9
...@@ -17,6 +17,7 @@ PicoDet: ...@@ -17,6 +17,7 @@ PicoDet:
backbone: LCNet backbone: LCNet
neck: CSPPAN neck: CSPPAN
head: PicoHead head: PicoHead
nms_cpu: True
LCNet: LCNet:
scale: 1.0 scale: 1.0
......
...@@ -10,6 +10,7 @@ PicoDet: ...@@ -10,6 +10,7 @@ PicoDet:
backbone: LCNet backbone: LCNet
neck: CSPPAN neck: CSPPAN
head: PicoHead head: PicoHead
nms_cpu: True
LCNet: LCNet:
scale: 2.5 scale: 2.5
......
...@@ -36,13 +36,14 @@ class PicoDet(BaseArch): ...@@ -36,13 +36,14 @@ class PicoDet(BaseArch):
__category__ = 'architecture' __category__ = 'architecture'
def __init__(self, backbone, neck, head='PicoHead'): def __init__(self, backbone, neck, head='PicoHead', nms_cpu=False):
super(PicoDet, self).__init__() super(PicoDet, self).__init__()
self.backbone = backbone self.backbone = backbone
self.neck = neck self.neck = neck
self.head = head self.head = head
self.export_post_process = True self.export_post_process = True
self.export_nms = True self.export_nms = True
self.nms_cpu = nms_cpu
@classmethod @classmethod
def from_config(cls, cfg, *args, **kwargs): def from_config(cls, cfg, *args, **kwargs):
...@@ -69,7 +70,10 @@ class PicoDet(BaseArch): ...@@ -69,7 +70,10 @@ class PicoDet(BaseArch):
else: else:
scale_factor = self.inputs['scale_factor'] scale_factor = self.inputs['scale_factor']
bboxes, bbox_num = self.head.post_process( bboxes, bbox_num = self.head.post_process(
head_outs, scale_factor, export_nms=self.export_nms) head_outs,
scale_factor,
export_nms=self.export_nms,
nms_cpu=self.nms_cpu)
return bboxes, bbox_num return bboxes, bbox_num
def get_loss(self, ): def get_loss(self, ):
......
...@@ -242,6 +242,7 @@ class PicoHead(OTAVFLHead): ...@@ -242,6 +242,7 @@ class PicoHead(OTAVFLHead):
self.nms_pre = nms_pre self.nms_pre = nms_pre
self.cell_offset = cell_offset self.cell_offset = cell_offset
self.eval_size = eval_size self.eval_size = eval_size
self.device = paddle.device.get_device()
self.use_sigmoid = self.loss_vfl.use_sigmoid self.use_sigmoid = self.loss_vfl.use_sigmoid
if self.use_sigmoid: if self.use_sigmoid:
...@@ -397,7 +398,11 @@ class PicoHead(OTAVFLHead): ...@@ -397,7 +398,11 @@ class PicoHead(OTAVFLHead):
stride_tensor = paddle.concat(stride_tensor) stride_tensor = paddle.concat(stride_tensor)
return anchor_points, stride_tensor return anchor_points, stride_tensor
def post_process(self, head_outs, scale_factor, export_nms=True): def post_process(self,
head_outs,
scale_factor,
export_nms=True,
nms_cpu=False):
pred_scores, pred_bboxes = head_outs pred_scores, pred_bboxes = head_outs
if not export_nms: if not export_nms:
return pred_bboxes, pred_scores return pred_bboxes, pred_scores
...@@ -409,7 +414,12 @@ class PicoHead(OTAVFLHead): ...@@ -409,7 +414,12 @@ class PicoHead(OTAVFLHead):
axis=-1).reshape([-1, 1, 4]) axis=-1).reshape([-1, 1, 4])
# scale bbox to origin image size. # scale bbox to origin image size.
pred_bboxes /= scale_factor pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) if nms_cpu:
paddle.set_device("cpu")
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
paddle.set_device(self.device)
else:
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num return bbox_pred, bbox_num
...@@ -767,7 +777,11 @@ class PicoHeadV2(GFLHead): ...@@ -767,7 +777,11 @@ class PicoHeadV2(GFLHead):
stride_tensor = paddle.concat(stride_tensor) stride_tensor = paddle.concat(stride_tensor)
return anchor_points, stride_tensor return anchor_points, stride_tensor
def post_process(self, head_outs, scale_factor, export_nms=True): def post_process(self,
head_outs,
scale_factor,
export_nms=True,
nms_cpu=False):
pred_scores, pred_bboxes = head_outs pred_scores, pred_bboxes = head_outs
if not export_nms: if not export_nms:
return pred_bboxes, pred_scores return pred_bboxes, pred_scores
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册