提交 4cc989d2 编写于 作者: W wangxiao1021

add multi-task example

上级 b293afa7
......@@ -64,7 +64,7 @@ if __name__ == '__main__':
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
# save_steps = n_steps // gpu_dev_count - batch_size
# save_steps = n_steps
save_steps = 2396
trainer.set_saver(save_steps=save_steps, save_path=save_path, save_type=save_type)
# step 8-3: start training
......
......@@ -67,7 +67,7 @@ if __name__ == '__main__':
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params, False)
# step 8-2*: set saver to save model
# save_steps = (n_steps-16) // gpu_dev_count
# save_steps = n_steps-16
save_steps = 6244
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
......
......@@ -64,7 +64,7 @@ if __name__ == '__main__':
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
# save_steps = (n_steps-8) // gpu_dev_count // 4
# save_steps = (n_steps-8) // 4
save_steps = 1520
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
......
# -*- coding: utf-8 -*-
import os
import requests
import tarfile
import shutil
from tqdm import tqdm
def download(src, url):
file_size = int(requests.head(url).headers['Content-Length'])
header = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/'
'70.0.3538.67 Safari/537.36'
}
pbar = tqdm(total=file_size)
resp = requests.get(url, headers=header, stream=True)
with open(src, 'ab') as f:
for chunk in resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz")
target_dir = os.path.dirname(abs_path)
download(downlaod_path, download_url)
tar = tarfile.open(downlaod_path)
tar.extractall(target_dir)
os.remove(downlaod_path)
abs_path = os.path.abspath(__file__)
dst_dir = os.path.join(os.path.dirname(abs_path), "data/mrc")
if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir):
os.makedirs(dst_dir)
for file in os.listdir(os.path.join(target_dir, 'task_data', 'cmrc2018')):
shutil.move(os.path.join(target_dir, 'task_data', 'cmrc2018', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data'))
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Evaluation script for CMRC 2018
version: v5
Note:
v5 formatted output, add usage description
v4 fixed segmentation issues
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from collections import Counter, OrderedDict
import string
import re
import argparse
import json
import sys
import nltk
import pdb
# split Chinese with English
def mixed_segmentation(in_str, rm_punc=False):
in_str = in_str.lower().strip()
segs_out = []
temp_str = ""
sp_char = [
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
')', '-', '~', '『', '』'
]
for char in in_str:
if rm_punc and char in sp_char:
continue
if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
temp_str = ""
segs_out.append(char)
else:
temp_str += char
#handling last part
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
return segs_out
# remove punctuation
def remove_punctuation(in_str):
in_str = in_str.lower().strip()
sp_char = [
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
')', '-', '~', '『', '』'
]
out_segs = []
for char in in_str:
if char in sp_char:
continue
else:
out_segs.append(char)
return ''.join(out_segs)
# find longest common string
def find_lcs(s1, s2):
m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
mmax = 0
p = 0
for i in range(len(s1)):
for j in range(len(s2)):
if s1[i] == s2[j]:
m[i + 1][j + 1] = m[i][j] + 1
if m[i + 1][j + 1] > mmax:
mmax = m[i + 1][j + 1]
p = i + 1
return s1[p - mmax:p], mmax
#
def evaluate(ground_truth_file, prediction_file):
f1 = 0
em = 0
total_count = 0
skip_count = 0
for instances in ground_truth_file["data"]:
for instance in instances["paragraphs"]:
context_text = instance['context'].strip()
for qas in instance['qas']:
total_count += 1
query_id = qas['id'].strip()
query_text = qas['question'].strip()
answers = [ans["text"] for ans in qas["answers"]]
if query_id not in prediction_file:
print('Unanswered question: {}\n'.format(
query_id))
skip_count += 1
continue
prediction = prediction_file[query_id]
f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction)
f1_score = 100.0 * f1 / total_count
em_score = 100.0 * em / total_count
return f1_score, em_score, total_count, skip_count
def calc_f1_score(answers, prediction):
f1_scores = []
for ans in answers:
ans_segs = mixed_segmentation(ans, rm_punc=True)
prediction_segs = mixed_segmentation(prediction, rm_punc=True)
lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
if lcs_len == 0:
f1_scores.append(0)
continue
precision = 1.0 * lcs_len / len(prediction_segs)
recall = 1.0 * lcs_len / len(ans_segs)
f1 = (2 * precision * recall) / (precision + recall)
f1_scores.append(f1)
return max(f1_scores)
def calc_em_score(answers, prediction):
em = 0
for ans in answers:
ans_ = remove_punctuation(ans)
prediction_ = remove_punctuation(prediction)
if ans_ == prediction_:
em = 1
break
return em
def eval_file(dataset_file, prediction_file):
ground_truth_file = json.load(open(dataset_file, 'r'))
prediction_file = json.load(open(prediction_file, 'r'))
F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
AVG = (EM + F1) * 0.5
return EM, F1, AVG, TOTAL
if __name__ == '__main__':
EM, F1, AVG, TOTAL = eval_file("task_data/cmrc2018/dev.json", "predictions.json")
print(EM)
print(F1)
print(TOTAL)
\ No newline at end of file
# -*- coding: UTF-8 -*-
import json
import os
import io
abs_path = os.path.abspath(__file__)
dst_dir = os.path.join(os.path.dirname(abs_path), "data/mlm/")
dst_dir2 = os.path.join(os.path.dirname(abs_path), "data/match/")
if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir):
os.makedirs(dst_dir)
if not os.path.exists(dst_dir2) or not os.path.isdir(dst_dir2):
os.makedirs(dst_dir2)
os.mknod("./data/mlm/train.tsv")
os.mknod("./data/match/train.tsv")
with io.open("./data/mrc/train.json", "r", encoding='utf-8') as file:
data = json.load(file)["data"]
i = 0
with open("./data/mlm/train.tsv","w") as f:
f.write("text_a\n")
with open("./data/match/train.tsv","w") as f2:
f2.write("text_a\ttext_b\tlabel\n")
for dd in data:
for d in dd["paragraphs"]:
text_a_mlm = d["context"]
l = text_a_mlm+"\n"
f.write(l.encode("utf-8"))
for qa in d["qas"]:
text_a = qa["question"]
answer = qa["answers"][0]
text_b = answer["text"]
start_pos = answer["answer_start"]
text_b_neg = text_a_mlm[0:start_pos]
if len(text_b_neg) > 512:
text_b_neg = text_b_neg[-512:-1]
l1 = text_a+"\t"+text_b+"\t1\n"
l2 = text_a+"\t"+text_b_neg+"\t0\n"
if i < 14246:
f2.write(l1.encode("utf-8"))
f2.write(l2.encode("utf-8"))
i +=2
f2.close()
f.close()
file.close()
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
# configs
max_seqlen = 512
batch_size = 8
num_epochs = 8
lr = 3e-5
doc_stride = 128
max_query_len = 64
max_ans_len = 128
weight_decay = 0.01
print_steps = 20
num_classes = 2
random_seed = 1
dropout_prob = 0.1
vocab_path = './pretrain/ernie-zh-base/vocab.txt'
do_lower_case = True
train_file = './data/mrc/train.json'
train_file_mlm = './data/mlm/train.tsv'
train_file_match = './data/match/train.tsv'
predict_file = './data/mrc/dev.json'
save_path = './outputs/'
pred_output = './outputs/predict/'
save_type = 'ckpt'
task_name = 'cmrc2018'
pre_params = './pretrain/ernie-zh-base/params'
config = json.load(open('./pretrain/ernie-zh-base/ernie_config.json'))
input_dim = config['hidden_size']
vocab_size = config['vocab_size']
hidden_act = config['hidden_act']
# ----------------------- for training -----------------------
# step 1-1: create readers for training
mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case)
match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed)
# mlm_reader = palm.reader.MaskLMReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data
mrc_reader.load_data(train_file, file_format='json', num_epochs=None, batch_size=batch_size)
match_reader.load_data(train_file_match, file_format='tsv', num_epochs=None, batch_size=batch_size)
# mlm_reader.load_data(train_file_mlm, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size)
# step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers
mrc_reader.register_with(ernie)
match_reader.register_with(ernie)
# mlm_reader.register_with(ernie)
# step 4: create task output heads
mrc_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len)
match_head = palm.head.Match(num_classes, input_dim, dropout_prob)
mlm_head = palm.head.MaskLM(input_dim, hidden_act, dropout_prob)
# step 5-1: create a task trainer
trainer_mrc = palm.Trainer(task_name, mix_ratio=1.0)
# trainer_mlm = palm.Trainer("mlm", mix_ratio=0.5)
trainer_match = palm.Trainer("match", mix_ratio=0.5)
trainer = palm.MultiHeadTrainer([trainer_mrc, trainer_match])
# step 5-2: build forward graph with backbone and task head
loss_var = trainer.build_forward(ernie, [mrc_head, match_head])
# step 6-1*: use warmup
n_steps = mrc_reader.num_examples * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
# step 6-2: create a optimizer
adam = palm.optimizer.Adam(loss_var, lr, sched)
# step 6-3: build backward
trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data
trainer.fit_readers_with_mixratio([mrc_reader, match_reader], task_name, num_epochs)
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
# save_steps = n_steps-8
save_steps = 1520
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
trainer.train(print_steps=print_steps)
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
predict_mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case, phase='predict')
# step 1-2: load the training data
predict_mrc_reader.load_data(predict_file, batch_size)
# step 2: create a backbone of the model to extract text features
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
# step 3: register the backbone in reader
predict_mrc_reader.register_with(pred_ernie)
# step 4: create the task output head
mrc_pred_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len, phase='predict')
# step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, mrc_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/ckpt.step'+str(12160)
pred_ckpt = trainer.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_mrc_reader, phase='predict')
# step 8: predict
print('predicting..')
trainer.predict(print_steps=print_steps, output_dir="outputs/")
......@@ -67,7 +67,7 @@ if __name__ == '__main__':
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
save_steps = (n_steps-20)// gpu_dev_count
save_steps = (n_steps-20)
print('save_steps: {}'.format(save_steps))
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
......
......@@ -24,19 +24,17 @@ class MaskLM(Head):
'''
mlm
'''
def __init__(self, input_dim, vocab_size, hidden_act, initializer_range, dropout_prob=0.0, \
def __init__(self, input_dim, vocab_size, hidden_act, dropout_prob=0.0, \
param_initializer_range=0.02, phase='train'):
self._is_training = phase == 'train'
self._emb_size = input_dim
self._hidden_size = input_dim
self._dropout_prob = dropout_prob if phase == 'train' else 0.0
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=param_initializer_range)
self._preds = []
self._vocab_size = vocab_size
self._hidden_act = hidden_act
self._initializer_range = initializer_range
self._initializer_range = param_initializer_range
@property
def inputs_attrs(self):
......
......@@ -5,6 +5,7 @@ from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
from paddlepalm import Trainer
from paddlepalm.utils import reader_helper
import numpy as np
from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake
import time
dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
......@@ -205,10 +206,10 @@ class MultiHeadTrainer(Trainer):
distribute_feeder_fn = iterator_fn
if phase == 'train':
self._train_reader = distribute_feeder_fn()
self._train_reader = distribute_feeder_fn
self._feed_batch_process_fn = feed_batch_process_fn
elif phase == 'predict':
self._predict_reader = distribute_feeder_fn()
self._predict_reader = distribute_feeder_fn
self._pred_feed_batch_process_fn = feed_batch_process_fn
def _check_finish(self, task_name, silent=False):
......
......@@ -99,7 +99,7 @@ class Reader(object):
if label_map_config:
with open(label_map_config, encoding='utf8') as f:
self.label_map = json.load(f)
self.label_map = (f)
else:
self.label_map = None
......
......@@ -373,7 +373,7 @@ class Trainer(object):
self._num_epochs = reader.num_epochs
if phase == 'train':
self._train_reader = reader
self._steps_pur_epoch = reader.num_examples // batch_size // gpu_dev_count
self._steps_pur_epoch = reader.num_examples // batch_size
shape_and_dtypes = self._shape_and_dtypes
name_to_position = self._name_to_position
if self._task_id is not None:
......@@ -387,7 +387,7 @@ class Trainer(object):
elif phase == 'predict':
self._predict_reader = reader
tail = self._num_examples % batch_size > 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 // gpu_dev_count if tail else 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
shape_and_dtypes = self._pred_shape_and_dtypes
name_to_position = self._pred_name_to_position
net_inputs = self._pred_net_inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册