未验证 提交 bf700a8e 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix ssd background to last (#2145)

上级 6d92ef31
...@@ -165,6 +165,7 @@ class Trainer(object): ...@@ -165,6 +165,7 @@ class Trainer(object):
if not self._weights_loaded: if not self._weights_loaded:
self.load_weights(self.cfg.pretrain_weights) self.load_weights(self.cfg.pretrain_weights)
model = self.model
if self._nranks > 1: if self._nranks > 1:
model = paddle.DataParallel(self.model) model = paddle.DataParallel(self.model)
else: else:
......
...@@ -102,7 +102,7 @@ class DetectionMAP(object): ...@@ -102,7 +102,7 @@ class DetectionMAP(object):
self.evaluate_difficult = evaluate_difficult self.evaluate_difficult = evaluate_difficult
self.reset() self.reset()
def update(self, bbox, gt_box, gt_label, difficult=None): def update(self, bbox, score, label, gt_box, gt_label, difficult=None):
""" """
Update metric statics from given prediction and ground Update metric statics from given prediction and ground
truth infomations. truth infomations.
...@@ -117,13 +117,13 @@ class DetectionMAP(object): ...@@ -117,13 +117,13 @@ class DetectionMAP(object):
# record class score positive # record class score positive
visited = [False] * len(gt_label) visited = [False] * len(gt_label)
for b in bbox: for b, s, l in zip(bbox, score, label):
label, score, xmin, ymin, xmax, ymax = b.tolist() xmin, ymin, xmax, ymax = b.tolist()
pred = [xmin, ymin, xmax, ymax] pred = [xmin, ymin, xmax, ymax]
max_idx = -1 max_idx = -1
max_overlap = -1.0 max_overlap = -1.0
for i, gl in enumerate(gt_label): for i, gl in enumerate(gt_label):
if int(gl) == int(label): if int(gl) == int(l):
overlap = jaccard_overlap(pred, gt_box[i], overlap = jaccard_overlap(pred, gt_box[i],
self.is_bbox_normalized) self.is_bbox_normalized)
if overlap > max_overlap: if overlap > max_overlap:
...@@ -134,12 +134,12 @@ class DetectionMAP(object): ...@@ -134,12 +134,12 @@ class DetectionMAP(object):
if self.evaluate_difficult or \ if self.evaluate_difficult or \
int(np.array(difficult[max_idx])) == 0: int(np.array(difficult[max_idx])) == 0:
if not visited[max_idx]: if not visited[max_idx]:
self.class_score_poss[int(label)].append([score, 1.0]) self.class_score_poss[int(l)].append([s, 1.0])
visited[max_idx] = True visited[max_idx] = True
else: else:
self.class_score_poss[int(label)].append([score, 0.0]) self.class_score_poss[int(l)].append([s, 0.0])
else: else:
self.class_score_poss[int(label)].append([score, 0.0]) self.class_score_poss[int(l)].append([s, 0.0])
def reset(self): def reset(self):
""" """
......
...@@ -148,6 +148,8 @@ class VOCMetric(Metric): ...@@ -148,6 +148,8 @@ class VOCMetric(Metric):
def update(self, inputs, outputs): def update(self, inputs, outputs):
bboxes = outputs['bbox'].numpy() bboxes = outputs['bbox'].numpy()
scores = outputs['score'].numpy()
labels = outputs['label'].numpy()
bbox_lengths = outputs['bbox_num'].numpy() bbox_lengths = outputs['bbox_num'].numpy()
if bboxes.shape == (1, 1) or bboxes is None: if bboxes.shape == (1, 1) or bboxes is None:
...@@ -171,9 +173,12 @@ class VOCMetric(Metric): ...@@ -171,9 +173,12 @@ class VOCMetric(Metric):
else difficults[i] else difficults[i]
bbox_num = bbox_lengths[i] bbox_num = bbox_lengths[i]
bbox = bboxes[bbox_idx:bbox_idx + bbox_num] bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
score = scores[bbox_idx:bbox_idx + bbox_num]
label = labels[bbox_idx:bbox_idx + bbox_num]
gt_box, gt_label, difficult = prune_zero_padding(gt_box, gt_label, gt_box, gt_label, difficult = prune_zero_padding(gt_box, gt_label,
difficult) difficult)
self.detection_map.update(bbox, gt_box, gt_label, difficult) self.detection_map.update(bbox, score, label, gt_box, gt_label,
difficult)
bbox_idx += bbox_num bbox_idx += bbox_num
def accumulate(self): def accumulate(self):
......
...@@ -54,4 +54,14 @@ class SSD(BaseArch): ...@@ -54,4 +54,14 @@ class SSD(BaseArch):
return {"loss": self._forward()} return {"loss": self._forward()}
def get_pred(self): def get_pred(self):
return dict(zip(['bbox', 'bbox_num'], self._forward())) bbox_pred, bbox_num = self._forward()
label = bbox_pred[:, 0]
score = bbox_pred[:, 1]
bbox = bbox_pred[:, 2:]
output = {
'bbox': bbox,
'score': score,
'label': label,
'bbox_num': bbox_num
}
return output
...@@ -58,7 +58,7 @@ class SSDHead(nn.Layer): ...@@ -58,7 +58,7 @@ class SSDHead(nn.Layer):
__inject__ = ['anchor_generator', 'loss'] __inject__ = ['anchor_generator', 'loss']
def __init__(self, def __init__(self,
num_classes=81, num_classes=80,
in_channels=(512, 1024, 512, 256, 256, 256), in_channels=(512, 1024, 512, 256, 256, 256),
anchor_generator=AnchorGeneratorSSD().__dict__, anchor_generator=AnchorGeneratorSSD().__dict__,
kernel_size=3, kernel_size=3,
...@@ -67,7 +67,8 @@ class SSDHead(nn.Layer): ...@@ -67,7 +67,8 @@ class SSDHead(nn.Layer):
conv_decay=0., conv_decay=0.,
loss='SSDLoss'): loss='SSDLoss'):
super(SSDHead, self).__init__() super(SSDHead, self).__init__()
self.num_classes = num_classes # add background class
self.num_classes = num_classes + 1
self.in_channels = in_channels self.in_channels = in_channels
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
self.loss = loss self.loss = loss
...@@ -106,7 +107,7 @@ class SSDHead(nn.Layer): ...@@ -106,7 +107,7 @@ class SSDHead(nn.Layer):
score_conv_name, score_conv_name,
nn.Conv2D( nn.Conv2D(
in_channels=in_channels[i], in_channels=in_channels[i],
out_channels=num_prior * num_classes, out_channels=num_prior * self.num_classes,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding)) padding=padding))
else: else:
...@@ -114,7 +115,7 @@ class SSDHead(nn.Layer): ...@@ -114,7 +115,7 @@ class SSDHead(nn.Layer):
score_conv_name, score_conv_name,
SepConvLayer( SepConvLayer(
in_channels=in_channels[i], in_channels=in_channels[i],
out_channels=num_prior * num_classes, out_channels=num_prior * self.num_classes,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding, padding=padding,
conv_decay=conv_decay, conv_decay=conv_decay,
...@@ -129,8 +130,8 @@ class SSDHead(nn.Layer): ...@@ -129,8 +130,8 @@ class SSDHead(nn.Layer):
box_preds = [] box_preds = []
cls_scores = [] cls_scores = []
prior_boxes = [] prior_boxes = []
for feat, box_conv, score_conv in zip(feats, self.box_convs, for i, (feat, box_conv, score_conv
self.score_convs): ) in enumerate(zip(feats, self.box_convs, self.score_convs)):
box_pred = box_conv(feat) box_pred = box_conv(feat)
box_pred = paddle.transpose(box_pred, [0, 2, 3, 1]) box_pred = paddle.transpose(box_pred, [0, 2, 3, 1])
box_pred = paddle.reshape(box_pred, [0, -1, 4]) box_pred = paddle.reshape(box_pred, [0, -1, 4])
......
...@@ -114,7 +114,8 @@ class SSDLoss(nn.Layer): ...@@ -114,7 +114,8 @@ class SSDLoss(nn.Layer):
scores = paddle.concat(scores, axis=1) scores = paddle.concat(scores, axis=1)
prior_boxes = paddle.concat(anchors, axis=0) prior_boxes = paddle.concat(anchors, axis=0)
gt_label = gt_class.unsqueeze(-1) gt_label = gt_class.unsqueeze(-1)
batch_size, num_priors, num_classes = scores.shape batch_size, num_priors = scores.shape[:2]
num_classes = scores.shape[-1] - 1
def _reshape_to_2d(x): def _reshape_to_2d(x):
return paddle.flatten(x, start_axis=2) return paddle.flatten(x, start_axis=2)
...@@ -137,7 +138,8 @@ class SSDLoss(nn.Layer): ...@@ -137,7 +138,8 @@ class SSDLoss(nn.Layer):
# 2. Compute confidence for mining hard examples # 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices # 2.1. Get the target label based on matched indices
target_label, _ = self._label_target_assign(gt_label, matched_indices) target_label, _ = self._label_target_assign(
gt_label, matched_indices, mismatch_value=num_classes)
confidence = _reshape_to_2d(scores) confidence = _reshape_to_2d(scores)
# 2.2. Compute confidence loss. # 2.2. Compute confidence loss.
# Reshape confidence to 2D tensor. # Reshape confidence to 2D tensor.
...@@ -173,7 +175,10 @@ class SSDLoss(nn.Layer): ...@@ -173,7 +175,10 @@ class SSDLoss(nn.Layer):
encoded_bbox, matched_indices) encoded_bbox, matched_indices)
# 4.3. Assign classification targets # 4.3. Assign classification targets
target_label, target_conf_weight = self._label_target_assign( target_label, target_conf_weight = self._label_target_assign(
gt_label, matched_indices, neg_mask=neg_mask) gt_label,
matched_indices,
neg_mask=neg_mask,
mismatch_value=num_classes)
# 5. Compute loss. # 5. Compute loss.
# 5.1 Compute confidence loss. # 5.1 Compute confidence loss.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册