提交 5528b566 编写于 作者: A Allen Lavoie 提交者: TensorFlower Gardener

Ignore get_configs that aren't JSON-serializable when saving checkpoints

This is extra "nice to have" metadata, and otherwise it looks like a checkpointing error. Not worth bothering people about.

PiperOrigin-RevId: 235247436
上级 fbe459fd
......@@ -851,10 +851,14 @@ class Trackable(object):
"""Serializes `self.get_config()` for saving."""
dereferenced_self = weak_self()
if dereferenced_self:
return json.dumps(
dereferenced_self,
default=serialization.get_json_type,
sort_keys=True).encode("utf8")
try:
return json.dumps(
dereferenced_self,
default=serialization.get_json_type,
sort_keys=True).encode("utf8")
except TypeError:
# Even if get_config worked objects may have produced garbage.
return ""
else:
return ""
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
......
......@@ -83,6 +83,19 @@ class InterfaceTests(test.TestCase):
with self.assertRaisesRegexp(AssertionError, "foo_attr"):
status.assert_consumed()
def testBuggyGetConfig(self):
class NotSerializable(object):
pass
class GetConfigRaisesError(base.Trackable):
def get_config(self):
return NotSerializable()
util.Checkpoint(obj=GetConfigRaisesError()).save(
os.path.join(self.get_temp_dir(), "ckpt"))
if __name__ == "__main__":
ops.enable_eager_execution()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册