From e0bf81b9347716de45990340014cb69d1b8f9bb9 Mon Sep 17 00:00:00 2001 From: chenlong Date: Mon, 24 Jul 2023 11:15:54 +0800 Subject: [PATCH] fix bug --- src/evaluate_humaneval_x.py | 47 +-- src/execution.py | 603 ++++++++++++++-------------- src/generate_humaneval_x.py | 27 +- src/inference/bbt_inference.py | 58 +++ src/inference/chatglm2_inference.py | 7 +- src/inference/chatgpt_inference.py | 17 +- src/inference/inference.py | 4 +- src/utils.py | 208 ++++++++-- 8 files changed, 593 insertions(+), 378 deletions(-) create mode 100644 src/inference/bbt_inference.py diff --git a/src/evaluate_humaneval_x.py b/src/evaluate_humaneval_x.py index 4aa2424..f8f0cf8 100644 --- a/src/evaluate_humaneval_x.py +++ b/src/evaluate_humaneval_x.py @@ -16,10 +16,10 @@ from metric import estimate_pass_at_k from execution import check_correctness LANGUAGE_NAME = { - "cpp" : "CPP", - "go" : "Go", - "java" : "Java", - "js" : "JavaScript", + "cpp": "CPP", + "go": "Go", + "java": "Java", + "js": "JavaScript", "python": "Python", } @@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False): if language == "python": code_ = [] for line in code.split("\n"): - if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'): - break + if line.strip().startswith("def "): + continue + if line and line[0] != ' ' and line[0] != '\t': + line = " " + line code_.append(line) code = "\n".join(code_) test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" @@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]: def evaluate_functional_correctness( - input_file: str = None, - tmp_dir: str = "./", - n_workers: int = 32, + language_type: str = "python", + input_folder: str = "../output", + tmp_dir: str = "../output/tmp/", + n_workers: int = 3, timeout: float = 500.0, - problem_file: str = "../data/humaneval_python.jsonl.gz", - out_dir: str = None, + problem_folder: str = "../eval_set/humaneval-x/", + out_dir: str = "../output/", k: List[int] = [1, 10, 100], test_groundtruth: bool = False, example_test: bool = False, + model_name: str = "chatgpt" ): if example_test: print("Example test...") - problems = read_dataset(problem_file, - dataset_type="humaneval") + input_file = f"{input_folder}/humaneval_{model_name}_{language_type}_finished.jsonl" sample_jsonl = stream_jsonl_all(input_file) + problems = read_dataset(data_folder=problem_folder, language_type=language_type, dataset_type="humaneval") if example_test: suffix = "_example_test.jsonl" else: @@ -125,14 +129,6 @@ def evaluate_functional_correctness( else: out_file = os.path.join(input_file.replace(".jsonl", suffix)) - if "/codegeex/benchmark/humaneval-x/" in input_file: - test_groundtruth = True - - if "-to-" in input_file: - translation_mode = True - else: - translation_mode = False - with ThreadPoolExecutor(max_workers=n_workers) as executor: futures = [] @@ -162,14 +158,6 @@ def evaluate_functional_correctness( for sample in tqdm(sample_jsonl): task_id = sample["task_id"] lang = task_id.split("/")[0].lower() - if translation_mode: - task_id = sample["task_id"].split("/")[-1] - lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-") - for l in LANGUAGE_NAME: - if l in lang: - lang = l - break - task_id = f"{LANGUAGE_NAME[lang]}/{task_id}" if lang == "javascript": lang = "js" tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") @@ -187,7 +175,6 @@ def evaluate_functional_correctness( completion_id[task_id] += 1 n_samples += 1 - print(completion_id) if len(completion_id) == len(problems): evaluate_pass_at_k = True else: diff --git a/src/execution.py b/src/execution.py index cbdf14f..7dfc645 100644 --- a/src/execution.py +++ b/src/execution.py @@ -1,3 +1,23 @@ +# Copyright (c) OpenAI (https://openai.com) + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# ============================================================================ import contextlib import faulthandler import io @@ -41,75 +61,162 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non out.write(jout) -def check_correctness( - task_id: str, - sample: dict, - language_type: str, - timeout: float = 3.0, - tmp_dir: str = None, - completion_id: Optional[int] = None, -) -> Dict: - """ - Evaluates the functional correctness of a completion by running the test - suite provided in the problem. - """ - - def unsafe_execute(tmp_dir): - random_id = random.uniform(1, 1000) - if "python" in language_type.lower(): - with create_tempdir(): - - # These system calls are needed when cleaning up tempdir. - import os - import shutil - rmtree = shutil.rmtree - rmdir = os.rmdir - chdir = os.chdir - - # Disable functionalities that can make destructive changes to the test. - reliability_guard() - - try: - exec_globals = {} - with swallow_io(): - with time_limit(timeout): - # WARNING - # This program exists to execute untrusted model-generated code. Although - # it is highly unlikely that model-generated code will do something overtly - # malicious in response to this test suite, model-generated code may act - # destructively due to a lack of model capability or alignment. - # Users are strongly encouraged to sandbox this evaluation suite so that it - # does not perform destructive actions on their host or network. - # Once you have read this disclaimer and taken appropriate precautions, - # uncomment the following line and proceed at your own risk: - exec(sample["test_code"], exec_globals) - result.append("passed") - except TimeoutException: - result.append("timed out") - except AssertionError as e: - result.append(f"failed: AssertionError") - except BaseException as e: - result.append(f"failed: {e}") - - # Needed for cleaning up. - shutil.rmtree = rmtree - os.rmdir = rmdir - os.chdir = chdir - - elif "go" in language_type.lower(): - assert tmp_dir is not None, "Go should be evaluated in a dir where necessary module files installed." +def unsafe_execute(tmp_dir, language_type, timeout, sample, result, task_id): + random_id = random.uniform(1, 1000) + if "python" in language_type.lower(): + with create_tempdir(): + # These system calls are needed when cleaning up tempdir. import os import shutil + rmtree = shutil.rmtree + rmdir = os.rmdir + chdir = os.chdir - if "tmp" not in tmp_dir: - tmp_dir = os.path.join(tmp_dir, "tmp") - tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") - if not os.path.exists(tmp_dir): - os.makedirs(tmp_dir) + # Disable functionalities that can make destructive changes to the test. + reliability_guard() - os.chdir(tmp_dir) - open(f"main_test.go", 'w').write(sample["test_code"]) + try: + exec_globals = {} + with swallow_io(): + with time_limit(timeout): + # WARNING + # This program exists to execute untrusted model-generated code. Although + # it is highly unlikely that model-generated code will do something overtly + # malicious in response to this test suite, model-generated code may act + # destructively due to a lack of model capability or alignment. + # Users are strongly encouraged to sandbox this evaluation suite so that it + # does not perform destructive actions on their host or network. + # Once you have read this disclaimer and taken appropriate precautions, + # uncomment the following line and proceed at your own risk: + exec(sample["test_code"], exec_globals) + result.append("passed") + except TimeoutException: + result.append("timed out") + except AssertionError as e: + result.append(f"failed: AssertionError") + except BaseException as e: + result.append(f"failed: {e}") + + # Needed for cleaning up. + shutil.rmtree = rmtree + os.rmdir = rmdir + os.chdir = chdir + + elif "go" in language_type.lower(): + assert tmp_dir is not None, "Go should be evaluated in a dir where necessary module files installed." + + import os + import shutil + + if "tmp" not in tmp_dir: + tmp_dir = os.path.join(tmp_dir, "tmp") + tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + # os.chdir(tmp_dir) + open(os.path.join(tmp_dir, "main_test.go"), 'w').write(sample["test_code"]) + try: + exec_result = None + with time_limit(timeout): + # WARNING + # This program exists to execute untrusted model-generated code. Although + # it is highly unlikely that model-generated code will do something overtly + # malicious in response to this test suite, model-generated code may act + # destructively due to a lack of model capability or alignment. + # Users are strongly encouraged to sandbox this evaluation suite so that it + # does not perform destructive actions on their host or network. + # Once you have read this disclaimer and taken appropriate precautions, + # uncomment the following line and proceed at your own risk: + exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", os.path.join(tmp_dir, "main_test.go")], timeout=timeout, capture_output=True) + + if exec_result.returncode == 0: + result.append("passed") + else: + if exec_result.stderr: + try: + err = exec_result.stderr.decode() + except: + err = exec_result.stderr + else: + try: + err = exec_result.stdout.decode() + except: + err = exec_result.stdout + result.append(f"failed: {err}") + + except TimeoutException: + result.append("timed out") + + shutil.rmtree(tmp_dir) + elif "js" in language_type.lower(): + import os + import shutil + + if "tmp" not in tmp_dir: + tmp_dir = os.path.join(tmp_dir, "tmp") + tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + # os.chdir(tmp_dir) + open(os.path.join(tmp_dir, "test.js"), 'w').write(sample["test_code"]) + try: + exec_result = None + with time_limit(timeout): + # WARNING + # This program exists to execute untrusted model-generated code. Although + # it is highly unlikely that model-generated code will do something overtly + # malicious in response to this test suite, model-generated code may act + # destructively due to a lack of model capability or alignment. + # Users are strongly encouraged to sandbox this evaluation suite so that it + # does not perform destructive actions on their host or network. + # Once you have read this disclaimer and taken appropriate precautions, + # uncomment the following line and proceed at your own risk: + exec_result = subprocess.run(["node", os.path.join(tmp_dir, "test.js")], timeout=timeout, capture_output=True) + + if exec_result.stderr.decode(): + err = exec_result.stderr.decode() + result.append(f"failed: {err}") + elif exec_result.stdout.decode(): + err = exec_result.stdout.decode() + result.append(f"failed: {err}") + else: + result.append("passed") + + except TimeoutException: + result.append("timed out") + + shutil.rmtree(tmp_dir) + elif "cpp" in language_type.lower(): + import os + import shutil + + if "tmp" not in tmp_dir: + tmp_dir = os.path.join(tmp_dir, "tmp") + tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + # os.chdir(tmp_dir) + open(os.path.join(tmp_dir, "test.cpp"), 'w').write(sample["test_code"]) + # if "162" in task_id: + # compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp"), "-lcrypto", "-lssl"], + # timeout=timeout, + # capture_output=True) + # else: + # compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout, + # capture_output=True) + compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout, + capture_output=True) + if compilation_result.returncode != 0: + if compilation_result.stderr: + err = compilation_result.stderr.decode() + else: + err = compilation_result.stdout.decode() + result.append(f"failed: compilation error: {err}") + else: try: exec_result = None with time_limit(timeout): @@ -122,7 +229,7 @@ def check_correctness( # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True) + exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True) if exec_result.returncode == 0: result.append("passed") @@ -138,232 +245,145 @@ def check_correctness( except: err = exec_result.stdout result.append(f"failed: {err}") - except TimeoutException: result.append("timed out") - shutil.rmtree(tmp_dir) - elif "js" in language_type.lower(): - import os - import shutil + shutil.rmtree(tmp_dir) + elif "rust" in language_type.lower(): + import os + + WD: str = os.path.dirname(os.path.abspath(__file__)) + RUST_DIR: str = os.path.join(WD, "rust") + RUST_SRC: str = os.path.join(RUST_DIR, "src") + RUST_BIN: str = os.path.join(RUST_SRC, "bin") + RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp") + RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs") + RUST_EXT: str = ".rs" + + # Create mandatory tmp directories + os.makedirs(RUST_TMP_DIR, exist_ok=True) + os.makedirs(RUST_LOGS, exist_ok=True) + os.makedirs(RUST_SRC, exist_ok=True) + os.makedirs(RUST_BIN, exist_ok=True) + + with tempfile.NamedTemporaryFile(dir=RUST_BIN, delete=False) as f: + # temporal file name + file_prefix = sample["task_id"].lower().replace("/", "_") + file_name: str = file_prefix + RUST_EXT + + os.rename(f.name, os.path.join(RUST_BIN, file_name)) + + # Sample to pure Rust function + rust_code: str = sample["test_code"] + + # dump the rust source code in the target temporal file + f.write(rust_code.encode('utf-8')) + + # Proceed towards Rust binaries compilation. Therefore move to Rust module root dir. + os.chdir(RUST_DIR) + + # Two possible outcomes + # Pass OR Fail compilation + log_filename: str = file_prefix + ".jsonl" + log_path: str = os.path.join(RUST_LOGS, log_filename) + cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path + # Compilation build status + returned_val_compilation: int + + # Overwrite file content + if os.path.exists(log_path): + if(file_size := os.path.getsize(log_path)) >= 0: + os.remove(log_path) + returned_val_compilation = os.system(cargo_check) - if "tmp" not in tmp_dir: - tmp_dir = os.path.join(tmp_dir, "tmp") - tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") - if not os.path.exists(tmp_dir): - os.makedirs(tmp_dir) + else: + returned_val_compilation = os.system(cargo_check) - os.chdir(tmp_dir) - open(f"test.js", 'w').write(sample["test_code"]) - try: - exec_result = None - with time_limit(timeout): - # WARNING - # This program exists to execute untrusted model-generated code. Although - # it is highly unlikely that model-generated code will do something overtly - # malicious in response to this test suite, model-generated code may act - # destructively due to a lack of model capability or alignment. - # Users are strongly encouraged to sandbox this evaluation suite so that it - # does not perform destructive actions on their host or network. - # Once you have read this disclaimer and taken appropriate precautions, - # uncomment the following line and proceed at your own risk: - exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True) + # 0 means success + if returned_val_compilation == 0: + # Execution pipeline + cargo_test: str = "cargo test --bin " + file_prefix + " --message-format json >> " + log_path + returned_val_execution = os.system(cargo_test) - if exec_result.stderr.decode(): - err = exec_result.stderr.decode() - result.append(f"failed: {err}") - elif exec_result.stdout.decode(): - err = exec_result.stdout.decode() - result.append(f"failed: {err}") - else: - result.append("passed") - - except TimeoutException: - result.append("timed out") - - shutil.rmtree(tmp_dir) - elif "cpp" in language_type.lower(): - import os - import shutil - - if "tmp" not in tmp_dir: - tmp_dir = os.path.join(tmp_dir, "tmp") - tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") - if not os.path.exists(tmp_dir): - os.makedirs(tmp_dir) - - os.chdir(tmp_dir) - open(f"test.cpp", 'w').write(sample["test_code"]) - if "162" in task_id: - compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"], - timeout=timeout, - capture_output=True) + if returned_val_execution == 0: + result.append("passed") else: - compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp"], timeout=timeout, - capture_output=True) - if compilation_result.returncode != 0: - if compilation_result.stderr: - err = compilation_result.stderr.decode() - else: - err = compilation_result.stdout.decode() - result.append(f"failed: compilation error: {err}") - else: - try: - exec_result = None - with time_limit(timeout): - # WARNING - # This program exists to execute untrusted model-generated code. Although - # it is highly unlikely that model-generated code will do something overtly - # malicious in response to this test suite, model-generated code may act - # destructively due to a lack of model capability or alignment. - # Users are strongly encouraged to sandbox this evaluation suite so that it - # does not perform destructive actions on their host or network. - # Once you have read this disclaimer and taken appropriate precautions, - # uncomment the following line and proceed at your own risk: - exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True) + result.append(f"failed: execution error") - if exec_result.returncode == 0: - result.append("passed") - else: - if exec_result.stderr: - try: - err = exec_result.stderr.decode() - except: - err = exec_result.stderr - else: - try: - err = exec_result.stdout.decode() - except: - err = exec_result.stdout - result.append(f"failed: {err}") - except TimeoutException: - result.append("timed out") - - shutil.rmtree(tmp_dir) - elif "rust" in language_type.lower(): - import os - - WD: str = os.path.dirname(os.path.abspath(__file__)) - RUST_DIR: str = os.path.join(WD, "rust") - RUST_SRC: str = os.path.join(RUST_DIR, "src") - RUST_BIN: str = os.path.join(RUST_SRC, "bin") - RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp") - RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs") - RUST_EXT: str = ".rs" - - # Create mandatory tmp directories - os.makedirs(RUST_TMP_DIR, exist_ok=True) - os.makedirs(RUST_LOGS, exist_ok=True) - os.makedirs(RUST_SRC, exist_ok=True) - os.makedirs(RUST_BIN, exist_ok=True) - - with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f: - #temporal file name - file_prefix = sample["task_id"].lower().replace("/", "_") - file_name:str = file_prefix +RUST_EXT - - os.rename(f.name, os.path.join(RUST_BIN, file_name)) - - # Sample to pure Rust function - rust_code: str = sample["test_code"] - - # dump the rust source code in the target temporal file - f.write(rust_code.encode('utf-8')) - - # Proceed towards Rust binaries compilation. Therefore move to Rust module root dir. - os.chdir(RUST_DIR) - - # Two possible outcomes - # Pass OR Fail compilation - log_filename: str = file_prefix + ".jsonl" - log_path: str = os.path.join(RUST_LOGS, log_filename) - cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path - # Compilation build status - returned_val_compilation: int - - # Overwrite file content - if os.path.exists(log_path): - if(file_size := os.path.getsize(log_path)) >= 0: - os.remove(log_path) - returned_val_compilation = os.system(cargo_check) - - else: - returned_val_compilation = os.system(cargo_check) + else: + result.append(f"failed: compilation error") - # 0 means success - if returned_val_compilation == 0: + elif "java" in language_type.lower(): + assert tmp_dir is not None, "Java should be evaluated in a temporary dir." - #Execution pipeline - cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path - returned_val_execution = os.system(cargo_test) - - if returned_val_execution == 0: - result.append("passed") - else: - result.append(f"failed: execution error") - - else: - result.append(f"failed: compilation error") + import os + import shutil + if "tmp" not in tmp_dir: + tmp_dir = os.path.join(tmp_dir, "tmp") + tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) - elif "java" in language_type.lower(): - assert tmp_dir is not None, "Java should be evaluated in a temporary dir." + # os.chdir(tmp_dir) + open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"]) + res = "failed: unknown error" + compile_returncode = -1 + for _ in range(5): + try: + compilation_result = subprocess.run(['javac', os.path.join(tmp_dir, "Main.java")], timeout=5, + capture_output=True) + compile_returncode = compilation_result.returncode + break + except subprocess.TimeoutExpired as e: + continue + if compile_returncode != 0: + res = "failed: compilation error" + else: + exec_result = None + try: + # WARNING + # This program exists to execute untrusted model-generated code. Although + # it is highly unlikely that model-generated code will do something overtly + # malicious in response to this test suite, model-generated code may act + # destructively due to a lack of model capability or alignment. + # Users are strongly encouraged to sandbox this evaluation suite so that it + # does not perform destructive actions on their host or network. + # Once you have read this disclaimer and taken appropriate precautions, + # uncomment the following line and proceed at your own risk: + exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True) + if exec_result.returncode == 0: + res = "passed" + elif exec_result.returncode == 1: + if "AssertionError" in exec_result.stderr.decode('unicode-escape'): + res = "failed: wrong answer" + else: + res = f"failed: {exec_result.stderr.decode()}" + except subprocess.TimeoutExpired as e: + res = "time out" + except BaseException as e: + res = f"failed: {e}" + result.append(res) + shutil.rmtree(tmp_dir) - import os - import shutil - if "tmp" not in tmp_dir: - tmp_dir = os.path.join(tmp_dir, "tmp") - tmp_dir = os.path.join(tmp_dir, f"{task_id.replace('/', '-')}-{random_id}") - if not os.path.exists(tmp_dir): - os.makedirs(tmp_dir) - - os.chdir(tmp_dir) - open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"]) - res = "failed: unknown error" - compile_returncode = -1 - for _ in range(5): - try: - compilation_result = subprocess.run(['javac', os.path.join(tmp_dir, "Main.java")], timeout=5, - capture_output=True) - compile_returncode = compilation_result.returncode - break - except subprocess.TimeoutExpired as e: - continue - if compile_returncode != 0: - res = "failed: compilation error" - else: - exec_result = None - try: - # WARNING - # This program exists to execute untrusted model-generated code. Although - # it is highly unlikely that model-generated code will do something overtly - # malicious in response to this test suite, model-generated code may act - # destructively due to a lack of model capability or alignment. - # Users are strongly encouraged to sandbox this evaluation suite so that it - # does not perform destructive actions on their host or network. - # Once you have read this disclaimer and taken appropriate precautions, - # uncomment the following line and proceed at your own risk: - # exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True) - if exec_result.returncode == 0: - res = "passed" - elif exec_result.returncode == 1: - if "AssertionError" in exec_result.stderr.decode('unicode-escape'): - res = "failed: wrong answer" - else: - res = f"failed: {exec_result.stderr.decode()}" - except subprocess.TimeoutExpired as e: - res = "time out" - except BaseException as e: - res = f"failed: {e}" - result.append(res) - - shutil.rmtree(tmp_dir) +def check_correctness( + task_id: str, + sample: dict, + language_type: str, + timeout: float = 3.0, + tmp_dir: str = None, + completion_id: Optional[int] = None, +) -> Dict: + """ + Evaluates the functional correctness of a completion by running the test + suite provided in the problem. + """ manager = multiprocessing.Manager() result = manager.list() - p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir,)) + p = multiprocessing.Process(target=unsafe_execute, args=(tmp_dir, language_type, timeout, sample, result, task_id,)) p.start() p.join(timeout=timeout + 1) if p.is_alive(): @@ -373,38 +393,19 @@ def check_correctness( result.append("timed out") return { - "task_id" : task_id, + "task_id": task_id, "completion_id": completion_id, - "test_code" : sample["test_code"], - "prompt" : sample["prompt"], - "generation" : sample["generation"], - "result" : result[0], - "passed" : result[0] == "passed", - "finish" : -1 if "finish" not in sample else sample["finish"], - "file" : "" if "file" not in sample else sample["file"], - "output" : [] if "output" not in sample else sample["output"], + "test_code": sample["test_code"], + "prompt": sample["prompt"], + "generation": sample["generation"], + "result": result[0], + "passed": result[0] == "passed", + "finish": -1 if "finish" not in sample else sample["finish"], + "file": "" if "file" not in sample else sample["file"], + "output": [] if "output" not in sample else sample["output"], } -# Copyright (c) OpenAI (https://openai.com) - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# ============================================================================ @contextlib.contextmanager def time_limit(seconds: float): def signal_handler(signum, frame): diff --git a/src/generate_humaneval_x.py b/src/generate_humaneval_x.py index 056ed44..cdffae7 100644 --- a/src/generate_humaneval_x.py +++ b/src/generate_humaneval_x.py @@ -19,7 +19,7 @@ if __name__ == "__main__": parser.add_argument( "--input_path", type=str, - default="eval_set/humaneval-x" + default="../eval_set/humaneval-x" ) parser.add_argument( "--language_type", @@ -30,18 +30,18 @@ if __name__ == "__main__": "--model_name", type=str, default="chatgpt", - help='supported model in [chatgpt,chatglm2].' + help='supported models in [chatgpt, chatglm2, bbt].' ) parser.add_argument( "--output_prefix", type=str, - default="chatgpt" + default="../output/humaneval" ) args = parser.parse_known_args()[0] - output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl" + output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl" - entries = read_dataset(args.input_path, args.language_type, dataset_type="humaneval") + entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval") for entry in entries.values(): entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type) @@ -52,6 +52,9 @@ if __name__ == "__main__": elif args.model_name == "chatglm2": from inference.chatglm2_inference import GLMInference model_inference = GLMInference() + elif args.model_name == "bbt": + from inference.bbt_inference import BBTInference + model_inference = BBTInference() else: print(f"{args.model_name} not supported.") @@ -60,8 +63,16 @@ if __name__ == "__main__": with open(output_file_path, "w") as f: for entry in entries.values(): prompt = entry["prompt"] + retry_times = 3 generated_tokens = model_inference.inference(prompt) - generated_code = cleanup_code(generated_code, language_type=args.language_type) + while generated_tokens is None and retry_times > 0: + time.sleep(10) + generated_tokens = model_inference.inference(prompt) + retry_times -= 1 + if generated_tokens is None: + print(f"task_id: {entry['task_id']} generate failed!") + continue + generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type) f.write( json.dumps( { @@ -79,11 +90,11 @@ if __name__ == "__main__": num_finished += 1 time_elapsed = time.perf_counter() - start_time - time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished + time_per_sample = 0.0 if num_finished == 0 else time_elapsed / num_finished print( f"finished {num_finished}, " f"elapsed {time_elapsed:.4f}", - f"speed {time_per_sampple:.4f}s/sample", + f"speed {time_per_sample:.4f}s/sample", f"remaining {len(entries) - num_finished:.4f}", flush=True, ) diff --git a/src/inference/bbt_inference.py b/src/inference/bbt_inference.py new file mode 100644 index 0000000..e26ade5 --- /dev/null +++ b/src/inference/bbt_inference.py @@ -0,0 +1,58 @@ + +import os +import json +import logging +from urllib.parse import quote +import requests +import traceback +from .inference import Inference +from utils import exception_reconnect + + +logger = logging.getLogger(__name__) + + +class BBTInference(Inference): + def __init__(self): + super(BBTInference, self).__init__() + self.params_url = "../llm_set/params/bbt.json" + self.paras_dict = self.get_params() + self.paras_base_dict.update(self.paras_dict) + self.paras_dict = self.paras_base_dict + self.url = self.paras_dict.get("url") + self.user_id = self.paras_dict.get("user_id") + self.timeout = self.paras_dict.get("timeout") + self.modelConfig = { + 'temperature': self.paras_dict.get("temperature"), + 'top_p': self.paras_dict.get("top_p"), + # 'top_k': 80, + # 'do_sample': True, + # 'num_beams': 4, + # 'no_repeat_ngram_size': 4, + # 'repetition_penalty': 1.1, + # 'length_penalty': 1, + 'max_new_tokens': self.paras_dict.get("max_new_tokens"), + } + + def get_params(self): + if not os.path.exists(self.params_url): + logger.error(f"params_url:{self.params_url} is not exists.") + content = open(self.params_url).read() + return json.loads(content) + + @exception_reconnect + def inference(self, query_text): + query_text = "补全以下方法体:\n" + query_text + try: + data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)} + resp = requests.post(self.url, data=data) + except Exception as e: + logger.error(traceback.format_exc()) + logger.error(e) + return None + + if resp.status_code == 200: + resp_json = json.loads(resp.text) + return resp_json["msg"] + else: + return None diff --git a/src/inference/chatglm2_inference.py b/src/inference/chatglm2_inference.py index c7669b7..4b1edf1 100644 --- a/src/inference/chatglm2_inference.py +++ b/src/inference/chatglm2_inference.py @@ -3,15 +3,16 @@ import json import torch import logging from transformers import AutoModel, AutoTokenizer -from inference import Inference +from .inference import Inference logger = logging.getLogger(__name__) class GLMInference(Inference): - def __int__(self): - self.params_url = "llm_set/params/chatgpt.json" + def __init__(self): + super(GLMInference, self).__init__() + self.params_url = "../llm_set/params/chatglm2.json" self.paras_dict = self.get_params() self.paras_dict.update(self.paras_base_dict) diff --git a/src/inference/chatgpt_inference.py b/src/inference/chatgpt_inference.py index 87928f4..f8503f4 100644 --- a/src/inference/chatgpt_inference.py +++ b/src/inference/chatgpt_inference.py @@ -2,20 +2,23 @@ import os import json import logging -import quote +from urllib.parse import quote import requests import traceback -from inference import Inference +from .inference import Inference +from utils import exception_reconnect logger = logging.getLogger(__name__) class ChatGPTInference(Inference): - def __int__(self): - self.params_url = "llm_set/params/chatgpt.json" + def __init__(self): + super(ChatGPTInference, self).__init__() + self.params_url = "../llm_set/params/chatgpt.json" self.paras_dict = self.get_params() - self.paras_dict.update(self.paras_base_dict) + self.paras_base_dict.update(self.paras_dict) + self.paras_dict = self.paras_base_dict self.gpt_url = self.paras_dict.get("url") self.id = self.paras_dict.get("id") self.stream = self.paras_dict.get("stream") @@ -28,9 +31,11 @@ class ChatGPTInference(Inference): content = open(self.params_url).read() return json.loads(content) + @exception_reconnect def inference(self, query_text): + query_text = "补全以下方法体:\n" + query_text query_text = quote(query_text) - param_str = f"?id={self.gpt_url}&stream={self.stream}&temperature={self.temperature}&q={query_text}" + param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}" try: resp = requests.get(self.gpt_url + param_str, timeout=self.timeout) except Exception as e: diff --git a/src/inference/inference.py b/src/inference/inference.py index c27d3b6..31e4ed0 100644 --- a/src/inference/inference.py +++ b/src/inference/inference.py @@ -8,8 +8,8 @@ logger = logging.getLogger(__name__) class Inference(object): - def __int__(self): - self.params_base_url = "llm_set/params/default.json" + def __init__(self): + self.params_base_url = "../llm_set/params/default.json" self.paras_base_dict = self.get_paras_base() def get_paras_base(self): diff --git a/src/utils.py b/src/utils.py index 998cefa..2043e70 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,6 +1,8 @@ import os +import re from typing import * from data_utils import stream_jsonl, LANGUAGE_TAG +from collections import Counter IMPORT_HELPER = { @@ -49,7 +51,7 @@ IMPORT_HELPER = { def read_dataset( - data_file: str = None, + data_folder: str = None, dataset_type: str = "humaneval", language_type: str = "python", num_shot=None, @@ -57,7 +59,7 @@ def read_dataset( if num_shot is not None: print(f"{num_shot}-shot setting...") if "humaneval" in dataset_type.lower(): - data_file = os.path.join(data_file, language_type, "data", f"humaneval_{language_type}.jsonl.gz") + data_file = os.path.join(data_folder, language_type, "data", f"humaneval_{language_type}.jsonl.gz") dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} else: raise f"Dataset: {dataset_type} not supported." @@ -147,41 +149,191 @@ def is_code_generation_finished( return False -def cleanup_code( - code: str, - language_type: str = None -): +def find_method_name(language_type, sample): + """查找方法名""" + if language_type.lower() == "python": + declaration = sample["declaration"] + ret = re.search("def (.*?)\\(", declaration) + if ret is None: + return None + return ret.group(1).strip() + elif language_type.lower() == "cpp": + declaration = sample["declaration"] + ret = re.search(" (.*?)\\(", declaration) + if ret is None: + return None + method_name = ret.group(1).strip() + if " " in method_name: + return method_name[1] + return method_name + elif language_type.lower() == "java": + declaration = sample["declaration"] + ret = re.search(" (.*?)\\(", declaration) + if ret is None: + return None + method_name = ret.group(1).strip() + if " " in method_name: + return method_name[1] + return method_name + elif language_type.lower() in ["js", "javascript"]: + declaration = sample["declaration"] + ret = re.search("const (.*?) ", declaration) + if ret is None: + return None + method_name = ret.group(1).strip() + return method_name + elif language_type.lower() == "go": + declaration = sample["declaration"] + ret = re.search("func (.*?)\\(", declaration) + if ret is None: + return None + return ret.group(1).strip() + elif language_type == "rust": + declaration = sample["declaration"] + ret = re.search("fn (.*?)\\(", declaration) + if ret is None: + return None + return ret.group(1).strip() + else: + return None + + +def extract_markdown_code(content): + """提取markdown中的代码,即"```"中的内容""" + codes = [] + rets = re.findall("```([\\s\\S]*?)```", content) + if rets is None: + return codes + for ret in rets: + if not ret.startswith("\n"): + lines = ret.split("\n") + codes.append("".join(lines[1:])) + else: + codes.append(ret.strip()) + + return codes + + +def cleanup_code(code: str, sample, language_type: str = None): """ Cleans up the generated code. """ if language_type is None: return code + method_name = find_method_name(language_type, sample) + if method_name is None: + return code + + method_body = code if language_type.lower() == "python": - end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint", "\nassert"] - for w in end_words: - if w in code: - code = code[:code.rfind(w)] + method_lines = [] + for line in code.split("\n"): + if f"def {method_name}" in line: + method_lines.append(line) + continue + if method_lines: + method_lines.append(line) + if line.startswith(" return"): + break + if method_lines: + method_body = "\n".join(method_lines[1:]) + elif language_type.lower() == "java": - main_pos = code.find("public static void main") - if main_pos != -1: - code = code[:main_pos] + '}' - if '}' in code: - code = code[:code.rfind('}')] + '}' - if code.count('{') + 1 == code.count('}'): - code += "\n}" + method_lines = [] + bracket_left = 0 + bracket_right = 0 + for line in code.split("\n"): + new_line = line.strip() + counter = Counter(new_line) + if new_line.startswith("public") and method_name in new_line: + method_lines.append(line) + bracket_left += counter["{"] + bracket_right += counter["}"] + continue + if method_lines: + method_lines.append(line) + bracket_left += counter["{"] + bracket_right += counter["}"] + if bracket_left == bracket_right and bracket_right > 0: + break + + if method_lines: + method_lines.append("}") + method_body = "\n".join(method_lines[1:]) elif language_type.lower() == "go": - end_words = ["\n//", "\nfunc main("] - for w in end_words: - if w in code: - code = code[:code.rfind(w)] - if '}' in code: - code = code[:code.rfind('}')] + '}' + method_lines = [] + bracket_left = 0 + bracket_right = 0 + for line in code.split("\n"): + counter = Counter(line) + if f"func {method_name}" in line: + method_lines.append(line) + bracket_left += counter["{"] + bracket_right += counter["}"] + continue + if method_lines: + method_lines.append(line) + bracket_left += counter["{"] + bracket_right += counter["}"] + if bracket_left == bracket_right and bracket_right > 0: + break + + if method_lines: + method_body = "\n".join(method_lines[1:]) + elif language_type.lower() == "cpp": - if '}' in code: - code = code[:code.rfind('}')] + '}' + method_lines = [] + bracket_left = 0 + bracket_right = 0 + for line in code.split("\n"): + counter = Counter(line) + if f" {method_name}(" in line: + method_lines.append(line) + bracket_left += counter["{"] + bracket_right += counter["}"] + continue + if method_lines: + method_lines.append(line) + bracket_left += counter["{"] + bracket_right += counter["}"] + if bracket_left == bracket_right and bracket_right > 0: + break + + if method_lines: + method_body = "\n".join(method_lines[1:]) + elif language_type.lower() == "js": - if '}' in code: - code = code[:code.rfind('}')] + '}' + method_lines = [] + bracket_left = 0 + bracket_right = 0 + for line in code.split("\n"): + counter = Counter(line) + if f"const {method_name}" in line: + method_lines.append(line) + bracket_left += counter["{"] + bracket_right += counter["}"] + continue + if method_lines: + method_lines.append(line) + bracket_left += counter["{"] + bracket_right += counter["}"] + if bracket_left == bracket_right and bracket_right > 0: + break + + if method_lines: + method_body = "\n".join(method_lines[1:]) + + return method_body + "\n" + + +def exception_reconnect(funct): + """异常重连""" + def wrapper_func(*args, **kwargs): + try: + return funct(*args, **kwargs) + except Exception as e: + print(f"exception will reconnect: {str(e)}") + return funct(*args, **kwargs) - return code + return wrapper_func \ No newline at end of file -- GitLab