未验证 提交 e25b1cfa 编写于 作者: O oyjxer 提交者: GitHub

del useless note in tokenization.py (#773)

* add ernie-m code

* del useless note in tokenization.py
Co-authored-by: Npangchao04 <pangchao04@baidu.com>
上级 a948f14d
...@@ -26,62 +26,11 @@ import re ...@@ -26,62 +26,11 @@ import re
import unicodedata import unicodedata
import six import six
from six.moves import range from six.moves import range
#import tensorflow as tf
import sentencepiece as spm import sentencepiece as spm
SPIECE_UNDERLINE = u"▁".encode("utf-8") SPIECE_UNDERLINE = u"▁".encode("utf-8")
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt",
six.ensure_str(init_checkpoint))
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def clean_text(text): def clean_text(text):
"""Performs invalid character removal and whitespace cleanup on text.""" """Performs invalid character removal and whitespace cleanup on text."""
text = text.replace(u"“", u'"')\ text = text.replace(u"“", u'"')\
...@@ -101,7 +50,6 @@ def clean_text(text): ...@@ -101,7 +50,6 @@ def clean_text(text):
return "".join(output) return "".join(output)
def preprocess_text(inputs, remove_space=True, lower=False): def preprocess_text(inputs, remove_space=True, lower=False):
"""preprocess data by removing extra space and normalize data.""" """preprocess data by removing extra space and normalize data."""
...@@ -126,7 +74,6 @@ def preprocess_text(inputs, remove_space=True, lower=False): ...@@ -126,7 +74,6 @@ def preprocess_text(inputs, remove_space=True, lower=False):
def encode_pieces(sp_model, text, return_unicode=True, sample=False): def encode_pieces(sp_model, text, return_unicode=True, sample=False):
"""turn sentences into word pieces.""" """turn sentences into word pieces."""
# liujiaxiang: add for ernie-albert, mainly consider for “/”/‘/’/— causing too many unk
text = clean_text(text) text = clean_text(text)
if six.PY2 and isinstance(text, six.text_type): if six.PY2 and isinstance(text, six.text_type):
...@@ -153,7 +100,6 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False): ...@@ -153,7 +100,6 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False):
else: else:
new_pieces.append(piece) new_pieces.append(piece)
# note(zhiliny): convert back to unicode for py2
if six.PY2 and return_unicode: if six.PY2 and return_unicode:
ret_pieces = [] ret_pieces = []
for piece in new_pieces: for piece in new_pieces:
...@@ -166,7 +112,6 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False): ...@@ -166,7 +112,6 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False):
def encode_ids(sp_model, text, sample=False): def encode_ids(sp_model, text, sample=False):
pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
ids = [sp_model.PieceToId(piece) for piece in pieces] ids = [sp_model.PieceToId(piece) for piece in pieces]
return ids return ids
...@@ -215,18 +160,6 @@ def printable_text(text): ...@@ -215,18 +160,6 @@ def printable_text(text):
raise ValueError("Not running on Python2 or Python 3?") raise ValueError("Not running on Python2 or Python 3?")
#def load_vocab(vocab_file):
# """Loads a vocabulary file into a dictionary."""
# vocab = collections.OrderedDict()
# with tf.gfile.GFile(vocab_file, "r") as reader:
# while True:
# token = convert_to_unicode(reader.readline())
# if not token:
# break
# token = token.strip().split()[0]
# if token not in vocab:
# vocab[token] = len(vocab)
# return vocab
def load_vocab(vocab_file): def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary.""" """Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict() vocab = collections.OrderedDict()
...@@ -275,19 +208,9 @@ class FullTokenizer(object): ...@@ -275,19 +208,9 @@ class FullTokenizer(object):
self.sp_model = None self.sp_model = None
if model_file: if model_file:
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
#tf.logging.info("loading sentence piece model")
self.sp_model.Load(model_file) self.sp_model.Load(model_file)
# Note(mingdachen): For the purpose of consisent API, we are
# generating a vocabulary for the sentence piece tokenizer.
#self.vocab = {self.sp_model.IdToPiece(i): i for i
# in range(self.sp_model.GetPieceSize())}
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
# import pdb; pdb.set_trace()
else: else:
#self.vocab = load_vocab(vocab_file)
#self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
#self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
# (liujiaxiang) comment useless code for a better diff code
raise ValueError('albert use spm by default') raise ValueError('albert use spm by default')
self.inv_vocab = {v: k for k, v in self.vocab.items()} self.inv_vocab = {v: k for k, v in self.vocab.items()}
...@@ -295,11 +218,6 @@ class FullTokenizer(object): ...@@ -295,11 +218,6 @@ class FullTokenizer(object):
if self.sp_model: if self.sp_model:
split_tokens = encode_pieces(self.sp_model, text, return_unicode=False) split_tokens = encode_pieces(self.sp_model, text, return_unicode=False)
else: else:
#split_tokens = []
#for token in self.basic_tokenizer.tokenize(text):
# for sub_token in self.wordpiece_tokenizer.tokenize(token):
# split_tokens.append(sub_token)
# (liujiaxiang) comment useless code for a better diff code
raise ValueError('albert use spm by default') raise ValueError('albert use spm by default')
return split_tokens return split_tokens
...@@ -309,12 +227,9 @@ class FullTokenizer(object): ...@@ -309,12 +227,9 @@ class FullTokenizer(object):
import tok as tok_protocol import tok as tok_protocol
text = " ".join([t.token for t in tok_list]) text = " ".join([t.token for t in tok_list])
#split_tokens = encode_pieces(self.sp_model, text, return_unicode=True)
split_tokens = encode_pieces(self.sp_model, text, return_unicode=False) split_tokens = encode_pieces(self.sp_model, text, return_unicode=False)
ids = self.convert_tokens_to_ids(split_tokens) ids = self.convert_tokens_to_ids(split_tokens)
# +1 for head _ : 'hello world' -> ['_hello', '_world']
if not (len(preprocess_text(''.join(split_tokens))) == len(text) + 1): if not (len(preprocess_text(''.join(split_tokens))) == len(text) + 1):
return None return None
...@@ -349,14 +264,6 @@ class FullTokenizer(object): ...@@ -349,14 +264,6 @@ class FullTokenizer(object):
position_to_nth[i] = nth_tok position_to_nth[i] = nth_tok
return position_to_nth return position_to_nth
# def convert_tokens_to_ids(self, tokens):
# if self.sp_model:
# #tf.logging.info("using sentence piece tokenzier.")
# return [self.sp_model.PieceToId(
# printable_text(token)) for token in tokens]
# else:
# return convert_by_vocab(self.vocab, tokens)
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
tokens_out = [] tokens_out = []
for i in tokens: for i in tokens:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册