提交 790860aa 编写于 作者: W wuzewu

Remove redundant code

上级 54638fd6
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
conf = {
"ssd": {
"with_background": True,
......@@ -65,7 +64,9 @@ ssd_train_ops = [
dict(op='ArrangeSSD')
]
ssd_eval_fields = ['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult']
ssd_eval_fields = [
'image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult'
]
ssd_eval_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='NormalizeBox'),
......@@ -139,9 +140,10 @@ yolo_train_ops = [
dict(op='RandomCrop'),
dict(op='RandomFlipImage', is_normalized=False),
dict(op='Resize', target_dim=608, interp='random'),
dict(op='NormalizePermute',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]),
dict(
op='NormalizePermute',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]),
dict(op='NormalizeBox'),
dict(op='ArrangeYOLO'),
]
......@@ -194,18 +196,28 @@ feed_config = {
},
"rcnn": {
"train": {
"fields": ['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'],
"OPS": rcnn_train_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
"fields":
['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'],
"OPS":
rcnn_train_ops,
"IS_PADDING":
True,
"COARSEST_STRIDE":
32,
},
"dev": {
"fields": ['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'],
"OPS": rcnn_eval_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
"USE_PADDED_IM_INFO": True,
"fields": [
'image', 'im_info', 'im_id', 'im_shape', 'gt_box', 'gt_label',
'is_difficult'
],
"OPS":
rcnn_eval_ops,
"IS_PADDING":
True,
"COARSEST_STRIDE":
32,
"USE_PADDED_IM_INFO":
True,
},
"predict": {
"fields": ['image', 'im_info', 'im_id', 'im_shape'],
......@@ -222,8 +234,10 @@ feed_config = {
"RANDOM_SHAPES": [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
},
"dev": {
"fields": ['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'],
"OPS": yolo_eval_ops,
"fields":
['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'],
"OPS":
yolo_eval_ops,
},
"predict": {
"fields": ['image', 'im_size', 'im_id'],
......@@ -231,64 +245,3 @@ feed_config = {
},
},
}
def get_model_type(module_name):
if 'yolo' in module_name:
return 'yolo'
elif 'ssd' in module_name:
return 'ssd'
elif 'rcnn' in module_name:
return 'rcnn'
else:
raise ValueError("module {} not supported".format(module_name))
def get_feed_list(module_name, input_dict, input_dict_pred=None):
pred_feed_list = None
if 'yolo' in module_name:
img = input_dict["image"]
im_size = input_dict["im_size"]
feed_list = [img.name, im_size.name]
elif 'ssd' in module_name:
image = input_dict["image"]
# image_shape = input_dict["im_shape"]
image_shape = input_dict["im_size"]
feed_list = [image.name, image_shape.name]
elif 'rcnn' in module_name:
image = input_dict['image']
im_info = input_dict['im_info']
gt_bbox = input_dict['gt_bbox']
gt_class = input_dict['gt_class']
is_crowd = input_dict['is_crowd']
feed_list = [image.name, im_info.name, gt_bbox.name, gt_class.name, is_crowd.name]
assert input_dict_pred is not None
image = input_dict_pred['image']
im_info = input_dict_pred['im_info']
im_shape = input_dict['im_shape']
pred_feed_list = [image.name, im_info.name, im_shape.name]
else:
raise NotImplementedError
return feed_list, pred_feed_list
def get_mid_feature(module_name, output_dict, output_dict_pred=None):
feature_pred = None
if 'yolo' in module_name:
feature = output_dict['head_features']
elif 'ssd' in module_name:
feature = output_dict['body_features']
elif 'rcnn' in module_name:
head_feat = output_dict['head_feat']
rpn_cls_loss = output_dict['rpn_cls_loss']
rpn_reg_loss = output_dict['rpn_reg_loss']
generate_proposal_labels = output_dict['generate_proposal_labels']
feature = [head_feat, rpn_cls_loss, rpn_reg_loss, generate_proposal_labels]
assert output_dict_pred is not None
head_feat = output_dict_pred['head_feat']
rois = output_dict_pred['rois']
feature_pred = [head_feat, rois]
else:
raise NotImplementedError
return feature, feature_pred
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册