Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
24788265
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
24788265
编写于
11月 19, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix module compatibility issues
上级
c4f19c8e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
3 deletion
+26
-3
paddlehub/__init__.py
paddlehub/__init__.py
+5
-1
paddlehub/compat/module/module_v1.py
paddlehub/compat/module/module_v1.py
+21
-2
未找到文件。
paddlehub/__init__.py
浏览文件 @
24788265
...
...
@@ -48,7 +48,11 @@ sys.modules['paddlehub.common.logger'] = log
sys
.
modules
[
'paddlehub.common.paddle_helper'
]
=
paddle_utils
sys
.
modules
[
'paddlehub.common.utils'
]
=
utils
sys
.
modules
[
'paddlehub.reader'
]
=
task
sys
.
modules
[
'paddlehub.reader.batching'
]
=
task
.
batch
AdamWeightDecayStrategy
=
lambda
:
0
ULMFiTStrategy
=
lambda
params_layer
=
0
:
0
common
=
EasyDict
(
paddle_helper
=
paddle_utils
)
dataset
=
EasyDict
(
Couplet
=
couplet
.
Couplet
)
AdamWeightDecayStrategy
=
lambda
:
0
finetune
=
EasyDict
(
strategy
=
EasyDict
(
ULMFiTStrategy
=
ULMFiTStrategy
))
logger
=
EasyDict
(
logger
=
log
.
logger
)
paddlehub/compat/module/module_v1.py
浏览文件 @
24788265
...
...
@@ -118,8 +118,8 @@ class ModuleV1(object):
op
.
_set_attr
(
'op_callstack'
,
[
''
])
@
paddle_utils
.
run_in_static_mode
def
context
(
self
,
signature
:
str
=
None
,
for_test
:
bool
=
False
,
trainable
:
bool
=
True
)
->
Tuple
[
dict
,
dict
,
paddle
.
static
.
Program
]:
def
context
(
self
,
signature
:
str
=
None
,
for_test
:
bool
=
False
,
trainable
:
bool
=
True
,
max_seq_len
:
int
=
128
)
->
Tuple
[
dict
,
dict
,
paddle
.
static
.
Program
]:
'''Get module context information, including graph structure and graph input and output variables.'''
program
=
self
.
program
.
clone
(
for_test
=
for_test
)
paddle_utils
.
remove_feed_fetch_op
(
program
)
...
...
@@ -141,8 +141,27 @@ class ModuleV1(object):
for
param
in
program
.
all_parameters
():
param
.
trainable
=
trainable
# The bert series model saved by ModuleV1 sets max_seq_len to 512 by default. We need to adjust max_seq_len
# according to the parameters in actual use.
if
'bert'
in
self
.
name
or
self
.
name
.
startswith
(
'ernie'
):
self
.
_update_bert_max_seq_len
(
program
,
feed_dict
,
max_seq_len
)
return
feed_dict
,
fetch_dict
,
program
def
_update_bert_max_seq_len
(
self
,
program
:
paddle
.
static
.
Program
,
feed_dict
:
dict
,
max_seq_len
:
int
=
128
):
MAX_SEQ_LENGTH
=
512
if
max_seq_len
>
MAX_SEQ_LENGTH
or
max_seq_len
<=
0
:
raise
ValueError
(
"max_seq_len({}) should be in the range of [1, {}]"
.
format
(
max_seq_len
,
MAX_SEQ_LENGTH
))
log
.
logger
.
info
(
"Set maximum sequence length of input tensor to {}"
.
format
(
max_seq_len
))
if
self
.
name
.
startswith
(
"ernie_v2"
):
feed_list
=
[
"input_ids"
,
"position_ids"
,
"segment_ids"
,
"input_mask"
,
"task_ids"
]
else
:
feed_list
=
[
"input_ids"
,
"position_ids"
,
"segment_ids"
,
"input_mask"
]
for
tensor_name
in
feed_list
:
seq_tensor_shape
=
[
-
1
,
max_seq_len
,
1
]
log
.
logger
.
info
(
"The shape of input tensor[{}] set to {}"
.
format
(
tensor_name
,
seq_tensor_shape
))
program
.
global_block
().
var
(
feed_dict
[
tensor_name
].
name
).
desc
.
set_shape
(
seq_tensor_shape
)
@
paddle_utils
.
run_in_static_mode
def
__call__
(
self
,
sign_name
:
str
,
data
:
dict
,
use_gpu
:
bool
=
False
,
batch_size
:
int
=
1
,
**
kwargs
):
'''Call the specified signature function for prediction.'''
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录