提交 62a6b95c 编写于 作者: K kinghuin 提交者: wuzewu

ernie-tiny support seq task

上级 bc8a7ed3
...@@ -37,16 +37,23 @@ args = parser.parse_args() ...@@ -37,16 +37,23 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
# Load Paddlehub ERNIE pretrained model # Load Paddlehub ERNIE pretrained model
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") module = hub.Module(name="ernie_v2_chinese_tiny")
inputs, outputs, program = module.context( inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len) trainable=True, max_seq_len=args.max_seq_len)
if module.name.startswith("ernie_v2"):
use_taskid = True
else:
use_taskid = False
# Download dataset and use SequenceLabelReader to read dataset # Download dataset and use SequenceLabelReader to read dataset
dataset = hub.dataset.MSRA_NER() dataset = hub.dataset.MSRA_NER()
reader = hub.reader.SequenceLabelReader( reader = hub.reader.SequenceLabelReader(
dataset=dataset, dataset=dataset,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len) max_seq_len=args.max_seq_len,
use_task_id=use_taskid,
sp_model_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
# Construct transfer learning network # Construct transfer learning network
# Use "sequence_output" for token-level output. # Use "sequence_output" for token-level output.
......
...@@ -13,8 +13,8 @@ python -u text_classifier.py \ ...@@ -13,8 +13,8 @@ python -u text_classifier.py \
--weight_decay=0.01 \ --weight_decay=0.01 \
--max_seq_len=128 \ --max_seq_len=128 \
--num_epoch=3 \ --num_epoch=3 \
--use_pyreader=False \ --use_pyreader=True \
--use_data_parallel=False --use_data_parallel=True
# Recommending hyper parameters for difference task # Recommending hyper parameters for difference task
# for ChineseGLUE: # for ChineseGLUE:
......
...@@ -361,6 +361,34 @@ class ClassifyReader(BaseReader): ...@@ -361,6 +361,34 @@ class ClassifyReader(BaseReader):
class SequenceLabelReader(BaseReader): class SequenceLabelReader(BaseReader):
def __init__(self,
vocab_path,
dataset=None,
label_map_config=None,
max_seq_len=512,
do_lower_case=True,
random_seed=None,
use_task_id=False,
sp_model_path=None,
word_dict_path=None,
in_tokens=False):
super(SequenceLabelReader, self).__init__(
vocab_path=vocab_path,
dataset=dataset,
label_map_config=label_map_config,
max_seq_len=max_seq_len,
do_lower_case=do_lower_case,
random_seed=random_seed,
use_task_id=use_task_id,
sp_model_path=sp_model_path,
word_dict_path=word_dict_path,
in_tokens=in_tokens)
if sp_model_path and word_dict_path:
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path,
do_lower_case=do_lower_case,
use_sentence_piece_vocab=True)
def _pad_batch_records(self, batch_records, phase=None): def _pad_batch_records(self, batch_records, phase=None):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records]
......
...@@ -113,11 +113,17 @@ def whitespace_tokenize(text): ...@@ -113,11 +113,17 @@ def whitespace_tokenize(text):
class FullTokenizer(object): class FullTokenizer(object):
"""Runs end-to-end tokenziation.""" """Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True): def __init__(self,
vocab_file,
do_lower_case=True,
use_sentence_piece_vocab=False):
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()} self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.use_sentence_piece_vocab = use_sentence_piece_vocab
self.wordpiece_tokenizer = WordpieceTokenizer(
vocab=self.vocab,
use_sentence_piece_vocab=self.use_sentence_piece_vocab)
def tokenize(self, text): def tokenize(self, text):
split_tokens = [] split_tokens = []
...@@ -329,10 +335,15 @@ class BasicTokenizer(object): ...@@ -329,10 +335,15 @@ class BasicTokenizer(object):
class WordpieceTokenizer(object): class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation.""" """Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): def __init__(self,
vocab,
unk_token="[UNK]",
max_input_chars_per_word=100,
use_sentence_piece_vocab=False):
self.vocab = vocab self.vocab = vocab
self.unk_token = unk_token self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word self.max_input_chars_per_word = max_input_chars_per_word
self.use_sentence_piece_vocab = use_sentence_piece_vocab
def tokenize(self, text): def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces. """Tokenizes a piece of text into its word pieces.
...@@ -369,7 +380,9 @@ class WordpieceTokenizer(object): ...@@ -369,7 +380,9 @@ class WordpieceTokenizer(object):
cur_substr = None cur_substr = None
while start < end: while start < end:
substr = "".join(chars[start:end]) substr = "".join(chars[start:end])
if start > 0: if start == 0 and self.use_sentence_piece_vocab:
substr = u'\u2581' + substr
if start > 0 and not self.use_sentence_piece_vocab:
substr = "##" + substr substr = "##" + substr
if substr in self.vocab: if substr in self.vocab:
cur_substr = substr cur_substr = substr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册