提交 e0bf81b9 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

fix bug

上级 cc9d2464
...@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k ...@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k
from execution import check_correctness from execution import check_correctness
LANGUAGE_NAME = { LANGUAGE_NAME = {
"cpp" : "CPP", "cpp": "CPP",
"go" : "Go", "go": "Go",
"java" : "Java", "java": "Java",
"js" : "JavaScript", "js": "JavaScript",
"python": "Python", "python": "Python",
} }
...@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False): ...@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False):
if language == "python": if language == "python":
code_ = [] code_ = []
for line in code.split("\n"): for line in code.split("\n"):
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'): if line.strip().startswith("def "):
break continue
if line and line[0] != ' ' and line[0] != '\t':
line = " " + line
code_.append(line) code_.append(line)
code = "\n".join(code_) code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
...@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]: ...@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def evaluate_functional_correctness( def evaluate_functional_correctness(
input_file: str = None, language_type: str = "python",
tmp_dir: str = "./", input_folder: str = "../output",
n_workers: int = 32, tmp_dir: str = "../output/tmp/",
n_workers: int = 3,
timeout: float = 500.0, timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz", problem_folder: str = "../eval_set/humaneval-x/",
out_dir: str = None, out_dir: str = "../output/",
k: List[int] = [1, 10, 100], k: List[int] = [1, 10, 100],
test_groundtruth: bool = False, test_groundtruth: bool = False,
example_test: bool = False, example_test: bool = False,
model_name: str = "chatgpt"
): ):
if example_test: if example_test:
print("Example test...") print("Example test...")
problems = read_dataset(problem_file, input_file = f"{input_folder}/humaneval_{model_name}_{language_type}_finished.jsonl"
dataset_type="humaneval")
sample_jsonl = stream_jsonl_all(input_file) sample_jsonl = stream_jsonl_all(input_file)
problems = read_dataset(data_folder=problem_folder, language_type=language_type, dataset_type="humaneval")
if example_test: if example_test:
suffix = "_example_test.jsonl" suffix = "_example_test.jsonl"
else: else:
...@@ -125,14 +129,6 @@ def evaluate_functional_correctness( ...@@ -125,14 +129,6 @@ def evaluate_functional_correctness(
else: else:
out_file = os.path.join(input_file.replace(".jsonl", suffix)) out_file = os.path.join(input_file.replace(".jsonl", suffix))
if "/codegeex/benchmark/humaneval-x/" in input_file:
test_groundtruth = True
if "-to-" in input_file:
translation_mode = True
else:
translation_mode = False
with ThreadPoolExecutor(max_workers=n_workers) as executor: with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = [] futures = []
...@@ -162,14 +158,6 @@ def evaluate_functional_correctness( ...@@ -162,14 +158,6 @@ def evaluate_functional_correctness(
for sample in tqdm(sample_jsonl): for sample in tqdm(sample_jsonl):
task_id = sample["task_id"] task_id = sample["task_id"]
lang = task_id.split("/")[0].lower() lang = task_id.split("/")[0].lower()
if translation_mode:
task_id = sample["task_id"].split("/")[-1]
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
for l in LANGUAGE_NAME:
if l in lang:
lang = l
break
task_id = f"{LANGUAGE_NAME[lang]}/{task_id}"
if lang == "javascript": if lang == "javascript":
lang = "js" lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
...@@ -187,7 +175,6 @@ def evaluate_functional_correctness( ...@@ -187,7 +175,6 @@ def evaluate_functional_correctness(
completion_id[task_id] += 1 completion_id[task_id] += 1
n_samples += 1 n_samples += 1
print(completion_id)
if len(completion_id) == len(problems): if len(completion_id) == len(problems):
evaluate_pass_at_k = True evaluate_pass_at_k = True
else: else:
......
此差异已折叠。
...@@ -19,7 +19,7 @@ if __name__ == "__main__": ...@@ -19,7 +19,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--input_path", "--input_path",
type=str, type=str,
default="eval_set/humaneval-x" default="../eval_set/humaneval-x"
) )
parser.add_argument( parser.add_argument(
"--language_type", "--language_type",
...@@ -30,18 +30,18 @@ if __name__ == "__main__": ...@@ -30,18 +30,18 @@ if __name__ == "__main__":
"--model_name", "--model_name",
type=str, type=str,
default="chatgpt", default="chatgpt",
help='supported model in [chatgpt,chatglm2].' help='supported models in [chatgpt, chatglm2, bbt].'
) )
parser.add_argument( parser.add_argument(
"--output_prefix", "--output_prefix",
type=str, type=str,
default="chatgpt" default="../output/humaneval"
) )
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl" output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl"
entries = read_dataset(args.input_path, args.language_type, dataset_type="humaneval") entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval")
for entry in entries.values(): for entry in entries.values():
entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type) entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)
...@@ -52,6 +52,9 @@ if __name__ == "__main__": ...@@ -52,6 +52,9 @@ if __name__ == "__main__":
elif args.model_name == "chatglm2": elif args.model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference from inference.chatglm2_inference import GLMInference
model_inference = GLMInference() model_inference = GLMInference()
elif args.model_name == "bbt":
from inference.bbt_inference import BBTInference
model_inference = BBTInference()
else: else:
print(f"{args.model_name} not supported.") print(f"{args.model_name} not supported.")
...@@ -60,8 +63,16 @@ if __name__ == "__main__": ...@@ -60,8 +63,16 @@ if __name__ == "__main__":
with open(output_file_path, "w") as f: with open(output_file_path, "w") as f:
for entry in entries.values(): for entry in entries.values():
prompt = entry["prompt"] prompt = entry["prompt"]
retry_times = 3
generated_tokens = model_inference.inference(prompt) generated_tokens = model_inference.inference(prompt)
generated_code = cleanup_code(generated_code, language_type=args.language_type) while generated_tokens is None and retry_times > 0:
time.sleep(10)
generated_tokens = model_inference.inference(prompt)
retry_times -= 1
if generated_tokens is None:
print(f"task_id: {entry['task_id']} generate failed!")
continue
generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type)
f.write( f.write(
json.dumps( json.dumps(
{ {
...@@ -79,11 +90,11 @@ if __name__ == "__main__": ...@@ -79,11 +90,11 @@ if __name__ == "__main__":
num_finished += 1 num_finished += 1
time_elapsed = time.perf_counter() - start_time time_elapsed = time.perf_counter() - start_time
time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished
print( print(
f"finished {num_finished}, " f"finished {num_finished}, "
f"elapsed {time_elapsed:.4f}", f"elapsed {time_elapsed:.4f}",
f"speed {time_per_sampple:.4f}s/sample", f"speed {time_per_sample:.4f}s/sample",
f"remaining {len(entries) - num_finished:.4f}", f"remaining {len(entries) - num_finished:.4f}",
flush=True, flush=True,
) )
import os
import json
import logging
from urllib.parse import quote
import requests
import traceback
from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__)
class BBTInference(Inference):
def __init__(self):
super(BBTInference, self).__init__()
self.params_url = "../llm_set/params/bbt.json"
self.paras_dict = self.get_params()
self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.url = self.paras_dict.get("url")
self.user_id = self.paras_dict.get("user_id")
self.timeout = self.paras_dict.get("timeout")
self.modelConfig = {
'temperature': self.paras_dict.get("temperature"),
'top_p': self.paras_dict.get("top_p"),
# 'top_k': 80,
# 'do_sample': True,
# 'num_beams': 4,
# 'no_repeat_ngram_size': 4,
# 'repetition_penalty': 1.1,
# 'length_penalty': 1,
'max_new_tokens': self.paras_dict.get("max_new_tokens"),
}
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
@exception_reconnect
def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
try:
data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)}
resp = requests.post(self.url, data=data)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return None
if resp.status_code == 200:
resp_json = json.loads(resp.text)
return resp_json["msg"]
else:
return None
...@@ -3,15 +3,16 @@ import json ...@@ -3,15 +3,16 @@ import json
import torch import torch
import logging import logging
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from inference import Inference from .inference import Inference
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GLMInference(Inference): class GLMInference(Inference):
def __int__(self): def __init__(self):
self.params_url = "llm_set/params/chatgpt.json" super(GLMInference, self).__init__()
self.params_url = "../llm_set/params/chatglm2.json"
self.paras_dict = self.get_params() self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict) self.paras_dict.update(self.paras_base_dict)
......
...@@ -2,20 +2,23 @@ ...@@ -2,20 +2,23 @@
import os import os
import json import json
import logging import logging
import quote from urllib.parse import quote
import requests import requests
import traceback import traceback
from inference import Inference from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ChatGPTInference(Inference): class ChatGPTInference(Inference):
def __int__(self): def __init__(self):
self.params_url = "llm_set/params/chatgpt.json" super(ChatGPTInference, self).__init__()
self.params_url = "../llm_set/params/chatgpt.json"
self.paras_dict = self.get_params() self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict) self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.gpt_url = self.paras_dict.get("url") self.gpt_url = self.paras_dict.get("url")
self.id = self.paras_dict.get("id") self.id = self.paras_dict.get("id")
self.stream = self.paras_dict.get("stream") self.stream = self.paras_dict.get("stream")
...@@ -28,9 +31,11 @@ class ChatGPTInference(Inference): ...@@ -28,9 +31,11 @@ class ChatGPTInference(Inference):
content = open(self.params_url).read() content = open(self.params_url).read()
return json.loads(content) return json.loads(content)
@exception_reconnect
def inference(self, query_text): def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
query_text = quote(query_text) query_text = quote(query_text)
param_str = f"?id={self.gpt_url}&stream={self.stream}&temperature={self.temperature}&q={query_text}" param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
try: try:
resp = requests.get(self.gpt_url + param_str, timeout=self.timeout) resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
except Exception as e: except Exception as e:
......
...@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__) ...@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
class Inference(object): class Inference(object):
def __int__(self): def __init__(self):
self.params_base_url = "llm_set/params/default.json" self.params_base_url = "../llm_set/params/default.json"
self.paras_base_dict = self.get_paras_base() self.paras_base_dict = self.get_paras_base()
def get_paras_base(self): def get_paras_base(self):
......
import os import os
import re
from typing import * from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG from data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter
IMPORT_HELPER = { IMPORT_HELPER = {
...@@ -49,7 +51,7 @@ IMPORT_HELPER = { ...@@ -49,7 +51,7 @@ IMPORT_HELPER = {
def read_dataset( def read_dataset(
data_file: str = None, data_folder: str = None,
dataset_type: str = "humaneval", dataset_type: str = "humaneval",
language_type: str = "python", language_type: str = "python",
num_shot=None, num_shot=None,
...@@ -57,7 +59,7 @@ def read_dataset( ...@@ -57,7 +59,7 @@ def read_dataset(
if num_shot is not None: if num_shot is not None:
print(f"{num_shot}-shot setting...") print(f"{num_shot}-shot setting...")
if "humaneval" in dataset_type.lower(): if "humaneval" in dataset_type.lower():
data_file = os.path.join(data_file, language_type, "data", f"humaneval_{language_type}.jsonl.gz") data_file = os.path.join(data_folder, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
else: else:
raise f"Dataset: {dataset_type} not supported." raise f"Dataset: {dataset_type} not supported."
...@@ -147,41 +149,191 @@ def is_code_generation_finished( ...@@ -147,41 +149,191 @@ def is_code_generation_finished(
return False return False
def cleanup_code( def find_method_name(language_type, sample):
code: str, """查找方法名"""
language_type: str = None if language_type.lower() == "python":
): declaration = sample["declaration"]
ret = re.search("def (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type.lower() == "cpp":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() == "java":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() in ["js", "javascript"]:
declaration = sample["declaration"]
ret = re.search("const (.*?) ", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
return method_name
elif language_type.lower() == "go":
declaration = sample["declaration"]
ret = re.search("func (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type == "rust":
declaration = sample["declaration"]
ret = re.search("fn (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
else:
return None
def extract_markdown_code(content):
"""提取markdown中的代码,即"```"中的内容"""
codes = []
rets = re.findall("```([\\s\\S]*?)```", content)
if rets is None:
return codes
for ret in rets:
if not ret.startswith("\n"):
lines = ret.split("\n")
codes.append("".join(lines[1:]))
else:
codes.append(ret.strip())
return codes
def cleanup_code(code: str, sample, language_type: str = None):
""" """
Cleans up the generated code. Cleans up the generated code.
""" """
if language_type is None: if language_type is None:
return code return code
method_name = find_method_name(language_type, sample)
if method_name is None:
return code
method_body = code
if language_type.lower() == "python": if language_type.lower() == "python":
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint", "\nassert"] method_lines = []
for w in end_words: for line in code.split("\n"):
if w in code: if f"def {method_name}" in line:
code = code[:code.rfind(w)] method_lines.append(line)
continue
if method_lines:
method_lines.append(line)
if line.startswith(" return"):
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "java": elif language_type.lower() == "java":
main_pos = code.find("public static void main") method_lines = []
if main_pos != -1: bracket_left = 0
code = code[:main_pos] + '}' bracket_right = 0
if '}' in code: for line in code.split("\n"):
code = code[:code.rfind('}')] + '}' new_line = line.strip()
if code.count('{') + 1 == code.count('}'): counter = Counter(new_line)
code += "\n}" if new_line.startswith("public") and method_name in new_line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_lines.append("}")
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "go": elif language_type.lower() == "go":
end_words = ["\n//", "\nfunc main("] method_lines = []
for w in end_words: bracket_left = 0
if w in code: bracket_right = 0
code = code[:code.rfind(w)] for line in code.split("\n"):
if '}' in code: counter = Counter(line)
code = code[:code.rfind('}')] + '}' if f"func {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "cpp": elif language_type.lower() == "cpp":
if '}' in code: method_lines = []
code = code[:code.rfind('}')] + '}' bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f" {method_name}(" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "js": elif language_type.lower() == "js":
if '}' in code: method_lines = []
code = code[:code.rfind('}')] + '}' bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f"const {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
return method_body + "\n"
def exception_reconnect(funct):
"""异常重连"""
def wrapper_func(*args, **kwargs):
try:
return funct(*args, **kwargs)
except Exception as e:
print(f"exception will reconnect: {str(e)}")
return funct(*args, **kwargs)
return code return wrapper_func
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册