Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
fa449c72
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fa449c72
编写于
11月 06, 2019
作者:
1024的传说
提交者:
pkpk
11月 06, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dialogue_general_understanding python3 (#3887)
上级
9a10a366
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
73 addition
and
67 deletion
+73
-67
PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/evaluation.py
...Dialogue/dialogue_general_understanding/dgu/evaluation.py
+73
-67
未找到文件。
PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/evaluation.py
浏览文件 @
fa449c72
...
@@ -22,26 +22,27 @@ class EvalDA(object):
...
@@ -22,26 +22,27 @@ class EvalDA(object):
"""
"""
evaluate da testset, swda|mrda
evaluate da testset, swda|mrda
"""
"""
def
__init__
(
self
,
task_name
,
pred
,
refer
):
def
__init__
(
self
,
task_name
,
pred
,
refer
):
"""
"""
predict file
predict file
"""
"""
self
.
pred_file
=
pred
self
.
pred_file
=
pred
self
.
refer_file
=
refer
self
.
refer_file
=
refer
def
load_data
(
self
):
def
load_data
(
self
):
"""
"""
load reference label and predict label
load reference label and predict label
"""
"""
pred_label
=
[]
pred_label
=
[]
refer_label
=
[]
refer_label
=
[]
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
label
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
1
]
label
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
1
]
refer_label
.
append
(
int
(
label
))
refer_label
.
append
(
int
(
label
))
idx
=
0
idx
=
0
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
elems
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)
elems
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)
if
len
(
elems
)
!=
2
or
not
elems
[
0
].
isdigit
():
if
len
(
elems
)
!=
2
or
not
elems
[
0
].
isdigit
():
continue
continue
...
@@ -49,15 +50,15 @@ class EvalDA(object):
...
@@ -49,15 +50,15 @@ class EvalDA(object):
pred_label
.
append
(
tag_id
)
pred_label
.
append
(
tag_id
)
return
pred_label
,
refer_label
return
pred_label
,
refer_label
def
evaluate
(
self
):
def
evaluate
(
self
):
"""
"""
calculate acc metrics
calculate acc metrics
"""
"""
pred_label
,
refer_label
=
self
.
load_data
()
pred_label
,
refer_label
=
self
.
load_data
()
common_num
=
0
common_num
=
0
total_num
=
len
(
pred_label
)
total_num
=
len
(
pred_label
)
for
i
in
range
(
total_num
):
for
i
in
range
(
total_num
):
if
pred_label
[
i
]
==
refer_label
[
i
]:
if
pred_label
[
i
]
==
refer_label
[
i
]:
common_num
+=
1
common_num
+=
1
acc
=
float
(
common_num
)
/
total_num
acc
=
float
(
common_num
)
/
total_num
return
acc
return
acc
...
@@ -67,26 +68,27 @@ class EvalATISIntent(object):
...
@@ -67,26 +68,27 @@ class EvalATISIntent(object):
"""
"""
evaluate da testset, swda|mrda
evaluate da testset, swda|mrda
"""
"""
def
__init__
(
self
,
pred
,
refer
):
def
__init__
(
self
,
pred
,
refer
):
"""
"""
predict file
predict file
"""
"""
self
.
pred_file
=
pred
self
.
pred_file
=
pred
self
.
refer_file
=
refer
self
.
refer_file
=
refer
def
load_data
(
self
):
def
load_data
(
self
):
"""
"""
load reference label and predict label
load reference label and predict label
"""
"""
pred_label
=
[]
pred_label
=
[]
refer_label
=
[]
refer_label
=
[]
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
label
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
0
]
label
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
0
]
refer_label
.
append
(
int
(
label
))
refer_label
.
append
(
int
(
label
))
idx
=
0
idx
=
0
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
elems
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)
elems
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)
if
len
(
elems
)
!=
2
or
not
elems
[
0
].
isdigit
():
if
len
(
elems
)
!=
2
or
not
elems
[
0
].
isdigit
():
continue
continue
...
@@ -94,45 +96,46 @@ class EvalATISIntent(object):
...
@@ -94,45 +96,46 @@ class EvalATISIntent(object):
pred_label
.
append
(
tag_id
)
pred_label
.
append
(
tag_id
)
return
pred_label
,
refer_label
return
pred_label
,
refer_label
def
evaluate
(
self
):
def
evaluate
(
self
):
"""
"""
calculate acc metrics
calculate acc metrics
"""
"""
pred_label
,
refer_label
=
self
.
load_data
()
pred_label
,
refer_label
=
self
.
load_data
()
common_num
=
0
common_num
=
0
total_num
=
len
(
pred_label
)
total_num
=
len
(
pred_label
)
for
i
in
range
(
total_num
):
for
i
in
range
(
total_num
):
if
pred_label
[
i
]
==
refer_label
[
i
]:
if
pred_label
[
i
]
==
refer_label
[
i
]:
common_num
+=
1
common_num
+=
1
acc
=
float
(
common_num
)
/
total_num
acc
=
float
(
common_num
)
/
total_num
return
acc
return
acc
class
EvalATISSlot
(
object
):
class
EvalATISSlot
(
object
):
"""
"""
evaluate atis slot
evaluate atis slot
"""
"""
def
__init__
(
self
,
pred
,
refer
):
def
__init__
(
self
,
pred
,
refer
):
"""
"""
pred file
pred file
"""
"""
self
.
pred_file
=
pred
self
.
pred_file
=
pred
self
.
refer_file
=
refer
self
.
refer_file
=
refer
def
load_data
(
self
):
def
load_data
(
self
):
"""
"""
load reference label and predict label
load reference label and predict label
"""
"""
pred_label
=
[]
pred_label
=
[]
refer_label
=
[]
refer_label
=
[]
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
labels
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
1
].
split
()
labels
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
1
].
split
()
labels
=
[
int
(
l
)
for
l
in
labels
]
labels
=
[
int
(
l
)
for
l
in
labels
]
refer_label
.
append
(
labels
)
refer_label
.
append
(
labels
)
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
if
len
(
line
.
split
(
'
\t
'
))
!=
2
or
not
line
[
0
].
isdigit
():
if
len
(
line
.
split
(
'
\t
'
))
!=
2
or
not
line
[
0
].
isdigit
():
continue
continue
labels
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
1
].
split
()[
1
:]
labels
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
1
].
split
()[
1
:]
labels
=
[
int
(
l
)
for
l
in
labels
]
labels
=
[
int
(
l
)
for
l
in
labels
]
...
@@ -140,15 +143,15 @@ class EvalATISSlot(object):
...
@@ -140,15 +143,15 @@ class EvalATISSlot(object):
pred_label_equal
=
[]
pred_label_equal
=
[]
refer_label_equal
=
[]
refer_label_equal
=
[]
assert
len
(
refer_label
)
==
len
(
pred_label
)
assert
len
(
refer_label
)
==
len
(
pred_label
)
for
i
in
range
(
len
(
refer_label
)):
for
i
in
range
(
len
(
refer_label
)):
num
=
len
(
refer_label
[
i
])
num
=
len
(
refer_label
[
i
])
refer_label_equal
.
extend
(
refer_label
[
i
])
refer_label_equal
.
extend
(
refer_label
[
i
])
pred_label
[
i
]
=
pred_label
[
i
][:
num
]
pred_label
[
i
]
=
pred_label
[
i
][:
num
]
pred_label_equal
.
extend
(
pred_label
[
i
])
pred_label_equal
.
extend
(
pred_label
[
i
])
return
pred_label_equal
,
refer_label_equal
return
pred_label_equal
,
refer_label_equal
def
evaluate
(
self
):
def
evaluate
(
self
):
"""
"""
evaluate f1_micro score
evaluate f1_micro score
"""
"""
...
@@ -156,13 +159,13 @@ class EvalATISSlot(object):
...
@@ -156,13 +159,13 @@ class EvalATISSlot(object):
tp
=
dict
()
tp
=
dict
()
fn
=
dict
()
fn
=
dict
()
fp
=
dict
()
fp
=
dict
()
for
i
in
range
(
len
(
refer_label
)):
for
i
in
range
(
len
(
refer_label
)):
if
refer_label
[
i
]
==
pred_label
[
i
]:
if
refer_label
[
i
]
==
pred_label
[
i
]:
if
refer_label
[
i
]
not
in
tp
:
if
refer_label
[
i
]
not
in
tp
:
tp
[
refer_label
[
i
]]
=
0
tp
[
refer_label
[
i
]]
=
0
tp
[
refer_label
[
i
]]
+=
1
tp
[
refer_label
[
i
]]
+=
1
else
:
else
:
if
pred_label
[
i
]
not
in
fp
:
if
pred_label
[
i
]
not
in
fp
:
fp
[
pred_label
[
i
]]
=
0
fp
[
pred_label
[
i
]]
=
0
fp
[
pred_label
[
i
]]
+=
1
fp
[
pred_label
[
i
]]
+=
1
if
refer_label
[
i
]
not
in
fn
:
if
refer_label
[
i
]
not
in
fn
:
...
@@ -170,17 +173,17 @@ class EvalATISSlot(object):
...
@@ -170,17 +173,17 @@ class EvalATISSlot(object):
fn
[
refer_label
[
i
]]
+=
1
fn
[
refer_label
[
i
]]
+=
1
results
=
[
"label precision recall"
]
results
=
[
"label precision recall"
]
for
i
in
range
(
0
,
130
):
for
i
in
range
(
0
,
130
):
if
i
not
in
tp
:
if
i
not
in
tp
:
results
.
append
(
" %s: 0.0 0.0"
%
i
)
results
.
append
(
" %s: 0.0 0.0"
%
i
)
continue
continue
if
i
in
fp
:
if
i
in
fp
:
precision
=
float
(
tp
[
i
])
/
(
tp
[
i
]
+
fp
[
i
])
precision
=
float
(
tp
[
i
])
/
(
tp
[
i
]
+
fp
[
i
])
else
:
else
:
precision
=
1.0
precision
=
1.0
if
i
in
fn
:
if
i
in
fn
:
recall
=
float
(
tp
[
i
])
/
(
tp
[
i
]
+
fn
[
i
])
recall
=
float
(
tp
[
i
])
/
(
tp
[
i
]
+
fn
[
i
])
else
:
else
:
recall
=
1.0
recall
=
1.0
results
.
append
(
" %s: %.4f %.4f"
%
(
i
,
precision
,
recall
))
results
.
append
(
" %s: %.4f %.4f"
%
(
i
,
precision
,
recall
))
tp_total
=
sum
(
tp
.
values
())
tp_total
=
sum
(
tp
.
values
())
...
@@ -193,32 +196,33 @@ class EvalATISSlot(object):
...
@@ -193,32 +196,33 @@ class EvalATISSlot(object):
return
"
\n
"
.
join
(
results
)
return
"
\n
"
.
join
(
results
)
class
EvalUDC
(
object
):
class
EvalUDC
(
object
):
"""
"""
evaluate udc
evaluate udc
"""
"""
def
__init__
(
self
,
pred
,
refer
):
def
__init__
(
self
,
pred
,
refer
):
"""
"""
predict file
predict file
"""
"""
self
.
pred_file
=
pred
self
.
pred_file
=
pred
self
.
refer_file
=
refer
self
.
refer_file
=
refer
def
load_data
(
self
):
def
load_data
(
self
):
"""
"""
load reference label and predict label
load reference label and predict label
"""
"""
data
=
[]
data
=
[]
refer_label
=
[]
refer_label
=
[]
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
label
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
0
]
label
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)[
0
]
refer_label
.
append
(
label
)
refer_label
.
append
(
label
)
idx
=
0
idx
=
0
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
elems
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)
elems
=
line
.
rstrip
(
'
\n
'
).
split
(
'
\t
'
)
if
len
(
elems
)
!=
2
or
not
elems
[
0
].
isdigit
():
if
len
(
elems
)
!=
2
or
not
elems
[
0
].
isdigit
():
continue
continue
match_prob
=
elems
[
1
]
match_prob
=
elems
[
1
]
data
.
append
((
float
(
match_prob
),
int
(
refer_label
[
idx
])))
data
.
append
((
float
(
match_prob
),
int
(
refer_label
[
idx
])))
...
@@ -230,8 +234,8 @@ class EvalUDC(object):
...
@@ -230,8 +234,8 @@ class EvalUDC(object):
calculate precision in recall n
calculate precision in recall n
"""
"""
pos_score
=
data
[
ind
][
0
]
pos_score
=
data
[
ind
][
0
]
curr
=
data
[
ind
:
ind
+
m
]
curr
=
data
[
ind
:
ind
+
m
]
curr
=
sorted
(
curr
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
curr
=
sorted
(
curr
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
if
curr
[
n
-
1
][
0
]
<=
pos_score
:
if
curr
[
n
-
1
][
0
]
<=
pos_score
:
return
1
return
1
...
@@ -241,20 +245,20 @@ class EvalUDC(object):
...
@@ -241,20 +245,20 @@ class EvalUDC(object):
"""
"""
calculate udc data
calculate udc data
"""
"""
data
=
self
.
load_data
()
data
=
self
.
load_data
()
assert
len
(
data
)
%
10
==
0
assert
len
(
data
)
%
10
==
0
p_at_1_in_2
=
0.0
p_at_1_in_2
=
0.0
p_at_1_in_10
=
0.0
p_at_1_in_10
=
0.0
p_at_2_in_10
=
0.0
p_at_2_in_10
=
0.0
p_at_5_in_10
=
0.0
p_at_5_in_10
=
0.0
length
=
len
(
data
)
/
10
length
=
int
(
len
(
data
)
/
10
)
for
i
in
range
(
0
,
length
):
for
i
in
range
(
0
,
length
):
ind
=
i
*
10
ind
=
i
*
10
assert
data
[
ind
][
1
]
==
1
assert
data
[
ind
][
1
]
==
1
p_at_1_in_2
+=
self
.
get_p_at_n_in_m
(
data
,
1
,
2
,
ind
)
p_at_1_in_2
+=
self
.
get_p_at_n_in_m
(
data
,
1
,
2
,
ind
)
p_at_1_in_10
+=
self
.
get_p_at_n_in_m
(
data
,
1
,
10
,
ind
)
p_at_1_in_10
+=
self
.
get_p_at_n_in_m
(
data
,
1
,
10
,
ind
)
p_at_2_in_10
+=
self
.
get_p_at_n_in_m
(
data
,
2
,
10
,
ind
)
p_at_2_in_10
+=
self
.
get_p_at_n_in_m
(
data
,
2
,
10
,
ind
)
...
@@ -262,13 +266,14 @@ class EvalUDC(object):
...
@@ -262,13 +266,14 @@ class EvalUDC(object):
metrics_out
=
[
p_at_1_in_2
/
length
,
p_at_1_in_10
/
length
,
\
metrics_out
=
[
p_at_1_in_2
/
length
,
p_at_1_in_10
/
length
,
\
p_at_2_in_10
/
length
,
p_at_5_in_10
/
length
]
p_at_2_in_10
/
length
,
p_at_5_in_10
/
length
]
return
metrics_out
return
metrics_out
class
EvalDSTC2
(
object
):
class
EvalDSTC2
(
object
):
"""
"""
evaluate dst testset, dstc2
evaluate dst testset, dstc2
"""
"""
def
__init__
(
self
,
task_name
,
pred
,
refer
):
def
__init__
(
self
,
task_name
,
pred
,
refer
):
"""
"""
predict file
predict file
...
@@ -277,39 +282,39 @@ class EvalDSTC2(object):
...
@@ -277,39 +282,39 @@ class EvalDSTC2(object):
self
.
pred_file
=
pred
self
.
pred_file
=
pred
self
.
refer_file
=
refer
self
.
refer_file
=
refer
def
load_data
(
self
):
def
load_data
(
self
):
"""
"""
load reference label and predict label
load reference label and predict label
"""
"""
pred_label
=
[]
pred_label
=
[]
refer_label
=
[]
refer_label
=
[]
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
refer_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
line
=
line
.
strip
(
'
\n
'
)
line
=
line
.
strip
(
'
\n
'
)
labels
=
[
int
(
l
)
for
l
in
line
.
split
(
'
\t
'
)[
-
1
].
split
()]
labels
=
[
int
(
l
)
for
l
in
line
.
split
(
'
\t
'
)[
-
1
].
split
()]
labels
=
sorted
(
list
(
set
(
labels
)))
labels
=
sorted
(
list
(
set
(
labels
)))
refer_label
.
append
(
" "
.
join
([
str
(
l
)
for
l
in
labels
]))
refer_label
.
append
(
" "
.
join
([
str
(
l
)
for
l
in
labels
]))
all_pred
=
[]
all_pred
=
[]
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
fr
=
io
.
open
(
self
.
pred_file
,
'r'
,
encoding
=
"utf8"
)
for
line
in
fr
:
for
line
in
fr
:
line
=
line
.
strip
(
'
\n
'
)
line
=
line
.
strip
(
'
\n
'
)
all_pred
.
append
(
line
)
all_pred
.
append
(
line
)
all_pred
=
all_pred
[
len
(
all_pred
)
-
len
(
refer_label
):]
all_pred
=
all_pred
[
len
(
all_pred
)
-
len
(
refer_label
):]
for
line
in
all_pred
:
for
line
in
all_pred
:
labels
=
[
int
(
l
)
for
l
in
line
.
split
(
'
\t
'
)[
-
1
].
split
()]
labels
=
[
int
(
l
)
for
l
in
line
.
split
(
'
\t
'
)[
-
1
].
split
()]
labels
=
sorted
(
list
(
set
(
labels
)))
labels
=
sorted
(
list
(
set
(
labels
)))
pred_label
.
append
(
" "
.
join
([
str
(
l
)
for
l
in
labels
]))
pred_label
.
append
(
" "
.
join
([
str
(
l
)
for
l
in
labels
]))
return
pred_label
,
refer_label
return
pred_label
,
refer_label
def
evaluate
(
self
):
def
evaluate
(
self
):
"""
"""
calculate joint acc && overall acc
calculate joint acc && overall acc
"""
"""
overall_all
=
0.0
overall_all
=
0.0
correct_joint
=
0
correct_joint
=
0
pred_label
,
refer_label
=
self
.
load_data
()
pred_label
,
refer_label
=
self
.
load_data
()
for
i
in
range
(
len
(
refer_label
)):
for
i
in
range
(
len
(
refer_label
)):
if
refer_label
[
i
]
!=
pred_label
[
i
]:
if
refer_label
[
i
]
!=
pred_label
[
i
]:
continue
continue
correct_joint
+=
1
correct_joint
+=
1
joint_all
=
float
(
correct_joint
)
/
len
(
refer_label
)
joint_all
=
float
(
correct_joint
)
/
len
(
refer_label
)
...
@@ -317,9 +322,9 @@ class EvalDSTC2(object):
...
@@ -317,9 +322,9 @@ class EvalDSTC2(object):
return
metrics_out
return
metrics_out
def
evaluate
(
task_name
,
pred_file
,
refer_file
):
def
evaluate
(
task_name
,
pred_file
,
refer_file
):
"""evaluate task metrics"""
"""evaluate task metrics"""
if
task_name
.
lower
()
==
'udc'
:
if
task_name
.
lower
()
==
'udc'
:
eval_inst
=
EvalUDC
(
pred_file
,
refer_file
)
eval_inst
=
EvalUDC
(
pred_file
,
refer_file
)
eval_metrics
=
eval_inst
.
evaluate
()
eval_metrics
=
eval_inst
.
evaluate
()
print
(
"MATCHING TASK: %s metrics in testset: "
%
task_name
)
print
(
"MATCHING TASK: %s metrics in testset: "
%
task_name
)
...
@@ -328,45 +333,46 @@ def evaluate(task_name, pred_file, refer_file):
...
@@ -328,45 +333,46 @@ def evaluate(task_name, pred_file, refer_file):
print
(
"R2@10: %s"
%
eval_metrics
[
2
])
print
(
"R2@10: %s"
%
eval_metrics
[
2
])
print
(
"R5@10: %s"
%
eval_metrics
[
3
])
print
(
"R5@10: %s"
%
eval_metrics
[
3
])
elif
task_name
.
lower
()
in
[
'swda'
,
'mrda'
]:
elif
task_name
.
lower
()
in
[
'swda'
,
'mrda'
]:
eval_inst
=
EvalDA
(
task_name
.
lower
(),
pred_file
,
refer_file
)
eval_inst
=
EvalDA
(
task_name
.
lower
(),
pred_file
,
refer_file
)
eval_metrics
=
eval_inst
.
evaluate
()
eval_metrics
=
eval_inst
.
evaluate
()
print
(
"DA TASK: %s metrics in testset: "
%
task_name
)
print
(
"DA TASK: %s metrics in testset: "
%
task_name
)
print
(
"ACC: %s"
%
eval_metrics
)
print
(
"ACC: %s"
%
eval_metrics
)
elif
task_name
.
lower
()
==
'atis_intent'
:
elif
task_name
.
lower
()
==
'atis_intent'
:
eval_inst
=
EvalATISIntent
(
pred_file
,
refer_file
)
eval_inst
=
EvalATISIntent
(
pred_file
,
refer_file
)
eval_metrics
=
eval_inst
.
evaluate
()
eval_metrics
=
eval_inst
.
evaluate
()
print
(
"INTENTION TASK: %s metrics in testset: "
%
task_name
)
print
(
"INTENTION TASK: %s metrics in testset: "
%
task_name
)
print
(
"ACC: %s"
%
eval_metrics
)
print
(
"ACC: %s"
%
eval_metrics
)
elif
task_name
.
lower
()
==
'atis_slot'
:
elif
task_name
.
lower
()
==
'atis_slot'
:
eval_inst
=
EvalATISSlot
(
pred_file
,
refer_file
)
eval_inst
=
EvalATISSlot
(
pred_file
,
refer_file
)
eval_metrics
=
eval_inst
.
evaluate
()
eval_metrics
=
eval_inst
.
evaluate
()
print
(
"SLOT FILLING TASK: %s metrics in testset: "
%
task_name
)
print
(
"SLOT FILLING TASK: %s metrics in testset: "
%
task_name
)
print
(
eval_metrics
)
print
(
eval_metrics
)
elif
task_name
.
lower
()
in
[
'dstc2'
,
'dstc2_asr'
]:
elif
task_name
.
lower
()
in
[
'dstc2'
,
'dstc2_asr'
]:
eval_inst
=
EvalDSTC2
(
task_name
.
lower
(),
pred_file
,
refer_file
)
eval_inst
=
EvalDSTC2
(
task_name
.
lower
(),
pred_file
,
refer_file
)
eval_metrics
=
eval_inst
.
evaluate
()
eval_metrics
=
eval_inst
.
evaluate
()
print
(
"DST TASK: %s metrics in testset: "
%
task_name
)
print
(
"DST TASK: %s metrics in testset: "
%
task_name
)
print
(
"JOINT ACC: %s"
%
eval_metrics
[
0
])
print
(
"JOINT ACC: %s"
%
eval_metrics
[
0
])
elif
task_name
.
lower
()
==
"multi-woz"
:
elif
task_name
.
lower
()
==
"multi-woz"
:
eval_inst
=
EvalMultiWoz
(
pred_file
,
refer_file
)
eval_inst
=
EvalMultiWoz
(
pred_file
,
refer_file
)
eval_metrics
=
eval_inst
.
evaluate
()
eval_metrics
=
eval_inst
.
evaluate
()
print
(
"DST TASK: %s metrics in testset: "
%
task_name
)
print
(
"DST TASK: %s metrics in testset: "
%
task_name
)
print
(
"JOINT ACC: %s"
%
eval_metrics
[
0
])
print
(
"JOINT ACC: %s"
%
eval_metrics
[
0
])
print
(
"OVERALL ACC: %s"
%
eval_metrics
[
1
])
print
(
"OVERALL ACC: %s"
%
eval_metrics
[
1
])
else
:
else
:
print
(
"task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]"
)
print
(
"task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
if
len
(
sys
.
argv
[
1
:])
<
3
:
if
len
(
sys
.
argv
[
1
:])
<
3
:
print
(
"please input task_name predict_file reference_file"
)
print
(
"please input task_name predict_file reference_file"
)
task_name
=
sys
.
argv
[
1
]
task_name
=
sys
.
argv
[
1
]
pred_file
=
sys
.
argv
[
2
]
pred_file
=
sys
.
argv
[
2
]
refer_file
=
sys
.
argv
[
3
]
refer_file
=
sys
.
argv
[
3
]
evaluate
(
task_name
,
pred_file
,
refer_file
)
evaluate
(
task_name
,
pred_file
,
refer_file
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录