提交 a0d1c58d 编写于 作者: W wuzewu

fix serialization bug

上级 1ebde3bd
...@@ -48,6 +48,12 @@ def get_variable_info(var): ...@@ -48,6 +48,12 @@ def get_variable_info(var):
def from_param_to_flexible_data(param, flexible_data): def from_param_to_flexible_data(param, flexible_data):
def paddle_obj_filter(pyobj):
return isinstance(pyobj, fluid.framework.Variable) or isinstance(
pyobj, fluid.framework.Block) or isinstance(
pyobj, fluid.framework.Program) or isinstance(
pyobj, fluid.framework.Operator)
flexible_data.type = module_desc_pb2.MAP flexible_data.type = module_desc_pb2.MAP
from_pyobj_to_flexible_data(param.trainable, from_pyobj_to_flexible_data(param.trainable,
flexible_data.map.data['trainable']) flexible_data.map.data['trainable'])
...@@ -55,10 +61,14 @@ def from_param_to_flexible_data(param, flexible_data): ...@@ -55,10 +61,14 @@ def from_param_to_flexible_data(param, flexible_data):
flexible_data.map.data['do_model_average']) flexible_data.map.data['do_model_average'])
from_pyobj_to_flexible_data(param.optimize_attr, from_pyobj_to_flexible_data(param.optimize_attr,
flexible_data.map.data['optimize_attr']) flexible_data.map.data['optimize_attr'])
from_pyobj_to_flexible_data(param.regularizer, from_pyobj_to_flexible_data(
flexible_data.map.data['regularizer']) param.regularizer,
from_pyobj_to_flexible_data(param.gradient_clip_attr, flexible_data.map.data['regularizer'],
flexible_data.map.data['gradient_clip_attr']) obj_filter=paddle_obj_filter)
from_pyobj_to_flexible_data(
param.gradient_clip_attr,
flexible_data.map.data['gradient_clip_attr'],
obj_filter=paddle_obj_filter)
def from_flexible_data_to_param(flexible_data): def from_flexible_data_to_param(flexible_data):
......
...@@ -51,7 +51,23 @@ def get_keyed_type_of_pyobj(pyobj): ...@@ -51,7 +51,23 @@ def get_keyed_type_of_pyobj(pyobj):
return module_desc_pb2.STRING return module_desc_pb2.STRING
def from_pyobj_to_flexible_data(pyobj, flexible_data): def get_pykey(key, keyed_type):
if keyed_type == module_desc_pb2.BOOLEAN:
return bool(key)
elif keyed_type == module_desc_pb2.INT:
return int(key)
elif keyed_type == module_desc_pb2.STRING:
return str(key)
elif keyed_type == module_desc_pb2.FLOAT:
return float(key)
return str(key)
#TODO(wuzewu): solving the problem of circular references
def from_pyobj_to_flexible_data(pyobj, flexible_data, obj_filter=None):
if obj_filter and obj_filter(pyobj):
logger.info("filter python object")
return
if isinstance(pyobj, bool): if isinstance(pyobj, bool):
flexible_data.type = module_desc_pb2.BOOLEAN flexible_data.type = module_desc_pb2.BOOLEAN
flexible_data.b = pyobj flexible_data.b = pyobj
...@@ -67,25 +83,31 @@ def from_pyobj_to_flexible_data(pyobj, flexible_data): ...@@ -67,25 +83,31 @@ def from_pyobj_to_flexible_data(pyobj, flexible_data):
elif isinstance(pyobj, list) or isinstance(pyobj, tuple): elif isinstance(pyobj, list) or isinstance(pyobj, tuple):
flexible_data.type = module_desc_pb2.LIST flexible_data.type = module_desc_pb2.LIST
for index, obj in enumerate(pyobj): for index, obj in enumerate(pyobj):
from_pyobj_to_flexible_data(obj, from_pyobj_to_flexible_data(
flexible_data.list.data[str(index)]) obj, flexible_data.list.data[str(index)], obj_filter)
elif isinstance(pyobj, set): elif isinstance(pyobj, set):
flexible_data.type = module_desc_pb2.SET flexible_data.type = module_desc_pb2.SET
for index, obj in enumerate(list(pyobj)): for index, obj in enumerate(list(pyobj)):
from_pyobj_to_flexible_data(obj, flexible_data.set.data[str(index)]) from_pyobj_to_flexible_data(obj, flexible_data.set.data[str(index)],
obj_filter)
elif isinstance(pyobj, dict): elif isinstance(pyobj, dict):
flexible_data.type = module_desc_pb2.MAP flexible_data.type = module_desc_pb2.MAP
for key, value in pyobj.items(): for key, value in pyobj.items():
from_pyobj_to_flexible_data(value, flexible_data.map.data[str(key)]) from_pyobj_to_flexible_data(value, flexible_data.map.data[str(key)],
obj_filter)
flexible_data.map.keyType[str(key)] = get_keyed_type_of_pyobj(key) flexible_data.map.keyType[str(key)] = get_keyed_type_of_pyobj(key)
elif isinstance(pyobj, type(None)): elif isinstance(pyobj, type(None)):
flexible_data.type = module_desc_pb2.NONE flexible_data.type = module_desc_pb2.NONE
else: else:
flexible_data.type = module_desc_pb2.OBJECT flexible_data.type = module_desc_pb2.OBJECT
flexible_data.name = str(pyobj.__class__.__name__) flexible_data.name = str(pyobj.__class__.__name__)
if not hasattr(pyobj, "__dict__"):
logger.warning(
"python obj %s has not __dict__ attr" % flexible_data.name)
return
for key, value in pyobj.__dict__.items(): for key, value in pyobj.__dict__.items():
from_pyobj_to_flexible_data(value, from_pyobj_to_flexible_data(
flexible_data.object.data[str(key)]) value, flexible_data.object.data[str(key)], obj_filter)
flexible_data.object.keyType[str(key)] = get_keyed_type_of_pyobj( flexible_data.object.keyType[str(key)] = get_keyed_type_of_pyobj(
key) key)
...@@ -112,7 +134,7 @@ def from_flexible_data_to_pyobj(flexible_data): ...@@ -112,7 +134,7 @@ def from_flexible_data_to_pyobj(flexible_data):
elif flexible_data.type == module_desc_pb2.MAP: elif flexible_data.type == module_desc_pb2.MAP:
result = {} result = {}
for key, value in flexible_data.map.data.items(): for key, value in flexible_data.map.data.items():
key = flexible_data.map.keyType[key] key = get_pykey(key, flexible_data.map.keyType[key])
result[key] = from_flexible_data_to_pyobj(value) result[key] = from_flexible_data_to_pyobj(value)
elif flexible_data.type == module_desc_pb2.NONE: elif flexible_data.type == module_desc_pb2.NONE:
result = None result = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册