提交 e0bf81b9 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

fix bug

上级 cc9d2464
...@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k ...@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k
from execution import check_correctness from execution import check_correctness
LANGUAGE_NAME = { LANGUAGE_NAME = {
"cpp" : "CPP", "cpp": "CPP",
"go" : "Go", "go": "Go",
"java" : "Java", "java": "Java",
"js" : "JavaScript", "js": "JavaScript",
"python": "Python", "python": "Python",
} }
...@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False): ...@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False):
if language == "python": if language == "python":
code_ = [] code_ = []
for line in code.split("\n"): for line in code.split("\n"):
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'): if line.strip().startswith("def "):
break continue
if line and line[0] != ' ' and line[0] != '\t':
line = " " + line
code_.append(line) code_.append(line)
code = "\n".join(code_) code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
...@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]: ...@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def evaluate_functional_correctness( def evaluate_functional_correctness(
input_file: str = None, language_type: str = "python",
tmp_dir: str = "./", input_folder: str = "../output",
n_workers: int = 32, tmp_dir: str = "../output/tmp/",
n_workers: int = 3,
timeout: float = 500.0, timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz", problem_folder: str = "../eval_set/humaneval-x/",
out_dir: str = None, out_dir: str = "../output/",
k: List[int] = [1, 10, 100], k: List[int] = [1, 10, 100],
test_groundtruth: bool = False, test_groundtruth: bool = False,
example_test: bool = False, example_test: bool = False,
model_name: str = "chatgpt"
): ):
if example_test: if example_test:
print("Example test...") print("Example test...")
problems = read_dataset(problem_file, input_file = f"{input_folder}/humaneval_{model_name}_{language_type}_finished.jsonl"
dataset_type="humaneval")
sample_jsonl = stream_jsonl_all(input_file) sample_jsonl = stream_jsonl_all(input_file)
problems = read_dataset(data_folder=problem_folder, language_type=language_type, dataset_type="humaneval")
if example_test: if example_test:
suffix = "_example_test.jsonl" suffix = "_example_test.jsonl"
else: else:
...@@ -125,14 +129,6 @@ def evaluate_functional_correctness( ...@@ -125,14 +129,6 @@ def evaluate_functional_correctness(
else: else:
out_file = os.path.join(input_file.replace(".jsonl", suffix)) out_file = os.path.join(input_file.replace(".jsonl", suffix))
if "/codegeex/benchmark/humaneval-x/" in input_file:
test_groundtruth = True
if "-to-" in input_file:
translation_mode = True
else:
translation_mode = False
with ThreadPoolExecutor(max_workers=n_workers) as executor: with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = [] futures = []
...@@ -162,14 +158,6 @@ def evaluate_functional_correctness( ...@@ -162,14 +158,6 @@ def evaluate_functional_correctness(
for sample in tqdm(sample_jsonl): for sample in tqdm(sample_jsonl):
task_id = sample["task_id"] task_id = sample["task_id"]
lang = task_id.split("/")[0].lower() lang = task_id.split("/")[0].lower()
if translation_mode:
task_id = sample["task_id"].split("/")[-1]
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
for l in LANGUAGE_NAME:
if l in lang:
lang = l
break
task_id = f"{LANGUAGE_NAME[lang]}/{task_id}"
if lang == "javascript": if lang == "javascript":
lang = "js" lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
...@@ -187,7 +175,6 @@ def evaluate_functional_correctness( ...@@ -187,7 +175,6 @@ def evaluate_functional_correctness(
completion_id[task_id] += 1 completion_id[task_id] += 1
n_samples += 1 n_samples += 1
print(completion_id)
if len(completion_id) == len(problems): if len(completion_id) == len(problems):
evaluate_pass_at_k = True evaluate_pass_at_k = True
else: else:
......
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
import contextlib import contextlib
import faulthandler import faulthandler
import io import io
...@@ -41,20 +61,7 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non ...@@ -41,20 +61,7 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non
out.write(jout) out.write(jout)
def check_correctness( def unsafe_execute(tmp_dir, language_type, timeout, sample, result, task_id):
task_id: str,
sample: dict,
language_type: str,
timeout: float = 3.0,
tmp_dir: str = None,
completion_id: Optional[int] = None,
) -> Dict:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
"""
def unsafe_execute(tmp_dir):
random_id = random.uniform(1, 1000) random_id = random.uniform(1, 1000)
if "python" in language_type.lower(): if "python" in language_type.lower():
with create_tempdir(): with create_tempdir():
...@@ -108,8 +115,8 @@ def check_correctness( ...@@ -108,8 +115,8 @@ def check_correctness(
if not os.path.exists(tmp_dir): if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir) os.makedirs(tmp_dir)
os.chdir(tmp_dir) # os.chdir(tmp_dir)
open(f"main_test.go", 'w').write(sample["test_code"]) open(os.path.join(tmp_dir, "main_test.go"), 'w').write(sample["test_code"])
try: try:
exec_result = None exec_result = None
with time_limit(timeout): with time_limit(timeout):
...@@ -122,7 +129,7 @@ def check_correctness( ...@@ -122,7 +129,7 @@ def check_correctness(
# does not perform destructive actions on their host or network. # does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions, # Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk: # uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True) exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", os.path.join(tmp_dir, "main_test.go")], timeout=timeout, capture_output=True)
if exec_result.returncode == 0: if exec_result.returncode == 0:
result.append("passed") result.append("passed")
...@@ -153,8 +160,8 @@ def check_correctness( ...@@ -153,8 +160,8 @@ def check_correctness(
if not os.path.exists(tmp_dir): if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir) os.makedirs(tmp_dir)
os.chdir(tmp_dir) # os.chdir(tmp_dir)
open(f"test.js", 'w').write(sample["test_code"]) open(os.path.join(tmp_dir, "test.js"), 'w').write(sample["test_code"])
try: try:
exec_result = None exec_result = None
with time_limit(timeout): with time_limit(timeout):
...@@ -167,7 +174,7 @@ def check_correctness( ...@@ -167,7 +174,7 @@ def check_correctness(
# does not perform destructive actions on their host or network. # does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions, # Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk: # uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True) exec_result = subprocess.run(["node", os.path.join(tmp_dir, "test.js")], timeout=timeout, capture_output=True)
if exec_result.stderr.decode(): if exec_result.stderr.decode():
err = exec_result.stderr.decode() err = exec_result.stderr.decode()
...@@ -192,14 +199,16 @@ def check_correctness( ...@@ -192,14 +199,16 @@ def check_correctness(
if not os.path.exists(tmp_dir): if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir) os.makedirs(tmp_dir)
os.chdir(tmp_dir) # os.chdir(tmp_dir)
open(f"test.cpp", 'w').write(sample["test_code"]) open(os.path.join(tmp_dir, "test.cpp"), 'w').write(sample["test_code"])
if "162" in task_id: # if "162" in task_id:
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"], # compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp"), "-lcrypto", "-lssl"],
timeout=timeout, # timeout=timeout,
capture_output=True) # capture_output=True)
else: # else:
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp"], timeout=timeout, # compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout,
# capture_output=True)
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout,
capture_output=True) capture_output=True)
if compilation_result.returncode != 0: if compilation_result.returncode != 0:
if compilation_result.stderr: if compilation_result.stderr:
...@@ -257,10 +266,10 @@ def check_correctness( ...@@ -257,10 +266,10 @@ def check_correctness(
os.makedirs(RUST_SRC, exist_ok=True) os.makedirs(RUST_SRC, exist_ok=True)
os.makedirs(RUST_BIN, exist_ok=True) os.makedirs(RUST_BIN, exist_ok=True)
with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f: with tempfile.NamedTemporaryFile(dir=RUST_BIN, delete=False) as f:
#temporal file name # temporal file name
file_prefix = sample["task_id"].lower().replace("/", "_") file_prefix = sample["task_id"].lower().replace("/", "_")
file_name:str = file_prefix +RUST_EXT file_name: str = file_prefix + RUST_EXT
os.rename(f.name, os.path.join(RUST_BIN, file_name)) os.rename(f.name, os.path.join(RUST_BIN, file_name))
...@@ -292,9 +301,8 @@ def check_correctness( ...@@ -292,9 +301,8 @@ def check_correctness(
# 0 means success # 0 means success
if returned_val_compilation == 0: if returned_val_compilation == 0:
# Execution pipeline
#Execution pipeline cargo_test: str = "cargo test --bin " + file_prefix + " --message-format json >> " + log_path
cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path
returned_val_execution = os.system(cargo_test) returned_val_execution = os.system(cargo_test)
if returned_val_execution == 0: if returned_val_execution == 0:
...@@ -305,7 +313,6 @@ def check_correctness( ...@@ -305,7 +313,6 @@ def check_correctness(
else: else:
result.append(f"failed: compilation error") result.append(f"failed: compilation error")
elif "java" in language_type.lower(): elif "java" in language_type.lower():
assert tmp_dir is not None, "Java should be evaluated in a temporary dir." assert tmp_dir is not None, "Java should be evaluated in a temporary dir."
...@@ -318,7 +325,7 @@ def check_correctness( ...@@ -318,7 +325,7 @@ def check_correctness(
if not os.path.exists(tmp_dir): if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir) os.makedirs(tmp_dir)
os.chdir(tmp_dir) # os.chdir(tmp_dir)
open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"]) open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"])
res = "failed: unknown error" res = "failed: unknown error"
compile_returncode = -1 compile_returncode = -1
...@@ -344,7 +351,7 @@ def check_correctness( ...@@ -344,7 +351,7 @@ def check_correctness(
# does not perform destructive actions on their host or network. # does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions, # Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk: # uncomment the following line and proceed at your own risk:
# exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True) exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True)
if exec_result.returncode == 0: if exec_result.returncode == 0:
res = "passed" res = "passed"
elif exec_result.returncode == 1: elif exec_result.returncode == 1:
...@@ -357,13 +364,26 @@ def check_correctness( ...@@ -357,13 +364,26 @@ def check_correctness(
except BaseException as e: except BaseException as e:
res = f"failed: {e}" res = f"failed: {e}"
result.append(res) result.append(res)
shutil.rmtree(tmp_dir) shutil.rmtree(tmp_dir)
def check_correctness(
task_id: str,
sample: dict,
language_type: str,
timeout: float = 3.0,
tmp_dir: str = None,
completion_id: Optional[int] = None,
) -> Dict:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
"""
manager = multiprocessing.Manager() manager = multiprocessing.Manager()
result = manager.list() result = manager.list()
p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir,)) p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir, language_type, timeout, sample, result, task_id,))
p.start() p.start()
p.join(timeout=timeout + 1) p.join(timeout=timeout + 1)
if p.is_alive(): if p.is_alive():
...@@ -373,38 +393,19 @@ def check_correctness( ...@@ -373,38 +393,19 @@ def check_correctness(
result.append("timed out") result.append("timed out")
return { return {
"task_id" : task_id, "task_id": task_id,
"completion_id": completion_id, "completion_id": completion_id,
"test_code" : sample["test_code"], "test_code": sample["test_code"],
"prompt" : sample["prompt"], "prompt": sample["prompt"],
"generation" : sample["generation"], "generation": sample["generation"],
"result" : result[0], "result": result[0],
"passed" : result[0] == "passed", "passed": result[0] == "passed",
"finish" : -1 if "finish" not in sample else sample["finish"], "finish": -1 if "finish" not in sample else sample["finish"],
"file" : "" if "file" not in sample else sample["file"], "file": "" if "file" not in sample else sample["file"],
"output" : [] if "output" not in sample else sample["output"], "output": [] if "output" not in sample else sample["output"],
} }
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
@contextlib.contextmanager @contextlib.contextmanager
def time_limit(seconds: float): def time_limit(seconds: float):
def signal_handler(signum, frame): def signal_handler(signum, frame):
......
...@@ -19,7 +19,7 @@ if __name__ == "__main__": ...@@ -19,7 +19,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--input_path", "--input_path",
type=str, type=str,
default="eval_set/humaneval-x" default="../eval_set/humaneval-x"
) )
parser.add_argument( parser.add_argument(
"--language_type", "--language_type",
...@@ -30,18 +30,18 @@ if __name__ == "__main__": ...@@ -30,18 +30,18 @@ if __name__ == "__main__":
"--model_name", "--model_name",
type=str, type=str,
default="chatgpt", default="chatgpt",
help='supported model in [chatgpt,chatglm2].' help='supported models in [chatgpt, chatglm2, bbt].'
) )
parser.add_argument( parser.add_argument(
"--output_prefix", "--output_prefix",
type=str, type=str,
default="chatgpt" default="../output/humaneval"
) )
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl" output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl"
entries = read_dataset(args.input_path, args.language_type, dataset_type="humaneval") entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval")
for entry in entries.values(): for entry in entries.values():
entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type) entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)
...@@ -52,6 +52,9 @@ if __name__ == "__main__": ...@@ -52,6 +52,9 @@ if __name__ == "__main__":
elif args.model_name == "chatglm2": elif args.model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference from inference.chatglm2_inference import GLMInference
model_inference = GLMInference() model_inference = GLMInference()
elif args.model_name == "bbt":
from inference.bbt_inference import BBTInference
model_inference = BBTInference()
else: else:
print(f"{args.model_name} not supported.") print(f"{args.model_name} not supported.")
...@@ -60,8 +63,16 @@ if __name__ == "__main__": ...@@ -60,8 +63,16 @@ if __name__ == "__main__":
with open(output_file_path, "w") as f: with open(output_file_path, "w") as f:
for entry in entries.values(): for entry in entries.values():
prompt = entry["prompt"] prompt = entry["prompt"]
retry_times = 3
generated_tokens = model_inference.inference(prompt) generated_tokens = model_inference.inference(prompt)
generated_code = cleanup_code(generated_code, language_type=args.language_type) while generated_tokens is None and retry_times > 0:
time.sleep(10)
generated_tokens = model_inference.inference(prompt)
retry_times -= 1
if generated_tokens is None:
print(f"task_id: {entry['task_id']} generate failed!")
continue
generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type)
f.write( f.write(
json.dumps( json.dumps(
{ {
...@@ -79,11 +90,11 @@ if __name__ == "__main__": ...@@ -79,11 +90,11 @@ if __name__ == "__main__":
num_finished += 1 num_finished += 1
time_elapsed = time.perf_counter() - start_time time_elapsed = time.perf_counter() - start_time
time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished
print( print(
f"finished {num_finished}, " f"finished {num_finished}, "
f"elapsed {time_elapsed:.4f}", f"elapsed {time_elapsed:.4f}",
f"speed {time_per_sampple:.4f}s/sample", f"speed {time_per_sample:.4f}s/sample",
f"remaining {len(entries) - num_finished:.4f}", f"remaining {len(entries) - num_finished:.4f}",
flush=True, flush=True,
) )
import os
import json
import logging
from urllib.parse import quote
import requests
import traceback
from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__)
class BBTInference(Inference):
def __init__(self):
super(BBTInference, self).__init__()
self.params_url = "../llm_set/params/bbt.json"
self.paras_dict = self.get_params()
self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.url = self.paras_dict.get("url")
self.user_id = self.paras_dict.get("user_id")
self.timeout = self.paras_dict.get("timeout")
self.modelConfig = {
'temperature': self.paras_dict.get("temperature"),
'top_p': self.paras_dict.get("top_p"),
# 'top_k': 80,
# 'do_sample': True,
# 'num_beams': 4,
# 'no_repeat_ngram_size': 4,
# 'repetition_penalty': 1.1,
# 'length_penalty': 1,
'max_new_tokens': self.paras_dict.get("max_new_tokens"),
}
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
@exception_reconnect
def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
try:
data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)}
resp = requests.post(self.url, data=data)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return None
if resp.status_code == 200:
resp_json = json.loads(resp.text)
return resp_json["msg"]
else:
return None
...@@ -3,15 +3,16 @@ import json ...@@ -3,15 +3,16 @@ import json
import torch import torch
import logging import logging
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from inference import Inference from .inference import Inference
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GLMInference(Inference): class GLMInference(Inference):
def __int__(self): def __init__(self):
self.params_url = "llm_set/params/chatgpt.json" super(GLMInference, self).__init__()
self.params_url = "../llm_set/params/chatglm2.json"
self.paras_dict = self.get_params() self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict) self.paras_dict.update(self.paras_base_dict)
......
...@@ -2,20 +2,23 @@ ...@@ -2,20 +2,23 @@
import os import os
import json import json
import logging import logging
import quote from urllib.parse import quote
import requests import requests
import traceback import traceback
from inference import Inference from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ChatGPTInference(Inference): class ChatGPTInference(Inference):
def __int__(self): def __init__(self):
self.params_url = "llm_set/params/chatgpt.json" super(ChatGPTInference, self).__init__()
self.params_url = "../llm_set/params/chatgpt.json"
self.paras_dict = self.get_params() self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict) self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.gpt_url = self.paras_dict.get("url") self.gpt_url = self.paras_dict.get("url")
self.id = self.paras_dict.get("id") self.id = self.paras_dict.get("id")
self.stream = self.paras_dict.get("stream") self.stream = self.paras_dict.get("stream")
...@@ -28,9 +31,11 @@ class ChatGPTInference(Inference): ...@@ -28,9 +31,11 @@ class ChatGPTInference(Inference):
content = open(self.params_url).read() content = open(self.params_url).read()
return json.loads(content) return json.loads(content)
@exception_reconnect
def inference(self, query_text): def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
query_text = quote(query_text) query_text = quote(query_text)
param_str = f"?id={self.gpt_url}&stream={self.stream}&temperature={self.temperature}&q={query_text}" param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
try: try:
resp = requests.get(self.gpt_url + param_str, timeout=self.timeout) resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
except Exception as e: except Exception as e:
......
...@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__) ...@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
class Inference(object): class Inference(object):
def __int__(self): def __init__(self):
self.params_base_url = "llm_set/params/default.json" self.params_base_url = "../llm_set/params/default.json"
self.paras_base_dict = self.get_paras_base() self.paras_base_dict = self.get_paras_base()
def get_paras_base(self): def get_paras_base(self):
......
import os import os
import re
from typing import * from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG from data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter
IMPORT_HELPER = { IMPORT_HELPER = {
...@@ -49,7 +51,7 @@ IMPORT_HELPER = { ...@@ -49,7 +51,7 @@ IMPORT_HELPER = {
def read_dataset( def read_dataset(
data_file: str = None, data_folder: str = None,
dataset_type: str = "humaneval", dataset_type: str = "humaneval",
language_type: str = "python", language_type: str = "python",
num_shot=None, num_shot=None,
...@@ -57,7 +59,7 @@ def read_dataset( ...@@ -57,7 +59,7 @@ def read_dataset(
if num_shot is not None: if num_shot is not None:
print(f"{num_shot}-shot setting...") print(f"{num_shot}-shot setting...")
if "humaneval" in dataset_type.lower(): if "humaneval" in dataset_type.lower():
data_file = os.path.join(data_file, language_type, "data", f"humaneval_{language_type}.jsonl.gz") data_file = os.path.join(data_folder, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
else: else:
raise f"Dataset: {dataset_type} not supported." raise f"Dataset: {dataset_type} not supported."
...@@ -147,41 +149,191 @@ def is_code_generation_finished( ...@@ -147,41 +149,191 @@ def is_code_generation_finished(
return False return False
def cleanup_code( def find_method_name(language_type, sample):
code: str, """查找方法名"""
language_type: str = None if language_type.lower() == "python":
): declaration = sample["declaration"]
ret = re.search("def (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type.lower() == "cpp":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() == "java":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() in ["js", "javascript"]:
declaration = sample["declaration"]
ret = re.search("const (.*?) ", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
return method_name
elif language_type.lower() == "go":
declaration = sample["declaration"]
ret = re.search("func (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type == "rust":
declaration = sample["declaration"]
ret = re.search("fn (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
else:
return None
def extract_markdown_code(content):
"""提取markdown中的代码,即"```"中的内容"""
codes = []
rets = re.findall("```([\\s\\S]*?)```", content)
if rets is None:
return codes
for ret in rets:
if not ret.startswith("\n"):
lines = ret.split("\n")
codes.append("".join(lines[1:]))
else:
codes.append(ret.strip())
return codes
def cleanup_code(code: str, sample, language_type: str = None):
""" """
Cleans up the generated code. Cleans up the generated code.
""" """
if language_type is None: if language_type is None:
return code return code
method_name = find_method_name(language_type, sample)
if method_name is None:
return code
method_body = code
if language_type.lower() == "python": if language_type.lower() == "python":
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint", "\nassert"] method_lines = []
for w in end_words: for line in code.split("\n"):
if w in code: if f"def {method_name}" in line:
code = code[:code.rfind(w)] method_lines.append(line)
continue
if method_lines:
method_lines.append(line)
if line.startswith(" return"):
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "java": elif language_type.lower() == "java":
main_pos = code.find("public static void main") method_lines = []
if main_pos != -1: bracket_left = 0
code = code[:main_pos] + '}' bracket_right = 0
if '}' in code: for line in code.split("\n"):
code = code[:code.rfind('}')] + '}' new_line = line.strip()
if code.count('{') + 1 == code.count('}'): counter = Counter(new_line)
code += "\n}" if new_line.startswith("public") and method_name in new_line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_lines.append("}")
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "go": elif language_type.lower() == "go":
end_words = ["\n//", "\nfunc main("] method_lines = []
for w in end_words: bracket_left = 0
if w in code: bracket_right = 0
code = code[:code.rfind(w)] for line in code.split("\n"):
if '}' in code: counter = Counter(line)
code = code[:code.rfind('}')] + '}' if f"func {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "cpp": elif language_type.lower() == "cpp":
if '}' in code: method_lines = []
code = code[:code.rfind('}')] + '}' bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f" {method_name}(" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "js": elif language_type.lower() == "js":
if '}' in code: method_lines = []
code = code[:code.rfind('}')] + '}' bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f"const {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
return code if method_lines:
method_body = "\n".join(method_lines[1:])
return method_body + "\n"
def exception_reconnect(funct):
"""异常重连"""
def wrapper_func(*args, **kwargs):
try:
return funct(*args, **kwargs)
except Exception as e:
print(f"exception will reconnect: {str(e)}")
return funct(*args, **kwargs)
return wrapper_func
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册