未验证 提交 2686dce8 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix ssd export (#2176)

上级 d8e29086
......@@ -43,9 +43,8 @@ class SSD(BaseArch):
self.inputs['gt_bbox'],
self.inputs['gt_class'])
else:
boxes, scores, anchors = self.ssd_head(body_feats,
self.inputs['image'])
bbox, bbox_num = self.post_process((boxes, scores), anchors,
preds, anchors = self.ssd_head(body_feats, self.inputs['image'])
bbox, bbox_num = self.post_process(preds, anchors,
self.inputs['im_shape'],
self.inputs['scale_factor'])
return bbox, bbox_num
......
......@@ -130,8 +130,8 @@ class SSDHead(nn.Layer):
box_preds = []
cls_scores = []
prior_boxes = []
for i, (feat, box_conv, score_conv
) in enumerate(zip(feats, self.box_convs, self.score_convs)):
for feat, box_conv, score_conv in zip(feats, self.box_convs,
self.score_convs):
box_pred = box_conv(feat)
box_pred = paddle.transpose(box_pred, [0, 2, 3, 1])
box_pred = paddle.reshape(box_pred, [0, -1, 4])
......@@ -148,7 +148,7 @@ class SSDHead(nn.Layer):
return self.get_loss(box_preds, cls_scores, gt_bbox, gt_class,
prior_boxes)
else:
return box_preds, cls_scores, prior_boxes
return (box_preds, cls_scores), prior_boxes
def get_loss(self, boxes, scores, gt_bbox, gt_class, prior_boxes):
return self.loss(boxes, scores, gt_bbox, gt_class, prior_boxes)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册