提交 3e1aa4bd 编写于 作者: K kinghuin 提交者: wuzewu

support ernie-tiny cls

上级 271883bf
......@@ -33,7 +33,6 @@ parser.add_argument("--max_seq_len", type=int, default=512, help="Number of word
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
parser.add_argument("--use_taskid", type=ast.literal_eval, default=False, help="Whether to use taskid ,if yes to use ernie v2.")
args = parser.parse_args()
# yapf: enable.
......@@ -43,7 +42,7 @@ if __name__ == '__main__':
# Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
module = hub.Module(name="ernie_v2_chinese_tiny")
metrics_choices = ["acc"]
elif args.dataset.lower() == "tnews":
dataset = hub.dataset.TNews()
......@@ -75,60 +74,36 @@ if __name__ == '__main__':
metrics_choices = ["acc", "f1"]
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["f1", "acc"]
# The first metric will be choose to eval. Ref: task.py:799
elif args.dataset.lower() == "qqp":
dataset = hub.dataset.GLUE("QQP")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["f1", "acc"]
elif args.dataset.lower() == "sst-2":
dataset = hub.dataset.GLUE("SST-2")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"]
elif args.dataset.lower() == "cola":
dataset = hub.dataset.GLUE("CoLA")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["matthews", "acc"]
elif args.dataset.lower() == "qnli":
dataset = hub.dataset.GLUE("QNLI")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"]
elif args.dataset.lower() == "rte":
dataset = hub.dataset.GLUE("RTE")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"]
elif args.dataset.lower() == "mnli" or args.dataset.lower() == "mnli":
dataset = hub.dataset.GLUE("MNLI_m")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"]
elif args.dataset.lower() == "mnli_mm":
dataset = hub.dataset.GLUE("MNLI_mm")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"]
elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
......@@ -137,19 +112,22 @@ if __name__ == '__main__':
else:
raise ValueError("%s dataset is not defined" % args.dataset)
# Check metric
support_metrics = ["acc", "f1", "matthews"]
for metric in metrics_choices:
if metric not in support_metrics:
raise ValueError("\"%s\" metric is not defined" % metric)
# Start preparing parameters for reader and task accoring to module
# For ernie_v2, it has an addition embedding named task_id
# For ernie_v2_chinese_tiny, it use an addition sentence_piece_vocab to tokenize
if module.name.startswith("ernie_v2"):
use_taskid = True
else:
use_taskid = False
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
......@@ -163,9 +141,18 @@ if __name__ == '__main__':
inputs["segment_ids"].name,
inputs["input_mask"].name,
]
if args.use_taskid:
if use_taskid:
feed_list.append(inputs["task_ids"].name)
# Finish preparing parameter for reader and task accoring to modul
# Define reader
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
use_task_id=use_taskid,
sp_model_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
# Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
......
......@@ -320,6 +320,19 @@ class Module(object):
for assets_file in self.assets:
if "vocab.txt" in assets_file:
return assets_file
return None
def get_word_dict_path(self):
for assets_file in self.assets:
if "dict.wordseg.pickle" in assets_file:
return assets_file
return None
def get_spm_path(self):
for assets_file in self.assets:
if "spm_cased_simp_sampled.model" in assets_file:
return assets_file
return None
def _recover_from_desc(self):
# recover signature
......
......@@ -44,10 +44,16 @@ class BaseReader(object):
do_lower_case=True,
random_seed=None,
use_task_id=False,
sp_model_path=None,
word_dict_path=None,
in_tokens=False):
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
if sp_model_path and word_dict_path:
self.tokenizer = tokenization.WSSPTokenizer(
vocab_path, sp_model_path, word_dict_path, ws=True, lower=True)
else:
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
self.vocab = self.tokenizer.vocab
self.dataset = dataset
self.pad_id = self.vocab["[PAD]"]
......
......@@ -22,6 +22,8 @@ import collections
import io
import unicodedata
import six
import sentencepiece as spm
import pickle
def convert_to_unicode(text):
......@@ -154,6 +156,54 @@ class CharTokenizer(object):
return convert_by_vocab(self.inv_vocab, ids)
class WSSPTokenizer(object):
def __init__(self, vocab_file, sp_model_dir, word_dict, ws=True,
lower=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.ws = ws
self.lower = lower
self.dict = pickle.load(open(word_dict, 'rb'), encoding='utf8')
self.sp_model = spm.SentencePieceProcessor()
self.window_size = 5
self.sp_model.Load(sp_model_dir)
def cut(self, chars):
words = []
idx = 0
while idx < len(chars):
matched = False
for i in range(self.window_size, 0, -1):
cand = chars[idx:idx + i]
if cand in self.dict:
words.append(cand)
matched = True
break
if not matched:
i = 1
words.append(chars[idx])
idx += i
return words
def tokenize(self, text):
sen = text.decode('utf8')
if self.ws:
sen = [s for s in self.cut(sen) if s != ' ']
else:
sen = sen.split(' ')
if self.lower:
sen = [s.lower() for s in sen]
sen = ' '.join(sen)
ret = self.sp_model.EncodeAsPieces(sen)
return ret
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册