提交 a9a68bc9 编写于 作者: F Frédéric Branchaud-Charron 提交者: François Chollet

Reuse validation enqueuer (#10476)

* Reuse validation enqueuer

* Fix CI

* Indent val_data correctly

* Typo
上级 8a8a19ba
......@@ -102,26 +102,46 @@ def fit_generator(model,
val_enqueuer = None
try:
if do_validation and not val_gen:
# Prepare data for validation
if len(validation_data) == 2:
val_x, val_y = validation_data
val_sample_weight = None
elif len(validation_data) == 3:
val_x, val_y, val_sample_weight = validation_data
if do_validation:
if val_gen and workers > 0:
# Create an Enqueuer that can be reused
val_data = validation_data
if isinstance(val_data, Sequence):
val_enqueuer = OrderedEnqueuer(val_data,
use_multiprocessing=use_multiprocessing)
validation_steps = len(val_data)
else:
val_enqueuer = GeneratorEnqueuer(val_data,
use_multiprocessing=use_multiprocessing)
val_enqueuer.start(workers=workers,
max_queue_size=max_queue_size)
val_enqueuer_gen = val_enqueuer.get()
elif val_gen:
val_data = validation_data
if isinstance(val_data, Sequence):
val_enqueuer_gen = iter(val_data)
else:
val_enqueuer_gen = val_data
else:
raise ValueError('`validation_data` should be a tuple '
'`(val_x, val_y, val_sample_weight)` '
'or `(val_x, val_y)`. Found: ' +
str(validation_data))
val_x, val_y, val_sample_weights = model._standardize_user_data(
val_x, val_y, val_sample_weight)
val_data = val_x + val_y + val_sample_weights
if model.uses_learning_phase and not isinstance(K.learning_phase(),
int):
val_data += [0.]
for cbk in callbacks:
cbk.validation_data = val_data
# Prepare data for validation
if len(validation_data) == 2:
val_x, val_y = validation_data
val_sample_weight = None
elif len(validation_data) == 3:
val_x, val_y, val_sample_weight = validation_data
else:
raise ValueError('`validation_data` should be a tuple '
'`(val_x, val_y, val_sample_weight)` '
'or `(val_x, val_y)`. Found: ' +
str(validation_data))
val_x, val_y, val_sample_weights = model._standardize_user_data(
val_x, val_y, val_sample_weight)
val_data = val_x + val_y + val_sample_weights
if model.uses_learning_phase and not isinstance(K.learning_phase(),
int):
val_data += [0.]
for cbk in callbacks:
cbk.validation_data = val_data
if workers > 0:
if is_sequence:
......@@ -204,11 +224,9 @@ def fit_generator(model,
if steps_done >= steps_per_epoch and do_validation:
if val_gen:
val_outs = model.evaluate_generator(
validation_data,
val_enqueuer_gen,
validation_steps,
workers=workers,
use_multiprocessing=use_multiprocessing,
max_queue_size=max_queue_size)
workers=0)
else:
# No need for try/except because
# data has already been validated.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册