提交 fa449c72 编写于 作者: 1024的传说's avatar 1024的传说 提交者: pkpk

fix dialogue_general_understanding python3 (#3887)

上级 9a10a366
......@@ -22,6 +22,7 @@ class EvalDA(object):
"""
evaluate da testset, swda|mrda
"""
def __init__(self, task_name, pred, refer):
"""
predict file
......@@ -67,6 +68,7 @@ class EvalATISIntent(object):
"""
evaluate da testset, swda|mrda
"""
def __init__(self, pred, refer):
"""
predict file
......@@ -112,6 +114,7 @@ class EvalATISSlot(object):
"""
evaluate atis slot
"""
def __init__(self, pred, refer):
"""
pred file
......@@ -143,7 +146,7 @@ class EvalATISSlot(object):
for i in range(len(refer_label)):
num = len(refer_label[i])
refer_label_equal.extend(refer_label[i])
pred_label[i] = pred_label[i][: num]
pred_label[i] = pred_label[i][:num]
pred_label_equal.extend(pred_label[i])
return pred_label_equal, refer_label_equal
......@@ -197,6 +200,7 @@ class EvalUDC(object):
"""
evaluate udc
"""
def __init__(self, pred, refer):
"""
predict file
......@@ -230,8 +234,8 @@ class EvalUDC(object):
calculate precision in recall n
"""
pos_score = data[ind][0]
curr = data[ind: ind + m]
curr = sorted(curr, key = lambda x: x[0], reverse = True)
curr = data[ind:ind + m]
curr = sorted(curr, key=lambda x: x[0], reverse=True)
if curr[n - 1][0] <= pos_score:
return 1
......@@ -249,7 +253,7 @@ class EvalUDC(object):
p_at_2_in_10 = 0.0
p_at_5_in_10 = 0.0
length = len(data)/10
length = int(len(data) / 10)
for i in range(0, length):
ind = i * 10
......@@ -269,6 +273,7 @@ class EvalDSTC2(object):
"""
evaluate dst testset, dstc2
"""
def __init__(self, task_name, pred, refer):
"""
predict file
......@@ -357,7 +362,9 @@ def evaluate(task_name, pred_file, refer_file):
print("JOINT ACC: %s" % eval_metrics[0])
print("OVERALL ACC: %s" % eval_metrics[1])
else:
print("task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]")
print(
"task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]"
)
if __name__ == "__main__":
......@@ -368,5 +375,4 @@ if __name__ == "__main__":
pred_file = sys.argv[2]
refer_file = sys.argv[3]
evaluate(task_name, pred_file, refer_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册