From 5591ca5f02377b27c3827b34c14b3f2f86451e06 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 22 Nov 2016 09:35:21 -0800 Subject: [PATCH] Python DefFun now creates functions that include NodeDefs (in addition to the older FunctionDef::Node format). Add an optional out_names argument to Defun and names to Declare's arguments, so that signatures for forward-declared DefFuns can have signatures in the name (required for this change). Change: 139919259 --- tensorflow/python/framework/function.py | 191 +++++++++++++++---- tensorflow/python/framework/function_test.py | 4 +- 2 files changed, 154 insertions(+), 41 deletions(-) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 3faf79859c7..12063949d3b 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -42,17 +42,18 @@ def _make_argname_from_tensor_name(name): return re.sub(":0$", "", name).replace(":", "_o") -def _tensor_to_argdef(t): +def _tensor_to_argdef(t, name=None): arg = op_def_pb2.OpDef.ArgDef() - arg.name = _make_argname_from_tensor_name(t.name) + if name is None: + arg.name = _make_argname_from_tensor_name(t.name) + else: + arg.name = name arg.type = t.dtype.as_datatype_enum return arg -def _get_node_def_attr(op): - # pylint: disable=protected-access - return op._node_def.attr - # pylint: enable=protected-access +def _get_node_def(op): + return op._node_def # pylint: disable=protected-access def _add_input_array(op, start, limit, dtype, func): @@ -122,17 +123,66 @@ def _add_output_list(op, start, limit, dtype_lst, func): return ret_name -def _add_op_node(op, func): - """Converts an op to a function def node and add it to `func`.""" - node = function_pb2.FunctionDef.Node() - node.op = op.type +def _get_op_def(op): # pylint: disable=protected-access if hasattr(op, "_sig"): - op_def = getattr(op, "_sig") + return getattr(op, "_sig") else: - op_def = op_def_registry.get_registered_ops()[op.type] + return op_def_registry.get_registered_ops()[op.type] # pylint: enable=protected-access - attrs = _get_node_def_attr(op) + + +def _is_in_placeholders(op, func_arg_placeholders): + return op.values() and (op.values()[0].name in func_arg_placeholders) + + +def _create_input_dict(function_graph, func_arg_placeholders): + """Create a mapping from graph tensor names to function tensor names.""" + input_dict = {} + for op in function_graph.get_operations(): + if _is_in_placeholders(op, func_arg_placeholders): + input_dict[op.values()[0].name] = op.values()[0].name + input_dict[op.name] = op.name + else: + op_def = _get_op_def(op) + attrs = _get_node_def(op).attr + o = 0 + for arg_def in op_def.output_arg: + if arg_def.number_attr: + num = attrs[arg_def.number_attr].i + elif arg_def.type_list_attr: + num = len(attrs[arg_def.type_list_attr].list.type) + else: + num = 1 + for i in range(num): + result = "%s:%s:%d" % (op.name, arg_def.name, i) + input_dict[op.values()[o].name] = result + if o == 0: + input_dict[op.name] = result + o += 1 + return input_dict + + +def _add_op_node(op, func, input_dict): + """Converts an op to a function def node and add it to `func`.""" + # Add an entry in func.node_def + + # Note that extend() makes a copy in this case, see: + # https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields + func.node_def.extend([_get_node_def(op)]) + node_def = func.node_def[-1] + for i in range(len(node_def.input)): + if not node_def.input[i].startswith("^"): + assert node_def.input[i] in input_dict, ( + "%s missing from %s" % (node_def.input[i], input_dict.items())) + node_def.input[i] = input_dict[node_def.input[i]] + + # To support legacy consumers, add an entry in func.node. + # TODO(josh11b): Delete this. + node = function_pb2.FunctionDef.Node() + node.op = op.type + op_def = _get_op_def(op) + attrs = node_def.attr if not op_def.output_arg: node.ret.append(_make_argname_from_tensor_name(op.name)) else: @@ -174,12 +224,31 @@ def _add_op_node(op, func): inp_index += 1 node.dep.extend( [_make_argname_from_tensor_name(x.name) for x in op.control_inputs]) - for k, v in _get_node_def_attr(op).items(): + for k, v in attrs.items(): node.attr[k].CopyFrom(v) func.node.extend([node]) -def _graph_to_function_def(graph, inputs, outputs): +def _replace_ret(func, original, replacement): + for n in func.node: + for i, r in enumerate(n.ret): + if r == original: + n.ret[i] = replacement + return + raise ValueError("Could not find ret == '%s'" % original) + + +def _replace_arg(func, original, replacement): + for n in func.node: + for i, a in enumerate(n.arg): + if a == original: + n.arg[i] = replacement + for i, d in enumerate(n.dep): + if d == original: + n.dep[i] = replacement + + +def _graph_to_function_def(graph, inputs, outputs, out_names=None): """Returns `graph` as a `FunctionDef` protocol buffer. This method creates a [`FunctionDef`]( @@ -195,19 +264,47 @@ def _graph_to_function_def(graph, inputs, outputs): graph: Graph. inputs: List of tensors. Inputs to the function. outputs: List of tensors. Outputs of the function. + out_names: Optional list of string names for the outputs. Returns: A FunctionDef protocol buffer. + + Raises: + ValueError: if out_names is specified and the wrong length. """ func = function_pb2.FunctionDef() func.signature.name = "_" func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs]) - func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs]) + if out_names is None: + func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs]) + elif len(outputs) != len(out_names): + raise ValueError( + "Length of out_names (%d) does not match number of outputs (%d): %s" % + (len(out_names), len(outputs), ", ".join(out_names))) + else: + func.signature.output_arg.extend([ + _tensor_to_argdef(o, n) for o, n in zip(outputs, out_names)]) func_arg_placeholders = set([i.name for i in inputs]) + input_dict = _create_input_dict(graph, func_arg_placeholders) + for op in graph.get_operations(): - if op.values() and (op.values()[0].name in func_arg_placeholders): + if _is_in_placeholders(op, func_arg_placeholders): continue - _add_op_node(op, func) + _add_op_node(op, func, input_dict) + + if out_names is None: + for o in outputs: + k = _make_argname_from_tensor_name(o.name) + func.ret[k] = input_dict[o.name] + else: + for o, n in zip(outputs, out_names): + func.ret[n] = input_dict[o.name] + # TODO(josh11b): Delete this once we switch fully to NodeDefs for + # function bodies. + k = _make_argname_from_tensor_name(o.name) + _replace_ret(func, k, n) + _replace_arg(func, k, n) + return func @@ -251,7 +348,6 @@ def _call(sig, *inputs, **kwargs): Raises: ValueError: if the arguments are invalid. - """ if len(inputs) != len(sig.input_arg): raise ValueError("Expected number of arguments: %d, received: %d" % @@ -301,7 +397,6 @@ class _FuncGraph(ops.Graph): Each captured input's corresponding place holder is converted into a function argument and the caller passes in the captured tensor. - """ def __init__(self, *args, **kwargs): @@ -385,7 +480,6 @@ def get_extra_inputs(): returned list of tensors are those accessed inside the function body but defined outside the function body so far. Otherwise, returns an empty list. - """ g = ops.get_default_graph() if isinstance(g, _FuncGraph): @@ -402,7 +496,6 @@ def get_extra_args(): returned list of place holders are those used inside the function body corresponding those returned by get_extra_inputs(). Otherwise, returns an empty list. - """ g = ops.get_default_graph() if isinstance(g, _FuncGraph): @@ -429,6 +522,7 @@ class _DefinedFunction(object): func_name=None, grad_func=None, python_grad_func=None, + out_names=None, **kwargs): """Creates _DefinedFunction. @@ -443,6 +537,8 @@ class _DefinedFunction(object): to None. python_grad_func: A python callable implementing the gradient of the function python-side. + out_names: An optional list of strings for the function return value + names. **kwargs: The keyword arguments. **kwargs is passed to every call site of this function. @@ -455,6 +551,7 @@ class _DefinedFunction(object): self._func_name = func_name self._grad_func = grad_func self._python_grad_func = python_grad_func + self._out_names = out_names self._extra_kwargs = kwargs self._definition = None # Constructed lazily. @@ -531,7 +628,8 @@ class _DefinedFunction(object): inputs.extend(temp_graph.extra_args) # Build the FunctionDef - self._definition = _graph_to_function_def(temp_graph, inputs, outputs) + self._definition = _graph_to_function_def( + temp_graph, inputs, outputs, out_names=self._out_names) # Extra kwargs are treated as attrs on the function def. kwargs_attr = _parse_kwargs_as_attrs(**self._extra_kwargs) @@ -556,6 +654,7 @@ class _DefinedFunction(object): for s in slist: update_str(s) + # TODO(josh11b): Switch .node to .node_def for n in sorted(self._definition.node, key=lambda n: n.ret[0]): update_strs(n.ret) update_str(n.op) @@ -661,6 +760,7 @@ class _OverloadedFunction(object): func_name=None, grad_func=None, python_grad_func=None, + out_names=None, **kwargs): """Creates _DefinedFunction. @@ -673,6 +773,7 @@ class _OverloadedFunction(object): to None. python_grad_func: A python callable implementing the gradient of the function python-side. + out_names: A list of strings for the function return value names. **kwargs: The keyword arguments. **kwargs is passed to every call site of this function. @@ -686,6 +787,7 @@ class _OverloadedFunction(object): assert grad_func is None or isinstance(grad_func, _OverloadedFunction) self._grad_func = grad_func self._python_grad_func = python_grad_func + self._out_names = out_names self._extra_kwargs = kwargs self._overload = {} @@ -709,6 +811,7 @@ class _OverloadedFunction(object): name = "_".join([name, key]) defined = _DefinedFunction(self._func, self._argnames, input_types, name, None, self._python_grad_func, + out_names=self._out_names, **self._extra_kwargs) _ = defined.name # Fully instantiate the function definition. if self._grad_func: @@ -802,11 +905,15 @@ class Defun(object): This will be called by tf.gradients to add the gradient ops to the graph. At most one of grad_func and python_grad_func can be specified. + + out_names = (optional). A list of strings, one per output + tensor. """ self._input_types = input_types self._func_name = kwargs.pop("func_name", None) self._grad_func = kwargs.pop("grad_func", None) self._python_grad_func = kwargs.pop("python_grad_func", None) + self._out_names = kwargs.pop("out_names", None) self._extra_kwargs = kwargs def __call__(self, func): @@ -833,7 +940,7 @@ class Defun(object): if self._input_types: # If Defun is given a list of types for the inputs, the number - # of of input types should be compatible with 'func'. + # of input types should be compatible with 'func'. num = len(self._input_types) if num < min_args or num > max_args: raise ValueError( @@ -841,17 +948,20 @@ class Defun(object): "input types.") return _DefinedFunction(func, argnames, self._input_types, self._func_name, self._grad_func, - self._python_grad_func, **self._extra_kwargs) + self._python_grad_func, + out_names=self._out_names, **self._extra_kwargs) # 'func' expects no arguments and input types is an empty list. if min_args == 0 and max_args == 0: return _DefinedFunction(func, [], [], self._func_name, self._grad_func, - self._python_grad_func, **self._extra_kwargs) + self._python_grad_func, + out_names=self._out_names, **self._extra_kwargs) # Input types are unknown. It's an overloaded function and hence # its definition needs to be deferred until it's called. return _OverloadedFunction(func, argnames, self._func_name, self._grad_func, - self._python_grad_func, **self._extra_kwargs) + self._python_grad_func, + out_names=self._out_names, **self._extra_kwargs) class Declare(object): @@ -861,38 +971,41 @@ class Declare(object): later during a graph construction. For example, - # Declares a function Foo, which takes a tf.int32 and a - # tf.float32 as inputs and returns a tf.float32 as its output. - foo = Declare("Foo", [tf.int32, tf.float32], [tf.float32]) + # Declares a function Foo, which takes a tf.int32 named "n" and a + # tf.float32 named "n" as inputs and returns a tf.float32 named "z" + # as its output. + foo = Declare("Foo", [("n", tf.int32), ("x", tf.float32)], + [("z", tf.float32)]) # Defines a function Bar calls Foo. @tf.Defun(tf.float32) def Bar(x): return foo(6, x) - # Defines Foo. - @tf.Defun(tf.int32, tf.float32) + # Defines Foo, with output named "z". + @tf.Defun(tf.int32, tf.float32, out_names=["z"]) def Foo(n, x): ... # Calculation. return result """ - def __init__(self, func_name, input_types, output_types): + def __init__(self, func_name, inputs, outputs): """Creates a `Declare` object. Args: func_name: The name of the function. - input_types: A list of data types of function arguments. - output_types: A list of data types of function return values. + inputs: A list of (name, data type) pairs of function arguments. + outputs: A list of (name, data type) pairs of function return values. """ self._sig = op_def_pb2.OpDef() self._sig.name = func_name - def _to_argdef_list(types): - return [op_def_pb2.OpDef.ArgDef(type=_.as_datatype_enum) for _ in types] + def _to_argdef_list(args): + return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n) + for n, t in args] - self._sig.input_arg.extend(_to_argdef_list(input_types)) - self._sig.output_arg.extend(_to_argdef_list(output_types)) + self._sig.input_arg.extend(_to_argdef_list(inputs)) + self._sig.output_arg.extend(_to_argdef_list(outputs)) def __call__(self, *inputs, **kwargs): inputs = [ops.convert_to_tensor(_) for _ in inputs] diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 1660be13574..c60e2bbd1be 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -485,9 +485,9 @@ class FunctionTest(tf.test.TestCase): self.assertAllClose(vals[2], vals[3]) def testDeclareTypeMistake(self): - foo = function.Declare("Foo", [tf.float32], [tf.float32]) + foo = function.Declare("Foo", [("x", tf.float32)], [("y", tf.float32)]) - @function.Defun(tf.float32, func_name="Foo") + @function.Defun(tf.float32, func_name="Foo", out_names=["y"]) def Foo(x): return x * x + 1 -- GitLab