predict.py 2.1 KB
Newer Older
W
wangxiao1021 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count


if __name__ == '__main__':

    # configs
    max_seqlen = 256
    batch_size = 16
    num_epochs = 6 
    print_steps = 5
    lr = 5e-5
    num_classes = 130
    random_seed = 1
    label_map = './data/atis/atis_slot/label_map.json'
    vocab_path = './pretrain/ernie-en-base/vocab.txt'
    predict_file = './data/atis/atis_slot/test.tsv'
    save_path = './outputs/'
    pred_output = './outputs/predict/'
    save_type = 'ckpt'

    pre_params = './pretrain/ernie-en-base/params'
    config = json.load(open('./pretrain/ernie-en-base/ernie_config.json'))
    input_dim = config['hidden_size']

    # -----------------------  for prediction ----------------------- 

    # step 1-1: create readers for prediction
    print('prepare to predict...')
    predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, phase='predict')
    # step 1-2: load the training data
    predict_seq_label_reader.load_data(predict_file, batch_size)
   
    # step 2: create a backbone of the model to extract text features
    pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict')
    
    # step 3: register the backbone in reader
    predict_seq_label_reader.register_with(pred_ernie)

    # step 4: create the task output head
    seq_label_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict')
    
    # step 5-1: create a task trainer
    trainer_seq_label = palm.Trainer("slot")
    # step 5-2: build forward graph with backbone and task head
    trainer_seq_label.build_predict_forward(pred_ernie, seq_label_pred_head)
    
    # step 6: load pretrained model
    pred_model_path = './outputs/1580822697.73-ckpt.step9282'
    pred_ckpt = trainer_seq_label.load_ckpt(pred_model_path)
    
    # step 7: fit prepared reader and data
    trainer_seq_label.fit_reader(predict_seq_label_reader, phase='predict')
   
    # step 8: predict
    print('predicting..')
    trainer_seq_label.predict(print_steps=print_steps, output_dir=pred_output)