diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 1ea4758e9434ba157b486c9b562fccbf067529e6..86b5a8fd41ef4b33ed87147fc8bc75b79e4c780c 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -84,6 +84,25 @@ class Module(object): self.config = ModuleConfig(self.module_dir) self.config.load() + self._process_parameter() + + def _process_parameter(self): + global_block = self.inference_program.global_block() + for param in self.config.desc.parameters: + name = param.name + if name in global_block.vars: + var = global_block.vars[name] + global_block.create_parameter( + name=name, + trainable=param.trainable, + shape=var.shape, + dtype=var.dtype, + optimize_attr={'learning_rate': param.learning_rate}, + type=var.type, + lod_level=var.lod_level, + error_clip=var.error_clip, + stop_gradient=var.stop_gradient, + is_data=var.is_data) def _construct_feed_dict(self, inputs): """ Construct feed dict according to user's inputs and module config. @@ -253,8 +272,9 @@ class ModuleConfig(object): with open(pb_path, "rb") as fi: self.desc.ParseFromString(fi.read()) - print("self.desc.sign2var", - self.desc.sign2var["default"].feed_desc[0].var_name) + +# print("self.desc.sign2var", +# self.desc.sign2var["default"].feed_desc[0].var_name) if self.desc.contain_assets: # load assets