diff --git a/python/paddle/tests/test_download.py b/python/paddle/tests/test_download.py index b8af7f6a80e72148a4f793a4de2188d3cc7a8b69..4be2dde1bccb132041723df8af7f5f36f24e133c 100644 --- a/python/paddle/tests/test_download.py +++ b/python/paddle/tests/test_download.py @@ -70,6 +70,13 @@ class TestDownload(unittest.TestCase): for url in urls: get_path_from_url(url, root_dir='./test') + def test_retry_exception(self, ): + with self.assertRaises(RuntimeError): + from paddle.utils.download import _download + _download( + 'www.baidu.com', + './test', ) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/utils/download.py b/python/paddle/utils/download.py index dda8abeff21c062011ba7e558f330eb92c87836a..3ad627ddea927467caaa1524285724850a5cdc36 100644 --- a/python/paddle/utils/download.py +++ b/python/paddle/utils/download.py @@ -186,7 +186,15 @@ def _download(url, path, md5sum=None): logger.info("Downloading {} from {}".format(fname, url)) - req = requests.get(url, stream=True) + try: + req = requests.get(url, stream=True) + except Exception as e: # requests.exceptions.ConnectionError + logger.info( + "Downloading {} from {} failed {} times with exception {}". + format(fname, url, retry_cnt + 1, str(e))) + time.sleep(1) + continue + if req.status_code != 200: raise RuntimeError("Downloading from {} failed with code " "{}!".format(url, req.status_code))