Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
dbd22a7e
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
4
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dbd22a7e
编写于
2月 20, 2020
作者:
X
Xiaoyao Xi
提交者:
GitHub
2月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #64 from wangxiao1021/api
add evaluate-slot.py
上级
3f0591a5
073d3319
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
82 addition
and
5 deletion
+82
-5
examples/multi-task/evaluate-intent.py
examples/multi-task/evaluate-intent.py
+1
-1
examples/multi-task/evaluate-slot.py
examples/multi-task/evaluate-slot.py
+78
-0
paddlepalm/reader/utils/reader4ernie.py
paddlepalm/reader/utils/reader4ernie.py
+3
-4
未找到文件。
examples/multi-task/evaluate-intent.py
浏览文件 @
dbd22a7e
...
...
@@ -15,7 +15,7 @@ def f1(preds, labels):
tn
=
np
.
sum
((
labels
==
'0'
)
&
(
preds
==
'0'
))
fp
=
np
.
sum
((
labels
==
'0'
)
&
(
preds
==
'1'
))
fn
=
np
.
sum
((
labels
==
'1'
)
&
(
preds
==
'0'
))
p
=
tp
*
1.0
/
(
tp
+
fp
)
p
=
tp
*
1.0
/
(
tp
+
fp
)
*
1.0
r
=
tp
*
1.0
/
(
tp
+
fn
)
*
1.0
f1
=
(
2
*
p
*
r
)
/
(
p
+
r
+
1e-8
)
return
f1
...
...
examples/multi-task/evaluate-slot.py
0 → 100644
浏览文件 @
dbd22a7e
# -*- coding: utf-8 -*-
import
json
def
load_label_map
(
map_dir
=
"./data/atis/atis_slot/label_map.json"
):
"""
:param map_dir: dict indictuing chunk type
:return:
"""
return
json
.
load
(
open
(
map_dir
,
"r"
))
def
cal_chunk
(
total_res
,
total_label
):
assert
len
(
total_label
)
==
len
(
total_res
),
"prediction result doesn't match to labels, {}, {}"
.
format
(
len
(
total_res
),
len
(
total_label
))
num_labels
=
0
num_corr
=
0
num_infers
=
0
for
res
,
label
in
zip
(
total_res
,
total_label
):
assert
len
(
res
)
==
len
(
label
),
"prediction result doesn't match to labels, {}, {}"
.
format
(
len
(
res
),
len
(
label
))
num_labels
+=
sum
([
0
if
i
==
6
else
1
for
i
in
label
])
num_corr
+=
sum
([
1
if
label
[
i
]
==
res
[
i
]
and
label
[
i
]
!=
6
else
0
for
i
in
range
(
len
(
label
))])
num_infers
+=
sum
([
0
if
i
==
6
else
1
for
i
in
res
])
precision
=
num_corr
*
1.0
/
num_infers
if
num_infers
>
0
else
0.0
recall
=
num_corr
*
1.0
/
num_labels
if
num_labels
>
0
else
0.0
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
if
precision
+
recall
>
0
else
0.0
return
precision
,
recall
,
f1
def
res_evaluate
(
res_dir
=
"./outputs/predict-slot/predictions.json"
,
data_dir
=
"./data/atis/atis_slot/test.tsv"
):
label_map
=
load_label_map
()
total_label
=
[]
with
open
(
data_dir
,
"r"
)
as
file
:
first_flag
=
True
for
line
in
file
:
if
first_flag
:
first_flag
=
False
continue
line
=
line
.
strip
(
"
\n
"
)
if
len
(
line
)
==
0
:
continue
line
=
line
.
split
(
"
\t
"
)
if
len
(
line
)
<
2
:
continue
labels
=
line
[
1
][:
-
1
].
split
(
"
\x02
"
)
total_label
.
append
(
labels
)
total_label
=
[[
label_map
[
j
]
for
j
in
i
]
for
i
in
total_label
]
total_res
=
[]
with
open
(
res_dir
,
"r"
)
as
file
:
cnt
=
0
for
line
in
file
:
line
=
line
.
strip
(
"
\n
"
)
if
len
(
line
)
==
0
:
continue
try
:
res_arr
=
json
.
loads
(
line
)
if
len
(
total_label
[
cnt
])
<
len
(
res_arr
):
total_res
.
append
(
res_arr
[
1
:
1
+
len
(
total_label
[
cnt
])])
elif
len
(
total_label
[
cnt
])
==
len
(
res_arr
):
total_res
.
append
(
res_arr
)
else
:
total_res
.
append
(
res_arr
)
total_label
[
cnt
]
=
total_label
[
cnt
][:
len
(
res_arr
)]
except
:
print
(
"json format error: {}"
.
format
(
cnt
))
print
(
line
)
cnt
+=
1
precision
,
recall
,
f1
=
cal_chunk
(
total_res
,
total_label
)
print
(
"precision: {}, recall: {}, f1: {}"
.
format
(
precision
,
recall
,
f1
))
res_evaluate
()
paddlepalm/reader/utils/reader4ernie.py
浏览文件 @
dbd22a7e
...
...
@@ -293,10 +293,9 @@ class Reader(object):
if
to_append
:
batch_records
.
append
(
record
)
else
:
ds
=
[
's'
]
*
7
for
piece
in
palm
.
distribute
.
yield_pieces
(
\
self
.
_pad_batch_records
(
batch_records
),
ds
,
batch_size
):
batch_pad_records
=
self
.
_pad_batch_records
(
batch_records
)
ds
=
[
's'
]
*
len
(
batch_pad_records
)
for
piece
in
palm
.
distribute
.
yield_pieces
(
batch_pad_records
,
ds
,
batch_size
):
yield
piece
batch_records
,
max_len
=
[
record
],
len
(
record
.
token_ids
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录