data_utils.py 4.7 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
import os
import gzip
import json

from typing import *

LANGUAGE_TAG = {
    "c"            : "// language: C",
    "c++"          : "// language: C++",
    "cpp"          : "// language: C++",
    "c#"           : "// language: C#",
    "csharp"       : "// language: C#",
    "css"          : "/* language: CSS */",
    "cuda"         : "// language: Cuda",
    "dart"         : "// language: Dart",
    "lua"          : "// language: Lua",
    "objectivec"  : "// language: Objective-C",
    "objective-c"  : "// language: Objective-C",
    "objective-c++": "// language: Objective-C++",
    "python"       : "# language: Python",
    "perl"         : "# language: Perl",
    "prolog"       : f"% language: Prolog",
    "swift"        : "// language: swift",
    "lisp"         : "; language: Lisp",
    "java"         : "// language: Java",
    "scala"        : "// language: Scala",
    "tex"          : f"% language: TeX",
    "vue"          : "<!--language: Vue-->",
    "markdown"     : "<!--language: Markdown-->",
    "html"         : "<!--language: HTML-->",
    "php"          : "// language: PHP",
    "js"           : "// language: JavaScript",
    "javascript"   : "// language: JavaScript",
    "typescript"   : "// language: TypeScript",
    "go"           : "// language: Go",
    "shell"        : "# language: Shell",
    "rust"         : "// language: Rust",
    "sql"          : "-- language: SQL",
    "kotlin"       : "// language: Kotlin",
    "vb"           : "' language: Visual Basic",
    "ruby"         : "# language: Ruby",
    "pascal"       : "// language: Pascal",
    "r"            : "# language: R",
    "fortran"      : "!language: Fortran",
    "lean"         : "-- language: Lean",
    "matlab"       : f"% language: Matlab",
    "delphi"       : "{language: Delphi}",
    "scheme"       : "; language: Scheme",
    "basic"        : "' language: Basic",
    "assembly"     : "; language: Assembly",
    "groovy"       : "// language: Groovy",
    "abap"         : "* language: Abap",
    "gdscript"     : "# language: GDScript",
    "haskell"      : "-- language: Haskell",
    "julia"        : "# language: Julia",
    "elixir"       : "# language: Elixir",
    "excel"        : "' language: Excel",
    "clojure"      : "; language: Clojure",
    "actionscript" : "// language: ActionScript",
    "solidity"     : "// language: Solidity",
    "powershell"   : "# language: PowerShell",
    "erlang"       : f"% language: Erlang",
    "cobol"        : "// language: Cobol",
}


def stream_jsonl(filename: str) -> Iterable[Dict]:
    """
    Parses each jsonl line and yields it as a dictionary
    """
    if filename.endswith(".gz"):
        with open(filename, "rb") as gzfp:
            with gzip.open(gzfp, "rt") as fp:
                for line in fp:
                    if any(not x.isspace() for x in line):
                        yield json.loads(line)
    else:
        with open(filename, "r") as fp:
            for line in fp:
                if any(not x.isspace() for x in line):
                    yield json.loads(line)


def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
    """
    Writes an iterable of dictionaries to jsonl
    """
    if append:
        mode = "ab"
    else:
        mode = "wb"
    filename = os.path.expanduser(filename)
    if filename.endswith(".gz"):
        with open(filename, mode) as fp:
            with gzip.GzipFile(fileobj=fp, mode="wb") as gzfp:
                for x in data:
                    gzfp.write((json.dumps(x) + "\n").encode("utf-8"))
    else:
        with open(filename, mode) as fp:
            for x in data:
                fp.write((json.dumps(x) + "\n").encode("utf-8"))
                
                
def sliding_window(
    prompt_tokens: list, 
    code_tokens: list, 
    seq_len: int, 
    sliding_stride: int, 
    minimum_code_len: int = 1,
) -> Iterable[Tuple[list, list]]:
    """
    Generate a series of (prompt, code) pairs by sliding the window over the code.
    """
    prompt_len = len(prompt_tokens)
    code_len = len(code_tokens)
    total_len = prompt_len + code_len

    start_idx = max(0, prompt_len - seq_len + minimum_code_len)  # at least `minimum_code_len` code token should be in the window
    end_idx = max(0, total_len - seq_len)
    start_idx = min(start_idx, end_idx)

    for i in range(start_idx, end_idx + 1, sliding_stride):
        current_prompt = prompt_tokens[i:i + seq_len]
        current_code = code_tokens[max(i - prompt_len, 0):i - prompt_len + seq_len]
        yield current_prompt, current_code

    if (end_idx - start_idx) % sliding_stride != 0:
        current_prompt = prompt_tokens[end_idx:end_idx + seq_len]
        current_code = code_tokens[max(end_idx - prompt_len, 0):end_idx - prompt_len + seq_len]
        yield current_prompt, current_code