Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
llm-coding-eval
提交
e0bf81b9
L
llm-coding-eval
项目概览
CSDN 技术社区
/
llm-coding-eval
通知
19
Star
3
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
1
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
L
llm-coding-eval
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e0bf81b9
编写于
7月 24, 2023
作者:
CSDN-Ada助手
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug
上级
cc9d2464
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
593 addition
and
378 deletion
+593
-378
src/evaluate_humaneval_x.py
src/evaluate_humaneval_x.py
+17
-30
src/execution.py
src/execution.py
+302
-301
src/generate_humaneval_x.py
src/generate_humaneval_x.py
+19
-8
src/inference/bbt_inference.py
src/inference/bbt_inference.py
+58
-0
src/inference/chatglm2_inference.py
src/inference/chatglm2_inference.py
+4
-3
src/inference/chatgpt_inference.py
src/inference/chatgpt_inference.py
+11
-6
src/inference/inference.py
src/inference/inference.py
+2
-2
src/utils.py
src/utils.py
+180
-28
未找到文件。
src/evaluate_humaneval_x.py
浏览文件 @
e0bf81b9
...
@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k
...
@@ -16,10 +16,10 @@ from metric import estimate_pass_at_k
from
execution
import
check_correctness
from
execution
import
check_correctness
LANGUAGE_NAME
=
{
LANGUAGE_NAME
=
{
"cpp"
:
"CPP"
,
"cpp"
:
"CPP"
,
"go"
:
"Go"
,
"go"
:
"Go"
,
"java"
:
"Java"
,
"java"
:
"Java"
,
"js"
:
"JavaScript"
,
"js"
:
"JavaScript"
,
"python"
:
"Python"
,
"python"
:
"Python"
,
}
}
...
@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False):
...
@@ -39,8 +39,10 @@ def process_humaneval_test(sample, problems, example_test=False):
if
language
==
"python"
:
if
language
==
"python"
:
code_
=
[]
code_
=
[]
for
line
in
code
.
split
(
"
\n
"
):
for
line
in
code
.
split
(
"
\n
"
):
if
(
len
(
line
.
strip
())
>
0
and
line
[
0
]
!=
' '
and
line
[
0
]
!=
'
\t
'
):
if
line
.
strip
().
startswith
(
"def "
):
break
continue
if
line
and
line
[
0
]
!=
' '
and
line
[
0
]
!=
'
\t
'
:
line
=
" "
+
line
code_
.
append
(
line
)
code_
.
append
(
line
)
code
=
"
\n
"
.
join
(
code_
)
code
=
"
\n
"
.
join
(
code_
)
test_setup
=
"
\n
"
.
join
(
IMPORT_HELPER
[
"python"
])
+
"
\n
"
test_setup
=
"
\n
"
.
join
(
IMPORT_HELPER
[
"python"
])
+
"
\n
"
...
@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
...
@@ -97,23 +99,25 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def
evaluate_functional_correctness
(
def
evaluate_functional_correctness
(
input_file
:
str
=
None
,
language_type
:
str
=
"python"
,
tmp_dir
:
str
=
"./"
,
input_folder
:
str
=
"../output"
,
n_workers
:
int
=
32
,
tmp_dir
:
str
=
"../output/tmp/"
,
n_workers
:
int
=
3
,
timeout
:
float
=
500.0
,
timeout
:
float
=
500.0
,
problem_f
ile
:
str
=
"../data/humaneval_python.jsonl.gz
"
,
problem_f
older
:
str
=
"../eval_set/humaneval-x/
"
,
out_dir
:
str
=
None
,
out_dir
:
str
=
"../output/"
,
k
:
List
[
int
]
=
[
1
,
10
,
100
],
k
:
List
[
int
]
=
[
1
,
10
,
100
],
test_groundtruth
:
bool
=
False
,
test_groundtruth
:
bool
=
False
,
example_test
:
bool
=
False
,
example_test
:
bool
=
False
,
model_name
:
str
=
"chatgpt"
):
):
if
example_test
:
if
example_test
:
print
(
"Example test..."
)
print
(
"Example test..."
)
problems
=
read_dataset
(
problem_file
,
input_file
=
f
"
{
input_folder
}
/humaneval_
{
model_name
}
_
{
language_type
}
_finished.jsonl"
dataset_type
=
"humaneval"
)
sample_jsonl
=
stream_jsonl_all
(
input_file
)
sample_jsonl
=
stream_jsonl_all
(
input_file
)
problems
=
read_dataset
(
data_folder
=
problem_folder
,
language_type
=
language_type
,
dataset_type
=
"humaneval"
)
if
example_test
:
if
example_test
:
suffix
=
"_example_test.jsonl"
suffix
=
"_example_test.jsonl"
else
:
else
:
...
@@ -125,14 +129,6 @@ def evaluate_functional_correctness(
...
@@ -125,14 +129,6 @@ def evaluate_functional_correctness(
else
:
else
:
out_file
=
os
.
path
.
join
(
input_file
.
replace
(
".jsonl"
,
suffix
))
out_file
=
os
.
path
.
join
(
input_file
.
replace
(
".jsonl"
,
suffix
))
if
"/codegeex/benchmark/humaneval-x/"
in
input_file
:
test_groundtruth
=
True
if
"-to-"
in
input_file
:
translation_mode
=
True
else
:
translation_mode
=
False
with
ThreadPoolExecutor
(
max_workers
=
n_workers
)
as
executor
:
with
ThreadPoolExecutor
(
max_workers
=
n_workers
)
as
executor
:
futures
=
[]
futures
=
[]
...
@@ -162,14 +158,6 @@ def evaluate_functional_correctness(
...
@@ -162,14 +158,6 @@ def evaluate_functional_correctness(
for
sample
in
tqdm
(
sample_jsonl
):
for
sample
in
tqdm
(
sample_jsonl
):
task_id
=
sample
[
"task_id"
]
task_id
=
sample
[
"task_id"
]
lang
=
task_id
.
split
(
"/"
)[
0
].
lower
()
lang
=
task_id
.
split
(
"/"
)[
0
].
lower
()
if
translation_mode
:
task_id
=
sample
[
"task_id"
].
split
(
"/"
)[
-
1
]
lang
=
regex
.
findall
(
"-to-.*-"
,
input_file
)[
0
].
split
(
"-to-"
)[
-
1
].
rstrip
(
"-"
)
for
l
in
LANGUAGE_NAME
:
if
l
in
lang
:
lang
=
l
break
task_id
=
f
"
{
LANGUAGE_NAME
[
lang
]
}
/
{
task_id
}
"
if
lang
==
"javascript"
:
if
lang
==
"javascript"
:
lang
=
"js"
lang
=
"js"
tmp_dir_
=
os
.
path
.
join
(
tmp_dir
,
lang
,
"evaluation"
)
tmp_dir_
=
os
.
path
.
join
(
tmp_dir
,
lang
,
"evaluation"
)
...
@@ -187,7 +175,6 @@ def evaluate_functional_correctness(
...
@@ -187,7 +175,6 @@ def evaluate_functional_correctness(
completion_id
[
task_id
]
+=
1
completion_id
[
task_id
]
+=
1
n_samples
+=
1
n_samples
+=
1
print
(
completion_id
)
if
len
(
completion_id
)
==
len
(
problems
):
if
len
(
completion_id
)
==
len
(
problems
):
evaluate_pass_at_k
=
True
evaluate_pass_at_k
=
True
else
:
else
:
...
...
src/execution.py
浏览文件 @
e0bf81b9
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
import
contextlib
import
contextlib
import
faulthandler
import
faulthandler
import
io
import
io
...
@@ -41,20 +61,7 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non
...
@@ -41,20 +61,7 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non
out
.
write
(
jout
)
out
.
write
(
jout
)
def
check_correctness
(
def
unsafe_execute
(
tmp_dir
,
language_type
,
timeout
,
sample
,
result
,
task_id
):
task_id
:
str
,
sample
:
dict
,
language_type
:
str
,
timeout
:
float
=
3.0
,
tmp_dir
:
str
=
None
,
completion_id
:
Optional
[
int
]
=
None
,
)
->
Dict
:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
"""
def
unsafe_execute
(
tmp_dir
):
random_id
=
random
.
uniform
(
1
,
1000
)
random_id
=
random
.
uniform
(
1
,
1000
)
if
"python"
in
language_type
.
lower
():
if
"python"
in
language_type
.
lower
():
with
create_tempdir
():
with
create_tempdir
():
...
@@ -108,8 +115,8 @@ def check_correctness(
...
@@ -108,8 +115,8 @@ def check_correctness(
if
not
os
.
path
.
exists
(
tmp_dir
):
if
not
os
.
path
.
exists
(
tmp_dir
):
os
.
makedirs
(
tmp_dir
)
os
.
makedirs
(
tmp_dir
)
os
.
chdir
(
tmp_dir
)
#
os.chdir(tmp_dir)
open
(
f
"main_test.go"
,
'w'
).
write
(
sample
[
"test_code"
])
open
(
os
.
path
.
join
(
tmp_dir
,
"main_test.go"
)
,
'w'
).
write
(
sample
[
"test_code"
])
try
:
try
:
exec_result
=
None
exec_result
=
None
with
time_limit
(
timeout
):
with
time_limit
(
timeout
):
...
@@ -122,7 +129,7 @@ def check_correctness(
...
@@ -122,7 +129,7 @@ def check_correctness(
# does not perform destructive actions on their host or network.
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# uncomment the following line and proceed at your own risk:
exec_result
=
subprocess
.
run
([
"go"
,
"test"
,
f
"-timeout=
{
timeout
}
s"
,
"main_test.go"
],
timeout
=
timeout
,
capture_output
=
True
)
exec_result
=
subprocess
.
run
([
"go"
,
"test"
,
f
"-timeout=
{
timeout
}
s"
,
os
.
path
.
join
(
tmp_dir
,
"main_test.go"
)
],
timeout
=
timeout
,
capture_output
=
True
)
if
exec_result
.
returncode
==
0
:
if
exec_result
.
returncode
==
0
:
result
.
append
(
"passed"
)
result
.
append
(
"passed"
)
...
@@ -153,8 +160,8 @@ def check_correctness(
...
@@ -153,8 +160,8 @@ def check_correctness(
if
not
os
.
path
.
exists
(
tmp_dir
):
if
not
os
.
path
.
exists
(
tmp_dir
):
os
.
makedirs
(
tmp_dir
)
os
.
makedirs
(
tmp_dir
)
os
.
chdir
(
tmp_dir
)
#
os.chdir(tmp_dir)
open
(
f
"test.js"
,
'w'
).
write
(
sample
[
"test_code"
])
open
(
os
.
path
.
join
(
tmp_dir
,
"test.js"
)
,
'w'
).
write
(
sample
[
"test_code"
])
try
:
try
:
exec_result
=
None
exec_result
=
None
with
time_limit
(
timeout
):
with
time_limit
(
timeout
):
...
@@ -167,7 +174,7 @@ def check_correctness(
...
@@ -167,7 +174,7 @@ def check_correctness(
# does not perform destructive actions on their host or network.
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# uncomment the following line and proceed at your own risk:
exec_result
=
subprocess
.
run
([
"node"
,
"test.js"
],
timeout
=
timeout
,
capture_output
=
True
)
exec_result
=
subprocess
.
run
([
"node"
,
os
.
path
.
join
(
tmp_dir
,
"test.js"
)
],
timeout
=
timeout
,
capture_output
=
True
)
if
exec_result
.
stderr
.
decode
():
if
exec_result
.
stderr
.
decode
():
err
=
exec_result
.
stderr
.
decode
()
err
=
exec_result
.
stderr
.
decode
()
...
@@ -192,14 +199,16 @@ def check_correctness(
...
@@ -192,14 +199,16 @@ def check_correctness(
if
not
os
.
path
.
exists
(
tmp_dir
):
if
not
os
.
path
.
exists
(
tmp_dir
):
os
.
makedirs
(
tmp_dir
)
os
.
makedirs
(
tmp_dir
)
os
.
chdir
(
tmp_dir
)
# os.chdir(tmp_dir)
open
(
f
"test.cpp"
,
'w'
).
write
(
sample
[
"test_code"
])
open
(
os
.
path
.
join
(
tmp_dir
,
"test.cpp"
),
'w'
).
write
(
sample
[
"test_code"
])
if
"162"
in
task_id
:
# if "162" in task_id:
compilation_result
=
subprocess
.
run
([
"/usr/bin/g++"
,
"-std=c++11"
,
"test.cpp"
,
"-lcrypto"
,
"-lssl"
],
# compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp"), "-lcrypto", "-lssl"],
timeout
=
timeout
,
# timeout=timeout,
capture_output
=
True
)
# capture_output=True)
else
:
# else:
compilation_result
=
subprocess
.
run
([
"/usr/bin/g++"
,
"-std=c++11"
,
"test.cpp"
],
timeout
=
timeout
,
# compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", os.path.join(tmp_dir, "test.cpp")], timeout=timeout,
# capture_output=True)
compilation_result
=
subprocess
.
run
([
"/usr/bin/g++"
,
"-std=c++11"
,
os
.
path
.
join
(
tmp_dir
,
"test.cpp"
)],
timeout
=
timeout
,
capture_output
=
True
)
capture_output
=
True
)
if
compilation_result
.
returncode
!=
0
:
if
compilation_result
.
returncode
!=
0
:
if
compilation_result
.
stderr
:
if
compilation_result
.
stderr
:
...
@@ -257,10 +266,10 @@ def check_correctness(
...
@@ -257,10 +266,10 @@ def check_correctness(
os
.
makedirs
(
RUST_SRC
,
exist_ok
=
True
)
os
.
makedirs
(
RUST_SRC
,
exist_ok
=
True
)
os
.
makedirs
(
RUST_BIN
,
exist_ok
=
True
)
os
.
makedirs
(
RUST_BIN
,
exist_ok
=
True
)
with
tempfile
.
NamedTemporaryFile
(
dir
=
RUST_BIN
,
delete
=
False
)
as
f
:
with
tempfile
.
NamedTemporaryFile
(
dir
=
RUST_BIN
,
delete
=
False
)
as
f
:
#
temporal file name
#
temporal file name
file_prefix
=
sample
[
"task_id"
].
lower
().
replace
(
"/"
,
"_"
)
file_prefix
=
sample
[
"task_id"
].
lower
().
replace
(
"/"
,
"_"
)
file_name
:
str
=
file_prefix
+
RUST_EXT
file_name
:
str
=
file_prefix
+
RUST_EXT
os
.
rename
(
f
.
name
,
os
.
path
.
join
(
RUST_BIN
,
file_name
))
os
.
rename
(
f
.
name
,
os
.
path
.
join
(
RUST_BIN
,
file_name
))
...
@@ -292,9 +301,8 @@ def check_correctness(
...
@@ -292,9 +301,8 @@ def check_correctness(
# 0 means success
# 0 means success
if
returned_val_compilation
==
0
:
if
returned_val_compilation
==
0
:
# Execution pipeline
#Execution pipeline
cargo_test
:
str
=
"cargo test --bin "
+
file_prefix
+
" --message-format json >> "
+
log_path
cargo_test
:
str
=
"cargo test --bin "
+
file_prefix
+
" --message-format json >> "
+
log_path
returned_val_execution
=
os
.
system
(
cargo_test
)
returned_val_execution
=
os
.
system
(
cargo_test
)
if
returned_val_execution
==
0
:
if
returned_val_execution
==
0
:
...
@@ -305,7 +313,6 @@ def check_correctness(
...
@@ -305,7 +313,6 @@ def check_correctness(
else
:
else
:
result
.
append
(
f
"failed: compilation error"
)
result
.
append
(
f
"failed: compilation error"
)
elif
"java"
in
language_type
.
lower
():
elif
"java"
in
language_type
.
lower
():
assert
tmp_dir
is
not
None
,
"Java should be evaluated in a temporary dir."
assert
tmp_dir
is
not
None
,
"Java should be evaluated in a temporary dir."
...
@@ -318,7 +325,7 @@ def check_correctness(
...
@@ -318,7 +325,7 @@ def check_correctness(
if
not
os
.
path
.
exists
(
tmp_dir
):
if
not
os
.
path
.
exists
(
tmp_dir
):
os
.
makedirs
(
tmp_dir
)
os
.
makedirs
(
tmp_dir
)
os
.
chdir
(
tmp_dir
)
#
os.chdir(tmp_dir)
open
(
os
.
path
.
join
(
tmp_dir
,
"Main.java"
),
'w'
).
write
(
sample
[
"test_code"
])
open
(
os
.
path
.
join
(
tmp_dir
,
"Main.java"
),
'w'
).
write
(
sample
[
"test_code"
])
res
=
"failed: unknown error"
res
=
"failed: unknown error"
compile_returncode
=
-
1
compile_returncode
=
-
1
...
@@ -344,7 +351,7 @@ def check_correctness(
...
@@ -344,7 +351,7 @@ def check_correctness(
# does not perform destructive actions on their host or network.
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# uncomment the following line and proceed at your own risk:
#
exec_result = subprocess.run([f'java', '-cp', tmp_dir, 'Main'], timeout=timeout, capture_output=True)
exec_result
=
subprocess
.
run
([
f
'java'
,
'-cp'
,
tmp_dir
,
'Main'
],
timeout
=
timeout
,
capture_output
=
True
)
if
exec_result
.
returncode
==
0
:
if
exec_result
.
returncode
==
0
:
res
=
"passed"
res
=
"passed"
elif
exec_result
.
returncode
==
1
:
elif
exec_result
.
returncode
==
1
:
...
@@ -357,13 +364,26 @@ def check_correctness(
...
@@ -357,13 +364,26 @@ def check_correctness(
except
BaseException
as
e
:
except
BaseException
as
e
:
res
=
f
"failed:
{
e
}
"
res
=
f
"failed:
{
e
}
"
result
.
append
(
res
)
result
.
append
(
res
)
shutil
.
rmtree
(
tmp_dir
)
shutil
.
rmtree
(
tmp_dir
)
def
check_correctness
(
task_id
:
str
,
sample
:
dict
,
language_type
:
str
,
timeout
:
float
=
3.0
,
tmp_dir
:
str
=
None
,
completion_id
:
Optional
[
int
]
=
None
,
)
->
Dict
:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
"""
manager
=
multiprocessing
.
Manager
()
manager
=
multiprocessing
.
Manager
()
result
=
manager
.
list
()
result
=
manager
.
list
()
p
=
multiprocessing
.
Process
(
target
=
unsafe_execute
,
args
=
(
tmp_dir
,))
p
=
multiprocessing
.
Process
(
target
=
unsafe_execute
,
args
=
(
tmp_dir
,
language_type
,
timeout
,
sample
,
result
,
task_id
,
))
p
.
start
()
p
.
start
()
p
.
join
(
timeout
=
timeout
+
1
)
p
.
join
(
timeout
=
timeout
+
1
)
if
p
.
is_alive
():
if
p
.
is_alive
():
...
@@ -373,38 +393,19 @@ def check_correctness(
...
@@ -373,38 +393,19 @@ def check_correctness(
result
.
append
(
"timed out"
)
result
.
append
(
"timed out"
)
return
{
return
{
"task_id"
:
task_id
,
"task_id"
:
task_id
,
"completion_id"
:
completion_id
,
"completion_id"
:
completion_id
,
"test_code"
:
sample
[
"test_code"
],
"test_code"
:
sample
[
"test_code"
],
"prompt"
:
sample
[
"prompt"
],
"prompt"
:
sample
[
"prompt"
],
"generation"
:
sample
[
"generation"
],
"generation"
:
sample
[
"generation"
],
"result"
:
result
[
0
],
"result"
:
result
[
0
],
"passed"
:
result
[
0
]
==
"passed"
,
"passed"
:
result
[
0
]
==
"passed"
,
"finish"
:
-
1
if
"finish"
not
in
sample
else
sample
[
"finish"
],
"finish"
:
-
1
if
"finish"
not
in
sample
else
sample
[
"finish"
],
"file"
:
""
if
"file"
not
in
sample
else
sample
[
"file"
],
"file"
:
""
if
"file"
not
in
sample
else
sample
[
"file"
],
"output"
:
[]
if
"output"
not
in
sample
else
sample
[
"output"
],
"output"
:
[]
if
"output"
not
in
sample
else
sample
[
"output"
],
}
}
# Copyright (c) OpenAI (https://openai.com)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ============================================================================
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
time_limit
(
seconds
:
float
):
def
time_limit
(
seconds
:
float
):
def
signal_handler
(
signum
,
frame
):
def
signal_handler
(
signum
,
frame
):
...
...
src/generate_humaneval_x.py
浏览文件 @
e0bf81b9
...
@@ -19,7 +19,7 @@ if __name__ == "__main__":
...
@@ -19,7 +19,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--input_path"
,
"--input_path"
,
type
=
str
,
type
=
str
,
default
=
"eval_set/humaneval-x"
default
=
"
../
eval_set/humaneval-x"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--language_type"
,
"--language_type"
,
...
@@ -30,18 +30,18 @@ if __name__ == "__main__":
...
@@ -30,18 +30,18 @@ if __name__ == "__main__":
"--model_name"
,
"--model_name"
,
type
=
str
,
type
=
str
,
default
=
"chatgpt"
,
default
=
"chatgpt"
,
help
=
'supported model
in [chatgpt,chatglm2
].'
help
=
'supported model
s in [chatgpt, chatglm2, bbt
].'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--output_prefix"
,
"--output_prefix"
,
type
=
str
,
type
=
str
,
default
=
"
chatgpt
"
default
=
"
../output/humaneval
"
)
)
args
=
parser
.
parse_known_args
()[
0
]
args
=
parser
.
parse_known_args
()[
0
]
output_file_path
=
args
.
output_prefix
+
f
"_
finished_rank
{
args
.
gen_rank
}
.jsonl"
output_file_path
=
args
.
output_prefix
+
f
"_
{
args
.
model_name
}
_
{
args
.
language_type
}
_finished
.jsonl"
entries
=
read_dataset
(
args
.
input_path
,
args
.
language_type
,
dataset_type
=
"humaneval"
)
entries
=
read_dataset
(
data_folder
=
args
.
input_path
,
language_type
=
args
.
language_type
,
dataset_type
=
"humaneval"
)
for
entry
in
entries
.
values
():
for
entry
in
entries
.
values
():
entry
[
"prompt"
]
=
process_extra_prompt
(
entry
[
"prompt"
],
args
.
language_type
)
entry
[
"prompt"
]
=
process_extra_prompt
(
entry
[
"prompt"
],
args
.
language_type
)
...
@@ -52,6 +52,9 @@ if __name__ == "__main__":
...
@@ -52,6 +52,9 @@ if __name__ == "__main__":
elif
args
.
model_name
==
"chatglm2"
:
elif
args
.
model_name
==
"chatglm2"
:
from
inference.chatglm2_inference
import
GLMInference
from
inference.chatglm2_inference
import
GLMInference
model_inference
=
GLMInference
()
model_inference
=
GLMInference
()
elif
args
.
model_name
==
"bbt"
:
from
inference.bbt_inference
import
BBTInference
model_inference
=
BBTInference
()
else
:
else
:
print
(
f
"
{
args
.
model_name
}
not supported."
)
print
(
f
"
{
args
.
model_name
}
not supported."
)
...
@@ -60,8 +63,16 @@ if __name__ == "__main__":
...
@@ -60,8 +63,16 @@ if __name__ == "__main__":
with
open
(
output_file_path
,
"w"
)
as
f
:
with
open
(
output_file_path
,
"w"
)
as
f
:
for
entry
in
entries
.
values
():
for
entry
in
entries
.
values
():
prompt
=
entry
[
"prompt"
]
prompt
=
entry
[
"prompt"
]
retry_times
=
3
generated_tokens
=
model_inference
.
inference
(
prompt
)
generated_tokens
=
model_inference
.
inference
(
prompt
)
generated_code
=
cleanup_code
(
generated_code
,
language_type
=
args
.
language_type
)
while
generated_tokens
is
None
and
retry_times
>
0
:
time
.
sleep
(
10
)
generated_tokens
=
model_inference
.
inference
(
prompt
)
retry_times
-=
1
if
generated_tokens
is
None
:
print
(
f
"task_id:
{
entry
[
'task_id'
]
}
generate failed!"
)
continue
generated_code
=
cleanup_code
(
generated_tokens
,
entry
,
language_type
=
args
.
language_type
)
f
.
write
(
f
.
write
(
json
.
dumps
(
json
.
dumps
(
{
{
...
@@ -79,11 +90,11 @@ if __name__ == "__main__":
...
@@ -79,11 +90,11 @@ if __name__ == "__main__":
num_finished
+=
1
num_finished
+=
1
time_elapsed
=
time
.
perf_counter
()
-
start_time
time_elapsed
=
time
.
perf_counter
()
-
start_time
time_per_samp
p
le
=
0.0
if
num_finished
==
0
else
time_elapsed
/
num_finished
time_per_sample
=
0.0
if
num_finished
==
0
else
time_elapsed
/
num_finished
print
(
print
(
f
"finished
{
num_finished
}
, "
f
"finished
{
num_finished
}
, "
f
"elapsed
{
time_elapsed
:.
4
f
}
"
,
f
"elapsed
{
time_elapsed
:.
4
f
}
"
,
f
"speed
{
time_per_samp
p
le
:.
4
f
}
s/sample"
,
f
"speed
{
time_per_sample
:.
4
f
}
s/sample"
,
f
"remaining
{
len
(
entries
)
-
num_finished
:.
4
f
}
"
,
f
"remaining
{
len
(
entries
)
-
num_finished
:.
4
f
}
"
,
flush
=
True
,
flush
=
True
,
)
)
src/inference/bbt_inference.py
0 → 100644
浏览文件 @
e0bf81b9
import
os
import
json
import
logging
from
urllib.parse
import
quote
import
requests
import
traceback
from
.inference
import
Inference
from
utils
import
exception_reconnect
logger
=
logging
.
getLogger
(
__name__
)
class
BBTInference
(
Inference
):
def
__init__
(
self
):
super
(
BBTInference
,
self
).
__init__
()
self
.
params_url
=
"../llm_set/params/bbt.json"
self
.
paras_dict
=
self
.
get_params
()
self
.
paras_base_dict
.
update
(
self
.
paras_dict
)
self
.
paras_dict
=
self
.
paras_base_dict
self
.
url
=
self
.
paras_dict
.
get
(
"url"
)
self
.
user_id
=
self
.
paras_dict
.
get
(
"user_id"
)
self
.
timeout
=
self
.
paras_dict
.
get
(
"timeout"
)
self
.
modelConfig
=
{
'temperature'
:
self
.
paras_dict
.
get
(
"temperature"
),
'top_p'
:
self
.
paras_dict
.
get
(
"top_p"
),
# 'top_k': 80,
# 'do_sample': True,
# 'num_beams': 4,
# 'no_repeat_ngram_size': 4,
# 'repetition_penalty': 1.1,
# 'length_penalty': 1,
'max_new_tokens'
:
self
.
paras_dict
.
get
(
"max_new_tokens"
),
}
def
get_params
(
self
):
if
not
os
.
path
.
exists
(
self
.
params_url
):
logger
.
error
(
f
"params_url:
{
self
.
params_url
}
is not exists."
)
content
=
open
(
self
.
params_url
).
read
()
return
json
.
loads
(
content
)
@
exception_reconnect
def
inference
(
self
,
query_text
):
query_text
=
"补全以下方法体:
\n
"
+
query_text
try
:
data
=
{
'content'
:
query_text
,
'user_id'
:
self
.
user_id
,
'config'
:
json
.
dumps
(
self
.
modelConfig
)}
resp
=
requests
.
post
(
self
.
url
,
data
=
data
)
except
Exception
as
e
:
logger
.
error
(
traceback
.
format_exc
())
logger
.
error
(
e
)
return
None
if
resp
.
status_code
==
200
:
resp_json
=
json
.
loads
(
resp
.
text
)
return
resp_json
[
"msg"
]
else
:
return
None
src/inference/chatglm2_inference.py
浏览文件 @
e0bf81b9
...
@@ -3,15 +3,16 @@ import json
...
@@ -3,15 +3,16 @@ import json
import
torch
import
torch
import
logging
import
logging
from
transformers
import
AutoModel
,
AutoTokenizer
from
transformers
import
AutoModel
,
AutoTokenizer
from
inference
import
Inference
from
.
inference
import
Inference
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
GLMInference
(
Inference
):
class
GLMInference
(
Inference
):
def
__int__
(
self
):
def
__init__
(
self
):
self
.
params_url
=
"llm_set/params/chatgpt.json"
super
(
GLMInference
,
self
).
__init__
()
self
.
params_url
=
"../llm_set/params/chatglm2.json"
self
.
paras_dict
=
self
.
get_params
()
self
.
paras_dict
=
self
.
get_params
()
self
.
paras_dict
.
update
(
self
.
paras_base_dict
)
self
.
paras_dict
.
update
(
self
.
paras_base_dict
)
...
...
src/inference/chatgpt_inference.py
浏览文件 @
e0bf81b9
...
@@ -2,20 +2,23 @@
...
@@ -2,20 +2,23 @@
import
os
import
os
import
json
import
json
import
logging
import
logging
import
quote
from
urllib.parse
import
quote
import
requests
import
requests
import
traceback
import
traceback
from
inference
import
Inference
from
.inference
import
Inference
from
utils
import
exception_reconnect
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
ChatGPTInference
(
Inference
):
class
ChatGPTInference
(
Inference
):
def
__int__
(
self
):
def
__init__
(
self
):
self
.
params_url
=
"llm_set/params/chatgpt.json"
super
(
ChatGPTInference
,
self
).
__init__
()
self
.
params_url
=
"../llm_set/params/chatgpt.json"
self
.
paras_dict
=
self
.
get_params
()
self
.
paras_dict
=
self
.
get_params
()
self
.
paras_dict
.
update
(
self
.
paras_base_dict
)
self
.
paras_base_dict
.
update
(
self
.
paras_dict
)
self
.
paras_dict
=
self
.
paras_base_dict
self
.
gpt_url
=
self
.
paras_dict
.
get
(
"url"
)
self
.
gpt_url
=
self
.
paras_dict
.
get
(
"url"
)
self
.
id
=
self
.
paras_dict
.
get
(
"id"
)
self
.
id
=
self
.
paras_dict
.
get
(
"id"
)
self
.
stream
=
self
.
paras_dict
.
get
(
"stream"
)
self
.
stream
=
self
.
paras_dict
.
get
(
"stream"
)
...
@@ -28,9 +31,11 @@ class ChatGPTInference(Inference):
...
@@ -28,9 +31,11 @@ class ChatGPTInference(Inference):
content
=
open
(
self
.
params_url
).
read
()
content
=
open
(
self
.
params_url
).
read
()
return
json
.
loads
(
content
)
return
json
.
loads
(
content
)
@
exception_reconnect
def
inference
(
self
,
query_text
):
def
inference
(
self
,
query_text
):
query_text
=
"补全以下方法体:
\n
"
+
query_text
query_text
=
quote
(
query_text
)
query_text
=
quote
(
query_text
)
param_str
=
f
"?id=
{
self
.
gpt_url
}
&stream=
{
self
.
stream
}
&temperature=
{
self
.
temperature
}
&q=
{
query_text
}
"
param_str
=
f
"?id=
{
self
.
id
}
&stream=
{
self
.
stream
}
&temperature=
{
self
.
temperature
}
&q=
{
query_text
}
"
try
:
try
:
resp
=
requests
.
get
(
self
.
gpt_url
+
param_str
,
timeout
=
self
.
timeout
)
resp
=
requests
.
get
(
self
.
gpt_url
+
param_str
,
timeout
=
self
.
timeout
)
except
Exception
as
e
:
except
Exception
as
e
:
...
...
src/inference/inference.py
浏览文件 @
e0bf81b9
...
@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
...
@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
class
Inference
(
object
):
class
Inference
(
object
):
def
__int__
(
self
):
def
__in
i
t__
(
self
):
self
.
params_base_url
=
"llm_set/params/default.json"
self
.
params_base_url
=
"
../
llm_set/params/default.json"
self
.
paras_base_dict
=
self
.
get_paras_base
()
self
.
paras_base_dict
=
self
.
get_paras_base
()
def
get_paras_base
(
self
):
def
get_paras_base
(
self
):
...
...
src/utils.py
浏览文件 @
e0bf81b9
import
os
import
os
import
re
from
typing
import
*
from
typing
import
*
from
data_utils
import
stream_jsonl
,
LANGUAGE_TAG
from
data_utils
import
stream_jsonl
,
LANGUAGE_TAG
from
collections
import
Counter
IMPORT_HELPER
=
{
IMPORT_HELPER
=
{
...
@@ -49,7 +51,7 @@ IMPORT_HELPER = {
...
@@ -49,7 +51,7 @@ IMPORT_HELPER = {
def
read_dataset
(
def
read_dataset
(
data_f
ile
:
str
=
None
,
data_f
older
:
str
=
None
,
dataset_type
:
str
=
"humaneval"
,
dataset_type
:
str
=
"humaneval"
,
language_type
:
str
=
"python"
,
language_type
:
str
=
"python"
,
num_shot
=
None
,
num_shot
=
None
,
...
@@ -57,7 +59,7 @@ def read_dataset(
...
@@ -57,7 +59,7 @@ def read_dataset(
if
num_shot
is
not
None
:
if
num_shot
is
not
None
:
print
(
f
"
{
num_shot
}
-shot setting..."
)
print
(
f
"
{
num_shot
}
-shot setting..."
)
if
"humaneval"
in
dataset_type
.
lower
():
if
"humaneval"
in
dataset_type
.
lower
():
data_file
=
os
.
path
.
join
(
data_f
ile
,
language_type
,
"data"
,
f
"humaneval_
{
language_type
}
.jsonl.gz"
)
data_file
=
os
.
path
.
join
(
data_f
older
,
language_type
,
"data"
,
f
"humaneval_
{
language_type
}
.jsonl.gz"
)
dataset
=
{
task
[
"task_id"
]:
task
for
task
in
stream_jsonl
(
data_file
)}
dataset
=
{
task
[
"task_id"
]:
task
for
task
in
stream_jsonl
(
data_file
)}
else
:
else
:
raise
f
"Dataset:
{
dataset_type
}
not supported."
raise
f
"Dataset:
{
dataset_type
}
not supported."
...
@@ -147,41 +149,191 @@ def is_code_generation_finished(
...
@@ -147,41 +149,191 @@ def is_code_generation_finished(
return
False
return
False
def
cleanup_code
(
def
find_method_name
(
language_type
,
sample
):
code
:
str
,
"""查找方法名"""
language_type
:
str
=
None
if
language_type
.
lower
()
==
"python"
:
):
declaration
=
sample
[
"declaration"
]
ret
=
re
.
search
(
"def (.*?)
\\
("
,
declaration
)
if
ret
is
None
:
return
None
return
ret
.
group
(
1
).
strip
()
elif
language_type
.
lower
()
==
"cpp"
:
declaration
=
sample
[
"declaration"
]
ret
=
re
.
search
(
" (.*?)
\\
("
,
declaration
)
if
ret
is
None
:
return
None
method_name
=
ret
.
group
(
1
).
strip
()
if
" "
in
method_name
:
return
method_name
[
1
]
return
method_name
elif
language_type
.
lower
()
==
"java"
:
declaration
=
sample
[
"declaration"
]
ret
=
re
.
search
(
" (.*?)
\\
("
,
declaration
)
if
ret
is
None
:
return
None
method_name
=
ret
.
group
(
1
).
strip
()
if
" "
in
method_name
:
return
method_name
[
1
]
return
method_name
elif
language_type
.
lower
()
in
[
"js"
,
"javascript"
]:
declaration
=
sample
[
"declaration"
]
ret
=
re
.
search
(
"const (.*?) "
,
declaration
)
if
ret
is
None
:
return
None
method_name
=
ret
.
group
(
1
).
strip
()
return
method_name
elif
language_type
.
lower
()
==
"go"
:
declaration
=
sample
[
"declaration"
]
ret
=
re
.
search
(
"func (.*?)
\\
("
,
declaration
)
if
ret
is
None
:
return
None
return
ret
.
group
(
1
).
strip
()
elif
language_type
==
"rust"
:
declaration
=
sample
[
"declaration"
]
ret
=
re
.
search
(
"fn (.*?)
\\
("
,
declaration
)
if
ret
is
None
:
return
None
return
ret
.
group
(
1
).
strip
()
else
:
return
None
def
extract_markdown_code
(
content
):
"""提取markdown中的代码,即"```"中的内容"""
codes
=
[]
rets
=
re
.
findall
(
"```([
\\
s
\\
S]*?)```"
,
content
)
if
rets
is
None
:
return
codes
for
ret
in
rets
:
if
not
ret
.
startswith
(
"
\n
"
):
lines
=
ret
.
split
(
"
\n
"
)
codes
.
append
(
""
.
join
(
lines
[
1
:]))
else
:
codes
.
append
(
ret
.
strip
())
return
codes
def
cleanup_code
(
code
:
str
,
sample
,
language_type
:
str
=
None
):
"""
"""
Cleans up the generated code.
Cleans up the generated code.
"""
"""
if
language_type
is
None
:
if
language_type
is
None
:
return
code
return
code
method_name
=
find_method_name
(
language_type
,
sample
)
if
method_name
is
None
:
return
code
method_body
=
code
if
language_type
.
lower
()
==
"python"
:
if
language_type
.
lower
()
==
"python"
:
end_words
=
[
"
\n
def"
,
"
\n
class"
,
"
\n
if"
,
"
\n
#"
,
"
\n
print"
,
"
\n
assert"
]
method_lines
=
[]
for
w
in
end_words
:
for
line
in
code
.
split
(
"
\n
"
):
if
w
in
code
:
if
f
"def
{
method_name
}
"
in
line
:
code
=
code
[:
code
.
rfind
(
w
)]
method_lines
.
append
(
line
)
continue
if
method_lines
:
method_lines
.
append
(
line
)
if
line
.
startswith
(
" return"
):
break
if
method_lines
:
method_body
=
"
\n
"
.
join
(
method_lines
[
1
:])
elif
language_type
.
lower
()
==
"java"
:
elif
language_type
.
lower
()
==
"java"
:
main_pos
=
code
.
find
(
"public static void main"
)
method_lines
=
[]
if
main_pos
!=
-
1
:
bracket_left
=
0
code
=
code
[:
main_pos
]
+
'}'
bracket_right
=
0
if
'}'
in
code
:
for
line
in
code
.
split
(
"
\n
"
):
code
=
code
[:
code
.
rfind
(
'}'
)]
+
'}'
new_line
=
line
.
strip
()
if
code
.
count
(
'{'
)
+
1
==
code
.
count
(
'}'
):
counter
=
Counter
(
new_line
)
code
+=
"
\n
}"
if
new_line
.
startswith
(
"public"
)
and
method_name
in
new_line
:
method_lines
.
append
(
line
)
bracket_left
+=
counter
[
"{"
]
bracket_right
+=
counter
[
"}"
]
continue
if
method_lines
:
method_lines
.
append
(
line
)
bracket_left
+=
counter
[
"{"
]
bracket_right
+=
counter
[
"}"
]
if
bracket_left
==
bracket_right
and
bracket_right
>
0
:
break
if
method_lines
:
method_lines
.
append
(
"}"
)
method_body
=
"
\n
"
.
join
(
method_lines
[
1
:])
elif
language_type
.
lower
()
==
"go"
:
elif
language_type
.
lower
()
==
"go"
:
end_words
=
[
"
\n
//"
,
"
\n
func main("
]
method_lines
=
[]
for
w
in
end_words
:
bracket_left
=
0
if
w
in
code
:
bracket_right
=
0
code
=
code
[:
code
.
rfind
(
w
)]
for
line
in
code
.
split
(
"
\n
"
):
if
'}'
in
code
:
counter
=
Counter
(
line
)
code
=
code
[:
code
.
rfind
(
'}'
)]
+
'}'
if
f
"func
{
method_name
}
"
in
line
:
method_lines
.
append
(
line
)
bracket_left
+=
counter
[
"{"
]
bracket_right
+=
counter
[
"}"
]
continue
if
method_lines
:
method_lines
.
append
(
line
)
bracket_left
+=
counter
[
"{"
]
bracket_right
+=
counter
[
"}"
]
if
bracket_left
==
bracket_right
and
bracket_right
>
0
:
break
if
method_lines
:
method_body
=
"
\n
"
.
join
(
method_lines
[
1
:])
elif
language_type
.
lower
()
==
"cpp"
:
elif
language_type
.
lower
()
==
"cpp"
:
if
'}'
in
code
:
method_lines
=
[]
code
=
code
[:
code
.
rfind
(
'}'
)]
+
'}'
bracket_left
=
0
bracket_right
=
0
for
line
in
code
.
split
(
"
\n
"
):
counter
=
Counter
(
line
)
if
f
"
{
method_name
}
("
in
line
:
method_lines
.
append
(
line
)
bracket_left
+=
counter
[
"{"
]
bracket_right
+=
counter
[
"}"
]
continue
if
method_lines
:
method_lines
.
append
(
line
)
bracket_left
+=
counter
[
"{"
]
bracket_right
+=
counter
[
"}"
]
if
bracket_left
==
bracket_right
and
bracket_right
>
0
:
break
if
method_lines
:
method_body
=
"
\n
"
.
join
(
method_lines
[
1
:])
elif
language_type
.
lower
()
==
"js"
:
elif
language_type
.
lower
()
==
"js"
:
if
'}'
in
code
:
method_lines
=
[]
code
=
code
[:
code
.
rfind
(
'}'
)]
+
'}'
bracket_left
=
0
bracket_right
=
0
for
line
in
code
.
split
(
"
\n
"
):
counter
=
Counter
(
line
)
if
f
"const
{
method_name
}
"
in
line
:
method_lines
.
append
(
line
)
bracket_left
+=
counter
[
"{"
]
bracket_right
+=
counter
[
"}"
]
continue
if
method_lines
:
method_lines
.
append
(
line
)
bracket_left
+=
counter
[
"{"
]
bracket_right
+=
counter
[
"}"
]
if
bracket_left
==
bracket_right
and
bracket_right
>
0
:
break
return
code
if
method_lines
:
method_body
=
"
\n
"
.
join
(
method_lines
[
1
:])
return
method_body
+
"
\n
"
def
exception_reconnect
(
funct
):
"""异常重连"""
def
wrapper_func
(
*
args
,
**
kwargs
):
try
:
return
funct
(
*
args
,
**
kwargs
)
except
Exception
as
e
:
print
(
f
"exception will reconnect:
{
str
(
e
)
}
"
)
return
funct
(
*
args
,
**
kwargs
)
return
wrapper_func
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录