提交 91115ab6 编写于 作者: Y Yi Wang

Use module name and raw data filename as the local filename

上级 37e2b920
import requests
import hashlib import hashlib
import os import os
import shutil import shutil
import urllib2
__all__ = ['DATA_HOME', 'download', 'md5file'] __all__ = ['DATA_HOME', 'download', 'md5file']
...@@ -11,31 +11,6 @@ if not os.path.exists(DATA_HOME): ...@@ -11,31 +11,6 @@ if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME) os.makedirs(DATA_HOME)
def download(url, package_name, md5):
filename = os.path.split(url)[-1]
assert DATA_HOME is not None
filepath = os.path.join(DATA_HOME, md5)
if not os.path.exists(filepath):
os.makedirs(filepath)
__full_file__ = os.path.join(filepath, filename)
def __file_ok__():
if not os.path.exists(__full_file__):
return False
md5_hash = hashlib.md5()
with open(__full_file__, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest() == md5
while not __file_ok__():
response = urllib2.urlopen(url)
with open(__full_file__, mode='wb') as of:
shutil.copyfileobj(fsrc=response, fdst=of)
return __full_file__
def md5file(fname): def md5file(fname):
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
f = open(fname, "rb") f = open(fname, "rb")
...@@ -43,3 +18,18 @@ def md5file(fname): ...@@ -43,3 +18,18 @@ def md5file(fname):
hash_md5.update(chunk) hash_md5.update(chunk)
f.close() f.close()
return hash_md5.hexdigest() return hash_md5.hexdigest()
def download(url, module_name, md5sum):
dirname = os.path.join(DATA_HOME, module_name)
if not os.path.exists(dirname):
os.makedirs(dirname)
filename = os.path.join(dirname, url.split('/')[-1])
if not (os.path.exists(filename) and md5file(filename) == md5sum):
# If file doesn't exist or MD5 doesn't match, then download.
r = requests.get(url, stream=True)
with open(filename, 'w') as f:
shutil.copyfileobj(r.raw, f)
return filename
...@@ -5,12 +5,18 @@ import tempfile ...@@ -5,12 +5,18 @@ import tempfile
class TestCommon(unittest.TestCase): class TestCommon(unittest.TestCase):
def test_md5file(self): def test_md5file(self):
_, temp_path =tempfile.mkstemp() _, temp_path =tempfile.mkstemp()
f = open(temp_path, 'w') with open(temp_path, 'w') as f:
f.write("Hello\n") f.write("Hello\n")
f.close()
self.assertEqual( self.assertEqual(
'09f7e02f1290be211da707a266f153b3', '09f7e02f1290be211da707a266f153b3',
paddle.v2.dataset.common.md5file(temp_path)) paddle.v2.dataset.common.md5file(temp_path))
def test_download(self):
yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
self.assertEqual(
paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460',
paddle.v2.dataset.common.download(
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册