提交 dd07a4c4 编写于 作者: G Gabriel de Marmiesse 提交者: Frédéric Branchaud-Charron

Refactoring: Simplified some code by using the `to_list` function. (#10678)

### Summary
We have the method `to_list` in keras. Let's use it to make the codebase simpler!
### Related Issues

### PR Overview

- [ ] This PR requires new unit tests [y/n] (make sure tests are included)
- [ ] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date)
- [x] This PR is backwards compatible [y/n]
- [ ] This PR changes the current API [y/n] (all API changes need to be approved by fchollet)
上级 b2979c25
......@@ -25,6 +25,7 @@ from .. import optimizers
from .. import losses
from .. import metrics as metrics_module
from ..utils.generic_utils import slice_arrays
from ..utils.generic_utils import to_list
from ..utils.generic_utils import unpack_singleton
from ..legacy import interfaces
......@@ -155,8 +156,7 @@ class Model(Network):
masks = self.compute_mask(self.inputs, mask=None)
if masks is None:
masks = [None for _ in self.outputs]
if not isinstance(masks, list):
masks = [masks]
masks = to_list(masks)
# Prepare loss weights.
if loss_weights is None:
......
......@@ -14,6 +14,7 @@ from .. import backend as K
from .. import callbacks as cbks
from ..utils.generic_utils import Progbar
from ..utils.generic_utils import slice_arrays
from ..utils.generic_utils import to_list
from ..utils.generic_utils import unpack_singleton
......@@ -152,8 +153,7 @@ def fit_loop(model, f, ins,
callbacks.on_batch_begin(step_index, batch_logs)
outs = f(ins)
if not isinstance(outs, list):
outs = [outs]
outs = to_list(outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
......@@ -165,8 +165,7 @@ def fit_loop(model, f, ins,
val_outs = test_loop(model, val_f, val_ins,
steps=validation_steps,
verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
val_outs = to_list(val_outs)
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
......@@ -198,8 +197,7 @@ def fit_loop(model, f, ins,
ins_batch[i] = ins_batch[i].toarray()
outs = f(ins_batch)
if not isinstance(outs, list):
outs = [outs]
outs = to_list(outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
......@@ -212,8 +210,7 @@ def fit_loop(model, f, ins,
val_outs = test_loop(model, val_f, val_ins,
batch_size=batch_size,
verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
val_outs = to_list(val_outs)
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
......@@ -267,8 +264,7 @@ def predict_loop(model, f, ins, batch_size=32, verbose=0, steps=None):
unconcatenated_outs = []
for step in range(steps):
batch_outs = f(ins)
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
batch_outs = to_list(batch_outs)
if step == 0:
for batch_out in batch_outs:
unconcatenated_outs.append([])
......@@ -296,8 +292,7 @@ def predict_loop(model, f, ins, batch_size=32, verbose=0, steps=None):
ins_batch[i] = ins_batch[i].toarray()
batch_outs = f(ins_batch)
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
batch_outs = to_list(batch_outs)
if batch_index == 0:
# Pre-allocate the results arrays.
for batch_out in batch_outs:
......
......@@ -12,6 +12,7 @@ from ..utils.data_utils import Sequence
from ..utils.data_utils import GeneratorEnqueuer
from ..utils.data_utils import OrderedEnqueuer
from ..utils.generic_utils import Progbar
from ..utils.generic_utils import to_list
from ..utils.generic_utils import unpack_singleton
from .. import callbacks as cbks
......@@ -211,8 +212,7 @@ def fit_generator(model,
sample_weight=sample_weight,
class_weight=class_weight)
if not isinstance(outs, list):
outs = [outs]
outs = to_list(outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
......@@ -236,8 +236,7 @@ def fit_generator(model,
batch_size=batch_size,
sample_weight=val_sample_weights,
verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
val_outs = to_list(val_outs)
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
......@@ -342,8 +341,7 @@ def evaluate_generator(model, generator,
'or (x, y). Found: ' +
str(generator_output))
outs = model.test_on_batch(x, y, sample_weight=sample_weight)
if not isinstance(outs, list):
outs = [outs]
outs = to_list(outs)
outs_per_batch.append(outs)
if x is None or len(x) == 0:
......@@ -450,8 +448,7 @@ def predict_generator(model, generator,
x = generator_output
outs = model.predict_on_batch(x)
if not isinstance(outs, list):
outs = [outs]
outs = to_list(outs)
if not all_outs:
for out in outs:
......
......@@ -9,6 +9,7 @@ import numpy as np
from .. import backend as K
from .. import losses
from ..utils.generic_utils import to_list
def standardize_single_array(x):
......@@ -321,8 +322,7 @@ def collect_metrics(metrics, output_names):
nested_metrics = []
for name in output_names:
output_metrics = metrics.get(name, [])
if not isinstance(output_metrics, list):
output_metrics = [output_metrics]
output_metrics = to_list(output_metrics)
nested_metrics.append(output_metrics)
return nested_metrics
else:
......
......@@ -8,6 +8,7 @@ import warnings
from ..engine import Layer, InputSpec
from .. import backend as K
from ..utils import conv_utils
from ..utils.generic_utils import to_list
from .. import regularizers
from .. import constraints
from .. import activations
......@@ -521,10 +522,8 @@ class Recurrent(Layer):
# Compute the full input spec, including state
input_spec = self.input_spec
state_spec = self.state_spec
if not isinstance(input_spec, list):
input_spec = [input_spec]
if not isinstance(state_spec, list):
state_spec = [state_spec]
input_spec = to_list(input_spec)
state_spec = to_list(state_spec)
self.input_spec = input_spec + state_spec
# Compute the full inputs, including state
......
......@@ -224,8 +224,7 @@ def multi_gpu_model(model, gpus=None, cpu_merge=True, cpu_relocation=False):
# Apply model on slice
# (creating a model replica on the target device).
outputs = model(inputs)
if not isinstance(outputs, list):
outputs = [outputs]
outputs = to_list(outputs)
# Save the outputs for merging back together later.
for o in range(len(outputs)):
......
......@@ -11,6 +11,7 @@ import numpy as np
from ..utils.np_utils import to_categorical
from ..utils.generic_utils import has_arg
from ..utils.generic_utils import to_list
from ..models import Sequential
......@@ -291,8 +292,7 @@ class KerasClassifier(BaseWrapper):
y = to_categorical(y)
outputs = self.model.evaluate(x, y, **kwargs)
if not isinstance(outputs, list):
outputs = [outputs]
outputs = to_list(outputs)
for name, output in zip(self.model.metrics_names, outputs):
if name == 'acc':
return output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册