model.py 4.3 KB
Newer Older
S
SiMing Dai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
import os
from collections import OrderedDict

import numpy as np
from tqdm import tqdm
from paddlehub.common.logger import logger

from slda_weibo.vocab import Vocab, WordCount


class TopicModel(object):
    """Storage Structure of Topic model, including vocabulary and word topic count.
    """

    def __init__(self, model_dir, config):
        """
        Args:
            model_dir: the path of model directory
            config: ModelConfig class.
        """
        self.__word_topic = None  # Model parameter of word topic.
        self.__vocab = Vocab()  # Vocab data structure of model.
        self.__num_topics = config.num_topics  # Number of topics.
        self.__alpha = config.alpha
        self.__alpha_sum = self.__alpha * self.__num_topics
        self.__beta = config.beta
        self.__beta_sum = None
        self.__type = config.type  # Model type.
        self.__topic_sum = np.zeros(
            self.__num_topics,
            dtype="int64")  # Accum sum of each topic in word topic.
        self.__topic_words = [[] for _ in range(self.__num_topics)]
        word_topic_path = os.path.join(model_dir, config.word_topic_file)
        vocab_path = os.path.join(model_dir, config.vocab_file)
        self.load_model(word_topic_path, vocab_path)

    def term_id(self, term):
        return self.__vocab.get_id(term)

    def load_model(self, word_topic_path, vocab_path):

        # Loading vocabulary
        self.__vocab.load(vocab_path)

        self.__beta_sum = self.__beta * self.__vocab.size()
        self.__word_topic = [{} for _ in range(self.__vocab.size())]  # 字典列表
        self.__load_word_dict(word_topic_path)
        logger.info(
            "Model Info: #num_topics=%d #vocab_size=%d alpha=%f beta=%f" %
            (self.num_topics(), self.vocab_size(), self.alpha(), self.beta()))

    def word_topic_value(self, word_id, topic_id):
        """Return value of specific word under specific topic in the model.
        """
        word_dict = self.__word_topic[word_id]
        if topic_id not in word_dict:
            return 0
        return word_dict[topic_id]

    def word_topic(self, term_id):
        """Return the topic distribution of a word.
        """
        return self.__word_topic[term_id]

    def topic_sum_value(self, topic_id):
        return self.__topic_sum[topic_id]

    def topic_sum(self):
        return self.__topic_sum

    def num_topics(self):
        return self.__num_topics

    def vocab_size(self):
        return self.__vocab.size()

    def alpha(self):
        return self.__alpha

    def alpha_sum(self):
        return self.__alpha_sum

    def beta(self):
        return self.__beta

    def beta_sum(self):
        return self.__beta_sum

    def type(self):
        return self.__type

    def __load_word_dict(self, word_dict_path):
        """Load the word topic parameters.
        """
        logger.info("Loading word topic.")
S
SiMing Dai 已提交
96
        with open(word_dict_path, 'r', encoding='utf-8') as f:
S
SiMing Dai 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            for line in tqdm(f.readlines()):
                fields = line.strip().split(" ")
                assert len(fields) > 0, "Model file format error!"
                term_id = int(fields[0])
                assert term_id < self.vocab_size(), "Term id out of range!"
                assert term_id >= 0, "Term id out of range!"
                for i in range(1, len(fields)):
                    topic_count = fields[i].split(":")
                    assert len(topic_count) == 2, "Topic count format error!"

                    topic_id = int(topic_count[0])
                    assert topic_id >= 0, "Topic out of range!"
                    assert topic_id < self.__num_topics, "Topic out of range!"

                    count = int(topic_count[1])
                    assert count >= 0, "Topic count error!"

                    self.__word_topic[term_id][topic_id] = count
                    self.__topic_sum[topic_id] += count
                    self.__topic_words[topic_id].append(
                        WordCount(term_id, count))
                new_dict = OrderedDict()
                for key in sorted(self.__word_topic[term_id]):
                    new_dict[key] = self.__word_topic[term_id][key]
                self.__word_topic[term_id] = new_dict

    def get_vocab(self):
        return self.__vocab.vocabulary()

    def topic_words(self):
        return self.__topic_words