diff --git a/avocado/utils/download.py b/avocado/utils/download.py index b4bf33468b17767d6c0ecff1630bc0106570dc76..2cab13a16fb60a446677f6b1987dec57e3e51f9b 100644 --- a/avocado/utils/download.py +++ b/avocado/utils/download.py @@ -25,6 +25,7 @@ import urllib2 from . import aurl from . import output +from . import crypto log = logging.getLogger('avocado.test') @@ -113,15 +114,7 @@ def url_download_interactive(url, output_file, title='', chunk_size=102400): output_file.close() -def get_file(src, dst, permissions=None): - """ - Get a file from src and put it in dest, returning dest path. - - :param src: source path or URL. May be local or a remote file. - :param dst: destination path. - :param permissions: (optional) set access permissions. - :return: destination path. - """ +def _get_file(src, dst, permissions=None): if src == dst: return @@ -133,3 +126,55 @@ def get_file(src, dst, permissions=None): if permissions: os.chmod(dst, permissions) return dst + + +def get_file(src, dst, permissions=None, hash_expected=None, + hash_algorithm="md5", download_retries=1): + """ + Gets a file from a source location, optionally using caching. + + If no hash_expected is provided, simply download the file. Else, + keep trying to download the file until download_failures exceeds + download_retries or the hashes match. + + If the hashes match, return dst. If download_failures exceeds + download_retries, raise an EnvironmentError. + + :param src: source path or URL. May be local or a remote file. + :param dst: destination path. + :param permissions: (optional) set access permissions. + :param hash_expected: Hash string that we expect the file downloaded to + have. + :param hash_algorithm: Algorithm used to calculate the hash string + (md5, sha1). + :param download_retries: Number of times we are going to retry a failed + download. + :raise: EnvironmentError. + :return: destination path. + """ + def _verify_hash(filename): + if os.path.isfile(filename): + return crypto.hash_file(filename, algorithm=hash_algorithm) + return None + + if hash_expected is None: + return _get_file(src, dst, permissions) + + download_failures = 0 + hash_file = _verify_hash(dst) + + while not hash_file == hash_expected: + hash_file = _verify_hash(_get_file(src, dst, permissions)) + if hash_file != hash_expected: + log.error("It seems that dst %s is corrupted" % dst) + download_failures += 1 + if download_failures > download_retries: + raise EnvironmentError("Failed to retrieve %s. " + "Possible reasons - Network connectivity " + "problems or incorrect hash_expected " + "provided -> '%s'" % + (src, hash_expected)) + else: + log.error("Retrying download of src %s", src) + + return dst