提交 790860aa 编写于 作者: W wuzewu

Remove redundant code

上级 54638fd6
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
conf = { conf = {
"ssd": { "ssd": {
"with_background": True, "with_background": True,
...@@ -65,7 +64,9 @@ ssd_train_ops = [ ...@@ -65,7 +64,9 @@ ssd_train_ops = [
dict(op='ArrangeSSD') dict(op='ArrangeSSD')
] ]
ssd_eval_fields = ['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult'] ssd_eval_fields = [
'image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult'
]
ssd_eval_ops = [ ssd_eval_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False), dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='NormalizeBox'), dict(op='NormalizeBox'),
...@@ -139,7 +140,8 @@ yolo_train_ops = [ ...@@ -139,7 +140,8 @@ yolo_train_ops = [
dict(op='RandomCrop'), dict(op='RandomCrop'),
dict(op='RandomFlipImage', is_normalized=False), dict(op='RandomFlipImage', is_normalized=False),
dict(op='Resize', target_dim=608, interp='random'), dict(op='Resize', target_dim=608, interp='random'),
dict(op='NormalizePermute', dict(
op='NormalizePermute',
mean=[123.675, 116.28, 103.53], mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]), std=[58.395, 57.120, 57.375]),
dict(op='NormalizeBox'), dict(op='NormalizeBox'),
...@@ -194,18 +196,28 @@ feed_config = { ...@@ -194,18 +196,28 @@ feed_config = {
}, },
"rcnn": { "rcnn": {
"train": { "train": {
"fields": ['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'], "fields":
"OPS": rcnn_train_ops, ['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'],
"IS_PADDING": True, "OPS":
"COARSEST_STRIDE": 32, rcnn_train_ops,
"IS_PADDING":
True,
"COARSEST_STRIDE":
32,
}, },
"dev": { "dev": {
"fields": ['image', 'im_info', 'im_id', 'im_shape', 'gt_box', "fields": [
'gt_label', 'is_difficult'], 'image', 'im_info', 'im_id', 'im_shape', 'gt_box', 'gt_label',
"OPS": rcnn_eval_ops, 'is_difficult'
"IS_PADDING": True, ],
"COARSEST_STRIDE": 32, "OPS":
"USE_PADDED_IM_INFO": True, rcnn_eval_ops,
"IS_PADDING":
True,
"COARSEST_STRIDE":
32,
"USE_PADDED_IM_INFO":
True,
}, },
"predict": { "predict": {
"fields": ['image', 'im_info', 'im_id', 'im_shape'], "fields": ['image', 'im_info', 'im_id', 'im_shape'],
...@@ -222,8 +234,10 @@ feed_config = { ...@@ -222,8 +234,10 @@ feed_config = {
"RANDOM_SHAPES": [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] "RANDOM_SHAPES": [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
}, },
"dev": { "dev": {
"fields": ['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'], "fields":
"OPS": yolo_eval_ops, ['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'],
"OPS":
yolo_eval_ops,
}, },
"predict": { "predict": {
"fields": ['image', 'im_size', 'im_id'], "fields": ['image', 'im_size', 'im_id'],
...@@ -231,64 +245,3 @@ feed_config = { ...@@ -231,64 +245,3 @@ feed_config = {
}, },
}, },
} }
def get_model_type(module_name):
if 'yolo' in module_name:
return 'yolo'
elif 'ssd' in module_name:
return 'ssd'
elif 'rcnn' in module_name:
return 'rcnn'
else:
raise ValueError("module {} not supported".format(module_name))
def get_feed_list(module_name, input_dict, input_dict_pred=None):
pred_feed_list = None
if 'yolo' in module_name:
img = input_dict["image"]
im_size = input_dict["im_size"]
feed_list = [img.name, im_size.name]
elif 'ssd' in module_name:
image = input_dict["image"]
# image_shape = input_dict["im_shape"]
image_shape = input_dict["im_size"]
feed_list = [image.name, image_shape.name]
elif 'rcnn' in module_name:
image = input_dict['image']
im_info = input_dict['im_info']
gt_bbox = input_dict['gt_bbox']
gt_class = input_dict['gt_class']
is_crowd = input_dict['is_crowd']
feed_list = [image.name, im_info.name, gt_bbox.name, gt_class.name, is_crowd.name]
assert input_dict_pred is not None
image = input_dict_pred['image']
im_info = input_dict_pred['im_info']
im_shape = input_dict['im_shape']
pred_feed_list = [image.name, im_info.name, im_shape.name]
else:
raise NotImplementedError
return feed_list, pred_feed_list
def get_mid_feature(module_name, output_dict, output_dict_pred=None):
feature_pred = None
if 'yolo' in module_name:
feature = output_dict['head_features']
elif 'ssd' in module_name:
feature = output_dict['body_features']
elif 'rcnn' in module_name:
head_feat = output_dict['head_feat']
rpn_cls_loss = output_dict['rpn_cls_loss']
rpn_reg_loss = output_dict['rpn_reg_loss']
generate_proposal_labels = output_dict['generate_proposal_labels']
feature = [head_feat, rpn_cls_loss, rpn_reg_loss, generate_proposal_labels]
assert output_dict_pred is not None
head_feat = output_dict_pred['head_feat']
rois = output_dict_pred['rois']
feature_pred = [head_feat, rois]
else:
raise NotImplementedError
return feature, feature_pred
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册