diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.cc b/paddle/fluid/operators/string/faster_tokenizer_op.cc index 49457af8f00c80debabf4ed76e456dacc580be00..42047021b408a8bd2582d68faaf814ff145dc1a0 100644 --- a/paddle/fluid/operators/string/faster_tokenizer_op.cc +++ b/paddle/fluid/operators/string/faster_tokenizer_op.cc @@ -100,9 +100,14 @@ void BasicTokenizer::Tokenize(const string& text, vector* res) const { // String is converted into wstring failedly. return; } - - std::wstring dest_text; - for (auto ch : unicode_text) { + std::wstring cache_text = L""; + auto PushCacheText = [&]() { + if (cache_text != L"") { + res->emplace_back(cache_text); + cache_text = L""; + } + }; + for (auto& ch : unicode_text) { if (ch == 0 || ch == 0xfffd || IsControl(ch)) { continue; } @@ -110,25 +115,24 @@ void BasicTokenizer::Tokenize(const string& text, vector* res) const { ch = do_lower_case(ch); } if (IsChineseChar(ch) || IsPunctuation(ch)) { - dest_text += ' '; - dest_text += ch; - dest_text += ' '; + PushCacheText(); + res->emplace_back(std::wstring{ch}); } else if (IsWhiteSpace(ch)) { - dest_text += ' '; + PushCacheText(); } else { - dest_text += ch; + cache_text += ch; } } - boost::split(*res, dest_text, boost::is_any_of(kStripChars)); + PushCacheText(); } WordPieceTokenizer::WordPieceTokenizer( - framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/, + const framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/, const size_t max_input_chars_per_word /* = 100 */) : vocab_(vocab), unk_token_(unk_token), max_input_chars_per_word_(max_input_chars_per_word) { - unk_token_id_ = (*vocab_)[unk_token_]; + unk_token_id_ = vocab_->at(unk_token_); } void WordPieceTokenizer::Tokenize(const wstring& text, @@ -178,7 +182,7 @@ void WordPieceTokenizer::Tokenize(const wstring& text, } } -BertTokenizer::BertTokenizer(framework::Vocab* vocab, +BertTokenizer::BertTokenizer(const framework::Vocab* vocab, bool do_lower_case /* = false */, const wstring& unk_token /* = L"[UNK]" */, const wstring& pad_token /* = L"[PAD]" */, @@ -196,11 +200,11 @@ BertTokenizer::BertTokenizer(framework::Vocab* vocab, vocab_(vocab), basic_tokenizer_(do_lower_case_), word_piece_tokenizer_(vocab_, unk_token) { - unk_token_id_ = (*vocab_)[unk_token_]; - pad_token_id_ = (*vocab_)[pad_token_]; - cls_token_id_ = (*vocab_)[cls_token_]; - mask_token_id_ = (*vocab_)[mask_token_]; - sep_token_id_ = (*vocab_)[sep_token_]; + unk_token_id_ = vocab_->at(unk_token_); + pad_token_id_ = vocab_->at(pad_token_); + cls_token_id_ = vocab_->at(cls_token_); + mask_token_id_ = vocab_->at(mask_token_); + sep_token_id_ = vocab_->at(sep_token_); all_special_tokens_ = vector( {unk_token_, pad_token_, cls_token_, mask_token_, sep_token_}); diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.h b/paddle/fluid/operators/string/faster_tokenizer_op.h old mode 100755 new mode 100644 index d9b7fa26a6704b82bb5e9911109f0064260db3b7..5218b7c2eaa51da062dff1ac80d08b0b2e50d906 --- a/paddle/fluid/operators/string/faster_tokenizer_op.h +++ b/paddle/fluid/operators/string/faster_tokenizer_op.h @@ -56,13 +56,13 @@ class BasicTokenizer { class WordPieceTokenizer { public: - explicit WordPieceTokenizer(framework::Vocab* vocab, + explicit WordPieceTokenizer(const framework::Vocab* vocab, const wstring& unk_token = L"[UNK]", const size_t max_input_chars_per_word = 100); void Tokenize(const wstring& text, vector* output) const; private: - framework::Vocab* vocab_; + const framework::Vocab* vocab_; wstring unk_token_{L"[UNK]"}; int64_t unk_token_id_; size_t max_input_chars_per_word_; @@ -70,7 +70,8 @@ class WordPieceTokenizer { class BertTokenizer { public: - explicit BertTokenizer(framework::Vocab* vocab, bool do_lower_case = false, + explicit BertTokenizer(const framework::Vocab* vocab, + bool do_lower_case = false, const wstring& unk_token = L"[UNK]", const wstring& pad_token = L"[PAD]", const wstring& cls_token = L"[CLS]", @@ -106,7 +107,7 @@ class BertTokenizer { bool do_lower_case_; wstring unk_token_, pad_token_, cls_token_, mask_token_, sep_token_; string padding_site_; - framework::Vocab* vocab_; + const framework::Vocab* vocab_; BasicTokenizer basic_tokenizer_; WordPieceTokenizer word_piece_tokenizer_; int64_t unk_token_id_, cls_token_id_, mask_token_id_, pad_token_id_, @@ -140,21 +141,20 @@ class FasterTokenizerKernel : public framework::OpKernel { return; } - BertTokenizer* tokenizer_ptr = - new BertTokenizer(const_cast(vocab), do_lower_case); + BertTokenizer tokenizer(vocab, do_lower_case); size_t batch_max_seq_len = 0; size_t batch_size = text->size(); vector>> batch_encode_inputs( batch_size); if (text_pair) { - tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, *text_pair, - is_split_into_words, max_seq_len, - pad_to_max_seq_len); + tokenizer.BatchEncode(&batch_encode_inputs, *text, *text_pair, + is_split_into_words, max_seq_len, + pad_to_max_seq_len); } else { - tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, vector(), - is_split_into_words, max_seq_len, - pad_to_max_seq_len); + tokenizer.BatchEncode(&batch_encode_inputs, *text, vector(), + is_split_into_words, max_seq_len, + pad_to_max_seq_len); } for (size_t i = 0; i < batch_size; ++i) { @@ -173,7 +173,7 @@ class FasterTokenizerKernel : public framework::OpKernel { static_cast(batch_max_seq_len)})); auto* seg_ids_data = seg_ids->mutable_data(ctx.GetPlace()); - auto pad_token_id = tokenizer_ptr->GetPadTokenID(); + auto pad_token_id = tokenizer.GetPadTokenID(); for (size_t i = 0; i < batch_size; i++) { auto& encoder_input_ids = batch_encode_inputs[i]["input_ids"]; auto& encoder_seg_ids = batch_encode_inputs[i]["token_type_ids"]; @@ -188,7 +188,6 @@ class FasterTokenizerKernel : public framework::OpKernel { std::memset(seg_ids_data + i * batch_max_seq_len + seq_len, pad_token_id, (batch_max_seq_len - seq_len) * sizeof(T)); } - delete tokenizer_ptr; } };