diff --git a/tools/converter.py b/tools/converter.py index c2469d7b10ed908e6559e0323cbd2137f06496b8..3567db2f6e0c0ad8ca186c1441c8c174300bc098 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -656,6 +656,28 @@ def print_configuration(configs): MaceLogger.summary(StringFormatter.table(header, data, title)) +def download_file(url, dst, num_retries=3): + from six.moves import urllib + + def _progress(block_num, block_size, total_size): + sys.stdout.write( + '\r>> Downloading %s %.1f%%' % (url, + float(block_num * block_size) / + float(total_size) * 100.0)) + sys.stdout.flush() + + try: + urllib.request.urlretrieve(url, dst, _progress) + MaceLogger.info('\nDownloaded successfully.') + except (urllib.URLError, urllib.ContentTooShortError) as e: + MaceLogger.warning('Download error:', e.reason) + if num_retries > 0: + return download_file(url, dst, num_retries - 1) + else: + return False + return True + + def get_model_files(model_file_path, model_sha256_checksum, model_output_dir, @@ -670,8 +692,9 @@ def get_model_files(model_file_path, if not os.path.exists(model_file) or \ sha256_checksum(model_file) != model_sha256_checksum: MaceLogger.info("Downloading model, please wait ...") - six.moves.urllib.request.urlretrieve(model_file_path, model_file) - MaceLogger.info("Model downloaded successfully.") + if not download_file(model_file_path, model_file): + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "Model download failed.") if sha256_checksum(model_file) != model_sha256_checksum: MaceLogger.error(ModuleName.MODEL_CONVERTER, @@ -684,8 +707,9 @@ def get_model_files(model_file_path, if not os.path.exists(weight_file) or \ sha256_checksum(weight_file) != weight_sha256_checksum: MaceLogger.info("Downloading model weight, please wait ...") - six.moves.urllib.request.urlretrieve(weight_file_path, weight_file) - MaceLogger.info("Model weight downloaded successfully.") + if not download_file(weight_file_path, weight_file): + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "Model download failed.") if weight_file: if sha256_checksum(weight_file) != weight_sha256_checksum: