提交 9f64d907 编写于 作者: W wuzewu

Add pretrained model download script

上级 411386ff
......@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from download_util import download_file_and_uncompress
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "..", "test")
sys.path.append(TEST_PATH)
from test_utils import download_file_and_uncompress
def download_cityscapes_dataset(savepath, extrapath):
......
......@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from download_util import download_file_and_uncompress
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "..", "test")
sys.path.append(TEST_PATH)
from test_utils import download_file_and_uncompress
def download_pet_dataset(savepath, extrapath):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import shutil
import requests
import sys
import tarfile
import zipfile
import platform
import functools
lasttime = time.time()
FLUSH_INTERVAL = 0.1
def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
def _download_file(url, savepath, print_progress):
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(dl / total_length * 100)))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
def _uncompress_file(filepath, extrapath, delete_file, print_progress):
if print_progress:
print("Uncompress %s" % os.path.basename(filepath))
if filepath.endswith("zip"):
handler = _uncompress_file_zip
elif filepath.endswith("tgz"):
handler = _uncompress_file_tar
else:
handler = functools.partial(_uncompress_file_tar, mode="r")
for total_num, index in handler(filepath, extrapath):
if print_progress:
done = int(50 * float(index) / total_num)
progress(
"[%-50s] %.2f%%" % ('=' * done, float(index / total_num * 100)))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
if delete_file:
os.remove(filepath)
def _uncompress_file_zip(filepath, extrapath):
files = zipfile.ZipFile(filepath, 'r')
filelist = files.namelist()
total_num = len(filelist)
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index
files.close()
yield total_num, index
def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
files = tarfile.open(filepath, mode)
filelist = files.getnames()
total_num = len(filelist)
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index
files.close()
yield total_num, index
def download_file_and_uncompress(url,
savepath=None,
extrapath=None,
print_progress=True,
cover=False,
delete_file=True):
if savepath is None:
savepath = "."
if extrapath is None:
extrapath = "."
savename = url.split("/")[-1]
savepath = os.path.join(savepath, savename)
extraname = ".".join(savename.split(".")[:-1])
extraname = os.path.join(extrapath, extraname)
if cover:
if os.path.exists(savepath):
shutil.rmtree(savepath)
if os.path.exists(extraname):
shutil.rmtree(extraname)
if not os.path.exists(extraname):
if not os.path.exists(savepath):
_download_file(url, savepath, print_progress)
_uncompress_file(savepath, extrapath, delete_file, print_progress)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "..", "test")
sys.path.append(TEST_PATH)
from test_utils import download_file_and_uncompress
model_urls = {
"deeplabv3plus_mobilenetv2-1-0_bn_cityscapes":
"https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz",
"unet_bn_coco": "https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz"
}
if __name__ == "__main__":
if len(sys.argv) != 2:
print("usage:\n python download_model.py ${MODEL_NAME}")
exit(1)
model_name = sys.argv[1]
if not model_name in model_urls.keys():
print("Only support: \n {}".format("\n ".join(
list(model_urls.keys()))))
exit(1)
url = model_urls[model_name]
download_file_and_uncompress(
url=url,
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH,
extraname=model_name)
print("Pretrained Model download success!")
......@@ -121,6 +121,7 @@ def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
def download_file_and_uncompress(url,
savepath=None,
extrapath=None,
extraname=None,
print_progress=True,
cover=False,
delete_file=True):
......@@ -132,19 +133,25 @@ def download_file_and_uncompress(url,
savename = url.split("/")[-1]
savepath = os.path.join(savepath, savename)
extraname = ".".join(savename.split(".")[:-1])
extraname = os.path.join(extrapath, extraname)
savename = ".".join(savename.split(".")[:-1])
savename = os.path.join(extrapath, savename)
extraname = savename if extraname is None else os.path.join(
extrapath, extraname)
if cover:
if os.path.exists(savepath):
shutil.rmtree(savepath)
if os.path.exists(savename):
shutil.rmtree(savename)
if os.path.exists(extraname):
shutil.rmtree(extraname)
if not os.path.exists(extraname):
if not os.path.exists(savename):
if not os.path.exists(savepath):
_download_file(url, savepath, print_progress)
_uncompress_file(savepath, extrapath, delete_file, print_progress)
shutil.move(savename, extraname)
def _pdseg(command, flags, options, devices):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册