提交 5591ca5f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Python DefFun now creates functions that include NodeDefs

(in addition to the older FunctionDef::Node format).  Add an
optional out_names argument to Defun and names to Declare's
arguments, so that signatures for forward-declared DefFuns
can have signatures in the name (required for this change).
Change: 139919259
上级 be3e778f
...@@ -42,17 +42,18 @@ def _make_argname_from_tensor_name(name): ...@@ -42,17 +42,18 @@ def _make_argname_from_tensor_name(name):
return re.sub(":0$", "", name).replace(":", "_o") return re.sub(":0$", "", name).replace(":", "_o")
def _tensor_to_argdef(t): def _tensor_to_argdef(t, name=None):
arg = op_def_pb2.OpDef.ArgDef() arg = op_def_pb2.OpDef.ArgDef()
arg.name = _make_argname_from_tensor_name(t.name) if name is None:
arg.name = _make_argname_from_tensor_name(t.name)
else:
arg.name = name
arg.type = t.dtype.as_datatype_enum arg.type = t.dtype.as_datatype_enum
return arg return arg
def _get_node_def_attr(op): def _get_node_def(op):
# pylint: disable=protected-access return op._node_def # pylint: disable=protected-access
return op._node_def.attr
# pylint: enable=protected-access
def _add_input_array(op, start, limit, dtype, func): def _add_input_array(op, start, limit, dtype, func):
...@@ -122,17 +123,66 @@ def _add_output_list(op, start, limit, dtype_lst, func): ...@@ -122,17 +123,66 @@ def _add_output_list(op, start, limit, dtype_lst, func):
return ret_name return ret_name
def _add_op_node(op, func): def _get_op_def(op):
"""Converts an op to a function def node and add it to `func`."""
node = function_pb2.FunctionDef.Node()
node.op = op.type
# pylint: disable=protected-access # pylint: disable=protected-access
if hasattr(op, "_sig"): if hasattr(op, "_sig"):
op_def = getattr(op, "_sig") return getattr(op, "_sig")
else: else:
op_def = op_def_registry.get_registered_ops()[op.type] return op_def_registry.get_registered_ops()[op.type]
# pylint: enable=protected-access # pylint: enable=protected-access
attrs = _get_node_def_attr(op)
def _is_in_placeholders(op, func_arg_placeholders):
return op.values() and (op.values()[0].name in func_arg_placeholders)
def _create_input_dict(function_graph, func_arg_placeholders):
"""Create a mapping from graph tensor names to function tensor names."""
input_dict = {}
for op in function_graph.get_operations():
if _is_in_placeholders(op, func_arg_placeholders):
input_dict[op.values()[0].name] = op.values()[0].name
input_dict[op.name] = op.name
else:
op_def = _get_op_def(op)
attrs = _get_node_def(op).attr
o = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
num = attrs[arg_def.number_attr].i
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
else:
num = 1
for i in range(num):
result = "%s:%s:%d" % (op.name, arg_def.name, i)
input_dict[op.values()[o].name] = result
if o == 0:
input_dict[op.name] = result
o += 1
return input_dict
def _add_op_node(op, func, input_dict):
"""Converts an op to a function def node and add it to `func`."""
# Add an entry in func.node_def
# Note that extend() makes a copy in this case, see:
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
func.node_def.extend([_get_node_def(op)])
node_def = func.node_def[-1]
for i in range(len(node_def.input)):
if not node_def.input[i].startswith("^"):
assert node_def.input[i] in input_dict, (
"%s missing from %s" % (node_def.input[i], input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]]
# To support legacy consumers, add an entry in func.node.
# TODO(josh11b): Delete this.
node = function_pb2.FunctionDef.Node()
node.op = op.type
op_def = _get_op_def(op)
attrs = node_def.attr
if not op_def.output_arg: if not op_def.output_arg:
node.ret.append(_make_argname_from_tensor_name(op.name)) node.ret.append(_make_argname_from_tensor_name(op.name))
else: else:
...@@ -174,12 +224,31 @@ def _add_op_node(op, func): ...@@ -174,12 +224,31 @@ def _add_op_node(op, func):
inp_index += 1 inp_index += 1
node.dep.extend( node.dep.extend(
[_make_argname_from_tensor_name(x.name) for x in op.control_inputs]) [_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
for k, v in _get_node_def_attr(op).items(): for k, v in attrs.items():
node.attr[k].CopyFrom(v) node.attr[k].CopyFrom(v)
func.node.extend([node]) func.node.extend([node])
def _graph_to_function_def(graph, inputs, outputs): def _replace_ret(func, original, replacement):
for n in func.node:
for i, r in enumerate(n.ret):
if r == original:
n.ret[i] = replacement
return
raise ValueError("Could not find ret == '%s'" % original)
def _replace_arg(func, original, replacement):
for n in func.node:
for i, a in enumerate(n.arg):
if a == original:
n.arg[i] = replacement
for i, d in enumerate(n.dep):
if d == original:
n.dep[i] = replacement
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer. """Returns `graph` as a `FunctionDef` protocol buffer.
This method creates a [`FunctionDef`]( This method creates a [`FunctionDef`](
...@@ -195,19 +264,47 @@ def _graph_to_function_def(graph, inputs, outputs): ...@@ -195,19 +264,47 @@ def _graph_to_function_def(graph, inputs, outputs):
graph: Graph. graph: Graph.
inputs: List of tensors. Inputs to the function. inputs: List of tensors. Inputs to the function.
outputs: List of tensors. Outputs of the function. outputs: List of tensors. Outputs of the function.
out_names: Optional list of string names for the outputs.
Returns: Returns:
A FunctionDef protocol buffer. A FunctionDef protocol buffer.
Raises:
ValueError: if out_names is specified and the wrong length.
""" """
func = function_pb2.FunctionDef() func = function_pb2.FunctionDef()
func.signature.name = "_" func.signature.name = "_"
func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs]) func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs])
func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs]) if out_names is None:
func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
elif len(outputs) != len(out_names):
raise ValueError(
"Length of out_names (%d) does not match number of outputs (%d): %s" %
(len(out_names), len(outputs), ", ".join(out_names)))
else:
func.signature.output_arg.extend([
_tensor_to_argdef(o, n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs]) func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders)
for op in graph.get_operations(): for op in graph.get_operations():
if op.values() and (op.values()[0].name in func_arg_placeholders): if _is_in_placeholders(op, func_arg_placeholders):
continue continue
_add_op_node(op, func) _add_op_node(op, func, input_dict)
if out_names is None:
for o in outputs:
k = _make_argname_from_tensor_name(o.name)
func.ret[k] = input_dict[o.name]
else:
for o, n in zip(outputs, out_names):
func.ret[n] = input_dict[o.name]
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
# function bodies.
k = _make_argname_from_tensor_name(o.name)
_replace_ret(func, k, n)
_replace_arg(func, k, n)
return func return func
...@@ -251,7 +348,6 @@ def _call(sig, *inputs, **kwargs): ...@@ -251,7 +348,6 @@ def _call(sig, *inputs, **kwargs):
Raises: Raises:
ValueError: if the arguments are invalid. ValueError: if the arguments are invalid.
""" """
if len(inputs) != len(sig.input_arg): if len(inputs) != len(sig.input_arg):
raise ValueError("Expected number of arguments: %d, received: %d" % raise ValueError("Expected number of arguments: %d, received: %d" %
...@@ -301,7 +397,6 @@ class _FuncGraph(ops.Graph): ...@@ -301,7 +397,6 @@ class _FuncGraph(ops.Graph):
Each captured input's corresponding place holder is converted into a Each captured input's corresponding place holder is converted into a
function argument and the caller passes in the captured tensor. function argument and the caller passes in the captured tensor.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -385,7 +480,6 @@ def get_extra_inputs(): ...@@ -385,7 +480,6 @@ def get_extra_inputs():
returned list of tensors are those accessed inside the function body returned list of tensors are those accessed inside the function body
but defined outside the function body so far. Otherwise, returns an but defined outside the function body so far. Otherwise, returns an
empty list. empty list.
""" """
g = ops.get_default_graph() g = ops.get_default_graph()
if isinstance(g, _FuncGraph): if isinstance(g, _FuncGraph):
...@@ -402,7 +496,6 @@ def get_extra_args(): ...@@ -402,7 +496,6 @@ def get_extra_args():
returned list of place holders are those used inside the function returned list of place holders are those used inside the function
body corresponding those returned by get_extra_inputs(). Otherwise, body corresponding those returned by get_extra_inputs(). Otherwise,
returns an empty list. returns an empty list.
""" """
g = ops.get_default_graph() g = ops.get_default_graph()
if isinstance(g, _FuncGraph): if isinstance(g, _FuncGraph):
...@@ -429,6 +522,7 @@ class _DefinedFunction(object): ...@@ -429,6 +522,7 @@ class _DefinedFunction(object):
func_name=None, func_name=None,
grad_func=None, grad_func=None,
python_grad_func=None, python_grad_func=None,
out_names=None,
**kwargs): **kwargs):
"""Creates _DefinedFunction. """Creates _DefinedFunction.
...@@ -443,6 +537,8 @@ class _DefinedFunction(object): ...@@ -443,6 +537,8 @@ class _DefinedFunction(object):
to None. to None.
python_grad_func: A python callable implementing the gradient of python_grad_func: A python callable implementing the gradient of
the function python-side. the function python-side.
out_names: An optional list of strings for the function return value
names.
**kwargs: The keyword arguments. **kwargs is passed to every call **kwargs: The keyword arguments. **kwargs is passed to every call
site of this function. site of this function.
...@@ -455,6 +551,7 @@ class _DefinedFunction(object): ...@@ -455,6 +551,7 @@ class _DefinedFunction(object):
self._func_name = func_name self._func_name = func_name
self._grad_func = grad_func self._grad_func = grad_func
self._python_grad_func = python_grad_func self._python_grad_func = python_grad_func
self._out_names = out_names
self._extra_kwargs = kwargs self._extra_kwargs = kwargs
self._definition = None # Constructed lazily. self._definition = None # Constructed lazily.
...@@ -531,7 +628,8 @@ class _DefinedFunction(object): ...@@ -531,7 +628,8 @@ class _DefinedFunction(object):
inputs.extend(temp_graph.extra_args) inputs.extend(temp_graph.extra_args)
# Build the FunctionDef # Build the FunctionDef
self._definition = _graph_to_function_def(temp_graph, inputs, outputs) self._definition = _graph_to_function_def(
temp_graph, inputs, outputs, out_names=self._out_names)
# Extra kwargs are treated as attrs on the function def. # Extra kwargs are treated as attrs on the function def.
kwargs_attr = _parse_kwargs_as_attrs(**self._extra_kwargs) kwargs_attr = _parse_kwargs_as_attrs(**self._extra_kwargs)
...@@ -556,6 +654,7 @@ class _DefinedFunction(object): ...@@ -556,6 +654,7 @@ class _DefinedFunction(object):
for s in slist: for s in slist:
update_str(s) update_str(s)
# TODO(josh11b): Switch .node to .node_def
for n in sorted(self._definition.node, key=lambda n: n.ret[0]): for n in sorted(self._definition.node, key=lambda n: n.ret[0]):
update_strs(n.ret) update_strs(n.ret)
update_str(n.op) update_str(n.op)
...@@ -661,6 +760,7 @@ class _OverloadedFunction(object): ...@@ -661,6 +760,7 @@ class _OverloadedFunction(object):
func_name=None, func_name=None,
grad_func=None, grad_func=None,
python_grad_func=None, python_grad_func=None,
out_names=None,
**kwargs): **kwargs):
"""Creates _DefinedFunction. """Creates _DefinedFunction.
...@@ -673,6 +773,7 @@ class _OverloadedFunction(object): ...@@ -673,6 +773,7 @@ class _OverloadedFunction(object):
to None. to None.
python_grad_func: A python callable implementing the gradient of python_grad_func: A python callable implementing the gradient of
the function python-side. the function python-side.
out_names: A list of strings for the function return value names.
**kwargs: The keyword arguments. **kwargs is passed to every call **kwargs: The keyword arguments. **kwargs is passed to every call
site of this function. site of this function.
...@@ -686,6 +787,7 @@ class _OverloadedFunction(object): ...@@ -686,6 +787,7 @@ class _OverloadedFunction(object):
assert grad_func is None or isinstance(grad_func, _OverloadedFunction) assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
self._grad_func = grad_func self._grad_func = grad_func
self._python_grad_func = python_grad_func self._python_grad_func = python_grad_func
self._out_names = out_names
self._extra_kwargs = kwargs self._extra_kwargs = kwargs
self._overload = {} self._overload = {}
...@@ -709,6 +811,7 @@ class _OverloadedFunction(object): ...@@ -709,6 +811,7 @@ class _OverloadedFunction(object):
name = "_".join([name, key]) name = "_".join([name, key])
defined = _DefinedFunction(self._func, self._argnames, input_types, name, defined = _DefinedFunction(self._func, self._argnames, input_types, name,
None, self._python_grad_func, None, self._python_grad_func,
out_names=self._out_names,
**self._extra_kwargs) **self._extra_kwargs)
_ = defined.name # Fully instantiate the function definition. _ = defined.name # Fully instantiate the function definition.
if self._grad_func: if self._grad_func:
...@@ -802,11 +905,15 @@ class Defun(object): ...@@ -802,11 +905,15 @@ class Defun(object):
This will be called by tf.gradients to add the gradient ops This will be called by tf.gradients to add the gradient ops
to the graph. At most one of grad_func and python_grad_func to the graph. At most one of grad_func and python_grad_func
can be specified. can be specified.
out_names = (optional). A list of strings, one per output
tensor.
""" """
self._input_types = input_types self._input_types = input_types
self._func_name = kwargs.pop("func_name", None) self._func_name = kwargs.pop("func_name", None)
self._grad_func = kwargs.pop("grad_func", None) self._grad_func = kwargs.pop("grad_func", None)
self._python_grad_func = kwargs.pop("python_grad_func", None) self._python_grad_func = kwargs.pop("python_grad_func", None)
self._out_names = kwargs.pop("out_names", None)
self._extra_kwargs = kwargs self._extra_kwargs = kwargs
def __call__(self, func): def __call__(self, func):
...@@ -833,7 +940,7 @@ class Defun(object): ...@@ -833,7 +940,7 @@ class Defun(object):
if self._input_types: if self._input_types:
# If Defun is given a list of types for the inputs, the number # If Defun is given a list of types for the inputs, the number
# of of input types should be compatible with 'func'. # of input types should be compatible with 'func'.
num = len(self._input_types) num = len(self._input_types)
if num < min_args or num > max_args: if num < min_args or num > max_args:
raise ValueError( raise ValueError(
...@@ -841,17 +948,20 @@ class Defun(object): ...@@ -841,17 +948,20 @@ class Defun(object):
"input types.") "input types.")
return _DefinedFunction(func, argnames, self._input_types, return _DefinedFunction(func, argnames, self._input_types,
self._func_name, self._grad_func, self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs) self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)
# 'func' expects no arguments and input types is an empty list. # 'func' expects no arguments and input types is an empty list.
if min_args == 0 and max_args == 0: if min_args == 0 and max_args == 0:
return _DefinedFunction(func, [], [], self._func_name, self._grad_func, return _DefinedFunction(func, [], [], self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs) self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)
# Input types are unknown. It's an overloaded function and hence # Input types are unknown. It's an overloaded function and hence
# its definition needs to be deferred until it's called. # its definition needs to be deferred until it's called.
return _OverloadedFunction(func, argnames, self._func_name, self._grad_func, return _OverloadedFunction(func, argnames, self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs) self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)
class Declare(object): class Declare(object):
...@@ -861,38 +971,41 @@ class Declare(object): ...@@ -861,38 +971,41 @@ class Declare(object):
later during a graph construction. later during a graph construction.
For example, For example,
# Declares a function Foo, which takes a tf.int32 and a # Declares a function Foo, which takes a tf.int32 named "n" and a
# tf.float32 as inputs and returns a tf.float32 as its output. # tf.float32 named "n" as inputs and returns a tf.float32 named "z"
foo = Declare("Foo", [tf.int32, tf.float32], [tf.float32]) # as its output.
foo = Declare("Foo", [("n", tf.int32), ("x", tf.float32)],
[("z", tf.float32)])
# Defines a function Bar calls Foo. # Defines a function Bar calls Foo.
@tf.Defun(tf.float32) @tf.Defun(tf.float32)
def Bar(x): def Bar(x):
return foo(6, x) return foo(6, x)
# Defines Foo. # Defines Foo, with output named "z".
@tf.Defun(tf.int32, tf.float32) @tf.Defun(tf.int32, tf.float32, out_names=["z"])
def Foo(n, x): def Foo(n, x):
... # Calculation. ... # Calculation.
return result return result
""" """
def __init__(self, func_name, input_types, output_types): def __init__(self, func_name, inputs, outputs):
"""Creates a `Declare` object. """Creates a `Declare` object.
Args: Args:
func_name: The name of the function. func_name: The name of the function.
input_types: A list of data types of function arguments. inputs: A list of (name, data type) pairs of function arguments.
output_types: A list of data types of function return values. outputs: A list of (name, data type) pairs of function return values.
""" """
self._sig = op_def_pb2.OpDef() self._sig = op_def_pb2.OpDef()
self._sig.name = func_name self._sig.name = func_name
def _to_argdef_list(types): def _to_argdef_list(args):
return [op_def_pb2.OpDef.ArgDef(type=_.as_datatype_enum) for _ in types] return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
for n, t in args]
self._sig.input_arg.extend(_to_argdef_list(input_types)) self._sig.input_arg.extend(_to_argdef_list(inputs))
self._sig.output_arg.extend(_to_argdef_list(output_types)) self._sig.output_arg.extend(_to_argdef_list(outputs))
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
inputs = [ops.convert_to_tensor(_) for _ in inputs] inputs = [ops.convert_to_tensor(_) for _ in inputs]
......
...@@ -485,9 +485,9 @@ class FunctionTest(tf.test.TestCase): ...@@ -485,9 +485,9 @@ class FunctionTest(tf.test.TestCase):
self.assertAllClose(vals[2], vals[3]) self.assertAllClose(vals[2], vals[3])
def testDeclareTypeMistake(self): def testDeclareTypeMistake(self):
foo = function.Declare("Foo", [tf.float32], [tf.float32]) foo = function.Declare("Foo", [("x", tf.float32)], [("y", tf.float32)])
@function.Defun(tf.float32, func_name="Foo") @function.Defun(tf.float32, func_name="Foo", out_names=["y"])
def Foo(x): def Foo(x):
return x * x + 1 return x * x + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册