提交 d71b37d0 编写于 作者: W wangxiao1021

fix bugs

上级 1bb38efb
...@@ -32,7 +32,7 @@ label text_a ...@@ -32,7 +32,7 @@ label text_a
### Step 2: Train & Predict ### Step 2: Train & Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell ```shell
python run.py python run.py
......
...@@ -21,7 +21,7 @@ python download.py ...@@ -21,7 +21,7 @@ python download.py
After the dataset is downloaded, you should convert the data format for training: After the dataset is downloaded, you should convert the data format for training:
```shell ```shell
python process.py quora_duplicate_questions.tsv train.tsv test.tsv python process.py data/quora_duplicate_questions.tsv data/train.tsv data/test.tsv
``` ```
If everything goes well, there will be a folder named `data/` created with all the converted datas in it. If everything goes well, there will be a folder named `data/` created with all the converted datas in it.
...@@ -40,7 +40,7 @@ What are the differences between the Dell Inspiron 3000, 5000, and 7000 series l ...@@ -40,7 +40,7 @@ What are the differences between the Dell Inspiron 3000, 5000, and 7000 series l
### Step 2: Train & Predict ### Step 2: Train & Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell ```shell
python run.py python run.py
......
...@@ -39,12 +39,13 @@ Here is some example datas: ...@@ -39,12 +39,13 @@ Here is some example datas:
} }
] ]
} }
}
``` ```
### Step 2: Train & Predict ### Step 2: Train & Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell ```shell
python run.py python run.py
......
...@@ -28,8 +28,8 @@ def download(src, url): ...@@ -28,8 +28,8 @@ def download(src, url):
abs_path = os.path.abspath(__file__) abs_path = os.path.abspath(__file__)
download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" download_url = "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz"
downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz") downlaod_path = os.path.join(os.path.dirname(abs_path), "dmtk_data_1.0.0.tar.gz")
target_dir = os.path.dirname(abs_path) target_dir = os.path.dirname(abs_path)
download(downlaod_path, download_url) download(downlaod_path, download_url)
...@@ -37,14 +37,9 @@ tar = tarfile.open(downlaod_path) ...@@ -37,14 +37,9 @@ tar = tarfile.open(downlaod_path)
tar.extractall(target_dir) tar.extractall(target_dir)
os.remove(downlaod_path) os.remove(downlaod_path)
abs_path = os.path.abspath(__file__) shutil.rmtree(os.path.join(target_dir, 'data/dstc2/'))
dst_dir = os.path.join(os.path.dirname(abs_path), "data/mrc") shutil.rmtree(os.path.join(target_dir, 'data/mrda/'))
if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): shutil.rmtree(os.path.join(target_dir, 'data/multi-woz/'))
os.makedirs(dst_dir) shutil.rmtree(os.path.join(target_dir, 'data/swda/'))
shutil.rmtree(os.path.join(target_dir, 'data/udc/'))
for file in os.listdir(os.path.join(target_dir, 'task_data', 'cmrc2018')):
shutil.move(os.path.join(target_dir, 'task_data', 'cmrc2018', file), dst_dir)
shutil.rmtree(os.path.join(target_dir, 'task_data'))
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Evaluation script for CMRC 2018
version: v5
Note:
v5 formatted output, add usage description
v4 fixed segmentation issues
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from collections import Counter, OrderedDict
import string
import re
import argparse
import json import json
import sys
import nltk
import pdb
# split Chinese with English def load_label_map(map_dir="./data/atis/atis_slot/label_map.json"):
def mixed_segmentation(in_str, rm_punc=False): """
in_str = in_str.lower().strip() :param map_dir: dict indictuing chunk type
segs_out = [] :return:
temp_str = "" """
sp_char = [ return json.load(open(map_dir, "r"))
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
')', '-', '~', '『', '』'
]
for char in in_str:
if rm_punc and char in sp_char:
continue
if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
temp_str = ""
segs_out.append(char)
else:
temp_str += char
#handling last part
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
return segs_out
# remove punctuation def cal_chunk(total_res, total_label):
def remove_punctuation(in_str): assert len(total_label) == len(total_res), 'prediction result doesn\'t match to labels'
in_str = in_str.lower().strip() num_labels = 0
sp_char = [ num_corr = 0
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', num_infers = 0
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(', for res, label in zip(total_res, total_label):
')', '-', '~', '『', '』' assert len(res) == len(label), "prediction result doesn\'t match to labels"
] num_labels += sum([0 if i == 6 else 1 for i in label])
out_segs = [] num_corr += sum([1 if label[i] == res[i] and label[i] != 6 else 0 for i in range(len(label))])
for char in in_str: num_infers += sum([0 if i == 6 else 1 for i in res])
if char in sp_char:
continue
else:
out_segs.append(char)
return ''.join(out_segs)
precision = num_corr * 1.0 / num_infers if num_infers > 0 else 0.0
recall = num_corr * 1.0 / num_labels if num_labels > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
# find longest common string return precision, recall, f1
def find_lcs(s1, s2):
m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
mmax = 0
p = 0
for i in range(len(s1)):
for j in range(len(s2)):
if s1[i] == s2[j]:
m[i + 1][j + 1] = m[i][j] + 1
if m[i + 1][j + 1] > mmax:
mmax = m[i + 1][j + 1]
p = i + 1
return s1[p - mmax:p], mmax
# def res_evaluate(res_dir="./outputs/predict/predictions.json", data_dir="./data/atis/atis_slot/test.tsv"):
def evaluate(ground_truth_file, prediction_file): label_map = load_label_map()
f1 = 0
em = 0
total_count = 0
skip_count = 0
for instances in ground_truth_file["data"]:
for instance in instances["paragraphs"]:
context_text = instance['context'].strip()
for qas in instance['qas']:
total_count += 1
query_id = qas['id'].strip()
query_text = qas['question'].strip()
answers = [ans["text"] for ans in qas["answers"]]
if query_id not in prediction_file: total_label = []
print('Unanswered question: {}\n'.format( with open(data_dir, "r") as file:
query_id)) first_flag = True
skip_count += 1 for line in file:
if first_flag:
first_flag = False
continue continue
line = line.strip("\n")
prediction = prediction_file[query_id] if len(line) == 0:
f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction)
f1_score = 100.0 * f1 / total_count
em_score = 100.0 * em / total_count
return f1_score, em_score, total_count, skip_count
def calc_f1_score(answers, prediction):
f1_scores = []
for ans in answers:
ans_segs = mixed_segmentation(ans, rm_punc=True)
prediction_segs = mixed_segmentation(prediction, rm_punc=True)
lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
if lcs_len == 0:
f1_scores.append(0)
continue continue
precision = 1.0 * lcs_len / len(prediction_segs) line = line.split("\t")
recall = 1.0 * lcs_len / len(ans_segs) if len(line) < 2:
f1 = (2 * precision * recall) / (precision + recall) continue
f1_scores.append(f1) labels = line[1][:-1].split("\x02")
return max(f1_scores) total_label.append(labels)
total_label = [[label_map[j] for j in i] for i in total_label]
def calc_em_score(answers, prediction): total_res = []
em = 0 with open(res_dir, "r") as file:
for ans in answers: cnt = 0
ans_ = remove_punctuation(ans) for line in file:
prediction_ = remove_punctuation(prediction) line = line.strip("\n")
if ans_ == prediction_: if len(line) == 0:
em = 1 continue
break try:
return em res_arr = json.loads(line)
if len(total_label[cnt]) < len(res_arr):
total_res.append(res_arr[1: 1 + len(total_label[cnt])])
elif len(total_label[cnt]) == len(res_arr):
total_res.append(res_arr)
else:
total_res.append(res_arr)
total_label[cnt] = total_label[cnt][: len(res_arr)]
except:
print("json format error: {}".format(cnt))
print(line)
def eval_file(dataset_file, prediction_file): cnt += 1
ground_truth_file = json.load(open(dataset_file, 'r'))
prediction_file = json.load(open(prediction_file, 'r'))
F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
AVG = (EM + F1) * 0.5
return EM, F1, AVG, TOTAL
precision, recall, f1 = cal_chunk(total_res, total_label)
print("precision: {}, recall: {}, f1: {}".format(precision, recall, f1))
if __name__ == '__main__': res_evaluate()
EM, F1, AVG, TOTAL = eval_file("task_data/cmrc2018/dev.json", "predictions.json")
print(EM)
print(F1)
print(TOTAL)
\ No newline at end of file
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
# configs
max_seqlen = 256
batch_size = 16
num_epochs = 6
print_steps = 5
lr = 5e-5
num_classes = 130
random_seed = 1
label_map = './data/atis/atis_slot/label_map.json'
vocab_path = './pretrain/ernie-en-base/vocab.txt'
predict_file = './data/atis/atis_slot/test.tsv'
save_path = './outputs/'
pred_output = './outputs/predict/'
save_type = 'ckpt'
pre_params = './pretrain/ernie-en-base/params'
config = json.load(open('./pretrain/ernie-en-base/ernie_config.json'))
input_dim = config['hidden_size']
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
print('prepare to predict...')
predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict')
# step 1-2: load the training data
predict_seq_label_reader.load_data(predict_file, batch_size)
# step 2: create a backbone of the model to extract text features
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
# step 3: register the backbone in reader
predict_seq_label_reader.register_with(pred_ernie)
# step 4: create the task output head
seq_label_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict')
# step 5-1: create a task trainer
trainer_seq_label = palm.Trainer("slot")
# step 5-2: build forward graph with backbone and task head
trainer_seq_label.build_predict_forward(pred_ernie, seq_label_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/1580822697.73-ckpt.step9282'
pred_ckpt = trainer_seq_label.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer_seq_label.fit_reader(predict_seq_label_reader, phase='predict')
# step 8: predict
print('predicting..')
trainer_seq_label.predict(print_steps=print_steps, output_dir=pred_output)
\ No newline at end of file
# -*- coding: UTF-8 -*-
import json
import os import os
import io import json
abs_path = os.path.abspath(__file__) label_new = "data/atis/atis_slot/label_map.json"
dst_dir = os.path.join(os.path.dirname(abs_path), "data/match/") label_old = "data/atis/atis_slot/map_tag_slot_id.txt"
if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): train_old = "data/atis/atis_slot/train.txt"
os.makedirs(dst_dir) train_new = "data/atis/atis_slot/train.tsv"
os.mknod("./data/match/train.tsv") dev_old = "data/atis/atis_slot/dev.txt"
dev_new = "data/atis/atis_slot/dev.tsv"
with io.open("./data/mrc/train.json", "r", encoding='utf-8') as f: test_old = "data/atis/atis_slot/test.txt"
data = json.load(f)["data"] test_new = "data/atis/atis_slot/test.tsv"
i = 0
with open("./data/match/train.tsv","w") as f2:
f2.write("text_a\ttext_b\tlabel\n") intent_test = "data/atis/atis_intent/test.tsv"
for dd in data: os.rename("data/atis/atis_intent/test.txt", intent_test)
for d in dd["paragraphs"]: intent_train = "data/atis/atis_intent/train.tsv"
context = d["context"] os.rename("data/atis/atis_intent/train.txt", intent_train)
for qa in d["qas"]: intent_dev = "data/atis/atis_intent/dev.tsv"
text_a = qa["question"] os.rename("data/atis/atis_intent/dev.txt", intent_dev)
answer = qa["answers"][0]
text_b = answer["text"] with open(intent_dev, 'r+') as f:
start_pos = answer["answer_start"] content = f.read()
text_b_neg = context[0:start_pos] f.seek(0, 0)
if len(text_b_neg) > 512: f.write("label\ttext_a\n"+content)
text_b_neg = text_b_neg[-512:-1] f.close()
l1 = text_a+"\t"+text_b+"\t1\n"
l2 = text_a+"\t"+text_b_neg+"\t0\n" with open(intent_test, 'r+') as f:
if i < 14246: content = f.read()
f2.write(l1.encode("utf-8")) f.seek(0, 0)
f2.write(l2.encode("utf-8")) f.write("label\ttext_a\n"+content)
i +=2 f.close()
with open(intent_train, 'r+') as f:
content = f.read()
f.seek(0, 0)
f.write("label\ttext_a\n"+content)
f.close()
os.mknod(label_new)
os.mknod(train_new)
os.mknod(dev_new)
os.mknod(test_new)
tag = []
id = []
map = {}
with open(label_old, "r") as f:
with open(label_new, "w") as f2:
for line in f.readlines():
line = line.split('\t')
tag.append(line[0])
id.append(int(line[1][:-1]))
map[line[1][:-1]] = line[0]
re = {tag[i]:id[i] for i in range(len(tag))}
re = json.dumps(re)
f2.write(re)
f2.close()
f.close()
with open(train_old, "r") as f:
with open(train_new, "w") as f2:
f2.write("text_a\tlabel\n")
for line in f.readlines():
line = line.split('\t')
text = line[0].split(' ')
label = line[1].split(' ')
for t in text:
f2.write(t)
f2.write('\2')
f2.write('\t')
for t in label:
if t.endswith('\n'):
t = t[:-1]
f2.write(map[t])
f2.write('\2')
f2.write('\n')
f2.close()
f.close()
with open(test_old, "r") as f:
with open(test_new, "w") as f2:
f2.write("text_a\tlabel\n")
for line in f.readlines():
line = line.split('\t')
text = line[0].split(' ')
label = line[1].split(' ')
for t in text:
f2.write(t)
f2.write('\2')
f2.write('\t')
for t in label:
if t.endswith('\n'):
t = t[:-1]
f2.write(map[t])
f2.write('\2')
f2.write('\n')
f2.close()
f.close()
with open(dev_old, "r") as f:
with open(dev_new, "w") as f2:
f2.write("text_a\tlabel\n")
for line in f.readlines():
line = line.split('\t')
text = line[0].split(' ')
label = line[1].split(' ')
for t in text:
f2.write(t)
f2.write('\2')
f2.write('\t')
for t in label:
if t.endswith('\n'):
t = t[:-1]
f2.write(map[t])
f2.write('\2')
f2.write('\n')
f2.close() f2.close()
f.close() f.close()
os.remove(label_old)
os.remove(train_old)
os.remove(test_old)
os.remove(dev_old)
\ No newline at end of file
...@@ -8,61 +8,61 @@ if __name__ == '__main__': ...@@ -8,61 +8,61 @@ if __name__ == '__main__':
# configs # configs
max_seqlen = 128 max_seqlen = 128
batch_size = 8 batch_size = 16
num_epochs = 8 num_epochs = 20
lr = 3e-5 print_steps = 5
doc_stride = 128 lr = 2e-5
max_query_len = 64 num_classes = 130
max_ans_len = 128
weight_decay = 0.01 weight_decay = 0.01
print_steps = 1 num_classes_intent = 26
num_classes = 2
random_seed = 1
dropout_prob = 0.1 dropout_prob = 0.1
vocab_path = './pretrain/ernie-zh-base/vocab.txt' random_seed = 0
do_lower_case = True label_map = './data/atis/atis_slot/label_map.json'
vocab_path = './pretrain/ernie-en-base/vocab.txt'
train_file = './data/mrc/train.json' train_slot = './data/atis/atis_slot/train.tsv'
train_file_mlm = './data/mlm/train.tsv' train_intent = './data/atis/atis_intent/train.tsv'
train_file_match = './data/match/train.tsv' predict_file = './data/atis/atis_slot/test.tsv'
predict_file = './data/mrc/dev.json'
save_path = './outputs/' save_path = './outputs/'
pred_output = './outputs/predict/' pred_output = './outputs/predict/'
save_type = 'ckpt' save_type = 'ckpt'
task_name = 'cmrc2018'
pre_params = './pretrain/ernie-zh-base/params' pre_params = './pretrain/ernie-en-base/params'
config = json.load(open('./pretrain/ernie-zh-base/ernie_config.json')) config = json.load(open('./pretrain/ernie-en-base/ernie_config.json'))
input_dim = config['hidden_size'] input_dim = config['hidden_size']
# ----------------------- for training ----------------------- # ----------------------- for training -----------------------
# step 1-1: create readers for training # step 1-1: create readers for training
mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case) seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed) match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data # step 1-2: load the training data
mrc_reader.load_data(train_file, file_format='json', num_epochs=None, batch_size=batch_size) seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size)
match_reader.load_data(train_file_match, file_format='tsv', num_epochs=None, batch_size=batch_size) match_reader.load_data(train_intent, file_format='tsv', num_epochs=None, batch_size=batch_size)
# step 2: create a backbone of the model to extract text features # step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config) ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers # step 3: register the backbone in readers
mrc_reader.register_with(ernie) seq_label_reader.register_with(ernie)
match_reader.register_with(ernie) match_reader.register_with(ernie)
# step 4: create task output heads # step 4: create task output heads
mrc_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len) seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
match_head = palm.head.Match(num_classes, input_dim, dropout_prob) match_head = palm.head.Match(num_classes_intent, input_dim, dropout_prob)
# step 5-1: create a task trainer # step 5-1: create a task trainer
trainer_mrc = palm.Trainer(task_name, mix_ratio=1.0) trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0)
trainer_match = palm.Trainer("match", mix_ratio=0.5) trainer_match = palm.Trainer("intent", mix_ratio=0.5)
trainer = palm.MultiHeadTrainer([trainer_mrc, trainer_match]) trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_match])
# step 5-2: build forward graph with backbone and task head # # step 5-2: build forward graph with backbone and task head
loss_var = trainer.build_forward(ernie, [mrc_head, match_head]) loss_var1 = trainer_match.build_forward(ernie, match_head)
loss_var2 = trainer_seq_label.build_forward(ernie, seq_label_head)
loss_var = trainer.build_forward()
# step 6-1*: use warmup # step 6-1*: use warmup
n_steps = mrc_reader.num_examples * 2 * num_epochs // batch_size n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps) warmup_steps = int(0.1 * n_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
# step 6-2: create a optimizer # step 6-2: create a optimizer
...@@ -71,42 +71,13 @@ if __name__ == '__main__': ...@@ -71,42 +71,13 @@ if __name__ == '__main__':
trainer.build_backward(optimizer=adam, weight_decay=weight_decay) trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data # step 7: fit prepared reader and data
trainer.fit_readers_with_mixratio([mrc_reader, match_reader], task_name, num_epochs) trainer.fit_readers_with_mixratio([seq_label_reader, match_reader], "slot", num_epochs)
# step 8-1*: load pretrained parameters # step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params) trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model # step 8-2*: set saver to save model
save_steps = n_steps-batch_size # save_steps = int(n_steps-batch_size)
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) save_steps = 10
trainer_seq_label.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type, is_multi=True)
# step 8-3: start training # step 8-3: start training
trainer.train(print_steps=print_steps) trainer.train(print_steps=print_steps)
\ No newline at end of file
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
predict_mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case, phase='predict')
# step 1-2: load the training data
predict_mrc_reader.load_data(predict_file, batch_size)
# step 2: create a backbone of the model to extract text features
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
# step 3: register the backbone in reader
predict_mrc_reader.register_with(pred_ernie)
# step 4: create the task output head
mrc_pred_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len, phase='predict')
# step 5: build forward graph with backbone and task head
trainer_mrc.build_predict_forward(pred_ernie, mrc_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/ckpt.step'+str(save_steps)
pred_ckpt = trainer_mrc.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer_mrc.fit_reader(predict_mrc_reader, phase='predict')
# step 8: predict
print('predicting..')
trainer_mrc.predict(print_steps=print_steps, output_dir="outputs/")
...@@ -32,7 +32,7 @@ label text_a ...@@ -32,7 +32,7 @@ label text_a
### Step 2: Predict ### Step 2: Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell ```shell
python run.py python run.py
......
...@@ -34,7 +34,7 @@ text_a label ...@@ -34,7 +34,7 @@ text_a label
### Step 2: Train & Predict ### Step 2: Train & Predict
The code used to perform classification task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell ```shell
python run.py python run.py
......
...@@ -32,26 +32,26 @@ if __name__ == '__main__': ...@@ -32,26 +32,26 @@ if __name__ == '__main__':
# ----------------------- for training ----------------------- # ----------------------- for training -----------------------
# step 1-1: create readers for training # step 1-1: create readers for training
ner_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed) seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
# step 1-2: load the training data # step 1-2: load the training data
ner_reader.load_data(train_file, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size) seq_label_reader.load_data(train_file, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size)
# step 2: create a backbone of the model to extract text features # step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config) ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in reader # step 3: register the backbone in reader
ner_reader.register_with(ernie) seq_label_reader.register_with(ernie)
# step 4: create the task output head # step 4: create the task output head
ner_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob) seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
# step 5-1: create a task trainer # step 5-1: create a task trainer
trainer = palm.Trainer(task_name) trainer = palm.Trainer(task_name)
# step 5-2: build forward graph with backbone and task head # step 5-2: build forward graph with backbone and task head
loss_var = trainer.build_forward(ernie, ner_head) loss_var = trainer.build_forward(ernie, seq_label_head)
# step 6-1*: use warmup # step 6-1*: use warmup
n_steps = ner_reader.num_examples * num_epochs // batch_size n_steps = seq_label_reader.num_examples * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps) warmup_steps = int(0.1 * n_steps)
print('total_steps: {}'.format(n_steps)) print('total_steps: {}'.format(n_steps))
print('warmup_steps: {}'.format(warmup_steps)) print('warmup_steps: {}'.format(warmup_steps))
...@@ -62,43 +62,43 @@ if __name__ == '__main__': ...@@ -62,43 +62,43 @@ if __name__ == '__main__':
trainer.build_backward(optimizer=adam, weight_decay=weight_decay) trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data # step 7: fit prepared reader and data
trainer.fit_reader(ner_reader) trainer.fit_reader(seq_label_reader)
# step 8-1*: load pretrained parameters # # step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params) # trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model # # step 8-2*: set saver to save model
save_steps = (n_steps-20) save_steps = 1951
print('save_steps: {}'.format(save_steps)) # print('save_steps: {}'.format(save_steps))
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) # trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training # # step 8-3: start training
trainer.train(print_steps=train_print_steps) # trainer.train(print_steps=train_print_steps)
# ----------------------- for prediction ----------------------- # ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction # step 1-1: create readers for prediction
print('prepare to predict...') print('prepare to predict...')
predict_ner_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict') predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict')
# step 1-2: load the training data # step 1-2: load the training data
predict_ner_reader.load_data(predict_file, batch_size) predict_seq_label_reader.load_data(predict_file, batch_size)
# step 2: create a backbone of the model to extract text features # step 2: create a backbone of the model to extract text features
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
# step 3: register the backbone in reader # step 3: register the backbone in reader
predict_ner_reader.register_with(pred_ernie) predict_seq_label_reader.register_with(pred_ernie)
# step 4: create the task output head # step 4: create the task output head
ner_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict') seq_label_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict')
# step 5: build forward graph with backbone and task head # step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, ner_pred_head) trainer.build_predict_forward(pred_ernie, seq_label_pred_head)
# step 6: load pretrained model # step 6: load pretrained model
pred_model_path = './outputs/ckpt.step' + str(save_steps) pred_model_path = './outputs/ckpt.step' + str(save_steps)
pred_ckpt = trainer.load_ckpt(pred_model_path) pred_ckpt = trainer.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data # step 7: fit prepared reader and data
trainer.fit_reader(predict_ner_reader, phase='predict') trainer.fit_reader(predict_seq_label_reader, phase='predict')
# step 8: predict # step 8: predict
print('predicting..') print('predicting..')
......
...@@ -39,7 +39,6 @@ class MaskLM(Head): ...@@ -39,7 +39,6 @@ class MaskLM(Head):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
reader = { reader = {
"token_ids":[[-1, -1], 'int64'],
"mask_label": [[-1], 'int64'], "mask_label": [[-1], 'int64'],
"mask_pos": [[-1], 'int64'], "mask_pos": [[-1], 'int64'],
} }
...@@ -59,21 +58,19 @@ class MaskLM(Head): ...@@ -59,21 +58,19 @@ class MaskLM(Head):
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
mask_pos = inputs["reader"]["mask_pos"] mask_pos = inputs["reader"]["mask_pos"]
word_emb = inputs["backbone"]["embedding_table"]
enc_out = inputs["backbone"]["encoder_outputs"]
if self._is_training: if self._is_training:
mask_label = inputs["reader"]["mask_label"] mask_label = inputs["reader"]["mask_label"]
l1 = fluid.layers.shape(inputs["reader"]["token_ids"] )[0] l1 = enc_out.shape[0]
# bxs = inputs["reader"]["token_ids"].shape[2].value l2 = enc_out.shape[1]
l2 = fluid.layers.shape(inputs["reader"]["token_ids"][0])[0] bxs = fluid.layers.fill_constant(shape=[1], value=l1*l2, dtype='int64')
bxs = (l1*l2).astype(np.int64)
# max_position = inputs["reader"]["batchsize_x_seqlen"] - 1
max_position = bxs - 1 max_position = bxs - 1
mask_pos = fluid.layers.elementwise_min(mask_pos, max_position) mask_pos = fluid.layers.elementwise_min(mask_pos, max_position)
mask_pos.stop_gradient = True mask_pos.stop_gradient = True
word_emb = inputs["backbone"]["embedding_table"]
enc_out = inputs["backbone"]["encoder_outputs"]
emb_size = word_emb.shape[-1] emb_size = word_emb.shape[-1]
_param_initializer = fluid.initializer.TruncatedNormal( _param_initializer = fluid.initializer.TruncatedNormal(
......
...@@ -52,11 +52,12 @@ class MultiHeadTrainer(Trainer): ...@@ -52,11 +52,12 @@ class MultiHeadTrainer(Trainer):
'input_varnames': 'self._pred_input_varname_list', 'input_varnames': 'self._pred_input_varname_list',
'fetch_list': 'self._pred_fetch_name_list'} 'fetch_list': 'self._pred_fetch_name_list'}
self._check_save = lambda: False # self._check_save = lambda: False
for t in self._trainers: for t in self._trainers:
t._set_multitask() t._set_multitask()
def build_forward(self, backbone, heads): # def build_forward(self, backbone, heads):
def build_forward(self):
""" """
Build forward computation graph for training, which usually built from input layer to loss node. Build forward computation graph for training, which usually built from input layer to loss node.
...@@ -67,19 +68,12 @@ class MultiHeadTrainer(Trainer): ...@@ -67,19 +68,12 @@ class MultiHeadTrainer(Trainer):
Return: Return:
- loss_var: a Variable object. The computational graph variable(node) of loss. - loss_var: a Variable object. The computational graph variable(node) of loss.
""" """
head_dict = {}
if isinstance(heads, list): backbone = self._trainers[0]._backbone
head_dict = {k.name: v for k,v in zip(self._trainers, heads)} for i in self._trainers:
elif isinstance(heads, dict): assert i._task_head is not None and i._backbone is not None, "You should build forward for the {} task".format(i._name)
head_dict = heads assert i._backbone == backbone, "The backbone for each task must be the same"
else: head_dict[i._name] = i._task_head
raise ValueError()
num_heads = len(self._trainers)
assert len(head_dict) == num_heads
for t in self._trainers:
assert t.name in head_dict, "expected: {}, exists: {}".format(t.name, head_dict.keys())
train_prog = fluid.Program() train_prog = fluid.Program()
train_init_prog = fluid.Program() train_init_prog = fluid.Program()
...@@ -88,27 +82,13 @@ class MultiHeadTrainer(Trainer): ...@@ -88,27 +82,13 @@ class MultiHeadTrainer(Trainer):
def get_loss(i): def get_loss(i):
head = head_dict[self._trainers[i].name] head = head_dict[self._trainers[i].name]
# loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog)
loss_var = self._trainers[i].build_forward(backbone, head) loss_var = self._trainers[i].build_forward(backbone, head)
return loss_var return loss_var
# task_fns = {} task_fns = {i: lambda i=i: get_loss(i) for i in range(len(self._trainers))}
# for i in range(num_heads):
# def task_loss():
# task_id = i
# return lambda: get_loss(task_id)
# task_fns[i] = task_loss()
# task_fns = {i: lambda: get_loss(i) for i in range(num_heads)}
task_fns = {i: lambda i=i: get_loss(i) for i in range(num_heads)}
with fluid.program_guard(train_prog, train_init_prog): with fluid.program_guard(train_prog, train_init_prog):
task_id_var = fluid.data(name="__task_id",shape=[1],dtype='int64') task_id_var = fluid.data(name="__task_id",shape=[1],dtype='int64')
# task_id_var = fluid.layers.fill_constant(shape=[1],dtype='int64', value=1)
# print(task_id_var.name)
loss_var = layers.switch_case( loss_var = layers.switch_case(
branch_index=task_id_var, branch_index=task_id_var,
...@@ -242,7 +222,6 @@ class MultiHeadTrainer(Trainer): ...@@ -242,7 +222,6 @@ class MultiHeadTrainer(Trainer):
task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')} task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')}
self._trainers[task_id]._task_head.batch_postprocess(task_rt_outputs) self._trainers[task_id]._task_head.batch_postprocess(task_rt_outputs)
if print_steps > 0 and self._cur_train_step % print_steps == 0: if print_steps > 0 and self._cur_train_step % print_steps == 0:
loss = rt_outputs[self._trainers[task_id].name+'.loss'] loss = rt_outputs[self._trainers[task_id].name+'.loss']
loss = np.mean(np.squeeze(loss)).tolist() loss = np.mean(np.squeeze(loss)).tolist()
...@@ -257,7 +236,7 @@ class MultiHeadTrainer(Trainer): ...@@ -257,7 +236,7 @@ class MultiHeadTrainer(Trainer):
loss, print_steps / time_cost)) loss, print_steps / time_cost))
time_begin = time.time() time_begin = time.time()
self._check_save() # self._check_save()
finish = self._check_finish(self._trainers[task_id].name) finish = self._check_finish(self._trainers[task_id].name)
if finish: if finish:
break break
...@@ -287,7 +266,7 @@ class MultiHeadTrainer(Trainer): ...@@ -287,7 +266,7 @@ class MultiHeadTrainer(Trainer):
rt_outputs = self._trainers[task_id].train_one_step(batch) rt_outputs = self._trainers[task_id].train_one_step(batch)
self._cur_train_step += 1 self._cur_train_step += 1
self._check_save() # self._check_save()
return rt_outputs, task_id return rt_outputs, task_id
# if dev_count > 1: # if dev_count > 1:
......
...@@ -34,7 +34,6 @@ class MaskLMReader(Reader): ...@@ -34,7 +34,6 @@ class MaskLMReader(Reader):
for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese' for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese'
self._register.add('token_ids')
self._register.add('mask_pos') self._register.add('mask_pos')
if phase == 'train': if phase == 'train':
self._register.add('mask_label') self._register.add('mask_label')
......
...@@ -99,7 +99,7 @@ class Reader(object): ...@@ -99,7 +99,7 @@ class Reader(object):
if label_map_config: if label_map_config:
with open(label_map_config, encoding='utf8') as f: with open(label_map_config, encoding='utf8') as f:
self.label_map = (f) self.label_map = json.load(f)
else: else:
self.label_map = None self.label_map = None
......
...@@ -54,6 +54,8 @@ class Trainer(object): ...@@ -54,6 +54,8 @@ class Trainer(object):
self._train_init = False self._train_init = False
self._predict_init = False self._predict_init = False
self._train_init_prog = None
self._pred_init_prog = None
self._check_save = lambda: False self._check_save = lambda: False
...@@ -427,6 +429,7 @@ class Trainer(object): ...@@ -427,6 +429,7 @@ class Trainer(object):
self._pred_feed_batch_process_fn = feed_batch_process_fn self._pred_feed_batch_process_fn = feed_batch_process_fn
# return distribute_feeder_fn() # return distribute_feeder_fn()
def load_ckpt(self, model_path): def load_ckpt(self, model_path):
""" """
load training checkpoint for further training or predicting. load training checkpoint for further training or predicting.
...@@ -500,7 +503,7 @@ class Trainer(object): ...@@ -500,7 +503,7 @@ class Trainer(object):
convert=convert, convert=convert,
main_program=self._train_init_prog) main_program=self._train_init_prog)
def set_saver(self, save_path, save_steps, save_type='ckpt'): def set_saver(self, save_path, save_steps, save_type='ckpt', is_multi=False):
""" """
create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps. create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps.
...@@ -511,6 +514,7 @@ class Trainer(object): ...@@ -511,6 +514,7 @@ class Trainer(object):
""" """
save_type = save_type.split(',') save_type = save_type.split(',')
if 'predict' in save_type: if 'predict' in save_type:
assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model." assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
...@@ -534,10 +538,19 @@ class Trainer(object): ...@@ -534,10 +538,19 @@ class Trainer(object):
def temp_func(): def temp_func():
if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0: if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
if self._save_predict: if self._save_predict:
if is_multi:
self._save(save_path, suffix='-pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
else:
self._save(save_path, suffix='pred.step'+str(self._cur_train_step)) self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step))) print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
if self._save_ckpt: if self._save_ckpt:
if is_multi:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
else:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog) fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step))) print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
return True return True
...@@ -600,7 +613,7 @@ class Trainer(object): ...@@ -600,7 +613,7 @@ class Trainer(object):
(self._cur_train_step-1) % self._steps_pur_epoch + 1 , self._steps_pur_epoch, self._cur_train_epoch, (self._cur_train_step-1) % self._steps_pur_epoch + 1 , self._steps_pur_epoch, self._cur_train_epoch,
loss, print_steps / time_cost)) loss, print_steps / time_cost))
time_begin = time.time() time_begin = time.time()
self._check_save() # self._check_save()
# if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
# print(cur_task.name+': train finished!') # print(cur_task.name+': train finished!')
# cur_task.save() # cur_task.save()
...@@ -727,6 +740,7 @@ class Trainer(object): ...@@ -727,6 +740,7 @@ class Trainer(object):
rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
self._cur_train_step += 1 self._cur_train_step += 1
self._check_save()
self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch
return rt_outputs return rt_outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册