utils.py 5.0 KB
Newer Older
S
Steffy-zxf 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jieba
import numpy as np


def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = {}
    with open(vocab_file, "r", encoding="utf-8") as reader:
        tokens = reader.readlines()
    for index, token in enumerate(tokens):
        token = token.rstrip("\n").split("\t")[0]
        vocab[token] = index
    return vocab


def convert_ids_to_tokens(wids, inversed_vocab):
    """ Converts a token string (or a sequence of tokens) in a single integer id
        (or a sequence of ids), using the vocabulary.
    """
    tokens = []
    for wid in wids:
        wstr = inversed_vocab.get(wid, None)
        if wstr:
            tokens.append(wstr)
    return tokens


def convert_tokens_to_ids(tokens, vocab):
    """ Converts a token id (or a sequence of id) in a token string
        (or a sequence of tokens), using the vocabulary.
    """

    ids = []
    unk_id = vocab.get('[UNK]', None)
    for token in tokens:
        wid = vocab.get(token, unk_id)
        if wid:
            ids.append(wid)
    return ids


def pad_texts_to_max_seq_len(texts, max_seq_len, pad_token_id=0):
    """
    Padded the texts to the max sequence length if the length of text is lower than it.
    Unless it truncates the text.

    Args:
        texts(obj:`list`): Texts which contrains a sequence of word ids.
        max_seq_len(obj:`int`): Max sequence length.
        pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
    """
    for index, text in enumerate(texts):
        seq_len = len(text)
        if seq_len < max_seq_len:
            padded_tokens = [pad_token_id for _ in range(max_seq_len - seq_len)]
            new_text = text + padded_tokens
            texts[index] = new_text
        elif seq_len > max_seq_len:
            new_text = text[:max_seq_len]
            texts[index] = new_text


def generate_batch(batch, pad_token_id=0, return_label=True):
    """
    Generates a batch whose text will be padded to the max sequence length in the batch.

    Args:
        batch(obj:`List[Example]`) : One batch, which contains texts, labels and the true sequence lengths.
        pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.

    Returns:
        batch(:obj:`Tuple[list]`): The batch data which contains texts, seq_lens and labels.
    """
    seq_lens = [entry[1] for entry in batch]

    batch_max_seq_len = max(seq_lens)
    texts = [entry[0] for entry in batch]
    pad_texts_to_max_seq_len(texts, batch_max_seq_len, pad_token_id)

    if return_label:
        labels = [[entry[-1]] for entry in batch]
        return texts, seq_lens, labels
    else:
        return texts, seq_lens


def convert_example(example, vocab, unk_token_id=1, is_test=False):
    """
    Builds model inputs from a sequence for sequence classification tasks. 
    It use `jieba.cut` to tokenize text.

    Args:
        example(obj:`list[str]`): List of input data, containing text and label if it have label.
        vocab(obj:`dict`): The vocabulary.
        unk_token_id(obj:`int`, defaults to 1): The unknown token id.
        is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.

    Returns:
        input_ids(obj:`list[int]`): The list of token ids.s
        valid_length(obj:`int`): The input sequence valid length.
        label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
    """

    input_ids = []
    for token in jieba.cut(example[0]):
        token_id = vocab.get(token, unk_token_id)
        input_ids.append(token_id)
121
    valid_length = np.array(len(input_ids), dtype='int64')
122
    input_ids = np.array(input_ids, dtype='int64')
S
Steffy-zxf 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

    if not is_test:
        label = np.array(example[-1], dtype="int64")
        return input_ids, valid_length, label
    else:
        return input_ids, valid_length


def preprocess_prediction_data(data, vocab):
    """
    It process the prediction data as the format used as training.

    Args:
        data (obj:`List[str]`): The prediction data whose each element is  a tokenized text.

    Returns:
        examples (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.
            A Example object contains `text`(word_ids) and `se_len`(sequence length).

    """
    examples = []
    for text in data:
        tokens = " ".join(jieba.cut(text)).split(' ')
        ids = convert_tokens_to_ids(tokens, vocab)
        examples.append([ids, len(ids)])
    return examples