提交 4b74fc54 编写于 作者: F Frédéric Branchaud-Charron 提交者: François Chollet

Fix sequence bug (#9513)

* Made Sequence iterable

* Made it python2 compliant
上级 7ef5244a
......@@ -2161,6 +2161,9 @@ class Model(Container):
wait_time=wait_time)
val_enqueuer.start(workers=workers, max_queue_size=max_queue_size)
validation_generator = val_enqueuer.get()
else:
if isinstance(validation_data, Sequence):
validation_generator = iter(validation_data)
else:
validation_generator = validation_data
else:
......@@ -2193,6 +2196,9 @@ class Model(Container):
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
else:
output_generator = generator
......@@ -2365,6 +2371,9 @@ class Model(Container):
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
else:
output_generator = generator
......@@ -2489,6 +2498,9 @@ class Model(Container):
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
else:
output_generator = generator
......
......@@ -366,6 +366,12 @@ class Sequence(object):
"""
pass
def __iter__(self):
"""Create an infinite generator that iterate over the Sequence."""
while True:
for item in (self[i] for i in range(len(self))):
yield item
# Global variables to be shared across processes
_SHARED_SEQUENCES = {}
......
......@@ -6,12 +6,21 @@ import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense
from keras.utils.test_utils import keras_test
from keras.utils import Sequence
STEPS_PER_EPOCH = 100
STEPS = 100
WORKERS = 4
class DummySequence(Sequence):
def __getitem__(self, idx):
return np.zeros([10, 2]), np.ones([10])
def __len__(self):
return 10
@pytest.fixture
def in_tmpdir(tmpdir):
"""Runs a function in a temporary directory.
......@@ -175,6 +184,22 @@ def test_multiprocessing_training():
workers=0,
use_multiprocessing=False)
# - For Sequence
model.fit_generator(DummySequence(),
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=custom_generator(True),
validation_steps=1,
max_queue_size=10,
workers=0,
use_multiprocessing=True)
model.fit_generator(DummySequence(),
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=custom_generator(True),
validation_steps=1,
max_queue_size=10,
workers=0,
use_multiprocessing=False)
# Test invalid use cases
def invalid_generator():
while True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册