predict.py 4.4 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse

import paddle
S
Steffy-zxf 已提交
17
import paddlenlp as ppnlp
18
import paddle.nn.functional as F
Z
Zeyu Chen 已提交
19

S
Steffy-zxf 已提交
20
from utils import load_vocab, generate_batch, preprocess_prediction_data
Z
Zeyu Chen 已提交
21 22 23 24

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=eval, default=False, help="Whether use GPU for training, input should be True or False")
25 26 27 28
parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number of a batch for training.")
parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The path to vocabulary.")
parser.add_argument('--network', type=str, default="bilstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn, cnn and textcnn?")
parser.add_argument("--params_path", type=str, default='./checkpoints/final.pdparams', help="The path of model parameter to be loaded.")
Z
Zeyu Chen 已提交
29 30 31 32
args = parser.parse_args()
# yapf: enable


S
Steffy-zxf 已提交
33
def predict(model, data, label_map, collate_fn, batch_size=1, pad_token_id=0):
Z
Zeyu Chen 已提交
34 35 36 37 38 39 40 41 42 43 44
    """
    Predicts the data labels.

    Args:
        model (obj:`paddle.nn.Layer`): A model to classify texts.
        data (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.
            A Example object contains `text`(word_ids) and `se_len`(sequence length).
        label_map(obj:`dict`): The label id (key) to label str (value) map.
        collate_fn(obj: `callable`): function to generate mini-batch data by merging
            the sample list.
        batch_size(obj:`int`, defaults to 1): The number of batch.
S
Steffy-zxf 已提交
45
        pad_token_id(obj:`int`, optional, defaults to 0): The pad token index.
Z
Zeyu Chen 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

    Returns:
        results(obj:`dict`): All the predictions labels.
    """

    # Seperates data into some batches.
    batches = []
    one_batch = []
    for example in data:
        one_batch.append(example)
        if len(one_batch) == batch_size:
            batches.append(one_batch)
            one_batch = []
    if one_batch:
        # The last batch whose size is less than the config batch_size setting.
        batches.append(one_batch)

    results = []
S
Steffy-zxf 已提交
64
    model.eval()
Z
Zeyu Chen 已提交
65 66 67 68 69
    for batch in batches:
        texts, seq_lens = collate_fn(
            batch, pad_token_id=pad_token_id, return_label=False)
        texts = paddle.to_tensor(texts)
        seq_lens = paddle.to_tensor(seq_lens)
70 71
        logits = model(texts, seq_lens)
        probs = F.softmax(logits, axis=1)
Z
Zeyu Chen 已提交
72 73 74 75 76 77 78 79
        idx = paddle.argmax(probs, axis=1).numpy()
        idx = idx.tolist()
        labels = [label_map[i] for i in idx]
        results.extend(labels)
    return results


if __name__ == "__main__":
S
Steffy-zxf 已提交
80
    paddle.set_device("gpu") if args.use_gpu else paddle.set_device("cpu")
Z
Zeyu Chen 已提交
81 82 83 84 85
    # Loads vocab.
    vocab = load_vocab(args.vocab_path)
    label_map = {0: 'negative', 1: 'positive'}

    # Constructs the newtork.
S
Steffy-zxf 已提交
86
    model = ppnlp.models.Senta(
S
Steffy-zxf 已提交
87
        network=args.network, vocab_size=len(vocab), num_classes=len(label_map))
Z
Zeyu Chen 已提交
88 89

    # Loads model parameters.
S
Steffy-zxf 已提交
90 91
    state_dict = paddle.load(args.params_path)
    model.set_dict(state_dict)
Z
Zeyu Chen 已提交
92 93 94 95 96 97 98 99
    print("Loaded parameters from %s" % args.params_path)

    # Firstly pre-processing prediction data  and then do predict.
    data = [
        '这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般',
        '怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片',
        '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。',
    ]
S
Steffy-zxf 已提交
100 101
    examples = preprocess_prediction_data(data, vocab)

Z
Zeyu Chen 已提交
102 103 104 105 106 107 108 109
    results = predict(
        model,
        examples,
        label_map=label_map,
        batch_size=args.batch_size,
        collate_fn=generate_batch)

    for idx, text in enumerate(data):
S
Steffy-zxf 已提交
110
        print('Data: {} \t Label: {}'.format(text, results[idx]))