提交 d18c5645 编写于 作者: G Gabriel de Marmiesse 提交者: François Chollet

Refactoring: Simplified some if-else by creating a function similar to `to_list`. (#10679)

* Refactoring: Simplified some if-else by creating a function similar to `to_list`.

* Changed the name and the description of `first_or_list`.
上级 066aa6ab
......@@ -14,6 +14,7 @@ from ..utils.layer_utils import count_params
from ..utils.generic_utils import has_arg
from ..utils.generic_utils import object_list_uid
from ..utils.generic_utils import to_list
from ..utils.generic_utils import unpack_singleton
from ..utils.generic_utils import is_all_none
from ..legacy import interfaces
......@@ -427,10 +428,7 @@ class Layer(object):
'and thus cannot be built. '
'You can build it manually via: '
'`layer.build(batch_input_shape)`')
if len(input_shapes) == 1:
self.build(input_shapes[0])
else:
self.build(input_shapes)
self.build(unpack_singleton(input_shapes))
self.built = True
# Load weights that were specified at layer instantiation.
......@@ -468,10 +466,7 @@ class Layer(object):
if x in inputs_ls:
x = K.identity(x)
output_ls_copy.append(x)
if len(output_ls_copy) == 1:
output = output_ls_copy[0]
else:
output = output_ls_copy
output = unpack_singleton(output_ls_copy)
# Inferring the output shape is only relevant for Theano.
if all([s is not None
......@@ -668,10 +663,7 @@ class Layer(object):
', but the layer has only ' +
str(len(self._inbound_nodes)) + ' inbound nodes.')
values = getattr(self._inbound_nodes[node_index], attr)
if len(values) == 1:
return values[0]
else:
return values
return unpack_singleton(values)
def get_input_shape_at(self, node_index):
"""Retrieves the input shape(s) of a layer at a given node.
......@@ -897,10 +889,7 @@ class Layer(object):
[str(node.input_shapes) for node in self._inbound_nodes])
if len(all_input_shapes) == 1:
input_shapes = self._inbound_nodes[0].input_shapes
if len(input_shapes) == 1:
return input_shapes[0]
else:
return input_shapes
return unpack_singleton(input_shapes)
else:
raise AttributeError('The layer "' + str(self.name) +
' has multiple inbound nodes, '
......@@ -932,10 +921,7 @@ class Layer(object):
[str(node.output_shapes) for node in self._inbound_nodes])
if len(all_output_shapes) == 1:
output_shapes = self._inbound_nodes[0].output_shapes
if len(output_shapes) == 1:
return output_shapes[0]
else:
return output_shapes
return unpack_singleton(output_shapes)
else:
raise AttributeError('The layer "' + str(self.name) +
' has multiple inbound nodes, '
......@@ -1326,9 +1312,7 @@ def _collect_previous_mask(input_tensors):
masks.append(mask)
else:
masks.append(None)
if len(masks) == 1:
return masks[0]
return masks
return unpack_singleton(masks)
def _to_snake_case(name):
......@@ -1357,6 +1341,4 @@ def _collect_input_shape(input_tensors):
shapes.append(K.int_shape(x))
except TypeError:
shapes.append(None)
if len(shapes) == 1:
return shapes[0]
return shapes
return unpack_singleton(shapes)
......@@ -8,6 +8,7 @@ from .base_layer import Layer
from .base_layer import Node
from .. import backend as K
from ..legacy import interfaces
from ..utils.generic_utils import unpack_singleton
class InputLayer(Layer):
......@@ -177,7 +178,4 @@ def Input(shape=None, batch_shape=None,
# Return tensor including _keras_shape and _keras_history.
# Note that in this case train_output and test_output are the same pointer.
outputs = input_layer._inbound_nodes[0].output_tensors
if len(outputs) == 1:
return outputs[0]
else:
return outputs
return unpack_singleton(outputs)
......@@ -23,6 +23,7 @@ from ..utils.layer_utils import get_source_inputs
from ..utils.generic_utils import has_arg
from ..utils.generic_utils import to_list
from ..utils.generic_utils import object_list_uid
from ..utils.generic_utils import unpack_singleton
from ..legacy import interfaces
try:
......@@ -267,10 +268,7 @@ class Network(Layer):
node = layer._inbound_nodes[node_index]
mask = node.output_masks[tensor_index]
masks.append(mask)
if len(masks) == 1:
mask = masks[0]
else:
mask = masks
mask = unpack_singleton(masks)
self._output_mask_cache[mask_cache_key] = mask
# Build self.input_names and self.output_names.
......@@ -539,9 +537,7 @@ class Network(Layer):
'Found input_spec = ' +
str(layer.input_spec))
specs += layer.input_spec
if len(specs) == 1:
return specs[0]
return specs
return unpack_singleton(specs)
def call(self, inputs, mask=None):
"""Calls the model on new inputs.
......@@ -605,8 +601,8 @@ class Network(Layer):
cache_key = ', '.join([str(x) for x in input_shapes])
if cache_key in self._output_shape_cache:
output_shapes = self._output_shape_cache[cache_key]
if isinstance(output_shapes, list) and len(output_shapes) == 1:
return output_shapes[0]
if isinstance(output_shapes, list):
return unpack_singleton(output_shapes)
return output_shapes
else:
# Bad luck, we have to run the graph manually.
......@@ -643,10 +639,7 @@ class Network(Layer):
input_shape = layers_to_output_shapes[shape_key]
input_shapes.append(input_shape)
if len(input_shapes) == 1:
output_shape = layer.compute_output_shape(input_shapes[0])
else:
output_shape = layer.compute_output_shape(input_shapes)
output_shape = layer.compute_output_shape(unpack_singleton(input_shapes))
output_shapes = to_list(output_shape)
node_index = layer._inbound_nodes.index(node)
......@@ -669,8 +662,8 @@ class Network(Layer):
output_shapes.append(layers_to_output_shapes[key])
# Store in cache.
self._output_shape_cache[cache_key] = output_shapes
if isinstance(output_shapes, list) and len(output_shapes) == 1:
return output_shapes[0]
if isinstance(output_shapes, list):
return unpack_singleton(output_shapes)
return output_shapes
def run_internal_graph(self, inputs, masks=None):
......@@ -781,12 +774,10 @@ class Network(Layer):
# Update _keras_shape.
if all([hasattr(x, '_keras_shape') for x in computed_tensors]):
if len(computed_tensors) == 1:
shapes = to_list(layer.compute_output_shape(computed_tensors[0]._keras_shape))
uses_learning_phase = computed_tensors[0]._uses_learning_phase
else:
shapes = to_list(layer.compute_output_shape([x._keras_shape for x in computed_tensors]))
uses_learning_phase = any([x._uses_learning_phase for x in computed_tensors])
input_shapes = unpack_singleton([x._keras_shape for x in computed_tensors])
shapes = to_list(layer.compute_output_shape(input_shapes))
uses_learning_phase = any([x._uses_learning_phase for x in computed_tensors])
for x, s in zip(output_tensors, shapes):
x._keras_shape = s
x._uses_learning_phase = getattr(x, '_uses_learning_phase', False) or uses_learning_phase
......@@ -814,26 +805,18 @@ class Network(Layer):
cache_key = object_list_uid(inputs)
cache_key += '_' + object_list_uid(masks)
if len(output_tensors) == 1:
output_tensors = output_tensors[0]
self._output_tensor_cache[cache_key] = output_tensors
else:
self._output_tensor_cache[cache_key] = output_tensors
output_tensors = unpack_singleton(output_tensors)
self._output_tensor_cache[cache_key] = output_tensors
if len(output_masks) == 1:
output_masks = output_masks[0]
self._output_mask_cache[cache_key] = output_masks
else:
self._output_mask_cache[cache_key] = output_masks
output_masks = unpack_singleton(output_masks)
self._output_mask_cache[cache_key] = output_masks
if output_shapes is not None:
input_shapes = [x._keras_shape for x in inputs]
cache_key = ', '.join([str(x) for x in input_shapes])
if len(output_shapes) == 1:
output_shapes = output_shapes[0]
self._output_shape_cache[cache_key] = output_shapes
else:
self._output_shape_cache[cache_key] = output_shapes
output_shapes = unpack_singleton(output_shapes)
self._output_shape_cache[cache_key] = output_shapes
return output_tensors, output_masks, output_shapes
def get_config(self):
......@@ -1000,10 +983,7 @@ class Network(Layer):
# Call layer on its inputs, thus creating the node
# and building the layer if needed.
if input_tensors:
if len(input_tensors) == 1:
layer(input_tensors[0], **kwargs)
else:
layer(input_tensors, **kwargs)
layer(unpack_singleton(input_tensors), **kwargs)
def process_layer(layer_data):
"""Deserializes a layer, then call it on appropriate inputs.
......
......@@ -25,6 +25,7 @@ from .. import optimizers
from .. import losses
from .. import metrics as metrics_module
from ..utils.generic_utils import slice_arrays
from ..utils.generic_utils import unpack_singleton
from ..legacy import interfaces
......@@ -620,16 +621,10 @@ class Model(Network):
if outputs is None:
# Obtain symbolic outputs by calling the model.
if len(self.inputs) == 1:
if self._expects_training_arg:
outputs = self.call(self.inputs[0], training=training)
else:
outputs = self.call(self.inputs[0])
if self._expects_training_arg:
outputs = self.call(unpack_singleton(self.inputs), training=training)
else:
if self._expects_training_arg:
outputs = self.call(self.inputs, training=training)
else:
outputs = self.call(self.inputs)
outputs = self.call(unpack_singleton(self.inputs))
if isinstance(outputs, (list, tuple)):
outputs = list(outputs)
else:
......@@ -1218,9 +1213,7 @@ class Model(Network):
ins = x + y + sample_weights
self._make_train_function()
outputs = self.train_function(ins)
if len(outputs) == 1:
return outputs[0]
return outputs
return unpack_singleton(outputs)
def test_on_batch(self, x, y, sample_weight=None):
"""Test the model on a single batch of samples.
......@@ -1259,9 +1252,7 @@ class Model(Network):
ins = x + y + sample_weights
self._make_test_function()
outputs = self.test_function(ins)
if len(outputs) == 1:
return outputs[0]
return outputs
return unpack_singleton(outputs)
def predict_on_batch(self, x):
"""Returns predictions for a single batch of samples.
......@@ -1279,9 +1270,7 @@ class Model(Network):
ins = x
self._make_predict_function()
outputs = self.predict_function(ins)
if len(outputs) == 1:
return outputs[0]
return outputs
return unpack_singleton(outputs)
@interfaces.legacy_generator_methods_support
def fit_generator(self, generator,
......
......@@ -14,6 +14,7 @@ from .. import backend as K
from .. import callbacks as cbks
from ..utils.generic_utils import Progbar
from ..utils.generic_utils import slice_arrays
from ..utils.generic_utils import unpack_singleton
def fit_loop(model, f, ins,
......@@ -306,9 +307,7 @@ def predict_loop(model, f, ins, batch_size=32, verbose=0, steps=None):
outs[i][batch_start:batch_end] = batch_out
if verbose == 1:
progbar.update(batch_end)
if len(outs) == 1:
return outs[0]
return outs
return unpack_singleton(outs)
def test_loop(model, f, ins, batch_size=None, verbose=0, steps=None):
......@@ -415,6 +414,4 @@ def test_loop(model, f, ins, batch_size=None, verbose=0, steps=None):
for i in range(len(outs)):
if i not in stateful_metric_indices:
outs[i] /= num_samples
if len(outs) == 1:
return outs[0]
return outs
return unpack_singleton(outs)
......@@ -12,6 +12,7 @@ from ..utils.data_utils import Sequence
from ..utils.data_utils import GeneratorEnqueuer
from ..utils.data_utils import OrderedEnqueuer
from ..utils.generic_utils import Progbar
from ..utils.generic_utils import unpack_singleton
from .. import callbacks as cbks
......@@ -375,9 +376,7 @@ def evaluate_generator(model, generator,
weights=batch_sizes))
else:
averages.append(float(outs_per_batch[-1][i]))
if len(averages) == 1:
return averages[0]
return averages
return unpack_singleton(averages)
def predict_generator(model, generator,
......
......@@ -461,6 +461,22 @@ def to_list(x):
return [x]
def unpack_singleton(x):
"""Gets the first element if the iterable has only one value.
Otherwise return the iterable.
# Argument:
x: A list or tuple.
# Returns:
The same iterable or the first element.
"""
if len(x) == 1:
return x[0]
return x
def object_list_uid(object_list):
object_list = to_list(object_list)
return ', '.join([str(abs(id(x))) for x in object_list])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册