提交 336a4fef 编写于 作者: F FlyingQianMM

use hub to download hr pretrained weights

上级 bb3030a8
...@@ -99,11 +99,10 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): ...@@ -99,11 +99,10 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
backbone = 'DetResNet50' backbone = 'DetResNet50'
assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format( assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
backbone) backbone)
if backbone.startswith("HRNet"): # url = image_pretrain[backbone]
url = image_pretrain[backbone] # fname = osp.split(url)[-1].split('.')[0]
fname = osp.split(url)[-1].split('.')[0] # paddlex.utils.download_and_decompress(url, path=new_save_dir)
paddlex.utils.download_and_decompress(url, path=new_save_dir) # return osp.join(new_save_dir, fname)
return osp.join(new_save_dir, fname)
try: try:
hub.download(backbone, save_path=new_save_dir) hub.download(backbone, save_path=new_save_dir)
except Exception as e: except Exception as e:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册