'''
@File    :   node_knowledge_mapping.py
@Time    :   2022/05/30 16:21:40
@Author  :   Lu Xin 
@Contact :   luxin@csdn.net
'''

# here put the import lib
import re
import ipdb

import pandas as pd

from treelib import Tree
from treelib import Node

from path import get_tree_dir
from path import get_index_dir
from path import get_sample_id_dir

from utils import load_json
from utils import load_markdown


class NodeKnowledgeMapping():
    def __init__(self, category="blog") -> None:
        self.tree_name = None
        self.category = category
        self.tree = Tree()
        self.text_id_dict = None
        self.section_text_dict = None
        self.section_sample_dict = None

    def load(self):
        self.__load_tree()
        self.__load_index()
        self.__load_sample_id()

    def __construct_tree(self, tree_dict, parent):
        for node_text, node_info in tree_dict.items():
            node_id = node_info["node_id"]
            subtree_list = node_info["children"]
            node = Node(
                tag=node_text, 
                identifier=node_id)
            self.tree.add_node(node, parent=parent)
            for subtree_dict in subtree_list:
                self.__construct_tree(subtree_dict, node_id)

    def __load_tree(self):
        self.text_id_dict = {}

        tree_dict = load_json(get_tree_dir())
        self.tree_name = list(tree_dict.keys())[0].lower()
        self.__construct_tree(tree_dict, None)
        paths_to_leaves = self.tree.paths_to_leaves()
        for path in paths_to_leaves:
            text = "-".join(
                [self.tree.get_node(node_id).tag.replace(" ", "").lower() \
                 for node_id in path[1: ]])
            id = path[-1]
            self.text_id_dict[text] = id
        
    def __load_index(self):
        self.section_text_dict = {}

        mk_list = load_markdown(get_index_dir())
        _len = len(mk_list)
        _index = 0
        while _index < (_len - 1):
            line = mk_list[_index]
            line_next = mk_list[_index + 1]

            if line.startswith("##") and not line_next.startswith("##"):
                section = re.sub(r"^#{1,10} {1,5}", "", line)
                section = re.sub(r"^\[.*?\]", "", section).strip()
                text = line_next.replace(" ", "").lower()
                if not text.startswith(self.tree_name):
                    text = self.tree_name + text
                if text.find("不采纳") == -1:
                    self.section_text_dict[section] = text
                _index += 2
            else:
                _index += 1

    def __load_sample_id(self):
        self.section_sample_dict = load_json(get_sample_id_dir())

    def get_node_knowledge_mapping(self, file_name):
        columns = ["node_id", "text", "book_text", "sample_id", "tree_name", "category"]
        contents = []
        for section, text in self.section_text_dict.items():
            if text in self.text_id_dict:
                node_id = self.text_id_dict[text]
            else:
                print("路径 \"{}\" 不存在!".format(text))
                continue
            sample_id = self.section_sample_dict.get(section, None)
            contents.append([node_id, text, section, sample_id, self.tree_name, self.category])
        
        df = pd.DataFrame(contents, columns=columns)
        df.to_csv(file_name, index=False)
        

def main():
    nkm = NodeKnowledgeMapping()
    nkm.load()

    file_name = "./data/mysql_update_4_top.csv"
    nkm.get_node_knowledge_mapping(file_name)


if __name__=='__main__':
    main()
