evaluation.py 11.7 KB
Newer Older
0
0YuanZhang0 已提交
1
# -*- coding: utf-8 -*-
0
0YuanZhang0 已提交
2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
Y
Yibing Liu 已提交
3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
0
0YuanZhang0 已提交
15
"""evaluate task metrics"""
Y
Yibing Liu 已提交
16 17

import sys
0
0YuanZhang0 已提交
18
import io
Y
Yibing Liu 已提交
19 20 21 22 23 24


class EvalDA(object):
    """
    evaluate da testset, swda|mrda
    """
25 26

    def __init__(self, task_name, pred, refer):
Y
Yibing Liu 已提交
27 28 29 30
        """
        predict file
        """
        self.pred_file = pred
0
0YuanZhang0 已提交
31
        self.refer_file = refer
Y
Yibing Liu 已提交
32

33
    def load_data(self):
Y
Yibing Liu 已提交
34 35 36 37 38
        """
        load reference label and predict label
        """
        pred_label = []
        refer_label = []
0
0YuanZhang0 已提交
39
        fr = io.open(self.refer_file, 'r', encoding="utf8")
40
        for line in fr:
0
0YuanZhang0 已提交
41 42
            label = line.rstrip('\n').split('\t')[1]
            refer_label.append(int(label))
Y
Yibing Liu 已提交
43
        idx = 0
0
0YuanZhang0 已提交
44
        fr = io.open(self.pred_file, 'r', encoding="utf8")
45
        for line in fr:
0
0YuanZhang0 已提交
46 47 48 49 50
            elems = line.rstrip('\n').split('\t')
            if len(elems) != 2 or not elems[0].isdigit():
                continue
            tag_id = int(elems[1])
            pred_label.append(tag_id)
Y
Yibing Liu 已提交
51 52
        return pred_label, refer_label

53
    def evaluate(self):
Y
Yibing Liu 已提交
54 55 56 57 58 59
        """
        calculate acc metrics
        """
        pred_label, refer_label = self.load_data()
        common_num = 0
        total_num = len(pred_label)
60 61
        for i in range(total_num):
            if pred_label[i] == refer_label[i]:
Y
Yibing Liu 已提交
62 63 64 65 66 67 68 69 70
                common_num += 1
        acc = float(common_num) / total_num
        return acc


class EvalATISIntent(object):
    """
    evaluate da testset, swda|mrda
    """
71 72

    def __init__(self, pred, refer):
Y
Yibing Liu 已提交
73 74 75 76
        """
        predict file
        """
        self.pred_file = pred
0
0YuanZhang0 已提交
77
        self.refer_file = refer
Y
Yibing Liu 已提交
78

79
    def load_data(self):
Y
Yibing Liu 已提交
80 81 82 83 84
        """
        load reference label and predict label
        """
        pred_label = []
        refer_label = []
0
0YuanZhang0 已提交
85
        fr = io.open(self.refer_file, 'r', encoding="utf8")
86
        for line in fr:
0
0YuanZhang0 已提交
87 88
            label = line.rstrip('\n').split('\t')[0]
            refer_label.append(int(label))
Y
Yibing Liu 已提交
89
        idx = 0
0
0YuanZhang0 已提交
90
        fr = io.open(self.pred_file, 'r', encoding="utf8")
91
        for line in fr:
0
0YuanZhang0 已提交
92 93 94 95 96
            elems = line.rstrip('\n').split('\t')
            if len(elems) != 2 or not elems[0].isdigit():
                continue
            tag_id = int(elems[1])
            pred_label.append(tag_id)
Y
Yibing Liu 已提交
97 98
        return pred_label, refer_label

99
    def evaluate(self):
Y
Yibing Liu 已提交
100 101 102 103 104 105
        """
        calculate acc metrics
        """
        pred_label, refer_label = self.load_data()
        common_num = 0
        total_num = len(pred_label)
106 107
        for i in range(total_num):
            if pred_label[i] == refer_label[i]:
Y
Yibing Liu 已提交
108 109 110 111 112
                common_num += 1
        acc = float(common_num) / total_num
        return acc


113
class EvalATISSlot(object):
Y
Yibing Liu 已提交
114 115 116
    """
    evaluate atis slot
    """
117 118

    def __init__(self, pred, refer):
Y
Yibing Liu 已提交
119 120 121 122
        """
        pred file
        """
        self.pred_file = pred
0
0YuanZhang0 已提交
123
        self.refer_file = refer
Y
Yibing Liu 已提交
124

125
    def load_data(self):
Y
Yibing Liu 已提交
126 127 128 129 130
        """
        load reference label and predict label
        """
        pred_label = []
        refer_label = []
0
0YuanZhang0 已提交
131
        fr = io.open(self.refer_file, 'r', encoding="utf8")
132
        for line in fr:
0
0YuanZhang0 已提交
133 134 135 136
            labels = line.rstrip('\n').split('\t')[1].split()
            labels = [int(l) for l in labels]
            refer_label.append(labels)
        fr = io.open(self.pred_file, 'r', encoding="utf8")
137 138
        for line in fr:
            if len(line.split('\t')) != 2 or not line[0].isdigit():
0
0YuanZhang0 已提交
139 140 141 142
                continue
            labels = line.rstrip('\n').split('\t')[1].split()[1:]
            labels = [int(l) for l in labels]
            pred_label.append(labels)
Y
Yibing Liu 已提交
143 144 145
        pred_label_equal = []
        refer_label_equal = []
        assert len(refer_label) == len(pred_label)
146
        for i in range(len(refer_label)):
Y
Yibing Liu 已提交
147 148
            num = len(refer_label[i])
            refer_label_equal.extend(refer_label[i])
149
            pred_label[i] = pred_label[i][:num]
Y
Yibing Liu 已提交
150 151 152 153
            pred_label_equal.extend(pred_label[i])

        return pred_label_equal, refer_label_equal

154
    def evaluate(self):
Y
Yibing Liu 已提交
155 156 157 158 159 160 161
        """
        evaluate f1_micro score
        """
        pred_label, refer_label = self.load_data()
        tp = dict()
        fn = dict()
        fp = dict()
162
        for i in range(len(refer_label)):
Y
Yibing Liu 已提交
163
            if refer_label[i] == pred_label[i]:
164
                if refer_label[i] not in tp:
Y
Yibing Liu 已提交
165 166
                    tp[refer_label[i]] = 0
                tp[refer_label[i]] += 1
167 168
            else:
                if pred_label[i] not in fp:
Y
Yibing Liu 已提交
169 170 171 172 173 174 175
                    fp[pred_label[i]] = 0
                fp[pred_label[i]] += 1
                if refer_label[i] not in fn:
                    fn[refer_label[i]] = 0
                fn[refer_label[i]] += 1

        results = ["label    precision    recall"]
176 177
        for i in range(0, 130):
            if i not in tp:
Y
Yibing Liu 已提交
178 179
                results.append(" %s:    0.0     0.0" % i)
                continue
180
            if i in fp:
Y
Yibing Liu 已提交
181
                precision = float(tp[i]) / (tp[i] + fp[i])
182
            else:
Y
Yibing Liu 已提交
183
                precision = 1.0
184
            if i in fn:
Y
Yibing Liu 已提交
185
                recall = float(tp[i]) / (tp[i] + fn[i])
186
            else:
Y
Yibing Liu 已提交
187 188 189 190 191 192 193 194 195 196 197 198
                recall = 1.0
            results.append(" %s:    %.4f    %.4f" % (i, precision, recall))
        tp_total = sum(tp.values())
        fn_total = sum(fn.values())
        fp_total = sum(fp.values())
        p_total = float(tp_total) / (tp_total + fp_total)
        r_total = float(tp_total) / (tp_total + fn_total)
        f_micro = 2 * p_total * r_total / (p_total + r_total)
        results.append("f1_micro: %.4f" % (f_micro))
        return "\n".join(results)


199
class EvalUDC(object):
Y
Yibing Liu 已提交
200 201 202
    """
    evaluate udc
    """
203 204

    def __init__(self, pred, refer):
Y
Yibing Liu 已提交
205 206 207 208
        """
        predict file
        """
        self.pred_file = pred
0
0YuanZhang0 已提交
209
        self.refer_file = refer
Y
Yibing Liu 已提交
210

211
    def load_data(self):
Y
Yibing Liu 已提交
212 213 214
        """
        load reference label and predict label
        """
215
        data = []
Y
Yibing Liu 已提交
216
        refer_label = []
0
0YuanZhang0 已提交
217
        fr = io.open(self.refer_file, 'r', encoding="utf8")
218
        for line in fr:
0
0YuanZhang0 已提交
219 220
            label = line.rstrip('\n').split('\t')[0]
            refer_label.append(label)
Y
Yibing Liu 已提交
221
        idx = 0
0
0YuanZhang0 已提交
222
        fr = io.open(self.pred_file, 'r', encoding="utf8")
223
        for line in fr:
0
0YuanZhang0 已提交
224
            elems = line.rstrip('\n').split('\t')
225
            if len(elems) != 2 or not elems[0].isdigit():
0
0YuanZhang0 已提交
226 227 228 229
                continue
            match_prob = elems[1]
            data.append((float(match_prob), int(refer_label[idx])))
            idx += 1
Y
Yibing Liu 已提交
230 231 232 233 234 235 236
        return data

    def get_p_at_n_in_m(self, data, n, m, ind):
        """
        calculate precision in recall n
        """
        pos_score = data[ind][0]
237 238
        curr = data[ind:ind + m]
        curr = sorted(curr, key=lambda x: x[0], reverse=True)
Y
Yibing Liu 已提交
239 240 241 242 243 244 245 246 247

        if curr[n - 1][0] <= pos_score:
            return 1
        return 0

    def evaluate(self):
        """
        calculate udc data
        """
248
        data = self.load_data()
Y
Yibing Liu 已提交
249
        assert len(data) % 10 == 0
250

Y
Yibing Liu 已提交
251 252 253 254 255
        p_at_1_in_2 = 0.0
        p_at_1_in_10 = 0.0
        p_at_2_in_10 = 0.0
        p_at_5_in_10 = 0.0

256
        length = int(len(data) / 10)
Y
Yibing Liu 已提交
257 258 259 260

        for i in range(0, length):
            ind = i * 10
            assert data[ind][1] == 1
261

Y
Yibing Liu 已提交
262 263 264 265 266 267 268
            p_at_1_in_2 += self.get_p_at_n_in_m(data, 1, 2, ind)
            p_at_1_in_10 += self.get_p_at_n_in_m(data, 1, 10, ind)
            p_at_2_in_10 += self.get_p_at_n_in_m(data, 2, 10, ind)
            p_at_5_in_10 += self.get_p_at_n_in_m(data, 5, 10, ind)

        metrics_out = [p_at_1_in_2 / length, p_at_1_in_10 / length, \
                p_at_2_in_10 / length, p_at_5_in_10 / length]
269
        return metrics_out
Y
Yibing Liu 已提交
270 271


272
class EvalDSTC2(object):
Y
Yibing Liu 已提交
273 274 275
    """
    evaluate dst testset, dstc2
    """
276

0
0YuanZhang0 已提交
277
    def __init__(self, task_name, pred, refer):
Y
Yibing Liu 已提交
278 279 280 281 282
        """
        predict file
        """
        self.task_name = task_name
        self.pred_file = pred
0
0YuanZhang0 已提交
283
        self.refer_file = refer
Y
Yibing Liu 已提交
284

285
    def load_data(self):
Y
Yibing Liu 已提交
286 287 288 289 290
        """
        load reference label and predict label
        """
        pred_label = []
        refer_label = []
0
0YuanZhang0 已提交
291
        fr = io.open(self.refer_file, 'r', encoding="utf8")
292
        for line in fr:
0
0YuanZhang0 已提交
293 294 295 296
            line = line.strip('\n')
            labels = [int(l) for l in line.split('\t')[-1].split()]
            labels = sorted(list(set(labels)))
            refer_label.append(" ".join([str(l) for l in labels]))
Y
Yibing Liu 已提交
297
        all_pred = []
0
0YuanZhang0 已提交
298
        fr = io.open(self.pred_file, 'r', encoding="utf8")
299
        for line in fr:
0
0YuanZhang0 已提交
300 301
            line = line.strip('\n')
            all_pred.append(line)
Y
Yibing Liu 已提交
302
        all_pred = all_pred[len(all_pred) - len(refer_label):]
303
        for line in all_pred:
Y
Yibing Liu 已提交
304 305 306 307 308
            labels = [int(l) for l in line.split('\t')[-1].split()]
            labels = sorted(list(set(labels)))
            pred_label.append(" ".join([str(l) for l in labels]))
        return pred_label, refer_label

309
    def evaluate(self):
Y
Yibing Liu 已提交
310 311 312 313 314 315
        """
        calculate joint acc && overall acc
        """
        overall_all = 0.0
        correct_joint = 0
        pred_label, refer_label = self.load_data()
316 317
        for i in range(len(refer_label)):
            if refer_label[i] != pred_label[i]:
Y
Yibing Liu 已提交
318 319 320 321 322 323 324
                continue
            correct_joint += 1
        joint_all = float(correct_joint) / len(refer_label)
        metrics_out = [joint_all, overall_all]
        return metrics_out


325
def evaluate(task_name, pred_file, refer_file):
0
0YuanZhang0 已提交
326
    """evaluate task metrics"""
327
    if task_name.lower() == 'udc':
0
0YuanZhang0 已提交
328
        eval_inst = EvalUDC(pred_file, refer_file)
Y
Yibing Liu 已提交
329 330 331 332 333 334 335
        eval_metrics = eval_inst.evaluate()
        print("MATCHING TASK: %s metrics in testset: " % task_name)
        print("R1@2: %s" % eval_metrics[0])
        print("R1@10: %s" % eval_metrics[1])
        print("R2@10: %s" % eval_metrics[2])
        print("R5@10: %s" % eval_metrics[3])

336
    elif task_name.lower() in ['swda', 'mrda']:
0
0YuanZhang0 已提交
337
        eval_inst = EvalDA(task_name.lower(), pred_file, refer_file)
Y
Yibing Liu 已提交
338 339 340 341
        eval_metrics = eval_inst.evaluate()
        print("DA TASK: %s metrics in testset: " % task_name)
        print("ACC: %s" % eval_metrics)

342
    elif task_name.lower() == 'atis_intent':
0
0YuanZhang0 已提交
343
        eval_inst = EvalATISIntent(pred_file, refer_file)
Y
Yibing Liu 已提交
344 345 346 347
        eval_metrics = eval_inst.evaluate()
        print("INTENTION TASK: %s metrics in testset: " % task_name)
        print("ACC: %s" % eval_metrics)

348
    elif task_name.lower() == 'atis_slot':
0
0YuanZhang0 已提交
349
        eval_inst = EvalATISSlot(pred_file, refer_file)
Y
Yibing Liu 已提交
350 351 352
        eval_metrics = eval_inst.evaluate()
        print("SLOT FILLING TASK: %s metrics in testset: " % task_name)
        print(eval_metrics)
353
    elif task_name.lower() in ['dstc2', 'dstc2_asr']:
0
0YuanZhang0 已提交
354
        eval_inst = EvalDSTC2(task_name.lower(), pred_file, refer_file)
Y
Yibing Liu 已提交
355 356 357
        eval_metrics = eval_inst.evaluate()
        print("DST TASK: %s metrics in testset: " % task_name)
        print("JOINT ACC: %s" % eval_metrics[0])
358
    elif task_name.lower() == "multi-woz":
0
0YuanZhang0 已提交
359
        eval_inst = EvalMultiWoz(pred_file, refer_file)
Y
Yibing Liu 已提交
360 361 362 363
        eval_metrics = eval_inst.evaluate()
        print("DST TASK: %s metrics in testset: " % task_name)
        print("JOINT ACC: %s" % eval_metrics[0])
        print("OVERALL ACC: %s" % eval_metrics[1])
364 365 366 367
    else:
        print(
            "task name not in [udc|swda|mrda|atis_intent|atis_slot|dstc2|dstc2_asr|multi-woz]"
        )
Y
Yibing Liu 已提交
368

0
0YuanZhang0 已提交
369

370 371
if __name__ == "__main__":
    if len(sys.argv[1:]) < 3:
0
0YuanZhang0 已提交
372 373 374 375 376 377 378
        print("please input task_name predict_file reference_file")

    task_name = sys.argv[1]
    pred_file = sys.argv[2]
    refer_file = sys.argv[3]

    evaluate(task_name, pred_file, refer_file)