提交 cc9d2464 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

transplant eval code

上级 88b40af9
......@@ -19,3 +19,5 @@
## 同类评测项目
1. 斯坦福大学的评测:AlpacaEval Logo Leaderboard <https://tatsu-lab.github.io/alpaca_eval/>
2. https://github.com/the-crypt-keeper/can-ai-code
3. https://github.com/THUDM/CodeGeeX/tree/main/codegeex/benchmark
\ No newline at end of file
/*
Input to this function is a string containing multiple groups of nested parentheses. Your goal is to
separate those group into separate strings and return the vector of those.
Separate groups are balanced (each open brace is properly closed) and not nested within each other
Ignore any spaces in the input string.
>>> separate_paren_groups("( ) (( )) (( )( ))")
{"()", "(())", "(()())"}
*/
#include<stdio.h>
#include<vector>
#include<string>
using namespace std;
vector<string> separate_paren_groups(string paren_string){
vector<string> all_parens;
string current_paren;
int level=0;
char chr;
int i;
for (i=0;i<paren_string.length();i++)
{
chr=paren_string[i];
if (chr=='(')
{
level+=1;
current_paren+=chr;
}
if (chr==')')
{
level-=1;
current_paren+=chr;
if (level==0){
all_parens.push_back(current_paren);
current_paren="";
}
}
}
return all_parens;
}
#undef NDEBUG
#include<assert.h>
bool issame(vector<string> a,vector<string>b){
if (a.size()!=b.size()) return false;
for (int i=0;i<a.size();i++)
{
if (a[i]!=b[i]) return false;
}
return true;
}
int main(){
assert (issame(separate_paren_groups("(()()) ((())) () ((())()())"),{"(()())", "((()))", "()", "((())()())"}));
assert (issame(separate_paren_groups("() (()) ((())) (((())))"), {"()", "(())", "((()))", "(((())))" }));
assert (issame(separate_paren_groups("(()(())((())))") ,{ "(()(())((())))" }));
assert (issame(separate_paren_groups("( ) (( )) (( )( ))") ,{"()", "(())", "(()())"}));
}
\ No newline at end of file
module humanEval
go 1.18
require (
github.com/go-openapi/inflect v0.19.0
github.com/stretchr/testify v1.8.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-openapi/inflect v0.19.0 h1:9jCH9scKIbHeV9m12SmPilScz6krDxKRasNNSNPXu/4=
github.com/go-openapi/inflect v0.19.0/go.mod h1:lHpZVlpIQqLyKwJ4N+YSc9hchQy/i12fJykb83CRBH4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
{
"model_path": "llm_set/models/chatglm2-6b",
"device": "cuda:0",
"quantize": false,
"max_length": 1024,
"min_length": 10
}
\ No newline at end of file
{
"url": "http://47.74.27.238/conversation.php",
"id": 1001059396431415,
"stream": 0,
"timeout": 20
}
\ No newline at end of file
{
"prompt": "",
"max_new_tokens": 512,
"do_sample": true,
"temperature": 0.7,
"top_p": 0.95,
"typical_p": 1,
"epsilon_cutoff": 0,
"eta_cutoff": 0,
"repetition_penalty": 1.2,
"top_k": 50,
"min_length": 0,
"no_repeat_ngram_size": 0,
"num_beams": 1,
"penalty_alpha": 0,
"length_penalty": 1,
"early_stopping": false,
"mirostat_mode": 0,
"mirostat_tau": 5,
"mirostat_eta": 0.1,
"seed": -1,
"add_bos_token": true,
"truncation_length": 2048,
"ban_eos_token": false,
"skip_special_tokens": true,
"stopping_strings": []
}
\ No newline at end of file
import os
import gzip
import json
from typing import *
LANGUAGE_TAG = {
"c" : "// language: C",
"c++" : "// language: C++",
"cpp" : "// language: C++",
"c#" : "// language: C#",
"csharp" : "// language: C#",
"css" : "/* language: CSS */",
"cuda" : "// language: Cuda",
"dart" : "// language: Dart",
"lua" : "// language: Lua",
"objectivec" : "// language: Objective-C",
"objective-c" : "// language: Objective-C",
"objective-c++": "// language: Objective-C++",
"python" : "# language: Python",
"perl" : "# language: Perl",
"prolog" : f"% language: Prolog",
"swift" : "// language: swift",
"lisp" : "; language: Lisp",
"java" : "// language: Java",
"scala" : "// language: Scala",
"tex" : f"% language: TeX",
"vue" : "<!--language: Vue-->",
"markdown" : "<!--language: Markdown-->",
"html" : "<!--language: HTML-->",
"php" : "// language: PHP",
"js" : "// language: JavaScript",
"javascript" : "// language: JavaScript",
"typescript" : "// language: TypeScript",
"go" : "// language: Go",
"shell" : "# language: Shell",
"rust" : "// language: Rust",
"sql" : "-- language: SQL",
"kotlin" : "// language: Kotlin",
"vb" : "' language: Visual Basic",
"ruby" : "# language: Ruby",
"pascal" : "// language: Pascal",
"r" : "# language: R",
"fortran" : "!language: Fortran",
"lean" : "-- language: Lean",
"matlab" : f"% language: Matlab",
"delphi" : "{language: Delphi}",
"scheme" : "; language: Scheme",
"basic" : "' language: Basic",
"assembly" : "; language: Assembly",
"groovy" : "// language: Groovy",
"abap" : "* language: Abap",
"gdscript" : "# language: GDScript",
"haskell" : "-- language: Haskell",
"julia" : "# language: Julia",
"elixir" : "# language: Elixir",
"excel" : "' language: Excel",
"clojure" : "; language: Clojure",
"actionscript" : "// language: ActionScript",
"solidity" : "// language: Solidity",
"powershell" : "# language: PowerShell",
"erlang" : f"% language: Erlang",
"cobol" : "// language: Cobol",
}
def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
Parses each jsonl line and yields it as a dictionary
"""
if filename.endswith(".gz"):
with open(filename, "rb") as gzfp:
with gzip.open(gzfp, "rt") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
else:
with open(filename, "r") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
"""
Writes an iterable of dictionaries to jsonl
"""
if append:
mode = "ab"
else:
mode = "wb"
filename = os.path.expanduser(filename)
if filename.endswith(".gz"):
with open(filename, mode) as fp:
with gzip.GzipFile(fileobj=fp, mode="wb") as gzfp:
for x in data:
gzfp.write((json.dumps(x) + "\n").encode("utf-8"))
else:
with open(filename, mode) as fp:
for x in data:
fp.write((json.dumps(x) + "\n").encode("utf-8"))
def sliding_window(
prompt_tokens: list,
code_tokens: list,
seq_len: int,
sliding_stride: int,
minimum_code_len: int = 1,
) -> Iterable[Tuple[list, list]]:
"""
Generate a series of (prompt, code) pairs by sliding the window over the code.
"""
prompt_len = len(prompt_tokens)
code_len = len(code_tokens)
total_len = prompt_len + code_len
start_idx = max(0, prompt_len - seq_len + minimum_code_len) # at least `minimum_code_len` code token should be in the window
end_idx = max(0, total_len - seq_len)
start_idx = min(start_idx, end_idx)
for i in range(start_idx, end_idx + 1, sliding_stride):
current_prompt = prompt_tokens[i:i + seq_len]
current_code = code_tokens[max(i - prompt_len, 0):i - prompt_len + seq_len]
yield current_prompt, current_code
if (end_idx - start_idx) % sliding_stride != 0:
current_prompt = prompt_tokens[end_idx:end_idx + seq_len]
current_code = code_tokens[max(end_idx - prompt_len, 0):end_idx - prompt_len + seq_len]
yield current_prompt, current_code
import os
import sys
import fire
import json
import gzip
import regex
import numpy as np
from typing import *
from tqdm.auto import tqdm
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from utils import read_dataset, IMPORT_HELPER
from metric import estimate_pass_at_k
from execution import check_correctness
LANGUAGE_NAME = {
"cpp" : "CPP",
"go" : "Go",
"java" : "Java",
"js" : "JavaScript",
"python": "Python",
}
def process_humaneval_test(sample, problems, example_test=False):
task_id = sample["task_id"]
language = task_id.split("/")[0].lower()
prompt = sample["prompt"]
if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "":
test = problems[task_id]["example_test"]
else:
test = problems[task_id]["test"]
code = sample["generation"]
# Pre-process for different languages
if language == "python":
code_ = []
for line in code.split("\n"):
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'):
break
code_.append(line)
code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
test_string = test_setup + prompt + code + "\n" + test + "\n"
elif language == "cpp":
test_set_up = ""
for s in IMPORT_HELPER["cpp"]:
if s not in prompt:
test_set_up += s + "\n"
test_string = test_set_up + "\n" + prompt + code + "\n" + test
elif language == "java":
test_string = prompt + code + "\n" + test
elif language == "js" or language == "javascript":
test_string = prompt + code + "\n" + test
elif language == "go":
import_string = problems[task_id]["import"]
prompt = prompt.replace(import_string, "")
if example_test and "example_test" in problems[task_id]:
test = problems[task_id]["example_test"]
else:
test = problems[task_id]["test"]
test_setup = problems[task_id]["test_setup"]
other_pkgs = []
for pkg in IMPORT_HELPER["go"]:
if pkg not in test_setup:
p = pkg.split("/")[-1]
if p + "." in code:
other_pkgs.append(f"\"{pkg}\"")
if other_pkgs:
import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")"
test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test
else:
test_string = test_setup + "\n" + prompt + code + "\n" + test
elif language == "rust":
main = "\nfn main(){ \n } \n"
declaration = problems[task_id]["declaration"]
test_string = main + declaration + prompt + code + test
return test_string
def stream_jsonl_all(filename: str) -> Iterable[Dict]:
results = []
if filename.endswith(".gz"):
fp = gzip.open(open(filename, "rb"), "rt")
else:
fp = open(filename, "r")
for line in fp:
if any(not x.isspace() for x in line):
results.append(json.loads(line))
fp.close()
return results
def evaluate_functional_correctness(
input_file: str = None,
tmp_dir: str = "./",
n_workers: int = 32,
timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz",
out_dir: str = None,
k: List[int] = [1, 10, 100],
test_groundtruth: bool = False,
example_test: bool = False,
):
if example_test:
print("Example test...")
problems = read_dataset(problem_file,
dataset_type="humaneval")
sample_jsonl = stream_jsonl_all(input_file)
if example_test:
suffix = "_example_test.jsonl"
else:
suffix = "_results.jsonl"
if out_dir is not None:
if not os.path.exists(out_dir):
os.makedirs(out_dir)
out_file = os.path.join(out_dir, input_file.split('/')[-1].replace(".jsonl", suffix))
else:
out_file = os.path.join(input_file.replace(".jsonl", suffix))
if "/codegeex/benchmark/humaneval-x/" in input_file:
test_groundtruth = True
if "-to-" in input_file:
translation_mode = True
else:
translation_mode = False
with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = []
completion_id = Counter()
n_samples = 0
results = defaultdict(list)
if test_groundtruth:
print("Testing ground truth...")
for sample in tqdm(problems.values()):
task_id = sample["task_id"]
lang = task_id.split("/")[0].lower()
if lang == "javascript":
lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
sample["generation"] = sample["canonical_solution"]
sample["test_code"] = process_humaneval_test(sample, problems, example_test)
if sample["test_code"] is None:
continue
args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id])
future = executor.submit(check_correctness, *args)
futures.append(future)
completion_id[task_id] += 1
n_samples += 1
else:
print("Reading samples...")
for sample in tqdm(sample_jsonl):
task_id = sample["task_id"]
lang = task_id.split("/")[0].lower()
if translation_mode:
task_id = sample["task_id"].split("/")[-1]
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
for l in LANGUAGE_NAME:
if l in lang:
lang = l
break
task_id = f"{LANGUAGE_NAME[lang]}/{task_id}"
if lang == "javascript":
lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
sample["task_id"] = task_id
sample["test_code"] = process_humaneval_test(sample, problems, example_test)
if sample["test_code"] is None:
continue
if "completion_id" in sample:
completion_id_ = sample["completion_id"]
else:
completion_id_ = completion_id[task_id]
args = (task_id, sample, lang, timeout, tmp_dir_, completion_id_)
future = executor.submit(check_correctness, *args)
futures.append(future)
completion_id[task_id] += 1
n_samples += 1
print(completion_id)
if len(completion_id) == len(problems):
evaluate_pass_at_k = True
else:
evaluate_pass_at_k = False
print("Running test suites...")
for future in tqdm(as_completed(futures), total=len(futures)):
result = future.result()
results[result["task_id"]].append((result["completion_id"], result))
# Calculate pass@k.
total, correct = [], []
for result in results.values():
passed = [r[1]["passed"] for r in result]
total.append(len(passed))
correct.append(sum(passed))
total = np.array(total)
correct = np.array(correct)
if evaluate_pass_at_k:
ks = k
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
for k in ks if (total >= k).all()}
print(pass_at_k)
else:
print("Total:", np.sum(total))
print("Correct:", np.sum(correct))
print("Writing to: ", out_file)
if out_file.endswith(".gz"):
fp = gzip.GzipFile(fileobj=open(out_file, "wb"), mode="wb")
for res in results.values():
for r in res:
fp.write((json.dumps(r[1]) + "\n").encode("utf-8"))
else:
fp = open(out_file, 'w')
for res in results.values():
for r in res:
fp.write(json.dumps(r[1]) + "\n")
fp.close()
print("Evaluation finished.")
def main():
fire.Fire(evaluate_functional_correctness)
if __name__ == "__main__":
sys.exit(main())
此差异已折叠。
import os
import sys
import fire
import glob
def gather_output(
output_dir: str = "./output",
output_prefix: str = None,
if_remove_rank_files: int = 0,
):
if output_prefix is None:
output_list = glob.glob(output_dir + "/*")
else:
output_list = glob.glob(os.path.join(output_dir, output_prefix + "*"))
for output_file in output_list:
if "rank0" in output_file:
output_prefix_ = output_file.split("_rank0.jsonl")[0]
rank_files = glob.glob(output_prefix_ + "_rank*")
with open(output_prefix_ + ".jsonl", "w") as f_out:
for rank_file in rank_files:
with open(rank_file, "r") as f_in:
for line in f_in:
f_out.write(line)
if if_remove_rank_files:
os.remove(rank_file)
print(f"Removing {rank_file}...")
if output_prefix is None:
output_list = glob.glob(output_dir + "/*")
else:
output_list = glob.glob(os.path.join(output_dir, output_prefix + "*"))
for output_file in output_list:
if "rank" in output_file or "_unfinished" in output_file or "all" in output_file or "_result" in output_file:
continue
if "_finished" not in output_file:
continue
output_prefix_ = output_file.split("_finished.jsonl")[0]
files = [output_file, output_prefix_ + "_unfinished.jsonl"]
with open(output_prefix_ + "_all.jsonl", "w") as f_out:
for f in files:
with open(f, "r") as f_in:
for line in f_in:
f_out.write(line)
print("Gathering finished. Saved in {}".format(output_prefix_ + "_all.jsonl"))
def main():
fire.Fire(gather_output)
if __name__ == "__main__":
sys.exit(main())
import argparse
import logging
import os
import random
import socket
import time
import json
import torch
from utils import read_dataset, process_extra_prompt
from utils import cleanup_code
logging.getLogger("torch").setLevel(logging.WARNING)
if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path",
type=str,
default="eval_set/humaneval-x"
)
parser.add_argument(
"--language_type",
type=str,
default="python"
)
parser.add_argument(
"--model_name",
type=str,
default="chatgpt",
help='supported model in [chatgpt,chatglm2].'
)
parser.add_argument(
"--output_prefix",
type=str,
default="chatgpt"
)
args = parser.parse_known_args()[0]
output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl"
entries = read_dataset(args.input_path, args.language_type, dataset_type="humaneval")
for entry in entries.values():
entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)
model_inference = None
if args.model_name == "chatgpt":
from inference.chatgpt_inference import ChatGPTInference
model_inference = ChatGPTInference()
elif args.model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference
model_inference = GLMInference()
else:
print(f"{args.model_name} not supported.")
start_time = time.perf_counter()
num_finished = 0
with open(output_file_path, "w") as f:
for entry in entries.values():
prompt = entry["prompt"]
generated_tokens = model_inference.inference(prompt)
generated_code = cleanup_code(generated_code, language_type=args.language_type)
f.write(
json.dumps(
{
"task_id": entry['task_id'],
"prompt": prompt,
"generation": generated_code,
"scores": 0.0,
"finish": 2,
"output": generated_tokens,
}
)
+ "\n"
)
f.flush()
num_finished += 1
time_elapsed = time.perf_counter() - start_time
time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished
print(
f"finished {num_finished}, "
f"elapsed {time_elapsed:.4f}",
f"speed {time_per_sampple:.4f}s/sample",
f"remaining {len(entries) - num_finished:.4f}",
flush=True,
)
import os
import json
import torch
import logging
from transformers import AutoModel, AutoTokenizer
from inference import Inference
logger = logging.getLogger(__name__)
class GLMInference(Inference):
def __int__(self):
self.params_url = "llm_set/params/chatgpt.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
self.temperature = self.paras_dict.get("temperature")
self.quantize = self.paras_dict.get("quantize")
self.device = self.paras_dict.get("device")
self.model_path = self.paras_dict.get("model_path")
if self.device == "cpu":
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True,
ignore_mismatched_sizes=True).float()
self.DEV = torch.device('cpu')
else:
os.environ["CUDA_VISIBLE_DEVICES"] = self.device.split(":")[1]
self.DEV = torch.device(self.device)
if self.quantize:
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).quantize(8).cuda()
else:
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).cuda()
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.model = self.model.eval()
self.max_length = self.paras_dict.get("max_length")
self.min_length = self.paras_dict.get("min_length")
self.top_p = self.paras_dict.get("top_p")
self.top_k = self.paras_dict.get("top_k")
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
def inference(self, message):
input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV)
sentence = None
with torch.no_grad():
generation_output = self.model.generate(
input_ids,
do_sample=True,
min_length=self.min_length,
max_length=self.max_length,
top_p=self.top_p,
top_k=self.top_k,
temperature=self.temperature
)
output = self.tokenizer.decode([el.item() for el in generation_output[0]])
sentence = output.split("ChatCSDN:")[1].strip()
return sentence
import os
import json
import logging
import quote
import requests
import traceback
from inference import Inference
logger = logging.getLogger(__name__)
class ChatGPTInference(Inference):
def __int__(self):
self.params_url = "llm_set/params/chatgpt.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
self.gpt_url = self.paras_dict.get("url")
self.id = self.paras_dict.get("id")
self.stream = self.paras_dict.get("stream")
self.temperature = self.paras_dict.get("temperature")
self.timeout = self.paras_dict.get("timeout")
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
def inference(self, query_text):
query_text = quote(query_text)
param_str = f"?id={self.gpt_url}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
try:
resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return None
if resp.status_code == 200:
resp_json = json.loads(resp.text)
return "\n".join(resp_json["message"]["content"]["parts"])
else:
return None
import os
import json
import logging
logger = logging.getLogger(__name__)
class Inference(object):
def __int__(self):
self.params_base_url = "llm_set/params/default.json"
self.paras_base_dict = self.get_paras_base()
def get_paras_base(self):
if not os.path.exists(self.params_base_url):
logger.error(f"params_base_url:{self.params_base_url} is not exists.")
content = open(self.params_base_url).read()
return json.loads(content)
import os
import sys
import fire
import json
import glob
import numpy as np
import pandas as pd
from collections import defaultdict
from metric import estimate_pass_at_k
error_types = {
"python" : [
"accepted",
"assertion error",
"undefined error",
"runtime error",
"syntax error",
"timeout error",
"type error",
],
"java" : [
"accepted",
"compilation error",
"assertion error",
"timeout error",
"index error",
"class cast error",
"stack overflow",
"null pointer error",
"unsupported operation",
"number format error",
"no such element",
"illegal argument",
"out of memory",
"arithmetic error",
"others",
],
"cpp" : [
"accepted",
"compilation error",
"assertion error",
"range error",
"invalid argument",
"pointer error",
"out of memory",
"package error",
"others",
],
"javascript": [
"accepted",
"assertion error",
"undefined error",
"runtime error",
"syntax error",
"timeout error",
"range error",
"type error",
],
"go" : [
"accepted",
"assertion error",
"undefined error",
"runtime error",
"syntax error",
"timeout error",
"type error",
"notused error",
],
}
def inspect_result(
input_dir: str = None,
input_file: str = None,
output_dir: str = None,
pass_at_k_outpath: str = None,
):
if input_dir is not None:
input_files = glob.glob(input_dir + "/*_results.jsonl")
else:
input_files = [input_file]
if output_dir is not None:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
else:
output_dir = "/"
pass_at_k_outs = []
for input_file in input_files:
result_stats = defaultdict(dict)
df = None
incompleted = False
with open(input_file, "r") as f:
for i, line in enumerate(f):
obj = json.loads(line)
task_id = obj["task_id"]
language_type = task_id.split("/")[0].lower()
if language_type not in error_types:
if language_type == "humaneval":
language_type = "python"
elif language_type == "C++":
language_type = "cpp"
else:
incompleted = True
break
if task_id not in result_stats.keys():
default_stats = {}
default_stats["task_id"] = task_id
default_stats["n_sample"] = 0
for k in error_types[language_type]:
default_stats[k] = 0
result_stats[task_id] = default_stats.copy()
if df is None:
df = pd.DataFrame(columns=error_types[language_type])
result_stats[task_id]["n_sample"] += 1
if "failed" in obj["result"]:
error = obj["result"]
if language_type == "python":
if "assertionerror" in error.lower():
result_stats[task_id]["assertion error"] += 1
elif "syntax" in error.lower() or "indent" in error.lower() or "literal" in error.lower():
result_stats[task_id]["syntax error"] += 1
elif "not defined" in error.lower():
result_stats[task_id]["undefined error"] += 1
elif "timeout" in error.lower():
result_stats[task_id]["timeout error"] += 1
elif "type" in error.lower():
result_stats[task_id]["type error"] += 1
else:
result_stats[task_id]["runtime error"] += 1
elif language_type == "java":
if "wrong answer" in error:
result_stats[task_id]["assertion error"] += 1
elif "compilation error" in error:
result_stats[task_id]["compilation error"] += 1
elif "time out" in error:
result_stats[task_id]["timeout error"] += 1
elif "IndexOutOfBounds" in error:
result_stats[task_id]["index error"] += 1
elif "UnsupportedOperation" in error:
result_stats[task_id]["unsupported operation"] += 1
elif "ClassCast" in error:
result_stats[task_id]["class cast error"] += 1
elif "NullPointer" in error:
result_stats[task_id]["null pointer error"] += 1
elif "NumberFormat" in error:
result_stats[task_id]["number format error"] += 1
elif "NoSuchElement" in error:
result_stats[task_id]["no such element"] += 1
elif "StackOverflow" in error:
result_stats[task_id]["stack overflow"] += 1
elif "Arithmetic" in error:
result_stats[task_id]["arithmetic error"] += 1
elif "OutOfMemory" in error:
result_stats[task_id]["out of memory"] += 1
elif "IllegalArgument" in error:
result_stats[task_id]["illegal argument"] += 1
else:
result_stats[task_id]["others"] += 1
elif language_type == "cpp":
if "compilation error" in error.lower():
result_stats[task_id]["compilation error"] += 1
elif "int main(): assertion" in error.lower():
result_stats[task_id]["assertion error"] += 1
elif "out_of_range" in error.lower():
result_stats[task_id]['range error'] += 1
elif "corrupted top size" in error.lower():
result_stats[task_id]['range error'] += 1
elif "length_error" in error.lower():
result_stats[task_id]['range error'] += 1
elif "invalid_argument" in error.lower():
result_stats[task_id]['invalid argument'] += 1
elif "invalid pointer" in error.lower():
result_stats[task_id]['pointer error'] += 1
elif "double free" in error.lower():
result_stats[task_id]['pointer error'] += 1
elif "free()" in error.lower():
result_stats[task_id]['pointer error'] += 1
elif "logic_error" in error.lower():
result_stats[task_id]['pointer error'] += 1
elif "sysmalloc: assertion" in error.lower():
result_stats[task_id]['pointer error'] += 1
elif "stack smashing" in error.lower():
result_stats[task_id]['out of memory'] += 1
elif "bad_alloc" in error.lower():
result_stats[task_id]['out of memory'] += 1
elif "terminate called after throwing an instance of" in error.lower():
result_stats[task_id]['package error'] += 1
else:
result_stats[task_id]["others"] += 1
elif language_type == "javascript":
if "Assertion failed" in error:
result_stats[task_id]["assertion error"] += 1
elif "SyntaxError" in error:
result_stats[task_id]["syntax error"] += 1
elif "ReferenceError" in error:
result_stats[task_id]["undefined error"] += 1
elif "timed out" in error:
result_stats[task_id]["timeout error"] += 1
elif "TypeError" in error:
result_stats[task_id]["type error"] += 1
elif "RangeError" in error:
result_stats[task_id]["range error"] += 1
else:
result_stats[task_id]["runtime error"] += 1
elif language_type == "go":
if "Error: \tNot equal:" in error:
result_stats[task_id]["assertion error"] += 1
elif "undefined" in error:
result_stats[task_id]["undefined error"] += 1
elif "expected" in error and "found" in error:
result_stats[task_id]["syntax error"] += 1
elif "illegal" in error:
result_stats[task_id]["syntax error"] += 1
elif "unexpected" in error:
result_stats[task_id]["syntax error"] += 1
elif "FAIL" in error:
result_stats[task_id]["runtime error"] += 1
elif "timed out" in error:
result_stats[task_id]["timeout error"] += 1
elif "not used" in error:
result_stats[task_id]['notused error'] += 1
elif "type" in error:
result_stats[task_id]['type error'] += 1
else:
incompleted = True
break
else:
if obj["passed"]:
result_stats[task_id]["accepted"] += 1
if incompleted:
print(f"Language not supported, aborted. {input_file}")
else:
try:
total, correct = [], []
for k, res in result_stats.items():
total.append(res["n_sample"])
correct.append(res["accepted"])
df_res = pd.DataFrame(res, index=[int(k.split("/")[-1])])
df = pd.concat([df, df_res], axis=0)
total = np.array(total)
correct = np.array(correct)
ks = [1, 10, 100, 1000]
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
for k in ks if (total >= k).all()}
print(pass_at_k)
pass_at_k["file"] = input_file
pass_at_k["n"] = res["n_sample"]
pass_at_k_outs.append(pass_at_k)
output_prefix = input_file.split("/")[-1].split(".jsonl")[0]
output_file = os.path.join(output_dir, output_prefix + "_stats.xlsx")
df = df.sort_index(ascending=True)
df.to_excel(output_file)
print(f"Stats saved in {output_file}")
except Exception as e:
print(e)
print(f"Data incompleted, aborted. {input_file}")
if pass_at_k_outpath is not None:
jsonl_path = os.path.join(output_dir, pass_at_k_outpath)
with open(jsonl_path, "w") as f_out:
for p in pass_at_k_outs:
f_out.write(json.dumps(p) + "\n")
print(f"Pass at k saved in {jsonl_path}")
def main():
fire.Fire(inspect_result)
if __name__ == "__main__":
sys.exit(main())
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
import itertools
import numpy as np
from typing import *
def estimate_pass_at_k(
num_samples: Union[int, List[int], np.ndarray],
num_correct: Union[List[int], np.ndarray],
k: int
) -> np.ndarray:
"""
Estimates pass@k of each problem and returns them in an array.
"""
def estimator(n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
if isinstance(num_samples, int):
num_samples_it = itertools.repeat(num_samples, len(num_correct))
else:
assert len(num_samples) == len(num_correct)
num_samples_it = iter(num_samples)
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
import os
import zmq
import time
import torch
import random
import socket
import logging
import argparse
from typing import *
from utils import read_translation_dataset
from codegeex.megatron import get_args
from codegeex.megatron.inference import run_generation_distributed, model_provider
from codegeex.megatron.initialize import initialize_megatron
logging.getLogger("torch").setLevel(logging.WARNING)
def add_code_generate_args(parser):
"""Code generation arguments."""
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--hostfile",
type=str,
default="./hostfile",
)
group.add_argument(
"--channel-ip",
type=str,
default=None,
help="IP for ZeroMQ channel",
)
group.add_argument(
"--channel-port",
type=int,
default=5555,
help="Port for ZeroMQ channel",
)
group.add_argument(
"--master-port",
type=int,
default=5666,
)
group.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
group.add_argument(
"--greedy",
action="store_true",
default=False,
help="Use greedy sampling.",
)
group.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top p sampling.",
)
group.add_argument(
"--top-k",
type=int,
default=0,
help="Top k sampling.",
)
group.add_argument(
"--out-seq-length",
type=int,
default=1024,
help="Size of the output generated text.",
)
group.add_argument(
"--input-path",
type=str,
default="./benchmark/humaneval/HumanEval.jsonl",
help="Get input path",
)
group.add_argument(
"--num-samples",
type=int,
default=0,
help="Number of samples to generate",
)
group.add_argument(
"--recompute",
action="store_true",
help="During generation recompute all attention "
"instead of using previously computed keys/values.",
)
group.add_argument(
"--load-deepspeed",
action="store_true",
help="Load DeepSpeed checkpoint",
)
group.add_argument(
"--ws-encoding-start-id",
type=int,
default=None,
help="Start id for whitespace encoding",
)
group.add_argument(
"--ws-encoding-length",
type=int,
default=None,
help="Length of whitespace encoding",
)
group.add_argument(
"--dataset",
type=str,
default="humaneval",
)
group.add_argument(
"--samples-per-problem",
type=int,
default=200,
help="Number of samples to generate for each problem",
)
group.add_argument(
"--output-prefix",
type=str,
default="./output/humaneval",
help="Prefix for output files",
)
group.add_argument(
"--gen-node-rank",
type=int,
default=None,
)
group.add_argument(
"--gen-node-world-size",
type=int,
default=None,
)
group.add_argument(
"--gen-world-size",
type=int,
default=1,
help="Number of machines to use for generation",
)
group.add_argument(
"--gen-rank",
type=int,
default=0,
help="Machine rank for human eval generation",
)
group.add_argument(
"--extra-prompt",
type=str,
default=None,
help="Extra prompt to use for human eval generation",
)
group.add_argument(
"--verbose-interval",
type=int,
default=100,
)
group.add_argument(
"--problem-split",
type=str,
default="test",
)
group.add_argument(
"--prompt-type",
type=str,
default="notag",
)
group.add_argument(
"--num-devices-per-node",
type=int,
default=None,
)
group.add_argument(
"--return-scores",
action="store_true",
)
group.add_argument(
"--free-guidance",
action="store_true",
)
group.add_argument(
"--guide-temp",
type=float,
default=1.5,
)
group.add_argument(
"--attention-upweight",
type=float,
default=None,
)
group.add_argument(
'--bad-ids',
nargs="*",
type=int,
default=None,
help='Identify the type of programming language to generate',
)
group.add_argument(
"--src-path",
type=str,
default="",
help="Get source path",
)
group.add_argument(
"--tgt-path",
type=str,
default="",
help="Get target path",
)
group.add_argument(
'--language-src-type',
type=str,
default=None,
help='Identify the type of programming language',
)
group.add_argument(
'--language-tgt-type',
type=str,
default=None,
help='Identify the type of programming language to translate',
)
return parser
def main(node_rank: int, local_rank: int, master_port: int, num_devices: int):
"""Main program."""
os.environ["WORLD_SIZE"] = str(num_devices)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_ADDR"] = "0.0.0.0"
os.environ["MASTER_PORT"] = f"{master_port}"
initialize_megatron(
extra_args_provider=add_code_generate_args,
args_defaults={
"tokenizer_type": "GPT2BPETokenizer",
"no_load_rng" : True,
"no_load_optim" : True,
},
)
# set_random_seed(node_rank * num_devices + local_rank)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
world_size = args.gen_node_world_size * num_devices
args.gen_rank = num_devices * node_rank + local_rank
args.gen_world_size = world_size
print(f"Generating on rank {args.gen_rank} of {args.gen_world_size}")
# Set up model and load checkpoint.
state_dict = torch.load(args.load, map_location="cpu")
state_dict = state_dict["module"]
print("Building CodeGeeX model ...")
model = model_provider()
model.load_state_dict(state_dict)
model.eval()
if args.fp16 and args.ln_fp16:
model.half()
model.cuda()
# Generate samples.
run_generation_distributed(model)
print(f"(gen_rank={args.gen_rank}, rank={local_rank}) finished, waiting ...")
torch.distributed.barrier()
def server():
print(f"[ server ] starting ...", flush=True)
parser = argparse.ArgumentParser()
parser.add_argument(
"--channel-ip",
type=str,
default=None,
help="IP for ZeroMQ channel",
)
parser.add_argument(
"--channel-port",
type=int,
default=5555,
help="Port for ZeroMQ channel",
)
parser.add_argument(
"--master-port",
type=int,
default=6666,
help="Port for torch distributed",
)
parser.add_argument(
"--samples-per-problem",
type=int,
default=200,
help="Number of samples to generate for each problem",
)
parser.add_argument(
"--gen-node-world-size",
type=int,
default=1,
help="Number of machines to use for generation",
)
parser.add_argument(
"--src-path",
type=str,
default="",
help="Get source path",
)
parser.add_argument(
"--tgt-path",
type=str,
default="",
help="Get target path",
)
parser.add_argument(
"--problem-split",
type=str,
default="test",
)
parser.add_argument(
"--micro-batch-size",
type=int,
default=1,
)
parser.add_argument(
'--language-src-type',
type=str,
default=None,
help='Identify the type of programming language',
)
parser.add_argument(
'--language-tgt-type',
type=str,
default=None,
help='Identify the type of programming language to translate',
)
args = parser.parse_known_args()[0]
entries = read_translation_dataset(args.src_path,
args.tgt_path,
lang_src=args.language_src_type,
lang_tgt=args.language_tgt_type,
dataset_type="humaneval")
assert args.samples_per_problem % args.micro_batch_size == 0, "samples_per_problem should be divisible by micro_batch_size"
res = []
for entry in entries.values():
res.extend([entry] * (args.samples_per_problem // args.micro_batch_size))
random.shuffle(res)
all_entries = res
# setup zeromq channel
print(f"[ server ] starting up on port {args.channel_port}", flush=True)
context = zmq.Context()
print(f"[ server ] creating socket", flush=True)
socket = context.socket(zmq.REP)
print(f"[ server ] binding to port {args.channel_port}", flush=True)
socket.bind(f"tcp://*:{args.channel_port}")
print(
f"[ server ] loaded {len(entries)} entries, generated {len(all_entries)} samples",
flush=True,
)
remaining_entries = all_entries.copy()
running_workers = args.gen_node_world_size * torch.cuda.device_count()
num_finished = 0
print(f"[ server ] listening for requests ...", flush=True)
start_time = time.perf_counter()
while True:
# Wait for next request from client
msg = socket.recv_json()
rank = msg["rank"]
action = msg["action"]
if action == "pull":
if len(remaining_entries) == 0:
print(f"[ server ] Shutting down worker {rank}", flush=True)
socket.send_json({"task_id": None})
running_workers -= 1
if running_workers == 0:
print(f"[ server ] All workers finished", flush=True)
break
else:
entry = remaining_entries.pop()
time_elapsed = time.perf_counter() - start_time
print(f"[ server ] Sending entry {entry['task_id']} to worker {rank}", flush=True)
remaining = (
len(remaining_entries)
/ (len(all_entries) - len(remaining_entries))
* time_elapsed
)
time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished / args.micro_batch_size
print(
f"[ server ] total {len(all_entries)}, assigned {len(all_entries) - len(remaining_entries)}, "
f"finished {num_finished}, "
f"elapsed {time_elapsed:.4f}",
f"speed {time_per_sampple:.4f}s/sample",
f"remaining {remaining:.4f}",
flush=True,
)
socket.send_json({"task_id": entry})
else:
if action == "success":
print(f"[ server ] {msg['task_id']} is finished", flush=True)
socket.send_json({"pong": 1})
else:
print(f"[ server ] {msg['task_id']} is not finished", flush=True)
remaining_entries.append(msg['task_id'])
socket.send_json({"pong": 1})
break
num_finished += 1
if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
parser = argparse.ArgumentParser()
parser.add_argument(
"--hostfile",
type=str,
default="./hostfile",
)
parser.add_argument(
"--master-port",
type=int,
default=5666,
)
parser.add_argument(
"--seed",
type=int,
default=42,
)
args = parser.parse_known_args()[0]
print("start method: " + torch.multiprocessing.get_start_method())
processes = []
num_devices = torch.cuda.device_count()
hosts = open(args.hostfile, "r").readlines()
hosts = [host.strip() for host in hosts]
master_port = args.master_port
node_rank = None
for i in range(len(hosts)):
if hosts[i] == socket.gethostbyname(socket.gethostname()):
node_rank = i
break
assert (
node_rank is not None
), f"Could not find hostname ({socket.gethostbyname(socket.gethostname())}) in hostlist"
# launch server
if socket.gethostbyname(socket.gethostname()) == hosts[0]:
server_process = torch.multiprocessing.Process(target=server)
print(f"Launching server ...")
server_process.start()
processes.append(server_process)
for i in range(num_devices):
local_rank = i
print(f"launching local rank {i}")
p = torch.multiprocessing.Process(
target=main,
args=(node_rank, local_rank, master_port, num_devices),
)
p.start()
processes.append(p)
for p in processes:
p.join()
import os
from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG
IMPORT_HELPER = {
"python": [
"import math",
"import re",
"import sys",
"import copy",
"import datetime",
"import itertools",
"import collections",
"import heapq",
"import statistics",
"import functools",
"import hashlib",
"import numpy",
"import numpy as np",
"import string",
"from typing import *",
"from collections import *",
],
"go" : [
"math",
"strings",
"fmt",
"strconv",
"time",
"bytes",
"regexp",
"sort",
"math/rand",
"crypto/md5",
],
"cpp" : [
"#include<stdlib.h>",
"#include<algorithm>",
"#include<math.h>",
"#include<stdio.h>",
"#include<vector>",
"#include<string>",
"#include<climits>",
"#include<cstring>",
"#include<iostream>",
],
}
def read_dataset(
data_file: str = None,
dataset_type: str = "humaneval",
language_type: str = "python",
num_shot=None,
) -> Dict:
if num_shot is not None:
print(f"{num_shot}-shot setting...")
if "humaneval" in dataset_type.lower():
data_file = os.path.join(data_file, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
else:
raise f"Dataset: {dataset_type} not supported."
return dataset
def read_translation_dataset(
data_file_src: str = None,
data_file_tgt: str = None,
lang_src: str = None,
lang_tgt: str = None,
dataset_type: str = "humaneval",
) -> Dict:
if "humaneval" in dataset_type.lower():
dataset_src = {task["task_id"]: task for task in stream_jsonl(data_file_src)}
dataset_tgt = {task["task_id"].split("/")[-1]: task for task in stream_jsonl(data_file_tgt)}
for k, sample in dataset_src.items():
prompt = "code translation\n"
if lang_src == "cpp":
prompt += "C++:\n"
elif lang_src == "js":
prompt += "JavaScript:\n"
else:
prompt += f"{lang_src}:\n".capitalize()
prompt += dataset_src[k]["declaration"] + "\n" + dataset_src[k]["canonical_solution"].rstrip() + "\n"
if lang_tgt == "cpp":
prompt += "C++:\n"
elif lang_tgt == "js":
prompt += "JavaScript:\n"
else:
prompt += f"{lang_tgt}:\n".capitalize()
prompt += dataset_tgt[k.split("/")[-1]]["declaration"]
dataset_src[k]["prompt"] = prompt
else:
raise f"Dataset: {dataset_type} not supported."
return dataset_src
def process_extra_prompt(prompt: str, language_type: str = None) -> str:
"""
Processes the extra prompt.
"""
language = language_type.lower()
if language in LANGUAGE_TAG:
extra_prompt = LANGUAGE_TAG[language] + "\n"
else:
extra_prompt = ""
return extra_prompt + prompt
def is_code_generation_finished(
code: str,
language_type: str = None,
dataset: str = None,
):
"""
Checks whether the generated code is finished.
"""
if language_type is None or dataset is None:
return False
if "humaneval" in dataset.lower():
if language_type.lower() == "python":
for line in code.split("\n"):
if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
return True
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
for w in end_words:
if w in code:
return True
elif language_type.lower() == "java":
if code.count("{") + 1 == code.count("}"):
return True
elif language_type.lower() == "go":
if code.count("{") + 1 == code.count("}"):
return True
elif language_type.lower() == "js":
if code.count("{") + 1 == code.count("}"):
return True
elif language_type.lower() == "cpp":
if code.count("{") + 1 == code.count("}"):
return True
return False
def cleanup_code(
code: str,
language_type: str = None
):
"""
Cleans up the generated code.
"""
if language_type is None:
return code
if language_type.lower() == "python":
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint", "\nassert"]
for w in end_words:
if w in code:
code = code[:code.rfind(w)]
elif language_type.lower() == "java":
main_pos = code.find("public static void main")
if main_pos != -1:
code = code[:main_pos] + '}'
if '}' in code:
code = code[:code.rfind('}')] + '}'
if code.count('{') + 1 == code.count('}'):
code += "\n}"
elif language_type.lower() == "go":
end_words = ["\n//", "\nfunc main("]
for w in end_words:
if w in code:
code = code[:code.rfind(w)]
if '}' in code:
code = code[:code.rfind('}')] + '}'
elif language_type.lower() == "cpp":
if '}' in code:
code = code[:code.rfind('}')] + '}'
elif language_type.lower() == "js":
if '}' in code:
code = code[:code.rfind('}')] + '}'
return code
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册