提交 df8e306d 编写于 作者: S Sukriti Ramesh 提交者: TensorFlower Gardener

Add support for clear devices in SavedModel.

Change: 137772478
上级 f6d70256
......@@ -249,11 +249,18 @@ class SavedModelBuilder(object):
proto_meta_graph_def = self._saved_model.meta_graphs.add()
proto_meta_graph_def.CopyFrom(meta_graph_def)
def _maybe_clear_devices(self, clear_devices):
if not clear_devices:
return
for node in ops.get_default_graph().as_graph_def().node:
node.device = ""
def add_meta_graph(self,
tags,
signature_def_map=None,
assets_collection=None,
legacy_init_op=None):
legacy_init_op=None,
clear_devices=False):
"""Adds the current meta graph to the SavedModel.
Creates a Saver in the current scope and uses the Saver to export the meta
......@@ -269,6 +276,8 @@ class SavedModelBuilder(object):
the first meta graph in the SavedModel.
legacy_init_op: Op or group of ops to execute after the restore op upon a
load.
clear_devices: Set to true if the device info on the default graph should
be cleared.
Raises:
AssertionError: If the variables for the SavedModel have not been saved
......@@ -279,6 +288,8 @@ class SavedModelBuilder(object):
"Variables and assets have not been saved yet. "
"Please invoke `add_meta_graph_and_variables()` first.")
self._maybe_clear_devices(clear_devices)
# Save asset files and write them to disk, if any.
self._save_and_write_assets(assets_collection)
......@@ -300,7 +311,8 @@ class SavedModelBuilder(object):
tags,
signature_def_map=None,
assets_collection=None,
legacy_init_op=None):
legacy_init_op=None,
clear_devices=False):
"""Adds the current meta graph to the SavedModel and saves variables.
Creates a Saver to save the variables from the provided session. Exports the
......@@ -318,11 +330,15 @@ class SavedModelBuilder(object):
assets_collection: Assets collection to be saved with SavedModel.
legacy_init_op: Op or group of ops to execute after the restore op upon a
load.
clear_devices: Set to true if the device info on the default graph should
be cleared.
"""
if self._has_saved_variables:
raise AssertionError("Variables and assets have already been saved. "
"Please invoke `add_meta_graph()` instead.")
self._maybe_clear_devices(clear_devices)
# Save asset files and write them to disk, if any.
self._save_and_write_assets(assets_collection)
......
......@@ -553,6 +553,29 @@ class SavedModelTest(tf.test.TestCase):
tf.get_collection("init_op")[0].run()
self.assertEqual(3, tf.get_collection("v")[2].eval())
def testClearDevices(self):
export_dir = os.path.join(tf.test.get_temp_dir(), "test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Specify a device and save a variable.
tf.reset_default_graph()
with tf.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(
sess, [tag_constants.TRAINING], clear_devices=True)
# Save the SavedModel to disk.
builder.save()
# Restore the graph with a single predefined tag whose variables were saved
# without any device information.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册