From 37e2b92089ed583ba9e73f615444c3b080cd1b63 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 27 Feb 2017 23:41:32 +0000 Subject: [PATCH] Add md5file into dataset/common.py, and unit test in tests/common_test.py --- python/paddle/v2/dataset/common.py | 13 +++++++++++-- python/paddle/v2/dataset/tests/common_test.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 python/paddle/v2/dataset/tests/common_test.py diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index ae4a5383b03..ff5ed76c0f7 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -3,7 +3,7 @@ import os import shutil import urllib2 -__all__ = ['DATA_HOME', 'download'] +__all__ = ['DATA_HOME', 'download', 'md5file'] DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') @@ -11,7 +11,7 @@ if not os.path.exists(DATA_HOME): os.makedirs(DATA_HOME) -def download(url, md5): +def download(url, package_name, md5): filename = os.path.split(url)[-1] assert DATA_HOME is not None filepath = os.path.join(DATA_HOME, md5) @@ -34,3 +34,12 @@ def download(url, md5): with open(__full_file__, mode='wb') as of: shutil.copyfileobj(fsrc=response, fdst=of) return __full_file__ + + +def md5file(fname): + hash_md5 = hashlib.md5() + f = open(fname, "rb") + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + f.close() + return hash_md5.hexdigest() diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py new file mode 100644 index 00000000000..d2f97f06de3 --- /dev/null +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -0,0 +1,16 @@ +import paddle.v2.dataset.common +import unittest +import tempfile + +class TestCommon(unittest.TestCase): + def test_md5file(self): + _, temp_path =tempfile.mkstemp() + f = open(temp_path, 'w') + f.write("Hello\n") + f.close() + self.assertEqual( + '09f7e02f1290be211da707a266f153b3', + paddle.v2.dataset.common.md5file(temp_path)) + +if __name__ == '__main__': + unittest.main() -- GitLab