提交 41eaf869 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

add baichuan

上级 40c01206
......@@ -67,22 +67,21 @@ example_test: 提示中出现的公开测例,用于评测。
## 运行命令
下面是一个使用chatgpt来生成python语言测试数据的样例:
python generate_humaneval_x.py --input_path ../eval_set/humaneval-x
python main.py --task_type generate
--input_path eval_set/humaneval-x
--language_type python
--model_name chatgpt
--output_prefix ../output/humaneval
--output_prefix output/humaneval
评估样例:
python evaluate_humaneval_x.py --language_type python
--input_folder ../output
--tmp_dir ../output/tmp/
--n_workers 3
--timeout 500.0
--problem_folder ../eval_set/humaneval-x/
--out_dir ../output/
--k [1, 10, 100]
--test_groundtruth False
--example_test False
python main.py --task_type evaluate
--language_type python
--input_path output
--tmp_dir output/tmp/
--n_workers 1
--timeout 10.0
--problem_folder eval_set/humaneval-x/
--output_prefix output/
--model_name chatgpt
## 测试结果
......@@ -90,18 +89,19 @@ python evaluate_humaneval_x.py --language_type python
受限于模型推理速度,目前只测试了pass@1指标。
| | python | java | cpp | js | go |
|--------------|--------|--------|--------|--------|---------|
|--------------|--------|--------|--------|--------|--------|
| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% |
| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0% |
| bbt-13B | 2.49% | 0% | 1.90% | 1.83% | 0.61% |
| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0.00% |
| bbt-13B | 2.49% | 0.00% | 1.90% | 1.83% | 0.61% |
| chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% |
| codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% |
| llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% |
| baichuan-7B | 7.93% | 1.83% | 0.00% | 6.71% | 6.71% |
## TODO
1、测试更多开源模型,例如百川,llama2,rwkv。
2、测试模型的pass@10和pass@100指标。
3、代码翻译类任务还没有适配,同时也需要构造相关的数据。
1、测试模型的pass@10和pass@100指标。
2、代码翻译类任务还没有适配,同时也需要构造相关的数据。
{
"model_path": "../llm_set/models/baichuan-7b",
"device": "cuda:0",
"quantize": false,
"max_length": 1024,
"min_length": 10
}
\ No newline at end of file
import argparse
def remove_none_params(params_dict):
params = {}
for key in params_dict:
if params_dict[key] is not None:
params[key] = params_dict[key]
return params
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--task_type",
type=str,
required=True,
help='supported tasks. [generate, translate, evaluate]'
)
parser.add_argument(
"--input_path",
type=str,
)
parser.add_argument(
"--language_type",
type=str,
default="python"
)
parser.add_argument(
"--model_name",
type=str,
default="chatgpt",
help='supported models: [chatgpt, chatglm2, bbt, codegeex2, baichuan].'
)
parser.add_argument(
"--output_prefix",
type=str
)
parser.add_argument(
"--problem_folder",
type=str,
default="eval_set/humaneval-x/",
help='for evaluate, problem path.'
)
parser.add_argument(
"--tmp_dir",
type=str,
default="output/tmp/",
help='for evaluate, temp files path.'
)
parser.add_argument(
"--n_workers",
type=int,
default=1,
help='for evaluate, number of processes.'
)
parser.add_argument(
"--timeout",
type=float,
default=10.0,
help='for evaluate, single testing timeout time.'
)
args = parser.parse_known_args()[0]
task_type = args.task_type
if task_type == "generate":
from src.generate_humaneval_x import generate
params = remove_none_params({
"input_path": args.input_path,
"language_type": args.language_type,
"model_name": args.model_name,
"output_prefix": args.output_prefix
})
generate(**params)
elif task_type == "translate":
print("Not implemented.")
elif task_type == "evaluate":
from src.evaluate_humaneval_x import evaluate_functional_correctness
params = remove_none_params({
"input_folder": args.input_path,
"language_type": args.language_type,
"model_name": args.model_name,
"out_dir": args.output_prefix,
"problem_folder": args.problem_folder,
"tmp_dir": args.tmp_dir,
"n_workers": args.n_workers,
"timeout": args.timeout
})
evaluate_functional_correctness(**params)
else:
print(f"task_type: {task_type} not supported.")
......@@ -11,9 +11,9 @@ from tqdm.auto import tqdm
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from utils import read_dataset, IMPORT_HELPER
from metric import estimate_pass_at_k
from execution import check_correctness
from src.utils import read_dataset, IMPORT_HELPER
from src.metric import estimate_pass_at_k
from src.execution import check_correctness
LANGUAGE_NAME = {
"cpp": "CPP",
......@@ -102,12 +102,12 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def evaluate_functional_correctness(
language_type: str = "python",
input_folder: str = "../output",
tmp_dir: str = "../output/tmp/",
input_folder: str = "output",
tmp_dir: str = "output/tmp/",
n_workers: int = 3,
timeout: float = 500.0,
problem_folder: str = "../eval_set/humaneval-x/",
out_dir: str = "../output/",
timeout: float = 10.0,
problem_folder: str = "eval_set/humaneval-x/",
out_dir: str = "output/",
k: List[int] = [1, 10, 100],
test_groundtruth: bool = False,
example_test: bool = False,
......
python generate_humaneval_x.py --model_name bbt --language_type python
python generate_humaneval_x.py --model_name bbt --language_type java
python generate_humaneval_x.py --model_name bbt --language_type cpp
python generate_humaneval_x.py --model_name bbt --language_type js
python generate_humaneval_x.py --model_name bbt --language_type go
\ No newline at end of file
python generate_humaneval_x.py --model_name baichuan --language_type python
python generate_humaneval_x.py --model_name baichuan --language_type cpp
python generate_humaneval_x.py --model_name baichuan --language_type js
python generate_humaneval_x.py --model_name baichuan --language_type go
\ No newline at end of file
import argparse
import logging
import os
import random
import socket
import time
import json
import torch
......@@ -13,56 +10,39 @@ from utils import cleanup_code
logging.getLogger("torch").setLevel(logging.WARNING)
if __name__ == "__main__":
def generate(input_path="eval_set/humaneval-x",
language_type="python",
model_name="chatgpt",
output_prefix="output/humaneval"):
torch.multiprocessing.set_start_method("spawn")
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path",
type=str,
default="../eval_set/humaneval-x"
)
parser.add_argument(
"--language_type",
type=str,
default="python"
)
parser.add_argument(
"--model_name",
type=str,
default="chatgpt",
help='supported models in [chatgpt, chatglm2, bbt, codegeex2].'
)
parser.add_argument(
"--output_prefix",
type=str,
default="../output/humaneval"
)
args = parser.parse_known_args()[0]
output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl"
output_file_path = output_prefix + f"_{model_name}_{language_type}_finished.jsonl"
entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval")
entries = read_dataset(data_folder=input_path, language_type=language_type, dataset_type="humaneval")
for entry in entries.values():
entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type)
entry["prompt"] = process_extra_prompt(entry["prompt"], language_type)
model_inference = None
if args.model_name == "chatgpt":
if model_name == "chatgpt":
from inference.chatgpt_inference import ChatGPTInference
model_inference = ChatGPTInference()
elif args.model_name == "chatglm2":
elif model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference
model_inference = GLMInference()
elif args.model_name == "bbt":
elif model_name == "bbt":
from inference.bbt_inference import BBTInference
model_inference = BBTInference()
elif args.model_name == "codegeex2":
elif model_name == "codegeex2":
from inference.codegeex2_inference import CodeGeex2Inference
model_inference = CodeGeex2Inference()
elif args.model_name == "llama2":
elif model_name == "llama2":
from inference.llama2_inference import LLAMA2Inference
model_inference = LLAMA2Inference()
elif model_name == "baichuan":
from inference.baichuan_inference import BaiChuanInference
model_inference = BaiChuanInference()
else:
print(f"{args.model_name} not supported.")
print(f"{model_name} not supported.")
start_time = time.perf_counter()
num_finished = 0
......@@ -78,7 +58,7 @@ if __name__ == "__main__":
if generated_tokens is None:
print(f"task_id: {entry['task_id']} generate failed!")
continue
generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type)
generated_code = cleanup_code(generated_tokens, entry, language_type=language_type)
f.write(
json.dumps(
{
......
import os
import json
import torch
import logging
from .inference import Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
logger = logging.getLogger(__name__)
class BaiChuanInference(Inference):
def __init__(self):
super(BaiChuanInference, self).__init__()
self.params_url = "../llm_set/params/baichuan.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
self.temperature = self.paras_dict.get("temperature")
self.quantize = self.paras_dict.get("quantize")
self.device = self.paras_dict.get("device")
self.model_path = self.paras_dict.get("model_path")
self.DEV = torch.device(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16)
self.max_length = self.paras_dict.get("max_length")
self.min_length = self.paras_dict.get("min_length")
self.top_p = self.paras_dict.get("top_p")
self.top_k = self.paras_dict.get("top_k")
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
def inference(self, message):
input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV)
sentence = None
with torch.no_grad():
generation_output = self.model.generate(
input_ids,
do_sample=True,
min_length=self.min_length,
max_length=self.max_length,
top_p=self.top_p,
top_k=self.top_k,
temperature=self.temperature
)
output = self.tokenizer.decode([el.item() for el in generation_output[0]])
sentence = output.strip()
return sentence
import os
import re
from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG
from src.data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册