未验证 提交 15ab54d7 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #51 from wangxiao1021/api

Update README.md
......@@ -186,17 +186,17 @@ Available pretrain items:
For more implementation details, see following demos:
- [Sentiment Classification]()
- [Quora Question Pairs matching]()
- [Tagging]()
- [SQuAD machine Reading Comprehension]().
- [Sentiment Classification](https://github.com/PaddlePaddle/PALM/tree/master/examples/classification)
- [Quora Question Pairs matching](https://github.com/PaddlePaddle/PALM/tree/master/examples/matching)
- [Tagging](https://github.com/PaddlePaddle/PALM/tree/master/examples/tagging)
- [SQuAD machine Reading Comprehension](https://github.com/PaddlePaddle/PALM/tree/master/examples/mrc).
### set saver
To save models/checkpoints and logs during training, just call `trainer.set_saver` method. More implementation details see [this]().
To save models/checkpoints and logs during training, just call `trainer.set_saver` method. More implementation details see [this](https://github.com/PaddlePaddle/PALM/tree/master/examples).
### do prediction
To do predict/evaluation after a training stage, just create another three reader, backbone and head instance with `phase='predict'` (repeat step 1~4 above). Then do predicting with `predict` method in trainer (no need to create another trainer). More implementation details see [this]().
To do predict/evaluation after a training stage, just create another three reader, backbone and head instance with `phase='predict'` (repeat step 1~4 above). Then do predicting with `predict` method in trainer (no need to create another trainer). More implementation details see [this](https://github.com/PaddlePaddle/PALM/tree/master/examples/predict).
### multi-task learning
To run with multi-task learning mode:
......@@ -212,7 +212,7 @@ The save/load and predict operations of a multi_head_trainer is the same as a tr
For more implementation details with `multi_head_trainer`, see
- [ATIS: joint training of dialogue intent recognition and slot filling]()
- [ATIS: joint training of dialogue intent recognition and slot filling](https://github.com/PaddlePaddle/PALM/tree/master/examples/multi-task)
- [MRQA: learning reading comprehension auxilarized with mask language model]() (初次发版先不用加)
......@@ -222,5 +222,4 @@ This tutorial is contributed by [PaddlePaddle](https://github.com/PaddlePaddle/P
## 许可证书
此向导由[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)贡献,受[Apache-2.0 license](https://github.com/PaddlePaddle/models/blob/develop/LICENSE)许可认证。
此向导由[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)贡献,受[Apache-2.0 license](https://github.com/PaddlePaddle/models/blob/develop/LICENSE)许可认证。
\ No newline at end of file
## Examples 4: Machine Reading Comprehension
## Example 4: Machine Reading Comprehension
This task is a machine reading comprehension task. The following sections detail model preparation, dataset preparation, and how to run the task.
### Step 1: Prepare Pre-trained Models & Datasets
......
## Example 6: Joint Training in Dialogue
This task is a slot filling task. During training, the task uses intent determination task to assist in training slot filling model. The following sections detail model preparation, dataset preparation, and how to run the task.
### Step 1: Prepare Pre-trained Models & Datasets
#### Pre-trianed Model
The pre-training model of this mission is: [ernie-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api).
Make sure you have downloaded the required pre-training model in the current folder.
#### Dataset
This task uses the `Airline Travel Information System` dataset.
Download dataset:
```shell
python download.py
```
After the dataset is downloaded, you should convert the data format for training:
```shell
python process.py
```
If everything goes well, there will be a folder named `data/atis/` created with all the datas in it.
Here is some example datas:
`data/atis/atis_slot/train.tsv` :
```
text_a label
i want to fly from boston at 838 am and arrive in denver at 1110 in the morning O O O O O B-fromloc.city_name O B-depart_time.time I-depart_time.time O O O B-toloc.city_name O B-arrive_time.time O O B-arrive_time.period_of_day
what flights are available from pittsburgh to baltimore on thursday morning O O O O O B-fromloc.city_name O B-toloc.city_name O B-depart_date.day_name B-depart_time.period_of_day
what is the arrival time in san francisco for the 755 am flight leaving washington O O O B-flight_time I-flight_time O B-fromloc.city_name I-fromloc.city_name O O B-depart_time.time I-depart_time.time O O B-fromloc.city_name
cheapest airfare from tacoma to orlando B-cost_relative O O B-fromloc.city_name O B-toloc.city_name
```
`data/atis/atis_intent/train.tsv` :
```
label text_a
0 i want to fly from boston at 838 am and arrive in denver at 1110 in the morning
0 what flights are available from pittsburgh to baltimore on thursday morning
1 what is the arrival time in san francisco for the 755 am flight leaving washington
2 cheapest airfare from tacoma to orlando
```
### Step 2: Train & Predict
The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run:
```shell
python run.py
```
If you want to specify a specific gpu or use multiple gpus for training, please use **`CUDA_VISIBLE_DEVICES`**, for example:
```shell
CUDA_VISIBLE_DEVICES=0,1,2 python run.py
```
Some logs will be shown below:
```
global step: 5, slot: step 3/309 (epoch 0), loss: 68.965, speed: 0.58 steps/s
global step: 10, intent: step 3/311 (epoch 0), loss: 3.407, speed: 8.76 steps/s
global step: 15, slot: step 12/309 (epoch 0), loss: 54.611, speed: 1.21 steps/s
global step: 20, intent: step 7/311 (epoch 0), loss: 3.487, speed: 10.28 steps/s
```
After the run, you can view the saved models in the `outputs/` folder.
If you want to use the trained model to predict the `atis_slot & atis_intent` data, run:
```shell
python predict-slot.py
python predict-intent.py
```
If you want to specify a specific gpu or use multiple gpus for predict, please use **`CUDA_VISIBLE_DEVICES`**, for example:
```shell
CUDA_VISIBLE_DEVICES=0,1,2 python predict-slot.py
CUDA_VISIBLE_DEVICES=0,1,2 python predict-intent.py
```
After the run, you can view the predictions in the `outputs/predict-slot` folder and `outputs/predict-intent` folder. Here are some examples of predictions:
`atis_slot`:
```
[129, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 5, 19, 1, 1, 1, 1, 1, 21, 21, 68, 129]
[129, 1, 39, 37, 1, 1, 1, 1, 1, 2, 1, 5, 19, 1, 23, 3, 4, 129, 129, 129, 129, 129]
[129, 1, 39, 37, 1, 1, 1, 1, 1, 1, 2, 1, 5, 19, 129, 129, 129, 129, 129, 129, 129, 129]
[129, 1, 1, 1, 1, 1, 1, 14, 15, 1, 2, 1, 5, 19, 1, 39, 37, 129, 24, 129, 129, 129]
```
`atis_intent`:
```
{"index": 0, "logits": [9.938603401184082, -0.3914794623851776, -0.050973162055015564, -1.0229418277740479, 0.04799401015043259, -0.9632213115692139, -0.6427211761474609, -1.337939739227295, -0.7969412803649902, -1.4441455602645874, -0.6339573264122009, -1.0393054485321045, -0.9242327213287354, -1.9637483358383179, 0.16733427345752716, -0.5280354619026184, -1.7195699214935303, -2.199411630630493, -1.2833174467086792, -1.3081035614013672, -1.6036226749420166, -1.8527079820632935, -2.289180040359497, -2.267214775085449, -2.2578916549682617, -2.2010505199432373], "probs": [0.999531626701355, 3.26210938510485e-05, 4.585415081237443e-05, 1.7348344044876285e-05, 5.06243304698728e-05, 1.8415948943584226e-05, 2.5373808966833167e-05, 1.266065828531282e-05, 2.174747896788176e-05, 1.1384962817828637e-05, 2.5597169951652177e-05, 1.7066764485207386e-05, 1.914815220516175e-05, 6.771284006390488e-06, 5.70411684748251e-05, 2.8457265216275118e-05, 8.644025911053177e-06, 5.349628736439627e-06, 1.3371440218179487e-05, 1.3044088518654462e-05, 9.706698619993404e-06, 7.5665011536329985e-06, 4.890325726591982e-06, 4.99892985317274e-06, 5.045753368904116e-06, 5.340866664482746e-06], "label": 0}
{"index": 1, "logits": [0.8863624930381775, -2.232290506362915, 8.191509246826172, -0.03161466494202614, -0.9149583578109741, -2.172696352005005, -0.3937145471572876, -0.3954394459724426, 1.5333592891693115, 0.8630291223526001, -0.9684226512908936, -2.722721815109253, -0.0060247331857681274, -0.9865402579307556, 1.6328885555267334, 0.3972966969013214, 0.27919167280197144, -1.4911551475524902, -0.9552251696586609, -0.9169244170188904, -0.810670793056488, -1.5118697881698608, -2.0140435695648193, -1.6299077272415161, -1.8589974641799927, -2.07601261138916], "probs": [0.0006675600307062268, 2.9517297662096098e-05, 0.9932880997657776, 0.0002665741485543549, 0.0001102013120544143, 3.132982965325937e-05, 0.00018559220188762993, 0.00018527248175814748, 0.0012749042361974716, 0.0006521637551486492, 0.00010446414671605453, 1.8075270418194123e-05, 0.0002734838053584099, 0.00010258861584588885, 0.0014083238784223795, 0.00040934717981144786, 0.00036374686169438064, 6.193659646669403e-05, 0.00010585198469925672, 0.00010998480865964666, 0.0001223145518451929, 6.0666847275570035e-05, 3.671637750812806e-05, 5.391232480178587e-05, 4.287416595616378e-05, 3.4510172554291785e-05], "label": 0}
{"index": 2, "logits": [9.789957046508789, -0.1730862706899643, -0.7198237776756287, -1.0460278987884521, 0.23521068692207336, -0.5075851678848267, -0.44724929332733154, -1.2945927381515503, -0.6984466314315796, -1.8749892711639404, -0.4631594121456146, -0.6256799697875977, -1.0252169370651245, -1.951456069946289, -0.17572557926177979, -0.6771697402000427, -1.7992591857910156, -2.1457295417785645, -1.4203097820281982, -1.4963451623916626, -1.692310094833374, -1.9219486713409424, -2.2533645629882812, -2.430952310562134, -2.3094685077667236, -2.2399914264678955], "probs": [0.9994625449180603, 4.708383130491711e-05, 2.725377635215409e-05, 1.9667899323394522e-05, 7.082601223373786e-05, 3.3697724575176835e-05, 3.579350595828146e-05, 1.5339375750045292e-05, 2.784266871458385e-05, 8.58508519741008e-06, 3.522853512549773e-05, 2.9944207199150696e-05, 2.0081495677004568e-05, 7.953084605105687e-06, 4.695970710599795e-05, 2.8441407266655006e-05, 9.26048778637778e-06, 6.548832516273251e-06, 1.3527245755540207e-05, 1.2536826943687629e-05, 1.030578732752474e-05, 8.19125762063777e-06, 5.880556273041293e-06, 4.923717369820224e-06, 5.559719284065068e-06, 5.9597273320832755e-06], "label": 0}
{"index": 3, "logits": [9.787659645080566, -0.6223222017288208, -0.03971472755074501, -1.038114070892334, 0.24018540978431702, -0.8904737830162048, -0.7114139795303345, -1.2315020561218262, -0.5120854377746582, -1.4273980855941772, -0.44618460536003113, -1.0241562128067017, -0.9727545380592346, -1.8587366342544556, 0.020689941942691803, -0.6228570342063904, -1.6020199060440063, -2.130260467529297, -1.370570421218872, -1.40530526638031, -1.6782578229904175, -1.94076669216156, -2.2038567066192627, -2.336832284927368, -2.268157720565796, -2.140028953552246], "probs": [0.9994485974311829, 3.0113611501292326e-05, 5.392447565100156e-05, 1.986949791898951e-05, 7.134198676794767e-05, 2.303065048181452e-05, 2.7546762794372626e-05, 1.6375688574044034e-05, 3.362310235388577e-05, 1.3462414244713727e-05, 3.591357381083071e-05, 2.0148761905147694e-05, 2.12115264730528e-05, 8.74570196174318e-06, 5.728216274292208e-05, 3.0097504350123927e-05, 1.1305383850412909e-05, 6.666126409982098e-06, 1.4249604646465741e-05, 1.3763145034317859e-05, 1.0475521776243113e-05, 8.056933438638225e-06, 6.193143690325087e-06, 5.422014055511681e-06, 5.807448815176031e-06, 6.601325367228128e-06], "label": 0}
```
### Step 3: Evaluate
Once you have the prediction, you can run the evaluation script to evaluate the model:
```shell
python evaluate-slot.py
python evaluate-intent.py
```
The evaluation results are as follows:
`atis_slot`:
```
precision: 0.894397728514, recall: 0.894104803493, f1: 0.894251242016
```
`atis_intent`:
```
data num: 893
precision: 0.708846584546, recall: 1.0, f1: 0.999999995
```
# -*- coding: utf-8 -*-
import json
import numpy as np
def accuracy(preds, labels):
preds = np.array(preds)
labels = np.array(labels)
return (preds == labels).mean()
def f1(preds, labels):
preds = np.array(preds)
labels = np.array(labels)
tp = np.sum((labels == '1') & (preds == '1'))
tn = np.sum((labels == '0') & (preds == '0'))
fp = np.sum((labels == '0') & (preds == '1'))
fn = np.sum((labels == '1') & (preds == '0'))
p = tp * 1.0 / (tp + fp)
r = tp * 1.0 / (tp + fn) * 1.0
f1 = (2 * p * r) / (p + r + 1e-8)
return f1
def recall(preds, labels):
preds = np.array(preds)
labels = np.array(labels)
# recall=TP/(TP+FN)
tp = np.sum((labels == '1') & (preds == '1'))
fn = np.sum((labels == '1') & (preds == '0'))
re = tp * 1.0 / (tp + fn)
return re
def res_evaluate(res_dir="./outputs/predict-intent/predictions.json", eval_phase='test'):
if eval_phase == 'test':
data_dir="./data/atis/atis_intent/test.tsv"
elif eval_phase == 'dev':
data_dir="./data/dev.tsv"
else:
assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test'
labels = []
with open(data_dir, "r") as file:
first_flag = True
for line in file:
line = line.split("\t")
label = line[0]
if label=='label':
continue
labels.append(str(label))
file.close()
preds = []
with open(res_dir, "r") as file:
for line in file.readlines():
line = json.loads(line)
pred = line['label']
preds.append(str(pred))
file.close()
assert len(labels) == len(preds), "prediction result doesn't match to labels"
print('data num: {}'.format(len(labels)))
print("precision: {}, recall: {}, f1: {}".format(accuracy(preds, labels), recall(preds, labels), f1(preds, labels)))
res_evaluate()
# -*- coding: utf-8 -*-
import json
def load_label_map(map_dir="./data/atis/atis_slot/label_map.json"):
"""
:param map_dir: dict indictuing chunk type
:return:
"""
return json.load(open(map_dir, "r"))
def cal_chunk(total_res, total_label):
assert len(total_label) == len(total_res), 'prediction result doesn\'t match to labels'
num_labels = 0
num_corr = 0
num_infers = 0
for res, label in zip(total_res, total_label):
assert len(res) == len(label), "prediction result doesn\'t match to labels"
num_labels += sum([0 if i == 6 else 1 for i in label])
num_corr += sum([1 if label[i] == res[i] and label[i] != 6 else 0 for i in range(len(label))])
num_infers += sum([0 if i == 6 else 1 for i in res])
precision = num_corr * 1.0 / num_infers if num_infers > 0 else 0.0
recall = num_corr * 1.0 / num_labels if num_labels > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
return precision, recall, f1
def res_evaluate(res_dir="./outputs/predict/predictions.json", data_dir="./data/atis/atis_slot/test.tsv"):
label_map = load_label_map()
total_label = []
with open(data_dir, "r") as file:
first_flag = True
for line in file:
if first_flag:
first_flag = False
continue
line = line.strip("\n")
if len(line) == 0:
continue
line = line.split("\t")
if len(line) < 2:
continue
labels = line[1][:-1].split("\x02")
total_label.append(labels)
total_label = [[label_map[j] for j in i] for i in total_label]
total_res = []
with open(res_dir, "r") as file:
cnt = 0
for line in file:
line = line.strip("\n")
if len(line) == 0:
continue
try:
res_arr = json.loads(line)
if len(total_label[cnt]) < len(res_arr):
total_res.append(res_arr[1: 1 + len(total_label[cnt])])
elif len(total_label[cnt]) == len(res_arr):
total_res.append(res_arr)
else:
total_res.append(res_arr)
total_label[cnt] = total_label[cnt][: len(res_arr)]
except:
print("json format error: {}".format(cnt))
print(line)
cnt += 1
precision, recall, f1 = cal_chunk(total_res, total_label)
print("precision: {}, recall: {}, f1: {}".format(precision, recall, f1))
res_evaluate()
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
# configs
max_seqlen = 256
batch_size = 16
num_epochs = 6
print_steps = 5
num_classes = 26
vocab_path = './pretrain/ernie-en-base/vocab.txt'
predict_file = './data/atis/atis_intent/test.tsv'
save_path = './outputs/'
pred_output = './outputs/predict-intent/'
save_type = 'ckpt'
random_seed = 0
pre_params = './pretrain/ernie-en-base/params'
config = json.load(open('./pretrain/ernie-en-base/ernie_config.json'))
input_dim = config['hidden_size']
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
print('prepare to predict...')
predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed, phase='predict')
# step 1-2: load the training data
predict_cls_reader.load_data(predict_file, batch_size)
# step 2: create a backbone of the model to extract text features
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
# step 3: register the backbone in reader
predict_cls_reader.register_with(pred_ernie)
# step 4: create the task output head
cls_pred_head = palm.head.Classify(num_classes, input_dim, phase='predict')
# step 5-1: create a task trainer
trainer = palm.Trainer("intent")
# step 5-2: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, cls_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/ckpt.step9282'
pred_ckpt = trainer.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_cls_reader, phase='predict')
# step 8: predict
print('predicting..')
trainer.predict(print_steps=print_steps, output_dir=pred_output)
\ No newline at end of file
......@@ -11,15 +11,14 @@ if __name__ == '__main__':
batch_size = 16
num_epochs = 6
print_steps = 5
lr = 5e-5
num_classes = 130
random_seed = 1
label_map = './data/atis/atis_slot/label_map.json'
vocab_path = './pretrain/ernie-en-base/vocab.txt'
predict_file = './data/atis/atis_slot/test.tsv'
save_path = './outputs/'
pred_output = './outputs/predict/'
pred_output = './outputs/predict-slot/'
save_type = 'ckpt'
random_seed = 0
pre_params = './pretrain/ernie-en-base/params'
config = json.load(open('./pretrain/ernie-en-base/ernie_config.json'))
......@@ -29,7 +28,7 @@ if __name__ == '__main__':
# step 1-1: create readers for prediction
print('prepare to predict...')
predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict')
predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed, phase='predict')
# step 1-2: load the training data
predict_seq_label_reader.load_data(predict_file, batch_size)
......@@ -48,7 +47,7 @@ if __name__ == '__main__':
trainer_seq_label.build_predict_forward(pred_ernie, seq_label_pred_head)
# step 6: load pretrained model
pred_model_path = './outputs/1580822697.73-ckpt.step9282'
pred_model_path = './outputs/ckpt.step9282'
pred_ckpt = trainer_seq_label.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
......
......@@ -35,30 +35,30 @@ if __name__ == '__main__':
# step 1-1: create readers for training
seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed)
cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data
seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size)
match_reader.load_data(train_intent, file_format='tsv', num_epochs=None, batch_size=batch_size)
cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None)
# step 2: create a backbone of the model to extract text features
ernie = palm.backbone.ERNIE.from_config(config)
# step 3: register the backbone in readers
seq_label_reader.register_with(ernie)
match_reader.register_with(ernie)
cls_reader.register_with(ernie)
# step 4: create task output heads
seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
match_head = palm.head.Match(num_classes_intent, input_dim, dropout_prob)
cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob)
# step 5-1: create a task trainer
trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0)
trainer_match = palm.Trainer("intent", mix_ratio=0.5)
trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_match])
trainer_cls = palm.Trainer("intent", mix_ratio=1.0)
trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls])
# # step 5-2: build forward graph with backbone and task head
loss_var1 = trainer_match.build_forward(ernie, match_head)
loss_var2 = trainer_seq_label.build_forward(ernie, seq_label_head)
loss1 = trainer_cls.build_forward(ernie, cls_head)
loss2 = trainer_seq_label.build_forward(ernie, seq_label_head)
loss_var = trainer.build_forward()
# step 6-1*: use warmup
......@@ -71,13 +71,13 @@ if __name__ == '__main__':
trainer.build_backward(optimizer=adam, weight_decay=weight_decay)
# step 7: fit prepared reader and data
trainer.fit_readers_with_mixratio([seq_label_reader, match_reader], "slot", num_epochs)
trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs)
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model
# save_steps = int(n_steps-batch_size)
save_steps = 10
trainer_seq_label.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type, is_multi=True)
save_steps = int(n_steps-batch_size) // 2
# save_steps = 10
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
trainer.train(print_steps=print_steps)
\ No newline at end of file
## Examples 3: Tagging
## Example 3: Tagging
This task is a named entity recognition task. The following sections detail model preparation, dataset preparation, and how to run the task.
### Step 1: Prepare Pre-trained Models & Datasets
......
......@@ -77,7 +77,7 @@ if __name__ == '__main__':
# step 1-1: create readers for prediction
print('prepare to predict...')
predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict')
predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed, phase='predict')
# step 1-2: load the training data
predict_seq_label_reader.load_data(predict_file, batch_size)
......
......@@ -111,7 +111,7 @@ class Classify(Head):
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)):
label = 0 if self._preds[i][0] > self._preds[i][1] else 1
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._preds[i]}
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))
......
......@@ -52,7 +52,7 @@ class MultiHeadTrainer(Trainer):
'input_varnames': 'self._pred_input_varname_list',
'fetch_list': 'self._pred_fetch_name_list'}
# self._check_save = lambda: False
self._check_save = lambda: False
for t in self._trainers:
t._set_multitask()
......@@ -82,7 +82,9 @@ class MultiHeadTrainer(Trainer):
def get_loss(i):
head = head_dict[self._trainers[i].name]
self._trainers[i]._lock_prog = True
loss_var = self._trainers[i].build_forward(backbone, head)
self._trainers[i]._lock_prog = False
return loss_var
task_fns = {i: lambda i=i: get_loss(i) for i in range(len(self._trainers))}
......@@ -236,7 +238,7 @@ class MultiHeadTrainer(Trainer):
loss, print_steps / time_cost))
time_begin = time.time()
# self._check_save()
self._check_save()
finish = self._check_finish(self._trainers[task_id].name)
if finish:
break
......@@ -266,7 +268,7 @@ class MultiHeadTrainer(Trainer):
rt_outputs = self._trainers[task_id].train_one_step(batch)
self._cur_train_step += 1
# self._check_save()
self._check_save()
return rt_outputs, task_id
# if dev_count > 1:
......
......@@ -107,6 +107,7 @@ class Trainer(object):
'fetch_list': 'self._pred_fetch_name_list'}
self._lock = False
self._lock_prog = False
self._build_forward = False
def build_forward(self, backbone, task_head):
......@@ -161,9 +162,11 @@ class Trainer(object):
train_prog = fluid.Program()
train_init_prog = fluid.Program()
self._train_prog = train_prog
self._train_init_prog = train_init_prog
if not self._multi_task:
if not self._lock_prog:
self._train_prog = train_prog
self._train_init_prog = train_init_prog
if not self._lock_prog:
with fluid.program_guard(train_prog, train_init_prog):
net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
bb_output_vars = backbone.build(net_inputs)
......@@ -184,7 +187,7 @@ class Trainer(object):
task_inputs['reader'] = task_inputs_from_reader
scope = self.name+'.'
if not self._multi_task:
if not self._lock_prog:
with fluid.program_guard(train_prog, train_init_prog):
with fluid.unique_name.guard(scope):
output_vars = self._build_head(task_inputs, phase='train', scope=scope)
......@@ -209,7 +212,7 @@ class Trainer(object):
# task_id_vec = layers.one_hot(task_id_var, num_instances)
# losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0)
# loss = layers.reduce_sum(task_id_vec * losses)
if not self._multi_task:
if not self._lock_prog:
with fluid.program_guard(train_prog, train_init_prog):
loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
else:
......@@ -388,8 +391,9 @@ class Trainer(object):
reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone')
elif phase == 'predict':
self._predict_reader = reader
tail = self._num_examples % batch_size > 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
# tail = self._num_examples % batch_size > 0
# self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
self._pred_steps_pur_epoch = reader.num_examples // batch_size
shape_and_dtypes = self._pred_shape_and_dtypes
name_to_position = self._pred_name_to_position
net_inputs = self._pred_net_inputs
......@@ -503,7 +507,7 @@ class Trainer(object):
convert=convert,
main_program=self._train_init_prog)
def set_saver(self, save_path, save_steps, save_type='ckpt', is_multi=False):
def set_saver(self, save_path, save_steps, save_type='ckpt'):
"""
create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps.
......@@ -540,19 +544,11 @@ class Trainer(object):
if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
if self._save_predict:
if is_multi:
self._save(save_path, suffix='-pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
else:
self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
if self._save_ckpt:
if is_multi:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
else:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
return True
else:
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册