diff --git a/README.md b/README.md
index 0e2d942776dec82b7aa5b116648664386e7296f5..cb994f67db70d9b163928e8b5788ceb4812dd26a 100644
--- a/README.md
+++ b/README.md
@@ -33,12 +33,12 @@ pass@k:=𝔼[1−(n-ck)/(nk)], n=200
样本使用JSON列表格式存储在codegeex/benchmark/humaneval-x/[LANG]/data/humaneval_[LANG].jsonl.gz,每条样本包含6个部分:
-task_id: 题目的目标语言与ID。语言为["Python", "Java", "JavaScript", "CPP", "Go"]中之一。
-prompt: 函数声明与描述,用于代码生成。
-declaration: 仅有函数声明,用于代码翻译。
-canonical_solution: 手写的示例解答。
-test: 隐藏测例,用于评测。
-example_test: 提示中出现的公开测例,用于评测。
+task_id: 题目的目标语言与ID。语言为["Python", "Java", "JavaScript", "CPP", "Go"]中之一。
+prompt: 函数声明与描述,用于代码生成。
+declaration: 仅有函数声明,用于代码翻译。
+canonical_solution: 手写的示例解答。
+test: 隐藏测例,用于评测。
+example_test: 提示中出现的公开测例,用于评测。
评测生成的代码需要使用多种语言编译、运行。我们使用的各编程语言依赖及所用包的版本如下:
@@ -67,41 +67,41 @@ example_test: 提示中出现的公开测例,用于评测。
## 运行命令
下面是一个使用chatgpt来生成python语言测试数据的样例:
-python generate_humaneval_x.py --input_path ../eval_set/humaneval-x
- --language_type python
- --model_name chatgpt
- --output_prefix ../output/humaneval
+python main.py --task_type generate
+ --input_path eval_set/humaneval-x
+ --language_type python
+ --model_name chatgpt
+ --output_prefix output/humaneval
评估样例:
-python evaluate_humaneval_x.py --language_type python
- --input_folder ../output
- --tmp_dir ../output/tmp/
- --n_workers 3
- --timeout 500.0
- --problem_folder ../eval_set/humaneval-x/
- --out_dir ../output/
- --k [1, 10, 100]
- --test_groundtruth False
- --example_test False
- --model_name chatgpt
+python main.py --task_type evaluate
+ --language_type python
+ --input_path output
+ --tmp_dir output/tmp/
+ --n_workers 1
+ --timeout 10.0
+ --problem_folder eval_set/humaneval-x/
+ --output_prefix output/
+ --model_name chatgpt
## 测试结果
受限于模型推理速度,目前只测试了pass@1指标。
-| | python | java | cpp | js | go |
-|--------------|--------|--------|--------|--------|---------|
-| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% |
-| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0% |
-| bbt-13B | 2.49% | 0% | 1.90% | 1.83% | 0.61% |
-| chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% |
-| codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% |
-| llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% |
+| | python | java | cpp | js | go |
+|--------------|--------|--------|--------|--------|--------|
+| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% |
+| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0.00% |
+| bbt-13B | 2.49% | 0.00% | 1.90% | 1.83% | 0.61% |
+| chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% |
+| codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% |
+| llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% |
+| baichuan-7B | 7.93% | 1.83% | 0.00% | 6.71% | 6.71% |
+
## TODO
-1、测试更多开源模型,例如百川,llama2,rwkv。
-2、测试模型的pass@10和pass@100指标。
-3、代码翻译类任务还没有适配,同时也需要构造相关的数据。
+1、测试模型的pass@10和pass@100指标。
+2、代码翻译类任务还没有适配,同时也需要构造相关的数据。
diff --git a/llm_set/params/baichuan.json b/llm_set/params/baichuan.json
new file mode 100644
index 0000000000000000000000000000000000000000..5b45fb5185ee0e132c7b7e46976d7c2e2bd1edf0
--- /dev/null
+++ b/llm_set/params/baichuan.json
@@ -0,0 +1,7 @@
+{
+ "model_path": "../llm_set/models/baichuan-7b",
+ "device": "cuda:0",
+ "quantize": false,
+ "max_length": 1024,
+ "min_length": 10
+}
\ No newline at end of file
diff --git a/main.py b/main.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..15dc2c0c1ca798ea8a3b0846b8103fa776873788 100644
--- a/main.py
+++ b/main.py
@@ -0,0 +1,95 @@
+
+
+import argparse
+
+
+def remove_none_params(params_dict):
+ params = {}
+ for key in params_dict:
+ if params_dict[key] is not None:
+ params[key] = params_dict[key]
+ return params
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--task_type",
+ type=str,
+ required=True,
+ help='supported tasks. [generate, translate, evaluate]'
+ )
+ parser.add_argument(
+ "--input_path",
+ type=str,
+ )
+ parser.add_argument(
+ "--language_type",
+ type=str,
+ default="python"
+ )
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="chatgpt",
+ help='supported models: [chatgpt, chatglm2, bbt, codegeex2, baichuan].'
+ )
+ parser.add_argument(
+ "--output_prefix",
+ type=str
+ )
+ parser.add_argument(
+ "--problem_folder",
+ type=str,
+ default="eval_set/humaneval-x/",
+ help='for evaluate, problem path.'
+ )
+ parser.add_argument(
+ "--tmp_dir",
+ type=str,
+ default="output/tmp/",
+ help='for evaluate, temp files path.'
+ )
+ parser.add_argument(
+ "--n_workers",
+ type=int,
+ default=1,
+ help='for evaluate, number of processes.'
+ )
+ parser.add_argument(
+ "--timeout",
+ type=float,
+ default=10.0,
+ help='for evaluate, single testing timeout time.'
+ )
+ args = parser.parse_known_args()[0]
+
+ task_type = args.task_type
+ if task_type == "generate":
+ from src.generate_humaneval_x import generate
+ params = remove_none_params({
+ "input_path": args.input_path,
+ "language_type": args.language_type,
+ "model_name": args.model_name,
+ "output_prefix": args.output_prefix
+ })
+ generate(**params)
+ elif task_type == "translate":
+ print("Not implemented.")
+
+ elif task_type == "evaluate":
+ from src.evaluate_humaneval_x import evaluate_functional_correctness
+ params = remove_none_params({
+ "input_folder": args.input_path,
+ "language_type": args.language_type,
+ "model_name": args.model_name,
+ "out_dir": args.output_prefix,
+ "problem_folder": args.problem_folder,
+ "tmp_dir": args.tmp_dir,
+ "n_workers": args.n_workers,
+ "timeout": args.timeout
+ })
+
+ evaluate_functional_correctness(**params)
+ else:
+ print(f"task_type: {task_type} not supported.")
diff --git a/src/evaluate_humaneval_x.py b/src/evaluate_humaneval_x.py
index 5ea019ab4b272626806597c94b641b72c175a6b9..71839ea3aa62ce07d224592d831412f318ac2130 100644
--- a/src/evaluate_humaneval_x.py
+++ b/src/evaluate_humaneval_x.py
@@ -11,9 +11,9 @@ from tqdm.auto import tqdm
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
-from utils import read_dataset, IMPORT_HELPER
-from metric import estimate_pass_at_k
-from execution import check_correctness
+from src.utils import read_dataset, IMPORT_HELPER
+from src.metric import estimate_pass_at_k
+from src.execution import check_correctness
LANGUAGE_NAME = {
"cpp": "CPP",
@@ -102,12 +102,12 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def evaluate_functional_correctness(
language_type: str = "python",
- input_folder: str = "../output",
- tmp_dir: str = "../output/tmp/",
+ input_folder: str = "output",
+ tmp_dir: str = "output/tmp/",
n_workers: int = 3,
- timeout: float = 500.0,
- problem_folder: str = "../eval_set/humaneval-x/",
- out_dir: str = "../output/",
+ timeout: float = 10.0,
+ problem_folder: str = "eval_set/humaneval-x/",
+ out_dir: str = "output/",
k: List[int] = [1, 10, 100],
test_groundtruth: bool = False,
example_test: bool = False,
diff --git a/src/generate_batch.sh b/src/generate_batch.sh
index e5491e8d716fb57497a0f869f5c8304346060555..acd8a023425d0069bbff90ea917a60ee9de910c7 100644
--- a/src/generate_batch.sh
+++ b/src/generate_batch.sh
@@ -1,5 +1,4 @@
-python generate_humaneval_x.py --model_name bbt --language_type python
-python generate_humaneval_x.py --model_name bbt --language_type java
-python generate_humaneval_x.py --model_name bbt --language_type cpp
-python generate_humaneval_x.py --model_name bbt --language_type js
-python generate_humaneval_x.py --model_name bbt --language_type go
\ No newline at end of file
+python generate_humaneval_x.py --model_name baichuan --language_type python
+python generate_humaneval_x.py --model_name baichuan --language_type cpp
+python generate_humaneval_x.py --model_name baichuan --language_type js
+python generate_humaneval_x.py --model_name baichuan --language_type go
\ No newline at end of file
diff --git a/src/generate_humaneval_x.py b/src/generate_humaneval_x.py
index 9e47a37d549cc6b5fa6a30ebc383590d379661ec..8d303bdd46860571eec7f241c3ff54f8d8d57915 100644
--- a/src/generate_humaneval_x.py
+++ b/src/generate_humaneval_x.py
@@ -1,8 +1,5 @@
-import argparse
+
import logging
-import os
-import random
-import socket
import time
import json
import torch
@@ -13,56 +10,39 @@ from utils import cleanup_code
logging.getLogger("torch").setLevel(logging.WARNING)
-if __name__ == "__main__":
+def generate(input_path="eval_set/humaneval-x",
+ language_type="python",
+ model_name="chatgpt",
+ output_prefix="output/humaneval"):
torch.multiprocessing.set_start_method("spawn")
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--input_path",
- type=str,
- default="../eval_set/humaneval-x"
- )
- parser.add_argument(
- "--language_type",
- type=str,
- default="python"
- )
- parser.add_argument(
- "--model_name",
- type=str,
- default="chatgpt",
- help='supported models in [chatgpt, chatglm2, bbt, codegeex2].'
- )
- parser.add_argument(
- "--output_prefix",
- type=str,
- default="../output/humaneval"
- )
- args = parser.parse_known_args()[0]
- output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl"
+ output_file_path = output_prefix + f"_{model_name}_{language_type}_finished.jsonl"
- entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval")
+ entries = read_dataset(data_folder=input_path, language_type=language_type, dataset_type="humaneval")
for entry in entries.values():
- entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)
+ entry["prompt"] = process_extra_prompt(entry["prompt"], language_type)
model_inference = None
- if args.model_name == "chatgpt":
+ if model_name == "chatgpt":
from inference.chatgpt_inference import ChatGPTInference
model_inference = ChatGPTInference()
- elif args.model_name == "chatglm2":
+ elif model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference
model_inference = GLMInference()
- elif args.model_name == "bbt":
+ elif model_name == "bbt":
from inference.bbt_inference import BBTInference
model_inference = BBTInference()
- elif args.model_name == "codegeex2":
+ elif model_name == "codegeex2":
from inference.codegeex2_inference import CodeGeex2Inference
model_inference = CodeGeex2Inference()
- elif args.model_name == "llama2":
+ elif model_name == "llama2":
from inference.llama2_inference import LLAMA2Inference
model_inference = LLAMA2Inference()
+ elif model_name == "baichuan":
+ from inference.baichuan_inference import BaiChuanInference
+ model_inference = BaiChuanInference()
else:
- print(f"{args.model_name} not supported.")
+ print(f"{model_name} not supported.")
start_time = time.perf_counter()
num_finished = 0
@@ -78,7 +58,7 @@ if __name__ == "__main__":
if generated_tokens is None:
print(f"task_id: {entry['task_id']} generate failed!")
continue
- generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type)
+ generated_code = cleanup_code(generated_tokens, entry, language_type=language_type)
f.write(
json.dumps(
{
diff --git a/src/inference/baichuan_inference.py b/src/inference/baichuan_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..61ebec2404bd679dfb4dac9f44a06cd0cee4c2e0
--- /dev/null
+++ b/src/inference/baichuan_inference.py
@@ -0,0 +1,60 @@
+
+
+import os
+import json
+import torch
+import logging
+from .inference import Inference
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+logger = logging.getLogger(__name__)
+
+
+class BaiChuanInference(Inference):
+
+ def __init__(self):
+ super(BaiChuanInference, self).__init__()
+ self.params_url = "../llm_set/params/baichuan.json"
+ self.paras_dict = self.get_params()
+ self.paras_dict.update(self.paras_base_dict)
+
+ self.temperature = self.paras_dict.get("temperature")
+ self.quantize = self.paras_dict.get("quantize")
+ self.device = self.paras_dict.get("device")
+ self.model_path = self.paras_dict.get("model_path")
+ self.DEV = torch.device(self.device)
+
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto",
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ torch_dtype=torch.float16)
+ self.max_length = self.paras_dict.get("max_length")
+ self.min_length = self.paras_dict.get("min_length")
+ self.top_p = self.paras_dict.get("top_p")
+ self.top_k = self.paras_dict.get("top_k")
+
+ def get_params(self):
+ if not os.path.exists(self.params_url):
+ logger.error(f"params_url:{self.params_url} is not exists.")
+ content = open(self.params_url).read()
+ return json.loads(content)
+
+ def inference(self, message):
+ input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV)
+
+ sentence = None
+ with torch.no_grad():
+ generation_output = self.model.generate(
+ input_ids,
+ do_sample=True,
+ min_length=self.min_length,
+ max_length=self.max_length,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ temperature=self.temperature
+ )
+ output = self.tokenizer.decode([el.item() for el in generation_output[0]])
+ sentence = output.strip()
+ return sentence
diff --git a/src/utils.py b/src/utils.py
index 2043e7027011f55e69728a8e5185b0f487bcf6d5..59dc3ef527bef4021cf4061ab2b5f138fbd4b861 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,7 +1,7 @@
import os
import re
from typing import *
-from data_utils import stream_jsonl, LANGUAGE_TAG
+from src.data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter