Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
f4d9b46b
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f4d9b46b
编写于
3月 02, 2021
作者:
S
Steffy-zxf
提交者:
GitHub
3月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the compatibility error caused by the upgrade of PretrainedTokenizer
上级
71d0cc9d
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
130 addition
and
83 deletion
+130
-83
paddlehub/datasets/base_nlp_dataset.py
paddlehub/datasets/base_nlp_dataset.py
+62
-22
paddlehub/module/nlp_module.py
paddlehub/module/nlp_module.py
+62
-53
paddlehub/utils/utils.py
paddlehub/utils/utils.py
+5
-7
requirements.txt
requirements.txt
+1
-1
未找到文件。
paddlehub/datasets/base_nlp_dataset.py
浏览文件 @
f4d9b46b
...
...
@@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
,
List
,
Optional
,
Union
,
Tuple
import
csv
import
io
import
os
from
typing
import
Dict
,
List
,
Optional
,
Union
,
Tuple
import
numpy
as
np
import
paddle
import
paddlenlp
from
packaging.version
import
Version
from
paddlehub.env
import
DATA_HOME
from
paddlenlp.transformers
import
PretrainedTokenizer
...
...
@@ -27,7 +29,6 @@ from paddlehub.utils.utils import download, reseg_token_label, pad_sequence, tru
from
paddlehub.utils.xarfile
import
is_xarfile
,
unarchive
class
InputExample
(
object
):
"""
The input data structure of Transformer modules (BERT, ERNIE and so on).
...
...
@@ -233,7 +234,16 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
records
=
[]
for
example
in
examples
:
if
isinstance
(
self
.
tokenizer
,
PretrainedTokenizer
):
record
=
self
.
tokenizer
.
encode
(
text
=
example
.
text_a
,
text_pair
=
example
.
text_b
,
max_seq_len
=
self
.
max_seq_len
)
if
Version
(
paddlenlp
.
__version__
)
<=
Version
(
'2.0.0rc2'
):
record
=
self
.
tokenizer
.
encode
(
text
=
example
.
text_a
,
text_pair
=
example
.
text_b
,
max_seq_len
=
self
.
max_seq_len
)
else
:
record
=
self
.
tokenizer
(
text
=
example
.
text_a
,
text_pair
=
example
.
text_b
,
max_seq_len
=
self
.
max_seq_len
,
pad_to_max_seq_len
=
True
,
return_length
=
True
)
elif
isinstance
(
self
.
tokenizer
,
JiebaTokenizer
):
pad_token
=
self
.
tokenizer
.
vocab
.
pad_token
...
...
@@ -246,7 +256,9 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
ids
=
pad_sequence
(
ids
,
self
.
max_seq_len
,
pad_token_id
)
record
=
{
'text'
:
ids
,
'seq_len'
:
seq_len
}
else
:
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
if
not
record
:
logger
.
info
(
...
...
@@ -260,17 +272,26 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
def
__getitem__
(
self
,
idx
):
record
=
self
.
records
[
idx
]
if
isinstance
(
self
.
tokenizer
,
PretrainedTokenizer
):
input_ids
=
np
.
array
(
record
[
'input_ids'
])
if
Version
(
paddlenlp
.
__version__
)
>=
Version
(
'2.0.0rc5'
):
token_type_ids
=
np
.
array
(
record
[
'token_type_ids'
])
else
:
token_type_ids
=
record
[
'segment_ids'
]
if
'label'
in
record
.
keys
():
return
np
.
array
(
record
[
'input_ids'
]),
np
.
array
(
record
[
'segment_ids'
])
,
np
.
array
(
record
[
'label'
],
dtype
=
np
.
int64
)
return
input_ids
,
token_type_ids
,
np
.
array
(
record
[
'label'
],
dtype
=
np
.
int64
)
else
:
return
np
.
array
(
record
[
'input_ids'
]),
np
.
array
(
record
[
'segment_ids'
])
return
input_ids
,
token_type_ids
elif
isinstance
(
self
.
tokenizer
,
JiebaTokenizer
):
if
'label'
in
record
.
keys
():
return
np
.
array
(
record
[
'text'
]),
np
.
array
(
record
[
'label'
],
dtype
=
np
.
int64
)
else
:
return
np
.
array
(
record
[
'text'
])
else
:
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
def
__len__
(
self
):
return
len
(
self
.
records
)
...
...
@@ -303,6 +324,7 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
is_file_with_header(:obj:bool, `optional`, default to :obj: False) :
Whether or not the file is with the header introduction.
"""
def
__init__
(
self
,
base_path
:
str
,
tokenizer
:
Union
[
PretrainedTokenizer
,
JiebaTokenizer
],
...
...
@@ -311,7 +333,7 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
data_file
:
str
=
None
,
label_file
:
str
=
None
,
label_list
:
list
=
None
,
split_char
:
str
=
"
\002
"
,
split_char
:
str
=
"
\002
"
,
no_entity_label
:
str
=
"O"
,
ignore_label
:
int
=
-
100
,
is_file_with_header
:
bool
=
False
):
...
...
@@ -365,7 +387,15 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
pad_token
=
self
.
tokenizer
.
pad_token
tokens
,
labels
=
reseg_token_label
(
tokenizer
=
self
.
tokenizer
,
tokens
=
tokens
,
labels
=
labels
)
if
Version
(
paddlenlp
.
__version__
)
<=
Version
(
'2.0.0rc2'
):
record
=
self
.
tokenizer
.
encode
(
text
=
tokens
,
max_seq_len
=
self
.
max_seq_len
)
else
:
record
=
self
.
tokenizer
(
text
=
tokens
,
max_seq_len
=
self
.
max_seq_len
,
pad_to_max_seq_len
=
True
,
is_split_into_words
=
True
,
return_length
=
True
)
elif
isinstance
(
self
.
tokenizer
,
JiebaTokenizer
):
pad_token
=
self
.
tokenizer
.
vocab
.
pad_token
...
...
@@ -379,12 +409,13 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
record
=
{
'text'
:
ids
,
'seq_len'
:
seq_len
}
else
:
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
if
not
record
:
logger
.
info
(
"The text %s has been dropped as it has no words in the vocab after tokenization."
%
example
.
text_a
)
"The text %s has been dropped as it has no words in the vocab after tokenization."
%
example
.
text_a
)
continue
# convert labels into record
...
...
@@ -395,37 +426,46 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
elif
isinstance
(
self
.
tokenizer
,
JiebaTokenizer
):
tokens_with_specical_token
=
[
self
.
tokenizer
.
vocab
.
to_tokens
(
id_
)
for
id_
in
record
[
'text'
]]
else
:
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
tokens_index
=
0
for
token
in
tokens_with_specical_token
:
if
tokens_index
<
len
(
tokens
)
and
token
==
tokens
[
tokens_index
]:
record
[
"label"
].
append
(
self
.
label_list
.
index
(
labels
[
tokens_index
]))
if
tokens_index
<
len
(
tokens
)
and
token
==
tokens
[
tokens_index
]:
record
[
"label"
].
append
(
self
.
label_list
.
index
(
labels
[
tokens_index
]))
tokens_index
+=
1
elif
token
in
[
pad_token
]:
record
[
"label"
].
append
(
self
.
ignore_label
)
# label of special token
else
:
record
[
"label"
].
append
(
self
.
label_list
.
index
(
self
.
no_entity_label
))
record
[
"label"
].
append
(
self
.
label_list
.
index
(
self
.
no_entity_label
))
records
.
append
(
record
)
return
records
def
__getitem__
(
self
,
idx
):
record
=
self
.
records
[
idx
]
if
isinstance
(
self
.
tokenizer
,
PretrainedTokenizer
):
input_ids
=
np
.
array
(
record
[
'input_ids'
])
seq_lens
=
np
.
array
(
record
[
'seq_len'
])
if
Version
(
paddlenlp
.
__version__
)
>=
Version
(
'2.0.0rc5'
):
token_type_ids
=
np
.
array
(
record
[
'token_type_ids'
])
else
:
token_type_ids
=
np
.
array
(
record
[
'segment_ids'
])
if
'label'
in
record
.
keys
():
return
np
.
array
(
record
[
'input_ids'
]),
np
.
array
(
record
[
'segment_ids'
]),
np
.
array
(
record
[
'seq_len'
])
,
np
.
array
(
record
[
'label'
],
dtype
=
np
.
int64
)
return
input_ids
,
token_type_ids
,
seq_lens
,
np
.
array
(
record
[
'label'
],
dtype
=
np
.
int64
)
else
:
return
np
.
array
(
record
[
'input_ids'
]),
np
.
array
(
record
[
'segment_ids'
]),
np
.
array
(
record
[
'seq_len'
])
return
input_ids
,
token_type_ids
,
seq_lens
elif
isinstance
(
self
.
tokenizer
,
JiebaTokenizer
):
if
'label'
in
record
.
keys
():
return
np
.
array
(
record
[
'text'
]),
np
.
array
(
record
[
'seq_len'
]),
np
.
array
(
record
[
'label'
],
dtype
=
np
.
int64
)
else
:
return
np
.
array
(
record
[
'text'
]),
np
.
array
(
record
[
'seq_len'
])
else
:
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
raise
RuntimeError
(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.
format
(
type
(
self
.
tokenizer
)))
def
__len__
(
self
):
return
len
(
self
.
records
)
paddlehub/module/nlp_module.py
浏览文件 @
f4d9b46b
...
...
@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# FIXME(zhangxuefei): remove this file after paddlenlp is released.
import
copy
import
functools
import
inspect
...
...
@@ -25,6 +22,7 @@ from typing import List, Tuple
import
paddle
import
paddle.nn
as
nn
from
packaging.version
import
Version
from
paddle.dataset.common
import
DATA_HOME
from
paddle.utils.download
import
get_path_from_url
from
paddlehub.module.module
import
serving
,
RunModule
,
runnable
...
...
@@ -32,11 +30,11 @@ from paddlehub.module.module import serving, RunModule, runnable
from
paddlehub.utils.log
import
logger
from
paddlehub.utils.utils
import
reseg_token_label
import
paddlenlp
from
paddlenlp.embeddings.token_embedding
import
EMBEDDING_HOME
,
EMBEDDING_URL_ROOT
from
paddlenlp.data
import
JiebaTokenizer
from
paddlehub.compat.module.nlp_module
import
DataFormatError
__all__
=
[
'PretrainedModel'
,
'register_base_model'
,
...
...
@@ -357,14 +355,9 @@ class TextServing(object):
"""
A base class for text model which supports serving.
"""
@
serving
def
predict_method
(
self
,
data
:
List
[
List
[
str
]],
max_seq_len
:
int
=
128
,
batch_size
:
int
=
1
,
use_gpu
:
bool
=
False
):
def
predict_method
(
self
,
data
:
List
[
List
[
str
]],
max_seq_len
:
int
=
128
,
batch_size
:
int
=
1
,
use_gpu
:
bool
=
False
):
"""
Run predict method as a service.
Serving as a task which is specified from serving config.
...
...
@@ -391,20 +384,16 @@ class TextServing(object):
if
self
.
task
==
'token-cls'
:
# remove labels of [CLS] token and pad tokens
results
=
[
token_labels
[
1
:
len
(
data
[
i
][
0
])
+
1
]
for
i
,
token_labels
in
enumerate
(
results
)
]
results
=
[
token_labels
[
1
:
len
(
data
[
i
][
0
])
+
1
]
for
i
,
token_labels
in
enumerate
(
results
)]
return
results
elif
self
.
task
is
None
:
# embedding service
results
=
self
.
get_embedding
(
data
,
use_gpu
)
return
results
else
:
# unknown service
logger
.
error
(
f
'Unknown task
{
self
.
task
}
, current tasks supported:
\n
'
logger
.
error
(
f
'Unknown task
{
self
.
task
}
, current tasks supported:
\n
'
'1. seq-cls: sequence classification service;
\n
'
'2. token-cls: sequence labeling service;
\n
'
'3. None: embedding service'
)
'3. None: embedding service'
)
return
...
...
@@ -422,11 +411,33 @@ class TransformerModule(RunModule, TextServing):
if
self
.
task
==
'token-cls'
:
# Extra processing of token-cls task
tokens
=
text
[
0
].
split
(
split_char
)
text
[
0
],
_
=
reseg_token_label
(
tokenizer
=
tokenizer
,
tokens
=
tokens
)
is_split_into_words
=
True
else
:
is_split_into_words
=
False
if
len
(
text
)
==
1
:
encoded_inputs
=
tokenizer
.
encode
(
text
[
0
],
text_pair
=
None
,
max_seq_len
=
max_seq_len
,
pad_to_max_seq_len
=
pad_to_max_seq_len
)
if
Version
(
paddlenlp
.
__version__
)
<=
Version
(
'2.0.0rc2'
):
encoded_inputs
=
tokenizer
.
encode
(
text
[
0
],
text_pair
=
None
,
max_seq_len
=
max_seq_len
,
pad_to_max_seq_len
=
pad_to_max_seq_len
)
else
:
encoded_inputs
=
tokenizer
(
text
=
text
[
0
],
max_seq_len
=
max_seq_len
,
pad_to_max_seq_len
=
True
,
is_split_into_words
=
is_split_into_words
,
return_length
=
True
)
elif
len
(
text
)
==
2
:
encoded_inputs
=
tokenizer
.
encode
(
text
[
0
],
text_pair
=
text
[
1
],
max_seq_len
=
max_seq_len
,
pad_to_max_seq_len
=
pad_to_max_seq_len
)
if
Version
(
paddlenlp
.
__version__
)
<=
Version
(
'2.0.0rc2'
):
encoded_inputs
=
tokenizer
.
encode
(
text
[
0
],
text_pair
=
text
[
1
],
max_seq_len
=
max_seq_len
,
pad_to_max_seq_len
=
pad_to_max_seq_len
)
else
:
encoded_inputs
=
tokenizer
(
text
=
text
[
0
],
text_pair
=
text
[
1
],
max_seq_len
=
max_seq_len
,
pad_to_max_seq_len
=
True
,
is_split_into_words
=
is_split_into_words
,
return_length
=
True
)
else
:
raise
RuntimeError
(
'The input text must have one or two sequence, but got %d. Please check your inputs.'
%
len
(
text
))
...
...
@@ -442,7 +453,14 @@ class TransformerModule(RunModule, TextServing):
examples
=
[]
for
text
in
data
:
encoded_inputs
=
self
.
_convert_text_to_input
(
tokenizer
,
text
,
max_seq_len
,
split_char
)
examples
.
append
((
encoded_inputs
[
'input_ids'
],
encoded_inputs
[
'segment_ids'
]))
input_ids
=
encoded_inputs
[
'input_ids'
]
if
Version
(
paddlenlp
.
__version__
)
>=
Version
(
'2.0.0rc5'
):
token_type_ids
=
encoded_inputs
[
'token_type_ids'
]
else
:
token_type_ids
=
encoded_inputs
[
'segment_ids'
]
examples
.
append
((
input_ids
,
token_type_ids
))
# Seperates data into some batches.
one_batch
=
[]
...
...
@@ -468,7 +486,8 @@ class TransformerModule(RunModule, TextServing):
if
self
.
task
==
'seq-cls'
:
predictions
,
avg_loss
,
metric
=
self
(
input_ids
=
batch
[
0
],
token_type_ids
=
batch
[
1
],
labels
=
batch
[
2
])
elif
self
.
task
==
'token-cls'
:
predictions
,
avg_loss
,
metric
=
self
(
input_ids
=
batch
[
0
],
token_type_ids
=
batch
[
1
],
seq_lengths
=
batch
[
2
],
labels
=
batch
[
3
])
predictions
,
avg_loss
,
metric
=
self
(
input_ids
=
batch
[
0
],
token_type_ids
=
batch
[
1
],
seq_lengths
=
batch
[
2
],
labels
=
batch
[
3
])
self
.
metric
.
reset
()
return
{
'loss'
:
avg_loss
,
'metrics'
:
metric
}
...
...
@@ -485,7 +504,8 @@ class TransformerModule(RunModule, TextServing):
if
self
.
task
==
'seq-cls'
:
predictions
,
avg_loss
,
metric
=
self
(
input_ids
=
batch
[
0
],
token_type_ids
=
batch
[
1
],
labels
=
batch
[
2
])
elif
self
.
task
==
'token-cls'
:
predictions
,
avg_loss
,
metric
=
self
(
input_ids
=
batch
[
0
],
token_type_ids
=
batch
[
1
],
seq_lengths
=
batch
[
2
],
labels
=
batch
[
3
])
predictions
,
avg_loss
,
metric
=
self
(
input_ids
=
batch
[
0
],
token_type_ids
=
batch
[
1
],
seq_lengths
=
batch
[
2
],
labels
=
batch
[
3
])
self
.
metric
.
reset
()
return
{
'metrics'
:
metric
}
...
...
@@ -502,20 +522,14 @@ class TransformerModule(RunModule, TextServing):
if
self
.
task
is
not
None
:
raise
RuntimeError
(
"The get_embedding method is only valid when task is None, but got task %s"
%
self
.
task
)
return
self
.
predict
(
data
=
data
,
use_gpu
=
use_gpu
)
return
self
.
predict
(
data
=
data
,
use_gpu
=
use_gpu
)
def
predict
(
self
,
def
predict
(
self
,
data
:
List
[
List
[
str
]],
max_seq_len
:
int
=
128
,
split_char
:
str
=
'
\002
'
,
batch_size
:
int
=
1
,
use_gpu
:
bool
=
False
):
use_gpu
:
bool
=
False
):
"""
Predicts the data labels.
...
...
@@ -532,12 +546,10 @@ class TransformerModule(RunModule, TextServing):
"""
if
self
.
task
not
in
self
.
_tasks_supported
\
and
self
.
task
is
not
None
:
# None for getting embedding
raise
RuntimeError
(
f
'Unknown task
{
self
.
task
}
, current tasks supported:
\n
'
raise
RuntimeError
(
f
'Unknown task
{
self
.
task
}
, current tasks supported:
\n
'
'1. seq-cls: sequence classification;
\n
'
'2. token-cls: sequence labeling;
\n
'
'3. None: embedding'
)
'3. None: embedding'
)
paddle
.
set_device
(
'gpu'
)
if
use_gpu
else
paddle
.
set_device
(
'cpu'
)
...
...
@@ -563,10 +575,7 @@ class TransformerModule(RunModule, TextServing):
results
.
extend
(
token_labels
)
elif
self
.
task
==
None
:
sequence_output
,
pooled_output
=
self
(
input_ids
,
segment_ids
)
results
.
append
([
pooled_output
.
squeeze
(
0
).
numpy
().
tolist
(),
sequence_output
.
squeeze
(
0
).
numpy
().
tolist
()
])
results
.
append
([
pooled_output
.
squeeze
(
0
).
numpy
().
tolist
(),
sequence_output
.
squeeze
(
0
).
numpy
().
tolist
()])
return
results
...
...
@@ -575,6 +584,7 @@ class EmbeddingServing(object):
"""
A base class for embedding model which supports serving.
"""
@
serving
def
calc_similarity
(
self
,
data
:
List
[
List
[
str
]]):
"""
...
...
@@ -593,8 +603,7 @@ class EmbeddingServing(object):
for
word
in
word_pair
:
if
self
.
get_idx_from_word
(
word
)
==
\
self
.
get_idx_from_word
(
self
.
vocab
.
unk_token
):
raise
RuntimeError
(
f
'Word "
{
word
}
" is not in vocab. Please check your inputs.'
)
raise
RuntimeError
(
f
'Word "
{
word
}
" is not in vocab. Please check your inputs.'
)
results
.
append
(
str
(
self
.
cosine_sim
(
*
word_pair
)))
return
results
...
...
paddlehub/utils/utils.py
浏览文件 @
f4d9b46b
...
...
@@ -336,12 +336,11 @@ def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None):
'''
if
labels
:
if
len
(
tokens
)
!=
len
(
labels
):
raise
ValueError
(
"The length of tokens must be same with labels"
)
raise
ValueError
(
"The length of tokens must be same with labels"
)
ret_tokens
=
[]
ret_labels
=
[]
for
token
,
label
in
zip
(
tokens
,
labels
):
sub_token
=
tokenizer
(
token
)
sub_token
=
tokenizer
.
_tokenize
(
token
)
if
len
(
sub_token
)
==
0
:
continue
ret_tokens
.
extend
(
sub_token
)
...
...
@@ -354,13 +353,12 @@ def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None):
ret_labels
.
extend
([
sub_label
]
*
(
len
(
sub_token
)
-
1
))
if
len
(
ret_tokens
)
!=
len
(
ret_labels
):
raise
ValueError
(
"The length of ret_tokens can't match with labels"
)
raise
ValueError
(
"The length of ret_tokens can't match with labels"
)
return
ret_tokens
,
ret_labels
else
:
ret_tokens
=
[]
for
token
in
tokens
:
sub_token
=
tokenizer
(
token
)
sub_token
=
tokenizer
.
_tokenize
(
token
)
if
len
(
sub_token
)
==
0
:
continue
ret_tokens
.
extend
(
sub_token
)
...
...
@@ -376,7 +374,7 @@ def pad_sequence(ids: List[int], max_seq_len: int, pad_token_id: int):
assert
len
(
ids
)
<=
max_seq_len
,
\
f
'The input length
{
len
(
ids
)
}
is greater than max_seq_len
{
max_seq_len
}
. '
\
'Please check the input list and max_seq_len if you really want to pad a sequence.'
return
ids
[:]
+
[
pad_token_id
]
*
(
max_seq_len
-
len
(
ids
))
return
ids
[:]
+
[
pad_token_id
]
*
(
max_seq_len
-
len
(
ids
))
def
trunc_sequence
(
ids
:
List
[
int
],
max_seq_len
:
int
):
...
...
requirements.txt
浏览文件 @
f4d9b46b
...
...
@@ -16,4 +16,4 @@ tqdm
visualdl
>= 2.0.0
# gunicorn not support windows
gunicorn
>= 19.10.0; sys_platform != "win32"
paddlenlp
>= 2.0.0b2
\ No newline at end of file
paddlenlp
>= 2.0.0rc5
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录