import os
import re
from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter


IMPORT_HELPER = {
    "python": [
        "import math",
        "import re",
        "import sys",
        "import copy",
        "import datetime",
        "import itertools",
        "import collections",
        "import heapq",
        "import statistics",
        "import functools",
        "import hashlib",
        "import numpy",
        "import numpy as np",
        "import string",
        "from typing import *",
        "from collections import *",
    ],
    "go"    : [
        "math",
        "strings",
        "fmt",
        "strconv",
        "time",
        "bytes",
        "regexp",
        "sort",
        "math/rand",
        "crypto/md5",
    ],
    "cpp"   : [
        "#include<stdlib.h>",
        "#include<algorithm>",
        "#include<math.h>",
        "#include<stdio.h>",
        "#include<vector>",
        "#include<string>",
        "#include<climits>",
        "#include<cstring>",
        "#include<iostream>",
    ],
}


def read_dataset(
    data_folder: str = None,
    dataset_type: str = "humaneval",
    language_type: str = "python",
    num_shot=None,
) -> Dict:
    if num_shot is not None:
        print(f"{num_shot}-shot setting...")
    if "humaneval" in dataset_type.lower():
        data_file = os.path.join(data_folder, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
        dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
    else:
        raise f"Dataset: {dataset_type} not supported."

    return dataset


def read_translation_dataset(
    data_file_src: str = None,
    data_file_tgt: str = None,
    lang_src: str = None,
    lang_tgt: str = None,
    dataset_type: str = "humaneval",
) -> Dict:
    if "humaneval" in dataset_type.lower():
        dataset_src = {task["task_id"]: task for task in stream_jsonl(data_file_src)}
        dataset_tgt = {task["task_id"].split("/")[-1]: task for task in stream_jsonl(data_file_tgt)}
        for k, sample in dataset_src.items():
            prompt = "code translation\n"
            if lang_src == "cpp":
                prompt += "C++:\n"
            elif lang_src == "js":
                prompt += "JavaScript:\n"
            else:
                prompt += f"{lang_src}:\n".capitalize()
            prompt += dataset_src[k]["declaration"] + "\n" + dataset_src[k]["canonical_solution"].rstrip() + "\n"
            if lang_tgt == "cpp":
                prompt += "C++:\n"
            elif lang_tgt == "js":
                prompt += "JavaScript:\n"
            else:
                prompt += f"{lang_tgt}:\n".capitalize()
            prompt += dataset_tgt[k.split("/")[-1]]["declaration"]
            dataset_src[k]["prompt"] = prompt
    else:
        raise f"Dataset: {dataset_type} not supported."

    return dataset_src


def process_extra_prompt(prompt: str, language_type: str = None) -> str:
    """
    Processes the extra prompt.
    """
    language = language_type.lower()
    if language in LANGUAGE_TAG:
        extra_prompt = LANGUAGE_TAG[language] + "\n"
    else:
        extra_prompt = ""

    return extra_prompt + prompt


def is_code_generation_finished(
    code: str,
    language_type: str = None,
    dataset: str = None,
):
    """
    Checks whether the generated code is finished.
    """
    if language_type is None or dataset is None:
        return False

    if "humaneval" in dataset.lower():
        if language_type.lower() == "python":
            for line in code.split("\n"):
                if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
                    return True
            end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
            for w in end_words:
                if w in code:
                    return True
        elif language_type.lower() == "java":
            if code.count("{") + 1 == code.count("}"):
                return True
        elif language_type.lower() == "go":
            if code.count("{") + 1 == code.count("}"):
                return True
        elif language_type.lower() == "js":
            if code.count("{") + 1 == code.count("}"):
                return True
        elif language_type.lower() == "cpp":
            if code.count("{") + 1 == code.count("}"):
                return True

    return False


def find_method_name(language_type, sample):
    """查找方法名"""
    if language_type.lower() == "python":
        declaration = sample["declaration"]
        ret = re.search("def (.*?)\\(", declaration)
        if ret is None:
            return None
        return ret.group(1).strip()
    elif language_type.lower() == "cpp":
        declaration = sample["declaration"]
        ret = re.search(" (.*?)\\(", declaration)
        if ret is None:
            return None
        method_name = ret.group(1).strip()
        if " " in method_name:
            return method_name[1]
        return method_name
    elif language_type.lower() == "java":
        declaration = sample["declaration"]
        ret = re.search(" (.*?)\\(", declaration)
        if ret is None:
            return None
        method_name = ret.group(1).strip()
        if " " in method_name:
            return method_name[1]
        return method_name
    elif language_type.lower() in ["js", "javascript"]:
        declaration = sample["declaration"]
        ret = re.search("const (.*?) ", declaration)
        if ret is None:
            return None
        method_name = ret.group(1).strip()
        return method_name
    elif language_type.lower() == "go":
        declaration = sample["declaration"]
        ret = re.search("func (.*?)\\(", declaration)
        if ret is None:
            return None
        return ret.group(1).strip()
    elif language_type == "rust":
        declaration = sample["declaration"]
        ret = re.search("fn (.*?)\\(", declaration)
        if ret is None:
            return None
        return ret.group(1).strip()
    else:
        return None


def extract_markdown_code(content):
    """提取markdown中的代码，即"```"中的内容"""
    codes = []
    rets = re.findall("```([\\s\\S]*?)```", content)
    if rets is None:
        return codes
    for ret in rets:
        if not ret.startswith("\n"):
            lines = ret.split("\n")
            codes.append("".join(lines[1:]))
        else:
            codes.append(ret.strip())

    return codes


def cleanup_code(code: str, sample, language_type: str = None):
    """
    Cleans up the generated code.
    """
    if language_type is None:
        return code

    method_name = find_method_name(language_type, sample)
    if method_name is None:
        return code

    method_body = code
    if language_type.lower() == "python":
        method_lines = []
        for line in code.split("\n"):
            if f"def {method_name}" in line:
                method_lines.append(line)
                continue
            if method_lines:
                method_lines.append(line)
            if line.startswith("    return"):
                break
        if method_lines:
            method_body = "\n".join(method_lines[1:])

    elif language_type.lower() == "java":
        method_lines = []
        bracket_left = 0
        bracket_right = 0
        for line in code.split("\n"):
            new_line = line.strip()
            counter = Counter(new_line)
            if new_line.startswith("public") and method_name in new_line:
                method_lines.append(line)
                bracket_left += counter["{"]
                bracket_right += counter["}"]
                continue
            if method_lines:
                method_lines.append(line)
                bracket_left += counter["{"]
                bracket_right += counter["}"]
            if bracket_left == bracket_right and bracket_right > 0:
                break

        if method_lines:
            method_lines.append("}")
            method_body = "\n".join(method_lines[1:])
    elif language_type.lower() == "go":
        method_lines = []
        bracket_left = 0
        bracket_right = 0
        for line in code.split("\n"):
            counter = Counter(line)
            if f"func {method_name}" in line:
                method_lines.append(line)
                bracket_left += counter["{"]
                bracket_right += counter["}"]
                continue
            if method_lines:
                method_lines.append(line)
                bracket_left += counter["{"]
                bracket_right += counter["}"]
            if bracket_left == bracket_right and bracket_right > 0:
                break

        if method_lines:
            method_body = "\n".join(method_lines[1:])
            print(method_body)

    elif language_type.lower() == "cpp":
        method_lines = []
        bracket_left = 0
        bracket_right = 0
        for line in code.split("\n"):
            counter = Counter(line)
            if f" {method_name}(" in line:
                method_lines.append(line)
                bracket_left += counter["{"]
                bracket_right += counter["}"]
                continue
            if method_lines:
                method_lines.append(line)
                bracket_left += counter["{"]
                bracket_right += counter["}"]
            if bracket_left == bracket_right and bracket_right > 0:
                break

        if method_lines:
            method_body = "\n".join(method_lines[1:])

    elif language_type.lower() == "js":
        method_lines = []
        bracket_left = 0
        bracket_right = 0
        for line in code.split("\n"):
            counter = Counter(line)
            if f"const {method_name}" in line:
                method_lines.append(line)
                bracket_left += counter["{"]
                bracket_right += counter["}"]
                continue
            if method_lines:
                method_lines.append(line)
                bracket_left += counter["{"]
                bracket_right += counter["}"]
            if bracket_left == bracket_right and bracket_right > 0:
                break

        if method_lines:
            method_body = "\n".join(method_lines[1:])

    return method_body + "\n"


def exception_reconnect(funct):
    """异常重连"""
    def wrapper_func(*args, **kwargs):
        try:
            return funct(*args, **kwargs)
        except Exception as e:
            print(f"exception will reconnect: {str(e)}")
            return funct(*args, **kwargs)

    return wrapper_func