提交 ee2d40d4 编写于 作者: F FlyingQianMM

add coco pretrained weights for detection

上级 dede0136
...@@ -201,7 +201,7 @@ class BaseAPI: ...@@ -201,7 +201,7 @@ class BaseAPI:
if backbone == "HRNet": if backbone == "HRNet":
backbone = backbone + "_W{}".format(self.width) backbone = backbone + "_W{}".format(self.width)
pretrain_weights = get_pretrain_weights( pretrain_weights = get_pretrain_weights(
pretrain_weights, self.model_type, backbone, pretrain_dir) pretrain_weights, class_name, backbone, pretrain_dir)
if startup_prog is None: if startup_prog is None:
startup_prog = fluid.default_startup_program() startup_prog = fluid.default_startup_program()
self.exe.run(startup_prog) self.exe.run(startup_prog)
......
import paddlex import paddlex
import paddlex.utils.logging as logging
import paddlehub as hub import paddlehub as hub
import os import os
import os.path as osp import os.path as osp
...@@ -73,16 +74,58 @@ image_pretrain = { ...@@ -73,16 +74,58 @@ image_pretrain = {
} }
coco_pretrain = { coco_pretrain = {
'YOLOv3_DarkNet53':
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar',
'YOLOv3_MobileNetV1':
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar',
'YOLOv3_MobileNetV3_large':
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams',
'YOLOv3_ResNet34':
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar',
'YOLOv3_ResNet50_vd':
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar',
'FasterRCNN_ResNet50':
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar',
'FasterRCNN_ResNet50_vd':
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar',
'FasterRCNN_ResNet101':
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
'FasterRCNN_ResNet101_vd':
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar',
'FasterRCNN_HRNet_W18':
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar',
'MaskRCNN_ResNet50':
'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar',
'MaskRCNN_ResNet50_vd':
'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar',
'MaskRCNN_ResNet101':
'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
'MaskRCNN_ResNet101_vd':
'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar',
'UNet': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz' 'UNet': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz'
} }
def get_pretrain_weights(flag, model_type, backbone, save_dir): def get_pretrain_weights(flag, class_name, backbone, save_dir):
if flag is None: if flag is None:
return None return None
elif osp.isdir(flag): elif osp.isdir(flag):
return flag return flag
elif flag == 'IMAGENET': warning_info = "{} supports to be finetuned with weights pretrained on the IMAGENET dataset only, so pretrain_weights is forced to be set to IMAGENET"
if flag == 'COCO':
if class_name == "FasterRCNN" and backbone in ['ResNet18'] or \
class_name == "MaskRCNN" and backbone in ['ResNet18', 'HRNet_W18'] or \
class_name == 'DeepLabv3p' and backbone in ['Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']:
model_name = '{}_{}'.format(class_name, backbone)
logging.warning(warning_info.format(model_name))
flag = 'IMAGENET'
elif class_name == 'HRNet':
logging.warning(warning_info.format(class_name))
flag = 'IMAGENET'
if flag == 'CITYSCAPES':
model_name = '{}_{}'.format(class_name, backbone)
if flag == 'IMAGENET':
new_save_dir = save_dir new_save_dir = save_dir
if hasattr(paddlex, 'pretrain_dir'): if hasattr(paddlex, 'pretrain_dir'):
new_save_dir = paddlex.pretrain_dir new_save_dir = paddlex.pretrain_dir
...@@ -94,7 +137,7 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): ...@@ -94,7 +137,7 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
backbone = 'MobileNetV3_small_x1_0_ssld' backbone = 'MobileNetV3_small_x1_0_ssld'
elif backbone == 'MobileNetV3_large_ssld': elif backbone == 'MobileNetV3_large_ssld':
backbone = 'MobileNetV3_large_x1_0_ssld' backbone = 'MobileNetV3_large_x1_0_ssld'
if model_type == 'detector': if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
if backbone == 'ResNet50': if backbone == 'ResNet50':
backbone = 'DetResNet50' backbone = 'DetResNet50'
assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format( assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
...@@ -121,6 +164,8 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): ...@@ -121,6 +164,8 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
new_save_dir = save_dir new_save_dir = save_dir
if hasattr(paddlex, 'pretrain_dir'): if hasattr(paddlex, 'pretrain_dir'):
new_save_dir = paddlex.pretrain_dir new_save_dir = paddlex.pretrain_dir
if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
backbone = '{}_{}'.format(class_name, backbone)
url = coco_pretrain[backbone] url = coco_pretrain[backbone]
fname = osp.split(url)[-1].split('.')[0] fname = osp.split(url)[-1].split('.')[0]
# paddlex.utils.download_and_decompress(url, path=new_save_dir) # paddlex.utils.download_and_decompress(url, path=new_save_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册