predict.py 4.5 KB
Newer Older
S
Steffy-zxf 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
import argparse

import paddle
19
import paddle.nn.functional as F
20 21
import paddlenlp as ppnlp
from paddlenlp.data import JiebaTokenizer, Pad, Stack, Tuple, Vocab
S
Steffy-zxf 已提交
22

23
from utils import preprocess_prediction_data
S
Steffy-zxf 已提交
24 25 26 27 28

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=eval, default=False, help="Whether use GPU for training, input should be True or False")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
29
parser.add_argument("--vocab_path", type=str, default="./simnet_word_dict.txt", help="The path to vocabulary.")
S
Steffy-zxf 已提交
30
parser.add_argument('--network', type=str, default="lstm", help="Which network you would like to choose bow, cnn, lstm or gru ?")
31
parser.add_argument("--params_path", type=str, default='./checkpoints/final.pdparams', help="The path of model parameter to be loaded.")
S
Steffy-zxf 已提交
32 33 34 35
args = parser.parse_args()
# yapf: enable


36
def predict(model, data, label_map, batch_size=1, pad_token_id=0):
S
Steffy-zxf 已提交
37 38 39 40 41 42
    """
    Predicts the data labels.

    Args:
        model (obj:`paddle.nn.Layer`): A model to classify texts.
        data (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.
Z
Zeyu Chen 已提交
43
            A Example object contains `text`(word_ids) and `seq_len`(sequence length).
S
Steffy-zxf 已提交
44 45
        label_map(obj:`dict`): The label id (key) to label str (value) map.
        batch_size(obj:`int`, defaults to 1): The number of batch.
46
        pad_token_id(obj:`int`, optional, defaults to 0): The pad token index.
S
Steffy-zxf 已提交
47 48 49 50 51 52

    Returns:
        results(obj:`dict`): All the predictions labels.
    """

    # Seperates data into some batches.
53 54 55 56 57 58 59 60 61 62
    batches = [
        data[idx:idx + batch_size] for idx in range(0, len(data), batch_size)
    ]

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=pad_token_id),  # query_ids
        Pad(axis=0, pad_val=pad_token_id),  # title_ids
        Stack(dtype="int64"),  # query_seq_lens
        Stack(dtype="int64"),  # title_seq_lens
    ): [data for data in fn(samples)]
S
Steffy-zxf 已提交
63 64 65 66

    results = []
    model.eval()
    for batch in batches:
67 68 69 70
        query_ids, title_ids, query_seq_lens, title_seq_lens = batchify_fn(
            batch)
        query_ids = paddle.to_tensor(query_ids)
        title_ids = paddle.to_tensor(title_ids)
S
Steffy-zxf 已提交
71 72
        query_seq_lens = paddle.to_tensor(query_seq_lens)
        title_seq_lens = paddle.to_tensor(title_seq_lens)
73
        logits = model(query_ids, title_ids, query_seq_lens, title_seq_lens)
74
        probs = F.softmax(logits, axis=1)
S
Steffy-zxf 已提交
75 76 77 78 79 80 81 82 83 84
        idx = paddle.argmax(probs, axis=1).numpy()
        idx = idx.tolist()
        labels = [label_map[i] for i in idx]
        results.extend(labels)
    return results


if __name__ == "__main__":
    paddle.set_device("gpu") if args.use_gpu else paddle.set_device("cpu")
    # Loads vocab.
85 86 87
    vocab = Vocab.load_vocabulary(
        args.vocab_path, unk_token='[UNK]', pad_token='[PAD]')
    tokenizer = JiebaTokenizer(vocab)
S
Steffy-zxf 已提交
88 89 90 91
    label_map = {0: 'dissimilar', 1: 'similar'}

    # Constructs the newtork.
    model = ppnlp.models.SimNet(
S
Steffy-zxf 已提交
92
        network=args.network, vocab_size=len(vocab), num_classes=len(label_map))
S
Steffy-zxf 已提交
93 94 95 96 97 98 99 100 101 102 103 104

    # Loads model parameters.
    state_dict = paddle.load(args.params_path)
    model.set_dict(state_dict)
    print("Loaded parameters from %s" % args.params_path)

    # Firstly pre-processing prediction data  and then do predict.
    data = [
        ['世界上什么东西最小', '世界上什么东西最小?'],
        ['光眼睛大就好看吗', '眼睛好看吗?'],
        ['小蝌蚪找妈妈怎么样', '小蝌蚪找妈妈是谁画的'],
    ]
105
    examples = preprocess_prediction_data(data, tokenizer)
S
Steffy-zxf 已提交
106 107 108 109 110
    results = predict(
        model,
        examples,
        label_map=label_map,
        batch_size=args.batch_size,
111
        pad_token_id=vocab.token_to_idx.get('[PAD]', 0))
S
Steffy-zxf 已提交
112 113 114

    for idx, text in enumerate(data):
        print('Data: {} \t Label: {}'.format(text, results[idx]))