提交 fa449c72 编写于 作者: 1024的传说's avatar 1024的传说 提交者: pkpk

fix dialogue_general_understanding python3 (#3887)

上级 9a10a366
...@@ -22,6 +22,7 @@ class EvalDA(object): ...@@ -22,6 +22,7 @@ class EvalDA(object):
""" """
evaluate da testset, swda|mrda evaluate da testset, swda|mrda
""" """
def __init__(self, task_name, pred, refer): def __init__(self, task_name, pred, refer):
""" """
predict file predict file
...@@ -67,6 +68,7 @@ class EvalATISIntent(object): ...@@ -67,6 +68,7 @@ class EvalATISIntent(object):
""" """
evaluate da testset, swda|mrda evaluate da testset, swda|mrda
""" """
def __init__(self, pred, refer): def __init__(self, pred, refer):
""" """
predict file predict file
...@@ -112,6 +114,7 @@ class EvalATISSlot(object): ...@@ -112,6 +114,7 @@ class EvalATISSlot(object):
""" """
evaluate atis slot evaluate atis slot
""" """
def __init__(self, pred, refer): def __init__(self, pred, refer):
""" """
pred file pred file
...@@ -143,7 +146,7 @@ class EvalATISSlot(object): ...@@ -143,7 +146,7 @@ class EvalATISSlot(object):
for i in range(len(refer_label)): for i in range(len(refer_label)):
num = len(refer_label[i]) num = len(refer_label[i])
refer_label_equal.extend(refer_label[i]) refer_label_equal.extend(refer_label[i])
pred_label[i] = pred_label[i][: num] pred_label[i] = pred_label[i][:num]
pred_label_equal.extend(pred_label[i]) pred_label_equal.extend(pred_label[i])
return pred_label_equal, refer_label_equal return pred_label_equal, refer_label_equal
...@@ -197,6 +200,7 @@ class EvalUDC(object): ...@@ -197,6 +200,7 @@ class EvalUDC(object):
""" """
evaluate udc evaluate udc
""" """
def __init__(self, pred, refer): def __init__(self, pred, refer):
""" """
predict file predict file
...@@ -230,8 +234,8 @@ class EvalUDC(object): ...@@ -230,8 +234,8 @@ class EvalUDC(object):
calculate precision in recall n calculate precision in recall n
""" """
pos_score = data[ind][0] pos_score = data[ind][0]
curr = data[ind: ind + m] curr = data[ind:ind + m]
curr = sorted(curr, key = lambda x: x[0], reverse = True) curr = sorted(curr, key=lambda x: x[0], reverse=True)
if curr[n - 1][0] <= pos_score: if curr[n - 1][0] <= pos_score:
return 1 return 1
...@@ -249,7 +253,7 @@ class EvalUDC(object): ...@@ -249,7 +253,7 @@ class EvalUDC(object):
p_at_2_in_10 = 0.0 p_at_2_in_10 = 0.0
p_at_5_in_10 = 0.0 p_at_5_in_10 = 0.0
length = len(data)/10 length = int(len(data) / 10)
for i in range(0, length): for i in range(0, length):
ind = i * 10 ind = i * 10
...@@ -269,6 +273,7 @@ class EvalDSTC2(object): ...@@ -269,6 +273,7 @@ class EvalDSTC2(object):
""" """
evaluate dst testset, dstc2 evaluate dst testset, dstc2
""" """
def __init__(self, task_name, pred, refer): def __init__(self, task_name, pred, refer):
""" """
predict file predict file
...@@ -357,7 +362,9 @@ def evaluate(task_name, pred_file, refer_file): ...@@ -357,7 +362,9 @@ def evaluate(task_name, pred_file, refer_file):
print("JOINT ACC: %s" % eval_metrics[0]) print("JOINT ACC: %s" % eval_metrics[0])
print("OVERALL ACC: %s" % eval_metrics[1]) print("OVERALL ACC: %s" % eval_metrics[1])
else: else:
print("task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]") print(
"task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]"
)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -368,5 +375,4 @@ if __name__ == "__main__": ...@@ -368,5 +375,4 @@ if __name__ == "__main__":
pred_file = sys.argv[2] pred_file = sys.argv[2]
refer_file = sys.argv[3] refer_file = sys.argv[3]
evaluate(task_name, pred_file, refer_file) evaluate(task_name, pred_file, refer_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册