提交 5528b566 编写于 作者: A Allen Lavoie 提交者: TensorFlower Gardener

Ignore get_configs that aren't JSON-serializable when saving checkpoints

This is extra "nice to have" metadata, and otherwise it looks like a checkpointing error. Not worth bothering people about.

PiperOrigin-RevId: 235247436
上级 fbe459fd
...@@ -851,10 +851,14 @@ class Trackable(object): ...@@ -851,10 +851,14 @@ class Trackable(object):
"""Serializes `self.get_config()` for saving.""" """Serializes `self.get_config()` for saving."""
dereferenced_self = weak_self() dereferenced_self = weak_self()
if dereferenced_self: if dereferenced_self:
return json.dumps( try:
dereferenced_self, return json.dumps(
default=serialization.get_json_type, dereferenced_self,
sort_keys=True).encode("utf8") default=serialization.get_json_type,
sort_keys=True).encode("utf8")
except TypeError:
# Even if get_config worked objects may have produced garbage.
return ""
else: else:
return "" return ""
return {OBJECT_CONFIG_JSON_KEY: functools.partial( return {OBJECT_CONFIG_JSON_KEY: functools.partial(
......
...@@ -83,6 +83,19 @@ class InterfaceTests(test.TestCase): ...@@ -83,6 +83,19 @@ class InterfaceTests(test.TestCase):
with self.assertRaisesRegexp(AssertionError, "foo_attr"): with self.assertRaisesRegexp(AssertionError, "foo_attr"):
status.assert_consumed() status.assert_consumed()
def testBuggyGetConfig(self):
class NotSerializable(object):
pass
class GetConfigRaisesError(base.Trackable):
def get_config(self):
return NotSerializable()
util.Checkpoint(obj=GetConfigRaisesError()).save(
os.path.join(self.get_temp_dir(), "ckpt"))
if __name__ == "__main__": if __name__ == "__main__":
ops.enable_eager_execution() ops.enable_eager_execution()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册