提交 fa449c72 编写于 作者: 1024的传说's avatar 1024的传说 提交者: pkpk

fix dialogue_general_understanding python3 (#3887)

上级 9a10a366
...@@ -22,26 +22,27 @@ class EvalDA(object): ...@@ -22,26 +22,27 @@ class EvalDA(object):
""" """
evaluate da testset, swda|mrda evaluate da testset, swda|mrda
""" """
def __init__(self, task_name, pred, refer):
def __init__(self, task_name, pred, refer):
""" """
predict file predict file
""" """
self.pred_file = pred self.pred_file = pred
self.refer_file = refer self.refer_file = refer
def load_data(self): def load_data(self):
""" """
load reference label and predict label load reference label and predict label
""" """
pred_label = [] pred_label = []
refer_label = [] refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8") fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
label = line.rstrip('\n').split('\t')[1] label = line.rstrip('\n').split('\t')[1]
refer_label.append(int(label)) refer_label.append(int(label))
idx = 0 idx = 0
fr = io.open(self.pred_file, 'r', encoding="utf8") fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
elems = line.rstrip('\n').split('\t') elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit(): if len(elems) != 2 or not elems[0].isdigit():
continue continue
...@@ -49,15 +50,15 @@ class EvalDA(object): ...@@ -49,15 +50,15 @@ class EvalDA(object):
pred_label.append(tag_id) pred_label.append(tag_id)
return pred_label, refer_label return pred_label, refer_label
def evaluate(self): def evaluate(self):
""" """
calculate acc metrics calculate acc metrics
""" """
pred_label, refer_label = self.load_data() pred_label, refer_label = self.load_data()
common_num = 0 common_num = 0
total_num = len(pred_label) total_num = len(pred_label)
for i in range(total_num): for i in range(total_num):
if pred_label[i] == refer_label[i]: if pred_label[i] == refer_label[i]:
common_num += 1 common_num += 1
acc = float(common_num) / total_num acc = float(common_num) / total_num
return acc return acc
...@@ -67,26 +68,27 @@ class EvalATISIntent(object): ...@@ -67,26 +68,27 @@ class EvalATISIntent(object):
""" """
evaluate da testset, swda|mrda evaluate da testset, swda|mrda
""" """
def __init__(self, pred, refer):
def __init__(self, pred, refer):
""" """
predict file predict file
""" """
self.pred_file = pred self.pred_file = pred
self.refer_file = refer self.refer_file = refer
def load_data(self): def load_data(self):
""" """
load reference label and predict label load reference label and predict label
""" """
pred_label = [] pred_label = []
refer_label = [] refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8") fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
label = line.rstrip('\n').split('\t')[0] label = line.rstrip('\n').split('\t')[0]
refer_label.append(int(label)) refer_label.append(int(label))
idx = 0 idx = 0
fr = io.open(self.pred_file, 'r', encoding="utf8") fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
elems = line.rstrip('\n').split('\t') elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit(): if len(elems) != 2 or not elems[0].isdigit():
continue continue
...@@ -94,45 +96,46 @@ class EvalATISIntent(object): ...@@ -94,45 +96,46 @@ class EvalATISIntent(object):
pred_label.append(tag_id) pred_label.append(tag_id)
return pred_label, refer_label return pred_label, refer_label
def evaluate(self): def evaluate(self):
""" """
calculate acc metrics calculate acc metrics
""" """
pred_label, refer_label = self.load_data() pred_label, refer_label = self.load_data()
common_num = 0 common_num = 0
total_num = len(pred_label) total_num = len(pred_label)
for i in range(total_num): for i in range(total_num):
if pred_label[i] == refer_label[i]: if pred_label[i] == refer_label[i]:
common_num += 1 common_num += 1
acc = float(common_num) / total_num acc = float(common_num) / total_num
return acc return acc
class EvalATISSlot(object): class EvalATISSlot(object):
""" """
evaluate atis slot evaluate atis slot
""" """
def __init__(self, pred, refer):
def __init__(self, pred, refer):
""" """
pred file pred file
""" """
self.pred_file = pred self.pred_file = pred
self.refer_file = refer self.refer_file = refer
def load_data(self): def load_data(self):
""" """
load reference label and predict label load reference label and predict label
""" """
pred_label = [] pred_label = []
refer_label = [] refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8") fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
labels = line.rstrip('\n').split('\t')[1].split() labels = line.rstrip('\n').split('\t')[1].split()
labels = [int(l) for l in labels] labels = [int(l) for l in labels]
refer_label.append(labels) refer_label.append(labels)
fr = io.open(self.pred_file, 'r', encoding="utf8") fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
if len(line.split('\t')) != 2 or not line[0].isdigit(): if len(line.split('\t')) != 2 or not line[0].isdigit():
continue continue
labels = line.rstrip('\n').split('\t')[1].split()[1:] labels = line.rstrip('\n').split('\t')[1].split()[1:]
labels = [int(l) for l in labels] labels = [int(l) for l in labels]
...@@ -140,15 +143,15 @@ class EvalATISSlot(object): ...@@ -140,15 +143,15 @@ class EvalATISSlot(object):
pred_label_equal = [] pred_label_equal = []
refer_label_equal = [] refer_label_equal = []
assert len(refer_label) == len(pred_label) assert len(refer_label) == len(pred_label)
for i in range(len(refer_label)): for i in range(len(refer_label)):
num = len(refer_label[i]) num = len(refer_label[i])
refer_label_equal.extend(refer_label[i]) refer_label_equal.extend(refer_label[i])
pred_label[i] = pred_label[i][: num] pred_label[i] = pred_label[i][:num]
pred_label_equal.extend(pred_label[i]) pred_label_equal.extend(pred_label[i])
return pred_label_equal, refer_label_equal return pred_label_equal, refer_label_equal
def evaluate(self): def evaluate(self):
""" """
evaluate f1_micro score evaluate f1_micro score
""" """
...@@ -156,13 +159,13 @@ class EvalATISSlot(object): ...@@ -156,13 +159,13 @@ class EvalATISSlot(object):
tp = dict() tp = dict()
fn = dict() fn = dict()
fp = dict() fp = dict()
for i in range(len(refer_label)): for i in range(len(refer_label)):
if refer_label[i] == pred_label[i]: if refer_label[i] == pred_label[i]:
if refer_label[i] not in tp: if refer_label[i] not in tp:
tp[refer_label[i]] = 0 tp[refer_label[i]] = 0
tp[refer_label[i]] += 1 tp[refer_label[i]] += 1
else: else:
if pred_label[i] not in fp: if pred_label[i] not in fp:
fp[pred_label[i]] = 0 fp[pred_label[i]] = 0
fp[pred_label[i]] += 1 fp[pred_label[i]] += 1
if refer_label[i] not in fn: if refer_label[i] not in fn:
...@@ -170,17 +173,17 @@ class EvalATISSlot(object): ...@@ -170,17 +173,17 @@ class EvalATISSlot(object):
fn[refer_label[i]] += 1 fn[refer_label[i]] += 1
results = ["label precision recall"] results = ["label precision recall"]
for i in range(0, 130): for i in range(0, 130):
if i not in tp: if i not in tp:
results.append(" %s: 0.0 0.0" % i) results.append(" %s: 0.0 0.0" % i)
continue continue
if i in fp: if i in fp:
precision = float(tp[i]) / (tp[i] + fp[i]) precision = float(tp[i]) / (tp[i] + fp[i])
else: else:
precision = 1.0 precision = 1.0
if i in fn: if i in fn:
recall = float(tp[i]) / (tp[i] + fn[i]) recall = float(tp[i]) / (tp[i] + fn[i])
else: else:
recall = 1.0 recall = 1.0
results.append(" %s: %.4f %.4f" % (i, precision, recall)) results.append(" %s: %.4f %.4f" % (i, precision, recall))
tp_total = sum(tp.values()) tp_total = sum(tp.values())
...@@ -193,32 +196,33 @@ class EvalATISSlot(object): ...@@ -193,32 +196,33 @@ class EvalATISSlot(object):
return "\n".join(results) return "\n".join(results)
class EvalUDC(object): class EvalUDC(object):
""" """
evaluate udc evaluate udc
""" """
def __init__(self, pred, refer):
def __init__(self, pred, refer):
""" """
predict file predict file
""" """
self.pred_file = pred self.pred_file = pred
self.refer_file = refer self.refer_file = refer
def load_data(self): def load_data(self):
""" """
load reference label and predict label load reference label and predict label
""" """
data = [] data = []
refer_label = [] refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8") fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
label = line.rstrip('\n').split('\t')[0] label = line.rstrip('\n').split('\t')[0]
refer_label.append(label) refer_label.append(label)
idx = 0 idx = 0
fr = io.open(self.pred_file, 'r', encoding="utf8") fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
elems = line.rstrip('\n').split('\t') elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit(): if len(elems) != 2 or not elems[0].isdigit():
continue continue
match_prob = elems[1] match_prob = elems[1]
data.append((float(match_prob), int(refer_label[idx]))) data.append((float(match_prob), int(refer_label[idx])))
...@@ -230,8 +234,8 @@ class EvalUDC(object): ...@@ -230,8 +234,8 @@ class EvalUDC(object):
calculate precision in recall n calculate precision in recall n
""" """
pos_score = data[ind][0] pos_score = data[ind][0]
curr = data[ind: ind + m] curr = data[ind:ind + m]
curr = sorted(curr, key = lambda x: x[0], reverse = True) curr = sorted(curr, key=lambda x: x[0], reverse=True)
if curr[n - 1][0] <= pos_score: if curr[n - 1][0] <= pos_score:
return 1 return 1
...@@ -241,20 +245,20 @@ class EvalUDC(object): ...@@ -241,20 +245,20 @@ class EvalUDC(object):
""" """
calculate udc data calculate udc data
""" """
data = self.load_data() data = self.load_data()
assert len(data) % 10 == 0 assert len(data) % 10 == 0
p_at_1_in_2 = 0.0 p_at_1_in_2 = 0.0
p_at_1_in_10 = 0.0 p_at_1_in_10 = 0.0
p_at_2_in_10 = 0.0 p_at_2_in_10 = 0.0
p_at_5_in_10 = 0.0 p_at_5_in_10 = 0.0
length = len(data)/10 length = int(len(data) / 10)
for i in range(0, length): for i in range(0, length):
ind = i * 10 ind = i * 10
assert data[ind][1] == 1 assert data[ind][1] == 1
p_at_1_in_2 += self.get_p_at_n_in_m(data, 1, 2, ind) p_at_1_in_2 += self.get_p_at_n_in_m(data, 1, 2, ind)
p_at_1_in_10 += self.get_p_at_n_in_m(data, 1, 10, ind) p_at_1_in_10 += self.get_p_at_n_in_m(data, 1, 10, ind)
p_at_2_in_10 += self.get_p_at_n_in_m(data, 2, 10, ind) p_at_2_in_10 += self.get_p_at_n_in_m(data, 2, 10, ind)
...@@ -262,13 +266,14 @@ class EvalUDC(object): ...@@ -262,13 +266,14 @@ class EvalUDC(object):
metrics_out = [p_at_1_in_2 / length, p_at_1_in_10 / length, \ metrics_out = [p_at_1_in_2 / length, p_at_1_in_10 / length, \
p_at_2_in_10 / length, p_at_5_in_10 / length] p_at_2_in_10 / length, p_at_5_in_10 / length]
return metrics_out return metrics_out
class EvalDSTC2(object): class EvalDSTC2(object):
""" """
evaluate dst testset, dstc2 evaluate dst testset, dstc2
""" """
def __init__(self, task_name, pred, refer): def __init__(self, task_name, pred, refer):
""" """
predict file predict file
...@@ -277,39 +282,39 @@ class EvalDSTC2(object): ...@@ -277,39 +282,39 @@ class EvalDSTC2(object):
self.pred_file = pred self.pred_file = pred
self.refer_file = refer self.refer_file = refer
def load_data(self): def load_data(self):
""" """
load reference label and predict label load reference label and predict label
""" """
pred_label = [] pred_label = []
refer_label = [] refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8") fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
line = line.strip('\n') line = line.strip('\n')
labels = [int(l) for l in line.split('\t')[-1].split()] labels = [int(l) for l in line.split('\t')[-1].split()]
labels = sorted(list(set(labels))) labels = sorted(list(set(labels)))
refer_label.append(" ".join([str(l) for l in labels])) refer_label.append(" ".join([str(l) for l in labels]))
all_pred = [] all_pred = []
fr = io.open(self.pred_file, 'r', encoding="utf8") fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
line = line.strip('\n') line = line.strip('\n')
all_pred.append(line) all_pred.append(line)
all_pred = all_pred[len(all_pred) - len(refer_label):] all_pred = all_pred[len(all_pred) - len(refer_label):]
for line in all_pred: for line in all_pred:
labels = [int(l) for l in line.split('\t')[-1].split()] labels = [int(l) for l in line.split('\t')[-1].split()]
labels = sorted(list(set(labels))) labels = sorted(list(set(labels)))
pred_label.append(" ".join([str(l) for l in labels])) pred_label.append(" ".join([str(l) for l in labels]))
return pred_label, refer_label return pred_label, refer_label
def evaluate(self): def evaluate(self):
""" """
calculate joint acc && overall acc calculate joint acc && overall acc
""" """
overall_all = 0.0 overall_all = 0.0
correct_joint = 0 correct_joint = 0
pred_label, refer_label = self.load_data() pred_label, refer_label = self.load_data()
for i in range(len(refer_label)): for i in range(len(refer_label)):
if refer_label[i] != pred_label[i]: if refer_label[i] != pred_label[i]:
continue continue
correct_joint += 1 correct_joint += 1
joint_all = float(correct_joint) / len(refer_label) joint_all = float(correct_joint) / len(refer_label)
...@@ -317,9 +322,9 @@ class EvalDSTC2(object): ...@@ -317,9 +322,9 @@ class EvalDSTC2(object):
return metrics_out return metrics_out
def evaluate(task_name, pred_file, refer_file): def evaluate(task_name, pred_file, refer_file):
"""evaluate task metrics""" """evaluate task metrics"""
if task_name.lower() == 'udc': if task_name.lower() == 'udc':
eval_inst = EvalUDC(pred_file, refer_file) eval_inst = EvalUDC(pred_file, refer_file)
eval_metrics = eval_inst.evaluate() eval_metrics = eval_inst.evaluate()
print("MATCHING TASK: %s metrics in testset: " % task_name) print("MATCHING TASK: %s metrics in testset: " % task_name)
...@@ -328,45 +333,46 @@ def evaluate(task_name, pred_file, refer_file): ...@@ -328,45 +333,46 @@ def evaluate(task_name, pred_file, refer_file):
print("R2@10: %s" % eval_metrics[2]) print("R2@10: %s" % eval_metrics[2])
print("R5@10: %s" % eval_metrics[3]) print("R5@10: %s" % eval_metrics[3])
elif task_name.lower() in ['swda', 'mrda']: elif task_name.lower() in ['swda', 'mrda']:
eval_inst = EvalDA(task_name.lower(), pred_file, refer_file) eval_inst = EvalDA(task_name.lower(), pred_file, refer_file)
eval_metrics = eval_inst.evaluate() eval_metrics = eval_inst.evaluate()
print("DA TASK: %s metrics in testset: " % task_name) print("DA TASK: %s metrics in testset: " % task_name)
print("ACC: %s" % eval_metrics) print("ACC: %s" % eval_metrics)
elif task_name.lower() == 'atis_intent': elif task_name.lower() == 'atis_intent':
eval_inst = EvalATISIntent(pred_file, refer_file) eval_inst = EvalATISIntent(pred_file, refer_file)
eval_metrics = eval_inst.evaluate() eval_metrics = eval_inst.evaluate()
print("INTENTION TASK: %s metrics in testset: " % task_name) print("INTENTION TASK: %s metrics in testset: " % task_name)
print("ACC: %s" % eval_metrics) print("ACC: %s" % eval_metrics)
elif task_name.lower() == 'atis_slot': elif task_name.lower() == 'atis_slot':
eval_inst = EvalATISSlot(pred_file, refer_file) eval_inst = EvalATISSlot(pred_file, refer_file)
eval_metrics = eval_inst.evaluate() eval_metrics = eval_inst.evaluate()
print("SLOT FILLING TASK: %s metrics in testset: " % task_name) print("SLOT FILLING TASK: %s metrics in testset: " % task_name)
print(eval_metrics) print(eval_metrics)
elif task_name.lower() in ['dstc2', 'dstc2_asr']: elif task_name.lower() in ['dstc2', 'dstc2_asr']:
eval_inst = EvalDSTC2(task_name.lower(), pred_file, refer_file) eval_inst = EvalDSTC2(task_name.lower(), pred_file, refer_file)
eval_metrics = eval_inst.evaluate() eval_metrics = eval_inst.evaluate()
print("DST TASK: %s metrics in testset: " % task_name) print("DST TASK: %s metrics in testset: " % task_name)
print("JOINT ACC: %s" % eval_metrics[0]) print("JOINT ACC: %s" % eval_metrics[0])
elif task_name.lower() == "multi-woz": elif task_name.lower() == "multi-woz":
eval_inst = EvalMultiWoz(pred_file, refer_file) eval_inst = EvalMultiWoz(pred_file, refer_file)
eval_metrics = eval_inst.evaluate() eval_metrics = eval_inst.evaluate()
print("DST TASK: %s metrics in testset: " % task_name) print("DST TASK: %s metrics in testset: " % task_name)
print("JOINT ACC: %s" % eval_metrics[0]) print("JOINT ACC: %s" % eval_metrics[0])
print("OVERALL ACC: %s" % eval_metrics[1]) print("OVERALL ACC: %s" % eval_metrics[1])
else: else:
print("task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]") print(
"task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]"
)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv[1:]) < 3: if len(sys.argv[1:]) < 3:
print("please input task_name predict_file reference_file") print("please input task_name predict_file reference_file")
task_name = sys.argv[1] task_name = sys.argv[1]
pred_file = sys.argv[2] pred_file = sys.argv[2]
refer_file = sys.argv[3] refer_file = sys.argv[3]
evaluate(task_name, pred_file, refer_file) evaluate(task_name, pred_file, refer_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册