From a0d1c58dfd959041361b7763a3c3f6586f186bb0 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Thu, 14 Feb 2019 17:28:57 +0800 Subject: [PATCH] fix serialization bug --- paddle_hub/paddle_helper.py | 18 ++++++++++++++---- paddle_hub/utils.py | 38 +++++++++++++++++++++++++++++-------- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/paddle_hub/paddle_helper.py b/paddle_hub/paddle_helper.py index 6b3c79af..7c1de9cb 100644 --- a/paddle_hub/paddle_helper.py +++ b/paddle_hub/paddle_helper.py @@ -48,6 +48,12 @@ def get_variable_info(var): def from_param_to_flexible_data(param, flexible_data): + def paddle_obj_filter(pyobj): + return isinstance(pyobj, fluid.framework.Variable) or isinstance( + pyobj, fluid.framework.Block) or isinstance( + pyobj, fluid.framework.Program) or isinstance( + pyobj, fluid.framework.Operator) + flexible_data.type = module_desc_pb2.MAP from_pyobj_to_flexible_data(param.trainable, flexible_data.map.data['trainable']) @@ -55,10 +61,14 @@ def from_param_to_flexible_data(param, flexible_data): flexible_data.map.data['do_model_average']) from_pyobj_to_flexible_data(param.optimize_attr, flexible_data.map.data['optimize_attr']) - from_pyobj_to_flexible_data(param.regularizer, - flexible_data.map.data['regularizer']) - from_pyobj_to_flexible_data(param.gradient_clip_attr, - flexible_data.map.data['gradient_clip_attr']) + from_pyobj_to_flexible_data( + param.regularizer, + flexible_data.map.data['regularizer'], + obj_filter=paddle_obj_filter) + from_pyobj_to_flexible_data( + param.gradient_clip_attr, + flexible_data.map.data['gradient_clip_attr'], + obj_filter=paddle_obj_filter) def from_flexible_data_to_param(flexible_data): diff --git a/paddle_hub/utils.py b/paddle_hub/utils.py index 8da1f34b..0bcdd9ed 100644 --- a/paddle_hub/utils.py +++ b/paddle_hub/utils.py @@ -51,7 +51,23 @@ def get_keyed_type_of_pyobj(pyobj): return module_desc_pb2.STRING -def from_pyobj_to_flexible_data(pyobj, flexible_data): +def get_pykey(key, keyed_type): + if keyed_type == module_desc_pb2.BOOLEAN: + return bool(key) + elif keyed_type == module_desc_pb2.INT: + return int(key) + elif keyed_type == module_desc_pb2.STRING: + return str(key) + elif keyed_type == module_desc_pb2.FLOAT: + return float(key) + return str(key) + + +#TODO(wuzewu): solving the problem of circular references +def from_pyobj_to_flexible_data(pyobj, flexible_data, obj_filter=None): + if obj_filter and obj_filter(pyobj): + logger.info("filter python object") + return if isinstance(pyobj, bool): flexible_data.type = module_desc_pb2.BOOLEAN flexible_data.b = pyobj @@ -67,25 +83,31 @@ def from_pyobj_to_flexible_data(pyobj, flexible_data): elif isinstance(pyobj, list) or isinstance(pyobj, tuple): flexible_data.type = module_desc_pb2.LIST for index, obj in enumerate(pyobj): - from_pyobj_to_flexible_data(obj, - flexible_data.list.data[str(index)]) + from_pyobj_to_flexible_data( + obj, flexible_data.list.data[str(index)], obj_filter) elif isinstance(pyobj, set): flexible_data.type = module_desc_pb2.SET for index, obj in enumerate(list(pyobj)): - from_pyobj_to_flexible_data(obj, flexible_data.set.data[str(index)]) + from_pyobj_to_flexible_data(obj, flexible_data.set.data[str(index)], + obj_filter) elif isinstance(pyobj, dict): flexible_data.type = module_desc_pb2.MAP for key, value in pyobj.items(): - from_pyobj_to_flexible_data(value, flexible_data.map.data[str(key)]) + from_pyobj_to_flexible_data(value, flexible_data.map.data[str(key)], + obj_filter) flexible_data.map.keyType[str(key)] = get_keyed_type_of_pyobj(key) elif isinstance(pyobj, type(None)): flexible_data.type = module_desc_pb2.NONE else: flexible_data.type = module_desc_pb2.OBJECT flexible_data.name = str(pyobj.__class__.__name__) + if not hasattr(pyobj, "__dict__"): + logger.warning( + "python obj %s has not __dict__ attr" % flexible_data.name) + return for key, value in pyobj.__dict__.items(): - from_pyobj_to_flexible_data(value, - flexible_data.object.data[str(key)]) + from_pyobj_to_flexible_data( + value, flexible_data.object.data[str(key)], obj_filter) flexible_data.object.keyType[str(key)] = get_keyed_type_of_pyobj( key) @@ -112,7 +134,7 @@ def from_flexible_data_to_pyobj(flexible_data): elif flexible_data.type == module_desc_pb2.MAP: result = {} for key, value in flexible_data.map.data.items(): - key = flexible_data.map.keyType[key] + key = get_pykey(key, flexible_data.map.keyType[key]) result[key] = from_flexible_data_to_pyobj(value) elif flexible_data.type == module_desc_pb2.NONE: result = None -- GitLab