Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
290ded7a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
290ded7a
编写于
10月 26, 2021
作者:
J
Jack Zhou
提交者:
GitHub
10月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize FasterTokenizer (#36701)
* optimize fast tokenizer
上级
eca78a9f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
34 addition
and
31 deletion
+34
-31
paddle/fluid/operators/string/faster_tokenizer_op.cc
paddle/fluid/operators/string/faster_tokenizer_op.cc
+21
-17
paddle/fluid/operators/string/faster_tokenizer_op.h
paddle/fluid/operators/string/faster_tokenizer_op.h
+13
-14
未找到文件。
paddle/fluid/operators/string/faster_tokenizer_op.cc
浏览文件 @
290ded7a
...
...
@@ -100,9 +100,14 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const {
// String is converted into wstring failedly.
return
;
}
std
::
wstring
dest_text
;
for
(
auto
ch
:
unicode_text
)
{
std
::
wstring
cache_text
=
L""
;
auto
PushCacheText
=
[
&
]()
{
if
(
cache_text
!=
L""
)
{
res
->
emplace_back
(
cache_text
);
cache_text
=
L""
;
}
};
for
(
auto
&
ch
:
unicode_text
)
{
if
(
ch
==
0
||
ch
==
0xfffd
||
IsControl
(
ch
))
{
continue
;
}
...
...
@@ -110,25 +115,24 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const {
ch
=
do_lower_case
(
ch
);
}
if
(
IsChineseChar
(
ch
)
||
IsPunctuation
(
ch
))
{
dest_text
+=
' '
;
dest_text
+=
ch
;
dest_text
+=
' '
;
PushCacheText
();
res
->
emplace_back
(
std
::
wstring
{
ch
});
}
else
if
(
IsWhiteSpace
(
ch
))
{
dest_text
+=
' '
;
PushCacheText
()
;
}
else
{
dest
_text
+=
ch
;
cache
_text
+=
ch
;
}
}
boost
::
split
(
*
res
,
dest_text
,
boost
::
is_any_of
(
kStripChars
)
);
PushCacheText
(
);
}
WordPieceTokenizer
::
WordPieceTokenizer
(
framework
::
Vocab
*
vocab
,
const
wstring
&
unk_token
/* = L"[UNK]"*/
,
const
framework
::
Vocab
*
vocab
,
const
wstring
&
unk_token
/* = L"[UNK]"*/
,
const
size_t
max_input_chars_per_word
/* = 100 */
)
:
vocab_
(
vocab
),
unk_token_
(
unk_token
),
max_input_chars_per_word_
(
max_input_chars_per_word
)
{
unk_token_id_
=
(
*
vocab_
)[
unk_token_
]
;
unk_token_id_
=
vocab_
->
at
(
unk_token_
)
;
}
void
WordPieceTokenizer
::
Tokenize
(
const
wstring
&
text
,
...
...
@@ -178,7 +182,7 @@ void WordPieceTokenizer::Tokenize(const wstring& text,
}
}
BertTokenizer
::
BertTokenizer
(
framework
::
Vocab
*
vocab
,
BertTokenizer
::
BertTokenizer
(
const
framework
::
Vocab
*
vocab
,
bool
do_lower_case
/* = false */
,
const
wstring
&
unk_token
/* = L"[UNK]" */
,
const
wstring
&
pad_token
/* = L"[PAD]" */
,
...
...
@@ -196,11 +200,11 @@ BertTokenizer::BertTokenizer(framework::Vocab* vocab,
vocab_
(
vocab
),
basic_tokenizer_
(
do_lower_case_
),
word_piece_tokenizer_
(
vocab_
,
unk_token
)
{
unk_token_id_
=
(
*
vocab_
)[
unk_token_
]
;
pad_token_id_
=
(
*
vocab_
)[
pad_token_
]
;
cls_token_id_
=
(
*
vocab_
)[
cls_token_
]
;
mask_token_id_
=
(
*
vocab_
)[
mask_token_
]
;
sep_token_id_
=
(
*
vocab_
)[
sep_token_
]
;
unk_token_id_
=
vocab_
->
at
(
unk_token_
)
;
pad_token_id_
=
vocab_
->
at
(
pad_token_
)
;
cls_token_id_
=
vocab_
->
at
(
cls_token_
)
;
mask_token_id_
=
vocab_
->
at
(
mask_token_
)
;
sep_token_id_
=
vocab_
->
at
(
sep_token_
)
;
all_special_tokens_
=
vector
<
wstring
>
(
{
unk_token_
,
pad_token_
,
cls_token_
,
mask_token_
,
sep_token_
});
...
...
paddle/fluid/operators/string/faster_tokenizer_op.h
100755 → 100644
浏览文件 @
290ded7a
...
...
@@ -56,13 +56,13 @@ class BasicTokenizer {
class
WordPieceTokenizer
{
public:
explicit
WordPieceTokenizer
(
framework
::
Vocab
*
vocab
,
explicit
WordPieceTokenizer
(
const
framework
::
Vocab
*
vocab
,
const
wstring
&
unk_token
=
L"[UNK]"
,
const
size_t
max_input_chars_per_word
=
100
);
void
Tokenize
(
const
wstring
&
text
,
vector
<
int64_t
>*
output
)
const
;
private:
framework
::
Vocab
*
vocab_
;
const
framework
::
Vocab
*
vocab_
;
wstring
unk_token_
{
L"[UNK]"
};
int64_t
unk_token_id_
;
size_t
max_input_chars_per_word_
;
...
...
@@ -70,7 +70,8 @@ class WordPieceTokenizer {
class
BertTokenizer
{
public:
explicit
BertTokenizer
(
framework
::
Vocab
*
vocab
,
bool
do_lower_case
=
false
,
explicit
BertTokenizer
(
const
framework
::
Vocab
*
vocab
,
bool
do_lower_case
=
false
,
const
wstring
&
unk_token
=
L"[UNK]"
,
const
wstring
&
pad_token
=
L"[PAD]"
,
const
wstring
&
cls_token
=
L"[CLS]"
,
...
...
@@ -106,7 +107,7 @@ class BertTokenizer {
bool
do_lower_case_
;
wstring
unk_token_
,
pad_token_
,
cls_token_
,
mask_token_
,
sep_token_
;
string
padding_site_
;
framework
::
Vocab
*
vocab_
;
const
framework
::
Vocab
*
vocab_
;
BasicTokenizer
basic_tokenizer_
;
WordPieceTokenizer
word_piece_tokenizer_
;
int64_t
unk_token_id_
,
cls_token_id_
,
mask_token_id_
,
pad_token_id_
,
...
...
@@ -140,21 +141,20 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
return
;
}
BertTokenizer
*
tokenizer_ptr
=
new
BertTokenizer
(
const_cast
<
framework
::
Vocab
*>
(
vocab
),
do_lower_case
);
BertTokenizer
tokenizer
(
vocab
,
do_lower_case
);
size_t
batch_max_seq_len
=
0
;
size_t
batch_size
=
text
->
size
();
vector
<
unordered_map
<
string
,
vector
<
int64_t
>>>
batch_encode_inputs
(
batch_size
);
if
(
text_pair
)
{
tokenizer
_ptr
->
BatchEncode
(
&
batch_encode_inputs
,
*
text
,
*
text_pair
,
is_split_into_words
,
max_seq_len
,
pad_to_max_seq_len
);
tokenizer
.
BatchEncode
(
&
batch_encode_inputs
,
*
text
,
*
text_pair
,
is_split_into_words
,
max_seq_len
,
pad_to_max_seq_len
);
}
else
{
tokenizer
_ptr
->
BatchEncode
(
&
batch_encode_inputs
,
*
text
,
vector
<
string
>
(),
is_split_into_words
,
max_seq_len
,
pad_to_max_seq_len
);
tokenizer
.
BatchEncode
(
&
batch_encode_inputs
,
*
text
,
vector
<
string
>
(),
is_split_into_words
,
max_seq_len
,
pad_to_max_seq_len
);
}
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
...
...
@@ -173,7 +173,7 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
static_cast
<
int64_t
>
(
batch_max_seq_len
)}));
auto
*
seg_ids_data
=
seg_ids
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
pad_token_id
=
tokenizer
_ptr
->
GetPadTokenID
();
auto
pad_token_id
=
tokenizer
.
GetPadTokenID
();
for
(
size_t
i
=
0
;
i
<
batch_size
;
i
++
)
{
auto
&
encoder_input_ids
=
batch_encode_inputs
[
i
][
"input_ids"
];
auto
&
encoder_seg_ids
=
batch_encode_inputs
[
i
][
"token_type_ids"
];
...
...
@@ -188,7 +188,6 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
std
::
memset
(
seg_ids_data
+
i
*
batch_max_seq_len
+
seq_len
,
pad_token_id
,
(
batch_max_seq_len
-
seq_len
)
*
sizeof
(
T
));
}
delete
tokenizer_ptr
;
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录