提交 fa449c72 编写于 作者: 1024的传说's avatar 1024的传说 提交者: pkpk

fix dialogue_general_understanding python3 (#3887)

上级 9a10a366
......@@ -22,26 +22,27 @@ class EvalDA(object):
"""
evaluate da testset, swda|mrda
"""
def __init__(self, task_name, pred, refer):
def __init__(self, task_name, pred, refer):
"""
predict file
"""
self.pred_file = pred
self.refer_file = refer
def load_data(self):
def load_data(self):
"""
load reference label and predict label
"""
pred_label = []
refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
label = line.rstrip('\n').split('\t')[1]
refer_label.append(int(label))
idx = 0
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit():
continue
......@@ -49,15 +50,15 @@ class EvalDA(object):
pred_label.append(tag_id)
return pred_label, refer_label
def evaluate(self):
def evaluate(self):
"""
calculate acc metrics
"""
pred_label, refer_label = self.load_data()
common_num = 0
total_num = len(pred_label)
for i in range(total_num):
if pred_label[i] == refer_label[i]:
for i in range(total_num):
if pred_label[i] == refer_label[i]:
common_num += 1
acc = float(common_num) / total_num
return acc
......@@ -67,26 +68,27 @@ class EvalATISIntent(object):
"""
evaluate da testset, swda|mrda
"""
def __init__(self, pred, refer):
def __init__(self, pred, refer):
"""
predict file
"""
self.pred_file = pred
self.refer_file = refer
def load_data(self):
def load_data(self):
"""
load reference label and predict label
"""
pred_label = []
refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
label = line.rstrip('\n').split('\t')[0]
refer_label.append(int(label))
idx = 0
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit():
continue
......@@ -94,45 +96,46 @@ class EvalATISIntent(object):
pred_label.append(tag_id)
return pred_label, refer_label
def evaluate(self):
def evaluate(self):
"""
calculate acc metrics
"""
pred_label, refer_label = self.load_data()
common_num = 0
total_num = len(pred_label)
for i in range(total_num):
if pred_label[i] == refer_label[i]:
for i in range(total_num):
if pred_label[i] == refer_label[i]:
common_num += 1
acc = float(common_num) / total_num
return acc
class EvalATISSlot(object):
class EvalATISSlot(object):
"""
evaluate atis slot
"""
def __init__(self, pred, refer):
def __init__(self, pred, refer):
"""
pred file
"""
self.pred_file = pred
self.refer_file = refer
def load_data(self):
def load_data(self):
"""
load reference label and predict label
"""
pred_label = []
refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
labels = line.rstrip('\n').split('\t')[1].split()
labels = [int(l) for l in labels]
refer_label.append(labels)
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
if len(line.split('\t')) != 2 or not line[0].isdigit():
for line in fr:
if len(line.split('\t')) != 2 or not line[0].isdigit():
continue
labels = line.rstrip('\n').split('\t')[1].split()[1:]
labels = [int(l) for l in labels]
......@@ -140,15 +143,15 @@ class EvalATISSlot(object):
pred_label_equal = []
refer_label_equal = []
assert len(refer_label) == len(pred_label)
for i in range(len(refer_label)):
for i in range(len(refer_label)):
num = len(refer_label[i])
refer_label_equal.extend(refer_label[i])
pred_label[i] = pred_label[i][: num]
pred_label[i] = pred_label[i][:num]
pred_label_equal.extend(pred_label[i])
return pred_label_equal, refer_label_equal
def evaluate(self):
def evaluate(self):
"""
evaluate f1_micro score
"""
......@@ -156,13 +159,13 @@ class EvalATISSlot(object):
tp = dict()
fn = dict()
fp = dict()
for i in range(len(refer_label)):
for i in range(len(refer_label)):
if refer_label[i] == pred_label[i]:
if refer_label[i] not in tp:
if refer_label[i] not in tp:
tp[refer_label[i]] = 0
tp[refer_label[i]] += 1
else:
if pred_label[i] not in fp:
else:
if pred_label[i] not in fp:
fp[pred_label[i]] = 0
fp[pred_label[i]] += 1
if refer_label[i] not in fn:
......@@ -170,17 +173,17 @@ class EvalATISSlot(object):
fn[refer_label[i]] += 1
results = ["label precision recall"]
for i in range(0, 130):
if i not in tp:
for i in range(0, 130):
if i not in tp:
results.append(" %s: 0.0 0.0" % i)
continue
if i in fp:
if i in fp:
precision = float(tp[i]) / (tp[i] + fp[i])
else:
else:
precision = 1.0
if i in fn:
if i in fn:
recall = float(tp[i]) / (tp[i] + fn[i])
else:
else:
recall = 1.0
results.append(" %s: %.4f %.4f" % (i, precision, recall))
tp_total = sum(tp.values())
......@@ -193,32 +196,33 @@ class EvalATISSlot(object):
return "\n".join(results)
class EvalUDC(object):
class EvalUDC(object):
"""
evaluate udc
"""
def __init__(self, pred, refer):
def __init__(self, pred, refer):
"""
predict file
"""
self.pred_file = pred
self.refer_file = refer
def load_data(self):
def load_data(self):
"""
load reference label and predict label
"""
data = []
data = []
refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
label = line.rstrip('\n').split('\t')[0]
refer_label.append(label)
idx = 0
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit():
if len(elems) != 2 or not elems[0].isdigit():
continue
match_prob = elems[1]
data.append((float(match_prob), int(refer_label[idx])))
......@@ -230,8 +234,8 @@ class EvalUDC(object):
calculate precision in recall n
"""
pos_score = data[ind][0]
curr = data[ind: ind + m]
curr = sorted(curr, key = lambda x: x[0], reverse = True)
curr = data[ind:ind + m]
curr = sorted(curr, key=lambda x: x[0], reverse=True)
if curr[n - 1][0] <= pos_score:
return 1
......@@ -241,20 +245,20 @@ class EvalUDC(object):
"""
calculate udc data
"""
data = self.load_data()
data = self.load_data()
assert len(data) % 10 == 0
p_at_1_in_2 = 0.0
p_at_1_in_10 = 0.0
p_at_2_in_10 = 0.0
p_at_5_in_10 = 0.0
length = len(data)/10
length = int(len(data) / 10)
for i in range(0, length):
ind = i * 10
assert data[ind][1] == 1
p_at_1_in_2 += self.get_p_at_n_in_m(data, 1, 2, ind)
p_at_1_in_10 += self.get_p_at_n_in_m(data, 1, 10, ind)
p_at_2_in_10 += self.get_p_at_n_in_m(data, 2, 10, ind)
......@@ -262,13 +266,14 @@ class EvalUDC(object):
metrics_out = [p_at_1_in_2 / length, p_at_1_in_10 / length, \
p_at_2_in_10 / length, p_at_5_in_10 / length]
return metrics_out
return metrics_out
class EvalDSTC2(object):
class EvalDSTC2(object):
"""
evaluate dst testset, dstc2
"""
def __init__(self, task_name, pred, refer):
"""
predict file
......@@ -277,39 +282,39 @@ class EvalDSTC2(object):
self.pred_file = pred
self.refer_file = refer
def load_data(self):
def load_data(self):
"""
load reference label and predict label
"""
pred_label = []
refer_label = []
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
line = line.strip('\n')
labels = [int(l) for l in line.split('\t')[-1].split()]
labels = sorted(list(set(labels)))
refer_label.append(" ".join([str(l) for l in labels]))
all_pred = []
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
for line in fr:
line = line.strip('\n')
all_pred.append(line)
all_pred = all_pred[len(all_pred) - len(refer_label):]
for line in all_pred:
for line in all_pred:
labels = [int(l) for l in line.split('\t')[-1].split()]
labels = sorted(list(set(labels)))
pred_label.append(" ".join([str(l) for l in labels]))
return pred_label, refer_label
def evaluate(self):
def evaluate(self):
"""
calculate joint acc && overall acc
"""
overall_all = 0.0
correct_joint = 0
pred_label, refer_label = self.load_data()
for i in range(len(refer_label)):
if refer_label[i] != pred_label[i]:
for i in range(len(refer_label)):
if refer_label[i] != pred_label[i]:
continue
correct_joint += 1
joint_all = float(correct_joint) / len(refer_label)
......@@ -317,9 +322,9 @@ class EvalDSTC2(object):
return metrics_out
def evaluate(task_name, pred_file, refer_file):
def evaluate(task_name, pred_file, refer_file):
"""evaluate task metrics"""
if task_name.lower() == 'udc':
if task_name.lower() == 'udc':
eval_inst = EvalUDC(pred_file, refer_file)
eval_metrics = eval_inst.evaluate()
print("MATCHING TASK: %s metrics in testset: " % task_name)
......@@ -328,45 +333,46 @@ def evaluate(task_name, pred_file, refer_file):
print("R2@10: %s" % eval_metrics[2])
print("R5@10: %s" % eval_metrics[3])
elif task_name.lower() in ['swda', 'mrda']:
elif task_name.lower() in ['swda', 'mrda']:
eval_inst = EvalDA(task_name.lower(), pred_file, refer_file)
eval_metrics = eval_inst.evaluate()
print("DA TASK: %s metrics in testset: " % task_name)
print("ACC: %s" % eval_metrics)
elif task_name.lower() == 'atis_intent':
elif task_name.lower() == 'atis_intent':
eval_inst = EvalATISIntent(pred_file, refer_file)
eval_metrics = eval_inst.evaluate()
print("INTENTION TASK: %s metrics in testset: " % task_name)
print("ACC: %s" % eval_metrics)
elif task_name.lower() == 'atis_slot':
elif task_name.lower() == 'atis_slot':
eval_inst = EvalATISSlot(pred_file, refer_file)
eval_metrics = eval_inst.evaluate()
print("SLOT FILLING TASK: %s metrics in testset: " % task_name)
print(eval_metrics)
elif task_name.lower() in ['dstc2', 'dstc2_asr']:
elif task_name.lower() in ['dstc2', 'dstc2_asr']:
eval_inst = EvalDSTC2(task_name.lower(), pred_file, refer_file)
eval_metrics = eval_inst.evaluate()
print("DST TASK: %s metrics in testset: " % task_name)
print("JOINT ACC: %s" % eval_metrics[0])
elif task_name.lower() == "multi-woz":
elif task_name.lower() == "multi-woz":
eval_inst = EvalMultiWoz(pred_file, refer_file)
eval_metrics = eval_inst.evaluate()
print("DST TASK: %s metrics in testset: " % task_name)
print("JOINT ACC: %s" % eval_metrics[0])
print("OVERALL ACC: %s" % eval_metrics[1])
else:
print("task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]")
else:
print(
"task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]"
)
if __name__ == "__main__":
if len(sys.argv[1:]) < 3:
if __name__ == "__main__":
if len(sys.argv[1:]) < 3:
print("please input task_name predict_file reference_file")
task_name = sys.argv[1]
pred_file = sys.argv[2]
refer_file = sys.argv[3]
evaluate(task_name, pred_file, refer_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册