diff --git a/demo/sequence-labeling/sequence_label.py b/demo/sequence-labeling/sequence_label.py index f8db6da3479569ced4df4d9e8dd0336dfa8007aa..2fe72b5b16d2864d78ed8b296ad6a410fb72f5b2 100644 --- a/demo/sequence-labeling/sequence_label.py +++ b/demo/sequence-labeling/sequence_label.py @@ -37,16 +37,23 @@ args = parser.parse_args() if __name__ == '__main__': # Load Paddlehub ERNIE pretrained model - module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") + module = hub.Module(name="ernie_v2_chinese_tiny") inputs, outputs, program = module.context( trainable=True, max_seq_len=args.max_seq_len) + if module.name.startswith("ernie_v2"): + use_taskid = True + else: + use_taskid = False # Download dataset and use SequenceLabelReader to read dataset dataset = hub.dataset.MSRA_NER() reader = hub.reader.SequenceLabelReader( dataset=dataset, vocab_path=module.get_vocab_path(), - max_seq_len=args.max_seq_len) + max_seq_len=args.max_seq_len, + use_task_id=use_taskid, + sp_model_path=module.get_spm_path(), + word_dict_path=module.get_word_dict_path()) # Construct transfer learning network # Use "sequence_output" for token-level output. diff --git a/demo/text-classification/run_classifier.sh b/demo/text-classification/run_classifier.sh index fad3e471b58db3d8d3faf052d26696c453ed48d2..d297cb7493d0990064f8fa79528a6fef0dcb971e 100644 --- a/demo/text-classification/run_classifier.sh +++ b/demo/text-classification/run_classifier.sh @@ -13,8 +13,8 @@ python -u text_classifier.py \ --weight_decay=0.01 \ --max_seq_len=128 \ --num_epoch=3 \ - --use_pyreader=False \ - --use_data_parallel=False + --use_pyreader=True \ + --use_data_parallel=True # Recommending hyper parameters for difference task # for ChineseGLUE: diff --git a/paddlehub/reader/nlp_reader.py b/paddlehub/reader/nlp_reader.py index e37885438ea8823855050cd63afd96d5727a4323..3bcc219f6ec3ee2ea7024dc102b117a6ee7e7121 100644 --- a/paddlehub/reader/nlp_reader.py +++ b/paddlehub/reader/nlp_reader.py @@ -361,6 +361,34 @@ class ClassifyReader(BaseReader): class SequenceLabelReader(BaseReader): + def __init__(self, + vocab_path, + dataset=None, + label_map_config=None, + max_seq_len=512, + do_lower_case=True, + random_seed=None, + use_task_id=False, + sp_model_path=None, + word_dict_path=None, + in_tokens=False): + super(SequenceLabelReader, self).__init__( + vocab_path=vocab_path, + dataset=dataset, + label_map_config=label_map_config, + max_seq_len=max_seq_len, + do_lower_case=do_lower_case, + random_seed=random_seed, + use_task_id=use_task_id, + sp_model_path=sp_model_path, + word_dict_path=word_dict_path, + in_tokens=in_tokens) + if sp_model_path and word_dict_path: + self.tokenizer = tokenization.FullTokenizer( + vocab_file=vocab_path, + do_lower_case=do_lower_case, + use_sentence_piece_vocab=True) + def _pad_batch_records(self, batch_records, phase=None): batch_token_ids = [record.token_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records] diff --git a/paddlehub/reader/tokenization.py b/paddlehub/reader/tokenization.py index fab4121ff4a147dde007c4f19e468cf9f9917b0c..ef49ed76fd82d0a0b58cfe2b3bc7122eb9e8acac 100644 --- a/paddlehub/reader/tokenization.py +++ b/paddlehub/reader/tokenization.py @@ -113,11 +113,17 @@ def whitespace_tokenize(text): class FullTokenizer(object): """Runs end-to-end tokenziation.""" - def __init__(self, vocab_file, do_lower_case=True): + def __init__(self, + vocab_file, + do_lower_case=True, + use_sentence_piece_vocab=False): self.vocab = load_vocab(vocab_file) self.inv_vocab = {v: k for k, v in self.vocab.items()} self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) - self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.use_sentence_piece_vocab = use_sentence_piece_vocab + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, + use_sentence_piece_vocab=self.use_sentence_piece_vocab) def tokenize(self, text): split_tokens = [] @@ -329,10 +335,15 @@ class BasicTokenizer(object): class WordpieceTokenizer(object): """Runs WordPiece tokenziation.""" - def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + def __init__(self, + vocab, + unk_token="[UNK]", + max_input_chars_per_word=100, + use_sentence_piece_vocab=False): self.vocab = vocab self.unk_token = unk_token self.max_input_chars_per_word = max_input_chars_per_word + self.use_sentence_piece_vocab = use_sentence_piece_vocab def tokenize(self, text): """Tokenizes a piece of text into its word pieces. @@ -369,7 +380,9 @@ class WordpieceTokenizer(object): cur_substr = None while start < end: substr = "".join(chars[start:end]) - if start > 0: + if start == 0 and self.use_sentence_piece_vocab: + substr = u'\u2581' + substr + if start > 0 and not self.use_sentence_piece_vocab: substr = "##" + substr if substr in self.vocab: cur_substr = substr