提交 b75b2f7d 编写于 作者: E Elton Viana 提交者: Gabriel de Marmiesse

Small refactors on the keras.utils module (#13388)

* Use .format calls for string interpolation on utils

* Use generators over listcomps whenever possible to save memory
上级 c8f66d15
......@@ -32,19 +32,19 @@ def normalize_tuple(value, n, name):
try:
value_tuple = tuple(value)
except TypeError:
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value))
raise ValueError('The `{}` argument must be a tuple of {} '
'integers. Received: {}'.format(name, n, value))
if len(value_tuple) != n:
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value))
raise ValueError('The `{}` argument must be a tuple of {} '
'integers. Received: {}'.format(name, n, value))
for single_value in value_tuple:
try:
int(single_value)
except ValueError:
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value) + ' '
'including element ' + str(single_value) + ' of '
'type ' + str(type(single_value)))
raise ValueError('The `{}` argument must be a tuple of {} '
'integers. Received: {} including element {} '
'of type {}'.format(name, n, value, single_value,
type(single_value)))
return value_tuple
......@@ -55,7 +55,7 @@ def normalize_padding(value):
allowed.add('full')
if padding not in allowed:
raise ValueError('The `padding` argument must be one of "valid", "same" '
'(or "causal" for Conv1D). Received: ' + str(padding))
'(or "causal" for Conv1D). Received: {}'.format(padding))
return padding
......
......@@ -67,8 +67,8 @@ if sys.version_info[0] == 2:
break
with closing(urlopen(url, data)) as response, open(filename, 'wb') as fd:
for chunk in chunk_read(response, reporthook=reporthook):
fd.write(chunk)
for chunk in chunk_read(response, reporthook=reporthook):
fd.write(chunk)
else:
from six.moves.urllib.request import urlretrieve
......@@ -195,10 +195,10 @@ def get_file(fname,
# File found; verify integrity if a hash was provided.
if file_hash is not None:
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
print('A local file was found, but it seems to be '
'incomplete or outdated because the ' + hash_algorithm +
' file hash does not match the original value of ' +
file_hash + ' so we will re-download the data.')
print('A local file was found, but it seems to be incomplete'
' or outdated because the {} file hash does not match '
'the original value of {} so we will re-download the '
'data.'.format(hash_algorithm, file_hash))
download = True
else:
download = True
......@@ -725,9 +725,9 @@ class GeneratorEnqueuer(SequenceEnqueuer):
while self.queue.qsize() > 0:
last_ones.append(self.queue.get(block=True))
# Wait for them to complete
list(map(lambda f: f.wait(), last_ones))
[f.wait() for f in last_ones]
# Keep the good ones
last_ones = [future.get() for future in last_ones if future.successful()]
last_ones = (future.get() for future in last_ones if future.successful())
for inputs in last_ones:
if inputs is not None:
yield inputs
......
......@@ -126,7 +126,7 @@ def deserialize_keras_object(identifier, module_objects=None,
# In this case we are dealing with a Keras config dictionary.
config = identifier
if 'class_name' not in config or 'config' not in config:
raise ValueError('Improper config format: ' + str(config))
raise ValueError('Improper config format: {}'.format(config))
class_name = config['class_name']
if custom_objects and class_name in custom_objects:
cls = custom_objects[class_name]
......@@ -136,8 +136,8 @@ def deserialize_keras_object(identifier, module_objects=None,
module_objects = module_objects or {}
cls = module_objects.get(class_name)
if cls is None:
raise ValueError('Unknown ' + printable_module_name +
': ' + class_name)
raise ValueError('Unknown {}: {}'.format(printable_module_name,
class_name))
if hasattr(cls, 'from_config'):
custom_objects = custom_objects or {}
if has_arg(cls.from_config, 'custom_objects'):
......@@ -163,12 +163,12 @@ def deserialize_keras_object(identifier, module_objects=None,
else:
fn = module_objects.get(function_name)
if fn is None:
raise ValueError('Unknown ' + printable_module_name +
':' + function_name)
raise ValueError('Unknown {}: {}'.format(printable_module_name,
function_name))
return fn
else:
raise ValueError('Could not interpret serialized ' +
printable_module_name + ': ' + identifier)
raise ValueError('Could not interpret serialized '
'{}: {}'.format(printable_module_name, identifier))
def func_dump(func):
......@@ -514,7 +514,7 @@ def unpack_singleton(x):
def object_list_uid(object_list):
object_list = to_list(object_list)
return ', '.join([str(abs(id(x))) for x in object_list])
return ', '.join((str(abs(id(x))) for x in object_list))
def is_all_none(iterable_or_element):
......
......@@ -153,7 +153,7 @@ def multi_gpu_model(model, gpus=None, cpu_merge=True, cpu_relocation=False):
if not gpus:
# Using all visible GPUs when not specifying `gpus`
# e.g. CUDA_VISIBLE_DEVICES=0,2 python keras_mgpu.py
gpus = len([x for x in available_devices if '/gpu:' in x])
gpus = len((x for x in available_devices if '/gpu:' in x))
if isinstance(gpus, (list, tuple)):
if len(gpus) <= 1:
......
......@@ -149,7 +149,7 @@ def model_to_dot(model,
inputlabels = str(layer.input_shape)
elif hasattr(layer, 'input_shapes'):
inputlabels = ', '.join(
[str(ishape) for ishape in layer.input_shapes])
(str(ishape) for ishape in layer.input_shapes))
else:
inputlabels = 'multiple'
label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册