提交 9f2a816c 编写于 作者: S Sherry Moore 提交者: TensorFlower Gardener

Added missing call to strip_scope_name for WhileContext frames.

Added test to demomstrate how to compose larger graphs from several scoped
subgraph.
Change: 137574327
上级 27cb0773
......@@ -40,6 +40,10 @@ from tensorflow.python.training import training_util
from tensorflow.python.util import compat
# Prefix to be added to unbound input names so they are easily identifiable.
_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"
def _node_def(from_node_def, export_scope, unbound_inputs):
"""Create a `NodeDef` proto with export_scope stripped.
......@@ -57,7 +61,8 @@ def _node_def(from_node_def, export_scope, unbound_inputs):
not node_def.input[i].lstrip("^").startswith(export_scope)):
# Adds "$unbound_inputs_" prefix to the unbound name so they are easily
# identifiable.
node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1$unbound_inputs_\2",
node_def.input[i] = re.sub(r"([\^]|^)(.*)",
r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
compat.as_str(v))
unbound_inputs.append(node_def.input[i])
else:
......@@ -210,6 +215,31 @@ def _get_kind_name(item):
return kind
def _should_include_node(node_or_node_name, export_scope):
"""Returns `True` if a node should be included.
Args:
node_or_node_name: A node or `string` node name.
export_scope: `string`. Name scope under which to extract the subgraph. The
scope name will be striped from the node definitions for easy import later
into new name scopes.
Returns:
`True` if the node should be included.
"""
if not isinstance(node_or_node_name, six.string_types):
try:
node_name = node_or_node_name.name
except AttributeError:
# Keep the object that we don't know how to process.
return True
else:
node_name = node_or_node_name
return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
(not export_scope or node_name.startswith(export_scope)))
def add_collection_def(meta_graph_def, key, graph=None,
export_scope=None):
"""Adds a collection to MetaGraphDef protocol buffer.
......@@ -232,6 +262,9 @@ def add_collection_def(meta_graph_def, key, graph=None,
graph = graph or ops.get_default_graph()
collection_list = graph.get_collection(key)
# Remove nodes that should not be exported from the collection list.
collection_list = [x for x in collection_list if
_should_include_node(x, export_scope)]
if not collection_list:
return
......@@ -555,7 +588,7 @@ def export_scoped_meta_graph(filename=None,
graph_def.versions.CopyFrom(graph._graph_def_versions)
bytesize = 0
for key in sorted(graph._nodes_by_name):
if key.startswith(export_scope):
if _should_include_node(key, export_scope):
value = graph._nodes_by_name[key]
# pylint: enable=protected-access
graph_def.node.extend([_node_def(value.node_def, export_scope,
......@@ -572,6 +605,8 @@ def export_scoped_meta_graph(filename=None,
# If we would like such information included in the exported meta_graph,
# add them to a special unbound_inputs collection.
if unbound_inputs_col_name:
# Clears the unbound_inputs collections.
graph.clear_collection(unbound_inputs_col_name)
for k in unbound_inputs:
graph.add_to_collection(unbound_inputs_col_name, k)
......@@ -579,7 +614,8 @@ def export_scoped_meta_graph(filename=None,
variables = graph.get_collection(ops.GraphKeys.VARIABLES,
scope=export_scope)
for v in variables:
var_list[ops.strip_name_scope(v.name, export_scope)] = v
if _should_include_node(v, export_scope):
var_list[ops.strip_name_scope(v.name, export_scope)] = v
scoped_meta_graph_def = create_meta_graph_def(
graph_def=graph_def,
......
......@@ -135,7 +135,7 @@ class SimpleMetaGraphTest(tf.test.TestCase):
class ScopedMetaGraphTest(tf.test.TestCase):
def _testScopedExport(self, test_dir, exported_filename, ckpt_filename):
def _testScopedExport(self, test_dir, exported_filenames):
graph = tf.Graph()
with graph.as_default():
# Creates an inference graph.
......@@ -176,7 +176,7 @@ class ScopedMetaGraphTest(tf.test.TestCase):
return it + 1, biases2
_, biases2 = control_flow_ops.while_loop(
loop_cond, loop_body,
[tf.constant(0), tf.Variable(tf.zeros([32]))])
[tf.constant(0), tf.Variable(tf.zeros([32]), name="biases")])
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights2) + biases2)
# Linear
with tf.name_scope("softmax_linear"):
......@@ -187,18 +187,30 @@ class ScopedMetaGraphTest(tf.test.TestCase):
biases3 = tf.Variable(tf.zeros([10]), name="biases")
logits = tf.matmul(hidden2, weights3) + biases3
tf.add_to_collection("logits", logits)
orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filename),
# Exports each sub-graph.
# Exports the first one with unbound_inputs_col_name set to default.
orig_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filenames[0]),
graph=tf.get_default_graph(), export_scope="hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],
sorted(var_names))
return orig_meta_graph
# Exports the rest with no unbound_inputs_col_name.
orig_meta_graph2, _ = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filenames[1]),
graph=tf.get_default_graph(), export_scope="hidden2",
unbound_inputs_col_name=None)
orig_meta_graph3, _ = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filenames[2]),
graph=tf.get_default_graph(), export_scope="softmax_linear",
unbound_inputs_col_name=None)
def _testScopedImport(self, test_dir, exported_filename,
new_exported_filename, ckpt_filename):
return [orig_meta_graph1, orig_meta_graph2, orig_meta_graph3]
def _testScopedImport(self, test_dir, exported_filenames):
graph = tf.Graph()
# Create all the missing inputs.
with graph.as_default():
......@@ -207,73 +219,83 @@ class ScopedMetaGraphTest(tf.test.TestCase):
with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filename), graph=graph,
os.path.join(test_dir, exported_filenames[0]), graph=graph,
import_scope="new_hidden1")
with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filename), graph=graph,
os.path.join(test_dir, exported_filenames[0]), graph=graph,
input_map={"image:0": new_image},
import_scope="new_hidden1")
# Verifies we can import the original "hidden1" into "new_hidden1".
var_list = meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filename), graph=graph,
os.path.join(test_dir, exported_filenames[0]), graph=graph,
input_map={"$unbound_inputs_images": new_image},
import_scope="new_hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["new_hidden1/biases:0", "new_hidden1/weights:0"],
sorted(new_var_names))
hidden1 = graph.as_graph_element("new_hidden1/Relu:0")
with graph.as_default():
# Hidden 2
with tf.name_scope("hidden2"):
weights = tf.Variable(
tf.truncated_normal([128, 32],
stddev=1.0 / math.sqrt(float(128))),
name="weights")
# The use of control_flow_ops.while_loop here is purely for adding test
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
def loop_cond(it, _):
return it < 2
def loop_body(it, biases):
biases += tf.constant(0.1, shape=[32])
return it + 1, biases
_, biases = control_flow_ops.while_loop(
loop_cond, loop_body,
[tf.constant(0), tf.Variable(tf.zeros([32]))])
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
# Linear
with tf.name_scope("softmax_linear"):
weights = tf.Variable(
tf.truncated_normal([32, 10],
stddev=1.0 / math.sqrt(float(32))),
name="weights")
biases = tf.Variable(tf.zeros([10]), name="biases")
logits = tf.matmul(hidden2, weights) + biases
tf.add_to_collection("logits", logits)
# Verifies we can import the original "hidden2" into "new_hidden2".
hidden1 = tf.identity(graph.as_graph_element("new_hidden1/Relu:0"),
name="hidden1/Relu")
var_list = meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filenames[1]), graph=graph,
input_map={"$unbound_inputs_hidden1/Relu": hidden1},
import_scope="new_hidden2", unbound_inputs_col_name=None)
new_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, new_exported_filename),
graph=graph, export_scope="new_hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["new_hidden2/biases:0", "new_hidden2/weights:0"],
sorted(new_var_names))
return new_meta_graph
# Verifies we can import the original "softmax_linear" into
# "new_softmax_linear".
hidden2 = tf.identity(graph.as_graph_element("new_hidden2/Relu:0"),
name="hidden2/Relu")
var_list = meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filenames[2]), graph=graph,
input_map={"$unbound_inputs_hidden2/Relu": hidden2},
import_scope="new_softmax_linear", unbound_inputs_col_name=None)
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["new_softmax_linear/biases:0",
"new_softmax_linear/weights:0"],
sorted(new_var_names))
# Exports the scoped meta graphs again.
new_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
graph=graph, export_scope="new_hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_meta_graph2, var_list = meta_graph.export_scoped_meta_graph(
graph=graph, export_scope="new_hidden2",
unbound_inputs_col_name=None)
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_meta_graph3, var_list = meta_graph.export_scoped_meta_graph(
graph=graph, export_scope="new_softmax_linear",
unbound_inputs_col_name=None)
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
# Verifies that we can export the subgraph under "hidden1" and import
# it into "new_hidden1" in a new graph.
return [new_meta_graph1, new_meta_graph2, new_meta_graph3]
# Verifies that we can export the subgraph under each layer and import
# them into new layers in a new graph.
def testScopedExportAndImport(self):
test_dir = _TestDir("scoped_export_import")
ckpt_filename = "ckpt"
orig_meta_graph = self._testScopedExport(
test_dir, "exported_hidden1.pbtxt", ckpt_filename)
new_meta_graph = self._testScopedImport(
test_dir, "exported_hidden1.pbtxt", "exported_new_hidden1.pbtxt",
ckpt_filename)
self.assertProtoEquals(orig_meta_graph, new_meta_graph)
filenames = ["exported_hidden1.pbtxt", "exported_hidden2.pbtxt",
"exported_softmax_linear.pbtxt"]
orig_meta_graphs = self._testScopedExport(test_dir, filenames)
new_meta_graphs = self._testScopedImport(test_dir, filenames)
# Delete the unbound_inputs to allow directly calling ProtoEqual.
del orig_meta_graphs[0].collection_def["unbound_inputs"]
del new_meta_graphs[0].collection_def["unbound_inputs"]
for a, b in zip(orig_meta_graphs, new_meta_graphs):
self.assertProtoEquals(a, b)
def _testScopedExportWithQueue(self, test_dir, exported_filename):
graph = tf.Graph()
......
......@@ -2780,6 +2780,18 @@ class Graph(object):
with self._lock:
return [x for x in self._collections if isinstance(x, six.string_types)]
def clear_collection(self, name):
"""Clears all values in a collection.
Args:
name: The key for the collection. The `GraphKeys` class contains many
standard names for collections.
"""
self._check_not_finalized()
with self._lock:
if name in self._collections:
del self._collections[name]
@contextlib.contextmanager
def _original_op(self, op):
"""Python 'with' handler to help annotate ops with their originator.
......
......@@ -1959,7 +1959,9 @@ class WhileContext(ControlFlowContext):
context_def.pivot_name = ops.strip_name_scope(
self._pivot.name, export_scope)
if self._loop_exits:
context_def.loop_exit_names.extend([l.name for l in self._loop_exits])
context_def.loop_exit_names.extend(
[ops.strip_name_scope(l.name, export_scope)
for l in self._loop_exits])
context_def.values_def.MergeFrom(
super(WhileContext, self)._to_proto(
export_scope=export_scope))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册