提交 a0d1c58d 编写于 作者: W wuzewu

fix serialization bug

上级 1ebde3bd
......@@ -48,6 +48,12 @@ def get_variable_info(var):
def from_param_to_flexible_data(param, flexible_data):
def paddle_obj_filter(pyobj):
return isinstance(pyobj, fluid.framework.Variable) or isinstance(
pyobj, fluid.framework.Block) or isinstance(
pyobj, fluid.framework.Program) or isinstance(
pyobj, fluid.framework.Operator)
flexible_data.type = module_desc_pb2.MAP
from_pyobj_to_flexible_data(param.trainable,
flexible_data.map.data['trainable'])
......@@ -55,10 +61,14 @@ def from_param_to_flexible_data(param, flexible_data):
flexible_data.map.data['do_model_average'])
from_pyobj_to_flexible_data(param.optimize_attr,
flexible_data.map.data['optimize_attr'])
from_pyobj_to_flexible_data(param.regularizer,
flexible_data.map.data['regularizer'])
from_pyobj_to_flexible_data(param.gradient_clip_attr,
flexible_data.map.data['gradient_clip_attr'])
from_pyobj_to_flexible_data(
param.regularizer,
flexible_data.map.data['regularizer'],
obj_filter=paddle_obj_filter)
from_pyobj_to_flexible_data(
param.gradient_clip_attr,
flexible_data.map.data['gradient_clip_attr'],
obj_filter=paddle_obj_filter)
def from_flexible_data_to_param(flexible_data):
......
......@@ -51,7 +51,23 @@ def get_keyed_type_of_pyobj(pyobj):
return module_desc_pb2.STRING
def from_pyobj_to_flexible_data(pyobj, flexible_data):
def get_pykey(key, keyed_type):
if keyed_type == module_desc_pb2.BOOLEAN:
return bool(key)
elif keyed_type == module_desc_pb2.INT:
return int(key)
elif keyed_type == module_desc_pb2.STRING:
return str(key)
elif keyed_type == module_desc_pb2.FLOAT:
return float(key)
return str(key)
#TODO(wuzewu): solving the problem of circular references
def from_pyobj_to_flexible_data(pyobj, flexible_data, obj_filter=None):
if obj_filter and obj_filter(pyobj):
logger.info("filter python object")
return
if isinstance(pyobj, bool):
flexible_data.type = module_desc_pb2.BOOLEAN
flexible_data.b = pyobj
......@@ -67,25 +83,31 @@ def from_pyobj_to_flexible_data(pyobj, flexible_data):
elif isinstance(pyobj, list) or isinstance(pyobj, tuple):
flexible_data.type = module_desc_pb2.LIST
for index, obj in enumerate(pyobj):
from_pyobj_to_flexible_data(obj,
flexible_data.list.data[str(index)])
from_pyobj_to_flexible_data(
obj, flexible_data.list.data[str(index)], obj_filter)
elif isinstance(pyobj, set):
flexible_data.type = module_desc_pb2.SET
for index, obj in enumerate(list(pyobj)):
from_pyobj_to_flexible_data(obj, flexible_data.set.data[str(index)])
from_pyobj_to_flexible_data(obj, flexible_data.set.data[str(index)],
obj_filter)
elif isinstance(pyobj, dict):
flexible_data.type = module_desc_pb2.MAP
for key, value in pyobj.items():
from_pyobj_to_flexible_data(value, flexible_data.map.data[str(key)])
from_pyobj_to_flexible_data(value, flexible_data.map.data[str(key)],
obj_filter)
flexible_data.map.keyType[str(key)] = get_keyed_type_of_pyobj(key)
elif isinstance(pyobj, type(None)):
flexible_data.type = module_desc_pb2.NONE
else:
flexible_data.type = module_desc_pb2.OBJECT
flexible_data.name = str(pyobj.__class__.__name__)
if not hasattr(pyobj, "__dict__"):
logger.warning(
"python obj %s has not __dict__ attr" % flexible_data.name)
return
for key, value in pyobj.__dict__.items():
from_pyobj_to_flexible_data(value,
flexible_data.object.data[str(key)])
from_pyobj_to_flexible_data(
value, flexible_data.object.data[str(key)], obj_filter)
flexible_data.object.keyType[str(key)] = get_keyed_type_of_pyobj(
key)
......@@ -112,7 +134,7 @@ def from_flexible_data_to_pyobj(flexible_data):
elif flexible_data.type == module_desc_pb2.MAP:
result = {}
for key, value in flexible_data.map.data.items():
key = flexible_data.map.keyType[key]
key = get_pykey(key, flexible_data.map.keyType[key])
result[key] = from_flexible_data_to_pyobj(value)
elif flexible_data.type == module_desc_pb2.NONE:
result = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册