提交 e0f979b8 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix `predict` with `run_eagerly=True`

PiperOrigin-RevId: 225257343
上级 22af085f
......@@ -246,6 +246,21 @@ class CorrectnessTest(test.TestCase):
layer(1.) # Plain-value inputs are only valid in eager mode.
self.assertEqual(1, len(layer.losses))
def test_predict_correctness(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
i3 = keras.layers.Input(shape=(4, 5))
o = keras.layers.add([i1, i2, i3])
model = keras.models.Model([i1, i2, i3], o)
model.run_eagerly = True
x1 = np.random.random((2, 4, 5))
x2 = np.random.random((2, 4, 5))
x3 = np.random.random((2, 4, 5))
out = model.predict([x1, x2, x3])
self.assertAllClose(out, x1 + x2 + x3)
if __name__ == '__main__':
ops.enable_eager_execution()
......
......@@ -49,7 +49,7 @@ def model_iteration(model,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
shuffle=False,
initial_epoch=0,
mode='train',
batch_size=None,
......@@ -246,8 +246,10 @@ def model_iteration(model,
# Maintain compatibility with the existing names.
fit_generator = functools.partial(model_iteration, mode='train')
evaluate_generator = functools.partial(model_iteration, mode='test')
predict_generator = functools.partial(model_iteration, mode='predict')
evaluate_generator = functools.partial(
model_iteration, mode='test', shuffle=False)
predict_generator = functools.partial(
model_iteration, mode='predict', shuffle=False)
def _get_next_batch(output_generator, mode):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册