diff --git a/ernie-m/reader/tokenization.py b/ernie-m/reader/tokenization.py index c7a549417f026caac72b56d856ff564aefb565e9..96ff87c2c36515808c2a387e91bf96282cd6ce83 100644 --- a/ernie-m/reader/tokenization.py +++ b/ernie-m/reader/tokenization.py @@ -26,62 +26,11 @@ import re import unicodedata import six from six.moves import range -#import tensorflow as tf import sentencepiece as spm SPIECE_UNDERLINE = u"▁".encode("utf-8") -def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): - """Checks whether the casing config is consistent with the checkpoint name.""" - - # The casing has to be passed in by the user and there is no explicit check - # as to whether it matches the checkpoint. The casing information probably - # should have been stored in the bert_config.json file, but it's not, so - # we have to heuristically detect it to validate. - - if not init_checkpoint: - return - - m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", - six.ensure_str(init_checkpoint)) - if m is None: - return - - model_name = m.group(1) - - lower_models = [ - "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", - "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" - ] - - cased_models = [ - "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", - "multi_cased_L-12_H-768_A-12" - ] - - is_bad_config = False - if model_name in lower_models and not do_lower_case: - is_bad_config = True - actual_flag = "False" - case_name = "lowercased" - opposite_flag = "True" - - if model_name in cased_models and do_lower_case: - is_bad_config = True - actual_flag = "True" - case_name = "cased" - opposite_flag = "False" - - if is_bad_config: - raise ValueError( - "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " - "However, `%s` seems to be a %s model, so you " - "should pass in `--do_lower_case=%s` so that the fine-tuning matches " - "how the model was pre-training. If this error is wrong, please " - "just comment out this check." % (actual_flag, init_checkpoint, - model_name, case_name, opposite_flag)) - def clean_text(text): """Performs invalid character removal and whitespace cleanup on text.""" text = text.replace(u"“", u'"')\ @@ -101,7 +50,6 @@ def clean_text(text): return "".join(output) - def preprocess_text(inputs, remove_space=True, lower=False): """preprocess data by removing extra space and normalize data.""" @@ -126,7 +74,6 @@ def preprocess_text(inputs, remove_space=True, lower=False): def encode_pieces(sp_model, text, return_unicode=True, sample=False): """turn sentences into word pieces.""" - # liujiaxiang: add for ernie-albert, mainly consider for “/”/‘/’/— causing too many unk text = clean_text(text) if six.PY2 and isinstance(text, six.text_type): @@ -153,7 +100,6 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False): else: new_pieces.append(piece) - # note(zhiliny): convert back to unicode for py2 if six.PY2 and return_unicode: ret_pieces = [] for piece in new_pieces: @@ -166,7 +112,6 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False): def encode_ids(sp_model, text, sample=False): - pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) ids = [sp_model.PieceToId(piece) for piece in pieces] return ids @@ -215,18 +160,6 @@ def printable_text(text): raise ValueError("Not running on Python2 or Python 3?") -#def load_vocab(vocab_file): -# """Loads a vocabulary file into a dictionary.""" -# vocab = collections.OrderedDict() -# with tf.gfile.GFile(vocab_file, "r") as reader: -# while True: -# token = convert_to_unicode(reader.readline()) -# if not token: -# break -# token = token.strip().split()[0] -# if token not in vocab: -# vocab[token] = len(vocab) -# return vocab def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() @@ -275,19 +208,9 @@ class FullTokenizer(object): self.sp_model = None if model_file: self.sp_model = spm.SentencePieceProcessor() - #tf.logging.info("loading sentence piece model") self.sp_model.Load(model_file) - # Note(mingdachen): For the purpose of consisent API, we are - # generating a vocabulary for the sentence piece tokenizer. - #self.vocab = {self.sp_model.IdToPiece(i): i for i - # in range(self.sp_model.GetPieceSize())} self.vocab = load_vocab(vocab_file) - # import pdb; pdb.set_trace() else: - #self.vocab = load_vocab(vocab_file) - #self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) - #self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) - # (liujiaxiang) comment useless code for a better diff code raise ValueError('albert use spm by default') self.inv_vocab = {v: k for k, v in self.vocab.items()} @@ -295,11 +218,6 @@ class FullTokenizer(object): if self.sp_model: split_tokens = encode_pieces(self.sp_model, text, return_unicode=False) else: - #split_tokens = [] - #for token in self.basic_tokenizer.tokenize(text): - # for sub_token in self.wordpiece_tokenizer.tokenize(token): - # split_tokens.append(sub_token) - # (liujiaxiang) comment useless code for a better diff code raise ValueError('albert use spm by default') return split_tokens @@ -309,12 +227,9 @@ class FullTokenizer(object): import tok as tok_protocol text = " ".join([t.token for t in tok_list]) - #split_tokens = encode_pieces(self.sp_model, text, return_unicode=True) split_tokens = encode_pieces(self.sp_model, text, return_unicode=False) ids = self.convert_tokens_to_ids(split_tokens) - # +1 for head _ : 'hello world' -> ['_hello', '_world'] - if not (len(preprocess_text(''.join(split_tokens))) == len(text) + 1): return None @@ -349,14 +264,6 @@ class FullTokenizer(object): position_to_nth[i] = nth_tok return position_to_nth -# def convert_tokens_to_ids(self, tokens): -# if self.sp_model: -# #tf.logging.info("using sentence piece tokenzier.") -# return [self.sp_model.PieceToId( -# printable_text(token)) for token in tokens] -# else: -# return convert_by_vocab(self.vocab, tokens) - def convert_tokens_to_ids(self, tokens): tokens_out = [] for i in tokens: