# -*- coding: UTF-8 -*-
# 作者：huanhuilong
# 标题：相似匹配2
# 描述：Word2Vec 词向量化，查询相近词

import os
import json
import jieba
from gensim.models import word2vec


class Dataset:
    def __init__(self, dataset_file) -> None:
        self.dataset_file = dataset_file
        self.train_sentences = None
        self.test_sentences = None

    def generate(self):
        input = [
            '今天 天气 很 好',
            '今天很 好',
            '今天 天气 很 好',
            '今天 天气 很 好',
            '今天 天',
            '今天 天气 很 好',
            '今天 天气 很 好',
            '今天 天气 很 好',
            '今天 天气 很 好',
            '今天 天',
            '今天 天气 很 好',
            '气 很 好',
            '今天 天气 很 差',
            '不错',
            '还行'
        ]

        sentences = []
        for s in input:
            sentences.append(Dataset.cut_sentence(s))

        return sentences

    @staticmethod
    def cut_sentence(s):
        return [s for s in list(jieba.cut(str(s), cut_all=True)) if s.strip() != '']

    def load(self):
        if os.path.exists(self.dataset_file):
            with open(self.dataset_file, 'r') as f:
                ret = json.loads(f.read())
                self.train_sentences = ret['train']
                self.test_sentences = ret['test']
        else:
            sentences = self.generate()
            with open(self.dataset_file, 'w') as f:
                cut = len(sentences)-5
                ret = {
                    'train': sentences[:cut],
                    'test': sentences[cut:],
                }
                self.train_sentences = ret['train']
                self.test_sentences = ret['test']
                f.write(json.dumps(ret))


class Word2VecSimilar:
    def __init__(self, model=None) -> None:
        self.model = model

    def fit(self, train_sentences):
        model = word2vec.Word2Vec(
            train_sentences,
            workers=2,
            size=300,
            min_count=1,
            window=5,
            sample=1e-3
        )
        model.init_sims(replace=True)
        self.model = model

    def predict(self, sentence):
        print('=>', sentence)
        more_sentences = Dataset.cut_sentence(sentence)
        print(more_sentences)
        self.model.train(
            more_sentences,
            total_examples=self.model.corpus_count,
            epochs=self.model.iter,
        )

        try:
            ret = self.model.wv.most_similar(more_sentences[0], topn=3)
            return ret
        except:
            return None

    def save(self, model_file):
        self.model.save(model_file)

    @staticmethod
    def load_model(model_file):
        model = word2vec.Word2Vec.load(model_file)
        return Word2VecSimilar(model)


if __name__ == '__main__':
    # 准备数据集
    dataset_file = '/tmp/nlp_word2vec.dataset'
    dataset = Dataset(dataset_file)
    dataset.load()

    # 训练模型
    model_file = '/tmp/nlp_word2vec.model'
    if os.path.exists(model_file):
        model = Word2VecSimilar.load_model(model_file)
    else:
        model = Word2VecSimilar()
        model.fit(dataset.train_sentences)
        model.save(model_file)

    # 预测
    for s in dataset.test_sentences:
        for w in s:
            print(w)
            ret = model.predict(w)
            if ret:
                print(ret)
