未验证 提交 1f3255d7 编写于 作者: S Sergei Slashchinin 提交者: GitHub

Merge pull request #18591 from sl-sergei:download_utilities

Scripts for downloading models in DNN samples

* Initial commit. Utility classes and functions for downloading files

* updated download script

* Support YAML parsing, update download script and configs

* Fix problem with archived files

* fix models.yml

* Move download utilities to more appropriate place

* Fix script description

* Update README

* update utilities for broader range of files

* fix loading with no hashsum provided

* remove unnecessary import

* fix for Python2

* Add usage examples for downloadFile function

* Add more secure cache folder selection

* Remove trailing whitespaces

* Fix indentation

* Update function interface

* Change function for temp dir, change entry name in models.yml

* Update getCacheDirectory function call

* Return python implementation for cache directory selection, use more specific env variable

* Fix whitespace
上级 fdeac73a
*.caffemodel
*.pb
*.weights
\ No newline at end of file
......@@ -19,6 +19,36 @@ Check `-h` option to know which values are used by default:
python object_detection.py opencv_fd -h
```
### Sample models
You can download sample models using ```download_models.py```. For example, the following command will download network weights for OpenCV Face Detector model and store them in FaceDetector folder:
```bash
python download_models.py --save_dir FaceDetector opencv_fd
```
You can use default configuration files adopted for OpenCV from [here](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn).
You also can use the script to download necessary files from your code. Assume you have the following code inside ```your_script.py```:
```python
from download_models import downloadFile
filepath1 = downloadFile("https://drive.google.com/uc?export=download&id=0B3gersZ2cHIxRm5PMWRoTkdHdHc", None, filename="MobileNetSSD_deploy.caffemodel", save_dir="save_dir_1")
filepath2 = downloadFile("https://drive.google.com/uc?export=download&id=0B3gersZ2cHIxRm5PMWRoTkdHdHc", "994d30a8afaa9e754d17d2373b2d62a7dfbaaf7a", filename="MobileNetSSD_deploy.caffemodel")
print(filepath1)
print(filepath2)
# Your code
```
By running the following commands, you will get **MobileNetSSD_deploy.caffemodel** file:
```bash
export OPENCV_DOWNLOAD_DATA_PATH=download_folder
python your_script.py
```
**Note** that you can provide a directory using **save_dir** parameter or via **OPENCV_SAVE_DIR** environment variable.
#### Face detection
[An origin model](https://github.com/opencv/opencv/tree/3.4/samples/dnn/face_detector)
with single precision floating point weights has been quantized using [TensorFlow framework](https://www.tensorflow.org/).
......@@ -48,7 +78,7 @@ AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ] | 0.528 | 0.528 |
```
## References
* [Models downloading script](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/download_models.py)
* [Models downloading script](https://github.com/opencv/opencv/samples/dnn/download_models.py)
* [Configuration files adopted for OpenCV](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn)
* [How to import models from TensorFlow Object Detection API](https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API)
* [Names of classes from different datasets](https://github.com/opencv/opencv/tree/3.4/samples/data/dnn)
'''
Helper module to download extra data from Internet
'''
from __future__ import print_function
import os
import cv2
import sys
import yaml
import argparse
import tarfile
import platform
import tempfile
import hashlib
import requests
import shutil
from pathlib import Path
from datetime import datetime
if sys.version_info[0] < 3:
from urllib2 import urlopen
else:
from urllib.request import urlopen
import xml.etree.ElementTree as ET
__all__ = ["downloadFile"]
class HashMismatchException(Exception):
def __init__(self, expected, actual):
Exception.__init__(self)
self.expected = expected
self.actual = actual
def __str__(self):
return 'Hash mismatch: expected {} vs actual of {}'.format(self.expected, self.actual)
def getHashsumFromFile(filepath):
sha = hashlib.sha1()
if os.path.exists(filepath):
print(' there is already a file with the same name')
with open(filepath, 'rb') as f:
while True:
buf = f.read(10*1024*1024)
if not buf:
break
sha.update(buf)
hashsum = sha.hexdigest()
return hashsum
def checkHashsum(expected_sha, filepath, silent=True):
print(' expected SHA1: {}'.format(expected_sha))
actual_sha = getHashsumFromFile(filepath)
print(' actual SHA1:{}'.format(actual_sha))
hashes_matched = expected_sha == actual_sha
if not hashes_matched and not silent:
raise HashMismatchException(expected_sha, actual_sha)
return hashes_matched
def isArchive(filepath):
return tarfile.is_tarfile(filepath)
class DownloadInstance:
def __init__(self, **kwargs):
self.name = kwargs.pop('name')
self.filename = kwargs.pop('filename')
self.loader = kwargs.pop('loader', None)
self.save_dir = kwargs.pop('save_dir')
self.sha = kwargs.pop('sha', None)
def __str__(self):
return 'DownloadInstance <{}>'.format(self.name)
def get(self):
print(" Working on " + self.name)
print(" Getting file " + self.filename)
if self.sha is None:
print(' No expected hashsum provided, loading file')
else:
filepath = os.path.join(self.save_dir, self.sha, self.filename)
if checkHashsum(self.sha, filepath):
print(' hash match - file already exists, skipping')
return filepath
else:
print(' hash didn\'t match, loading file')
if not os.path.exists(self.save_dir):
print(' creating directory: ' + self.save_dir)
os.makedirs(self.save_dir)
print(' hash check failed - loading')
assert self.loader
try:
self.loader.load(self.filename, self.sha, self.save_dir)
print(' done')
print(' file {}'.format(self.filename))
if self.sha is None:
download_path = os.path.join(self.save_dir, self.filename)
self.sha = getHashsumFromFile(download_path)
new_dir = os.path.join(self.save_dir, self.sha)
if not os.path.exists(new_dir):
os.makedirs(new_dir)
filepath = os.path.join(new_dir, self.filename)
if not (os.path.exists(filepath)):
shutil.move(download_path, new_dir)
print(' No expected hashsum provided, actual SHA is {}'.format(self.sha))
else:
checkHashsum(self.sha, filepath, silent=False)
except Exception as e:
print(" There was some problem with loading file {} for {}".format(self.filename, self.name))
print(" Exception: {}".format(e))
return
print(" Finished " + self.name)
return filepath
class Loader(object):
MB = 1024*1024
BUFSIZE = 10*MB
def __init__(self, download_name, download_sha, archive_member = None):
self.download_name = download_name
self.download_sha = download_sha
self.archive_member = archive_member
def load(self, requested_file, sha, save_dir):
if self.download_sha is None:
download_dir = save_dir
else:
# create a new folder in save_dir to avoid possible name conflicts
download_dir = os.path.join(save_dir, self.download_sha)
if not os.path.exists(download_dir):
os.makedirs(download_dir)
download_path = os.path.join(download_dir, self.download_name)
print(" Preparing to download file " + self.download_name)
if checkHashsum(self.download_sha, download_path):
print(' hash match - file already exists, no need to download')
else:
filesize = self.download(download_path)
print(' Downloaded {} with size {} Mb'.format(self.download_name, filesize/self.MB))
if self.download_sha is not None:
checkHashsum(self.download_sha, download_path, silent=False)
if self.download_name == requested_file:
return
else:
if isArchive(download_path):
if sha is not None:
extract_dir = os.path.join(save_dir, sha)
else:
extract_dir = save_dir
if not os.path.exists(extract_dir):
os.makedirs(extract_dir)
self.extract(requested_file, download_path, extract_dir)
else:
raise Exception("Downloaded file has different name")
def download(self, filepath):
print("Warning: download is not implemented, this is a base class")
return 0
def extract(self, requested_file, archive_path, save_dir):
filepath = os.path.join(save_dir, requested_file)
try:
with tarfile.open(archive_path) as f:
if self.archive_member is None:
pathDict = dict((os.path.split(elem)[1], os.path.split(elem)[0]) for elem in f.getnames())
self.archive_member = pathDict[requested_file]
assert self.archive_member in f.getnames()
self.save(filepath, f.extractfile(self.archive_member))
except Exception as e:
print(' catch {}'.format(e))
def save(self, filepath, r):
with open(filepath, 'wb') as f:
print(' progress ', end="")
sys.stdout.flush()
while True:
buf = r.read(self.BUFSIZE)
if not buf:
break
f.write(buf)
print('>', end="")
sys.stdout.flush()
class URLLoader(Loader):
def __init__(self, download_name, download_sha, url, archive_member = None):
super(URLLoader, self).__init__(download_name, download_sha, archive_member)
self.download_name = download_name
self.download_sha = download_sha
self.url = url
def download(self, filepath):
r = urlopen(self.url, timeout=60)
self.printRequest(r)
self.save(filepath, r)
return os.path.getsize(filepath)
def printRequest(self, r):
def getMB(r):
d = dict(r.info())
for c in ['content-length', 'Content-Length']:
if c in d:
return int(d[c]) / self.MB
return '<unknown>'
print(' {} {} [{} Mb]'.format(r.getcode(), r.msg, getMB(r)))
class GDriveLoader(Loader):
BUFSIZE = 1024 * 1024
PROGRESS_SIZE = 10 * 1024 * 1024
def __init__(self, download_name, download_sha, gid, archive_member = None):
super(GDriveLoader, self).__init__(download_name, download_sha, archive_member)
self.download_name = download_name
self.download_sha = download_sha
self.gid = gid
def download(self, filepath):
session = requests.Session() # re-use cookies
URL = "https://docs.google.com/uc?export=download"
response = session.get(URL, params = { 'id' : self.gid }, stream = True)
def get_confirm_token(response): # in case of large files
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
token = get_confirm_token(response)
if token:
params = { 'id' : self.gid, 'confirm' : token }
response = session.get(URL, params = params, stream = True)
sz = 0
progress_sz = self.PROGRESS_SIZE
with open(filepath, "wb") as f:
for chunk in response.iter_content(self.BUFSIZE):
if not chunk:
continue # keep-alive
f.write(chunk)
sz += len(chunk)
if sz >= progress_sz:
progress_sz += self.PROGRESS_SIZE
print('>', end='')
sys.stdout.flush()
print('')
return sz
def produceDownloadInstance(instance_name, filename, sha, url, save_dir, download_name=None, download_sha=None, archive_member=None):
spec_param = url
loader = URLLoader
if download_name is None:
download_name = filename
if download_sha is None:
download_sha = sha
if "drive.google.com" in url:
token = ""
token_part = url.rsplit('/', 1)[-1]
if "&id=" not in token_part:
token_part = url.rsplit('/', 1)[-2]
for param in token_part.split("&"):
if param.startswith("id="):
token = param[3:]
if token:
loader = GDriveLoader
spec_param = token
else:
print("Warning: possibly wrong Google Drive link")
return DownloadInstance(
name=instance_name,
filename=filename,
sha=sha,
save_dir=save_dir,
loader=loader(download_name, download_sha, spec_param, archive_member)
)
def getSaveDir():
env_path = os.environ.get("OPENCV_DOWNLOAD_DATA_PATH", None)
if env_path:
save_dir = env_path
else:
# TODO reuse binding function cv2.utils.fs.getCacheDirectory when issue #19011 is fixed
if platform.system() == "Darwin":
#On Apple devices
temp_env = os.environ.get("TMPDIR", None)
if temp_env is None or not os.path.isdir(temp_env):
temp_dir = Path("/tmp")
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
else:
temp_dir = temp_env
elif platform.system() == "Windows":
temp_dir = tempfile.gettempdir()
else:
xdg_cache_env = os.environ.get("XDG_CACHE_HOME", None)
if (xdg_cache_env and xdg_cache_env[0] and os.path.isdir(xdg_cache_env)):
temp_dir = xdg_cache_env
else:
home_env = os.environ.get("HOME", None)
if (home_env and home_env[0] and os.path.isdir(home_env)):
home_path = os.path.join(home_env, ".cache/")
if os.path.isdir(home_path):
temp_dir = home_path
else:
temp_dir = tempfile.gettempdir()
print("Using world accessible cache directory. This may be not secure: ", temp_dir)
save_dir = os.path.join(temp_dir, "downloads")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
return save_dir
def downloadFile(url, sha=None, save_dir=None, filename=None):
if save_dir is None:
save_dir = getSaveDir()
if filename is None:
filename = "download_" + datetime.now().__str__()
name = filename
return produceDownloadInstance(name, filename, sha, url, save_dir).get()
def parseMetalinkFile(metalink_filepath, save_dir):
NS = {'ml': 'urn:ietf:params:xml:ns:metalink'}
models = []
for file_elem in ET.parse(metalink_filepath).getroot().findall('ml:file', NS):
url = file_elem.find('ml:url', NS).text
fname = file_elem.attrib['name']
name = file_elem.find('ml:identity', NS).text
hash_sum = file_elem.find('ml:hash', NS).text
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir))
return models
def parseYAMLFile(yaml_filepath, save_dir):
models = []
with open(yaml_filepath, 'r') as stream:
data_loaded = yaml.safe_load(stream)
for name, params in data_loaded.items():
load_info = params.get("load_info", None)
if load_info:
fname = os.path.basename(params.get("model"))
hash_sum = load_info.get("sha1")
url = load_info.get("url")
download_sha = load_info.get("download_sha")
download_name = load_info.get("download_name")
archive_member = load_info.get("member")
models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
download_name=download_name, download_sha=download_sha, archive_member=archive_member))
return models
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='This is a utility script for downloading DNN models for samples.')
parser.add_argument('--save_dir', action="store", default=os.getcwd(),
help='Path to the directory to store downloaded files')
parser.add_argument('model_name', type=str, default="", nargs='?', action="store",
help='name of the model to download')
args = parser.parse_args()
models = []
save_dir = args.save_dir
selected_model_name = args.model_name
models.extend(parseMetalinkFile('face_detector/weights.meta4', save_dir))
models.extend(parseYAMLFile('models.yml', save_dir))
for m in models:
print(m)
if selected_model_name and not m.name.startswith(selected_model_name):
continue
print('Model: ' + selected_model_name)
m.get()
\ No newline at end of file
#!/usr/bin/env python
from __future__ import print_function
import hashlib
import time
import sys
import xml.etree.ElementTree as ET
if sys.version_info[0] < 3:
from urllib2 import urlopen
else:
from urllib.request import urlopen
class HashMismatchException(Exception):
def __init__(self, expected, actual):
Exception.__init__(self)
self.expected = expected
self.actual = actual
def __str__(self):
return 'Hash mismatch: {} vs {}'.format(self.expected, self.actual)
class MetalinkDownloader(object):
BUFSIZE = 10*1024*1024
NS = {'ml': 'urn:ietf:params:xml:ns:metalink'}
tick = 0
def download(self, metalink_file):
status = True
for file_elem in ET.parse(metalink_file).getroot().findall('ml:file', self.NS):
url = file_elem.find('ml:url', self.NS).text
fname = file_elem.attrib['name']
hash_sum = file_elem.find('ml:hash', self.NS).text
print('*** {}'.format(fname))
try:
self.verify(hash_sum, fname)
except Exception as ex:
print(' {}'.format(ex))
try:
print(' {}'.format(url))
with open(fname, 'wb') as file_stream:
self.buffered_read(urlopen(url), file_stream.write)
self.verify(hash_sum, fname)
except Exception as ex:
print(' {}'.format(ex))
print(' FAILURE')
status = False
continue
print(' SUCCESS')
return status
def print_progress(self, msg, timeout = 0):
if time.time() - self.tick > timeout:
print(msg, end='')
sys.stdout.flush()
self.tick = time.time()
def buffered_read(self, in_stream, processing):
self.print_progress(' >')
while True:
buf = in_stream.read(self.BUFSIZE)
if not buf:
break
processing(buf)
self.print_progress('>', 5)
print(' done')
def verify(self, hash_sum, fname):
sha = hashlib.sha1()
with open(fname, 'rb') as file_stream:
self.buffered_read(file_stream, sha.update)
if hash_sum != sha.hexdigest():
raise HashMismatchException(hash_sum, sha.hexdigest())
if __name__ == '__main__':
sys.exit(0 if MetalinkDownloader().download('weights.meta4') else 1)
<?xml version="1.0" encoding="UTF-8"?>
<metalink xmlns="urn:ietf:params:xml:ns:metalink">
<file name="res10_300x300_ssd_iter_140000_fp16.caffemodel">
<identity>OpenCV face detector FP16 weights</identity>
<identity>opencv_face_detector_fp16</identity>
<hash type="sha-1">31fc22bfdd907567a04bb45b7cfad29966caddc1</hash>
<url>https://raw.githubusercontent.com/opencv/opencv_3rdparty/dnn_samples_face_detector_20180205_fp16/res10_300x300_ssd_iter_140000_fp16.caffemodel</url>
</file>
<file name="opencv_face_detector_uint8.pb">
<identity>OpenCV face detector UINT8 weights</identity>
<identity>opencv_face_detector_uint8</identity>
<hash type="sha-1">4f2fdf6f231d759d7bbdb94353c5a68690f3d2ae</hash>
<url>https://raw.githubusercontent.com/opencv/opencv_3rdparty/dnn_samples_face_detector_20180220_uint8/opencv_face_detector_uint8.pb</url>
</file>
......
%YAML:1.0
%YAML 1.0
---
################################################################################
# Object detection models.
################################################################################
# OpenCV's face detection network
opencv_fd:
load_info:
url: "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel"
sha1: "15aa726b4d46d9f023526d85537db81cbc8dd566"
model: "opencv_face_detector.caffemodel"
config: "opencv_face_detector.prototxt"
mean: [104, 177, 123]
......@@ -19,6 +22,9 @@ opencv_fd:
# YOLO object detection family from Darknet (https://pjreddie.com/darknet/yolo/)
# Might be used for all YOLOv2, TinyYolov2, YOLOv3, YOLOv4 and TinyYolov4
yolo:
load_info:
url: "https://pjreddie.com/media/files/yolov3.weights"
sha1: "520878f12e97cf820529daea502acca380f1cb8e"
model: "yolov3.weights"
config: "yolov3.cfg"
mean: [0, 0, 0]
......@@ -30,6 +36,9 @@ yolo:
sample: "object_detection"
tiny-yolo-voc:
load_info:
url: "https://pjreddie.com/media/files/yolov2-tiny-voc.weights"
sha1: "24b4bd049fc4fa5f5e95f684a8967e65c625dff9"
model: "tiny-yolo-voc.weights"
config: "tiny-yolo-voc.cfg"
mean: [0, 0, 0]
......@@ -42,6 +51,9 @@ tiny-yolo-voc:
# Caffe implementation of SSD model from https://github.com/chuanqi305/MobileNet-SSD
ssd_caffe:
load_info:
url: "https://drive.google.com/uc?export=download&id=0B3gersZ2cHIxRm5PMWRoTkdHdHc"
sha1: "994d30a8afaa9e754d17d2373b2d62a7dfbaaf7a"
model: "MobileNetSSD_deploy.caffemodel"
config: "MobileNetSSD_deploy.prototxt"
mean: [127.5, 127.5, 127.5]
......@@ -54,6 +66,12 @@ ssd_caffe:
# TensorFlow implementation of SSD model from https://github.com/tensorflow/models/tree/master/research/object_detection
ssd_tf:
load_info:
url: "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz"
sha1: "9e4bcdd98f4c6572747679e4ce570de4f03a70e2"
download_sha: "6157ddb6da55db2da89dd561eceb7f944928e317"
download_name: "ssd_mobilenet_v1_coco_2017_11_17.tar.gz"
member: "ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb"
model: "ssd_mobilenet_v1_coco_2017_11_17.pb"
config: "ssd_mobilenet_v1_coco_2017_11_17.pbtxt"
mean: [0, 0, 0]
......@@ -66,6 +84,12 @@ ssd_tf:
# TensorFlow implementation of Faster-RCNN model from https://github.com/tensorflow/models/tree/master/research/object_detection
faster_rcnn_tf:
load_info:
url: "http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz"
sha1: "f2e4bf386b9bb3e25ddfcbbd382c20f417e444f3"
download_sha: "c710f25e5c6a3ce85fe793d5bf266d581ab1c230"
download_name: "faster_rcnn_inception_v2_coco_2018_01_28.tar.gz"
member: "faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb"
model: "faster_rcnn_inception_v2_coco_2018_01_28.pb"
config: "faster_rcnn_inception_v2_coco_2018_01_28.pbtxt"
mean: [0, 0, 0]
......@@ -81,6 +105,9 @@ faster_rcnn_tf:
# SqueezeNet v1.1 from https://github.com/DeepScale/SqueezeNet
squeezenet:
load_info:
url: "https://raw.githubusercontent.com/DeepScale/SqueezeNet/b5c3f1a23713c8b3fd7b801d229f6b04c64374a5/SqueezeNet_v1.1/squeezenet_v1.1.caffemodel"
sha1: "3397f026368a45ae236403ccc81cfcbe8ebe1bd0"
model: "squeezenet_v1.1.caffemodel"
config: "squeezenet_v1.1.prototxt"
mean: [0, 0, 0]
......@@ -93,6 +120,9 @@ squeezenet:
# Googlenet from https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet
googlenet:
load_info:
url: "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel"
sha1: "405fc5acd08a3bb12de8ee5e23a96bec22f08204"
model: "bvlc_googlenet.caffemodel"
config: "bvlc_googlenet.prototxt"
mean: [104, 117, 123]
......@@ -110,6 +140,9 @@ googlenet:
# ENet road scene segmentation network from https://github.com/e-lab/ENet-training
# Works fine for different input sizes.
enet:
load_info:
url: "https://www.dropbox.com/s/tdde0mawbi5dugq/Enet-model-best.net?dl=1"
sha1: "b4123a73bf464b9ebe9cfc4ab9c2d5c72b161315"
model: "Enet-model-best.net"
mean: [0, 0, 0]
scale: 0.00392
......@@ -120,6 +153,9 @@ enet:
sample: "segmentation"
fcn8s:
load_info:
url: "http://dl.caffe.berkeleyvision.org/fcn8s-heavy-pascal.caffemodel"
sha1: "c449ea74dd7d83751d1357d6a8c323fcf4038962"
model: "fcn8s-heavy-pascal.caffemodel"
config: "fcn8s-heavy-pascal.prototxt"
mean: [0, 0, 0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册