diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index ff5ed76c0f732b385b7913e8a2fdc66af9363fc7..b1831f38afb4f15796e4eaaacce6bc37f975578a 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -1,7 +1,7 @@ +import requests import hashlib import os import shutil -import urllib2 __all__ = ['DATA_HOME', 'download', 'md5file'] @@ -11,31 +11,6 @@ if not os.path.exists(DATA_HOME): os.makedirs(DATA_HOME) -def download(url, package_name, md5): - filename = os.path.split(url)[-1] - assert DATA_HOME is not None - filepath = os.path.join(DATA_HOME, md5) - if not os.path.exists(filepath): - os.makedirs(filepath) - __full_file__ = os.path.join(filepath, filename) - - def __file_ok__(): - if not os.path.exists(__full_file__): - return False - md5_hash = hashlib.md5() - with open(__full_file__, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): - md5_hash.update(chunk) - - return md5_hash.hexdigest() == md5 - - while not __file_ok__(): - response = urllib2.urlopen(url) - with open(__full_file__, mode='wb') as of: - shutil.copyfileobj(fsrc=response, fdst=of) - return __full_file__ - - def md5file(fname): hash_md5 = hashlib.md5() f = open(fname, "rb") @@ -43,3 +18,18 @@ def md5file(fname): hash_md5.update(chunk) f.close() return hash_md5.hexdigest() + + +def download(url, module_name, md5sum): + dirname = os.path.join(DATA_HOME, module_name) + if not os.path.exists(dirname): + os.makedirs(dirname) + + filename = os.path.join(dirname, url.split('/')[-1]) + if not (os.path.exists(filename) and md5file(filename) == md5sum): + # If file doesn't exist or MD5 doesn't match, then download. + r = requests.get(url, stream=True) + with open(filename, 'w') as f: + shutil.copyfileobj(r.raw, f) + + return filename diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index d2f97f06de3db95d24c518ddf92afd6c5fd726b5..0672a4671430a9834ae01882fe8c75f1ec101b0f 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -5,12 +5,18 @@ import tempfile class TestCommon(unittest.TestCase): def test_md5file(self): _, temp_path =tempfile.mkstemp() - f = open(temp_path, 'w') - f.write("Hello\n") - f.close() + with open(temp_path, 'w') as f: + f.write("Hello\n") self.assertEqual( '09f7e02f1290be211da707a266f153b3', paddle.v2.dataset.common.md5file(temp_path)) + def test_download(self): + yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460' + self.assertEqual( + paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460', + paddle.v2.dataset.common.download( + yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d')) + if __name__ == '__main__': unittest.main()