提交 6b09ec35 编写于 作者: Y Yang Zhang 提交者: qingqing01

Unify interface of detectors (#2503)

上级 44c2837e
......@@ -16,8 +16,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from ppdet.core.workspace import register
......@@ -64,11 +62,16 @@ class MaskRCNN(object):
self.mask_head = mask_head
self.fpn = fpn
def train(self, feed_vars):
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
assert mode in ['train', 'test'], "only support 'train' and 'test' mode"
if mode == 'train':
required_fields = ['gt_label', 'gt_box', 'gt_mask', 'is_crowd', 'im_info']
else:
required_fields = ['im_shape', 'im_info']
for var in required_fields:
assert var in feed_vars, "{} has no {} field".format(feed_vars, var)
im_info = feed_vars['im_info']
gt_box = feed_vars['gt_box']
is_crowd = feed_vars['is_crowd']
body_feats = self.backbone(im)
......@@ -76,13 +79,19 @@ class MaskRCNN(object):
if self.fpn is not None:
body_feats, spatial_scale = self.fpn.get_output(body_feats)
# rpn proposals
rois = self.rpn_head.get_proposals(body_feats, im_info)
# RPN proposals
rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode)
if self.fpn is None:
last_feat = body_feats[list(body_feats.keys())[-1]]
roi_feat = self.roi_extractor(last_feat, rois)
else:
roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
rpn_loss = self.rpn_head.get_loss(im_info, gt_box, is_crowd)
if mode == 'train':
rpn_loss = self.rpn_head.get_loss(im_info, feed_vars['gt_box'],
feed_vars['is_crowd'])
for var in ['gt_label', 'is_crowd', 'gt_box', 'im_info']:
assert var in feed_vars, "{} has no {}".format(feed_vars, var)
outs = self.bbox_assigner(
rpn_rois=rois,
gt_classes=feed_vars['gt_label'],
......@@ -91,39 +100,23 @@ class MaskRCNN(object):
im_info=feed_vars['im_info'])
rois = outs[0]
labels_int32 = outs[1]
bbox_targets = outs[2]
bbox_inside_weights = outs[3]
bbox_outside_weights = outs[4]
if self.fpn is None:
# in models without FPN, roi extractor only uses the last level of
# feature maps. And list(body_feats.keys())[-1] represents the name of
# last feature map.
last_feat = body_feats[list(body_feats.keys())[-1]]
roi_feat = self.roi_extractor(last_feat, rois)
else:
roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
loss = self.bbox_head.get_loss(roi_feat, labels_int32, bbox_targets,
bbox_inside_weights,
bbox_outside_weights)
loss = self.bbox_head.get_loss(roi_feat, labels_int32, *outs[2:])
loss.update(rpn_loss)
assert 'gt_mask' in feed_vars, "{} has no gt_mask".format(feed_vars)
outs = self.mask_assigner(
mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner(
rois=rois,
gt_classes=feed_vars['gt_label'],
is_crowd=feed_vars['is_crowd'],
gt_segms=feed_vars['gt_mask'],
im_info=feed_vars['im_info'],
labels_int32=labels_int32)
mask_rois, roi_has_mask_int32, mask_int32 = outs
if self.fpn is None:
bbox_head_feat = self.bbox_head.get_head_feat()
feat = fluid.layers.gather(bbox_head_feat, roi_has_mask_int32)
else:
feat = self.roi_extractor(body_feats, mask_rois, spatial_scale,
True)
is_mask=True)
mask_loss = self.mask_head.get_loss(feat, mask_int32)
loss.update(mask_loss)
......@@ -132,27 +125,9 @@ class MaskRCNN(object):
loss.update({'loss': total_loss})
return loss
def test(self, feed_vars):
im = feed_vars['image']
im_info = feed_vars['im_info']
im_shape = feed_vars['im_shape']
body_feats = self.backbone(im)
# FPN
if self.fpn is not None:
body_feats, spatial_scale = self.fpn.get_output(body_feats)
rois = self.rpn_head.get_proposals(body_feats, im_info, mode='test')
if self.fpn is None:
body_feat = body_feats[list(body_feats.keys())[-1]]
roi_feat = self.roi_extractor(body_feat, rois)
else:
roi_feat = self.roi_extractor(body_feats, rois, spatial_scale,
False)
bbox_pred = self.bbox_head.get_prediction(roi_feat, rois, im_info,
im_shape)
feed_vars['im_shape'])
bbox_pred = bbox_pred['bbox']
# share weight
......@@ -177,16 +152,21 @@ class MaskRCNN(object):
mask_rois = bbox * im_scale
if self.fpn is None:
mask_feat = self.roi_extractor(body_feat, mask_rois)
mask_feat = self.roi_extractor(last_feat, mask_rois)
mask_feat = self.bbox_head.get_head_feat(mask_feat)
else:
mask_feat = self.roi_extractor(body_feats, mask_rois,
spatial_scale, True)
spatial_scale, is_mask=True)
mask_out = self.mask_head.get_prediction(mask_feat, bbox)
fluid.layers.assign(input=mask_out, output=mask_pred)
return {'bbox': bbox_pred, 'mask': mask_pred}
def train(self, feed_vars):
return self.build(feed_vars, 'train')
def eval(self, feed_vars):
self.test(feed_vars)
return self.build(feed_vars, 'test')
def test(self, feed_vars):
return self.build(feed_vars, 'test')
......@@ -43,7 +43,7 @@ class RetinaNet(object):
self.fpn = fpn
self.retina_head = retina_head
def _forward(self, feed_vars, mode='train'):
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
im_info = feed_vars['im_info']
if mode == 'train':
......@@ -69,10 +69,10 @@ class RetinaNet(object):
return pred
def train(self, feed_vars):
return self._forward(feed_vars, 'train')
return self.build(feed_vars, 'train')
def eval(self, feed_vars):
return self._forward(feed_vars, 'test')
return self.build(feed_vars, 'test')
def test(self, feed_vars):
return self._forward(feed_vars, 'test')
return self.build(feed_vars, 'test')
......@@ -41,7 +41,7 @@ class YOLOv3(object):
self.backbone = backbone
self.yolo_head = yolo_head
def _forward(self, feed_vars, mode='train'):
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
body_feats = self.backbone(im)
......@@ -63,10 +63,10 @@ class YOLOv3(object):
return self.yolo_head.get_prediction(body_feats, im_shape)
def train(self, feed_vars):
return self._forward(feed_vars, mode='train')
return self.build(feed_vars, mode='train')
def eval(self, feed_vars):
return self._forward(feed_vars, mode='test')
return self.build(feed_vars, mode='test')
def test(self, feed_vars):
return self._forward(feed_vars, mode='test')
return self.build(feed_vars, mode='test')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册