diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 1b28432fbf71f92f6ca64ba88b0ea3bb361e3b70..83305c31e1ffc6be7f25c404e3b6e50b635fe18e 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -30,7 +30,7 @@ from paddle_hub.downloader import download_and_uncompress from paddle_hub import module_desc_pb2 from paddle_hub.logger import logger from paddle_hub.signature import Signature -from paddle_hub.utils import to_list, mkdir +from paddle_hub.utils import to_list, mkdir, from_pyobj_to_flexible_data, from_flexible_data_to_pyobj from paddle_hub.paddle_helper import from_param_to_flexible_data, get_variable_info, from_flexible_data_to_param from paddle_hub.version import __version__ @@ -91,6 +91,19 @@ class Module(object): stop_gradient=var.stop_gradient, is_data=var.is_data) + def _process_variable_info(self): + var_infos = self.config.desc.extra_info.map.data['var_infos'] + for var_info in var_infos.map.data: + idx = from_flexible_data_to_pyobj( + var_infos.map.data[var_info].map.data['block_id']) + stop_gradient = from_flexible_data_to_pyobj( + var_infos.map.data[var_info].map.data['stop_gradient']) + block = self.inference_program.blocks[idx] + var_name = HUB_VAR_PREFIX + var_info + if var_name in block.vars: + var = block.vars[var_name] + var.stop_gradient = stop_gradient + def __call__(self, sign_name="default", trainable=False): """ Call default signature and return results """ @@ -139,21 +152,23 @@ class Module(object): logger.info("**feed_target_names**\n{}".format(self.feed_target_names)) logger.info("**fetch_targets**\n{}".format(self.fetch_targets)) self._process_parameter() + self._process_variable_info() - program = self.get_inference_program().clone() - - _process_op_attr(program=program, is_test=False) - _set_param_trainable(program=program, trainable=trainable) + _process_op_attr(program=self.inference_program, is_test=False) + _set_param_trainable( + program=self.inference_program, trainable=trainable) for key, value in feed_dict.items(): - var = program.global_block().var(HUB_VAR_PREFIX + value) + var = self.inference_program.global_block().var(HUB_VAR_PREFIX + + value) feed_dict[key] = var for key, value in fetch_dict.items(): - var = program.global_block().var(HUB_VAR_PREFIX + value) + var = self.inference_program.global_block().var(HUB_VAR_PREFIX + + value) fetch_dict[key] = var - return feed_dict, fetch_dict, program + return feed_dict, fetch_dict, self.inference_program def get_inference_program(self): return self.inference_program @@ -256,7 +271,7 @@ class ModuleConfig(object): return os.path.join(meta_path, PARAM_FILENAME) -def create_module(sign_arr, module_dir=None, word_dict=None, place=None): +def create_module(sign_arr, module_dir=None, word_dict=None, exe=None): """ Create a module from main program """ assert sign_arr, "signature array should not be None" @@ -291,7 +306,6 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None): module_desc.auth_info.paddle_version = paddle.__version__ logger.info("hub version is %s" % __version__) logger.info("paddle version is %s" % paddle.__version__) - program = program.clone() # save asset if word_dict is None: @@ -312,9 +326,20 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None): param_attr = param_attrs.map.data[param.name] from_param_to_flexible_data(param, param_attr) + # save Variable Info + var_infos = extra_info.map.data['var_infos'] + var_infos.type = module_desc_pb2.MAP + for block in program.blocks: + for var in block.vars.values(): + var_info = var_infos.map.data[var.name] + var_info.type = module_desc_pb2.MAP + from_pyobj_to_flexible_data(var.stop_gradient, + var_info.map.data['stop_gradient']) + from_pyobj_to_flexible_data(block.idx, + var_info.map.data['block_id']) + # save signarture info sign_map = module_desc.sign2var - program = sign_arr[0].get_inputs()[0].block.program for sign in sign_arr: if sign.get_name() in sign_map: raise "Error! sign_arr contains repeat signatrue %s" % sign @@ -335,9 +360,10 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None): fetch_var.alias = fetch_names[index] # save inference program - if not place: + program = program.clone() + if not exe: place = fluid.CPUPlace() - exe = fluid.Executor(place=place) + exe = fluid.Executor(place=place) save_model_dir = os.path.join(module_dir, "model") mkdir(save_model_dir) fluid.io.save_inference_model( diff --git a/paddle_hub/utils.py b/paddle_hub/utils.py index f68dc59b8990cfc09d7bde9afcd006bc33ab1161..aee28de9289baff62cab5901105765732a6d5cf9 100644 --- a/paddle_hub/utils.py +++ b/paddle_hub/utils.py @@ -66,7 +66,6 @@ def get_pykey(key, keyed_type): #TODO(wuzewu): solving the problem of circular references def from_pyobj_to_flexible_data(pyobj, flexible_data, obj_filter=None): if obj_filter and obj_filter(pyobj): - logger.info("filter python object") return if isinstance(pyobj, bool): flexible_data.type = module_desc_pb2.BOOLEAN