提交 c25282f9 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Adds support for arbitrarily nested `inputs` and `outputs` in

`keras.backend.function`.

PiperOrigin-RevId: 224886577
上级 34145277
......@@ -2926,17 +2926,12 @@ class GraphExecutionFunction(object):
def __init__(self, inputs, outputs, updates=None, name=None,
**session_kwargs):
updates = updates or []
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` to a Keras backend function '
'should be a list or tuple.')
if not isinstance(outputs, (list, tuple)):
raise TypeError('`outputs` of a Keras backend function '
'should be a list or tuple.')
if not isinstance(updates, (list, tuple)):
raise TypeError('`updates` in a Keras backend function '
'should be a list or tuple.')
self.inputs = list(inputs)
self.outputs = list(outputs)
self.inputs = nest.flatten(inputs)
self._outputs_structure = outputs
self.outputs = nest.flatten(outputs)
with ops.control_dependencies(self.outputs):
updates_ops = []
for update in updates:
......@@ -3033,8 +3028,7 @@ class GraphExecutionFunction(object):
self.fetch_callbacks[fetch](output)
def __call__(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
inputs = nest.flatten(inputs)
session = get_session()
feed_arrays = []
......@@ -3077,7 +3071,8 @@ class GraphExecutionFunction(object):
fetched = self._callable_fn(*array_vals,
run_metadata=self.run_metadata)
self._call_fetch_callbacks(fetched[-len(self._fetches):])
return fetched[:len(self.outputs)]
return nest.pack_sequence_as(self._outputs_structure,
fetched[:len(self.outputs)])
class EagerExecutionFunction(object):
......@@ -3093,17 +3088,12 @@ class EagerExecutionFunction(object):
def __init__(self, inputs, outputs, updates=None, name=None):
updates = updates or []
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` to a Keras backend function '
'should be a list or tuple.')
if not isinstance(outputs, (list, tuple)):
raise TypeError('`outputs` of a Keras backend function '
'should be a list or tuple.')
if not isinstance(updates, (list, tuple)):
raise TypeError('`updates` in a Keras backend function '
'should be a list or tuple.')
self.inputs = list(inputs)
self.outputs = list(outputs)
self.inputs = nest.flatten(inputs)
self._outputs_structure = outputs
self.outputs = nest.flatten(outputs)
self.name = name
graph = get_graph()
......@@ -3153,6 +3143,7 @@ class EagerExecutionFunction(object):
x.op.inputs[0])
def __call__(self, inputs):
inputs = nest.flatten(inputs)
converted_inputs = []
for tensor, value in zip(self.inputs, inputs):
if value is None:
......@@ -3169,7 +3160,8 @@ class EagerExecutionFunction(object):
value = math_ops.cast(value, tensor.dtype)
converted_inputs.append(value)
outputs = self._graph_fn(*converted_inputs)
return [x.numpy() for x in outputs]
return nest.pack_sequence_as(self._outputs_structure,
[x.numpy() for x in outputs])
@tf_export('keras.backend.function')
......
......@@ -1695,6 +1695,39 @@ class BackendGraphTests(test.TestCase):
self.assertEqual(callback.times_called, 1)
self.assertEqual(callback.callback_result, 200)
@test_util.run_in_graph_and_eager_modes
def test_function_dict_outputs(self):
x_ph = keras.backend.placeholder(shape=(), name='x')
y_ph = keras.backend.placeholder(shape=(), name='y')
outputs = {'x*y': y_ph * x_ph, 'x*x': x_ph * x_ph}
f = keras.backend.function(inputs=[x_ph, y_ph], outputs=outputs)
x, y = 2., 5.
results = f([x, y])
self.assertEqual(results['x*y'], 10.)
self.assertEqual(results['x*x'], 4)
@test_util.run_in_graph_and_eager_modes
def test_function_dict_inputs(self):
placeholders = {
'x': keras.backend.placeholder(shape=()),
'y': keras.backend.placeholder(shape=())
}
outputs = [placeholders['x'] * placeholders['y']]
f = keras.backend.function(inputs=placeholders, outputs=outputs)
results = f({'x': 2., 'y': 3.})
self.assertEqual(results[0], 6.)
@test_util.run_in_graph_and_eager_modes
def test_function_single_input_output(self):
x_ph = keras.backend.placeholder(shape=(), name='x')
output = x_ph * x_ph
f = keras.backend.function(x_ph, output)
result = f(2.)
self.assertEqual(result, 4.)
def test_placeholder(self):
x = keras.backend.placeholder(shape=(3, 4))
self.assertEqual(x.get_shape().as_list(), [3, 4])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册