提交 cae77c0c 编写于 作者: Q qianlong

BasicTokenizer not case fold on preserverd words

上级 d6d93f16
......@@ -15,11 +15,16 @@
*/
#include "dataset/text/kernels/basic_tokenizer_op.h"
#include <memory>
#include <queue>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "unicode/errorcode.h"
#include "unicode/normalizer2.h"
#include "unicode/utypes.h"
namespace mindspore {
namespace dataset {
const bool BasicTokenizerOp::kDefLowerCase = false;
......@@ -40,8 +45,8 @@ const char BasicTokenizerOp::kCommonPattern[] =
"|[\\x{2B820}-\\x{2CEAF}]"
"|[\\x{F900}-\\x{FAFF}]"
"|[\\x{2F800}-\\x{2FA1F}]";
const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|";
const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|\\[unused\\d+\\]|";
const std::unordered_set<std::string> BasicTokenizerOp::kUnusedWords{"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"};
BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, NormalizeForm normalization_form,
bool preserve_unused_token)
: lower_case_(lower_case),
......@@ -67,6 +72,69 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal
regex_tokenizer_ = std::make_unique<RegexTokenizerOp>(delim_pattern, keep_delim_pattern);
}
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text,
const std::unordered_set<std::string> &unused_words,
std::string *outupt) {
icu::ErrorCode error;
const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error);
CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed.");
outupt->clear();
// 1. get start and end offsets of not case fold strs
std::queue<std::pair<int, int>> offsets; // offsets of not used words
int start = -1;
int len = 0;
for (int i = 0; i < text.length(); i++) {
if (text[i] == '[') {
start = i;
++len;
} else if (text[i] == ']' && start >= 0) {
++len;
std::string word(text.substr(start, len));
if (unused_words.find(word) != unused_words.end()) {
offsets.push(std::make_pair(start, start + len - 1));
}
start = -1;
len = 0;
} else if (start >= 0) {
++len;
}
}
// 2. Do not apply case fold on `unused_words`
start = 0;
for (int i = 0; i < text.length();) {
std::string_view process_text;
std::string preserve_token;
if (offsets.empty()) {
i = text.length();
process_text = text.substr(start, i - start);
} else {
preserve_token = text.substr(offsets.front().first, offsets.front().second - offsets.front().first + 1);
process_text = text.substr(start, offsets.front().first - start);
i = offsets.front().second + 1;
offsets.pop();
}
std::string temp;
icu::StringByteSink<std::string> sink(&temp);
nfkc_case_fold->normalizeUTF8(0, icu::StringPiece(process_text.data(), process_text.size()), sink, nullptr, error);
*outupt += temp + preserve_token;
}
return Status::OK();
}
Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input,
std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
std::vector<std::string> strs(input->Size());
int i = 0;
for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) {
RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++]));
}
*output = std::make_shared<Tensor>(std::move(strs), input->shape());
return Status::OK();
}
Status BasicTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
......@@ -75,8 +143,13 @@ Status BasicTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
std::shared_ptr<Tensor> cur_input;
std::shared_ptr<Tensor> processed_tensor;
if (lower_case_) {
// to lower case
RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor));
if (!preserve_unused_token_) {
// to lower case
RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor));
} else {
// to lower case except words in kUnusedWords
RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(input, &processed_tensor));
}
cur_input = processed_tensor;
// strip accent characters
RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor));
......
......@@ -17,6 +17,7 @@
#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_
#include <memory>
#include <string>
#include <unordered_set>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
......@@ -45,9 +46,15 @@ class BasicTokenizerOp : public TensorOp {
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
protected:
Status CaseFoldWithoutUnusedWords(const std::string_view &text, const std::unordered_set<std::string> &unused_words,
std::string *outupt);
Status CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
private:
static const char kCommonPattern[];
static const char kUnusedPattern[];
static const std::unordered_set<std::string> kUnusedWords;
bool lower_case_;
bool keep_whitespace_;
NormalizeForm normalization_form_;
......
......@@ -10,5 +10,6 @@ unused [SEP]
unused [UNK]
unused [PAD]
unused [MASK]
12+/-28=40/-16
Hello World!
\ No newline at end of file
[unused1]
[unused10]
12+/-28=40/-16
\ No newline at end of file
......@@ -27,7 +27,7 @@ vocab_bert = [
"繁", "體", "字", "嘿", "哈", "大", "笑", "嘻",
"i", "am", "mak", "make", "small", "mistake", "##s", "during", "work", "##ing", "hour",
"😀", "😃", "😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I",
"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"
"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]", "[unused1]", "[unused10]"
]
pad = '<pad>'
test_paras = [
......@@ -69,22 +69,40 @@ test_paras = [
# test preserved tokens
dict(
first=8,
last=12,
last=14,
expect_str=[
['[UNK]', '[CLS]'],
['[UNK]', '[SEP]'],
['[UNK]', '[UNK]'],
['[UNK]', '[PAD]'],
['[UNK]', '[MASK]'],
['[unused1]'],
['[unused10]']
],
lower_case=False,
vocab_list=vocab_bert,
preserve_unused_token=True,
),
dict(
first=8,
last=14,
expect_str=[
['[UNK]', '[CLS]'],
['[UNK]', '[SEP]'],
['[UNK]', '[UNK]'],
['[UNK]', '[PAD]'],
['[UNK]', '[MASK]'],
['[unused1]'],
['[unused10]']
],
lower_case=True,
vocab_list=vocab_bert,
preserve_unused_token=True,
),
# test special symbol
dict(
first=13,
last=13,
first=15,
last=15,
expect_str=[['12', '+', '/', '-', '28', '=', '40', '/', '-', '16']],
preserve_unused_token=True,
vocab_list=vocab_bert
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册