提交 bd707811 编写于 作者: W wuzewu

Add dygraph support

上级 15d102eb
......@@ -89,7 +89,7 @@ def moduleinfo(name, version, author, author_email, summary, type):
return _wrapper
class Module(object):
class Module(fluid.dygraph.Layer):
def __new__(cls,
name=None,
directory=None,
......@@ -121,7 +121,7 @@ class Module(object):
module = Module.init_with_directory(
directory=directory, **kwargs)
else:
module = object.__new__(cls)
module = fluid.dygraph.Layer.__new__(cls)
return module
......@@ -131,6 +131,7 @@ class Module(object):
module_dir=None,
version=None,
**kwargs):
super(Module, self).__init__()
# Avoid module being initialized multiple times
if "_is_initialize" in self.__dict__ and self._is_initialize:
return
......@@ -145,6 +146,14 @@ class Module(object):
self._initialize(**kwargs)
self._is_initialize = True
self._code_version = "v2"
self._model_runner = None
@property
def model_runner(self):
if not self._model_runner:
self._model_runner = fluid.dygraph.StaticModelRunner(
self.default_pretrained_model_path)
return self._model_runner
def _get_func_name(self, current_cls, module_func_dict):
mod = current_cls.__module__ + "." + current_cls.__name__
......@@ -248,6 +257,9 @@ class Module(object):
def _initialize(self):
pass
def forward(self, *args):
return self.model_runner(*args)
class ModuleHelper(object):
def __init__(self, directory):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册