提交 1bb38efb 编写于 作者: W wangxiao1021

add multi-task example

上级 4cc989d2
......@@ -4,33 +4,25 @@ import os
import io
abs_path = os.path.abspath(__file__)
dst_dir = os.path.join(os.path.dirname(abs_path), "data/mlm/")
dst_dir2 = os.path.join(os.path.dirname(abs_path), "data/match/")
dst_dir = os.path.join(os.path.dirname(abs_path), "data/match/")
if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir):
os.makedirs(dst_dir)
if not os.path.exists(dst_dir2) or not os.path.isdir(dst_dir2):
os.makedirs(dst_dir2)
os.mknod("./data/mlm/train.tsv")
os.mknod("./data/match/train.tsv")
with io.open("./data/mrc/train.json", "r", encoding='utf-8') as file:
data = json.load(file)["data"]
with io.open("./data/mrc/train.json", "r", encoding='utf-8') as f:
data = json.load(f)["data"]
i = 0
with open("./data/mlm/train.tsv","w") as f:
f.write("text_a\n")
with open("./data/match/train.tsv","w") as f2:
f2.write("text_a\ttext_b\tlabel\n")
for dd in data:
for d in dd["paragraphs"]:
text_a_mlm = d["context"]
l = text_a_mlm+"\n"
f.write(l.encode("utf-8"))
with open("./data/match/train.tsv","w") as f2:
f2.write("text_a\ttext_b\tlabel\n")
for dd in data:
for d in dd["paragraphs"]:
context = d["context"]
for qa in d["qas"]:
text_a = qa["question"]
answer = qa["answers"][0]
text_b = answer["text"]
start_pos = answer["answer_start"]
text_b_neg = text_a_mlm[0:start_pos]
text_b_neg = context[0:start_pos]
if len(text_b_neg) > 512:
text_b_neg = text_b_neg[-512:-1]
l1 = text_a+"\t"+text_b+"\t1\n"
......@@ -40,6 +32,5 @@ with io.open("./data/mrc/train.json", "r", encoding='utf-8') as file:
f2.write(l2.encode("utf-8"))
i +=2
f2.close()
f.close()
file.close()
f2.close()
f.close()
......@@ -7,7 +7,7 @@ from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
# configs
max_seqlen = 512
max_seqlen = 128
batch_size = 8
num_epochs = 8
lr = 3e-5
......@@ -15,7 +15,7 @@ if __name__ == '__main__':
max_query_len = 64
max_ans_len = 128
weight_decay = 0.01
print_steps = 20
print_steps = 1
num_classes = 2
random_seed = 1
dropout_prob = 0.1
......@@ -33,43 +33,36 @@ if __name__ == '__main__':
pre_params = './pretrain/ernie-zh-base/params'
config = json.load(open('./pretrain/ernie-zh-base/ernie_config.json'))
input_dim = config['hidden_size']
vocab_size = config['vocab_size']
hidden_act = config['hidden_act']
# ----------------------- for training -----------------------
# step 1-1: create readers for training
mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case)
match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed)
# mlm_reader = palm.reader.MaskLMReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data
mrc_reader.load_data(train_file, file_format='json', num_epochs=None, batch_size=batch_size)
match_reader.load_data(train_file_match, file_format='tsv', num_epochs=None, batch_size=batch_size)
# mlm_reader.load_data(train_file_mlm, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size)
# step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers
mrc_reader.register_with(ernie)
match_reader.register_with(ernie)
# mlm_reader.register_with(ernie)
# step 4: create task output heads
mrc_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len)
match_head = palm.head.Match(num_classes, input_dim, dropout_prob)
mlm_head = palm.head.MaskLM(input_dim, hidden_act, dropout_prob)
# step 5-1: create a task trainer
trainer_mrc = palm.Trainer(task_name, mix_ratio=1.0)
# trainer_mlm = palm.Trainer("mlm", mix_ratio=0.5)
trainer_match = palm.Trainer("match", mix_ratio=0.5)
trainer = palm.MultiHeadTrainer([trainer_mrc, trainer_match])
# step 5-2: build forward graph with backbone and task head
loss_var = trainer.build_forward(ernie, [mrc_head, match_head])
# step 6-1*: use warmup
n_steps = mrc_reader.num_examples * num_epochs // batch_size
n_steps = mrc_reader.num_examples * 2 * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
# step 6-2: create a optimizer
......@@ -79,12 +72,11 @@ if __name__ == '__main__':
# step 7: fit prepared reader and data
trainer.fit_readers_with_mixratio([mrc_reader, match_reader], task_name, num_epochs)
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
# save_steps = n_steps-8
save_steps = 1520
save_steps = n_steps-batch_size
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
trainer.train(print_steps=print_steps)
......@@ -106,15 +98,15 @@ if __name__ == '__main__':
mrc_pred_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len, phase='predict')
# step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, mrc_pred_head)
trainer_mrc.build_predict_forward(pred_ernie, mrc_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/ckpt.step'+str(12160)
pred_ckpt = trainer.load_ckpt(pred_model_path)
pred_model_path = './outputs/ckpt.step'+str(save_steps)
pred_ckpt = trainer_mrc.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_mrc_reader, phase='predict')
trainer_mrc.fit_reader(predict_mrc_reader, phase='predict')
# step 8: predict
print('predicting..')
trainer.predict(print_steps=print_steps, output_dir="outputs/")
trainer_mrc.predict(print_steps=print_steps, output_dir="outputs/")
......@@ -57,9 +57,9 @@ def yield_pieces(data, distribute_strategy, batch_size):
yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train', is_multi=False):
if postprocess_fn is None:
def postprocess_fn(batch):
def postprocess_fn(batch, id=-1, phase='train', is_multi=False):
return batch
def worker(reader, dev_count, queue):
......@@ -90,6 +90,10 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
queue.task_done()
if ret is not None:
batches, num_pad = ret
if dev_count > 1 and phase == 'train' and is_multi:
id = batches[0]['__task_id'][0]
else:
id = -1
batch_buf = []
flag_buf = []
for idx, batch in enumerate(batches):
......@@ -97,8 +101,8 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
flag = idx-len(batches) < -num_pad
# if num_pad > 0:
# num_pad -= 1
# batch = postprocess_fn(batch, id)
batch = postprocess_fn(batch)
batch = postprocess_fn(batch, id, phase, is_multi=is_multi)
# batch = postprocess_fn(batch)
batch_buf.append(batch)
flag_buf.append(flag)
yield batch_buf, flag_buf
......
......@@ -93,7 +93,7 @@ class MaskLM(Head):
param_attr=fluid.ParamAttr(
name=scope_name+'mask_lm_trans_fc.w_0',
initializer=_param_initializer),
bias_attr=fluid.ParamAttr(name=scope_name+'mask_lm_trans_fc.b_0'))
bias_attr=fluid.ParamAttr(name=scope_name+'mask_lm_trans_fc.b_0'))
# transform: layer norm
mask_trans_feat = pre_process_layer(
mask_trans_feat, 'n', name=scope_name+'mask_lm_trans')
......
......@@ -201,9 +201,9 @@ class MultiHeadTrainer(Trainer):
feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
if gpu_dev_count > 1:
distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn)
distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase, is_multi=True)
else:
distribute_feeder_fn = iterator_fn
distribute_feeder_fn = iterator_fn()
if phase == 'train':
self._train_reader = distribute_feeder_fn
......@@ -277,8 +277,8 @@ class MultiHeadTrainer(Trainer):
def train_one_step(self, batch):
if dev_count > 1:
assert isinstance(batch, list)
task_id = batch[0]['__task_id'][0]
assert isinstance(batch, tuple)
task_id = batch[0][0]['__task_id'][0]
else:
assert isinstance(batch, dict)
task_id = batch['__task_id'][0]
......
......@@ -415,7 +415,7 @@ class Trainer(object):
self._raw_iterator_fn = iterator_fn
feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
if gpu_dev_count > 1:
distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn)
distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase)
else:
distribute_feeder_fn = iterator_fn()
......@@ -718,9 +718,9 @@ class Trainer(object):
feed, mask = batch
rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list)
num_fakes = decode_fake(len(rt_outputs[0]), mask, self._train_batch_size)
for _ in range(num_fakes):
for item in rt_outputs:
item.pop()
if num_fakes:
rt_outputs = [i[:-num_fakes] for i in rt_outputs]
else:
feed = self._feed_batch_process_fn(batch)
rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list)
......@@ -735,9 +735,8 @@ class Trainer(object):
feed, mask = batch
rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
num_fakes = decode_fake(len(rt_outputs[0]), mask, self._predict_batch_size)
for _ in range(num_fakes):
for item in rt_outputs:
item.pop()
if num_fakes:
rt_outputs = [i[:-num_fakes] for i in rt_outputs]
else:
feed = self._pred_feed_batch_process_fn(batch)
rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
......
......@@ -21,13 +21,20 @@ import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import layers
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
def create_feed_batch_process_fn(net_inputs):
def feed_batch_process_fn(data):
def feed_batch_process_fn(data, id=-1, phase='train', is_multi=False):
temp = {}
for q, var in net_inputs.items():
if dev_count > 1 and phase=='train' and is_multi:
inputs = net_inputs[id]
else:
inputs= net_inputs
for q, var in inputs.items():
if isinstance(var, str) or isinstance(var, unicode):
temp[var] = data[q]
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册