predict-intent.py 1.8 KB
Newer Older
W
wangxiao1021 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count


if __name__ == '__main__':

    # configs
    max_seqlen = 256
    batch_size = 16
    num_epochs = 6 
    print_steps = 5
W
wangxiao1021 已提交
14
    num_classes = 26
W
wangxiao1021 已提交
15
    vocab_path = './pretrain/ERNIE-v2-en-base/vocab.txt'
W
wangxiao1021 已提交
16
    predict_file = './data/atis/atis_intent/test.tsv'
W
wangxiao1021 已提交
17
    save_path = './outputs/'
W
wangxiao1021 已提交
18
    pred_output = './outputs/predict-intent/'
W
wangxiao1021 已提交
19
    save_type = 'ckpt'
W
wangxiao1021 已提交
20
    random_seed = 0
W
wangxiao1021 已提交
21
    config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json'))
W
wangxiao1021 已提交
22 23 24 25 26 27
    input_dim = config['hidden_size']

    # -----------------------  for prediction ----------------------- 

    # step 1-1: create readers for prediction
    print('prepare to predict...')
W
wangxiao1021 已提交
28
    predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed, phase='predict')
W
wangxiao1021 已提交
29
    # step 1-2: load the training data
W
wangxiao1021 已提交
30 31
    predict_cls_reader.load_data(predict_file, batch_size)
    
W
wangxiao1021 已提交
32 33 34
    # step 2: create a backbone of the model to extract text features
    pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')

W
wangxiao1021 已提交
35 36 37
    # step 3: register the backbone in reader
    predict_cls_reader.register_with(pred_ernie)
    
W
wangxiao1021 已提交
38
    # step 4: create the task output head
W
wangxiao1021 已提交
39
    cls_pred_head = palm.head.Classify(num_classes, input_dim, phase='predict')
W
wangxiao1021 已提交
40 41
    
    # step 5-1: create a task trainer
W
wangxiao1021 已提交
42
    trainer = palm.Trainer("intent")
W
wangxiao1021 已提交
43
    # step 5-2: build forward graph with backbone and task head
W
wangxiao1021 已提交
44 45
    trainer.build_predict_forward(pred_ernie, cls_pred_head)
 
W
wangxiao1021 已提交
46
    # step 6: load checkpoint
W
wangxiao1021 已提交
47
    pred_model_path = './outputs/ckpt.step4641'
W
wangxiao1021 已提交
48
    trainer.load_ckpt(pred_model_path)
W
wangxiao1021 已提交
49

W
wangxiao1021 已提交
50
    # step 7: fit prepared reader and data
W
wangxiao1021 已提交
51 52
    trainer.fit_reader(predict_cls_reader, phase='predict')

W
wangxiao1021 已提交
53 54
    # step 8: predict
    print('predicting..')
W
wangxiao1021 已提交
55
    trainer.predict(print_steps=print_steps, output_dir=pred_output)