From ee2d40d43e3ba3d647caae152335355472474097 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Sun, 7 Jun 2020 13:15:45 +0800 Subject: [PATCH] add coco pretrained weights for detection --- paddlex/cv/models/base.py | 2 +- paddlex/cv/models/utils/pretrain_weights.py | 51 +++++++++++++++++++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index ac8989f..54848c8 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -201,7 +201,7 @@ class BaseAPI: if backbone == "HRNet": backbone = backbone + "_W{}".format(self.width) pretrain_weights = get_pretrain_weights( - pretrain_weights, self.model_type, backbone, pretrain_dir) + pretrain_weights, class_name, backbone, pretrain_dir) if startup_prog is None: startup_prog = fluid.default_startup_program() self.exe.run(startup_prog) diff --git a/paddlex/cv/models/utils/pretrain_weights.py b/paddlex/cv/models/utils/pretrain_weights.py index 3abbdd9..235b576 100644 --- a/paddlex/cv/models/utils/pretrain_weights.py +++ b/paddlex/cv/models/utils/pretrain_weights.py @@ -1,4 +1,5 @@ import paddlex +import paddlex.utils.logging as logging import paddlehub as hub import os import os.path as osp @@ -73,16 +74,58 @@ image_pretrain = { } coco_pretrain = { + 'YOLOv3_DarkNet53': + 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar', + 'YOLOv3_MobileNetV1': + 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar', + 'YOLOv3_MobileNetV3_large': + 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams', + 'YOLOv3_ResNet34': + 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar', + 'YOLOv3_ResNet50_vd': + 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar', + 'FasterRCNN_ResNet50': + 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar', + 'FasterRCNN_ResNet50_vd': + 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar', + 'FasterRCNN_ResNet101': + 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar', + 'FasterRCNN_ResNet101_vd': + 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar', + 'FasterRCNN_HRNet_W18': + 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar', + 'MaskRCNN_ResNet50': + 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar', + 'MaskRCNN_ResNet50_vd': + 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar', + 'MaskRCNN_ResNet101': + 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar', + 'MaskRCNN_ResNet101_vd': + 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar', 'UNet': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz' } -def get_pretrain_weights(flag, model_type, backbone, save_dir): +def get_pretrain_weights(flag, class_name, backbone, save_dir): if flag is None: return None elif osp.isdir(flag): return flag - elif flag == 'IMAGENET': + warning_info = "{} supports to be finetuned with weights pretrained on the IMAGENET dataset only, so pretrain_weights is forced to be set to IMAGENET" + if flag == 'COCO': + if class_name == "FasterRCNN" and backbone in ['ResNet18'] or \ + class_name == "MaskRCNN" and backbone in ['ResNet18', 'HRNet_W18'] or \ + class_name == 'DeepLabv3p' and backbone in ['Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']: + model_name = '{}_{}'.format(class_name, backbone) + logging.warning(warning_info.format(model_name)) + flag = 'IMAGENET' + elif class_name == 'HRNet': + logging.warning(warning_info.format(class_name)) + flag = 'IMAGENET' + if flag == 'CITYSCAPES': + model_name = '{}_{}'.format(class_name, backbone) + + if flag == 'IMAGENET': new_save_dir = save_dir if hasattr(paddlex, 'pretrain_dir'): new_save_dir = paddlex.pretrain_dir @@ -94,7 +137,7 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): backbone = 'MobileNetV3_small_x1_0_ssld' elif backbone == 'MobileNetV3_large_ssld': backbone = 'MobileNetV3_large_x1_0_ssld' - if model_type == 'detector': + if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']: if backbone == 'ResNet50': backbone = 'DetResNet50' assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format( @@ -121,6 +164,8 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): new_save_dir = save_dir if hasattr(paddlex, 'pretrain_dir'): new_save_dir = paddlex.pretrain_dir + if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']: + backbone = '{}_{}'.format(class_name, backbone) url = coco_pretrain[backbone] fname = osp.split(url)[-1].split('.')[0] # paddlex.utils.download_and_decompress(url, path=new_save_dir) -- GitLab