Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
f5dc2a65
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“ba3b2eb3a5c288bd898d057a77682cecf043836c”上不存在“doc/design/graph.html”
提交
f5dc2a65
编写于
4月 23, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update nlp_module
上级
bd707811
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
9 addition
and
2 deletion
+9
-2
paddlehub/module/module.py
paddlehub/module/module.py
+2
-2
paddlehub/module/nlp_module.py
paddlehub/module/nlp_module.py
+7
-0
未找到文件。
paddlehub/module/module.py
浏览文件 @
f5dc2a65
...
@@ -257,8 +257,8 @@ class Module(fluid.dygraph.Layer):
...
@@ -257,8 +257,8 @@ class Module(fluid.dygraph.Layer):
def
_initialize
(
self
):
def
_initialize
(
self
):
pass
pass
def
forward
(
self
,
*
args
):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
model_runner
(
*
args
)
return
self
.
model_runner
(
*
args
,
**
kwargs
)
class
ModuleHelper
(
object
):
class
ModuleHelper
(
object
):
...
...
paddlehub/module/nlp_module.py
浏览文件 @
f5dc2a65
...
@@ -353,6 +353,13 @@ class TransformerModule(NLPBaseModule):
...
@@ -353,6 +353,13 @@ class TransformerModule(NLPBaseModule):
return
inputs
,
outputs
,
module_program
return
inputs
,
outputs
,
module_program
@
property
def
model_runner
(
self
):
if
not
self
.
_model_runner
:
self
.
_model_runner
=
fluid
.
dygraph
.
StaticModelRunner
(
self
.
params_path
)
return
self
.
_model_runner
def
get_embedding
(
self
,
texts
,
use_gpu
=
False
,
batch_size
=
1
):
def
get_embedding
(
self
,
texts
,
use_gpu
=
False
,
batch_size
=
1
):
"""
"""
get pooled_output and sequence_output for input texts.
get pooled_output and sequence_output for input texts.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录