提交 bd707811 编写于 作者: W wuzewu

Add dygraph support

上级 15d102eb
...@@ -89,7 +89,7 @@ def moduleinfo(name, version, author, author_email, summary, type): ...@@ -89,7 +89,7 @@ def moduleinfo(name, version, author, author_email, summary, type):
return _wrapper return _wrapper
class Module(object): class Module(fluid.dygraph.Layer):
def __new__(cls, def __new__(cls,
name=None, name=None,
directory=None, directory=None,
...@@ -121,7 +121,7 @@ class Module(object): ...@@ -121,7 +121,7 @@ class Module(object):
module = Module.init_with_directory( module = Module.init_with_directory(
directory=directory, **kwargs) directory=directory, **kwargs)
else: else:
module = object.__new__(cls) module = fluid.dygraph.Layer.__new__(cls)
return module return module
...@@ -131,6 +131,7 @@ class Module(object): ...@@ -131,6 +131,7 @@ class Module(object):
module_dir=None, module_dir=None,
version=None, version=None,
**kwargs): **kwargs):
super(Module, self).__init__()
# Avoid module being initialized multiple times # Avoid module being initialized multiple times
if "_is_initialize" in self.__dict__ and self._is_initialize: if "_is_initialize" in self.__dict__ and self._is_initialize:
return return
...@@ -145,6 +146,14 @@ class Module(object): ...@@ -145,6 +146,14 @@ class Module(object):
self._initialize(**kwargs) self._initialize(**kwargs)
self._is_initialize = True self._is_initialize = True
self._code_version = "v2" self._code_version = "v2"
self._model_runner = None
@property
def model_runner(self):
if not self._model_runner:
self._model_runner = fluid.dygraph.StaticModelRunner(
self.default_pretrained_model_path)
return self._model_runner
def _get_func_name(self, current_cls, module_func_dict): def _get_func_name(self, current_cls, module_func_dict):
mod = current_cls.__module__ + "." + current_cls.__name__ mod = current_cls.__module__ + "." + current_cls.__name__
...@@ -248,6 +257,9 @@ class Module(object): ...@@ -248,6 +257,9 @@ class Module(object):
def _initialize(self): def _initialize(self):
pass pass
def forward(self, *args):
return self.model_runner(*args)
class ModuleHelper(object): class ModuleHelper(object):
def __init__(self, directory): def __init__(self, directory):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册