提交 e0bf81b9 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

fix bug

上级 cc9d2464
...@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k ...@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k
from execution import check_correctness from execution import check_correctness
LANGUAGE_NAME = { LANGUAGE_NAME = {
"cpp" : "CPP", "cpp": "CPP",
"go" : "Go", "go": "Go",
"java" : "Java", "java": "Java",
"js" : "JavaScript", "js": "JavaScript",
"python": "Python", "python": "Python",
} }
...@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False): ...@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False):
if language == "python": if language == "python":
code_ = [] code_ = []
for line in code.split("\n"): for line in code.split("\n"):
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'): if line.strip().startswith("def "):
break continue
if line and line[0] != ' ' and line[0] != '\t':
line = " " + line
code_.append(line) code_.append(line)
code = "\n".join(code_) code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
...@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]: ...@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def evaluate_functional_correctness( def evaluate_functional_correctness(
input_file: str = None, language_type: str = "python",
tmp_dir: str = "./", input_folder: str = "../output",
n_workers: int = 32, tmp_dir: str = "../output/tmp/",
n_workers: int = 3,
timeout: float = 500.0, timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz", problem_folder: str = "../eval_set/humaneval-x/",
out_dir: str = None, out_dir: str = "../output/",
k: List[int] = [1, 10, 100], k: List[int] = [1, 10, 100],
test_groundtruth: bool = False, test_groundtruth: bool = False,
example_test: bool = False, example_test: bool = False,
model_name: str = "chatgpt"
): ):
if example_test: if example_test:
print("Example test...") print("Example test...")
problems = read_dataset(problem_file, input_file = f"{input_folder}/humaneval_{model_name}_{language_type}_finished.jsonl"
dataset_type="humaneval")
sample_jsonl = stream_jsonl_all(input_file) sample_jsonl = stream_jsonl_all(input_file)
problems = read_dataset(data_folder=problem_folder, language_type=language_type, dataset_type="humaneval")
if example_test: if example_test:
suffix = "_example_test.jsonl" suffix = "_example_test.jsonl"
else: else:
...@@ -125,14 +129,6 @@ def evaluate_functional_correctness( ...@@ -125,14 +129,6 @@ def evaluate_functional_correctness(
else: else:
out_file = os.path.join(input_file.replace(".jsonl", suffix)) out_file = os.path.join(input_file.replace(".jsonl", suffix))
if "/codegeex/benchmark/humaneval-x/" in input_file:
test_groundtruth = True
if "-to-" in input_file:
translation_mode = True
else:
translation_mode = False
with ThreadPoolExecutor(max_workers=n_workers) as executor: with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = [] futures = []
...@@ -162,14 +158,6 @@ def evaluate_functional_correctness( ...@@ -162,14 +158,6 @@ def evaluate_functional_correctness(
for sample in tqdm(sample_jsonl): for sample in tqdm(sample_jsonl):
task_id = sample["task_id"] task_id = sample["task_id"]
lang = task_id.split("/")[0].lower() lang = task_id.split("/")[0].lower()
if translation_mode:
task_id = sample["task_id"].split("/")[-1]
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
for l in LANGUAGE_NAME:
if l in lang:
lang = l
break
task_id = f"{LANGUAGE_NAME[lang]}/{task_id}"
if lang == "javascript": if lang == "javascript":
lang = "js" lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
...@@ -187,7 +175,6 @@ def evaluate_functional_correctness( ...@@ -187,7 +175,6 @@ def evaluate_functional_correctness(
completion_id[task_id] += 1 completion_id[task_id] += 1
n_samples += 1 n_samples += 1
print(completion_id)
if len(completion_id) == len(problems): if len(completion_id) == len(problems):
evaluate_pass_at_k = True evaluate_pass_at_k = True
else: else:
......
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
import contextlib import contextlib
import faulthandler import faulthandler
import io import io
...@@ -41,75 +61,162 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non ...@@ -41,75 +61,162 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non
out.write(jout) out.write(jout)
def check_correctness( def unsafe_execute(tmp_dir, language_type, timeout, sample, result, task_id):
task_id: str, random_id = random.uniform(1, 1000)
sample: dict, if "python" in language_type.lower():
language_type: str, with create_tempdir():
timeout: float = 3.0,
tmp_dir: str = None,
completion_id: Optional[int] = None,
) -> Dict:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
"""
def unsafe_execute(tmp_dir):
random_id = random.uniform(1, 1000)
if "python" in language_type.lower():
with create_tempdir():
# These system calls are needed when cleaning up tempdir.
import os
import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir
# Disable functionalities that can make destructive changes to the test.
reliability_guard()
try:
exec_globals = {}
with swallow_io():
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec(sample["test_code"], exec_globals)
result.append("passed")
except TimeoutException:
result.append("timed out")
except AssertionError as e:
result.append(f"failed: AssertionError")
except BaseException as e:
result.append(f"failed: {e}")
# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
elif "go" in language_type.lower():
assert tmp_dir is not None, "Go should be evaluated in a dir where necessary module files installed."
# These system calls are needed when cleaning up tempdir.
import os import os
import shutil import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir
if "tmp" not in tmp_dir: # Disable functionalities that can make destructive changes to the test.
tmp_dir = os.path.join(tmp_dir, "tmp") reliability_guard()
tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
os.chdir(tmp_dir) try:
open(f"main_test.go", 'w').write(sample["test_code"]) exec_globals = {}
with swallow_io():
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec(sample["test_code"], exec_globals)
result.append("passed")
except TimeoutException:
result.append("timed out")
except AssertionError as e:
result.append(f"failed: AssertionError")
except BaseException as e:
result.append(f"failed: {e}")
# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
elif "go" in language_type.lower():
assert tmp_dir is not None, "Go should be evaluated in a dir where necessary module files installed."
import os
import shutil
if "tmp" not in tmp_dir:
tmp_dir = os.path.join(tmp_dir, "tmp")
tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
# os.chdir(tmp_dir)
open(os.path.join(tmp_dir, "main_test.go"), 'w').write(sample["test_code"])
try:
exec_result = None
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", os.path.join(tmp_dir, "main_test.go")], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
result.append("passed")
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except:
err = exec_result.stdout
result.append(f"failed: {err}")
except TimeoutException:
result.append("timed out")
shutil.rmtree(tmp_dir)
elif "js" in language_type.lower():
import os
import shutil
if "tmp" not in tmp_dir:
tmp_dir = os.path.join(tmp_dir, "tmp")
tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
# os.chdir(tmp_dir)
open(os.path.join(tmp_dir, "test.js"), 'w').write(sample["test_code"])
try:
exec_result = None
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["node", os.path.join(tmp_dir, "test.js")], timeout=timeout, capture_output=True)
if exec_result.stderr.decode():
err = exec_result.stderr.decode()
result.append(f"failed: {err}")
elif exec_result.stdout.decode():
err = exec_result.stdout.decode()
result.append(f"failed: {err}")
else:
result.append("passed")
except TimeoutException:
result.append("timed out")
shutil.rmtree(tmp_dir)
elif "cpp" in language_type.lower():
import os
import shutil
if "tmp" not in tmp_dir:
tmp_dir = os.path.join(tmp_dir, "tmp")
tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
# os.chdir(tmp_dir)
open(os.path.join(tmp_dir, "test.cpp"), 'w').write(sample["test_code"])
# if "162" in task_id:
# compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp"), "-lcrypto", "-lssl"],
# timeout=timeout,
# capture_output=True)
# else:
# compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout,
# capture_output=True)
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout,
capture_output=True)
if compilation_result.returncode != 0:
if compilation_result.stderr:
err = compilation_result.stderr.decode()
else:
err = compilation_result.stdout.decode()
result.append(f"failed: compilation error: {err}")
else:
try: try:
exec_result = None exec_result = None
with time_limit(timeout): with time_limit(timeout):
...@@ -122,7 +229,7 @@ def check_correctness( ...@@ -122,7 +229,7 @@ def check_correctness(
# does not perform destructive actions on their host or network. # does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions, # Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk: # uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True) exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)
if exec_result.returncode == 0: if exec_result.returncode == 0:
result.append("passed") result.append("passed")
...@@ -138,232 +245,145 @@ def check_correctness( ...@@ -138,232 +245,145 @@ def check_correctness(
except: except:
err = exec_result.stdout err = exec_result.stdout
result.append(f"failed: {err}") result.append(f"failed: {err}")
except TimeoutException: except TimeoutException:
result.append("timed out") result.append("timed out")
shutil.rmtree(tmp_dir) shutil.rmtree(tmp_dir)
elif "js" in language_type.lower(): elif "rust" in language_type.lower():
import os import os
import shutil
WD: str = os.path.dirname(os.path.abspath(__file__))
RUST_DIR: str = os.path.join(WD, "rust")
RUST_SRC: str = os.path.join(RUST_DIR, "src")
RUST_BIN: str = os.path.join(RUST_SRC, "bin")
RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp")
RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs")
RUST_EXT: str = ".rs"
# Create mandatory tmp directories
os.makedirs(RUST_TMP_DIR, exist_ok=True)
os.makedirs(RUST_LOGS, exist_ok=True)
os.makedirs(RUST_SRC, exist_ok=True)
os.makedirs(RUST_BIN, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=RUST_BIN, delete=False) as f:
# temporal file name
file_prefix = sample["task_id"].lower().replace("/", "_")
file_name: str = file_prefix + RUST_EXT
os.rename(f.name, os.path.join(RUST_BIN, file_name))
# Sample to pure Rust function
rust_code: str = sample["test_code"]
# dump the rust source code in the target temporal file
f.write(rust_code.encode('utf-8'))
# Proceed towards Rust binaries compilation. Therefore move to Rust module root dir.
os.chdir(RUST_DIR)
# Two possible outcomes
# Pass OR Fail compilation
log_filename: str = file_prefix + ".jsonl"
log_path: str = os.path.join(RUST_LOGS, log_filename)
cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path
# Compilation build status
returned_val_compilation: int
# Overwrite file content
if os.path.exists(log_path):
if(file_size := os.path.getsize(log_path)) >= 0:
os.remove(log_path)
returned_val_compilation = os.system(cargo_check)
if "tmp" not in tmp_dir: else:
tmp_dir = os.path.join(tmp_dir, "tmp") returned_val_compilation = os.system(cargo_check)
tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
os.chdir(tmp_dir) # 0 means success
open(f"test.js", 'w').write(sample["test_code"]) if returned_val_compilation == 0:
try: # Execution pipeline
exec_result = None cargo_test: str = "cargo test --bin " + file_prefix + " --message-format json >> " + log_path
with time_limit(timeout): returned_val_execution = os.system(cargo_test)
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
if exec_result.stderr.decode(): if returned_val_execution == 0:
err = exec_result.stderr.decode() result.append("passed")
result.append(f"failed: {err}")
elif exec_result.stdout.decode():
err = exec_result.stdout.decode()
result.append(f"failed: {err}")
else:
result.append("passed")
except TimeoutException:
result.append("timed out")
shutil.rmtree(tmp_dir)
elif "cpp" in language_type.lower():
import os
import shutil
if "tmp" not in tmp_dir:
tmp_dir = os.path.join(tmp_dir, "tmp")
tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
os.chdir(tmp_dir)
open(f"test.cpp", 'w').write(sample["test_code"])
if "162" in task_id:
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"],
timeout=timeout,
capture_output=True)
else: else:
compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp"], timeout=timeout, result.append(f"failed: execution error")
capture_output=True)
if compilation_result.returncode != 0:
if compilation_result.stderr:
err = compilation_result.stderr.decode()
else:
err = compilation_result.stdout.decode()
result.append(f"failed: compilation error: {err}")
else:
try:
exec_result = None
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)
if exec_result.returncode == 0: else:
result.append("passed") result.append(f"failed: compilation error")
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except:
err = exec_result.stdout
result.append(f"failed: {err}")
except TimeoutException:
result.append("timed out")
shutil.rmtree(tmp_dir)
elif "rust" in language_type.lower():
import os
WD: str = os.path.dirname(os.path.abspath(__file__))
RUST_DIR: str = os.path.join(WD, "rust")
RUST_SRC: str = os.path.join(RUST_DIR, "src")
RUST_BIN: str = os.path.join(RUST_SRC, "bin")
RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp")
RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs")
RUST_EXT: str = ".rs"
# Create mandatory tmp directories
os.makedirs(RUST_TMP_DIR, exist_ok=True)
os.makedirs(RUST_LOGS, exist_ok=True)
os.makedirs(RUST_SRC, exist_ok=True)
os.makedirs(RUST_BIN, exist_ok=True)
with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f:
#temporal file name
file_prefix = sample["task_id"].lower().replace("/", "_")
file_name:str = file_prefix +RUST_EXT
os.rename(f.name, os.path.join(RUST_BIN, file_name))
# Sample to pure Rust function
rust_code: str = sample["test_code"]
# dump the rust source code in the target temporal file
f.write(rust_code.encode('utf-8'))
# Proceed towards Rust binaries compilation. Therefore move to Rust module root dir.
os.chdir(RUST_DIR)
# Two possible outcomes
# Pass OR Fail compilation
log_filename: str = file_prefix + ".jsonl"
log_path: str = os.path.join(RUST_LOGS, log_filename)
cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path
# Compilation build status
returned_val_compilation: int
# Overwrite file content
if os.path.exists(log_path):
if(file_size := os.path.getsize(log_path)) >= 0:
os.remove(log_path)
returned_val_compilation = os.system(cargo_check)
else:
returned_val_compilation = os.system(cargo_check)
# 0 means success elif "java" in language_type.lower():
if returned_val_compilation == 0: assert tmp_dir is not None, "Java should be evaluated in a temporary dir."
#Execution pipeline import os
cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path import shutil
returned_val_execution = os.system(cargo_test)
if returned_val_execution == 0:
result.append("passed")
else:
result.append(f"failed: execution error")
else:
result.append(f"failed: compilation error")
if "tmp" not in tmp_dir:
tmp_dir = os.path.join(tmp_dir, "tmp")
tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}")
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
elif "java" in language_type.lower(): # os.chdir(tmp_dir)
assert tmp_dir is not None, "Java should be evaluated in a temporary dir." open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"])
res = "failed: unknown error"
compile_returncode = -1
for _ in range(5):
try:
compilation_result = subprocess.run(['javac', os.path.join(tmp_dir, "Main.java")], timeout=5,
capture_output=True)
compile_returncode = compilation_result.returncode
break
except subprocess.TimeoutExpired as e:
continue
if compile_returncode != 0:
res = "failed: compilation error"
else:
exec_result = None
try:
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
res = "passed"
elif exec_result.returncode == 1:
if "AssertionError" in exec_result.stderr.decode('unicode-escape'):
res = "failed: wrong answer"
else:
res = f"failed: {exec_result.stderr.decode()}"
except subprocess.TimeoutExpired as e:
res = "time out"
except BaseException as e:
res = f"failed: {e}"
result.append(res)
shutil.rmtree(tmp_dir)
import os
import shutil
if "tmp" not in tmp_dir: def check_correctness(
tmp_dir = os.path.join(tmp_dir, "tmp") task_id: str,
tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") sample: dict,
if not os.path.exists(tmp_dir): language_type: str,
os.makedirs(tmp_dir) timeout: float = 3.0,
tmp_dir: str = None,
os.chdir(tmp_dir) completion_id: Optional[int] = None,
open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"]) ) -> Dict:
res = "failed: unknown error" """
compile_returncode = -1 Evaluates the functional correctness of a completion by running the test
for _ in range(5): suite provided in the problem.
try: """
compilation_result = subprocess.run(['javac', os.path.join(tmp_dir, "Main.java")], timeout=5,
capture_output=True)
compile_returncode = compilation_result.returncode
break
except subprocess.TimeoutExpired as e:
continue
if compile_returncode != 0:
res = "failed: compilation error"
else:
exec_result = None
try:
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
res = "passed"
elif exec_result.returncode == 1:
if "AssertionError" in exec_result.stderr.decode('unicode-escape'):
res = "failed: wrong answer"
else:
res = f"failed: {exec_result.stderr.decode()}"
except subprocess.TimeoutExpired as e:
res = "time out"
except BaseException as e:
res = f"failed: {e}"
result.append(res)
shutil.rmtree(tmp_dir)
manager = multiprocessing.Manager() manager = multiprocessing.Manager()
result = manager.list() result = manager.list()
p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir,)) p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir, language_type, timeout, sample, result, task_id,))
p.start() p.start()
p.join(timeout=timeout + 1) p.join(timeout=timeout + 1)
if p.is_alive(): if p.is_alive():
...@@ -373,38 +393,19 @@ def check_correctness( ...@@ -373,38 +393,19 @@ def check_correctness(
result.append("timed out") result.append("timed out")
return { return {
"task_id" : task_id, "task_id": task_id,
"completion_id": completion_id, "completion_id": completion_id,
"test_code" : sample["test_code"], "test_code": sample["test_code"],
"prompt" : sample["prompt"], "prompt": sample["prompt"],
"generation" : sample["generation"], "generation": sample["generation"],
"result" : result[0], "result": result[0],
"passed" : result[0] == "passed", "passed": result[0] == "passed",
"finish" : -1 if "finish" not in sample else sample["finish"], "finish": -1 if "finish" not in sample else sample["finish"],
"file" : "" if "file" not in sample else sample["file"], "file": "" if "file" not in sample else sample["file"],
"output" : [] if "output" not in sample else sample["output"], "output": [] if "output" not in sample else sample["output"],
} }
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
@contextlib.contextmanager @contextlib.contextmanager
def time_limit(seconds: float): def time_limit(seconds: float):
def signal_handler(signum, frame): def signal_handler(signum, frame):
......
...@@ -19,7 +19,7 @@ if __name__ == "__main__": ...@@ -19,7 +19,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--input_path", "--input_path",
type=str, type=str,
default="eval_set/humaneval-x" default="../eval_set/humaneval-x"
) )
parser.add_argument( parser.add_argument(
"--language_type", "--language_type",
...@@ -30,18 +30,18 @@ if __name__ == "__main__": ...@@ -30,18 +30,18 @@ if __name__ == "__main__":
"--model_name", "--model_name",
type=str, type=str,
default="chatgpt", default="chatgpt",
help='supported model in [chatgpt,chatglm2].' help='supported models in [chatgpt, chatglm2, bbt].'
) )
parser.add_argument( parser.add_argument(
"--output_prefix", "--output_prefix",
type=str, type=str,
default="chatgpt" default="../output/humaneval"
) )
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl" output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl"
entries = read_dataset(args.input_path, args.language_type, dataset_type="humaneval") entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval")
for entry in entries.values(): for entry in entries.values():
entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type) entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)
...@@ -52,6 +52,9 @@ if __name__ == "__main__": ...@@ -52,6 +52,9 @@ if __name__ == "__main__":
elif args.model_name == "chatglm2": elif args.model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference from inference.chatglm2_inference import GLMInference
model_inference = GLMInference() model_inference = GLMInference()
elif args.model_name == "bbt":
from inference.bbt_inference import BBTInference
model_inference = BBTInference()
else: else:
print(f"{args.model_name} not supported.") print(f"{args.model_name} not supported.")
...@@ -60,8 +63,16 @@ if __name__ == "__main__": ...@@ -60,8 +63,16 @@ if __name__ == "__main__":
with open(output_file_path, "w") as f: with open(output_file_path, "w") as f:
for entry in entries.values(): for entry in entries.values():
prompt = entry["prompt"] prompt = entry["prompt"]
retry_times = 3
generated_tokens = model_inference.inference(prompt) generated_tokens = model_inference.inference(prompt)
generated_code = cleanup_code(generated_code, language_type=args.language_type) while generated_tokens is None and retry_times > 0:
time.sleep(10)
generated_tokens = model_inference.inference(prompt)
retry_times -= 1
if generated_tokens is None:
print(f"task_id: {entry['task_id']} generate failed!")
continue
generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type)
f.write( f.write(
json.dumps( json.dumps(
{ {
...@@ -79,11 +90,11 @@ if __name__ == "__main__": ...@@ -79,11 +90,11 @@ if __name__ == "__main__":
num_finished += 1 num_finished += 1
time_elapsed = time.perf_counter() - start_time time_elapsed = time.perf_counter() - start_time
time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished
print( print(
f"finished {num_finished}, " f"finished {num_finished}, "
f"elapsed {time_elapsed:.4f}", f"elapsed {time_elapsed:.4f}",
f"speed {time_per_sampple:.4f}s/sample", f"speed {time_per_sample:.4f}s/sample",
f"remaining {len(entries) - num_finished:.4f}", f"remaining {len(entries) - num_finished:.4f}",
flush=True, flush=True,
) )
import os
import json
import logging
from urllib.parse import quote
import requests
import traceback
from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__)
class BBTInference(Inference):
def __init__(self):
super(BBTInference, self).__init__()
self.params_url = "../llm_set/params/bbt.json"
self.paras_dict = self.get_params()
self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.url = self.paras_dict.get("url")
self.user_id = self.paras_dict.get("user_id")
self.timeout = self.paras_dict.get("timeout")
self.modelConfig = {
'temperature': self.paras_dict.get("temperature"),
'top_p': self.paras_dict.get("top_p"),
# 'top_k': 80,
# 'do_sample': True,
# 'num_beams': 4,
# 'no_repeat_ngram_size': 4,
# 'repetition_penalty': 1.1,
# 'length_penalty': 1,
'max_new_tokens': self.paras_dict.get("max_new_tokens"),
}
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
@exception_reconnect
def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
try:
data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)}
resp = requests.post(self.url, data=data)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return None
if resp.status_code == 200:
resp_json = json.loads(resp.text)
return resp_json["msg"]
else:
return None
...@@ -3,15 +3,16 @@ import json ...@@ -3,15 +3,16 @@ import json
import torch import torch
import logging import logging
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from inference import Inference from .inference import Inference
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GLMInference(Inference): class GLMInference(Inference):
def __int__(self): def __init__(self):
self.params_url = "llm_set/params/chatgpt.json" super(GLMInference, self).__init__()
self.params_url = "../llm_set/params/chatglm2.json"
self.paras_dict = self.get_params() self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict) self.paras_dict.update(self.paras_base_dict)
......
...@@ -2,20 +2,23 @@ ...@@ -2,20 +2,23 @@
import os import os
import json import json
import logging import logging
import quote from urllib.parse import quote
import requests import requests
import traceback import traceback
from inference import Inference from .inference import Inference
from utils import exception_reconnect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ChatGPTInference(Inference): class ChatGPTInference(Inference):
def __int__(self): def __init__(self):
self.params_url = "llm_set/params/chatgpt.json" super(ChatGPTInference, self).__init__()
self.params_url = "../llm_set/params/chatgpt.json"
self.paras_dict = self.get_params() self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict) self.paras_base_dict.update(self.paras_dict)
self.paras_dict = self.paras_base_dict
self.gpt_url = self.paras_dict.get("url") self.gpt_url = self.paras_dict.get("url")
self.id = self.paras_dict.get("id") self.id = self.paras_dict.get("id")
self.stream = self.paras_dict.get("stream") self.stream = self.paras_dict.get("stream")
...@@ -28,9 +31,11 @@ class ChatGPTInference(Inference): ...@@ -28,9 +31,11 @@ class ChatGPTInference(Inference):
content = open(self.params_url).read() content = open(self.params_url).read()
return json.loads(content) return json.loads(content)
@exception_reconnect
def inference(self, query_text): def inference(self, query_text):
query_text = "补全以下方法体:\n" + query_text
query_text = quote(query_text) query_text = quote(query_text)
param_str = f"?id={self.gpt_url}&stream={self.stream}&temperature={self.temperature}&q={query_text}" param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
try: try:
resp = requests.get(self.gpt_url + param_str, timeout=self.timeout) resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
except Exception as e: except Exception as e:
......
...@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__) ...@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
class Inference(object): class Inference(object):
def __int__(self): def __init__(self):
self.params_base_url = "llm_set/params/default.json" self.params_base_url = "../llm_set/params/default.json"
self.paras_base_dict = self.get_paras_base() self.paras_base_dict = self.get_paras_base()
def get_paras_base(self): def get_paras_base(self):
......
import os import os
import re
from typing import * from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG from data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter
IMPORT_HELPER = { IMPORT_HELPER = {
...@@ -49,7 +51,7 @@ IMPORT_HELPER = { ...@@ -49,7 +51,7 @@ IMPORT_HELPER = {
def read_dataset( def read_dataset(
data_file: str = None, data_folder: str = None,
dataset_type: str = "humaneval", dataset_type: str = "humaneval",
language_type: str = "python", language_type: str = "python",
num_shot=None, num_shot=None,
...@@ -57,7 +59,7 @@ def read_dataset( ...@@ -57,7 +59,7 @@ def read_dataset(
if num_shot is not None: if num_shot is not None:
print(f"{num_shot}-shot setting...") print(f"{num_shot}-shot setting...")
if "humaneval" in dataset_type.lower(): if "humaneval" in dataset_type.lower():
data_file = os.path.join(data_file, language_type, "data", f"humaneval_{language_type}.jsonl.gz") data_file = os.path.join(data_folder, language_type, "data", f"humaneval_{language_type}.jsonl.gz")
dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
else: else:
raise f"Dataset: {dataset_type} not supported." raise f"Dataset: {dataset_type} not supported."
...@@ -147,41 +149,191 @@ def is_code_generation_finished( ...@@ -147,41 +149,191 @@ def is_code_generation_finished(
return False return False
def cleanup_code( def find_method_name(language_type, sample):
code: str, """查找方法名"""
language_type: str = None if language_type.lower() == "python":
): declaration = sample["declaration"]
ret = re.search("def (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type.lower() == "cpp":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() == "java":
declaration = sample["declaration"]
ret = re.search(" (.*?)\\(", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
if " " in method_name:
return method_name[1]
return method_name
elif language_type.lower() in ["js", "javascript"]:
declaration = sample["declaration"]
ret = re.search("const (.*?) ", declaration)
if ret is None:
return None
method_name = ret.group(1).strip()
return method_name
elif language_type.lower() == "go":
declaration = sample["declaration"]
ret = re.search("func (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
elif language_type == "rust":
declaration = sample["declaration"]
ret = re.search("fn (.*?)\\(", declaration)
if ret is None:
return None
return ret.group(1).strip()
else:
return None
def extract_markdown_code(content):
"""提取markdown中的代码,即"```"中的内容"""
codes = []
rets = re.findall("```([\\s\\S]*?)```", content)
if rets is None:
return codes
for ret in rets:
if not ret.startswith("\n"):
lines = ret.split("\n")
codes.append("".join(lines[1:]))
else:
codes.append(ret.strip())
return codes
def cleanup_code(code: str, sample, language_type: str = None):
""" """
Cleans up the generated code. Cleans up the generated code.
""" """
if language_type is None: if language_type is None:
return code return code
method_name = find_method_name(language_type, sample)
if method_name is None:
return code
method_body = code
if language_type.lower() == "python": if language_type.lower() == "python":
end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint", "\nassert"] method_lines = []
for w in end_words: for line in code.split("\n"):
if w in code: if f"def {method_name}" in line:
code = code[:code.rfind(w)] method_lines.append(line)
continue
if method_lines:
method_lines.append(line)
if line.startswith(" return"):
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "java": elif language_type.lower() == "java":
main_pos = code.find("public static void main") method_lines = []
if main_pos != -1: bracket_left = 0
code = code[:main_pos] + '}' bracket_right = 0
if '}' in code: for line in code.split("\n"):
code = code[:code.rfind('}')] + '}' new_line = line.strip()
if code.count('{') + 1 == code.count('}'): counter = Counter(new_line)
code += "\n}" if new_line.startswith("public") and method_name in new_line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_lines.append("}")
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "go": elif language_type.lower() == "go":
end_words = ["\n//", "\nfunc main("] method_lines = []
for w in end_words: bracket_left = 0
if w in code: bracket_right = 0
code = code[:code.rfind(w)] for line in code.split("\n"):
if '}' in code: counter = Counter(line)
code = code[:code.rfind('}')] + '}' if f"func {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "cpp": elif language_type.lower() == "cpp":
if '}' in code: method_lines = []
code = code[:code.rfind('}')] + '}' bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f" {method_name}(" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
elif language_type.lower() == "js": elif language_type.lower() == "js":
if '}' in code: method_lines = []
code = code[:code.rfind('}')] + '}' bracket_left = 0
bracket_right = 0
for line in code.split("\n"):
counter = Counter(line)
if f"const {method_name}" in line:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
continue
if method_lines:
method_lines.append(line)
bracket_left += counter["{"]
bracket_right += counter["}"]
if bracket_left == bracket_right and bracket_right > 0:
break
if method_lines:
method_body = "\n".join(method_lines[1:])
return method_body + "\n"
def exception_reconnect(funct):
"""异常重连"""
def wrapper_func(*args, **kwargs):
try:
return funct(*args, **kwargs)
except Exception as e:
print(f"exception will reconnect: {str(e)}")
return funct(*args, **kwargs)
return code return wrapper_func
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册