Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
llm-coding-eval
提交
41eaf869
L
llm-coding-eval
项目概览
CSDN 技术社区
/
llm-coding-eval
通知
19
Star
3
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
1
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
L
llm-coding-eval
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
41eaf869
编写于
7月 27, 2023
作者:
CSDN-Ada助手
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add baichuan
上级
40c01206
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
225 addition
and
84 deletion
+225
-84
README.md
README.md
+32
-32
llm_set/params/baichuan.json
llm_set/params/baichuan.json
+7
-0
main.py
main.py
+95
-0
src/evaluate_humaneval_x.py
src/evaluate_humaneval_x.py
+8
-8
src/generate_batch.sh
src/generate_batch.sh
+4
-5
src/generate_humaneval_x.py
src/generate_humaneval_x.py
+18
-38
src/inference/baichuan_inference.py
src/inference/baichuan_inference.py
+60
-0
src/utils.py
src/utils.py
+1
-1
未找到文件。
README.md
浏览文件 @
41eaf869
...
...
@@ -67,22 +67,21 @@ example_test: 提示中出现的公开测例,用于评测。
## 运行命令
下面是一个使用chatgpt来生成python语言测试数据的样例:
python generate_humaneval_x.py --input_path ../eval_set/humaneval-x
python main.py --task_type generate
--input_path eval_set/humaneval-x
--language_type python
--model_name chatgpt
--output_prefix ../
output/humaneval
--output_prefix
output/humaneval
评估样例:
python evaluate_humaneval_x.py --language_type python
--input_folder ../output
--tmp_dir ../output/tmp/
--n_workers 3
--timeout 500.0
--problem_folder ../eval_set/humaneval-x/
--out_dir ../output/
--k [1, 10, 100]
--test_groundtruth False
--example_test False
python main.py --task_type evaluate
--language_type python
--input_path output
--tmp_dir output/tmp/
--n_workers 1
--timeout 10.0
--problem_folder eval_set/humaneval-x/
--output_prefix output/
--model_name chatgpt
## 测试结果
...
...
@@ -90,18 +89,19 @@ python evaluate_humaneval_x.py --language_type python
受限于模型推理速度,目前只测试了pass@1指标。
| | python | java | cpp | js | go |
|--------------|--------|--------|--------|--------|--------
-
|
|--------------|--------|--------|--------|--------|--------|
| chatgpt | 64.02% | 15.85% | 26.22% | 47.00% | 31.70% |
| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0
%
|
| bbt-13B | 2.49% | 0
% | 1.90% | 1.83% | 0.61%
|
| bbt-7B | 0.61% | 1.83% | 1.22% | 1.83% | 0
.00%
|
| bbt-13B | 2.49% | 0
.00% | 1.90% | 1.83% | 0.61%
|
| chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% |
| codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% |
| llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% |
| baichuan-7B | 7.93% | 1.83% | 0.00% | 6.71% | 6.71% |
## TODO
1、测试更多开源模型,例如百川,llama2,rwkv。
2、测试模型的pass@10和pass@100指标。
3、代码翻译类任务还没有适配,同时也需要构造相关的数据。
1、测试模型的pass@10和pass@100指标。
2、代码翻译类任务还没有适配,同时也需要构造相关的数据。
llm_set/params/baichuan.json
0 → 100644
浏览文件 @
41eaf869
{
"model_path"
:
"../llm_set/models/baichuan-7b"
,
"device"
:
"cuda:0"
,
"quantize"
:
false
,
"max_length"
:
1024
,
"min_length"
:
10
}
\ No newline at end of file
main.py
浏览文件 @
41eaf869
import
argparse
def
remove_none_params
(
params_dict
):
params
=
{}
for
key
in
params_dict
:
if
params_dict
[
key
]
is
not
None
:
params
[
key
]
=
params_dict
[
key
]
return
params
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--task_type"
,
type
=
str
,
required
=
True
,
help
=
'supported tasks. [generate, translate, evaluate]'
)
parser
.
add_argument
(
"--input_path"
,
type
=
str
,
)
parser
.
add_argument
(
"--language_type"
,
type
=
str
,
default
=
"python"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"chatgpt"
,
help
=
'supported models: [chatgpt, chatglm2, bbt, codegeex2, baichuan].'
)
parser
.
add_argument
(
"--output_prefix"
,
type
=
str
)
parser
.
add_argument
(
"--problem_folder"
,
type
=
str
,
default
=
"eval_set/humaneval-x/"
,
help
=
'for evaluate, problem path.'
)
parser
.
add_argument
(
"--tmp_dir"
,
type
=
str
,
default
=
"output/tmp/"
,
help
=
'for evaluate, temp files path.'
)
parser
.
add_argument
(
"--n_workers"
,
type
=
int
,
default
=
1
,
help
=
'for evaluate, number of processes.'
)
parser
.
add_argument
(
"--timeout"
,
type
=
float
,
default
=
10.0
,
help
=
'for evaluate, single testing timeout time.'
)
args
=
parser
.
parse_known_args
()[
0
]
task_type
=
args
.
task_type
if
task_type
==
"generate"
:
from
src.generate_humaneval_x
import
generate
params
=
remove_none_params
({
"input_path"
:
args
.
input_path
,
"language_type"
:
args
.
language_type
,
"model_name"
:
args
.
model_name
,
"output_prefix"
:
args
.
output_prefix
})
generate
(
**
params
)
elif
task_type
==
"translate"
:
print
(
"Not implemented."
)
elif
task_type
==
"evaluate"
:
from
src.evaluate_humaneval_x
import
evaluate_functional_correctness
params
=
remove_none_params
({
"input_folder"
:
args
.
input_path
,
"language_type"
:
args
.
language_type
,
"model_name"
:
args
.
model_name
,
"out_dir"
:
args
.
output_prefix
,
"problem_folder"
:
args
.
problem_folder
,
"tmp_dir"
:
args
.
tmp_dir
,
"n_workers"
:
args
.
n_workers
,
"timeout"
:
args
.
timeout
})
evaluate_functional_correctness
(
**
params
)
else
:
print
(
f
"task_type:
{
task_type
}
not supported."
)
src/evaluate_humaneval_x.py
浏览文件 @
41eaf869
...
...
@@ -11,9 +11,9 @@ from tqdm.auto import tqdm
from
collections
import
defaultdict
from
concurrent.futures
import
ThreadPoolExecutor
,
as_completed
from
utils
import
read_dataset
,
IMPORT_HELPER
from
metric
import
estimate_pass_at_k
from
execution
import
check_correctness
from
src.
utils
import
read_dataset
,
IMPORT_HELPER
from
src.
metric
import
estimate_pass_at_k
from
src.
execution
import
check_correctness
LANGUAGE_NAME
=
{
"cpp"
:
"CPP"
,
...
...
@@ -102,12 +102,12 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def
evaluate_functional_correctness
(
language_type
:
str
=
"python"
,
input_folder
:
str
=
"
../
output"
,
tmp_dir
:
str
=
"
../
output/tmp/"
,
input_folder
:
str
=
"output"
,
tmp_dir
:
str
=
"output/tmp/"
,
n_workers
:
int
=
3
,
timeout
:
float
=
50
0.0
,
problem_folder
:
str
=
"
../
eval_set/humaneval-x/"
,
out_dir
:
str
=
"
../
output/"
,
timeout
:
float
=
1
0.0
,
problem_folder
:
str
=
"eval_set/humaneval-x/"
,
out_dir
:
str
=
"output/"
,
k
:
List
[
int
]
=
[
1
,
10
,
100
],
test_groundtruth
:
bool
=
False
,
example_test
:
bool
=
False
,
...
...
src/generate_batch.sh
浏览文件 @
41eaf869
python generate_humaneval_x.py
--model_name
bbt
--language_type
python
python generate_humaneval_x.py
--model_name
bbt
--language_type
java
python generate_humaneval_x.py
--model_name
bbt
--language_type
cpp
python generate_humaneval_x.py
--model_name
bbt
--language_type
js
python generate_humaneval_x.py
--model_name
bbt
--language_type
go
\ No newline at end of file
python generate_humaneval_x.py
--model_name
baichuan
--language_type
python
python generate_humaneval_x.py
--model_name
baichuan
--language_type
cpp
python generate_humaneval_x.py
--model_name
baichuan
--language_type
js
python generate_humaneval_x.py
--model_name
baichuan
--language_type
go
\ No newline at end of file
src/generate_humaneval_x.py
浏览文件 @
41eaf869
import
argparse
import
logging
import
os
import
random
import
socket
import
time
import
json
import
torch
...
...
@@ -13,56 +10,39 @@ from utils import cleanup_code
logging
.
getLogger
(
"torch"
).
setLevel
(
logging
.
WARNING
)
if
__name__
==
"__main__"
:
def
generate
(
input_path
=
"eval_set/humaneval-x"
,
language_type
=
"python"
,
model_name
=
"chatgpt"
,
output_prefix
=
"output/humaneval"
):
torch
.
multiprocessing
.
set_start_method
(
"spawn"
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--input_path"
,
type
=
str
,
default
=
"../eval_set/humaneval-x"
)
parser
.
add_argument
(
"--language_type"
,
type
=
str
,
default
=
"python"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"chatgpt"
,
help
=
'supported models in [chatgpt, chatglm2, bbt, codegeex2].'
)
parser
.
add_argument
(
"--output_prefix"
,
type
=
str
,
default
=
"../output/humaneval"
)
args
=
parser
.
parse_known_args
()[
0
]
output_file_path
=
args
.
output_prefix
+
f
"_
{
args
.
model_name
}
_
{
args
.
language_type
}
_finished.jsonl"
output_file_path
=
output_prefix
+
f
"_
{
model_name
}
_
{
language_type
}
_finished.jsonl"
entries
=
read_dataset
(
data_folder
=
args
.
input_path
,
language_type
=
args
.
language_type
,
dataset_type
=
"humaneval"
)
entries
=
read_dataset
(
data_folder
=
input_path
,
language_type
=
language_type
,
dataset_type
=
"humaneval"
)
for
entry
in
entries
.
values
():
entry
[
"prompt"
]
=
process_extra_prompt
(
entry
[
"prompt"
],
args
.
language_type
)
entry
[
"prompt"
]
=
process_extra_prompt
(
entry
[
"prompt"
],
language_type
)
model_inference
=
None
if
args
.
model_name
==
"chatgpt"
:
if
model_name
==
"chatgpt"
:
from
inference.chatgpt_inference
import
ChatGPTInference
model_inference
=
ChatGPTInference
()
elif
args
.
model_name
==
"chatglm2"
:
elif
model_name
==
"chatglm2"
:
from
inference.chatglm2_inference
import
GLMInference
model_inference
=
GLMInference
()
elif
args
.
model_name
==
"bbt"
:
elif
model_name
==
"bbt"
:
from
inference.bbt_inference
import
BBTInference
model_inference
=
BBTInference
()
elif
args
.
model_name
==
"codegeex2"
:
elif
model_name
==
"codegeex2"
:
from
inference.codegeex2_inference
import
CodeGeex2Inference
model_inference
=
CodeGeex2Inference
()
elif
args
.
model_name
==
"llama2"
:
elif
model_name
==
"llama2"
:
from
inference.llama2_inference
import
LLAMA2Inference
model_inference
=
LLAMA2Inference
()
elif
model_name
==
"baichuan"
:
from
inference.baichuan_inference
import
BaiChuanInference
model_inference
=
BaiChuanInference
()
else
:
print
(
f
"
{
args
.
model_name
}
not supported."
)
print
(
f
"
{
model_name
}
not supported."
)
start_time
=
time
.
perf_counter
()
num_finished
=
0
...
...
@@ -78,7 +58,7 @@ if __name__ == "__main__":
if
generated_tokens
is
None
:
print
(
f
"task_id:
{
entry
[
'task_id'
]
}
generate failed!"
)
continue
generated_code
=
cleanup_code
(
generated_tokens
,
entry
,
language_type
=
args
.
language_type
)
generated_code
=
cleanup_code
(
generated_tokens
,
entry
,
language_type
=
language_type
)
f
.
write
(
json
.
dumps
(
{
...
...
src/inference/baichuan_inference.py
0 → 100644
浏览文件 @
41eaf869
import
os
import
json
import
torch
import
logging
from
.inference
import
Inference
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
logger
=
logging
.
getLogger
(
__name__
)
class
BaiChuanInference
(
Inference
):
def
__init__
(
self
):
super
(
BaiChuanInference
,
self
).
__init__
()
self
.
params_url
=
"../llm_set/params/baichuan.json"
self
.
paras_dict
=
self
.
get_params
()
self
.
paras_dict
.
update
(
self
.
paras_base_dict
)
self
.
temperature
=
self
.
paras_dict
.
get
(
"temperature"
)
self
.
quantize
=
self
.
paras_dict
.
get
(
"quantize"
)
self
.
device
=
self
.
paras_dict
.
get
(
"device"
)
self
.
model_path
=
self
.
paras_dict
.
get
(
"model_path"
)
self
.
DEV
=
torch
.
device
(
self
.
device
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
model_path
,
trust_remote_code
=
True
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model_path
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
low_cpu_mem_usage
=
True
,
torch_dtype
=
torch
.
float16
)
self
.
max_length
=
self
.
paras_dict
.
get
(
"max_length"
)
self
.
min_length
=
self
.
paras_dict
.
get
(
"min_length"
)
self
.
top_p
=
self
.
paras_dict
.
get
(
"top_p"
)
self
.
top_k
=
self
.
paras_dict
.
get
(
"top_k"
)
def
get_params
(
self
):
if
not
os
.
path
.
exists
(
self
.
params_url
):
logger
.
error
(
f
"params_url:
{
self
.
params_url
}
is not exists."
)
content
=
open
(
self
.
params_url
).
read
()
return
json
.
loads
(
content
)
def
inference
(
self
,
message
):
input_ids
=
self
.
tokenizer
.
encode
(
message
,
return_tensors
=
"pt"
).
to
(
self
.
DEV
)
sentence
=
None
with
torch
.
no_grad
():
generation_output
=
self
.
model
.
generate
(
input_ids
,
do_sample
=
True
,
min_length
=
self
.
min_length
,
max_length
=
self
.
max_length
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
temperature
=
self
.
temperature
)
output
=
self
.
tokenizer
.
decode
([
el
.
item
()
for
el
in
generation_output
[
0
]])
sentence
=
output
.
strip
()
return
sentence
src/utils.py
浏览文件 @
41eaf869
import
os
import
re
from
typing
import
*
from
data_utils
import
stream_jsonl
,
LANGUAGE_TAG
from
src.
data_utils
import
stream_jsonl
,
LANGUAGE_TAG
from
collections
import
Counter
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录