提交 d71b37d0 编写于 作者: W wangxiao1021

fix bugs

上级 1bb38efb
......@@ -32,7 +32,7 @@ label text_a
### Step 2: Train & Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell
python run.py
......
......@@ -21,7 +21,7 @@ python download.py
After the dataset is downloaded, you should convert the data format for training:
```shell
python process.py quora_duplicate_questions.tsv train.tsv test.tsv
python process.py data/quora_duplicate_questions.tsv data/train.tsv data/test.tsv
```
If everything goes well, there will be a folder named `data/` created with all the converted datas in it.
......@@ -40,7 +40,7 @@ What are the differences between the Dell Inspiron 3000, 5000, and 7000 series l
### Step 2: Train & Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell
python run.py
......
......@@ -39,12 +39,13 @@ Here is some example datas:
}
]
}
}
```
### Step 2: Train & Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell
python run.py
......
......@@ -28,8 +28,8 @@ def download(src, url):
abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz"
downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz")
download_url = "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz"
downlaod_path = os.path.join(os.path.dirname(abs_path), "dmtk_data_1.0.0.tar.gz")
target_dir = os.path.dirname(abs_path)
download(downlaod_path, download_url)
......@@ -37,14 +37,9 @@ tar = tarfile.open(downlaod_path)
tar.extractall(target_dir)
os.remove(downlaod_path)
abs_path = os.path.abspath(__file__)
dst_dir = os.path.join(os.path.dirname(abs_path), "data/mrc")
if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir):
os.makedirs(dst_dir)
for file in os.listdir(os.path.join(target_dir, 'task_data', 'cmrc2018')):
shutil.move(os.path.join(target_dir, 'task_data', 'cmrc2018', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data'))
shutil.rmtree(os.path.join(target_dir, 'data/dstc2/'))
shutil.rmtree(os.path.join(target_dir, 'data/mrda/'))
shutil.rmtree(os.path.join(target_dir, 'data/multi-woz/'))
shutil.rmtree(os.path.join(target_dir, 'data/swda/'))
shutil.rmtree(os.path.join(target_dir, 'data/udc/'))
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Evaluation script for CMRC 2018
version: v5
Note:
v5 formatted output, add usage description
v4 fixed segmentation issues
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
# -*- coding: utf-8 -*-
from collections import Counter, OrderedDict
import string
import re
import argparse
import json
import sys
import nltk
import pdb
# split Chinese with English
def mixed_segmentation(in_str, rm_punc=False):
in_str = in_str.lower().strip()
segs_out = []
temp_str = ""
sp_char = [
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
')', '-', '~', '『', '』'
]
for char in in_str:
if rm_punc and char in sp_char:
continue
if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
temp_str = ""
segs_out.append(char)
else:
temp_str += char
#handling last part
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
return segs_out
# remove punctuation
def remove_punctuation(in_str):
in_str = in_str.lower().strip()
sp_char = [
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
')', '-', '~', '『', '』'
]
out_segs = []
for char in in_str:
if char in sp_char:
continue
else:
out_segs.append(char)
return ''.join(out_segs)
# find longest common string
def find_lcs(s1, s2):
m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
mmax = 0
p = 0
for i in range(len(s1)):
for j in range(len(s2)):
if s1[i] == s2[j]:
m[i + 1][j + 1] = m[i][j] + 1
if m[i + 1][j + 1] > mmax:
mmax = m[i + 1][j + 1]
p = i + 1
return s1[p - mmax:p], mmax
#
def evaluate(ground_truth_file, prediction_file):
f1 = 0
em = 0
total_count = 0
skip_count = 0
for instances in ground_truth_file["data"]:
for instance in instances["paragraphs"]:
context_text = instance['context'].strip()
for qas in instance['qas']:
total_count += 1
query_id = qas['id'].strip()
query_text = qas['question'].strip()
answers = [ans["text"] for ans in qas["answers"]]
if query_id not in prediction_file:
print('Unanswered question: {}\n'.format(
query_id))
skip_count += 1
continue
prediction = prediction_file[query_id]
f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction)
f1_score = 100.0 * f1 / total_count
em_score = 100.0 * em / total_count
return f1_score, em_score, total_count, skip_count
def calc_f1_score(answers, prediction):
f1_scores = []
for ans in answers:
ans_segs = mixed_segmentation(ans, rm_punc=True)
prediction_segs = mixed_segmentation(prediction, rm_punc=True)
lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
if lcs_len == 0:
f1_scores.append(0)
continue
precision = 1.0 * lcs_len / len(prediction_segs)
recall = 1.0 * lcs_len / len(ans_segs)
f1 = (2 * precision * recall) / (precision + recall)
f1_scores.append(f1)
return max(f1_scores)
def calc_em_score(answers, prediction):
em = 0
for ans in answers:
ans_ = remove_punctuation(ans)
prediction_ = remove_punctuation(prediction)
if ans_ == prediction_:
em = 1
break
return em
def eval_file(dataset_file, prediction_file):
ground_truth_file = json.load(open(dataset_file, 'r'))
prediction_file = json.load(open(prediction_file, 'r'))
F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
AVG = (EM + F1) * 0.5
return EM, F1, AVG, TOTAL
if __name__ == '__main__':
EM, F1, AVG, TOTAL = eval_file("task_data/cmrc2018/dev.json", "predictions.json")
print(EM)
print(F1)
print(TOTAL)
\ No newline at end of file
def load_label_map(map_dir="./data/atis/atis_slot/label_map.json"):
"""
:param map_dir: dict indictuing chunk type
:return:
"""
return json.load(open(map_dir, "r"))
def cal_chunk(total_res, total_label):
assert len(total_label) == len(total_res), 'prediction result doesn\'t match to labels'
num_labels = 0
num_corr = 0
num_infers = 0
for res, label in zip(total_res, total_label):
assert len(res) == len(label), "prediction result doesn\'t match to labels"
num_labels += sum([0 if i == 6 else 1 for i in label])
num_corr += sum([1 if label[i] == res[i] and label[i] != 6 else 0 for i in range(len(label))])
num_infers += sum([0 if i == 6 else 1 for i in res])
precision = num_corr * 1.0 / num_infers if num_infers > 0 else 0.0
recall = num_corr * 1.0 / num_labels if num_labels > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
return precision, recall, f1
def res_evaluate(res_dir="./outputs/predict/predictions.json", data_dir="./data/atis/atis_slot/test.tsv"):
label_map = load_label_map()
total_label = []
with open(data_dir, "r") as file:
first_flag = True
for line in file:
if first_flag:
first_flag = False
continue
line = line.strip("\n")
if len(line) == 0:
continue
line = line.split("\t")
if len(line) < 2:
continue
labels = line[1][:-1].split("\x02")
total_label.append(labels)
total_label = [[label_map[j] for j in i] for i in total_label]
total_res = []
with open(res_dir, "r") as file:
cnt = 0
for line in file:
line = line.strip("\n")
if len(line) == 0:
continue
try:
res_arr = json.loads(line)
if len(total_label[cnt]) < len(res_arr):
total_res.append(res_arr[1: 1 + len(total_label[cnt])])
elif len(total_label[cnt]) == len(res_arr):
total_res.append(res_arr)
else:
total_res.append(res_arr)
total_label[cnt] = total_label[cnt][: len(res_arr)]
except:
print("json format error: {}".format(cnt))
print(line)
cnt += 1
precision, recall, f1 = cal_chunk(total_res, total_label)
print("precision: {}, recall: {}, f1: {}".format(precision, recall, f1))
res_evaluate()
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
# configs
max_seqlen = 256
batch_size = 16
num_epochs = 6
print_steps = 5
lr = 5e-5
num_classes = 130
random_seed = 1
label_map = './data/atis/atis_slot/label_map.json'
vocab_path = './pretrain/ernie-en-base/vocab.txt'
predict_file = './data/atis/atis_slot/test.tsv'
save_path = './outputs/'
pred_output = './outputs/predict/'
save_type = 'ckpt'
pre_params = './pretrain/ernie-en-base/params'
config = json.load(open('./pretrain/ernie-en-base/ernie_config.json'))
input_dim = config['hidden_size']
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
print('prepare to predict...')
predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict')
# step 1-2: load the training data
predict_seq_label_reader.load_data(predict_file, batch_size)
# step 2: create a backbone of the model to extract text features
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
# step 3: register the backbone in reader
predict_seq_label_reader.register_with(pred_ernie)
# step 4: create the task output head
seq_label_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict')
# step 5-1: create a task trainer
trainer_seq_label = palm.Trainer("slot")
# step 5-2: build forward graph with backbone and task head
trainer_seq_label.build_predict_forward(pred_ernie, seq_label_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/1580822697.73-ckpt.step9282'
pred_ckpt = trainer_seq_label.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer_seq_label.fit_reader(predict_seq_label_reader, phase='predict')
# step 8: predict
print('predicting..')
trainer_seq_label.predict(print_steps=print_steps, output_dir=pred_output)
\ No newline at end of file
# -*- coding: UTF-8 -*-
import json
import os
import io
abs_path = os.path.abspath(__file__)
dst_dir = os.path.join(os.path.dirname(abs_path), "data/match/")
if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir):
os.makedirs(dst_dir)
os.mknod("./data/match/train.tsv")
with io.open("./data/mrc/train.json", "r", encoding='utf-8') as f:
data = json.load(f)["data"]
i = 0
with open("./data/match/train.tsv","w") as f2:
f2.write("text_a\ttext_b\tlabel\n")
for dd in data:
for d in dd["paragraphs"]:
context = d["context"]
for qa in d["qas"]:
text_a = qa["question"]
answer = qa["answers"][0]
text_b = answer["text"]
start_pos = answer["answer_start"]
text_b_neg = context[0:start_pos]
if len(text_b_neg) > 512:
text_b_neg = text_b_neg[-512:-1]
l1 = text_a+"\t"+text_b+"\t1\n"
l2 = text_a+"\t"+text_b_neg+"\t0\n"
if i < 14246:
f2.write(l1.encode("utf-8"))
f2.write(l2.encode("utf-8"))
i +=2
import json
label_new = "data/atis/atis_slot/label_map.json"
label_old = "data/atis/atis_slot/map_tag_slot_id.txt"
train_old = "data/atis/atis_slot/train.txt"
train_new = "data/atis/atis_slot/train.tsv"
dev_old = "data/atis/atis_slot/dev.txt"
dev_new = "data/atis/atis_slot/dev.tsv"
test_old = "data/atis/atis_slot/test.txt"
test_new = "data/atis/atis_slot/test.tsv"
intent_test = "data/atis/atis_intent/test.tsv"
os.rename("data/atis/atis_intent/test.txt", intent_test)
intent_train = "data/atis/atis_intent/train.tsv"
os.rename("data/atis/atis_intent/train.txt", intent_train)
intent_dev = "data/atis/atis_intent/dev.tsv"
os.rename("data/atis/atis_intent/dev.txt", intent_dev)
with open(intent_dev, 'r+') as f:
content = f.read()
f.seek(0, 0)
f.write("label\ttext_a\n"+content)
f.close()
with open(intent_test, 'r+') as f:
content = f.read()
f.seek(0, 0)
f.write("label\ttext_a\n"+content)
f.close()
with open(intent_train, 'r+') as f:
content = f.read()
f.seek(0, 0)
f.write("label\ttext_a\n"+content)
f.close()
os.mknod(label_new)
os.mknod(train_new)
os.mknod(dev_new)
os.mknod(test_new)
tag = []
id = []
map = {}
with open(label_old, "r") as f:
with open(label_new, "w") as f2:
for line in f.readlines():
line = line.split('\t')
tag.append(line[0])
id.append(int(line[1][:-1]))
map[line[1][:-1]] = line[0]
re = {tag[i]:id[i] for i in range(len(tag))}
re = json.dumps(re)
f2.write(re)
f2.close()
f.close()
with open(train_old, "r") as f:
with open(train_new, "w") as f2:
f2.write("text_a\tlabel\n")
for line in f.readlines():
line = line.split('\t')
text = line[0].split(' ')
label = line[1].split(' ')
for t in text:
f2.write(t)
f2.write('\2')
f2.write('\t')
for t in label:
if t.endswith('\n'):
t = t[:-1]
f2.write(map[t])
f2.write('\2')
f2.write('\n')
f2.close()
f.close()
with open(test_old, "r") as f:
with open(test_new, "w") as f2:
f2.write("text_a\tlabel\n")
for line in f.readlines():
line = line.split('\t')
text = line[0].split(' ')
label = line[1].split(' ')
for t in text:
f2.write(t)
f2.write('\2')
f2.write('\t')
for t in label:
if t.endswith('\n'):
t = t[:-1]
f2.write(map[t])
f2.write('\2')
f2.write('\n')
f2.close()
f.close()
with open(dev_old, "r") as f:
with open(dev_new, "w") as f2:
f2.write("text_a\tlabel\n")
for line in f.readlines():
line = line.split('\t')
text = line[0].split(' ')
label = line[1].split(' ')
for t in text:
f2.write(t)
f2.write('\2')
f2.write('\t')
for t in label:
if t.endswith('\n'):
t = t[:-1]
f2.write(map[t])
f2.write('\2')
f2.write('\n')
f2.close()
f.close()
os.remove(label_old)
os.remove(train_old)
os.remove(test_old)
os.remove(dev_old)
\ No newline at end of file
......@@ -8,61 +8,61 @@ if __name__ == '__main__':
# configs
max_seqlen = 128
batch_size = 8
num_epochs = 8
lr = 3e-5
doc_stride = 128
max_query_len = 64
max_ans_len = 128
batch_size = 16
num_epochs = 20
print_steps = 5
lr = 2e-5
num_classes = 130
weight_decay = 0.01
print_steps = 1
num_classes = 2
random_seed = 1
num_classes_intent = 26
dropout_prob = 0.1
vocab_path = './pretrain/ernie-zh-base/vocab.txt'
do_lower_case = True
random_seed = 0
label_map = './data/atis/atis_slot/label_map.json'
vocab_path = './pretrain/ernie-en-base/vocab.txt'
train_file = './data/mrc/train.json'
train_file_mlm = './data/mlm/train.tsv'
train_file_match = './data/match/train.tsv'
predict_file = './data/mrc/dev.json'
train_slot = './data/atis/atis_slot/train.tsv'
train_intent = './data/atis/atis_intent/train.tsv'
predict_file = './data/atis/atis_slot/test.tsv'
save_path = './outputs/'
pred_output = './outputs/predict/'
save_type = 'ckpt'
task_name = 'cmrc2018'
pre_params = './pretrain/ernie-zh-base/params'
config = json.load(open('./pretrain/ernie-zh-base/ernie_config.json'))
pre_params = './pretrain/ernie-en-base/params'
config = json.load(open('./pretrain/ernie-en-base/ernie_config.json'))
input_dim = config['hidden_size']
# ----------------------- for training -----------------------
# step 1-1: create readers for training
mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case)
# step 1-1: create readers for training
seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data
mrc_reader.load_data(train_file, file_format='json', num_epochs=None, batch_size=batch_size)
match_reader.load_data(train_file_match, file_format='tsv', num_epochs=None, batch_size=batch_size)
seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size)
match_reader.load_data(train_intent, file_format='tsv', num_epochs=None, batch_size=batch_size)
# step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers
mrc_reader.register_with(ernie)
seq_label_reader.register_with(ernie)
match_reader.register_with(ernie)
# step 4: create task output heads
mrc_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len)
match_head = palm.head.Match(num_classes, input_dim, dropout_prob)
seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
match_head = palm.head.Match(num_classes_intent, input_dim, dropout_prob)
# step 5-1: create a task trainer
trainer_mrc = palm.Trainer(task_name, mix_ratio=1.0)
trainer_match = palm.Trainer("match", mix_ratio=0.5)
trainer = palm.MultiHeadTrainer([trainer_mrc, trainer_match])
# step 5-2: build forward graph with backbone and task head
loss_var = trainer.build_forward(ernie, [mrc_head, match_head])
trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0)
trainer_match = palm.Trainer("intent", mix_ratio=0.5)
trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_match])
# # step 5-2: build forward graph with backbone and task head
loss_var1 = trainer_match.build_forward(ernie, match_head)
loss_var2 = trainer_seq_label.build_forward(ernie, seq_label_head)
loss_var = trainer.build_forward()
# step 6-1*: use warmup
n_steps = mrc_reader.num_examples * 2 * num_epochs // batch_size
n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
# step 6-2: create a optimizer
......@@ -71,42 +71,13 @@ if __name__ == '__main__':
trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data
trainer.fit_readers_with_mixratio([mrc_reader, match_reader], task_name, num_epochs)
trainer.fit_readers_with_mixratio([seq_label_reader, match_reader], "slot", num_epochs)
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
save_steps = n_steps-batch_size
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# save_steps = int(n_steps-batch_size)
save_steps = 10
trainer_seq_label.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type, is_multi=True)
# step 8-3: start training
trainer.train(print_steps=print_steps)
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
predict_mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case, phase='predict')
# step 1-2: load the training data
predict_mrc_reader.load_data(predict_file, batch_size)
# step 2: create a backbone of the model to extract text features
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
# step 3: register the backbone in reader
predict_mrc_reader.register_with(pred_ernie)
# step 4: create the task output head
mrc_pred_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len, phase='predict')
# step 5: build forward graph with backbone and task head
trainer_mrc.build_predict_forward(pred_ernie, mrc_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/ckpt.step'+str(save_steps)
pred_ckpt = trainer_mrc.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer_mrc.fit_reader(predict_mrc_reader, phase='predict')
# step 8: predict
print('predicting..')
trainer_mrc.predict(print_steps=print_steps, output_dir="outputs/")
trainer.train(print_steps=print_steps)
\ No newline at end of file
......@@ -32,7 +32,7 @@ label text_a
### Step 2: Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell
python run.py
......
......@@ -34,7 +34,7 @@ text_a label
### Step 2: Train & Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell
python run.py
......
......@@ -32,26 +32,26 @@ if __name__ == '__main__':
# ----------------------- for training -----------------------
# step 1-1: create readers for training
ner_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
# step 1-2: load the training data
ner_reader.load_data(train_file, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size)
seq_label_reader.load_data(train_file, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size)
# step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in reader
ner_reader.register_with(ernie)
seq_label_reader.register_with(ernie)
# step 4: create the task output head
ner_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
# step 5-1: create a task trainer
trainer = palm.Trainer(task_name)
# step 5-2: build forward graph with backbone and task head
loss_var = trainer.build_forward(ernie, ner_head)
loss_var = trainer.build_forward(ernie, seq_label_head)
# step 6-1*: use warmup
n_steps = ner_reader.num_examples * num_epochs // batch_size
n_steps = seq_label_reader.num_examples * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps)
print('total_steps: {}'.format(n_steps))
print('warmup_steps: {}'.format(warmup_steps))
......@@ -62,43 +62,43 @@ if __name__ == '__main__':
trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data
trainer.fit_reader(ner_reader)
trainer.fit_reader(seq_label_reader)
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
save_steps = (n_steps-20)
print('save_steps: {}'.format(save_steps))
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
trainer.train(print_steps=train_print_steps)
# # step 8-1*: load pretrained parameters
# trainer.load_pretrain(pre_params)
# # step 8-2*: set saver to save model
save_steps = 1951
# print('save_steps: {}'.format(save_steps))
# trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# # step 8-3: start training
# trainer.train(print_steps=train_print_steps)
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
print('prepare to predict...')
predict_ner_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict')
predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict')
# step 1-2: load the training data
predict_ner_reader.load_data(predict_file, batch_size)
predict_seq_label_reader.load_data(predict_file, batch_size)
# step 2: create a backbone of the model to extract text features
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
# step 3: register the backbone in reader
predict_ner_reader.register_with(pred_ernie)
predict_seq_label_reader.register_with(pred_ernie)
# step 4: create the task output head
ner_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict')
seq_label_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict')
# step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, ner_pred_head)
trainer.build_predict_forward(pred_ernie, seq_label_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/ckpt.step' + str(save_steps)
pred_ckpt = trainer.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_ner_reader, phase='predict')
trainer.fit_reader(predict_seq_label_reader, phase='predict')
# step 8: predict
print('predicting..')
......
......@@ -39,7 +39,6 @@ class MaskLM(Head):
@property
def inputs_attrs(self):
reader = {
"token_ids":[[-1, -1], 'int64'],
"mask_label": [[-1], 'int64'],
"mask_pos": [[-1], 'int64'],
}
......@@ -59,21 +58,19 @@ class MaskLM(Head):
def build(self, inputs, scope_name=""):
mask_pos = inputs["reader"]["mask_pos"]
word_emb = inputs["backbone"]["embedding_table"]
enc_out = inputs["backbone"]["encoder_outputs"]
if self._is_training:
mask_label = inputs["reader"]["mask_label"]
l1 = fluid.layers.shape(inputs["reader"]["token_ids"] )[0]
# bxs = inputs["reader"]["token_ids"].shape[2].value
l2 = fluid.layers.shape(inputs["reader"]["token_ids"][0])[0]
bxs = (l1*l2).astype(np.int64)
# max_position = inputs["reader"]["batchsize_x_seqlen"] - 1
mask_label = inputs["reader"]["mask_label"]
l1 = enc_out.shape[0]
l2 = enc_out.shape[1]
bxs = fluid.layers.fill_constant(shape=[1], value=l1*l2, dtype='int64')
max_position = bxs - 1
mask_pos = fluid.layers.elementwise_min(mask_pos, max_position)
mask_pos.stop_gradient = True
word_emb = inputs["backbone"]["embedding_table"]
enc_out = inputs["backbone"]["encoder_outputs"]
emb_size = word_emb.shape[-1]
_param_initializer = fluid.initializer.TruncatedNormal(
......
......@@ -52,11 +52,12 @@ class MultiHeadTrainer(Trainer):
'input_varnames': 'self._pred_input_varname_list',
'fetch_list': 'self._pred_fetch_name_list'}
self._check_save = lambda: False
# self._check_save = lambda: False
for t in self._trainers:
t._set_multitask()
def build_forward(self, backbone, heads):
# def build_forward(self, backbone, heads):
def build_forward(self):
"""
Build forward computation graph for training, which usually built from input layer to loss node.
......@@ -67,20 +68,13 @@ class MultiHeadTrainer(Trainer):
Return:
- loss_var: a Variable object. The computational graph variable(node) of loss.
"""
if isinstance(heads, list):
head_dict = {k.name: v for k,v in zip(self._trainers, heads)}
elif isinstance(heads, dict):
head_dict = heads
else:
raise ValueError()
num_heads = len(self._trainers)
assert len(head_dict) == num_heads
for t in self._trainers:
assert t.name in head_dict, "expected: {}, exists: {}".format(t.name, head_dict.keys())
head_dict = {}
backbone = self._trainers[0]._backbone
for i in self._trainers:
assert i._task_head is not None and i._backbone is not None, "You should build forward for the {} task".format(i._name)
assert i._backbone == backbone, "The backbone for each task must be the same"
head_dict[i._name] = i._task_head
train_prog = fluid.Program()
train_init_prog = fluid.Program()
self._train_prog = train_prog
......@@ -88,27 +82,13 @@ class MultiHeadTrainer(Trainer):
def get_loss(i):
head = head_dict[self._trainers[i].name]
# loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog)
loss_var = self._trainers[i].build_forward(backbone, head)
return loss_var
# task_fns = {}
# for i in range(num_heads):
# def task_loss():
# task_id = i
# return lambda: get_loss(task_id)
# task_fns[i] = task_loss()
# task_fns = {i: lambda: get_loss(i) for i in range(num_heads)}
task_fns = {i: lambda i=i: get_loss(i) for i in range(num_heads)}
task_fns = {i: lambda i=i: get_loss(i) for i in range(len(self._trainers))}
with fluid.program_guard(train_prog, train_init_prog):
task_id_var = fluid.data(name="__task_id",shape=[1],dtype='int64')
# task_id_var = fluid.layers.fill_constant(shape=[1],dtype='int64', value=1)
# print(task_id_var.name)
loss_var = layers.switch_case(
branch_index=task_id_var,
......@@ -242,7 +222,6 @@ class MultiHeadTrainer(Trainer):
task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')}
self._trainers[task_id]._task_head.batch_postprocess(task_rt_outputs)
if print_steps > 0 and self._cur_train_step % print_steps == 0:
loss = rt_outputs[self._trainers[task_id].name+'.loss']
loss = np.mean(np.squeeze(loss)).tolist()
......@@ -257,7 +236,7 @@ class MultiHeadTrainer(Trainer):
loss, print_steps / time_cost))
time_begin = time.time()
self._check_save()
# self._check_save()
finish = self._check_finish(self._trainers[task_id].name)
if finish:
break
......@@ -287,7 +266,7 @@ class MultiHeadTrainer(Trainer):
rt_outputs = self._trainers[task_id].train_one_step(batch)
self._cur_train_step += 1
self._check_save()
# self._check_save()
return rt_outputs, task_id
# if dev_count > 1:
......
......@@ -34,7 +34,6 @@ class MaskLMReader(Reader):
for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese'
self._register.add('token_ids')
self._register.add('mask_pos')
if phase == 'train':
self._register.add('mask_label')
......
......@@ -99,7 +99,7 @@ class Reader(object):
if label_map_config:
with open(label_map_config, encoding='utf8') as f:
self.label_map = (f)
self.label_map = json.load(f)
else:
self.label_map = None
......
......@@ -46,7 +46,7 @@ class Trainer(object):
self._pred_reader = None
self._task_head = None
self._pred_head = None
self._train_reader = None
self._predict_reader = None
self._train_iterator = None
......@@ -54,6 +54,8 @@ class Trainer(object):
self._train_init = False
self._predict_init = False
self._train_init_prog = None
self._pred_init_prog = None
self._check_save = lambda: False
......@@ -427,6 +429,7 @@ class Trainer(object):
self._pred_feed_batch_process_fn = feed_batch_process_fn
# return distribute_feeder_fn()
def load_ckpt(self, model_path):
"""
load training checkpoint for further training or predicting.
......@@ -465,7 +468,7 @@ class Trainer(object):
strict=True)
else:
raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")
def load_predict_model(self, model_path, convert=False):
"""
load pretrain models(backbone) for training.
......@@ -500,7 +503,7 @@ class Trainer(object):
convert=convert,
main_program=self._train_init_prog)
def set_saver(self, save_path, save_steps, save_type='ckpt'):
def set_saver(self, save_path, save_steps, save_type='ckpt', is_multi=False):
"""
create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps.
......@@ -510,6 +513,7 @@ class Trainer(object):
save_type: a string. The type of saved model. Currently support checkpoint(ckpt) and predict model(predict), default is ckpt. If both two types are needed to save, you can set as "ckpt,predict".
"""
save_type = save_type.split(',')
if 'predict' in save_type:
......@@ -534,12 +538,21 @@ class Trainer(object):
def temp_func():
if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
if self._save_predict:
self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
if is_multi:
self._save(save_path, suffix='-pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
else:
self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
if self._save_ckpt:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
if is_multi:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
else:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
return True
else:
return False
......@@ -600,7 +613,7 @@ class Trainer(object):
(self._cur_train_step-1) % self._steps_pur_epoch + 1 , self._steps_pur_epoch, self._cur_train_epoch,
loss, print_steps / time_cost))
time_begin = time.time()
self._check_save()
# self._check_save()
# if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
# print(cur_task.name+': train finished!')
# cur_task.save()
......@@ -727,6 +740,7 @@ class Trainer(object):
rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
self._cur_train_step += 1
self._check_save()
self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch
return rt_outputs
......@@ -749,7 +763,7 @@ class Trainer(object):
@property
def name(self):
return self._name
@property
def num_examples(self):
return self._num_examples
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册