From 5528b56647d39908ab0c9d2e97ff939d18f9b30f Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 22 Feb 2019 13:26:41 -0800 Subject: [PATCH] Ignore get_configs that aren't JSON-serializable when saving checkpoints This is extra "nice to have" metadata, and otherwise it looks like a checkpointing error. Not worth bothering people about. PiperOrigin-RevId: 235247436 --- tensorflow/python/training/tracking/base.py | 12 ++++++++---- tensorflow/python/training/tracking/base_test.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index 4a8960d3cab..f1f1fcba7c1 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -851,10 +851,14 @@ class Trackable(object): """Serializes `self.get_config()` for saving.""" dereferenced_self = weak_self() if dereferenced_self: - return json.dumps( - dereferenced_self, - default=serialization.get_json_type, - sort_keys=True).encode("utf8") + try: + return json.dumps( + dereferenced_self, + default=serialization.get_json_type, + sort_keys=True).encode("utf8") + except TypeError: + # Even if get_config worked objects may have produced garbage. + return "" else: return "" return {OBJECT_CONFIG_JSON_KEY: functools.partial( diff --git a/tensorflow/python/training/tracking/base_test.py b/tensorflow/python/training/tracking/base_test.py index 4a74417e3ba..d76e20edf7e 100644 --- a/tensorflow/python/training/tracking/base_test.py +++ b/tensorflow/python/training/tracking/base_test.py @@ -83,6 +83,19 @@ class InterfaceTests(test.TestCase): with self.assertRaisesRegexp(AssertionError, "foo_attr"): status.assert_consumed() + def testBuggyGetConfig(self): + + class NotSerializable(object): + pass + + class GetConfigRaisesError(base.Trackable): + + def get_config(self): + return NotSerializable() + + util.Checkpoint(obj=GetConfigRaisesError()).save( + os.path.join(self.get_temp_dir(), "ckpt")) + if __name__ == "__main__": ops.enable_eager_execution() -- GitLab