未验证 提交 6fbdb5c0 编写于 作者: W wangxinxin08 提交者: GitHub

fix problem of s2anet while deploying (#4624)

* fix problem of s2anet while deploying

* correct problem of s2anet TestReader

* add dota default category
上级 b5c534c2
......@@ -17,4 +17,3 @@ EvalDataset:
TestDataset:
!ImageFolder
anno_path: trainval_split/s2anet_trainval_paddle_coco.json
dataset_dir: dataset/DOTA_1024_s2anet/
......@@ -29,8 +29,6 @@ EvalReader:
TestReader:
inputs_def:
image_shape: [3, 1024, 1024]
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True}
......
......@@ -24,4 +24,3 @@ S2ANetHead:
reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1]
cls_loss_weight: [1.1, 1.05]
reg_loss_type: 'l1'
is_training: True
......@@ -28,4 +28,3 @@ S2ANetHead:
use_sigmoid_cls: True
reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1]
cls_loss_weight: [1.1, 1.05]
is_training: True
......@@ -39,7 +39,8 @@ def get_categories(metric_type, anno_file=None, arch=None):
if arch == 'keypoint_arch':
return (None, {'id': 'keypoint'})
if metric_type.lower() == 'coco' or metric_type.lower() == 'rbox' or metric_type.lower() == 'snipercoco':
if metric_type.lower() == 'coco' or metric_type.lower(
) == 'rbox' or metric_type.lower() == 'snipercoco':
if anno_file and os.path.isfile(anno_file):
# lazy import pycocotools here
from pycocotools.coco import COCO
......@@ -53,6 +54,9 @@ def get_categories(metric_type, anno_file=None, arch=None):
# anno file not exist, load default categories of COCO17
else:
if metric_type.lower() == 'rbox':
return _dota_category()
return _coco17_category()
elif metric_type.lower() == 'voc':
......@@ -294,6 +298,34 @@ def _coco17_category():
return clsid2catid, catid2name
def _dota_category():
"""
Get class id to category id map and category id
to category name map of dota dataset
"""
catid2name = {
0: 'background',
1: 'plane',
2: 'baseball-diamond',
3: 'bridge',
4: 'ground-track-field',
5: 'small-vehicle',
6: 'large-vehicle',
7: 'ship',
8: 'tennis-court',
9: 'basketball-court',
10: 'storage-tank',
11: 'soccer-ball-field',
12: 'roundabout',
13: 'harbor',
14: 'swimming-pool',
15: 'helicopter'
}
catid2name.pop(0)
clsid2catid = {i: i + 1 for i in range(len(catid2name))}
return clsid2catid, catid2name
def _vocall_category():
"""
Get class id to category id map and category id
......
......@@ -233,8 +233,7 @@ class S2ANetHead(nn.Layer):
anchor_assign=RBoxAssigner().__dict__,
reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.1],
cls_loss_weight=[1.1, 1.05],
reg_loss_type='l1',
is_training=True):
reg_loss_type='l1'):
super(S2ANetHead, self).__init__()
self.stacked_convs = stacked_convs
self.feat_in = feat_in
......@@ -260,8 +259,6 @@ class S2ANetHead(nn.Layer):
self.alpha = 1.0
self.beta = 1.0
self.reg_loss_type = reg_loss_type
self.is_training = is_training
self.s2anet_head_out = None
# anchor
......@@ -451,12 +448,10 @@ class S2ANetHead(nn.Layer):
init_anchors = self.rect2rbox(init_anchors)
self.base_anchors_list.append(init_anchors)
if self.is_training:
if self.training:
refine_anchor = self.bbox_decode(fam_reg.detach(), init_anchors)
else:
fam_reg1 = fam_reg.clone()
fam_reg1.stop_gradient = True
refine_anchor = self.bbox_decode(fam_reg1, init_anchors)
refine_anchor = self.bbox_decode(fam_reg, init_anchors)
self.refine_anchor_list.append(refine_anchor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册