“ba3b2eb3a5c288bd898d057a77682cecf043836c”上不存在“doc/design/graph.html”
提交 f5dc2a65 编写于 作者: W wuzewu

update nlp_module

上级 bd707811
...@@ -257,8 +257,8 @@ class Module(fluid.dygraph.Layer): ...@@ -257,8 +257,8 @@ class Module(fluid.dygraph.Layer):
def _initialize(self): def _initialize(self):
pass pass
def forward(self, *args): def forward(self, *args, **kwargs):
return self.model_runner(*args) return self.model_runner(*args, **kwargs)
class ModuleHelper(object): class ModuleHelper(object):
......
...@@ -353,6 +353,13 @@ class TransformerModule(NLPBaseModule): ...@@ -353,6 +353,13 @@ class TransformerModule(NLPBaseModule):
return inputs, outputs, module_program return inputs, outputs, module_program
@property
def model_runner(self):
if not self._model_runner:
self._model_runner = fluid.dygraph.StaticModelRunner(
self.params_path)
return self._model_runner
def get_embedding(self, texts, use_gpu=False, batch_size=1): def get_embedding(self, texts, use_gpu=False, batch_size=1):
""" """
get pooled_output and sequence_output for input texts. get pooled_output and sequence_output for input texts.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册