提交 a27f4307 编写于 作者: W wuzewu

record variable info

上级 b125aa1b
......@@ -30,7 +30,7 @@ from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger
from paddle_hub.signature import Signature
from paddle_hub.utils import to_list, mkdir
from paddle_hub.utils import to_list, mkdir, from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.paddle_helper import from_param_to_flexible_data, get_variable_info, from_flexible_data_to_param
from paddle_hub.version import __version__
......@@ -91,6 +91,19 @@ class Module(object):
stop_gradient=var.stop_gradient,
is_data=var.is_data)
def _process_variable_info(self):
var_infos = self.config.desc.extra_info.map.data['var_infos']
for var_info in var_infos.map.data:
idx = from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['block_id'])
stop_gradient = from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['stop_gradient'])
block = self.inference_program.blocks[idx]
var_name = HUB_VAR_PREFIX + var_info
if var_name in block.vars:
var = block.vars[var_name]
var.stop_gradient = stop_gradient
def __call__(self, sign_name="default", trainable=False):
""" Call default signature and return results
"""
......@@ -139,21 +152,23 @@ class Module(object):
logger.info("**feed_target_names**\n{}".format(self.feed_target_names))
logger.info("**fetch_targets**\n{}".format(self.fetch_targets))
self._process_parameter()
self._process_variable_info()
program = self.get_inference_program().clone()
_process_op_attr(program=program, is_test=False)
_set_param_trainable(program=program, trainable=trainable)
_process_op_attr(program=self.inference_program, is_test=False)
_set_param_trainable(
program=self.inference_program, trainable=trainable)
for key, value in feed_dict.items():
var = program.global_block().var(HUB_VAR_PREFIX + value)
var = self.inference_program.global_block().var(HUB_VAR_PREFIX +
value)
feed_dict[key] = var
for key, value in fetch_dict.items():
var = program.global_block().var(HUB_VAR_PREFIX + value)
var = self.inference_program.global_block().var(HUB_VAR_PREFIX +
value)
fetch_dict[key] = var
return feed_dict, fetch_dict, program
return feed_dict, fetch_dict, self.inference_program
def get_inference_program(self):
return self.inference_program
......@@ -256,7 +271,7 @@ class ModuleConfig(object):
return os.path.join(meta_path, PARAM_FILENAME)
def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
def create_module(sign_arr, module_dir=None, word_dict=None, exe=None):
""" Create a module from main program
"""
assert sign_arr, "signature array should not be None"
......@@ -291,7 +306,6 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
module_desc.auth_info.paddle_version = paddle.__version__
logger.info("hub version is %s" % __version__)
logger.info("paddle version is %s" % paddle.__version__)
program = program.clone()
# save asset
if word_dict is None:
......@@ -312,9 +326,20 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
param_attr = param_attrs.map.data[param.name]
from_param_to_flexible_data(param, param_attr)
# save Variable Info
var_infos = extra_info.map.data['var_infos']
var_infos.type = module_desc_pb2.MAP
for block in program.blocks:
for var in block.vars.values():
var_info = var_infos.map.data[var.name]
var_info.type = module_desc_pb2.MAP
from_pyobj_to_flexible_data(var.stop_gradient,
var_info.map.data['stop_gradient'])
from_pyobj_to_flexible_data(block.idx,
var_info.map.data['block_id'])
# save signarture info
sign_map = module_desc.sign2var
program = sign_arr[0].get_inputs()[0].block.program
for sign in sign_arr:
if sign.get_name() in sign_map:
raise "Error! sign_arr contains repeat signatrue %s" % sign
......@@ -335,7 +360,8 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
fetch_var.alias = fetch_names[index]
# save inference program
if not place:
program = program.clone()
if not exe:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
save_model_dir = os.path.join(module_dir, "model")
......
......@@ -66,7 +66,6 @@ def get_pykey(key, keyed_type):
#TODO(wuzewu): solving the problem of circular references
def from_pyobj_to_flexible_data(pyobj, flexible_data, obj_filter=None):
if obj_filter and obj_filter(pyobj):
logger.info("filter python object")
return
if isinstance(pyobj, bool):
flexible_data.type = module_desc_pb2.BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册