提交 9bfb72bc 编写于 作者: A A. Unique TensorFlower 提交者: Gunhan Gulsoy

Added retries to download function.

Change: 139964468
上级 e088c887
......@@ -44,6 +44,18 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
py_test(
name = "base_test",
size = "small",
srcs = ["base_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:framework_test_lib",
],
)
py_test(
name = "load_csv_test",
size = "small",
......
......@@ -23,7 +23,9 @@ import collections
import csv
import os
from os import path
import random
import tempfile
import time
import numpy as np
from six.moves import urllib
......@@ -121,6 +123,73 @@ def load_boston(data_path=None):
features_dtype=np.float)
def retry(initial_delay,
max_delay,
factor=2.0,
jitter=0.25,
is_retriable=None):
"""Simple decorator for wrapping retriable functions.
Args:
initial_delay: the initial delay.
factor: each subsequent retry, the delay is multiplied by this value.
(must be >= 1).
jitter: to avoid lockstep, the returned delay is multiplied by a random
number between (1-jitter) and (1+jitter). To add a 20% jitter, set
jitter = 0.2. Must be < 1.
max_delay: the maximum delay allowed (actual max is
max_delay * (1 + jitter).
is_retriable: (optional) a function that takes an Exception as an argument
and returns true if retry should be applied.
"""
if factor < 1:
raise ValueError('factor must be >= 1; was %f' % (factor,))
if jitter >= 1:
raise ValueError('jitter must be < 1; was %f' % (jitter,))
# Generator to compute the individual delays
def delays():
delay = initial_delay
while delay <= max_delay:
yield delay * random.uniform(1 - jitter, 1 + jitter)
delay *= factor
def wrap(fn):
"""Wrapper function factory invoked by decorator magic."""
def wrapped_fn(*args, **kwargs):
"""The actual wrapper function that applies the retry logic."""
for delay in delays():
try:
return fn(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except)
if is_retriable is None:
continue
if is_retriable(e):
time.sleep(delay)
else:
raise
return fn(*args, **kwargs)
return wrapped_fn
return wrap
_RETRIABLE_ERRNOS = {
110, # Connection timed out [socket.py]
}
def _is_retriable(e):
return isinstance(e, IOError) and e.errno in _RETRIABLE_ERRNOS
@retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
def urlretrieve_with_retry(url, filename):
urllib.request.urlretrieve(url, filename)
def maybe_download(filename, work_directory, source_url):
"""Download the data from source url, unless it's already here.
......@@ -138,7 +207,7 @@ def maybe_download(filename, work_directory, source_url):
if not gfile.Exists(filepath):
with tempfile.NamedTemporaryFile() as tmpfile:
temp_file_name = tmpfile.name
urllib.request.urlretrieve(source_url, temp_file_name)
urlretrieve_with_retry(source_url, temp_file_name)
gfile.Copy(temp_file_name, filepath)
with gfile.GFile(filepath) as f:
size = f.size()
......
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import base
mock = tf.test.mock
_TIMEOUT = IOError(110, "timeout")
class BaseTest(tf.test.TestCase):
"""Test load csv functions."""
def testUrlretrieveRetriesOnIOError(self):
with mock.patch.object(base, "time") as mock_time:
with mock.patch.object(base, "urllib") as mock_urllib:
mock_urllib.request.urlretrieve.side_effect = [
_TIMEOUT,
_TIMEOUT,
_TIMEOUT,
_TIMEOUT,
_TIMEOUT,
None
]
base.urlretrieve_with_retry("http://dummy.com", "/tmp/dummy")
# Assert full backoff was tried
actual_list = [arg[0][0] for arg in mock_time.sleep.call_args_list]
expected_list = [1, 2, 4, 8, 16]
for actual, expected in zip(actual_list, expected_list):
self.assertLessEqual(abs(actual - expected), 0.25 * expected)
self.assertEquals(len(actual_list), len(expected_list))
def testUrlretrieveRaisesAfterRetriesAreExhausted(self):
with mock.patch.object(base, "time") as mock_time:
with mock.patch.object(base, "urllib") as mock_urllib:
mock_urllib.request.urlretrieve.side_effect = [
_TIMEOUT,
_TIMEOUT,
_TIMEOUT,
_TIMEOUT,
_TIMEOUT,
_TIMEOUT,
]
with self.assertRaises(IOError):
base.urlretrieve_with_retry("http://dummy.com", "/tmp/dummy")
# Assert full backoff was tried
actual_list = [arg[0][0] for arg in mock_time.sleep.call_args_list]
expected_list = [1, 2, 4, 8, 16]
for actual, expected in zip(actual_list, expected_list):
self.assertLessEqual(abs(actual - expected), 0.25 * expected)
self.assertEquals(len(actual_list), len(expected_list))
def testUrlretrieveRaisesOnNonRetriableErrorWithoutRetry(self):
with mock.patch.object(base, "time") as mock_time:
with mock.patch.object(base, "urllib") as mock_urllib:
mock_urllib.request.urlretrieve.side_effect = [
IOError(2, "No such file or directory"),
]
with self.assertRaises(IOError):
base.urlretrieve_with_retry("http://dummy.com", "/tmp/dummy")
# Assert no retries
self.assertFalse(mock_time.called)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册