utils.py 3.0 KB
Newer Older
S
Steffy-zxf 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np


17
def convert_example(example, tokenizer, is_test=False):
S
Steffy-zxf 已提交
18 19 20 21 22 23
    """
    Builds model inputs from a sequence for sequence classification tasks. 
    It use `jieba.cut` to tokenize text.

    Args:
        example(obj:`list[str]`): List of input data, containing text and label if it have label.
24
        tokenizer(obj: paddlenlp.data.JiebaTokenizer): It use jieba to cut the chinese string.
S
Steffy-zxf 已提交
25 26 27 28 29 30 31 32 33 34 35
        is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.

    Returns:
        query_ids(obj:`list[int]`): The list of query ids.
        title_ids(obj:`list[int]`): The list of title ids.
        query_seq_len(obj:`int`): The input sequence query length.
        title_seq_len(obj:`int`): The input sequence title length.
        label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
    """

    query, title = example[0], example[1]
36 37 38 39
    query_ids = np.array(tokenizer.encode(query), dtype="int64")
    query_seq_len = np.array(len(query_ids), dtype="int64")
    title_ids = np.array(tokenizer.encode(title), dtype="int64")
    title_seq_len = np.array(len(title_ids), dtype="int64")
S
Steffy-zxf 已提交
40 41 42 43 44 45 46 47

    if not is_test:
        label = np.array(example[-1], dtype="int64")
        return query_ids, title_ids, query_seq_len, title_seq_len, label
    else:
        return query_ids, title_ids, query_seq_len, title_seq_len


48
def preprocess_prediction_data(data, tokenizer):
S
Steffy-zxf 已提交
49 50 51 52 53 54 55
    """
    It process the prediction data as the format used as training.

    Args:
        data (obj:`List[List[str, str]]`): 
            The prediction data whose each element is a text pair. 
            Each text will be tokenized by jieba.lcut() function.
56
        tokenizer(obj: paddlenlp.data.JiebaTokenizer): It use jieba to cut the chinese string.
S
Steffy-zxf 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69

    Returns:
        examples (obj:`list`): The processed data whose each element 
            is a `list` object, which contains 

            - query_ids(obj:`list[int]`): The list of query ids.
            - title_ids(obj:`list[int]`): The list of title ids.
            - query_seq_len(obj:`int`): The input sequence query length.
            - title_seq_len(obj:`int`): The input sequence title length.

    """
    examples = []
    for query, title in data:
70 71
        query_ids = tokenizer.encode(query)
        title_ids = tokenizer.encode(title)
S
Steffy-zxf 已提交
72 73
        examples.append([query_ids, title_ids, len(query_ids), len(title_ids)])
    return examples