提交 e0bf81b9 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

fix bug

上级 cc9d2464
......@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k
from execution import check_correctness
LANGUAGE_NAME = {
"cpp" : "CPP",
"go" : "Go",
"java" : "Java",
"js" : "JavaScript",
"cpp": "CPP",
"go": "Go",
"java": "Java",
"js": "JavaScript",
"python": "Python",
}
......@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False):
if language == "python":
code_ = []
for line in code.split("\n"):
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'):
break
if line.strip().startswith("def "):
continue
if line and line[0] != ' ' and line[0] != '\t':
line = " " + line
code_.append(line)
code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
......@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def evaluate_functional_correctness(
input_file: str = None,
tmp_dir: str = "./",
n_workers: int = 32,
language_type: str = "python",
input_folder: str = "../output",
tmp_dir: str = "../output/tmp/",
n_workers: int = 3,
timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz",
out_dir: str = None,
problem_folder: str = "../eval_set/humaneval-x/",
out_dir: str = "../output/",
k: List[int] = [1, 10, 100],
test_groundtruth: bool = False,
example_test: bool = False,
model_name: str = "chatgpt"
):
if example_test:
print("Example test...")
problems = read_dataset(problem_file,
dataset_type="humaneval")
input_file = f"{input_folder}/humaneval_{model_name}_{language_type}_finished.jsonl"
sample_jsonl = stream_jsonl_all(input_file)
problems = read_dataset(data_folder=problem_folder, language_type=language_type, dataset_type="humaneval")
if example_test:
suffix = "_example_test.jsonl"
else:
......@@ -125,14 +129,6 @@ def evaluate_functional_correctness(
else:
out_file = os.path.join(input_file.replace(".jsonl", suffix))
if "/codegeex/benchmark/humaneval-x/" in input_file:
test_groundtruth = True
if "-to-" in input_file:
translation_mode = True
else:
translation_mode = False
with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = []
......@@ -162,14 +158,6 @@ def evaluate_functional_correctness(
for sample in tqdm(sample_jsonl):
task_id = sample["task_id"]
lang = task_id.split("/")[0].lower()
if translation_mode:
task_id = sample["task_id"].split("/")[-1]
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
for l in LANGUAGE_NAME:
if l in lang:
lang = l
break
task_id = f"{LANGUAGE_NAME[lang]}/{task_id}"
if lang == "javascript":
lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
......@@ -187,7 +175,6 @@ def evaluate_functional_correctness(
completion_id[task_id] += 1
n_samples += 1
print(completion_id)
if len(completion_id) == len(problems):
evaluate_pass_at_k = True
else:
......
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
import contextlib
import faulthandler
import io
......@@ -41,20 +61,7 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non
out.write(jout)
def check_correctness(
task_id: str,
sample: dict,
language_type: str,
timeout: float = 3.0,
tmp_dir: str = None,
completion_id: Optional[int] = None,
) -> Dict:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
"""
def unsafe_execute(tmp_dir):
def unsafe_execute(tmp_dir, language_type, timeout, sample, result, task_id):
random_id = random.uniform(1, 1000)
if "python" in language_type.lower():
with create_tempdir():
......@@ -108,8 +115,8 @@ def check_correctness(
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
os.chdir(tmp_dir)
open(f"main_test.go", 'w').write(sample["test_code"])
# os.chdir(tmp_dir)
open(os.path.join(tmp_dir, "main_test.go"), 'w').write(sample["test_code"])
try:
exec_result = None
with time_limit(timeout):
......@@ -122,7 +129,7 @@ def check_correctness(
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True)
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", os.path.join(tmp_dir, "main_test.go")], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
result.append("passed")
......@@ -153,8 +160,8 @@ def check_correctness(
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
os.chdir(tmp_dir)
open(f"test.js", 'w').write(sample["test_code"])
# os.chdir(tmp_dir)
open(os.path.join(tmp_dir, "test.js"), 'w').write(sample["test_code"])
try:
exec_result = None
with time_limit(timeout):
......@@ -167,7 +174,7 @@ def check_correctness(
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
exec_result = subprocess.run(["node", os.path.join(tmp_dir, "test.js")], timeout=timeout, capture_output=True)
if exec_result.stderr.decode():
err = exec_result.stderr.decode()
......@@ -192,14 +199,16 @@ def check_correctness(
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
os.chdir(tmp_dir)
open(f"test.cpp", 'w').write(sample["test_code"])
if "162" in task_id:
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"],
timeout=timeout,
capture_output=True)
else:
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp"], timeout=timeout,
# os.chdir(tmp_dir)
open(os.path.join(tmp_dir, "test.cpp"), 'w').write(sample["test_code"])
# if "162" in task_id:
# compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp"), "-lcrypto", "-lssl"],
# timeout=timeout,
# capture_output=True)
# else:
# compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout,
# capture_output=True)
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout,
capture_output=True)
if compilation_result.returncode != 0:
if compilation_result.stderr:
......@@ -257,10 +266,10 @@ def check_correctness(
os.makedirs(RUST_SRC, exist_ok=True)
os.makedirs(RUST_BIN, exist_ok=True)
with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f:
#temporal file name
with tempfile.NamedTemporaryFile(dir=RUST_BIN, delete=False) as f:
# temporal file name
file_prefix = sample["task_id"].lower().replace("/", "_")
file_name:str = file_prefix +RUST_EXT
file_name: str = file_prefix + RUST_EXT
os.rename(f.name, os.path.join(RUST_BIN, file_name))
......@@ -292,9 +301,8 @@ def check_correctness(
# 0 means success
if returned_val_compilation == 0:
#Execution pipeline
cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path
# Execution pipeline
cargo_test: str = "cargo test --bin " + file_prefix + " --message-format json >> " + log_path
returned_val_execution = os.system(cargo_test)
if returned_val_execution == 0:
......@@ -305,7 +313,6 @@ def check_correctness(
else:
result.append(f"failed: compilation error")
elif "java" in language_type.lower():
assert tmp_dir is not None, "Java should be evaluated in a temporary dir."
......@@ -318,7 +325,7 @@ def check_correctness(
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
os.chdir(tmp_dir)
# os.chdir(tmp_dir)
open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"])
res = "failed: unknown error"
compile_returncode = -1
......@@ -344,7 +351,7 @@ def check_correctness(
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True)
exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
res = "passed"
elif exec_result.returncode == 1:
......@@ -357,13 +364,26 @@ def check_correctness(
except BaseException as e:
res = f"failed: {e}"
result.append(res)
shutil.rmtree(tmp_dir)
def check_correctness(
task_id: str,
sample: dict,
language_type: str,
timeout: float = 3.0,
tmp_dir: str = None,
completion_id: Optional[int] = None,
) -> Dict:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
"""
manager = multiprocessing.Manager()
result = manager.list()
p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir,))
p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir, language_type, timeout, sample, result, task_id,))
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
......@@ -373,38 +393,19 @@ def check_correctness(
result.append("timed out")
return {
"task_id" : task_id,
"task_id": task_id,
"completion_id": completion_id,
"test_code" : sample["test_code"],
"prompt" : sample["prompt"],
"generation" : sample["generation"],
"result" : result[0],
"passed" : result[0] == "passed",
"finish" : -1 if "finish" not in sample else sample["finish"],
"file" : "" if "file" not in sample else sample["file"],
"output" : [] if "output" not in sample else sample["output"],
"test_code": sample["test_code"],
"prompt": sample["prompt"],
"generation": sample["generation"],
"result": result[0],
"passed": result[0] == "passed",
"finish": -1 if "finish" not in sample else sample["finish"],
"file": "" if "file" not in sample else sample["file"],
"output": [] if "output" not in sample else sample["output"],
}
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
@contextlib.contextmanager
def time_limit(seconds: float):
def signal_handler(signum, frame):
......
......@@ -19,7 +19,7 @@ if __name__ == "__main__":
parser.add_argument(
"--input_path",
type=str,
default="eval_set/humaneval-x"
default="../eval_set/humaneval-x"
)
parser.add_argument(
"--language_type",
......@@ -30,18 +30,18 @@ if __name__ == "__main__":
"--model_name",
type=str,
default="chatgpt",
help='supported model in [chatgpt,chatglm2].'
help='supported models in [chatgpt, chatglm2, bbt].'
)
parser.add_argument(
"--output_prefix",
type=str,
default="chatgpt"
default="../output/humaneval"
)
args = parser.parse_known_args()[0]
output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl"
output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl"
entries = read_dataset(args.input_path, args.language_type, dataset_type="humaneval")
entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval")
for entry in entries.values():
entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)
......@@ -52,6 +52,9 @@ if __name__ == "__main__":
elif args.model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference
model_inference = GLMInference()
elif args.model_name == "bbt":
from inference.bbt_inference import BBTInference
model_inference = BBTInference()
else:
print(f"{args.model_name} not supported.")
......@@ -60,8 +63,16 @@ if __name__ == "__main__":
with open(output_file_path, "w") as f:
for entry in entries.values():
prompt = entry["prompt"]
retry_times = 3
generated_tokens = model_inference.inference(prompt)
generated_code = cleanup_code(generated_code, language_type=args.language_type)
while generated_tokens is None and retry_times > 0:
time.sleep(10)
generated_tokens = model_inference.inference(prompt)
retry_times -= 1
if generated_tokens is None:
print(f"task_id: {entry['task_id']} generate failed!")
continue
generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type)
f.write(
json.dumps(
{
......@@ -79,11 +90,11 @@ if __name__ == "__main__":
num_finished += 1
time_elapsed = time.perf_counter() - start_time
time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished
time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished
print(
f"finished {num_finished}, "
f"elapsed {time_elapsed:.4f}",
f"speed {time_per_sampple:.4f}s/sample",
f"speed {time_per_sample:.4f}s/sample",
f"remaining {len(entries) - num_finished:.4f}",
flush=True,
)
import os
import json
import logging
from urllib.parse import quote
import requests
import traceback
from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__)
class BBTInference(Inference):
def __init__(self):
super(BBTInference, self).__init__()
self.params_url = "../llm_set/params/bbt.json"
self.paras_dict = self.get_params()
self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.url = self.paras_dict.get("url")
self.user_id = self.paras_dict.get("user_id")
self.timeout = self.paras_dict.get("timeout")
self.modelConfig = {
'temperature': self.paras_dict.get("temperature"),
'top_p': self.paras_dict.get("top_p"),
# 'top_k': 80,
# 'do_sample': True,
# 'num_beams': 4,
# 'no_repeat_ngram_size': 4,
# 'repetition_penalty': 1.1,
# 'length_penalty': 1,
'max_new_tokens': self.paras_dict.get("max_new_tokens"),
}
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
@exception_reconnect
def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
try:
data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)}
resp = requests.post(self.url, data=data)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return None
if resp.status_code == 200:
resp_json = json.loads(resp.text)
return resp_json["msg"]
else:
return None
......@@ -3,15 +3,16 @@ import json
import torch
import logging
from transformers import AutoModel, AutoTokenizer
from inference import Inference
from .inference import Inference
logger = logging.getLogger(__name__)
class GLMInference(Inference):
def __int__(self):
self.params_url = "llm_set/params/chatgpt.json"
def __init__(self):
super(GLMInference, self).__init__()
self.params_url = "../llm_set/params/chatglm2.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
......
......@@ -2,20 +2,23 @@
import os
import json
import logging
import quote
from urllib.parse import quote
import requests
import traceback
from inference import Inference
from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__)
class ChatGPTInference(Inference):
def __int__(self):
self.params_url = "llm_set/params/chatgpt.json"
def __init__(self):
super(ChatGPTInference, self).__init__()
self.params_url = "../llm_set/params/chatgpt.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.gpt_url = self.paras_dict.get("url")
self.id = self.paras_dict.get("id")
self.stream = self.paras_dict.get("stream")
......@@ -28,9 +31,11 @@ class ChatGPTInference(Inference):
content = open(self.params_url).read()
return json.loads(content)
@exception_reconnect
def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
query_text = quote(query_text)
param_str = f"?id={self.gpt_url}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
try:
resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
except Exception as e:
......
......@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
class Inference(object):
def __int__(self):
self.params_base_url = "llm_set/params/default.json"
def __init__(self):
self.params_base_url = "../llm_set/params/default.json"
self.paras_base_dict = self.get_paras_base()
def get_paras_base(self):
......
import os
import re
from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter
IMPORT_HELPER = {
......@@ -49,7 +51,7 @@ IMPORT_HELPER = {
def read_dataset(
data_file: str = None,
data_folder: str = None,
dataset_type: str = "humaneval",
language_type: str = "python",
num_shot=None,
......@@ -57,7 +59,7 @@ def read_dataset(
if num_shot is not None:
print(f"{num_shot}-shot setting...")
if "humaneval" in dataset_type.lower():
data_file = os.path.join(data_file, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
data_file = os.path.join(data_folder, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
else:
raise f"Dataset: {dataset_type} not supported."
......@@ -147,41 +149,191 @@ def is_code_generation_finished(
return False
def cleanup_code(
code: str,
language_type: str = None
):
def find_method_name(language_type, sample):
"""查找方法名"""
if language_type.lower() == "python":
declaration = sample["declaration"]
ret = re.search("def (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type.lower() == "cpp":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() == "java":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() in ["js", "javascript"]:
declaration = sample["declaration"]
ret = re.search("const (.*?) ", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
return method_name
elif language_type.lower() == "go":
declaration = sample["declaration"]
ret = re.search("func (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type == "rust":
declaration = sample["declaration"]
ret = re.search("fn (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
else:
return None
def extract_markdown_code(content):
"""提取markdown中的代码,即"```"中的内容"""
codes = []
rets = re.findall("```([\\s\\S]*?)```", content)
if rets is None:
return codes
for ret in rets:
if not ret.startswith("\n"):
lines = ret.split("\n")
codes.append("".join(lines[1:]))
else:
codes.append(ret.strip())
return codes
def cleanup_code(code: str, sample, language_type: str = None):
"""
Cleans up the generated code.
"""
if language_type is None:
return code
method_name = find_method_name(language_type, sample)
if method_name is None:
return code
method_body = code
if language_type.lower() == "python":
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint", "\nassert"]
for w in end_words:
if w in code:
code = code[:code.rfind(w)]
method_lines = []
for line in code.split("\n"):
if f"def {method_name}" in line:
method_lines.append(line)
continue
if method_lines:
method_lines.append(line)
if line.startswith(" return"):
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "java":
main_pos = code.find("public static void main")
if main_pos != -1:
code = code[:main_pos] + '}'
if '}' in code:
code = code[:code.rfind('}')] + '}'
if code.count('{') + 1 == code.count('}'):
code += "\n}"
method_lines = []
bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
new_line = line.strip()
counter = Counter(new_line)
if new_line.startswith("public") and method_name in new_line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_lines.append("}")
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "go":
end_words = ["\n//", "\nfunc main("]
for w in end_words:
if w in code:
code = code[:code.rfind(w)]
if '}' in code:
code = code[:code.rfind('}')] + '}'
method_lines = []
bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f"func {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "cpp":
if '}' in code:
code = code[:code.rfind('}')] + '}'
method_lines = []
bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f" {method_name}(" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "js":
if '}' in code:
code = code[:code.rfind('}')] + '}'
method_lines = []
bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f"const {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
return code
if method_lines:
method_body = "\n".join(method_lines[1:])
return method_body + "\n"
def exception_reconnect(funct):
"""异常重连"""
def wrapper_func(*args, **kwargs):
try:
return funct(*args, **kwargs)
except Exception as e:
print(f"exception will reconnect: {str(e)}")
return funct(*args, **kwargs)
return wrapper_func
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册