diff --git a/README.md b/README.md index 0e2d942776dec82b7aa5b116648664386e7296f5..cb994f67db70d9b163928e8b5788ceb4812dd26a 100644 --- a/README.md +++ b/README.md @@ -33,12 +33,12 @@ pass@k:=𝔼[1−(n-ck)/(nk)], n=200 样本使用JSON列表格式存储在codegeex/benchmark/humaneval-x/[LANG]/data/humaneval_[LANG].jsonl.gz,每条样本包含6个部分: -task_id: 题目的目标语言与ID。语言为["Python", "Java", "JavaScript", "CPP", "Go"]中之一。 -prompt: 函数声明与描述,用于代码生成。 -declaration: 仅有函数声明,用于代码翻译。 -canonical_solution: 手写的示例解答。 -test: 隐藏测例,用于评测。 -example_test: 提示中出现的公开测例,用于评测。 +task_id: 题目的目标语言与ID。语言为["Python", "Java", "JavaScript", "CPP", "Go"]中之一。 +prompt: 函数声明与描述,用于代码生成。 +declaration: 仅有函数声明,用于代码翻译。 +canonical_solution: 手写的示例解答。 +test: 隐藏测例,用于评测。 +example_test: 提示中出现的公开测例,用于评测。 评测生成的代码需要使用多种语言编译、运行。我们使用的各编程语言依赖及所用包的版本如下: @@ -67,41 +67,41 @@ example_test: 提示中出现的公开测例,用于评测。 ## 运行命令 下面是一个使用chatgpt来生成python语言测试数据的样例: -python generate_humaneval_x.py --input_path ../eval_set/humaneval-x - --language_type python - --model_name chatgpt - --output_prefix ../output/humaneval +python main.py --task_type generate + --input_path eval_set/humaneval-x + --language_type python + --model_name chatgpt + --output_prefix output/humaneval 评估样例: -python evaluate_humaneval_x.py --language_type python - --input_folder ../output - --tmp_dir ../output/tmp/ - --n_workers 3 - --timeout 500.0 - --problem_folder ../eval_set/humaneval-x/ - --out_dir ../output/ - --k [1, 10, 100] - --test_groundtruth False - --example_test False - --model_name chatgpt +python main.py --task_type evaluate + --language_type python + --input_path output + --tmp_dir output/tmp/ + --n_workers 1 + --timeout 10.0 + --problem_folder eval_set/humaneval-x/ + --output_prefix output/ + --model_name chatgpt ## 测试结果 受限于模型推理速度,目前只测试了pass@1指标。 -| | python | java | cpp | js | go | -|--------------|--------|--------|--------|--------|---------| -| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% | -| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0% | -| bbt-13B | 2.49% | 0% | 1.90% | 1.83% | 0.61% | -| chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% | -| codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% | -| llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% | +| | python | java | cpp | js | go | +|--------------|--------|--------|--------|--------|--------| +| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% | +| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0.00% | +| bbt-13B | 2.49% | 0.00% | 1.90% | 1.83% | 0.61% | +| chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% | +| codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% | +| llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% | +| baichuan-7B | 7.93% | 1.83% | 0.00% | 6.71% | 6.71% | + ## TODO -1、测试更多开源模型,例如百川,llama2,rwkv。 -2、测试模型的pass@10和pass@100指标。 -3、代码翻译类任务还没有适配,同时也需要构造相关的数据。 +1、测试模型的pass@10和pass@100指标。 +2、代码翻译类任务还没有适配,同时也需要构造相关的数据。 diff --git a/llm_set/params/baichuan.json b/llm_set/params/baichuan.json new file mode 100644 index 0000000000000000000000000000000000000000..5b45fb5185ee0e132c7b7e46976d7c2e2bd1edf0 --- /dev/null +++ b/llm_set/params/baichuan.json @@ -0,0 +1,7 @@ +{ + "model_path": "../llm_set/models/baichuan-7b", + "device": "cuda:0", + "quantize": false, + "max_length": 1024, + "min_length": 10 +} \ No newline at end of file diff --git a/main.py b/main.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..15dc2c0c1ca798ea8a3b0846b8103fa776873788 100644 --- a/main.py +++ b/main.py @@ -0,0 +1,95 @@ + + +import argparse + + +def remove_none_params(params_dict): + params = {} + for key in params_dict: + if params_dict[key] is not None: + params[key] = params_dict[key] + return params + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--task_type", + type=str, + required=True, + help='supported tasks. [generate, translate, evaluate]' + ) + parser.add_argument( + "--input_path", + type=str, + ) + parser.add_argument( + "--language_type", + type=str, + default="python" + ) + parser.add_argument( + "--model_name", + type=str, + default="chatgpt", + help='supported models: [chatgpt, chatglm2, bbt, codegeex2, baichuan].' + ) + parser.add_argument( + "--output_prefix", + type=str + ) + parser.add_argument( + "--problem_folder", + type=str, + default="eval_set/humaneval-x/", + help='for evaluate, problem path.' + ) + parser.add_argument( + "--tmp_dir", + type=str, + default="output/tmp/", + help='for evaluate, temp files path.' + ) + parser.add_argument( + "--n_workers", + type=int, + default=1, + help='for evaluate, number of processes.' + ) + parser.add_argument( + "--timeout", + type=float, + default=10.0, + help='for evaluate, single testing timeout time.' + ) + args = parser.parse_known_args()[0] + + task_type = args.task_type + if task_type == "generate": + from src.generate_humaneval_x import generate + params = remove_none_params({ + "input_path": args.input_path, + "language_type": args.language_type, + "model_name": args.model_name, + "output_prefix": args.output_prefix + }) + generate(**params) + elif task_type == "translate": + print("Not implemented.") + + elif task_type == "evaluate": + from src.evaluate_humaneval_x import evaluate_functional_correctness + params = remove_none_params({ + "input_folder": args.input_path, + "language_type": args.language_type, + "model_name": args.model_name, + "out_dir": args.output_prefix, + "problem_folder": args.problem_folder, + "tmp_dir": args.tmp_dir, + "n_workers": args.n_workers, + "timeout": args.timeout + }) + + evaluate_functional_correctness(**params) + else: + print(f"task_type: {task_type} not supported.") diff --git a/src/evaluate_humaneval_x.py b/src/evaluate_humaneval_x.py index 5ea019ab4b272626806597c94b641b72c175a6b9..71839ea3aa62ce07d224592d831412f318ac2130 100644 --- a/src/evaluate_humaneval_x.py +++ b/src/evaluate_humaneval_x.py @@ -11,9 +11,9 @@ from tqdm.auto import tqdm from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed -from utils import read_dataset, IMPORT_HELPER -from metric import estimate_pass_at_k -from execution import check_correctness +from src.utils import read_dataset, IMPORT_HELPER +from src.metric import estimate_pass_at_k +from src.execution import check_correctness LANGUAGE_NAME = { "cpp": "CPP", @@ -102,12 +102,12 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]: def evaluate_functional_correctness( language_type: str = "python", - input_folder: str = "../output", - tmp_dir: str = "../output/tmp/", + input_folder: str = "output", + tmp_dir: str = "output/tmp/", n_workers: int = 3, - timeout: float = 500.0, - problem_folder: str = "../eval_set/humaneval-x/", - out_dir: str = "../output/", + timeout: float = 10.0, + problem_folder: str = "eval_set/humaneval-x/", + out_dir: str = "output/", k: List[int] = [1, 10, 100], test_groundtruth: bool = False, example_test: bool = False, diff --git a/src/generate_batch.sh b/src/generate_batch.sh index e5491e8d716fb57497a0f869f5c8304346060555..acd8a023425d0069bbff90ea917a60ee9de910c7 100644 --- a/src/generate_batch.sh +++ b/src/generate_batch.sh @@ -1,5 +1,4 @@ -python generate_humaneval_x.py --model_name bbt --language_type python -python generate_humaneval_x.py --model_name bbt --language_type java -python generate_humaneval_x.py --model_name bbt --language_type cpp -python generate_humaneval_x.py --model_name bbt --language_type js -python generate_humaneval_x.py --model_name bbt --language_type go \ No newline at end of file +python generate_humaneval_x.py --model_name baichuan --language_type python +python generate_humaneval_x.py --model_name baichuan --language_type cpp +python generate_humaneval_x.py --model_name baichuan --language_type js +python generate_humaneval_x.py --model_name baichuan --language_type go \ No newline at end of file diff --git a/src/generate_humaneval_x.py b/src/generate_humaneval_x.py index 9e47a37d549cc6b5fa6a30ebc383590d379661ec..8d303bdd46860571eec7f241c3ff54f8d8d57915 100644 --- a/src/generate_humaneval_x.py +++ b/src/generate_humaneval_x.py @@ -1,8 +1,5 @@ -import argparse + import logging -import os -import random -import socket import time import json import torch @@ -13,56 +10,39 @@ from utils import cleanup_code logging.getLogger("torch").setLevel(logging.WARNING) -if __name__ == "__main__": +def generate(input_path="eval_set/humaneval-x", + language_type="python", + model_name="chatgpt", + output_prefix="output/humaneval"): torch.multiprocessing.set_start_method("spawn") - parser = argparse.ArgumentParser() - parser.add_argument( - "--input_path", - type=str, - default="../eval_set/humaneval-x" - ) - parser.add_argument( - "--language_type", - type=str, - default="python" - ) - parser.add_argument( - "--model_name", - type=str, - default="chatgpt", - help='supported models in [chatgpt, chatglm2, bbt, codegeex2].' - ) - parser.add_argument( - "--output_prefix", - type=str, - default="../output/humaneval" - ) - args = parser.parse_known_args()[0] - output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl" + output_file_path = output_prefix + f"_{model_name}_{language_type}_finished.jsonl" - entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval") + entries = read_dataset(data_folder=input_path, language_type=language_type, dataset_type="humaneval") for entry in entries.values(): - entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type) + entry["prompt"] = process_extra_prompt(entry["prompt"], language_type) model_inference = None - if args.model_name == "chatgpt": + if model_name == "chatgpt": from inference.chatgpt_inference import ChatGPTInference model_inference = ChatGPTInference() - elif args.model_name == "chatglm2": + elif model_name == "chatglm2": from inference.chatglm2_inference import GLMInference model_inference = GLMInference() - elif args.model_name == "bbt": + elif model_name == "bbt": from inference.bbt_inference import BBTInference model_inference = BBTInference() - elif args.model_name == "codegeex2": + elif model_name == "codegeex2": from inference.codegeex2_inference import CodeGeex2Inference model_inference = CodeGeex2Inference() - elif args.model_name == "llama2": + elif model_name == "llama2": from inference.llama2_inference import LLAMA2Inference model_inference = LLAMA2Inference() + elif model_name == "baichuan": + from inference.baichuan_inference import BaiChuanInference + model_inference = BaiChuanInference() else: - print(f"{args.model_name} not supported.") + print(f"{model_name} not supported.") start_time = time.perf_counter() num_finished = 0 @@ -78,7 +58,7 @@ if __name__ == "__main__": if generated_tokens is None: print(f"task_id: {entry['task_id']} generate failed!") continue - generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type) + generated_code = cleanup_code(generated_tokens, entry, language_type=language_type) f.write( json.dumps( { diff --git a/src/inference/baichuan_inference.py b/src/inference/baichuan_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..61ebec2404bd679dfb4dac9f44a06cd0cee4c2e0 --- /dev/null +++ b/src/inference/baichuan_inference.py @@ -0,0 +1,60 @@ + + +import os +import json +import torch +import logging +from .inference import Inference +from transformers import AutoModelForCausalLM, AutoTokenizer + + +logger = logging.getLogger(__name__) + + +class BaiChuanInference(Inference): + + def __init__(self): + super(BaiChuanInference, self).__init__() + self.params_url = "../llm_set/params/baichuan.json" + self.paras_dict = self.get_params() + self.paras_dict.update(self.paras_base_dict) + + self.temperature = self.paras_dict.get("temperature") + self.quantize = self.paras_dict.get("quantize") + self.device = self.paras_dict.get("device") + self.model_path = self.paras_dict.get("model_path") + self.DEV = torch.device(self.device) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto", + trust_remote_code=True, + low_cpu_mem_usage=True, + torch_dtype=torch.float16) + self.max_length = self.paras_dict.get("max_length") + self.min_length = self.paras_dict.get("min_length") + self.top_p = self.paras_dict.get("top_p") + self.top_k = self.paras_dict.get("top_k") + + def get_params(self): + if not os.path.exists(self.params_url): + logger.error(f"params_url:{self.params_url} is not exists.") + content = open(self.params_url).read() + return json.loads(content) + + def inference(self, message): + input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV) + + sentence = None + with torch.no_grad(): + generation_output = self.model.generate( + input_ids, + do_sample=True, + min_length=self.min_length, + max_length=self.max_length, + top_p=self.top_p, + top_k=self.top_k, + temperature=self.temperature + ) + output = self.tokenizer.decode([el.item() for el in generation_output[0]]) + sentence = output.strip() + return sentence diff --git a/src/utils.py b/src/utils.py index 2043e7027011f55e69728a8e5185b0f487bcf6d5..59dc3ef527bef4021cf4061ab2b5f138fbd4b861 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,7 +1,7 @@ import os import re from typing import * -from data_utils import stream_jsonl, LANGUAGE_TAG +from src.data_utils import stream_jsonl, LANGUAGE_TAG from collections import Counter