未验证 提交 400250ed 编写于 作者: H haoyuying 提交者: GitHub

fix download dataset issue

上级 558ac0b8
...@@ -5,5 +5,5 @@ if __name__ == '__main__': ...@@ -5,5 +5,5 @@ if __name__ == '__main__':
model = hub.Module( model = hub.Module(
name='resnet50_vd_imagenet_ssld', name='resnet50_vd_imagenet_ssld',
label_list=["roses", "tulips", "daisy", "sunflowers", "dandelion"], label_list=["roses", "tulips", "daisy", "sunflowers", "dandelion"],
oad_checkpoint='/PATH/TO/CHECKPOINT') load_checkpoint='/PATH/TO/CHECKPOINT')
result = model.predict(['flower.jpg']) result = model.predict(['flower.jpg'])
...@@ -15,20 +15,27 @@ ...@@ -15,20 +15,27 @@
import os import os
import filelock
import paddlehub.env as hubenv import paddlehub.env as hubenv
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from paddlehub.utils import log, utils, xarfile from paddlehub.utils import log, utils, xarfile
def download_data(url): def download_data(url):
save_name = os.path.basename(url).split('.')[0]
output_path = os.path.join(hubenv.DATA_HOME, save_name)
if not os.path.exists(output_path):
get_path_from_url(url, hubenv.DATA_HOME)
def _wrapper(Dataset): def _wrapper(Dataset):
return Dataset
def _download_dataset_from_url(*args, **kwargs):
save_name = os.path.basename(url).split('.')[0]
output_path = os.path.join(hubenv.DATA_HOME, save_name)
lock = filelock.FileLock(os.path.join(hubenv.TMP_HOME, save_name))
with lock:
if not os.path.exists(output_path):
default_downloader.download_file_and_uncompress(url, hubenv.DATA_HOME, True)
return Dataset(*args, **kwargs)
return _download_dataset_from_url
return _wrapper return _wrapper
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册