diff --git a/tensorflow/python/saved_model/builder.py b/tensorflow/python/saved_model/builder.py index 43b97cf70c651455bcbaf22ffbd6ccac0644764d..24126c01935a02d695d775479b508598725dc36a 100644 --- a/tensorflow/python/saved_model/builder.py +++ b/tensorflow/python/saved_model/builder.py @@ -249,11 +249,18 @@ class SavedModelBuilder(object): proto_meta_graph_def = self._saved_model.meta_graphs.add() proto_meta_graph_def.CopyFrom(meta_graph_def) + def _maybe_clear_devices(self, clear_devices): + if not clear_devices: + return + for node in ops.get_default_graph().as_graph_def().node: + node.device = "" + def add_meta_graph(self, tags, signature_def_map=None, assets_collection=None, - legacy_init_op=None): + legacy_init_op=None, + clear_devices=False): """Adds the current meta graph to the SavedModel. Creates a Saver in the current scope and uses the Saver to export the meta @@ -268,7 +275,9 @@ class SavedModelBuilder(object): that this collection should be a subset of the assets saved as part of the first meta graph in the SavedModel. legacy_init_op: Op or group of ops to execute after the restore op upon a - load. + load. + clear_devices: Set to true if the device info on the default graph should + be cleared. Raises: AssertionError: If the variables for the SavedModel have not been saved @@ -279,6 +288,8 @@ class SavedModelBuilder(object): "Variables and assets have not been saved yet. " "Please invoke `add_meta_graph_and_variables()` first.") + self._maybe_clear_devices(clear_devices) + # Save asset files and write them to disk, if any. self._save_and_write_assets(assets_collection) @@ -300,7 +311,8 @@ class SavedModelBuilder(object): tags, signature_def_map=None, assets_collection=None, - legacy_init_op=None): + legacy_init_op=None, + clear_devices=False): """Adds the current meta graph to the SavedModel and saves variables. Creates a Saver to save the variables from the provided session. Exports the @@ -318,11 +330,15 @@ class SavedModelBuilder(object): assets_collection: Assets collection to be saved with SavedModel. legacy_init_op: Op or group of ops to execute after the restore op upon a load. + clear_devices: Set to true if the device info on the default graph should + be cleared. """ if self._has_saved_variables: raise AssertionError("Variables and assets have already been saved. " "Please invoke `add_meta_graph()` instead.") + self._maybe_clear_devices(clear_devices) + # Save asset files and write them to disk, if any. self._save_and_write_assets(assets_collection) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index a50620e113c8a4d50abfc757efed64b73dfadc91..6f2132b49241f9b33a5acdcd288c8913058fea1f 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -553,6 +553,29 @@ class SavedModelTest(tf.test.TestCase): tf.get_collection("init_op")[0].run() self.assertEqual(3, tf.get_collection("v")[2].eval()) + def testClearDevices(self): + export_dir = os.path.join(tf.test.get_temp_dir(), "test_clear_devices") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + # Specify a device and save a variable. + tf.reset_default_graph() + with tf.Session( + target="", + config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: + with sess.graph.device("/cpu:0"): + self._init_and_validate_variable(sess, "v", 42) + builder.add_meta_graph_and_variables( + sess, [tag_constants.TRAINING], clear_devices=True) + + # Save the SavedModel to disk. + builder.save() + + # Restore the graph with a single predefined tag whose variables were saved + # without any device information. + with self.test_session(graph=tf.Graph()) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval()) + if __name__ == "__main__": tf.test.main()