diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index 4a8960d3cabdac1852f42e89a6c2e61d0ca4577e..f1f1fcba7c15056dc2435e67da1adcae5d1822a6 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -851,10 +851,14 @@ class Trackable(object): """Serializes `self.get_config()` for saving.""" dereferenced_self = weak_self() if dereferenced_self: - return json.dumps( - dereferenced_self, - default=serialization.get_json_type, - sort_keys=True).encode("utf8") + try: + return json.dumps( + dereferenced_self, + default=serialization.get_json_type, + sort_keys=True).encode("utf8") + except TypeError: + # Even if get_config worked objects may have produced garbage. + return "" else: return "" return {OBJECT_CONFIG_JSON_KEY: functools.partial( diff --git a/tensorflow/python/training/tracking/base_test.py b/tensorflow/python/training/tracking/base_test.py index 4a74417e3ba9a081ad2a6c7150e63ffd3aa898fa..d76e20edf7e8dcca11c588a05cc7514625083086 100644 --- a/tensorflow/python/training/tracking/base_test.py +++ b/tensorflow/python/training/tracking/base_test.py @@ -83,6 +83,19 @@ class InterfaceTests(test.TestCase): with self.assertRaisesRegexp(AssertionError, "foo_attr"): status.assert_consumed() + def testBuggyGetConfig(self): + + class NotSerializable(object): + pass + + class GetConfigRaisesError(base.Trackable): + + def get_config(self): + return NotSerializable() + + util.Checkpoint(obj=GetConfigRaisesError()).save( + os.path.join(self.get_temp_dir(), "ckpt")) + if __name__ == "__main__": ops.enable_eager_execution()