Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
b14bb89c
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
5
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b14bb89c
编写于
2月 20, 2020
作者:
W
wangxiao1021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add evaluate-slot.py
上级
ad7d3792
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
78 addition
and
0 deletion
+78
-0
examples/multi-task/evaluate-slot.py
examples/multi-task/evaluate-slot.py
+78
-0
未找到文件。
examples/multi-task/evaluate-slot.py
0 → 100644
浏览文件 @
b14bb89c
# -*- coding: utf-8 -*-
import
json
def
load_label_map
(
map_dir
=
"./data/atis/atis_slot/label_map.json"
):
"""
:param map_dir: dict indictuing chunk type
:return:
"""
return
json
.
load
(
open
(
map_dir
,
"r"
))
def
cal_chunk
(
total_res
,
total_label
):
assert
len
(
total_label
)
==
len
(
total_res
),
"prediction result doesn't match to labels, {}, {}"
.
format
(
len
(
total_res
),
len
(
total_label
))
num_labels
=
0
num_corr
=
0
num_infers
=
0
for
res
,
label
in
zip
(
total_res
,
total_label
):
assert
len
(
res
)
==
len
(
label
),
"prediction result doesn't match to labels, {}, {}"
.
format
(
len
(
res
),
len
(
label
))
num_labels
+=
sum
([
0
if
i
==
6
else
1
for
i
in
label
])
num_corr
+=
sum
([
1
if
label
[
i
]
==
res
[
i
]
and
label
[
i
]
!=
6
else
0
for
i
in
range
(
len
(
label
))])
num_infers
+=
sum
([
0
if
i
==
6
else
1
for
i
in
res
])
precision
=
num_corr
*
1.0
/
num_infers
if
num_infers
>
0
else
0.0
recall
=
num_corr
*
1.0
/
num_labels
if
num_labels
>
0
else
0.0
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
if
precision
+
recall
>
0
else
0.0
return
precision
,
recall
,
f1
def
res_evaluate
(
res_dir
=
"./outputs/predict-slot/predictions.json"
,
data_dir
=
"./data/atis/atis_slot/test.tsv"
):
label_map
=
load_label_map
()
total_label
=
[]
with
open
(
data_dir
,
"r"
)
as
file
:
first_flag
=
True
for
line
in
file
:
if
first_flag
:
first_flag
=
False
continue
line
=
line
.
strip
(
"
\n
"
)
if
len
(
line
)
==
0
:
continue
line
=
line
.
split
(
"
\t
"
)
if
len
(
line
)
<
2
:
continue
labels
=
line
[
1
][:
-
1
].
split
(
"
\x02
"
)
total_label
.
append
(
labels
)
total_label
=
[[
label_map
[
j
]
for
j
in
i
]
for
i
in
total_label
]
total_res
=
[]
with
open
(
res_dir
,
"r"
)
as
file
:
cnt
=
0
for
line
in
file
:
line
=
line
.
strip
(
"
\n
"
)
if
len
(
line
)
==
0
:
continue
try
:
res_arr
=
json
.
loads
(
line
)
if
len
(
total_label
[
cnt
])
<
len
(
res_arr
):
total_res
.
append
(
res_arr
[
1
:
1
+
len
(
total_label
[
cnt
])])
elif
len
(
total_label
[
cnt
])
==
len
(
res_arr
):
total_res
.
append
(
res_arr
)
else
:
total_res
.
append
(
res_arr
)
total_label
[
cnt
]
=
total_label
[
cnt
][:
len
(
res_arr
)]
except
:
print
(
"json format error: {}"
.
format
(
cnt
))
print
(
line
)
cnt
+=
1
precision
,
recall
,
f1
=
cal_chunk
(
total_res
,
total_label
)
print
(
"precision: {}, recall: {}, f1: {}"
.
format
(
precision
,
recall
,
f1
))
res_evaluate
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录