提交 1bb38efb 编写于 作者: W wangxiao1021

add multi-task example

上级 4cc989d2
...@@ -4,33 +4,25 @@ import os ...@@ -4,33 +4,25 @@ import os
import io import io
abs_path = os.path.abspath(__file__) abs_path = os.path.abspath(__file__)
dst_dir = os.path.join(os.path.dirname(abs_path), "data/mlm/") dst_dir = os.path.join(os.path.dirname(abs_path), "data/match/")
dst_dir2 = os.path.join(os.path.dirname(abs_path), "data/match/")
if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir):
os.makedirs(dst_dir) os.makedirs(dst_dir)
if not os.path.exists(dst_dir2) or not os.path.isdir(dst_dir2):
os.makedirs(dst_dir2)
os.mknod("./data/mlm/train.tsv")
os.mknod("./data/match/train.tsv") os.mknod("./data/match/train.tsv")
with io.open("./data/mrc/train.json", "r", encoding='utf-8') as file: with io.open("./data/mrc/train.json", "r", encoding='utf-8') as f:
data = json.load(file)["data"] data = json.load(f)["data"]
i = 0 i = 0
with open("./data/mlm/train.tsv","w") as f: with open("./data/match/train.tsv","w") as f2:
f.write("text_a\n") f2.write("text_a\ttext_b\tlabel\n")
with open("./data/match/train.tsv","w") as f2: for dd in data:
f2.write("text_a\ttext_b\tlabel\n") for d in dd["paragraphs"]:
for dd in data: context = d["context"]
for d in dd["paragraphs"]:
text_a_mlm = d["context"]
l = text_a_mlm+"\n"
f.write(l.encode("utf-8"))
for qa in d["qas"]: for qa in d["qas"]:
text_a = qa["question"] text_a = qa["question"]
answer = qa["answers"][0] answer = qa["answers"][0]
text_b = answer["text"] text_b = answer["text"]
start_pos = answer["answer_start"] start_pos = answer["answer_start"]
text_b_neg = text_a_mlm[0:start_pos] text_b_neg = context[0:start_pos]
if len(text_b_neg) > 512: if len(text_b_neg) > 512:
text_b_neg = text_b_neg[-512:-1] text_b_neg = text_b_neg[-512:-1]
l1 = text_a+"\t"+text_b+"\t1\n" l1 = text_a+"\t"+text_b+"\t1\n"
...@@ -40,6 +32,5 @@ with io.open("./data/mrc/train.json", "r", encoding='utf-8') as file: ...@@ -40,6 +32,5 @@ with io.open("./data/mrc/train.json", "r", encoding='utf-8') as file:
f2.write(l2.encode("utf-8")) f2.write(l2.encode("utf-8"))
i +=2 i +=2
f2.close() f2.close()
f.close() f.close()
file.close()
...@@ -7,7 +7,7 @@ from paddlepalm.distribute import gpu_dev_count ...@@ -7,7 +7,7 @@ from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__': if __name__ == '__main__':
# configs # configs
max_seqlen = 512 max_seqlen = 128
batch_size = 8 batch_size = 8
num_epochs = 8 num_epochs = 8
lr = 3e-5 lr = 3e-5
...@@ -15,7 +15,7 @@ if __name__ == '__main__': ...@@ -15,7 +15,7 @@ if __name__ == '__main__':
max_query_len = 64 max_query_len = 64
max_ans_len = 128 max_ans_len = 128
weight_decay = 0.01 weight_decay = 0.01
print_steps = 20 print_steps = 1
num_classes = 2 num_classes = 2
random_seed = 1 random_seed = 1
dropout_prob = 0.1 dropout_prob = 0.1
...@@ -33,43 +33,36 @@ if __name__ == '__main__': ...@@ -33,43 +33,36 @@ if __name__ == '__main__':
pre_params = './pretrain/ernie-zh-base/params' pre_params = './pretrain/ernie-zh-base/params'
config = json.load(open('./pretrain/ernie-zh-base/ernie_config.json')) config = json.load(open('./pretrain/ernie-zh-base/ernie_config.json'))
input_dim = config['hidden_size'] input_dim = config['hidden_size']
vocab_size = config['vocab_size']
hidden_act = config['hidden_act']
# ----------------------- for training ----------------------- # ----------------------- for training -----------------------
# step 1-1: create readers for training # step 1-1: create readers for training
mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case) mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case)
match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed) match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed)
# mlm_reader = palm.reader.MaskLMReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data # step 1-2: load the training data
mrc_reader.load_data(train_file, file_format='json', num_epochs=None, batch_size=batch_size) mrc_reader.load_data(train_file, file_format='json', num_epochs=None, batch_size=batch_size)
match_reader.load_data(train_file_match, file_format='tsv', num_epochs=None, batch_size=batch_size) match_reader.load_data(train_file_match, file_format='tsv', num_epochs=None, batch_size=batch_size)
# mlm_reader.load_data(train_file_mlm, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size)
# step 2: create a backbone of the model to extract text features # step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config) ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers # step 3: register the backbone in readers
mrc_reader.register_with(ernie) mrc_reader.register_with(ernie)
match_reader.register_with(ernie) match_reader.register_with(ernie)
# mlm_reader.register_with(ernie)
# step 4: create task output heads # step 4: create task output heads
mrc_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len) mrc_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len)
match_head = palm.head.Match(num_classes, input_dim, dropout_prob) match_head = palm.head.Match(num_classes, input_dim, dropout_prob)
mlm_head = palm.head.MaskLM(input_dim, hidden_act, dropout_prob)
# step 5-1: create a task trainer # step 5-1: create a task trainer
trainer_mrc = palm.Trainer(task_name, mix_ratio=1.0) trainer_mrc = palm.Trainer(task_name, mix_ratio=1.0)
# trainer_mlm = palm.Trainer("mlm", mix_ratio=0.5)
trainer_match = palm.Trainer("match", mix_ratio=0.5) trainer_match = palm.Trainer("match", mix_ratio=0.5)
trainer = palm.MultiHeadTrainer([trainer_mrc, trainer_match]) trainer = palm.MultiHeadTrainer([trainer_mrc, trainer_match])
# step 5-2: build forward graph with backbone and task head # step 5-2: build forward graph with backbone and task head
loss_var = trainer.build_forward(ernie, [mrc_head, match_head]) loss_var = trainer.build_forward(ernie, [mrc_head, match_head])
# step 6-1*: use warmup # step 6-1*: use warmup
n_steps = mrc_reader.num_examples * num_epochs // batch_size n_steps = mrc_reader.num_examples * 2 * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps) warmup_steps = int(0.1 * n_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
# step 6-2: create a optimizer # step 6-2: create a optimizer
...@@ -79,12 +72,11 @@ if __name__ == '__main__': ...@@ -79,12 +72,11 @@ if __name__ == '__main__':
# step 7: fit prepared reader and data # step 7: fit prepared reader and data
trainer.fit_readers_with_mixratio([mrc_reader, match_reader], task_name, num_epochs) trainer.fit_readers_with_mixratio([mrc_reader, match_reader], task_name, num_epochs)
# step 8-1*: load pretrained parameters # step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params) trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model # step 8-2*: set saver to save model
# save_steps = n_steps-8 save_steps = n_steps-batch_size
save_steps = 1520
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training # step 8-3: start training
trainer.train(print_steps=print_steps) trainer.train(print_steps=print_steps)
...@@ -106,15 +98,15 @@ if __name__ == '__main__': ...@@ -106,15 +98,15 @@ if __name__ == '__main__':
mrc_pred_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len, phase='predict') mrc_pred_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len, phase='predict')
# step 5: build forward graph with backbone and task head # step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, mrc_pred_head) trainer_mrc.build_predict_forward(pred_ernie, mrc_pred_head)
# step 6: load pretrained model # step 6: load pretrained model
pred_model_path = './outputs/ckpt.step'+str(12160) pred_model_path = './outputs/ckpt.step'+str(save_steps)
pred_ckpt = trainer.load_ckpt(pred_model_path) pred_ckpt = trainer_mrc.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data # step 7: fit prepared reader and data
trainer.fit_reader(predict_mrc_reader, phase='predict') trainer_mrc.fit_reader(predict_mrc_reader, phase='predict')
# step 8: predict # step 8: predict
print('predicting..') print('predicting..')
trainer.predict(print_steps=print_steps, output_dir="outputs/") trainer_mrc.predict(print_steps=print_steps, output_dir="outputs/")
...@@ -57,9 +57,9 @@ def yield_pieces(data, distribute_strategy, batch_size): ...@@ -57,9 +57,9 @@ def yield_pieces(data, distribute_strategy, batch_size):
yield temp yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2): def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train', is_multi=False):
if postprocess_fn is None: if postprocess_fn is None:
def postprocess_fn(batch): def postprocess_fn(batch, id=-1, phase='train', is_multi=False):
return batch return batch
def worker(reader, dev_count, queue): def worker(reader, dev_count, queue):
...@@ -90,6 +90,10 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2): ...@@ -90,6 +90,10 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
queue.task_done() queue.task_done()
if ret is not None: if ret is not None:
batches, num_pad = ret batches, num_pad = ret
if dev_count > 1 and phase == 'train' and is_multi:
id = batches[0]['__task_id'][0]
else:
id = -1
batch_buf = [] batch_buf = []
flag_buf = [] flag_buf = []
for idx, batch in enumerate(batches): for idx, batch in enumerate(batches):
...@@ -97,8 +101,8 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2): ...@@ -97,8 +101,8 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
flag = idx-len(batches) < -num_pad flag = idx-len(batches) < -num_pad
# if num_pad > 0: # if num_pad > 0:
# num_pad -= 1 # num_pad -= 1
# batch = postprocess_fn(batch, id) batch = postprocess_fn(batch, id, phase, is_multi=is_multi)
batch = postprocess_fn(batch) # batch = postprocess_fn(batch)
batch_buf.append(batch) batch_buf.append(batch)
flag_buf.append(flag) flag_buf.append(flag)
yield batch_buf, flag_buf yield batch_buf, flag_buf
......
...@@ -93,7 +93,7 @@ class MaskLM(Head): ...@@ -93,7 +93,7 @@ class MaskLM(Head):
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=scope_name+'mask_lm_trans_fc.w_0', name=scope_name+'mask_lm_trans_fc.w_0',
initializer=_param_initializer), initializer=_param_initializer),
bias_attr=fluid.ParamAttr(name=scope_name+'mask_lm_trans_fc.b_0')) bias_attr=fluid.ParamAttr(name=scope_name+'mask_lm_trans_fc.b_0'))
# transform: layer norm # transform: layer norm
mask_trans_feat = pre_process_layer( mask_trans_feat = pre_process_layer(
mask_trans_feat, 'n', name=scope_name+'mask_lm_trans') mask_trans_feat, 'n', name=scope_name+'mask_lm_trans')
......
...@@ -201,9 +201,9 @@ class MultiHeadTrainer(Trainer): ...@@ -201,9 +201,9 @@ class MultiHeadTrainer(Trainer):
feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs) feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
if gpu_dev_count > 1: if gpu_dev_count > 1:
distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn) distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase, is_multi=True)
else: else:
distribute_feeder_fn = iterator_fn distribute_feeder_fn = iterator_fn()
if phase == 'train': if phase == 'train':
self._train_reader = distribute_feeder_fn self._train_reader = distribute_feeder_fn
...@@ -277,8 +277,8 @@ class MultiHeadTrainer(Trainer): ...@@ -277,8 +277,8 @@ class MultiHeadTrainer(Trainer):
def train_one_step(self, batch): def train_one_step(self, batch):
if dev_count > 1: if dev_count > 1:
assert isinstance(batch, list) assert isinstance(batch, tuple)
task_id = batch[0]['__task_id'][0] task_id = batch[0][0]['__task_id'][0]
else: else:
assert isinstance(batch, dict) assert isinstance(batch, dict)
task_id = batch['__task_id'][0] task_id = batch['__task_id'][0]
......
...@@ -415,7 +415,7 @@ class Trainer(object): ...@@ -415,7 +415,7 @@ class Trainer(object):
self._raw_iterator_fn = iterator_fn self._raw_iterator_fn = iterator_fn
feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs) feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
if gpu_dev_count > 1: if gpu_dev_count > 1:
distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn) distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase)
else: else:
distribute_feeder_fn = iterator_fn() distribute_feeder_fn = iterator_fn()
...@@ -718,9 +718,9 @@ class Trainer(object): ...@@ -718,9 +718,9 @@ class Trainer(object):
feed, mask = batch feed, mask = batch
rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list) rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list)
num_fakes = decode_fake(len(rt_outputs[0]), mask, self._train_batch_size) num_fakes = decode_fake(len(rt_outputs[0]), mask, self._train_batch_size)
for _ in range(num_fakes): if num_fakes:
for item in rt_outputs: rt_outputs = [i[:-num_fakes] for i in rt_outputs]
item.pop()
else: else:
feed = self._feed_batch_process_fn(batch) feed = self._feed_batch_process_fn(batch)
rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list) rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list)
...@@ -735,9 +735,8 @@ class Trainer(object): ...@@ -735,9 +735,8 @@ class Trainer(object):
feed, mask = batch feed, mask = batch
rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list) rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
num_fakes = decode_fake(len(rt_outputs[0]), mask, self._predict_batch_size) num_fakes = decode_fake(len(rt_outputs[0]), mask, self._predict_batch_size)
for _ in range(num_fakes): if num_fakes:
for item in rt_outputs: rt_outputs = [i[:-num_fakes] for i in rt_outputs]
item.pop()
else: else:
feed = self._pred_feed_batch_process_fn(batch) feed = self._pred_feed_batch_process_fn(batch)
rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list) rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
......
...@@ -21,13 +21,20 @@ import numpy as np ...@@ -21,13 +21,20 @@ import numpy as np
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import layers from paddle.fluid import layers
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
def create_feed_batch_process_fn(net_inputs): def create_feed_batch_process_fn(net_inputs):
def feed_batch_process_fn(data): def feed_batch_process_fn(data, id=-1, phase='train', is_multi=False):
temp = {} temp = {}
for q, var in net_inputs.items(): if dev_count > 1 and phase=='train' and is_multi:
inputs = net_inputs[id]
else:
inputs= net_inputs
for q, var in inputs.items():
if isinstance(var, str) or isinstance(var, unicode): if isinstance(var, str) or isinstance(var, unicode):
temp[var] = data[q] temp[var] = data[q]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册