diff --git a/paddlehub/common/detection_config.py b/paddlehub/common/detection_config.py index 6fdae982110de5fe4a63509ec4c6090293e6fe10..b48a6b56e9c0d68669c07137aa8d2978a6634a89 100644 --- a/paddlehub/common/detection_config.py +++ b/paddlehub/common/detection_config.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - conf = { "ssd": { "with_background": True, @@ -65,7 +64,9 @@ ssd_train_ops = [ dict(op='ArrangeSSD') ] -ssd_eval_fields = ['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult'] +ssd_eval_fields = [ + 'image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult' +] ssd_eval_ops = [ dict(op='DecodeImage', to_rgb=True, with_mixup=False), dict(op='NormalizeBox'), @@ -139,9 +140,10 @@ yolo_train_ops = [ dict(op='RandomCrop'), dict(op='RandomFlipImage', is_normalized=False), dict(op='Resize', target_dim=608, interp='random'), - dict(op='NormalizePermute', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.120, 57.375]), + dict( + op='NormalizePermute', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.120, 57.375]), dict(op='NormalizeBox'), dict(op='ArrangeYOLO'), ] @@ -194,18 +196,28 @@ feed_config = { }, "rcnn": { "train": { - "fields": ['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'], - "OPS": rcnn_train_ops, - "IS_PADDING": True, - "COARSEST_STRIDE": 32, + "fields": + ['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'], + "OPS": + rcnn_train_ops, + "IS_PADDING": + True, + "COARSEST_STRIDE": + 32, }, "dev": { - "fields": ['image', 'im_info', 'im_id', 'im_shape', 'gt_box', - 'gt_label', 'is_difficult'], - "OPS": rcnn_eval_ops, - "IS_PADDING": True, - "COARSEST_STRIDE": 32, - "USE_PADDED_IM_INFO": True, + "fields": [ + 'image', 'im_info', 'im_id', 'im_shape', 'gt_box', 'gt_label', + 'is_difficult' + ], + "OPS": + rcnn_eval_ops, + "IS_PADDING": + True, + "COARSEST_STRIDE": + 32, + "USE_PADDED_IM_INFO": + True, }, "predict": { "fields": ['image', 'im_info', 'im_id', 'im_shape'], @@ -222,8 +234,10 @@ feed_config = { "RANDOM_SHAPES": [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] }, "dev": { - "fields": ['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'], - "OPS": yolo_eval_ops, + "fields": + ['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'], + "OPS": + yolo_eval_ops, }, "predict": { "fields": ['image', 'im_size', 'im_id'], @@ -231,64 +245,3 @@ feed_config = { }, }, } - - -def get_model_type(module_name): - if 'yolo' in module_name: - return 'yolo' - elif 'ssd' in module_name: - return 'ssd' - elif 'rcnn' in module_name: - return 'rcnn' - else: - raise ValueError("module {} not supported".format(module_name)) - - -def get_feed_list(module_name, input_dict, input_dict_pred=None): - pred_feed_list = None - if 'yolo' in module_name: - img = input_dict["image"] - im_size = input_dict["im_size"] - feed_list = [img.name, im_size.name] - elif 'ssd' in module_name: - image = input_dict["image"] - # image_shape = input_dict["im_shape"] - image_shape = input_dict["im_size"] - feed_list = [image.name, image_shape.name] - elif 'rcnn' in module_name: - image = input_dict['image'] - im_info = input_dict['im_info'] - gt_bbox = input_dict['gt_bbox'] - gt_class = input_dict['gt_class'] - is_crowd = input_dict['is_crowd'] - feed_list = [image.name, im_info.name, gt_bbox.name, gt_class.name, is_crowd.name] - assert input_dict_pred is not None - image = input_dict_pred['image'] - im_info = input_dict_pred['im_info'] - im_shape = input_dict['im_shape'] - pred_feed_list = [image.name, im_info.name, im_shape.name] - else: - raise NotImplementedError - return feed_list, pred_feed_list - - -def get_mid_feature(module_name, output_dict, output_dict_pred=None): - feature_pred = None - if 'yolo' in module_name: - feature = output_dict['head_features'] - elif 'ssd' in module_name: - feature = output_dict['body_features'] - elif 'rcnn' in module_name: - head_feat = output_dict['head_feat'] - rpn_cls_loss = output_dict['rpn_cls_loss'] - rpn_reg_loss = output_dict['rpn_reg_loss'] - generate_proposal_labels = output_dict['generate_proposal_labels'] - feature = [head_feat, rpn_cls_loss, rpn_reg_loss, generate_proposal_labels] - assert output_dict_pred is not None - head_feat = output_dict_pred['head_feat'] - rois = output_dict_pred['rois'] - feature_pred = [head_feat, rois] - else: - raise NotImplementedError - return feature, feature_pred -