未验证 提交 290ded7a 编写于 作者: J Jack Zhou 提交者: GitHub

Optimize FasterTokenizer (#36701)

* optimize fast tokenizer
上级 eca78a9f
......@@ -100,9 +100,14 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const {
// String is converted into wstring failedly.
return;
}
std::wstring dest_text;
for (auto ch : unicode_text) {
std::wstring cache_text = L"";
auto PushCacheText = [&]() {
if (cache_text != L"") {
res->emplace_back(cache_text);
cache_text = L"";
}
};
for (auto& ch : unicode_text) {
if (ch == 0 || ch == 0xfffd || IsControl(ch)) {
continue;
}
......@@ -110,25 +115,24 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const {
ch = do_lower_case(ch);
}
if (IsChineseChar(ch) || IsPunctuation(ch)) {
dest_text += ' ';
dest_text += ch;
dest_text += ' ';
PushCacheText();
res->emplace_back(std::wstring{ch});
} else if (IsWhiteSpace(ch)) {
dest_text += ' ';
PushCacheText();
} else {
dest_text += ch;
cache_text += ch;
}
}
boost::split(*res, dest_text, boost::is_any_of(kStripChars));
PushCacheText();
}
WordPieceTokenizer::WordPieceTokenizer(
framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/,
const framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/,
const size_t max_input_chars_per_word /* = 100 */)
: vocab_(vocab),
unk_token_(unk_token),
max_input_chars_per_word_(max_input_chars_per_word) {
unk_token_id_ = (*vocab_)[unk_token_];
unk_token_id_ = vocab_->at(unk_token_);
}
void WordPieceTokenizer::Tokenize(const wstring& text,
......@@ -178,7 +182,7 @@ void WordPieceTokenizer::Tokenize(const wstring& text,
}
}
BertTokenizer::BertTokenizer(framework::Vocab* vocab,
BertTokenizer::BertTokenizer(const framework::Vocab* vocab,
bool do_lower_case /* = false */,
const wstring& unk_token /* = L"[UNK]" */,
const wstring& pad_token /* = L"[PAD]" */,
......@@ -196,11 +200,11 @@ BertTokenizer::BertTokenizer(framework::Vocab* vocab,
vocab_(vocab),
basic_tokenizer_(do_lower_case_),
word_piece_tokenizer_(vocab_, unk_token) {
unk_token_id_ = (*vocab_)[unk_token_];
pad_token_id_ = (*vocab_)[pad_token_];
cls_token_id_ = (*vocab_)[cls_token_];
mask_token_id_ = (*vocab_)[mask_token_];
sep_token_id_ = (*vocab_)[sep_token_];
unk_token_id_ = vocab_->at(unk_token_);
pad_token_id_ = vocab_->at(pad_token_);
cls_token_id_ = vocab_->at(cls_token_);
mask_token_id_ = vocab_->at(mask_token_);
sep_token_id_ = vocab_->at(sep_token_);
all_special_tokens_ = vector<wstring>(
{unk_token_, pad_token_, cls_token_, mask_token_, sep_token_});
......
......@@ -56,13 +56,13 @@ class BasicTokenizer {
class WordPieceTokenizer {
public:
explicit WordPieceTokenizer(framework::Vocab* vocab,
explicit WordPieceTokenizer(const framework::Vocab* vocab,
const wstring& unk_token = L"[UNK]",
const size_t max_input_chars_per_word = 100);
void Tokenize(const wstring& text, vector<int64_t>* output) const;
private:
framework::Vocab* vocab_;
const framework::Vocab* vocab_;
wstring unk_token_{L"[UNK]"};
int64_t unk_token_id_;
size_t max_input_chars_per_word_;
......@@ -70,7 +70,8 @@ class WordPieceTokenizer {
class BertTokenizer {
public:
explicit BertTokenizer(framework::Vocab* vocab, bool do_lower_case = false,
explicit BertTokenizer(const framework::Vocab* vocab,
bool do_lower_case = false,
const wstring& unk_token = L"[UNK]",
const wstring& pad_token = L"[PAD]",
const wstring& cls_token = L"[CLS]",
......@@ -106,7 +107,7 @@ class BertTokenizer {
bool do_lower_case_;
wstring unk_token_, pad_token_, cls_token_, mask_token_, sep_token_;
string padding_site_;
framework::Vocab* vocab_;
const framework::Vocab* vocab_;
BasicTokenizer basic_tokenizer_;
WordPieceTokenizer word_piece_tokenizer_;
int64_t unk_token_id_, cls_token_id_, mask_token_id_, pad_token_id_,
......@@ -140,21 +141,20 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
return;
}
BertTokenizer* tokenizer_ptr =
new BertTokenizer(const_cast<framework::Vocab*>(vocab), do_lower_case);
BertTokenizer tokenizer(vocab, do_lower_case);
size_t batch_max_seq_len = 0;
size_t batch_size = text->size();
vector<unordered_map<string, vector<int64_t>>> batch_encode_inputs(
batch_size);
if (text_pair) {
tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, *text_pair,
is_split_into_words, max_seq_len,
pad_to_max_seq_len);
tokenizer.BatchEncode(&batch_encode_inputs, *text, *text_pair,
is_split_into_words, max_seq_len,
pad_to_max_seq_len);
} else {
tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, vector<string>(),
is_split_into_words, max_seq_len,
pad_to_max_seq_len);
tokenizer.BatchEncode(&batch_encode_inputs, *text, vector<string>(),
is_split_into_words, max_seq_len,
pad_to_max_seq_len);
}
for (size_t i = 0; i < batch_size; ++i) {
......@@ -173,7 +173,7 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
static_cast<int64_t>(batch_max_seq_len)}));
auto* seg_ids_data = seg_ids->mutable_data<T>(ctx.GetPlace());
auto pad_token_id = tokenizer_ptr->GetPadTokenID();
auto pad_token_id = tokenizer.GetPadTokenID();
for (size_t i = 0; i < batch_size; i++) {
auto& encoder_input_ids = batch_encode_inputs[i]["input_ids"];
auto& encoder_seg_ids = batch_encode_inputs[i]["token_type_ids"];
......@@ -188,7 +188,6 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
std::memset(seg_ids_data + i * batch_max_seq_len + seq_len, pad_token_id,
(batch_max_seq_len - seq_len) * sizeof(T));
}
delete tokenizer_ptr;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册