提交 e0bf81b9 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

fix bug

上级 cc9d2464
......@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k
from execution import check_correctness
LANGUAGE_NAME = {
"cpp" : "CPP",
"go" : "Go",
"java" : "Java",
"js" : "JavaScript",
"cpp": "CPP",
"go": "Go",
"java": "Java",
"js": "JavaScript",
"python": "Python",
}
......@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False):
if language == "python":
code_ = []
for line in code.split("\n"):
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'):
break
if line.strip().startswith("def "):
continue
if line and line[0] != ' ' and line[0] != '\t':
line = " " + line
code_.append(line)
code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
......@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def evaluate_functional_correctness(
input_file: str = None,
tmp_dir: str = "./",
n_workers: int = 32,
language_type: str = "python",
input_folder: str = "../output",
tmp_dir: str = "../output/tmp/",
n_workers: int = 3,
timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz",
out_dir: str = None,
problem_folder: str = "../eval_set/humaneval-x/",
out_dir: str = "../output/",
k: List[int] = [1, 10, 100],
test_groundtruth: bool = False,
example_test: bool = False,
model_name: str = "chatgpt"
):
if example_test:
print("Example test...")
problems = read_dataset(problem_file,
dataset_type="humaneval")
input_file = f"{input_folder}/humaneval_{model_name}_{language_type}_finished.jsonl"
sample_jsonl = stream_jsonl_all(input_file)
problems = read_dataset(data_folder=problem_folder, language_type=language_type, dataset_type="humaneval")
if example_test:
suffix = "_example_test.jsonl"
else:
......@@ -125,14 +129,6 @@ def evaluate_functional_correctness(
else:
out_file = os.path.join(input_file.replace(".jsonl", suffix))
if "/codegeex/benchmark/humaneval-x/" in input_file:
test_groundtruth = True
if "-to-" in input_file:
translation_mode = True
else:
translation_mode = False
with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = []
......@@ -162,14 +158,6 @@ def evaluate_functional_correctness(
for sample in tqdm(sample_jsonl):
task_id = sample["task_id"]
lang = task_id.split("/")[0].lower()
if translation_mode:
task_id = sample["task_id"].split("/")[-1]
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
for l in LANGUAGE_NAME:
if l in lang:
lang = l
break
task_id = f"{LANGUAGE_NAME[lang]}/{task_id}"
if lang == "javascript":
lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
......@@ -187,7 +175,6 @@ def evaluate_functional_correctness(
completion_id[task_id] += 1
n_samples += 1
print(completion_id)
if len(completion_id) == len(problems):
evaluate_pass_at_k = True
else:
......
此差异已折叠。
......@@ -19,7 +19,7 @@ if __name__ == "__main__":
parser.add_argument(
"--input_path",
type=str,
default="eval_set/humaneval-x"
default="../eval_set/humaneval-x"
)
parser.add_argument(
"--language_type",
......@@ -30,18 +30,18 @@ if __name__ == "__main__":
"--model_name",
type=str,
default="chatgpt",
help='supported model in [chatgpt,chatglm2].'
help='supported models in [chatgpt, chatglm2, bbt].'
)
parser.add_argument(
"--output_prefix",
type=str,
default="chatgpt"
default="../output/humaneval"
)
args = parser.parse_known_args()[0]
output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl"
output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl"
entries = read_dataset(args.input_path, args.language_type, dataset_type="humaneval")
entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval")
for entry in entries.values():
entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)
......@@ -52,6 +52,9 @@ if __name__ == "__main__":
elif args.model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference
model_inference = GLMInference()
elif args.model_name == "bbt":
from inference.bbt_inference import BBTInference
model_inference = BBTInference()
else:
print(f"{args.model_name} not supported.")
......@@ -60,8 +63,16 @@ if __name__ == "__main__":
with open(output_file_path, "w") as f:
for entry in entries.values():
prompt = entry["prompt"]
retry_times = 3
generated_tokens = model_inference.inference(prompt)
generated_code = cleanup_code(generated_code, language_type=args.language_type)
while generated_tokens is None and retry_times > 0:
time.sleep(10)
generated_tokens = model_inference.inference(prompt)
retry_times -= 1
if generated_tokens is None:
print(f"task_id: {entry['task_id']} generate failed!")
continue
generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type)
f.write(
json.dumps(
{
......@@ -79,11 +90,11 @@ if __name__ == "__main__":
num_finished += 1
time_elapsed = time.perf_counter() - start_time
time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished
time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished
print(
f"finished {num_finished}, "
f"elapsed {time_elapsed:.4f}",
f"speed {time_per_sampple:.4f}s/sample",
f"speed {time_per_sample:.4f}s/sample",
f"remaining {len(entries) - num_finished:.4f}",
flush=True,
)
import os
import json
import logging
from urllib.parse import quote
import requests
import traceback
from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__)
class BBTInference(Inference):
def __init__(self):
super(BBTInference, self).__init__()
self.params_url = "../llm_set/params/bbt.json"
self.paras_dict = self.get_params()
self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.url = self.paras_dict.get("url")
self.user_id = self.paras_dict.get("user_id")
self.timeout = self.paras_dict.get("timeout")
self.modelConfig = {
'temperature': self.paras_dict.get("temperature"),
'top_p': self.paras_dict.get("top_p"),
# 'top_k': 80,
# 'do_sample': True,
# 'num_beams': 4,
# 'no_repeat_ngram_size': 4,
# 'repetition_penalty': 1.1,
# 'length_penalty': 1,
'max_new_tokens': self.paras_dict.get("max_new_tokens"),
}
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
@exception_reconnect
def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
try:
data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)}
resp = requests.post(self.url, data=data)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return None
if resp.status_code == 200:
resp_json = json.loads(resp.text)
return resp_json["msg"]
else:
return None
......@@ -3,15 +3,16 @@ import json
import torch
import logging
from transformers import AutoModel, AutoTokenizer
from inference import Inference
from .inference import Inference
logger = logging.getLogger(__name__)
class GLMInference(Inference):
def __int__(self):
self.params_url = "llm_set/params/chatgpt.json"
def __init__(self):
super(GLMInference, self).__init__()
self.params_url = "../llm_set/params/chatglm2.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
......
......@@ -2,20 +2,23 @@
import os
import json
import logging
import quote
from urllib.parse import quote
import requests
import traceback
from inference import Inference
from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__)
class ChatGPTInference(Inference):
def __int__(self):
self.params_url = "llm_set/params/chatgpt.json"
def __init__(self):
super(ChatGPTInference, self).__init__()
self.params_url = "../llm_set/params/chatgpt.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.gpt_url = self.paras_dict.get("url")
self.id = self.paras_dict.get("id")
self.stream = self.paras_dict.get("stream")
......@@ -28,9 +31,11 @@ class ChatGPTInference(Inference):
content = open(self.params_url).read()
return json.loads(content)
@exception_reconnect
def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
query_text = quote(query_text)
param_str = f"?id={self.gpt_url}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
try:
resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
except Exception as e:
......
......@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
class Inference(object):
def __int__(self):
self.params_base_url = "llm_set/params/default.json"
def __init__(self):
self.params_base_url = "../llm_set/params/default.json"
self.paras_base_dict = self.get_paras_base()
def get_paras_base(self):
......
import os
import re
from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter
IMPORT_HELPER = {
......@@ -49,7 +51,7 @@ IMPORT_HELPER = {
def read_dataset(
data_file: str = None,
data_folder: str = None,
dataset_type: str = "humaneval",
language_type: str = "python",
num_shot=None,
......@@ -57,7 +59,7 @@ def read_dataset(
if num_shot is not None:
print(f"{num_shot}-shot setting...")
if "humaneval" in dataset_type.lower():
data_file = os.path.join(data_file, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
data_file = os.path.join(data_folder, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
else:
raise f"Dataset: {dataset_type} not supported."
......@@ -147,41 +149,191 @@ def is_code_generation_finished(
return False
def cleanup_code(
code: str,
language_type: str = None
):
def find_method_name(language_type, sample):
"""查找方法名"""
if language_type.lower() == "python":
declaration = sample["declaration"]
ret = re.search("def (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type.lower() == "cpp":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() == "java":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() in ["js", "javascript"]:
declaration = sample["declaration"]
ret = re.search("const (.*?) ", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
return method_name
elif language_type.lower() == "go":
declaration = sample["declaration"]
ret = re.search("func (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type == "rust":
declaration = sample["declaration"]
ret = re.search("fn (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
else:
return None
def extract_markdown_code(content):
"""提取markdown中的代码,即"```"中的内容"""
codes = []
rets = re.findall("```([\\s\\S]*?)```", content)
if rets is None:
return codes
for ret in rets:
if not ret.startswith("\n"):
lines = ret.split("\n")
codes.append("".join(lines[1:]))
else:
codes.append(ret.strip())
return codes
def cleanup_code(code: str, sample, language_type: str = None):
"""
Cleans up the generated code.
"""
if language_type is None:
return code
method_name = find_method_name(language_type, sample)
if method_name is None:
return code
method_body = code
if language_type.lower() == "python":
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint", "\nassert"]
for w in end_words:
if w in code:
code = code[:code.rfind(w)]
method_lines = []
for line in code.split("\n"):
if f"def {method_name}" in line:
method_lines.append(line)
continue
if method_lines:
method_lines.append(line)
if line.startswith(" return"):
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "java":
main_pos = code.find("public static void main")
if main_pos != -1:
code = code[:main_pos] + '}'
if '}' in code:
code = code[:code.rfind('}')] + '}'
if code.count('{') + 1 == code.count('}'):
code += "\n}"
method_lines = []
bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
new_line = line.strip()
counter = Counter(new_line)
if new_line.startswith("public") and method_name in new_line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_lines.append("}")
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "go":
end_words = ["\n//", "\nfunc main("]
for w in end_words:
if w in code:
code = code[:code.rfind(w)]
if '}' in code:
code = code[:code.rfind('}')] + '}'
method_lines = []
bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f"func {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "cpp":
if '}' in code:
code = code[:code.rfind('}')] + '}'
method_lines = []
bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f" {method_name}(" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "js":
if '}' in code:
code = code[:code.rfind('}')] + '}'
method_lines = []
bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f"const {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
return method_body + "\n"
def exception_reconnect(funct):
"""异常重连"""
def wrapper_func(*args, **kwargs):
try:
return funct(*args, **kwargs)
except Exception as e:
print(f"exception will reconnect: {str(e)}")
return funct(*args, **kwargs)
return code
return wrapper_func
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册