提交 f889492f 编写于 作者: C chenxuyi

examples compat to ERNIE tiny

上级 72e21235
...@@ -4,6 +4,7 @@ import re ...@@ -4,6 +4,7 @@ import re
from propeller import log from propeller import log
import itertools import itertools
from propeller.paddle.data import Dataset from propeller.paddle.data import Dataset
import pickle
import six import six
...@@ -101,7 +102,7 @@ class SpaceTokenizer(object): ...@@ -101,7 +102,7 @@ class SpaceTokenizer(object):
class CharTokenizer(object): class CharTokenizer(object):
def __init__(self, vocab, lower=True): def __init__(self, vocab, lower=True, sentencepiece_style_vocab=False):
""" """
char tokenizer (wordpiece english) char tokenizer (wordpiece english)
normed txt(space seperated or not) => list of word-piece normed txt(space seperated or not) => list of word-piece
...@@ -110,6 +111,7 @@ class CharTokenizer(object): ...@@ -110,6 +111,7 @@ class CharTokenizer(object):
#self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)') #self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)')
self.pat = re.compile(r'([a-zA-Z0-9]+|\S)') self.pat = re.compile(r'([a-zA-Z0-9]+|\S)')
self.lower = lower self.lower = lower
self.sentencepiece_style_vocab = sentencepiece_style_vocab
def __call__(self, sen): def __call__(self, sen):
if len(sen) == 0: if len(sen) == 0:
...@@ -119,11 +121,51 @@ class CharTokenizer(object): ...@@ -119,11 +121,51 @@ class CharTokenizer(object):
sen = sen.lower() sen = sen.lower()
res = [] res = []
for match in self.pat.finditer(sen): for match in self.pat.finditer(sen):
words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]') words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]', sentencepiece_style_vocab=self.sentencepiece_style_vocab)
res.extend(words) res.extend(words)
return res return res
class WSSPTokenizer(object):
def __init__(self, sp_model_dir, word_dict, ws=True, lower=True):
self.ws = ws
self.lower = lower
self.dict = pickle.load(open(word_dict, 'rb'), encoding='utf8')
import sentencepiece as spm
self.sp_model = spm.SentencePieceProcessor()
self.window_size = 5
self.sp_model.Load(sp_model_dir)
def cut(self, chars):
words = []
idx = 0
while idx < len(chars):
matched = False
for i in range(self.window_size, 0, -1):
cand = chars[idx: idx+i]
if cand in self.dict:
words.append(cand)
matched = True
break
if not matched:
i = 1
words.append(chars[idx])
idx += i
return words
def __call__(self, sen):
sen = sen.decode('utf8')
if self.ws:
sen = [s for s in self.cut(sen) if s != ' ']
else:
sen = sen.split(' ')
if self.lower:
sen = [s.lower() for s in sen]
sen = ' '.join(sen)
ret = self.sp_model.EncodeAsPieces(sen)
return ret
def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id): def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id):
token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0 token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0
token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1 token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1
......
...@@ -55,7 +55,7 @@ class ClassificationErnieModel(propeller.train.Model): ...@@ -55,7 +55,7 @@ class ClassificationErnieModel(propeller.train.Model):
pos_ids = L.cast(pos_ids, 'int64') pos_ids = L.cast(pos_ids, 'int64')
pos_ids.stop_gradient = True pos_ids.stop_gradient = True
input_mask.stop_gradient = True input_mask.stop_gradient = True
task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment task_ids = L.zeros_like(src_ids) + self.hparam.task_id
task_ids.stop_gradient = True task_ids.stop_gradient = True
ernie = ErnieModel( ernie = ErnieModel(
...@@ -128,6 +128,8 @@ if __name__ == '__main__': ...@@ -128,6 +128,8 @@ if __name__ == '__main__':
parser.add_argument('--vocab_file', type=str, required=True) parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--do_predict', action='store_true') parser.add_argument('--do_predict', action='store_true')
parser.add_argument('--warm_start_from', type=str) parser.add_argument('--warm_start_from', type=str)
parser.add_argument('--sentence_piece_model', type=str, default=None)
parser.add_argument('--word_dict', type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
run_config = propeller.parse_runconfig(args) run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args) hparams = propeller.parse_hparam(args)
...@@ -138,7 +140,12 @@ if __name__ == '__main__': ...@@ -138,7 +140,12 @@ if __name__ == '__main__':
cls_id = vocab['[CLS]'] cls_id = vocab['[CLS]']
unk_id = vocab['[UNK]'] unk_id = vocab['[UNK]']
tokenizer = utils.data.CharTokenizer(vocab.keys()) if args.sentence_piece_model is not None:
if args.word_dict is None:
raise ValueError('--word_dict no specified in subword Model')
tokenizer = utils.data.WSSPTokenizer(args.sentence_piece_model, args.word_dict, ws=True, lower=True)
else:
tokenizer = utils.data.CharTokenizer(vocab.keys())
def tokenizer_func(inputs): def tokenizer_func(inputs):
'''avoid pickle error''' '''avoid pickle error'''
...@@ -179,7 +186,7 @@ if __name__ == '__main__': ...@@ -179,7 +186,7 @@ if __name__ == '__main__':
dev_ds.data_shapes = shapes dev_ds.data_shapes = shapes
dev_ds.data_types = types dev_ds.data_types = types
varname_to_warmstart = re.compile('encoder.*|pooled.*|.*embedding|pre_encoder_.*') varname_to_warmstart = re.compile(r'^encoder.*[wb]_0$|^.*embedding$|^.*bias$|^.*scale$|^pooled_fc.[wb]_0$')
warm_start_dir = args.warm_start_from warm_start_dir = args.warm_start_from
ws = propeller.WarmStartSetting( ws = propeller.WarmStartSetting(
predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)), predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)),
......
...@@ -32,7 +32,6 @@ import paddle.fluid.layers as L ...@@ -32,7 +32,6 @@ import paddle.fluid.layers as L
from model.ernie import ErnieModel from model.ernie import ErnieModel
from optimization import optimization from optimization import optimization
import tokenization
import utils.data import utils.data
from propeller import log from propeller import log
...@@ -121,7 +120,7 @@ class SequenceLabelErnieModel(propeller.train.Model): ...@@ -121,7 +120,7 @@ class SequenceLabelErnieModel(propeller.train.Model):
def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_size, max_seqlen, is_train): def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_size, max_seqlen, is_train):
label_map = {v: i for i, v in enumerate(label_list)} label_map = {v: i for i, v in enumerate(label_list)}
no_entity_id = label_map['O'] no_entity_id = label_map['O']
delimiter = '' delimiter = b''
def read_bio_data(filename): def read_bio_data(filename):
ds = propeller.data.Dataset.from_file(filename) ds = propeller.data.Dataset.from_file(filename)
...@@ -132,10 +131,10 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ...@@ -132,10 +131,10 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_
while 1: while 1:
line = next(iterator) line = next(iterator)
cols = line.rstrip(b'\n').split(b'\t') cols = line.rstrip(b'\n').split(b'\t')
tokens = cols[0].split(delimiter)
labels = cols[1].split(delimiter)
if len(cols) != 2: if len(cols) != 2:
continue continue
tokens = tokenization.convert_to_unicode(cols[0]).split(delimiter)
labels = tokenization.convert_to_unicode(cols[1]).split(delimiter)
if len(tokens) != len(labels) or len(tokens) == 0: if len(tokens) != len(labels) or len(tokens) == 0:
continue continue
yield [tokens, labels] yield [tokens, labels]
...@@ -151,7 +150,8 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ...@@ -151,7 +150,8 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_
ret_tokens = [] ret_tokens = []
ret_labels = [] ret_labels = []
for token, label in zip(tokens, labels): for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token) sub_token = tokenizer(token)
label = label.decode('utf8')
if len(sub_token) == 0: if len(sub_token) == 0:
continue continue
ret_tokens.extend(sub_token) ret_tokens.extend(sub_token)
...@@ -179,7 +179,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ...@@ -179,7 +179,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_
labels = labels[: max_seqlen - 2] labels = labels[: max_seqlen - 2]
tokens = ['[CLS]'] + tokens + ['[SEP]'] tokens = ['[CLS]'] + tokens + ['[SEP]']
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = [vocab[t] for t in tokens]
label_ids = [no_entity_id] + [label_map[x] for x in labels] + [no_entity_id] label_ids = [no_entity_id] + [label_map[x] for x in labels] + [no_entity_id]
token_type_ids = [0] * len(token_ids) token_type_ids = [0] * len(token_ids)
input_seqlen = len(token_ids) input_seqlen = len(token_ids)
...@@ -211,7 +211,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_ ...@@ -211,7 +211,7 @@ def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_
def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seqlen): def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seqlen):
delimiter = '' delimiter = b''
def stdin_gen(): def stdin_gen():
if six.PY3: if six.PY3:
...@@ -232,9 +232,9 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql ...@@ -232,9 +232,9 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql
while 1: while 1:
line, = next(iterator) line, = next(iterator)
cols = line.rstrip(b'\n').split(b'\t') cols = line.rstrip(b'\n').split(b'\t')
tokens = cols[0].split(delimiter)
if len(cols) != 1: if len(cols) != 1:
continue continue
tokens = tokenization.convert_to_unicode(cols[0]).split(delimiter)
if len(tokens) == 0: if len(tokens) == 0:
continue continue
yield tokens, yield tokens,
...@@ -247,7 +247,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql ...@@ -247,7 +247,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql
tokens, = next(iterator) tokens, = next(iterator)
ret_tokens = [] ret_tokens = []
for token in tokens: for token in tokens:
sub_token = tokenizer.tokenize(token) sub_token = tokenizer(token)
if len(sub_token) == 0: if len(sub_token) == 0:
continue continue
ret_tokens.extend(sub_token) ret_tokens.extend(sub_token)
...@@ -266,7 +266,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql ...@@ -266,7 +266,7 @@ def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seql
tokens = tokens[: max_seqlen - 2] tokens = tokens[: max_seqlen - 2]
tokens = ['[CLS]'] + tokens + ['[SEP]'] tokens = ['[CLS]'] + tokens + ['[SEP]']
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = [vocab[t] for t in tokens]
token_type_ids = [0] * len(token_ids) token_type_ids = [0] * len(token_ids)
input_seqlen = len(token_ids) input_seqlen = len(token_ids)
...@@ -296,13 +296,15 @@ if __name__ == '__main__': ...@@ -296,13 +296,15 @@ if __name__ == '__main__':
parser.add_argument('--data_dir', type=str, required=True) parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--vocab_file', type=str, required=True) parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--do_predict', action='store_true') parser.add_argument('--do_predict', action='store_true')
parser.add_argument('--use_sentence_piece_vocab', action='store_true')
parser.add_argument('--warm_start_from', type=str) parser.add_argument('--warm_start_from', type=str)
args = parser.parse_args() args = parser.parse_args()
run_config = propeller.parse_runconfig(args) run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args) hparams = propeller.parse_hparam(args)
tokenizer = tokenization.FullTokenizer(args.vocab_file)
vocab = tokenizer.vocab vocab = {j.strip().split('\t')[0]: i for i, j in enumerate(open(args.vocab_file, 'r', encoding='utf8'))}
tokenizer = utils.data.CharTokenizer(vocab, sentencepiece_style_vocab=args.use_sentence_piece_vocab)
sep_id = vocab['[SEP]'] sep_id = vocab['[SEP]']
cls_id = vocab['[CLS]'] cls_id = vocab['[CLS]']
unk_id = vocab['[UNK]'] unk_id = vocab['[UNK]']
...@@ -358,7 +360,7 @@ if __name__ == '__main__': ...@@ -358,7 +360,7 @@ if __name__ == '__main__':
from_dir=warm_start_dir from_dir=warm_start_dir
) )
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
propeller.train.train_and_eval( propeller.train.train_and_eval(
model_class_or_model_fn=SequenceLabelErnieModel, model_class_or_model_fn=SequenceLabelErnieModel,
params=hparams, params=hparams,
...@@ -387,7 +389,6 @@ if __name__ == '__main__': ...@@ -387,7 +389,6 @@ if __name__ == '__main__':
predict_ds.data_types = types predict_ds.data_types = types
rev_label_map = {i: v for i, v in enumerate(label_list)} rev_label_map = {i: v for i, v in enumerate(label_list)}
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
learner = propeller.Learner(SequenceLabelErnieModel, run_config, hparams) learner = propeller.Learner(SequenceLabelErnieModel, run_config, hparams)
for pred, _ in learner.predict(predict_ds, ckpt=-1): for pred, _ in learner.predict(predict_ds, ckpt=-1):
pred_str = ' '.join([rev_label_map[idx] for idx in np.argmax(pred, 1).tolist()]) pred_str = ' '.join([rev_label_map[idx] for idx in np.argmax(pred, 1).tolist()])
......
...@@ -146,6 +146,7 @@ if __name__ == '__main__': ...@@ -146,6 +146,7 @@ if __name__ == '__main__':
parser.add_argument('--data_dir', type=str, required=True) parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--warm_start_from', type=str) parser.add_argument('--warm_start_from', type=str)
parser.add_argument('--sentence_piece_model', type=str, default=None) parser.add_argument('--sentence_piece_model', type=str, default=None)
parser.add_argument('--word_dict', type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
run_config = propeller.parse_runconfig(args) run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args) hparams = propeller.parse_hparam(args)
...@@ -157,7 +158,9 @@ if __name__ == '__main__': ...@@ -157,7 +158,9 @@ if __name__ == '__main__':
unk_id = vocab['[UNK]'] unk_id = vocab['[UNK]']
if args.sentence_piece_model is not None: if args.sentence_piece_model is not None:
tokenizer = utils.data.JBSPTokenizer(args.sentence_piece_model, jb=True, lower=True) if args.word_dict is None:
raise ValueError('--word_dict no specified in subword Model')
tokenizer = utils.data.WSSPTokenizer(args.sentence_piece_model, args.word_dict, ws=True, lower=True)
else: else:
tokenizer = utils.data.CharTokenizer(vocab.keys()) tokenizer = utils.data.CharTokenizer(vocab.keys())
...@@ -218,7 +221,7 @@ if __name__ == '__main__': ...@@ -218,7 +221,7 @@ if __name__ == '__main__':
from_dir=warm_start_dir from_dir=warm_start_dir
) )
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1']) best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
propeller.train_and_eval( propeller.train_and_eval(
model_class_or_model_fn=RankingErnieModel, model_class_or_model_fn=RankingErnieModel,
params=hparams, params=hparams,
...@@ -258,6 +261,7 @@ if __name__ == '__main__': ...@@ -258,6 +261,7 @@ if __name__ == '__main__':
est = propeller.Learner(RankingErnieModel, run_config, hparams) est = propeller.Learner(RankingErnieModel, run_config, hparams)
for qid, res in est.predict(predict_ds, ckpt=-1): for qid, res in est.predict(predict_ds, ckpt=-1):
print('%d\t%d\t%.5f\t%.5f' % (qid[0], np.argmax(res), res[0], res[1])) print('%d\t%d\t%.5f\t%.5f' % (qid[0], np.argmax(res), res[0], res[1]))
#for i in predict_ds: #for i in predict_ds:
# sen = i[0] # sen = i[0]
# for ss in np.squeeze(sen): # for ss in np.squeeze(sen):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册