提交 332f3a0c 编写于 作者: W wuzewu

Fix the problem that static graph model fails to execute in the 2.0rc version

上级 39338512
......@@ -29,6 +29,7 @@ class ModuleV1(object):
'''
'''
@paddle_utils.run_in_static_mode
def __init__(self, name: str = None, directory: str = None, version: str = None):
if not directory:
return
......@@ -101,7 +102,7 @@ class ModuleV1(object):
def _load_model(self):
model_path = os.path.join(self.directory, 'model')
exe = paddle.static.Executor(paddle.CPUPlace())
self.program, _, _ = paddle.io.load_inference_model(model_path, executor=exe)
self.program, _, _ = paddle.static.load_inference_model(model_path, executor=exe)
# Clear the callstack since it may leak the privacy of the creator.
for block in self.program.blocks:
......@@ -110,6 +111,7 @@ class ModuleV1(object):
continue
op._set_attr('op_callstack', [''])
@paddle_utils.run_in_static_mode
def context(self, signature: str = None, for_test: bool = False,
trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]:
'''
......@@ -136,6 +138,7 @@ class ModuleV1(object):
return feed_dict, fetch_dict, program
@paddle_utils.run_in_static_mode
def __call__(self, sign_name: str, data: dict, use_gpu: bool = False, batch_size: int = 1, **kwargs):
'''
'''
......@@ -186,7 +189,7 @@ class ModuleV1(object):
# previously generated.
cls_uuid = utils.md5(module_info.name + module_info.author + module_info.author_email + module_info.type +
module_info.summary + module_info.version + directory)
cls = type(cls_uuid, (cls, ), {})
cls = type('ModuleV1_{}'.format(cls_uuid), (cls, ), {})
cls.name = module_info.name
cls.author = module_info.author
......
......@@ -186,7 +186,7 @@ class TransformerModule(NLPBaseModule):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
paddle.io.load(
paddle.static.load(
executor=exe,
model_path=pretraining_params_path,
program=main_program,
......@@ -195,6 +195,7 @@ class TransformerModule(NLPBaseModule):
def param_prefix(self) -> str:
return '@HUB_%s@' % self.name
@paddle_utils.run_in_static_mode
def context(
self,
max_seq_len: int = None,
......@@ -225,23 +226,26 @@ class TransformerModule(NLPBaseModule):
startup_program = paddle.static.Program()
with paddle.static.program_guard(module_program, startup_program):
with paddle.fluid.unique_name.guard():
input_ids = paddle.data(name='input_ids', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
position_ids = paddle.data(name='position_ids', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
segment_ids = paddle.data(name='segment_ids', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
input_mask = paddle.data(name='input_mask', shape=[-1, max_seq_len, 1], dtype='float32', lod_level=0)
input_ids = paddle.static.data(name='input_ids', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
position_ids = paddle.static.data(
name='position_ids', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
segment_ids = paddle.static.data(
name='segment_ids', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
input_mask = paddle.static.data(
name='input_mask', shape=[-1, max_seq_len, 1], dtype='float32', lod_level=0)
pooled_output, sequence_output = self.net(input_ids, position_ids, segment_ids, input_mask)
data_list = [(input_ids, position_ids, segment_ids, input_mask)]
output_name_list = [(pooled_output.name, sequence_output.name)]
if num_slots > 1:
input_ids_2 = paddle.data(
input_ids_2 = paddle.static.data(
name='input_ids_2', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
position_ids_2 = paddle.data(
position_ids_2 = paddle.static.data(
name='position_ids_2', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
segment_ids_2 = paddle.data(
segment_ids_2 = paddle.static.data(
name='segment_ids_2', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
input_mask_2 = paddle.data(
input_mask_2 = paddle.static.data(
name='input_mask_2', shape=[-1, max_seq_len, 1], dtype='float32', lod_level=0)
pooled_output_2, sequence_output_2 = self.net(input_ids_2, position_ids_2, segment_ids_2,
input_mask_2)
......@@ -249,13 +253,13 @@ class TransformerModule(NLPBaseModule):
output_name_list.append((pooled_output_2.name, sequence_output_2.name))
if num_slots > 2:
input_ids_3 = paddle.data(
input_ids_3 = paddle.static.data(
name='input_ids_3', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
position_ids_3 = paddle.data(
position_ids_3 = paddle.static.data(
name='position_ids_3', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
segment_ids_3 = paddle.data(
segment_ids_3 = paddle.static.data(
name='segment_ids_3', shape=[-1, max_seq_len, 1], dtype='int64', lod_level=0)
input_mask_3 = paddle.data(
input_mask_3 = paddle.static.data(
name='input_mask_3', shape=[-1, max_seq_len, 1], dtype='float32', lod_level=0)
pooled_output_3, sequence_output_3 = self.net(input_ids_3, position_ids_3, segment_ids_3,
input_mask_3)
......@@ -308,6 +312,7 @@ class TransformerModule(NLPBaseModule):
return inputs, outputs, module_program
@paddle_utils.run_in_static_mode
def get_embedding(self, texts: List[str], max_seq_len: int = 512, use_gpu: bool = False, batch_size: int = 1):
'''
get pooled_output and sequence_output for input texts.
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
from typing import Callable, List
......@@ -195,3 +196,27 @@ def set_op_attr(program: paddle.static.Program, is_test: bool = False):
continue
op._set_attr('is_test', is_test)
@contextlib.contextmanager
def static_mode_guard():
''''''
premode = 'static' if not paddle.in_dynamic_mode() else 'dynamic'
if premode == 'dynamic':
paddle.enable_static()
yield
if premode == 'dynamic':
paddle.disable_static()
def run_in_static_mode(func):
''''''
def runner(*args, **kwargs):
with static_mode_guard():
return func(*args, **kwargs)
return runner
......@@ -546,7 +546,7 @@ class BaseTask(object):
def save_inference_model(self, dirname: str, model_filename: str = None, params_filename: str = None):
with self.phase_guard('predict'):
paddle.io.save_inference_model(
paddle.static.save_inference_model(
dirname=dirname,
executor=self.exe,
main_program=self.main_program,
......
......@@ -38,7 +38,7 @@ class TransformerEmbeddingTask(BaseTask):
def _build_net(self) -> List[paddle.static.Variable]:
# ClassifyReader will return the seqence length of an input text
self.seq_len = paddle.data(name='seq_len', shape=[1], dtype='int64', lod_level=0)
self.seq_len = paddle.static.data(name='seq_len', shape=[1], dtype='int64', lod_level=0)
return [self.pooled_feature, self.seq_feature]
def _postprocessing(self, run_states: List[RunState]) -> List[List[np.ndarray]]:
......
......@@ -145,6 +145,7 @@ class LocalModuleManager(object):
return name.replace('-', '_')
def install(self,
*,
name: str = None,
directory: str = None,
archive: str = None,
......@@ -356,5 +357,7 @@ class LocalModuleManager(object):
for path, ds, ts in xarfile.unarchive_with_progress(archive, _tdir):
bar.update(float(ds) / ts)
path = path.split(os.sep)[0]
return self._install_from_directory(os.path.join(_tdir, path))
# Sometimes the path contains '.'
path = os.path.normpath(path)
directory = os.path.join(_tdir, path.split(os.sep)[0])
return self._install_from_directory(directory)
......@@ -18,12 +18,15 @@ import builtins
import inspect
import importlib
import os
import re
import sys
from typing import Callable, Generic, List, Optional
from easydict import EasyDict
import paddle
from paddlehub.utils import parser, log, utils
from paddlehub.compat import paddle_utils
from paddlehub.compat.module.module_v1 import ModuleV1
......@@ -171,7 +174,7 @@ class Module(object):
user_module._initialize(**kwargs)
return user_module
if user_module_cls == ModuleV1:
if issubclass(user_module_cls, ModuleV1):
return user_module_cls(directory=directory, **kwargs)
user_module_cls.directory = directory
......@@ -193,7 +196,7 @@ class Module(object):
user_module._initialize(**kwargs)
return user_module
if user_module_cls == ModuleV1:
if issubclass(user_module_cls, ModuleV1):
return user_module_cls(directory=directory, **kwargs)
user_module_cls.directory = directory
......@@ -237,6 +240,23 @@ class RunModule(object):
else:
return None
# After the 2.0.0rc version, paddle uses the dynamic graph mode by default, which will cause the
# execution of the static graph model to fail, so compatibility protection is required.
def __getattribute__(self, attr):
_attr = object.__getattribute__(self, attr)
# If the acquired attribute is a built-in property of the object, skip it.
if re.match('__.*__', attr):
return _attr
# If the module is a dygraph model, skip it.
elif isinstance(self, paddle.nn.Layer):
return _attr
# If the acquired attribute is not a class method, skip it.
elif not inspect.ismethod(_attr):
return _attr
return paddle_utils.run_in_static_mode(_attr)
@classmethod
def get_py_requirements(cls) -> List[str]:
'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册