From 9e99e1de9f6afe0cc222089470d6041e53d8ea24 Mon Sep 17 00:00:00 2001 From: wangxiao1021 Date: Wed, 5 Feb 2020 19:44:38 +0800 Subject: [PATCH] fix bugs --- examples/mrc/README.md | 2 +- examples/multi-task/README.md | 38 ++++++--- examples/multi-task/evaluate-intent.py | 64 +++++++++++++++ examples/multi-task/evaluate.py | 78 ------------------- examples/multi-task/predict-intent.py | 57 ++++++++++++++ .../{predict.py => predict-slot.py} | 9 +-- examples/multi-task/run.py | 26 +++---- examples/tagging/README.md | 2 +- examples/tagging/run.py | 2 +- paddlepalm/head/cls.py | 2 +- paddlepalm/multihead_trainer.py | 6 +- paddlepalm/trainer.py | 8 +- 12 files changed, 178 insertions(+), 116 deletions(-) create mode 100644 examples/multi-task/evaluate-intent.py delete mode 100644 examples/multi-task/evaluate.py create mode 100644 examples/multi-task/predict-intent.py rename examples/multi-task/{predict.py => predict-slot.py} (90%) diff --git a/examples/mrc/README.md b/examples/mrc/README.md index 1591aa6..8748c0c 100644 --- a/examples/mrc/README.md +++ b/examples/mrc/README.md @@ -1,4 +1,4 @@ -## Examples 4: Machine Reading Comprehension +## Example 4: Machine Reading Comprehension This task is a machine reading comprehension task. The following sections detail model preparation, dataset preparation, and how to run the task. ### Step 1: Prepare Pre-trained Models & Datasets diff --git a/examples/multi-task/README.md b/examples/multi-task/README.md index 62e49aa..1fd7e9a 100644 --- a/examples/multi-task/README.md +++ b/examples/multi-task/README.md @@ -1,4 +1,4 @@ -## Examples 6: Multi-Task Slot Filling +## Example 6: joint training in dialogue This task is a slot filling task. During training, the task uses intent determination task to assist in training slot filling model. The following sections detail model preparation, dataset preparation, and how to run the task. ### Step 1: Prepare Pre-trained Models & Datasets @@ -73,27 +73,37 @@ global step: 20, intent: step 7/311 (epoch 0), loss: 3.487, speed: 10.28 steps/s After the run, you can view the saved models in the `outputs/` folder. -If you want to use the trained model to predict the `atis_slot` data, run: +If you want to use the trained model to predict the `atis_slot & atis_intent` data, run: ```shell -python predict.py +python predict-slot.py +python predict-intent.py ``` If you want to specify a specific gpu or use multiple gpus for predict, please use **`CUDA_VISIBLE_DEVICES`**, for example: ```shell -CUDA_VISIBLE_DEVICES=0,1,2 python predict.py +CUDA_VISIBLE_DEVICES=0,1,2 python predict-slot.py +CUDA_VISIBLE_DEVICES=0,1,2 python predict-slot.py ``` +After the run, you can view the predictions in the `outputs/predict-slot` folder and `outputs/predict-intent` folder. Here are some examples of predictions: -After the run, you can view the predictions in the `outputs/predict` folder. Here are some examples of predictions: - +`atis_slot`: ``` - [129, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 5, 19, 1, 1, 1, 1, 1, 21, 21, 68, 129] [129, 1, 39, 37, 1, 1, 1, 1, 1, 2, 1, 5, 19, 1, 23, 3, 4, 129, 129, 129, 129, 129] [129, 1, 39, 37, 1, 1, 1, 1, 1, 1, 2, 1, 5, 19, 129, 129, 129, 129, 129, 129, 129, 129] -[129, 1, 1, 1, 1, 1, 1, 14, 15, 1, 2, 1, 5, 19, 1, 39, 37, 129, 129, 129, 129, 129] +[129, 1, 1, 1, 1, 1, 1, 14, 15, 1, 2, 1, 5, 19, 1, 39, 37, 129, 24, 129, 129, 129] +``` + +`atis_intent`: +``` +{"index": 0, "logits": [9.938603401184082, -0.3914794623851776, -0.050973162055015564, -1.0229418277740479, 0.04799401015043259, -0.9632213115692139, -0.6427211761474609, -1.337939739227295, -0.7969412803649902, -1.4441455602645874, -0.6339573264122009, -1.0393054485321045, -0.9242327213287354, -1.9637483358383179, 0.16733427345752716, -0.5280354619026184, -1.7195699214935303, -2.199411630630493, -1.2833174467086792, -1.3081035614013672, -1.6036226749420166, -1.8527079820632935, -2.289180040359497, -2.267214775085449, -2.2578916549682617, -2.2010505199432373], "probs": [0.999531626701355, 3.26210938510485e-05, 4.585415081237443e-05, 1.7348344044876285e-05, 5.06243304698728e-05, 1.8415948943584226e-05, 2.5373808966833167e-05, 1.266065828531282e-05, 2.174747896788176e-05, 1.1384962817828637e-05, 2.5597169951652177e-05, 1.7066764485207386e-05, 1.914815220516175e-05, 6.771284006390488e-06, 5.70411684748251e-05, 2.8457265216275118e-05, 8.644025911053177e-06, 5.349628736439627e-06, 1.3371440218179487e-05, 1.3044088518654462e-05, 9.706698619993404e-06, 7.5665011536329985e-06, 4.890325726591982e-06, 4.99892985317274e-06, 5.045753368904116e-06, 5.340866664482746e-06], "label": 0} +{"index": 1, "logits": [0.8863624930381775, -2.232290506362915, 8.191509246826172, -0.03161466494202614, -0.9149583578109741, -2.172696352005005, -0.3937145471572876, -0.3954394459724426, 1.5333592891693115, 0.8630291223526001, -0.9684226512908936, -2.722721815109253, -0.0060247331857681274, -0.9865402579307556, 1.6328885555267334, 0.3972966969013214, 0.27919167280197144, -1.4911551475524902, -0.9552251696586609, -0.9169244170188904, -0.810670793056488, -1.5118697881698608, -2.0140435695648193, -1.6299077272415161, -1.8589974641799927, -2.07601261138916], "probs": [0.0006675600307062268, 2.9517297662096098e-05, 0.9932880997657776, 0.0002665741485543549, 0.0001102013120544143, 3.132982965325937e-05, 0.00018559220188762993, 0.00018527248175814748, 0.0012749042361974716, 0.0006521637551486492, 0.00010446414671605453, 1.8075270418194123e-05, 0.0002734838053584099, 0.00010258861584588885, 0.0014083238784223795, 0.00040934717981144786, 0.00036374686169438064, 6.193659646669403e-05, 0.00010585198469925672, 0.00010998480865964666, 0.0001223145518451929, 6.0666847275570035e-05, 3.671637750812806e-05, 5.391232480178587e-05, 4.287416595616378e-05, 3.4510172554291785e-05], "label": 0} +{"index": 2, "logits": [9.789957046508789, -0.1730862706899643, -0.7198237776756287, -1.0460278987884521, 0.23521068692207336, -0.5075851678848267, -0.44724929332733154, -1.2945927381515503, -0.6984466314315796, -1.8749892711639404, -0.4631594121456146, -0.6256799697875977, -1.0252169370651245, -1.951456069946289, -0.17572557926177979, -0.6771697402000427, -1.7992591857910156, -2.1457295417785645, -1.4203097820281982, -1.4963451623916626, -1.692310094833374, -1.9219486713409424, -2.2533645629882812, -2.430952310562134, -2.3094685077667236, -2.2399914264678955], "probs": [0.9994625449180603, 4.708383130491711e-05, 2.725377635215409e-05, 1.9667899323394522e-05, 7.082601223373786e-05, 3.3697724575176835e-05, 3.579350595828146e-05, 1.5339375750045292e-05, 2.784266871458385e-05, 8.58508519741008e-06, 3.522853512549773e-05, 2.9944207199150696e-05, 2.0081495677004568e-05, 7.953084605105687e-06, 4.695970710599795e-05, 2.8441407266655006e-05, 9.26048778637778e-06, 6.548832516273251e-06, 1.3527245755540207e-05, 1.2536826943687629e-05, 1.030578732752474e-05, 8.19125762063777e-06, 5.880556273041293e-06, 4.923717369820224e-06, 5.559719284065068e-06, 5.9597273320832755e-06], "label": 0} +{"index": 3, "logits": [9.787659645080566, -0.6223222017288208, -0.03971472755074501, -1.038114070892334, 0.24018540978431702, -0.8904737830162048, -0.7114139795303345, -1.2315020561218262, -0.5120854377746582, -1.4273980855941772, -0.44618460536003113, -1.0241562128067017, -0.9727545380592346, -1.8587366342544556, 0.020689941942691803, -0.6228570342063904, -1.6020199060440063, -2.130260467529297, -1.370570421218872, -1.40530526638031, -1.6782578229904175, -1.94076669216156, -2.2038567066192627, -2.336832284927368, -2.268157720565796, -2.140028953552246], "probs": [0.9994485974311829, 3.0113611501292326e-05, 5.392447565100156e-05, 1.986949791898951e-05, 7.134198676794767e-05, 2.303065048181452e-05, 2.7546762794372626e-05, 1.6375688574044034e-05, 3.362310235388577e-05, 1.3462414244713727e-05, 3.591357381083071e-05, 2.0148761905147694e-05, 2.12115264730528e-05, 8.74570196174318e-06, 5.728216274292208e-05, 3.0097504350123927e-05, 1.1305383850412909e-05, 6.666126409982098e-06, 1.4249604646465741e-05, 1.3763145034317859e-05, 1.0475521776243113e-05, 8.056933438638225e-06, 6.193143690325087e-06, 5.422014055511681e-06, 5.807448815176031e-06, 6.601325367228128e-06], "label": 0} + ``` ### Step 3: Evaluate @@ -101,11 +111,19 @@ After the run, you can view the predictions in the `outputs/predict` folder. Her Once you have the prediction, you can run the evaluation script to evaluate the model: ```shell -python evaluate.py +python evaluate-slot.py +python evaluate-intent.py ``` The evaluation results are as follows: +`atis_slot`: +``` +precision: 0.894397728514, recall: 0.894104803493, f1: 0.894251242016 +``` + +`atis_intent`: ``` -precision: 0.894518453811, recall: 0.894323144105, f1: 0.894420788296 +data num: 893 +precision: 0.708846584546, recall: 1.0, f1: 0.999999995 ``` diff --git a/examples/multi-task/evaluate-intent.py b/examples/multi-task/evaluate-intent.py new file mode 100644 index 0000000..7c491cd --- /dev/null +++ b/examples/multi-task/evaluate-intent.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +import json +import numpy as np + +def accuracy(preds, labels): + preds = np.array(preds) + labels = np.array(labels) + return (preds == labels).mean() + +def f1(preds, labels): + preds = np.array(preds) + labels = np.array(labels) + tp = np.sum((labels == '1') & (preds == '1')) + tn = np.sum((labels == '0') & (preds == '0')) + fp = np.sum((labels == '0') & (preds == '1')) + fn = np.sum((labels == '1') & (preds == '0')) + p = tp * 1.0 / (tp + fp) + r = tp * 1.0 / (tp + fn) * 1.0 + f1 = (2 * p * r) / (p + r + 1e-8) + return f1 + +def recall(preds, labels): + preds = np.array(preds) + labels = np.array(labels) + # recall=TP/(TP+FN) + tp = np.sum((labels == '1') & (preds == '1')) + fn = np.sum((labels == '1') & (preds == '0')) + re = tp * 1.0 / (tp + fn) + return re + + +def res_evaluate(res_dir="./outputs/predict-intent/predictions.json", eval_phase='test'): + if eval_phase == 'test': + data_dir="./data/atis/atis_intent/test.tsv" + elif eval_phase == 'dev': + data_dir="./data/dev.tsv" + + else: + assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test' + + labels = [] + with open(data_dir, "r") as file: + first_flag = True + for line in file: + line = line.split("\t") + label = line[0] + if label=='label': + continue + labels.append(str(label)) + file.close() + + preds = [] + with open(res_dir, "r") as file: + for line in file.readlines(): + line = json.loads(line) + pred = line['label'] + preds.append(str(pred)) + file.close() + assert len(labels) == len(preds), "prediction result doesn't match to labels" + print('data num: {}'.format(len(labels))) + print("precision: {}, recall: {}, f1: {}".format(accuracy(preds, labels), recall(preds, labels), f1(preds, labels))) + +res_evaluate() diff --git a/examples/multi-task/evaluate.py b/examples/multi-task/evaluate.py deleted file mode 100644 index 6eea520..0000000 --- a/examples/multi-task/evaluate.py +++ /dev/null @@ -1,78 +0,0 @@ -# -*- coding: utf-8 -*- - -import json - - -def load_label_map(map_dir="./data/atis/atis_slot/label_map.json"): - """ - :param map_dir: dict indictuing chunk type - :return: - """ - return json.load(open(map_dir, "r")) - - -def cal_chunk(total_res, total_label): - assert len(total_label) == len(total_res), 'prediction result doesn\'t match to labels' - num_labels = 0 - num_corr = 0 - num_infers = 0 - for res, label in zip(total_res, total_label): - assert len(res) == len(label), "prediction result doesn\'t match to labels" - num_labels += sum([0 if i == 6 else 1 for i in label]) - num_corr += sum([1 if label[i] == res[i] and label[i] != 6 else 0 for i in range(len(label))]) - num_infers += sum([0 if i == 6 else 1 for i in res]) - - precision = num_corr * 1.0 / num_infers if num_infers > 0 else 0.0 - recall = num_corr * 1.0 / num_labels if num_labels > 0 else 0.0 - f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0 - - return precision, recall, f1 - - -def res_evaluate(res_dir="./outputs/predict/predictions.json", data_dir="./data/atis/atis_slot/test.tsv"): - label_map = load_label_map() - - total_label = [] - with open(data_dir, "r") as file: - first_flag = True - for line in file: - if first_flag: - first_flag = False - continue - line = line.strip("\n") - if len(line) == 0: - continue - line = line.split("\t") - if len(line) < 2: - continue - labels = line[1][:-1].split("\x02") - total_label.append(labels) - total_label = [[label_map[j] for j in i] for i in total_label] - - total_res = [] - with open(res_dir, "r") as file: - cnt = 0 - for line in file: - line = line.strip("\n") - if len(line) == 0: - continue - try: - res_arr = json.loads(line) - - if len(total_label[cnt]) < len(res_arr): - total_res.append(res_arr[1: 1 + len(total_label[cnt])]) - elif len(total_label[cnt]) == len(res_arr): - total_res.append(res_arr) - else: - total_res.append(res_arr) - total_label[cnt] = total_label[cnt][: len(res_arr)] - except: - print("json format error: {}".format(cnt)) - print(line) - - cnt += 1 - - precision, recall, f1 = cal_chunk(total_res, total_label) - print("precision: {}, recall: {}, f1: {}".format(precision, recall, f1)) - -res_evaluate() diff --git a/examples/multi-task/predict-intent.py b/examples/multi-task/predict-intent.py new file mode 100644 index 0000000..9e41cb3 --- /dev/null +++ b/examples/multi-task/predict-intent.py @@ -0,0 +1,57 @@ +# coding=utf-8 +import paddlepalm as palm +import json +from paddlepalm.distribute import gpu_dev_count + + +if __name__ == '__main__': + + # configs + max_seqlen = 256 + batch_size = 16 + num_epochs = 6 + print_steps = 5 + num_classes = 26 + vocab_path = './pretrain/ernie-en-base/vocab.txt' + predict_file = './data/atis/atis_intent/test.tsv' + save_path = './outputs/' + pred_output = './outputs/predict-intent/' + save_type = 'ckpt' + random_seed = 0 + + pre_params = './pretrain/ernie-en-base/params' + config = json.load(open('./pretrain/ernie-en-base/ernie_config.json')) + input_dim = config['hidden_size'] + + # ----------------------- for prediction ----------------------- + + # step 1-1: create readers for prediction + print('prepare to predict...') + predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed, phase='predict') + # step 1-2: load the training data + predict_cls_reader.load_data(predict_file, batch_size) + + # step 2: create a backbone of the model to extract text features + pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') + + # step 3: register the backbone in reader + predict_cls_reader.register_with(pred_ernie) + + # step 4: create the task output head + cls_pred_head = palm.head.Classify(num_classes, input_dim, phase='predict') + + # step 5-1: create a task trainer + trainer = palm.Trainer("intent") + # step 5-2: build forward graph with backbone and task head + trainer.build_predict_forward(pred_ernie, cls_pred_head) + + # step 6: load pretrained model + pred_model_path = './outputs/ckpt.step9282' + pred_ckpt = trainer.load_ckpt(pred_model_path) + + # step 7: fit prepared reader and data + trainer.fit_reader(predict_cls_reader, phase='predict') + + # step 8: predict + print('predicting..') + trainer.predict(print_steps=print_steps, output_dir=pred_output) \ No newline at end of file diff --git a/examples/multi-task/predict.py b/examples/multi-task/predict-slot.py similarity index 90% rename from examples/multi-task/predict.py rename to examples/multi-task/predict-slot.py index b0221bd..1f837ae 100644 --- a/examples/multi-task/predict.py +++ b/examples/multi-task/predict-slot.py @@ -11,15 +11,14 @@ if __name__ == '__main__': batch_size = 16 num_epochs = 6 print_steps = 5 - lr = 5e-5 num_classes = 130 - random_seed = 1 label_map = './data/atis/atis_slot/label_map.json' vocab_path = './pretrain/ernie-en-base/vocab.txt' predict_file = './data/atis/atis_slot/test.tsv' save_path = './outputs/' - pred_output = './outputs/predict/' + pred_output = './outputs/predict-slot/' save_type = 'ckpt' + random_seed = 0 pre_params = './pretrain/ernie-en-base/params' config = json.load(open('./pretrain/ernie-en-base/ernie_config.json')) @@ -29,7 +28,7 @@ if __name__ == '__main__': # step 1-1: create readers for prediction print('prepare to predict...') - predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict') + predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed, phase='predict') # step 1-2: load the training data predict_seq_label_reader.load_data(predict_file, batch_size) @@ -48,7 +47,7 @@ if __name__ == '__main__': trainer_seq_label.build_predict_forward(pred_ernie, seq_label_pred_head) # step 6: load pretrained model - pred_model_path = './outputs/1580822697.73-ckpt.step9282' + pred_model_path = './outputs/ckpt.step9282' pred_ckpt = trainer_seq_label.load_ckpt(pred_model_path) # step 7: fit prepared reader and data diff --git a/examples/multi-task/run.py b/examples/multi-task/run.py index fe7dffc..c02a6cf 100644 --- a/examples/multi-task/run.py +++ b/examples/multi-task/run.py @@ -35,30 +35,30 @@ if __name__ == '__main__': # step 1-1: create readers for training seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed) - match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed) + cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed) # step 1-2: load the training data seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size) - match_reader.load_data(train_intent, file_format='tsv', num_epochs=None, batch_size=batch_size) - + cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None) + # step 2: create a backbone of the model to extract text features ernie = palm.backbone.ERNIE.from_config(config) # step 3: register the backbone in readers seq_label_reader.register_with(ernie) - match_reader.register_with(ernie) + cls_reader.register_with(ernie) # step 4: create task output heads seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob) - match_head = palm.head.Match(num_classes_intent, input_dim, dropout_prob) + cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob) # step 5-1: create a task trainer trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0) - trainer_match = palm.Trainer("intent", mix_ratio=0.5) - trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_match]) + trainer_cls = palm.Trainer("intent", mix_ratio=1.0) + trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls]) # # step 5-2: build forward graph with backbone and task head - loss_var1 = trainer_match.build_forward(ernie, match_head) - loss_var2 = trainer_seq_label.build_forward(ernie, seq_label_head) + loss1 = trainer_cls.build_forward(ernie, cls_head) + loss2 = trainer_seq_label.build_forward(ernie, seq_label_head) loss_var = trainer.build_forward() # step 6-1*: use warmup @@ -71,13 +71,13 @@ if __name__ == '__main__': trainer.build_backward(optimizer=adam, weight_decay=weight_decay) # step 7: fit prepared reader and data - trainer.fit_readers_with_mixratio([seq_label_reader, match_reader], "slot", num_epochs) + trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs) # step 8-1*: load pretrained parameters trainer.load_pretrain(pre_params) # step 8-2*: set saver to save model - # save_steps = int(n_steps-batch_size) - save_steps = 10 - trainer_seq_label.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type, is_multi=True) + save_steps = int(n_steps-batch_size) // 2 + # save_steps = 10 + trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type, is_multi=True) # step 8-3: start training trainer.train(print_steps=print_steps) \ No newline at end of file diff --git a/examples/tagging/README.md b/examples/tagging/README.md index b7b549f..8c399f2 100644 --- a/examples/tagging/README.md +++ b/examples/tagging/README.md @@ -1,4 +1,4 @@ -## Examples 3: Tagging +## Example 3: Tagging This task is a named entity recognition task. The following sections detail model preparation, dataset preparation, and how to run the task. ### Step 1: Prepare Pre-trained Models & Datasets diff --git a/examples/tagging/run.py b/examples/tagging/run.py index c41c698..f39d8a0 100644 --- a/examples/tagging/run.py +++ b/examples/tagging/run.py @@ -77,7 +77,7 @@ if __name__ == '__main__': # step 1-1: create readers for prediction print('prepare to predict...') - predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict') + predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed, phase='predict') # step 1-2: load the training data predict_seq_label_reader.load_data(predict_file, batch_size) diff --git a/paddlepalm/head/cls.py b/paddlepalm/head/cls.py index e6ca016..3342dcc 100644 --- a/paddlepalm/head/cls.py +++ b/paddlepalm/head/cls.py @@ -111,7 +111,7 @@ class Classify(Head): with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): label = 0 if self._preds[i][0] > self._preds[i][1] else 1 - result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._preds[i]} + result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = json.dumps(result) writer.write(result+'\n') print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) diff --git a/paddlepalm/multihead_trainer.py b/paddlepalm/multihead_trainer.py index a38e434..4555062 100644 --- a/paddlepalm/multihead_trainer.py +++ b/paddlepalm/multihead_trainer.py @@ -52,7 +52,7 @@ class MultiHeadTrainer(Trainer): 'input_varnames': 'self._pred_input_varname_list', 'fetch_list': 'self._pred_fetch_name_list'} - # self._check_save = lambda: False + self._check_save = lambda: False for t in self._trainers: t._set_multitask() @@ -236,7 +236,7 @@ class MultiHeadTrainer(Trainer): loss, print_steps / time_cost)) time_begin = time.time() - # self._check_save() + self._check_save() finish = self._check_finish(self._trainers[task_id].name) if finish: break @@ -266,7 +266,7 @@ class MultiHeadTrainer(Trainer): rt_outputs = self._trainers[task_id].train_one_step(batch) self._cur_train_step += 1 - # self._check_save() + self._check_save() return rt_outputs, task_id # if dev_count > 1: diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index d52ba0a..5519ad6 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -107,6 +107,7 @@ class Trainer(object): 'fetch_list': 'self._pred_fetch_name_list'} self._lock = False + self._lock_prog = False self._build_forward = False def build_forward(self, backbone, task_head): @@ -163,7 +164,7 @@ class Trainer(object): self._train_prog = train_prog self._train_init_prog = train_init_prog - if not self._multi_task: + if not self._lock_prog: with fluid.program_guard(train_prog, train_init_prog): net_inputs = reader_helper.create_net_inputs(input_attrs, async=False) bb_output_vars = backbone.build(net_inputs) @@ -184,7 +185,7 @@ class Trainer(object): task_inputs['reader'] = task_inputs_from_reader scope = self.name+'.' - if not self._multi_task: + if not self._lock_prog: with fluid.program_guard(train_prog, train_init_prog): with fluid.unique_name.guard(scope): output_vars = self._build_head(task_inputs, phase='train', scope=scope) @@ -209,7 +210,7 @@ class Trainer(object): # task_id_vec = layers.one_hot(task_id_var, num_instances) # losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) # loss = layers.reduce_sum(task_id_vec * losses) - if not self._multi_task: + if not self._lock_prog: with fluid.program_guard(train_prog, train_init_prog): loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss']) else: @@ -548,6 +549,7 @@ class Trainer(object): self._save(save_path, suffix='pred.step'+str(self._cur_train_step)) print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step))) if self._save_ckpt: + print(self._train_prog) if is_multi: fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog) print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step))) -- GitLab