提交 4b74fc54 编写于 作者: F Frédéric Branchaud-Charron 提交者: François Chollet

Fix sequence bug (#9513)

* Made Sequence iterable

* Made it python2 compliant
上级 7ef5244a
......@@ -2162,7 +2162,10 @@ class Model(Container):
val_enqueuer.start(workers=workers, max_queue_size=max_queue_size)
validation_generator = val_enqueuer.get()
else:
validation_generator = validation_data
if isinstance(validation_data, Sequence):
validation_generator = iter(validation_data)
else:
validation_generator = validation_data
else:
if len(validation_data) == 2:
val_x, val_y = validation_data
......@@ -2194,7 +2197,10 @@ class Model(Container):
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
output_generator = generator
if is_sequence:
output_generator = iter(generator)
else:
output_generator = generator
callback_model.stop_training = False
# Construct epoch logs.
......@@ -2366,7 +2372,10 @@ class Model(Container):
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
output_generator = generator
if is_sequence:
output_generator = iter(generator)
else:
output_generator = generator
while steps_done < steps:
generator_output = next(output_generator)
......@@ -2490,7 +2499,10 @@ class Model(Container):
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
output_generator = generator
if is_sequence:
output_generator = iter(generator)
else:
output_generator = generator
if verbose == 1:
progbar = Progbar(target=steps)
......
......@@ -366,6 +366,12 @@ class Sequence(object):
"""
pass
def __iter__(self):
"""Create an infinite generator that iterate over the Sequence."""
while True:
for item in (self[i] for i in range(len(self))):
yield item
# Global variables to be shared across processes
_SHARED_SEQUENCES = {}
......
......@@ -6,12 +6,21 @@ import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense
from keras.utils.test_utils import keras_test
from keras.utils import Sequence
STEPS_PER_EPOCH = 100
STEPS = 100
WORKERS = 4
class DummySequence(Sequence):
def __getitem__(self, idx):
return np.zeros([10, 2]), np.ones([10])
def __len__(self):
return 10
@pytest.fixture
def in_tmpdir(tmpdir):
"""Runs a function in a temporary directory.
......@@ -175,6 +184,22 @@ def test_multiprocessing_training():
workers=0,
use_multiprocessing=False)
# - For Sequence
model.fit_generator(DummySequence(),
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=custom_generator(True),
validation_steps=1,
max_queue_size=10,
workers=0,
use_multiprocessing=True)
model.fit_generator(DummySequence(),
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=custom_generator(True),
validation_steps=1,
max_queue_size=10,
workers=0,
use_multiprocessing=False)
# Test invalid use cases
def invalid_generator():
while True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册