提交 41eaf869 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

add baichuan

上级 40c01206
...@@ -67,22 +67,21 @@ example_test: 提示中出现的公开测例,用于评测。 ...@@ -67,22 +67,21 @@ example_test: 提示中出现的公开测例,用于评测。
## 运行命令 ## 运行命令
下面是一个使用chatgpt来生成python语言测试数据的样例: 下面是一个使用chatgpt来生成python语言测试数据的样例:
python generate_humaneval_x.py --input_path ../eval_set/humaneval-x python main.py --task_type generate
--input_path eval_set/humaneval-x
--language_type python --language_type python
--model_name chatgpt --model_name chatgpt
--output_prefix ../output/humaneval --output_prefix output/humaneval
评估样例: 评估样例:
python evaluate_humaneval_x.py --language_type python python main.py --task_type evaluate
--input_folder ../output --language_type python
--tmp_dir ../output/tmp/ --input_path output
--n_workers 3 --tmp_dir output/tmp/
--timeout 500.0 --n_workers 1
--problem_folder ../eval_set/humaneval-x/ --timeout 10.0
--out_dir ../output/ --problem_folder eval_set/humaneval-x/
--k [1, 10, 100] --output_prefix output/
--test_groundtruth False
--example_test False
--model_name chatgpt --model_name chatgpt
## 测试结果 ## 测试结果
...@@ -90,18 +89,19 @@ python evaluate_humaneval_x.py --language_type python ...@@ -90,18 +89,19 @@ python evaluate_humaneval_x.py --language_type python
受限于模型推理速度,目前只测试了pass@1指标。 受限于模型推理速度,目前只测试了pass@1指标。
| | python | java | cpp | js | go | | | python | java | cpp | js | go |
|--------------|--------|--------|--------|--------|---------| |--------------|--------|--------|--------|--------|--------|
| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% | | chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% |
| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0% | | bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0.00% |
| bbt-13B | 2.49% | 0% | 1.90% | 1.83% | 0.61% | | bbt-13B | 2.49% | 0.00% | 1.90% | 1.83% | 0.61% |
| chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% | | chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% |
| codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% | | codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% |
| llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% | | llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% |
| baichuan-7B | 7.93% | 1.83% | 0.00% | 6.71% | 6.71% |
## TODO ## TODO
1、测试更多开源模型,例如百川,llama2,rwkv。 1、测试模型的pass@10和pass@100指标。
2、测试模型的pass@10和pass@100指标。 2、代码翻译类任务还没有适配,同时也需要构造相关的数据。
3、代码翻译类任务还没有适配,同时也需要构造相关的数据。
{
"model_path": "../llm_set/models/baichuan-7b",
"device": "cuda:0",
"quantize": false,
"max_length": 1024,
"min_length": 10
}
\ No newline at end of file
import argparse
def remove_none_params(params_dict):
params = {}
for key in params_dict:
if params_dict[key] is not None:
params[key] = params_dict[key]
return params
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--task_type",
type=str,
required=True,
help='supported tasks. [generate, translate, evaluate]'
)
parser.add_argument(
"--input_path",
type=str,
)
parser.add_argument(
"--language_type",
type=str,
default="python"
)
parser.add_argument(
"--model_name",
type=str,
default="chatgpt",
help='supported models: [chatgpt, chatglm2, bbt, codegeex2, baichuan].'
)
parser.add_argument(
"--output_prefix",
type=str
)
parser.add_argument(
"--problem_folder",
type=str,
default="eval_set/humaneval-x/",
help='for evaluate, problem path.'
)
parser.add_argument(
"--tmp_dir",
type=str,
default="output/tmp/",
help='for evaluate, temp files path.'
)
parser.add_argument(
"--n_workers",
type=int,
default=1,
help='for evaluate, number of processes.'
)
parser.add_argument(
"--timeout",
type=float,
default=10.0,
help='for evaluate, single testing timeout time.'
)
args = parser.parse_known_args()[0]
task_type = args.task_type
if task_type == "generate":
from src.generate_humaneval_x import generate
params = remove_none_params({
"input_path": args.input_path,
"language_type": args.language_type,
"model_name": args.model_name,
"output_prefix": args.output_prefix
})
generate(**params)
elif task_type == "translate":
print("Not implemented.")
elif task_type == "evaluate":
from src.evaluate_humaneval_x import evaluate_functional_correctness
params = remove_none_params({
"input_folder": args.input_path,
"language_type": args.language_type,
"model_name": args.model_name,
"out_dir": args.output_prefix,
"problem_folder": args.problem_folder,
"tmp_dir": args.tmp_dir,
"n_workers": args.n_workers,
"timeout": args.timeout
})
evaluate_functional_correctness(**params)
else:
print(f"task_type: {task_type} not supported.")
...@@ -11,9 +11,9 @@ from tqdm.auto import tqdm ...@@ -11,9 +11,9 @@ from tqdm.auto import tqdm
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from utils import read_dataset, IMPORT_HELPER from src.utils import read_dataset, IMPORT_HELPER
from metric import estimate_pass_at_k from src.metric import estimate_pass_at_k
from execution import check_correctness from src.execution import check_correctness
LANGUAGE_NAME = { LANGUAGE_NAME = {
"cpp": "CPP", "cpp": "CPP",
...@@ -102,12 +102,12 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]: ...@@ -102,12 +102,12 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def evaluate_functional_correctness( def evaluate_functional_correctness(
language_type: str = "python", language_type: str = "python",
input_folder: str = "../output", input_folder: str = "output",
tmp_dir: str = "../output/tmp/", tmp_dir: str = "output/tmp/",
n_workers: int = 3, n_workers: int = 3,
timeout: float = 500.0, timeout: float = 10.0,
problem_folder: str = "../eval_set/humaneval-x/", problem_folder: str = "eval_set/humaneval-x/",
out_dir: str = "../output/", out_dir: str = "output/",
k: List[int] = [1, 10, 100], k: List[int] = [1, 10, 100],
test_groundtruth: bool = False, test_groundtruth: bool = False,
example_test: bool = False, example_test: bool = False,
......
python generate_humaneval_x.py --model_name bbt --language_type python python generate_humaneval_x.py --model_name baichuan --language_type python
python generate_humaneval_x.py --model_name bbt --language_type java python generate_humaneval_x.py --model_name baichuan --language_type cpp
python generate_humaneval_x.py --model_name bbt --language_type cpp python generate_humaneval_x.py --model_name baichuan --language_type js
python generate_humaneval_x.py --model_name bbt --language_type js python generate_humaneval_x.py --model_name baichuan --language_type go
python generate_humaneval_x.py --model_name bbt --language_type go \ No newline at end of file
\ No newline at end of file
import argparse
import logging import logging
import os
import random
import socket
import time import time
import json import json
import torch import torch
...@@ -13,56 +10,39 @@ from utils import cleanup_code ...@@ -13,56 +10,39 @@ from utils import cleanup_code
logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("torch").setLevel(logging.WARNING)
if __name__ == "__main__": def generate(input_path="eval_set/humaneval-x",
language_type="python",
model_name="chatgpt",
output_prefix="output/humaneval"):
torch.multiprocessing.set_start_method("spawn") torch.multiprocessing.set_start_method("spawn")
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path",
type=str,
default="../eval_set/humaneval-x"
)
parser.add_argument(
"--language_type",
type=str,
default="python"
)
parser.add_argument(
"--model_name",
type=str,
default="chatgpt",
help='supported models in [chatgpt, chatglm2, bbt, codegeex2].'
)
parser.add_argument(
"--output_prefix",
type=str,
default="../output/humaneval"
)
args = parser.parse_known_args()[0]
output_file_path = args.output_prefix + f"_{args.model_name}_{args.language_type}_finished.jsonl" output_file_path = output_prefix + f"_{model_name}_{language_type}_finished.jsonl"
entries = read_dataset(data_folder=args.input_path, language_type=args.language_type, dataset_type="humaneval") entries = read_dataset(data_folder=input_path, language_type=language_type, dataset_type="humaneval")
for entry in entries.values(): for entry in entries.values():
entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type) entry["prompt"] = process_extra_prompt(entry["prompt"], language_type)
model_inference = None model_inference = None
if args.model_name == "chatgpt": if model_name == "chatgpt":
from inference.chatgpt_inference import ChatGPTInference from inference.chatgpt_inference import ChatGPTInference
model_inference = ChatGPTInference() model_inference = ChatGPTInference()
elif args.model_name == "chatglm2": elif model_name == "chatglm2":
from inference.chatglm2_inference import GLMInference from inference.chatglm2_inference import GLMInference
model_inference = GLMInference() model_inference = GLMInference()
elif args.model_name == "bbt": elif model_name == "bbt":
from inference.bbt_inference import BBTInference from inference.bbt_inference import BBTInference
model_inference = BBTInference() model_inference = BBTInference()
elif args.model_name == "codegeex2": elif model_name == "codegeex2":
from inference.codegeex2_inference import CodeGeex2Inference from inference.codegeex2_inference import CodeGeex2Inference
model_inference = CodeGeex2Inference() model_inference = CodeGeex2Inference()
elif args.model_name == "llama2": elif model_name == "llama2":
from inference.llama2_inference import LLAMA2Inference from inference.llama2_inference import LLAMA2Inference
model_inference = LLAMA2Inference() model_inference = LLAMA2Inference()
elif model_name == "baichuan":
from inference.baichuan_inference import BaiChuanInference
model_inference = BaiChuanInference()
else: else:
print(f"{args.model_name} not supported.") print(f"{model_name} not supported.")
start_time = time.perf_counter() start_time = time.perf_counter()
num_finished = 0 num_finished = 0
...@@ -78,7 +58,7 @@ if __name__ == "__main__": ...@@ -78,7 +58,7 @@ if __name__ == "__main__":
if generated_tokens is None: if generated_tokens is None:
print(f"task_id: {entry['task_id']} generate failed!") print(f"task_id: {entry['task_id']} generate failed!")
continue continue
generated_code = cleanup_code(generated_tokens, entry, language_type=args.language_type) generated_code = cleanup_code(generated_tokens, entry, language_type=language_type)
f.write( f.write(
json.dumps( json.dumps(
{ {
......
import os
import json
import torch
import logging
from .inference import Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
logger = logging.getLogger(__name__)
class BaiChuanInference(Inference):
def __init__(self):
super(BaiChuanInference, self).__init__()
self.params_url = "../llm_set/params/baichuan.json"
self.paras_dict = self.get_params()
self.paras_dict.update(self.paras_base_dict)
self.temperature = self.paras_dict.get("temperature")
self.quantize = self.paras_dict.get("quantize")
self.device = self.paras_dict.get("device")
self.model_path = self.paras_dict.get("model_path")
self.DEV = torch.device(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16)
self.max_length = self.paras_dict.get("max_length")
self.min_length = self.paras_dict.get("min_length")
self.top_p = self.paras_dict.get("top_p")
self.top_k = self.paras_dict.get("top_k")
def get_params(self):
if not os.path.exists(self.params_url):
logger.error(f"params_url:{self.params_url} is not exists.")
content = open(self.params_url).read()
return json.loads(content)
def inference(self, message):
input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV)
sentence = None
with torch.no_grad():
generation_output = self.model.generate(
input_ids,
do_sample=True,
min_length=self.min_length,
max_length=self.max_length,
top_p=self.top_p,
top_k=self.top_k,
temperature=self.temperature
)
output = self.tokenizer.decode([el.item() for el in generation_output[0]])
sentence = output.strip()
return sentence
import os import os
import re import re
from typing import * from typing import *
from data_utils import stream_jsonl, LANGUAGE_TAG from src.data_utils import stream_jsonl, LANGUAGE_TAG
from collections import Counter from collections import Counter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册