未验证 提交 6fbdb5c0 编写于 作者: W wangxinxin08 提交者: GitHub

fix problem of s2anet while deploying (#4624)

* fix problem of s2anet while deploying

* correct problem of s2anet TestReader

* add dota default category
上级 b5c534c2
...@@ -17,4 +17,3 @@ EvalDataset: ...@@ -17,4 +17,3 @@ EvalDataset:
TestDataset: TestDataset:
!ImageFolder !ImageFolder
anno_path: trainval_split/s2anet_trainval_paddle_coco.json anno_path: trainval_split/s2anet_trainval_paddle_coco.json
dataset_dir: dataset/DOTA_1024_s2anet/
...@@ -29,8 +29,6 @@ EvalReader: ...@@ -29,8 +29,6 @@ EvalReader:
TestReader: TestReader:
inputs_def:
image_shape: [3, 1024, 1024]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True} - Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True}
......
...@@ -24,4 +24,3 @@ S2ANetHead: ...@@ -24,4 +24,3 @@ S2ANetHead:
reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1] reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1]
cls_loss_weight: [1.1, 1.05] cls_loss_weight: [1.1, 1.05]
reg_loss_type: 'l1' reg_loss_type: 'l1'
is_training: True
...@@ -28,4 +28,3 @@ S2ANetHead: ...@@ -28,4 +28,3 @@ S2ANetHead:
use_sigmoid_cls: True use_sigmoid_cls: True
reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1] reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1]
cls_loss_weight: [1.1, 1.05] cls_loss_weight: [1.1, 1.05]
is_training: True
...@@ -39,7 +39,8 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -39,7 +39,8 @@ def get_categories(metric_type, anno_file=None, arch=None):
if arch == 'keypoint_arch': if arch == 'keypoint_arch':
return (None, {'id': 'keypoint'}) return (None, {'id': 'keypoint'})
if metric_type.lower() == 'coco' or metric_type.lower() == 'rbox' or metric_type.lower() == 'snipercoco': if metric_type.lower() == 'coco' or metric_type.lower(
) == 'rbox' or metric_type.lower() == 'snipercoco':
if anno_file and os.path.isfile(anno_file): if anno_file and os.path.isfile(anno_file):
# lazy import pycocotools here # lazy import pycocotools here
from pycocotools.coco import COCO from pycocotools.coco import COCO
...@@ -53,6 +54,9 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -53,6 +54,9 @@ def get_categories(metric_type, anno_file=None, arch=None):
# anno file not exist, load default categories of COCO17 # anno file not exist, load default categories of COCO17
else: else:
if metric_type.lower() == 'rbox':
return _dota_category()
return _coco17_category() return _coco17_category()
elif metric_type.lower() == 'voc': elif metric_type.lower() == 'voc':
...@@ -294,6 +298,34 @@ def _coco17_category(): ...@@ -294,6 +298,34 @@ def _coco17_category():
return clsid2catid, catid2name return clsid2catid, catid2name
def _dota_category():
"""
Get class id to category id map and category id
to category name map of dota dataset
"""
catid2name = {
0: 'background',
1: 'plane',
2: 'baseball-diamond',
3: 'bridge',
4: 'ground-track-field',
5: 'small-vehicle',
6: 'large-vehicle',
7: 'ship',
8: 'tennis-court',
9: 'basketball-court',
10: 'storage-tank',
11: 'soccer-ball-field',
12: 'roundabout',
13: 'harbor',
14: 'swimming-pool',
15: 'helicopter'
}
catid2name.pop(0)
clsid2catid = {i: i + 1 for i in range(len(catid2name))}
return clsid2catid, catid2name
def _vocall_category(): def _vocall_category():
""" """
Get class id to category id map and category id Get class id to category id map and category id
......
...@@ -233,8 +233,7 @@ class S2ANetHead(nn.Layer): ...@@ -233,8 +233,7 @@ class S2ANetHead(nn.Layer):
anchor_assign=RBoxAssigner().__dict__, anchor_assign=RBoxAssigner().__dict__,
reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.1], reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.1],
cls_loss_weight=[1.1, 1.05], cls_loss_weight=[1.1, 1.05],
reg_loss_type='l1', reg_loss_type='l1'):
is_training=True):
super(S2ANetHead, self).__init__() super(S2ANetHead, self).__init__()
self.stacked_convs = stacked_convs self.stacked_convs = stacked_convs
self.feat_in = feat_in self.feat_in = feat_in
...@@ -260,8 +259,6 @@ class S2ANetHead(nn.Layer): ...@@ -260,8 +259,6 @@ class S2ANetHead(nn.Layer):
self.alpha = 1.0 self.alpha = 1.0
self.beta = 1.0 self.beta = 1.0
self.reg_loss_type = reg_loss_type self.reg_loss_type = reg_loss_type
self.is_training = is_training
self.s2anet_head_out = None self.s2anet_head_out = None
# anchor # anchor
...@@ -451,12 +448,10 @@ class S2ANetHead(nn.Layer): ...@@ -451,12 +448,10 @@ class S2ANetHead(nn.Layer):
init_anchors = self.rect2rbox(init_anchors) init_anchors = self.rect2rbox(init_anchors)
self.base_anchors_list.append(init_anchors) self.base_anchors_list.append(init_anchors)
if self.is_training: if self.training:
refine_anchor = self.bbox_decode(fam_reg.detach(), init_anchors) refine_anchor = self.bbox_decode(fam_reg.detach(), init_anchors)
else: else:
fam_reg1 = fam_reg.clone() refine_anchor = self.bbox_decode(fam_reg, init_anchors)
fam_reg1.stop_gradient = True
refine_anchor = self.bbox_decode(fam_reg1, init_anchors)
self.refine_anchor_list.append(refine_anchor) self.refine_anchor_list.append(refine_anchor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册