未验证 提交 9adff317 编写于 作者: K kinghuin 提交者: GitHub

surport ernie module v2 (#442)

* surport ernie module v2
上级 b1f19840
......@@ -120,9 +120,7 @@ def get_depth_parameter(main_program):
return updated_depth_params_dict
def set_gradual_unfreeze(main_program, unfreeze_depths):
depth_params_dict = get_depth_parameter(main_program)
def set_gradual_unfreeze(depth_params_dict, unfreeze_depths):
for depth in unfreeze_depths:
for index, param in enumerate(depth_params_dict[depth]):
depth_params_dict[depth][index].stop_gradient = False
......@@ -509,7 +507,7 @@ class CombinedStrategy(DefaultStrategy):
if self.max_depth > 0 and self.epoch <= self.scheduler[
"gradual_unfreeze"]["blocks"]:
set_gradual_unfreeze(
self.main_program,
depth_params_dict=self.depth_params_dict,
unfreeze_depths=self.
sorted_depth[:self.max_depth * self.epoch //
self.scheduler["gradual_unfreeze"]["blocks"]])
......
......@@ -76,7 +76,7 @@ class LocalModuleManager(object):
sys.modules[_item.__module__].__file__)
if issubclass(
_item,
hub.Module) and _file.startwith(module_file):
hub.Module) and _file.startswith(module_file):
version = _item._version
break
sys.path.pop(0)
......
......@@ -137,7 +137,8 @@ class Module(object):
_run_func_name = self._get_func_name(self.__class__,
_module_runnable_func)
self._run_func = getattr(self, _run_func_name)
self._run_func = getattr(self,
_run_func_name) if _run_func_name else None
self._serving_func_name = self._get_func_name(self.__class__,
_module_serving_func)
self._directory = directory
......
......@@ -26,6 +26,7 @@ import six
import numpy as np
import paddle.fluid as fluid
from paddlehub.common import paddle_helper
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
import paddlehub as hub
from paddlehub.common.logger import logger
......@@ -265,6 +266,9 @@ class TransformerModule(NLPBaseModule):
logger.info("Load pretraining parameters from {}.".format(
pretraining_params_path))
def param_prefix(self):
return "@HUB_%s@" % self.name
def context(
self,
max_seq_len=128,
......@@ -330,8 +334,13 @@ class TransformerModule(NLPBaseModule):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# To be compatible with the module v1
vars = filter(lambda var: "tmp" not in var,
list(module_program.global_block().vars.keys())[4:])
paddle_helper.add_vars_prefix(
program=module_program, prefix=self.param_prefix(), vars=vars)
self.init_pretraining_params(
exe, self.params_path, main_program=startup_program)
exe, self.params_path, main_program=module_program)
self.params_layer = {}
for param in module_program.global_block().iter_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册