vocabulary.py 3.5 KB
Newer Older
D
dingsiyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
# -*- coding: utf-8 -*
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
:py:class:`Vocabulary`
"""
import collections
import unicodedata
import six

def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")


class Vocabulary(object):
    """Vocabulary"""
    def __init__(self, vocab_path, unk_token):
        """
        :param vocab_path: 
        :param unk_token: 
        """
        if not vocab_path:
            raise ValueError("vocab_path can't be None")

        self.vocab_path = vocab_path
        self.unk_token = unk_token
        self.vocab_dict, self.id_dict = self.load_vocab()
        self.vocab_size = len(self.id_dict)

    def load_vocab(self):
        """
        :return:
        """
        vocab_dict = collections.OrderedDict()
        id_dict = collections.OrderedDict()
        file_vocab = open(self.vocab_path)
        for num, line in enumerate(file_vocab):
            items = convert_to_unicode(line.strip()).split("\t")
            if len(items) > 2:
                break
            token = items[0]
            if len(items) == 2:
                index = items[1]
            else:
                index = num
            token = token.strip()

            vocab_dict[token] = int(index)
            id_dict[index] = token

        return vocab_dict, id_dict

    def add_reserve_id(self):
        """
        :return:
        """
        pass

    def convert_tokens_to_ids(self, tokens):
        """
        :param tokens:
        :return:
        """
        UNK = self.vocab_dict[self.unk_token]
        output = [self.vocab_dict.get(item, UNK) for item in tokens]
        # for item in tokens:
            # output.append(self.vocab_dict.get(item, UNK))
        return output

    def convert_ids_to_tokens(self, ids):
        """
        :param ids:
        :return:
        """
        output = []
        for item in ids:
            output.append(self.id_dict.get(item, self.unk_token))
        return output

    def get_vocab_size(self):
        """
        :return:
        """
        return len(self.id_dict)

    def covert_id_to_token(self, id):
        """
        :param id:
        :return: token
        """
        return self.id_dict.get(id, self.unk_token)

    def covert_token_to_id(self, token):
        """
        :param token:
        :return: id
        """
        UNK = self.vocab_dict[self.unk_token]
        return self.vocab_dict.get(token, UNK)