提交 5591ca5f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Python DefFun now creates functions that include NodeDefs

(in addition to the older FunctionDef::Node format).  Add an
optional out_names argument to Defun and names to Declare's
arguments, so that signatures for forward-declared DefFuns
can have signatures in the name (required for this change).
Change: 139919259
上级 be3e778f
......@@ -42,17 +42,18 @@ def _make_argname_from_tensor_name(name):
return re.sub(":0$", "", name).replace(":", "_o")
def _tensor_to_argdef(t):
def _tensor_to_argdef(t, name=None):
arg = op_def_pb2.OpDef.ArgDef()
arg.name = _make_argname_from_tensor_name(t.name)
if name is None:
arg.name = _make_argname_from_tensor_name(t.name)
else:
arg.name = name
arg.type = t.dtype.as_datatype_enum
return arg
def _get_node_def_attr(op):
# pylint: disable=protected-access
return op._node_def.attr
# pylint: enable=protected-access
def _get_node_def(op):
return op._node_def # pylint: disable=protected-access
def _add_input_array(op, start, limit, dtype, func):
......@@ -122,17 +123,66 @@ def _add_output_list(op, start, limit, dtype_lst, func):
return ret_name
def _add_op_node(op, func):
"""Converts an op to a function def node and add it to `func`."""
node = function_pb2.FunctionDef.Node()
node.op = op.type
def _get_op_def(op):
# pylint: disable=protected-access
if hasattr(op, "_sig"):
op_def = getattr(op, "_sig")
return getattr(op, "_sig")
else:
op_def = op_def_registry.get_registered_ops()[op.type]
return op_def_registry.get_registered_ops()[op.type]
# pylint: enable=protected-access
attrs = _get_node_def_attr(op)
def _is_in_placeholders(op, func_arg_placeholders):
return op.values() and (op.values()[0].name in func_arg_placeholders)
def _create_input_dict(function_graph, func_arg_placeholders):
"""Create a mapping from graph tensor names to function tensor names."""
input_dict = {}
for op in function_graph.get_operations():
if _is_in_placeholders(op, func_arg_placeholders):
input_dict[op.values()[0].name] = op.values()[0].name
input_dict[op.name] = op.name
else:
op_def = _get_op_def(op)
attrs = _get_node_def(op).attr
o = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
num = attrs[arg_def.number_attr].i
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
else:
num = 1
for i in range(num):
result = "%s:%s:%d" % (op.name, arg_def.name, i)
input_dict[op.values()[o].name] = result
if o == 0:
input_dict[op.name] = result
o += 1
return input_dict
def _add_op_node(op, func, input_dict):
"""Converts an op to a function def node and add it to `func`."""
# Add an entry in func.node_def
# Note that extend() makes a copy in this case, see:
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
func.node_def.extend([_get_node_def(op)])
node_def = func.node_def[-1]
for i in range(len(node_def.input)):
if not node_def.input[i].startswith("^"):
assert node_def.input[i] in input_dict, (
"%s missing from %s" % (node_def.input[i], input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]]
# To support legacy consumers, add an entry in func.node.
# TODO(josh11b): Delete this.
node = function_pb2.FunctionDef.Node()
node.op = op.type
op_def = _get_op_def(op)
attrs = node_def.attr
if not op_def.output_arg:
node.ret.append(_make_argname_from_tensor_name(op.name))
else:
......@@ -174,12 +224,31 @@ def _add_op_node(op, func):
inp_index += 1
node.dep.extend(
[_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
for k, v in _get_node_def_attr(op).items():
for k, v in attrs.items():
node.attr[k].CopyFrom(v)
func.node.extend([node])
def _graph_to_function_def(graph, inputs, outputs):
def _replace_ret(func, original, replacement):
for n in func.node:
for i, r in enumerate(n.ret):
if r == original:
n.ret[i] = replacement
return
raise ValueError("Could not find ret == '%s'" % original)
def _replace_arg(func, original, replacement):
for n in func.node:
for i, a in enumerate(n.arg):
if a == original:
n.arg[i] = replacement
for i, d in enumerate(n.dep):
if d == original:
n.dep[i] = replacement
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer.
This method creates a [`FunctionDef`](
......@@ -195,19 +264,47 @@ def _graph_to_function_def(graph, inputs, outputs):
graph: Graph.
inputs: List of tensors. Inputs to the function.
outputs: List of tensors. Outputs of the function.
out_names: Optional list of string names for the outputs.
Returns:
A FunctionDef protocol buffer.
Raises:
ValueError: if out_names is specified and the wrong length.
"""
func = function_pb2.FunctionDef()
func.signature.name = "_"
func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs])
func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
if out_names is None:
func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
elif len(outputs) != len(out_names):
raise ValueError(
"Length of out_names (%d) does not match number of outputs (%d): %s" %
(len(out_names), len(outputs), ", ".join(out_names)))
else:
func.signature.output_arg.extend([
_tensor_to_argdef(o, n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders)
for op in graph.get_operations():
if op.values() and (op.values()[0].name in func_arg_placeholders):
if _is_in_placeholders(op, func_arg_placeholders):
continue
_add_op_node(op, func)
_add_op_node(op, func, input_dict)
if out_names is None:
for o in outputs:
k = _make_argname_from_tensor_name(o.name)
func.ret[k] = input_dict[o.name]
else:
for o, n in zip(outputs, out_names):
func.ret[n] = input_dict[o.name]
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
# function bodies.
k = _make_argname_from_tensor_name(o.name)
_replace_ret(func, k, n)
_replace_arg(func, k, n)
return func
......@@ -251,7 +348,6 @@ def _call(sig, *inputs, **kwargs):
Raises:
ValueError: if the arguments are invalid.
"""
if len(inputs) != len(sig.input_arg):
raise ValueError("Expected number of arguments: %d, received: %d" %
......@@ -301,7 +397,6 @@ class _FuncGraph(ops.Graph):
Each captured input's corresponding place holder is converted into a
function argument and the caller passes in the captured tensor.
"""
def __init__(self, *args, **kwargs):
......@@ -385,7 +480,6 @@ def get_extra_inputs():
returned list of tensors are those accessed inside the function body
but defined outside the function body so far. Otherwise, returns an
empty list.
"""
g = ops.get_default_graph()
if isinstance(g, _FuncGraph):
......@@ -402,7 +496,6 @@ def get_extra_args():
returned list of place holders are those used inside the function
body corresponding those returned by get_extra_inputs(). Otherwise,
returns an empty list.
"""
g = ops.get_default_graph()
if isinstance(g, _FuncGraph):
......@@ -429,6 +522,7 @@ class _DefinedFunction(object):
func_name=None,
grad_func=None,
python_grad_func=None,
out_names=None,
**kwargs):
"""Creates _DefinedFunction.
......@@ -443,6 +537,8 @@ class _DefinedFunction(object):
to None.
python_grad_func: A python callable implementing the gradient of
the function python-side.
out_names: An optional list of strings for the function return value
names.
**kwargs: The keyword arguments. **kwargs is passed to every call
site of this function.
......@@ -455,6 +551,7 @@ class _DefinedFunction(object):
self._func_name = func_name
self._grad_func = grad_func
self._python_grad_func = python_grad_func
self._out_names = out_names
self._extra_kwargs = kwargs
self._definition = None # Constructed lazily.
......@@ -531,7 +628,8 @@ class _DefinedFunction(object):
inputs.extend(temp_graph.extra_args)
# Build the FunctionDef
self._definition = _graph_to_function_def(temp_graph, inputs, outputs)
self._definition = _graph_to_function_def(
temp_graph, inputs, outputs, out_names=self._out_names)
# Extra kwargs are treated as attrs on the function def.
kwargs_attr = _parse_kwargs_as_attrs(**self._extra_kwargs)
......@@ -556,6 +654,7 @@ class _DefinedFunction(object):
for s in slist:
update_str(s)
# TODO(josh11b): Switch .node to .node_def
for n in sorted(self._definition.node, key=lambda n: n.ret[0]):
update_strs(n.ret)
update_str(n.op)
......@@ -661,6 +760,7 @@ class _OverloadedFunction(object):
func_name=None,
grad_func=None,
python_grad_func=None,
out_names=None,
**kwargs):
"""Creates _DefinedFunction.
......@@ -673,6 +773,7 @@ class _OverloadedFunction(object):
to None.
python_grad_func: A python callable implementing the gradient of
the function python-side.
out_names: A list of strings for the function return value names.
**kwargs: The keyword arguments. **kwargs is passed to every call
site of this function.
......@@ -686,6 +787,7 @@ class _OverloadedFunction(object):
assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
self._grad_func = grad_func
self._python_grad_func = python_grad_func
self._out_names = out_names
self._extra_kwargs = kwargs
self._overload = {}
......@@ -709,6 +811,7 @@ class _OverloadedFunction(object):
name = "_".join([name, key])
defined = _DefinedFunction(self._func, self._argnames, input_types, name,
None, self._python_grad_func,
out_names=self._out_names,
**self._extra_kwargs)
_ = defined.name # Fully instantiate the function definition.
if self._grad_func:
......@@ -802,11 +905,15 @@ class Defun(object):
This will be called by tf.gradients to add the gradient ops
to the graph. At most one of grad_func and python_grad_func
can be specified.
out_names = (optional). A list of strings, one per output
tensor.
"""
self._input_types = input_types
self._func_name = kwargs.pop("func_name", None)
self._grad_func = kwargs.pop("grad_func", None)
self._python_grad_func = kwargs.pop("python_grad_func", None)
self._out_names = kwargs.pop("out_names", None)
self._extra_kwargs = kwargs
def __call__(self, func):
......@@ -833,7 +940,7 @@ class Defun(object):
if self._input_types:
# If Defun is given a list of types for the inputs, the number
# of of input types should be compatible with 'func'.
# of input types should be compatible with 'func'.
num = len(self._input_types)
if num < min_args or num > max_args:
raise ValueError(
......@@ -841,17 +948,20 @@ class Defun(object):
"input types.")
return _DefinedFunction(func, argnames, self._input_types,
self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs)
self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)
# 'func' expects no arguments and input types is an empty list.
if min_args == 0 and max_args == 0:
return _DefinedFunction(func, [], [], self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs)
self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)
# Input types are unknown. It's an overloaded function and hence
# its definition needs to be deferred until it's called.
return _OverloadedFunction(func, argnames, self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs)
self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)
class Declare(object):
......@@ -861,38 +971,41 @@ class Declare(object):
later during a graph construction.
For example,
# Declares a function Foo, which takes a tf.int32 and a
# tf.float32 as inputs and returns a tf.float32 as its output.
foo = Declare("Foo", [tf.int32, tf.float32], [tf.float32])
# Declares a function Foo, which takes a tf.int32 named "n" and a
# tf.float32 named "n" as inputs and returns a tf.float32 named "z"
# as its output.
foo = Declare("Foo", [("n", tf.int32), ("x", tf.float32)],
[("z", tf.float32)])
# Defines a function Bar calls Foo.
@tf.Defun(tf.float32)
def Bar(x):
return foo(6, x)
# Defines Foo.
@tf.Defun(tf.int32, tf.float32)
# Defines Foo, with output named "z".
@tf.Defun(tf.int32, tf.float32, out_names=["z"])
def Foo(n, x):
... # Calculation.
return result
"""
def __init__(self, func_name, input_types, output_types):
def __init__(self, func_name, inputs, outputs):
"""Creates a `Declare` object.
Args:
func_name: The name of the function.
input_types: A list of data types of function arguments.
output_types: A list of data types of function return values.
inputs: A list of (name, data type) pairs of function arguments.
outputs: A list of (name, data type) pairs of function return values.
"""
self._sig = op_def_pb2.OpDef()
self._sig.name = func_name
def _to_argdef_list(types):
return [op_def_pb2.OpDef.ArgDef(type=_.as_datatype_enum) for _ in types]
def _to_argdef_list(args):
return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
for n, t in args]
self._sig.input_arg.extend(_to_argdef_list(input_types))
self._sig.output_arg.extend(_to_argdef_list(output_types))
self._sig.input_arg.extend(_to_argdef_list(inputs))
self._sig.output_arg.extend(_to_argdef_list(outputs))
def __call__(self, *inputs, **kwargs):
inputs = [ops.convert_to_tensor(_) for _ in inputs]
......
......@@ -485,9 +485,9 @@ class FunctionTest(tf.test.TestCase):
self.assertAllClose(vals[2], vals[3])
def testDeclareTypeMistake(self):
foo = function.Declare("Foo", [tf.float32], [tf.float32])
foo = function.Declare("Foo", [("x", tf.float32)], [("y", tf.float32)])
@function.Defun(tf.float32, func_name="Foo")
@function.Defun(tf.float32, func_name="Foo", out_names=["y"])
def Foo(x):
return x * x + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册