提交 41b0cfa4 编写于 作者: A Allen Lavoie 提交者: TensorFlower Gardener

Fix import of SavedModels which use Defun

We weren't properly handling the op-type-is-function-name calling convention.

PiperOrigin-RevId: 258441889
上级 179b3884
......@@ -354,6 +354,12 @@ class _EagerDefinedFunction(object):
input_ops = set(arg.op for arg in inputs)
operations = [op for op in graph.get_operations() if op not in input_ops]
graph_output_names = graph._output_names # pylint: disable=protected-access
if (graph_output_names is not None
and all(t in graph_output_names for t in outputs)):
output_names = [compat.as_bytes(graph_output_names[t]) for t in outputs]
else:
output_names = []
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
......@@ -361,7 +367,7 @@ class _EagerDefinedFunction(object):
[o._c_op for o in operations], # pylint: disable=protected-access
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access
[],
output_names,
[o._c_op for o in graph.control_outputs], # pylint: disable=protected-access
[], # control_output_names
None,
......
......@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import op_selector
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
UnliftableError = op_selector.UnliftableError
......@@ -58,7 +59,7 @@ _ControlMutation = collections.namedtuple(
["copied_op", "old_graph_op"])
def _copy_non_source(op, graph, op_map):
def _copy_non_source(op, graph, op_map, base_graph):
"""Copy an op directly to a given graph.
Generally `op`'s inputs should already have been copied. If this is not the
......@@ -70,6 +71,7 @@ def _copy_non_source(op, graph, op_map):
op: The op to be copied.
graph: The destination graph.
op_map: A dict mapping ops and tensors in the old graph to the new one.
base_graph: The graph we're copying from, for any necessary functions.
Returns:
A tuple of (required_inputs, required_control_inputs):
required_inputs:
......@@ -113,6 +115,11 @@ def _copy_non_source(op, graph, op_map):
# to signal that the op was built inside a tpu_replicate context; if we're
# lifting it to another graph we're similarly lifting it into another context.
with ops.control_dependencies(copied_control_inputs), ops.device(op.device):
# pylint: disable=protected-access
f = base_graph._functions.get(op.type, None)
if f is not None and compat.as_str(f.name) not in graph._functions:
f.add_to_graph(graph)
# pylint: enable=protected-access
copied_op = graph.create_op(
op_type=op.type,
inputs=copied_inputs,
......@@ -133,7 +140,8 @@ def _copy_non_source(op, graph, op_map):
for mutation in control_mutations])
def _copy_source(s, graph, op_map, handle_captures, inverse_captures):
def _copy_source(s, graph, op_map, handle_captures, inverse_captures,
base_graph):
"""Create a source in a graph based on a Tensor from a different graph.
This function creates a placeholder analog of `s` in a graph with the
......@@ -156,6 +164,7 @@ def _copy_source(s, graph, op_map, handle_captures, inverse_captures):
graph or simply create a vanilla placeholder.
inverse_captures: A dict mapping s back to the Tensor or Variable that it
captures.
base_graph: The graph being copied from.
"""
if handle_captures and s in inverse_captures:
copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name)
......@@ -163,7 +172,8 @@ def _copy_source(s, graph, op_map, handle_captures, inverse_captures):
# Copy the default value to the graph.
default_value = s.op.inputs[0]
unavailable_inputs, unavailable_control_inputs = _copy_non_source(
op=default_value.op, graph=graph, op_map=op_map)
op=default_value.op, graph=graph, op_map=op_map,
base_graph=base_graph)
if unavailable_inputs or unavailable_control_inputs:
raise AssertionError(
"Could not copy source node {} because it has inputs."
......@@ -289,7 +299,8 @@ def lift_to_graph(init_tensors, graph, sources=None,
graph=graph,
op_map=op_map,
handle_captures=handle_captures,
inverse_captures=inverse_captures)
inverse_captures=inverse_captures,
base_graph=base_graph)
for s in sources:
source_ops.add(s.op)
_copy_source(
......@@ -297,7 +308,8 @@ def lift_to_graph(init_tensors, graph, sources=None,
graph=graph,
op_map=op_map,
handle_captures=handle_captures,
inverse_captures=inverse_captures)
inverse_captures=inverse_captures,
base_graph=base_graph)
input_mutations = []
control_mutations = []
......@@ -305,7 +317,7 @@ def lift_to_graph(init_tensors, graph, sources=None,
if op in source_ops:
continue
new_input_mutations, new_control_mutations = _copy_non_source(
op=op, graph=graph, op_map=op_map)
op=op, graph=graph, op_map=op_map, base_graph=base_graph)
input_mutations.extend(new_input_mutations)
control_mutations.extend(new_control_mutations)
......@@ -326,4 +338,3 @@ def lift_to_graph(init_tensors, graph, sources=None,
# pylint: enable=protected-access
return op_map
......@@ -193,6 +193,10 @@ class FuncGraph(ops.Graph):
self._watched_variables = weakref.WeakSet()
self.outer_graph = ops.get_default_graph()
self.captures = py_collections.OrderedDict()
# If not None, records the names of output args of this function. Used to
# preserve the output names in the signature of a serialized+deserialized
# function. Private at the moment mostly because it's often out of date.
self._output_names = None
self.deferred_captures = py_collections.OrderedDict()
# Inherit capture-by-value from outer graph.
if capture_by_value is not None:
......
......@@ -46,7 +46,8 @@ def function_def_to_graph(fdef, input_shapes=None, copy_functions=True):
a shape is None, the corresponding input placeholder will have unknown
shape.
copy_functions: Whether to copy all functions that exists in default graph
(independently of being used or not) to the created FuncGraph.
(independently of being used or not) to the created FuncGraph. Functions
required for graph import will be copied regardless.
Returns:
A FuncGraph.
......@@ -95,10 +96,16 @@ def function_def_to_graph(fdef, input_shapes=None, copy_functions=True):
for output_index, shape in enumerate(
output_shapes.list.shape[:len(op.outputs)]):
op.outputs[output_index].set_shape(shape)
output_names = {}
for ret_arg_def, tensor_name in zip(
fdef.signature.output_arg, output_tensor_names):
output_names[func_graph.get_tensor_by_name(tensor_name)] = (
ret_arg_def.name)
func_graph._output_names = output_names # pylint: disable=protected-access
return func_graph
def _is_function(fname):
def is_function(fname):
"""Checks for a function definition with `fname` in the current context."""
if context.executing_eagerly():
return context.context().has_function(fname)
......@@ -124,7 +131,8 @@ def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
`fdef.signature.input_arg`. If a shape is None, the corresponding input
placeholder will have unknown shape.
copy_functions: Whether to copy all functions that exists in default graph
(independently of being used or not) to the created GraphDef.
(independently of being used or not) to the created GraphDef. Directly
referenced functions are copied regardless.
Returns:
A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
......@@ -140,10 +148,18 @@ def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
producer=versions.GRAPH_DEF_VERSION,
min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
default_graph = ops.get_default_graph()
copied_functions = set()
# Copy *all* functions from outer graph to `graph_def` so that both direct
# and indirect references are safely handled.
if copy_functions:
ops.get_default_graph()._copy_functions_to_graph_def(graph_def, 0) # pylint: disable=protected-access
# pylint: disable=protected-access
default_graph._copy_functions_to_graph_def(graph_def, 0)
for function_name in default_graph._functions.keys():
copied_functions.add(function_name)
# pylint: enable=protected-access
if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
raise ValueError("Length of input_shapes must match the number of " +
......@@ -184,17 +200,27 @@ def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
nested_to_flat_tensor_name[control_name] = control_name
for node_def in fdef.node_def:
op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access
f = default_graph._functions.get(node_def.op, None) # pylint: disable=protected-access
if f is not None and hasattr(f, "signature"):
op_def = f.signature
if node_def.op not in copied_functions:
# Since this function is referenced as an op type, we have no choice but
# to copy it into the GraphDef if we want downstream tools to process
# it.
graph_def.library.function.append(f.definition)
copied_functions.add(node_def.op)
else:
op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access
for attr in op_def.attr:
if attr.type == "func":
fname = node_def.attr[attr.name].func.name
if not _is_function(fname):
if not is_function(fname):
raise ValueError("%s function not found." % fname)
elif attr.type == "list(func)":
for fn in node_def.attr[attr.name].list.func:
fname = fn.name
if not _is_function(fname):
if not is_function(fname):
raise ValueError("%s function not found." % fname)
# Iterate over output_args in op_def to build the map.
......
......@@ -22,6 +22,7 @@ import collections
import re
from tensorflow.core.framework import function_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as function_lib
from tensorflow.python.framework import func_graph as func_graph_lib
......@@ -271,7 +272,7 @@ def recreate_function(saved_function, concrete_functions):
decorator_argspec=function_spec.fullargspec)
def load_function_def_library(library):
def load_function_def_library(library, load_shared_name_suffix=None):
"""Load a set of functions as concrete functions without captured inputs.
Functions names are manipulated during load such that they do not overlap
......@@ -279,6 +280,8 @@ def load_function_def_library(library):
Args:
library: FunctionDefLibrary proto message.
load_shared_name_suffix: If specified, used to uniquify shared
names. Otherwise a unique name is generated.
Returns:
Map of original function names in the library to instances of
......@@ -289,14 +292,16 @@ def load_function_def_library(library):
"""
functions = {}
load_shared_name_suffix = "_load_{}".format(ops.uid())
if load_shared_name_suffix is None:
load_shared_name_suffix = "_load_{}".format(ops.uid())
for fdef in _sort_function_defs(library):
copy = _fix_fdef(fdef, functions, load_shared_name_suffix)
# There is no need to copy functions into the function def graph.
# It leads to a O(n^2) increase of memory when importing functions
# and the extra function definitions are a no-op since they already
# imported as a function before (due to the topologic sort import).
# There is no need to copy all functions into the function def graph. It
# leads to a O(n^2) increase of memory when importing functions and the
# extra function definitions are a no-op since they already imported as a
# function before and passed in explicitly (due to the topologic sort
# import).
func_graph = function_def_lib.function_def_to_graph(
copy, copy_functions=False)
......@@ -304,6 +309,8 @@ def load_function_def_library(library):
functions[dep].add_to_graph(func_graph)
func = function_lib.ConcreteFunction(func_graph)
func.add_to_graph()
if context.executing_eagerly():
func.add_to_graph(ops.get_default_graph())
functions[fdef.signature.name] = func
......@@ -347,6 +354,39 @@ def _sort_function_defs(library):
return [reverse[x] for x in output]
def fix_node_def(node_def, functions, shared_name_suffix, debug_name):
"""Replace functions calls and shared names in `node_def`."""
if "_gradient_op_type" in node_def.attr:
if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]:
# TODO(andresp): This code assumes that the gradient registered for this
# function call is the default gradient for the function and not a
# custom one.
fname = node_def.attr["f"].func.name
node_def.attr["_gradient_op_type"].s = compat.as_bytes(
functions[fname]._gradient_name) # pylint: disable=protected-access
else:
logging.warning("Importing a function (%s) with ops with custom "
"gradients. Will likely fail if a gradient is "
"requested.", debug_name)
if node_def.op in functions:
node_def.op = functions[node_def.op].name
for _, attr_value in node_def.attr.items():
if attr_value.func.name:
attr_value.func.name = functions[attr_value.func.name].name
# TODO(b/124205571): Avoid accidental sharing and destruction of restored
# resources. For now uniquify "shared_name" when loading functions to avoid
# sharing.
if "shared_name" in node_def.attr:
if node_def.attr["shared_name"].s:
node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix)
else:
# Blank shared_name attributes would use the node name, so we'll start
# with that when uniquifying.
node_def.attr["shared_name"].s = (
compat.as_bytes(node_def.name) + compat.as_bytes(shared_name_suffix))
def _fix_fdef(orig_fdef, functions, shared_name_suffix):
"""Fixes a FunctionDef proto to be loaded in current context.
......@@ -367,41 +407,25 @@ def _fix_fdef(orig_fdef, functions, shared_name_suffix):
fdef = function_pb2.FunctionDef()
fdef.CopyFrom(orig_fdef)
for node_def in fdef.node_def:
if "_gradient_op_type" in node_def.attr:
if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]:
# TODO(andresp): This code assumes that the gradient registered for this
# function call is the default gradient for the function and not a
# custom one.
fname = node_def.attr["f"].func.name
node_def.attr["_gradient_op_type"].s = compat.as_bytes(
functions[fname]._gradient_name) # pylint: disable=protected-access
else:
logging.warning("Importing a function (%s) with ops with custom "
"gradients. Will likely fail if a gradient is "
"requested.", fdef.signature.name)
for _, attr_value in node_def.attr.items():
if attr_value.func.name:
attr_value.func.name = functions[attr_value.func.name].name
# TODO(b/124205571): Avoid accidental sharing and destruction of restored
# resources. For now uniquify "shared_name" when loading functions to avoid
# sharing.
if "shared_name" in node_def.attr:
node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix)
fix_node_def(node_def, functions, shared_name_suffix, fdef.signature.name)
fdef.signature.name = _clean_function_name(fdef.signature.name)
return fdef
def _list_function_deps(fdef):
"""Find functions referenced in `fdef`."""
# TODO(andresp): Recurse into list attributes and into NameAttrList attrs both
# when listing deps and when fixing them. `function_def_to_graph` also
# requires fixes.
deps = set()
for node_def in fdef.node_def:
for _, attr_value in node_def.attr.items():
if attr_value.WhichOneof("value") == "func":
deps.add(attr_value.func.name)
if function_def_lib.is_function(node_def.op):
deps.add(node_def.op)
else:
for _, attr_value in node_def.attr.items():
if attr_value.WhichOneof("value") == "func":
deps.add(attr_value.func.name)
return deps
......
......@@ -36,6 +36,7 @@ from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
......@@ -1742,6 +1743,22 @@ class LoadTest(test.TestCase, parameterized.TestCase):
gen_resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=False)
def test_function_called_as_operation(self, cycles):
@framework_function.Defun(dtypes.float32)
def inner(x):
return x + 1.
@def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.float32)])
def outer(x):
return inner(x)
root = module.Module()
root.f = outer
imported = self.cycle(root, cycles)
self.assertAllClose(2., imported.f(constant_op.constant(1.)))
class SingleCycleTests(test.TestCase, parameterized.TestCase):
......
......@@ -27,6 +27,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.training import monitored_session
......@@ -152,6 +153,23 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
def load(self, tags):
"""Creates an object from the MetaGraph identified by `tags`."""
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
load_shared_name_suffix = "_load_{}".format(ops.uid())
functions = function_deserialization.load_function_def_library(
meta_graph_def.graph_def.library,
load_shared_name_suffix=load_shared_name_suffix)
# Replace existing functions in the MetaGraphDef with renamed functions so
# we don't have duplicates or name collisions.
meta_graph_def.graph_def.library.Clear()
for function in functions.values():
meta_graph_def.graph_def.library.function.append(
function._inference_function.definition) # pylint: disable=protected-access
# We've renamed functions and shared names. We need the same operation on
# the GraphDef itself for consistency.
for node_def in meta_graph_def.graph_def.node:
function_deserialization.fix_node_def(node_def, functions,
load_shared_name_suffix,
debug_name="MetaGraph import")
load_graph_returns = [None]
wrapped = wrap_function.wrap_function(
functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
......
......@@ -28,6 +28,7 @@ from tensorflow.python.eager import lift_to_graph
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
......@@ -513,6 +514,31 @@ class LoadTest(test.TestCase):
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
def _model_with_defun(self):
"""Generate a graph with a Defun and serialize in V1 format."""
export_graph = ops.Graph()
with export_graph.as_default():
@framework_function.Defun(dtypes.int64)
def f(x):
return x + 1
in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1])
out = f(in_placeholder)
with session_lib.Session() as session:
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(
session,
path,
inputs={"start": in_placeholder},
outputs={"output": out})
return path
def test_load_defun(self):
path = self._model_with_defun()
imported = load.load(path)
imported_fn = imported.signatures["serving_default"]
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([43], imported_fn(forty_two)["output"].numpy())
if __name__ == "__main__":
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册