未验证 提交 9adff317 编写于 作者: K kinghuin 提交者: GitHub

surport ernie module v2 (#442)

* surport ernie module v2
上级 b1f19840
...@@ -120,9 +120,7 @@ def get_depth_parameter(main_program): ...@@ -120,9 +120,7 @@ def get_depth_parameter(main_program):
return updated_depth_params_dict return updated_depth_params_dict
def set_gradual_unfreeze(main_program, unfreeze_depths): def set_gradual_unfreeze(depth_params_dict, unfreeze_depths):
depth_params_dict = get_depth_parameter(main_program)
for depth in unfreeze_depths: for depth in unfreeze_depths:
for index, param in enumerate(depth_params_dict[depth]): for index, param in enumerate(depth_params_dict[depth]):
depth_params_dict[depth][index].stop_gradient = False depth_params_dict[depth][index].stop_gradient = False
...@@ -509,7 +507,7 @@ class CombinedStrategy(DefaultStrategy): ...@@ -509,7 +507,7 @@ class CombinedStrategy(DefaultStrategy):
if self.max_depth > 0 and self.epoch <= self.scheduler[ if self.max_depth > 0 and self.epoch <= self.scheduler[
"gradual_unfreeze"]["blocks"]: "gradual_unfreeze"]["blocks"]:
set_gradual_unfreeze( set_gradual_unfreeze(
self.main_program, depth_params_dict=self.depth_params_dict,
unfreeze_depths=self. unfreeze_depths=self.
sorted_depth[:self.max_depth * self.epoch // sorted_depth[:self.max_depth * self.epoch //
self.scheduler["gradual_unfreeze"]["blocks"]]) self.scheduler["gradual_unfreeze"]["blocks"]])
......
...@@ -76,7 +76,7 @@ class LocalModuleManager(object): ...@@ -76,7 +76,7 @@ class LocalModuleManager(object):
sys.modules[_item.__module__].__file__) sys.modules[_item.__module__].__file__)
if issubclass( if issubclass(
_item, _item,
hub.Module) and _file.startwith(module_file): hub.Module) and _file.startswith(module_file):
version = _item._version version = _item._version
break break
sys.path.pop(0) sys.path.pop(0)
......
...@@ -137,7 +137,8 @@ class Module(object): ...@@ -137,7 +137,8 @@ class Module(object):
_run_func_name = self._get_func_name(self.__class__, _run_func_name = self._get_func_name(self.__class__,
_module_runnable_func) _module_runnable_func)
self._run_func = getattr(self, _run_func_name) self._run_func = getattr(self,
_run_func_name) if _run_func_name else None
self._serving_func_name = self._get_func_name(self.__class__, self._serving_func_name = self._get_func_name(self.__class__,
_module_serving_func) _module_serving_func)
self._directory = directory self._directory = directory
......
...@@ -26,6 +26,7 @@ import six ...@@ -26,6 +26,7 @@ import six
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlehub.common import paddle_helper
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
import paddlehub as hub import paddlehub as hub
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
...@@ -265,6 +266,9 @@ class TransformerModule(NLPBaseModule): ...@@ -265,6 +266,9 @@ class TransformerModule(NLPBaseModule):
logger.info("Load pretraining parameters from {}.".format( logger.info("Load pretraining parameters from {}.".format(
pretraining_params_path)) pretraining_params_path))
def param_prefix(self):
return "@HUB_%s@" % self.name
def context( def context(
self, self,
max_seq_len=128, max_seq_len=128,
...@@ -330,8 +334,13 @@ class TransformerModule(NLPBaseModule): ...@@ -330,8 +334,13 @@ class TransformerModule(NLPBaseModule):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# To be compatible with the module v1
vars = filter(lambda var: "tmp" not in var,
list(module_program.global_block().vars.keys())[4:])
paddle_helper.add_vars_prefix(
program=module_program, prefix=self.param_prefix(), vars=vars)
self.init_pretraining_params( self.init_pretraining_params(
exe, self.params_path, main_program=startup_program) exe, self.params_path, main_program=module_program)
self.params_layer = {} self.params_layer = {}
for param in module_program.global_block().iter_parameters(): for param in module_program.global_block().iter_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册