未验证 提交 290ded7a 编写于 作者: J Jack Zhou 提交者: GitHub

Optimize FasterTokenizer (#36701)

* optimize fast tokenizer
上级 eca78a9f
...@@ -100,9 +100,14 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const { ...@@ -100,9 +100,14 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const {
// String is converted into wstring failedly. // String is converted into wstring failedly.
return; return;
} }
std::wstring cache_text = L"";
std::wstring dest_text; auto PushCacheText = [&]() {
for (auto ch : unicode_text) { if (cache_text != L"") {
res->emplace_back(cache_text);
cache_text = L"";
}
};
for (auto& ch : unicode_text) {
if (ch == 0 || ch == 0xfffd || IsControl(ch)) { if (ch == 0 || ch == 0xfffd || IsControl(ch)) {
continue; continue;
} }
...@@ -110,25 +115,24 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const { ...@@ -110,25 +115,24 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const {
ch = do_lower_case(ch); ch = do_lower_case(ch);
} }
if (IsChineseChar(ch) || IsPunctuation(ch)) { if (IsChineseChar(ch) || IsPunctuation(ch)) {
dest_text += ' '; PushCacheText();
dest_text += ch; res->emplace_back(std::wstring{ch});
dest_text += ' ';
} else if (IsWhiteSpace(ch)) { } else if (IsWhiteSpace(ch)) {
dest_text += ' '; PushCacheText();
} else { } else {
dest_text += ch; cache_text += ch;
} }
} }
boost::split(*res, dest_text, boost::is_any_of(kStripChars)); PushCacheText();
} }
WordPieceTokenizer::WordPieceTokenizer( WordPieceTokenizer::WordPieceTokenizer(
framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/, const framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/,
const size_t max_input_chars_per_word /* = 100 */) const size_t max_input_chars_per_word /* = 100 */)
: vocab_(vocab), : vocab_(vocab),
unk_token_(unk_token), unk_token_(unk_token),
max_input_chars_per_word_(max_input_chars_per_word) { max_input_chars_per_word_(max_input_chars_per_word) {
unk_token_id_ = (*vocab_)[unk_token_]; unk_token_id_ = vocab_->at(unk_token_);
} }
void WordPieceTokenizer::Tokenize(const wstring& text, void WordPieceTokenizer::Tokenize(const wstring& text,
...@@ -178,7 +182,7 @@ void WordPieceTokenizer::Tokenize(const wstring& text, ...@@ -178,7 +182,7 @@ void WordPieceTokenizer::Tokenize(const wstring& text,
} }
} }
BertTokenizer::BertTokenizer(framework::Vocab* vocab, BertTokenizer::BertTokenizer(const framework::Vocab* vocab,
bool do_lower_case /* = false */, bool do_lower_case /* = false */,
const wstring& unk_token /* = L"[UNK]" */, const wstring& unk_token /* = L"[UNK]" */,
const wstring& pad_token /* = L"[PAD]" */, const wstring& pad_token /* = L"[PAD]" */,
...@@ -196,11 +200,11 @@ BertTokenizer::BertTokenizer(framework::Vocab* vocab, ...@@ -196,11 +200,11 @@ BertTokenizer::BertTokenizer(framework::Vocab* vocab,
vocab_(vocab), vocab_(vocab),
basic_tokenizer_(do_lower_case_), basic_tokenizer_(do_lower_case_),
word_piece_tokenizer_(vocab_, unk_token) { word_piece_tokenizer_(vocab_, unk_token) {
unk_token_id_ = (*vocab_)[unk_token_]; unk_token_id_ = vocab_->at(unk_token_);
pad_token_id_ = (*vocab_)[pad_token_]; pad_token_id_ = vocab_->at(pad_token_);
cls_token_id_ = (*vocab_)[cls_token_]; cls_token_id_ = vocab_->at(cls_token_);
mask_token_id_ = (*vocab_)[mask_token_]; mask_token_id_ = vocab_->at(mask_token_);
sep_token_id_ = (*vocab_)[sep_token_]; sep_token_id_ = vocab_->at(sep_token_);
all_special_tokens_ = vector<wstring>( all_special_tokens_ = vector<wstring>(
{unk_token_, pad_token_, cls_token_, mask_token_, sep_token_}); {unk_token_, pad_token_, cls_token_, mask_token_, sep_token_});
......
...@@ -56,13 +56,13 @@ class BasicTokenizer { ...@@ -56,13 +56,13 @@ class BasicTokenizer {
class WordPieceTokenizer { class WordPieceTokenizer {
public: public:
explicit WordPieceTokenizer(framework::Vocab* vocab, explicit WordPieceTokenizer(const framework::Vocab* vocab,
const wstring& unk_token = L"[UNK]", const wstring& unk_token = L"[UNK]",
const size_t max_input_chars_per_word = 100); const size_t max_input_chars_per_word = 100);
void Tokenize(const wstring& text, vector<int64_t>* output) const; void Tokenize(const wstring& text, vector<int64_t>* output) const;
private: private:
framework::Vocab* vocab_; const framework::Vocab* vocab_;
wstring unk_token_{L"[UNK]"}; wstring unk_token_{L"[UNK]"};
int64_t unk_token_id_; int64_t unk_token_id_;
size_t max_input_chars_per_word_; size_t max_input_chars_per_word_;
...@@ -70,7 +70,8 @@ class WordPieceTokenizer { ...@@ -70,7 +70,8 @@ class WordPieceTokenizer {
class BertTokenizer { class BertTokenizer {
public: public:
explicit BertTokenizer(framework::Vocab* vocab, bool do_lower_case = false, explicit BertTokenizer(const framework::Vocab* vocab,
bool do_lower_case = false,
const wstring& unk_token = L"[UNK]", const wstring& unk_token = L"[UNK]",
const wstring& pad_token = L"[PAD]", const wstring& pad_token = L"[PAD]",
const wstring& cls_token = L"[CLS]", const wstring& cls_token = L"[CLS]",
...@@ -106,7 +107,7 @@ class BertTokenizer { ...@@ -106,7 +107,7 @@ class BertTokenizer {
bool do_lower_case_; bool do_lower_case_;
wstring unk_token_, pad_token_, cls_token_, mask_token_, sep_token_; wstring unk_token_, pad_token_, cls_token_, mask_token_, sep_token_;
string padding_site_; string padding_site_;
framework::Vocab* vocab_; const framework::Vocab* vocab_;
BasicTokenizer basic_tokenizer_; BasicTokenizer basic_tokenizer_;
WordPieceTokenizer word_piece_tokenizer_; WordPieceTokenizer word_piece_tokenizer_;
int64_t unk_token_id_, cls_token_id_, mask_token_id_, pad_token_id_, int64_t unk_token_id_, cls_token_id_, mask_token_id_, pad_token_id_,
...@@ -140,21 +141,20 @@ class FasterTokenizerKernel : public framework::OpKernel<T> { ...@@ -140,21 +141,20 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
return; return;
} }
BertTokenizer* tokenizer_ptr = BertTokenizer tokenizer(vocab, do_lower_case);
new BertTokenizer(const_cast<framework::Vocab*>(vocab), do_lower_case);
size_t batch_max_seq_len = 0; size_t batch_max_seq_len = 0;
size_t batch_size = text->size(); size_t batch_size = text->size();
vector<unordered_map<string, vector<int64_t>>> batch_encode_inputs( vector<unordered_map<string, vector<int64_t>>> batch_encode_inputs(
batch_size); batch_size);
if (text_pair) { if (text_pair) {
tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, *text_pair, tokenizer.BatchEncode(&batch_encode_inputs, *text, *text_pair,
is_split_into_words, max_seq_len, is_split_into_words, max_seq_len,
pad_to_max_seq_len); pad_to_max_seq_len);
} else { } else {
tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, vector<string>(), tokenizer.BatchEncode(&batch_encode_inputs, *text, vector<string>(),
is_split_into_words, max_seq_len, is_split_into_words, max_seq_len,
pad_to_max_seq_len); pad_to_max_seq_len);
} }
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
...@@ -173,7 +173,7 @@ class FasterTokenizerKernel : public framework::OpKernel<T> { ...@@ -173,7 +173,7 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
static_cast<int64_t>(batch_max_seq_len)})); static_cast<int64_t>(batch_max_seq_len)}));
auto* seg_ids_data = seg_ids->mutable_data<T>(ctx.GetPlace()); auto* seg_ids_data = seg_ids->mutable_data<T>(ctx.GetPlace());
auto pad_token_id = tokenizer_ptr->GetPadTokenID(); auto pad_token_id = tokenizer.GetPadTokenID();
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; i++) {
auto& encoder_input_ids = batch_encode_inputs[i]["input_ids"]; auto& encoder_input_ids = batch_encode_inputs[i]["input_ids"];
auto& encoder_seg_ids = batch_encode_inputs[i]["token_type_ids"]; auto& encoder_seg_ids = batch_encode_inputs[i]["token_type_ids"];
...@@ -188,7 +188,6 @@ class FasterTokenizerKernel : public framework::OpKernel<T> { ...@@ -188,7 +188,6 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
std::memset(seg_ids_data + i * batch_max_seq_len + seq_len, pad_token_id, std::memset(seg_ids_data + i * batch_max_seq_len + seq_len, pad_token_id,
(batch_max_seq_len - seq_len) * sizeof(T)); (batch_max_seq_len - seq_len) * sizeof(T));
} }
delete tokenizer_ptr;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册