evaluate_intent.py 1.7 KB
Newer Older
W
wangxiao1021 已提交
1 2 3 4 5 6 7 8 9 10
#  -*- coding: utf-8 -*-

import json
import numpy as np

def accuracy(preds, labels):
    preds = np.array(preds)
    labels = np.array(labels) 
    return (preds == labels).mean()
  
W
wangxiao1021 已提交
11
def pre_recall_f1(preds, labels):
W
wangxiao1021 已提交
12 13 14 15
    preds = np.array(preds)
    labels = np.array(labels)
    # recall=TP/(TP+FN)
    tp = np.sum((labels == '1') & (preds == '1'))
W
wangxiao1021 已提交
16
    fp = np.sum((labels == '0') & (preds == '1'))
W
wangxiao1021 已提交
17
    fn = np.sum((labels == '1') & (preds == '0'))
W
wangxiao1021 已提交
18 19 20 21 22 23
    r = tp * 1.0 / (tp + fn)
    # Precision=TP/(TP+FP)
    p = tp * 1.0 / (tp + fp)
    epsilon = 1e-31
    f1 = 2 * p * r / (p+r+epsilon)
    return p, r, f1
W
wangxiao1021 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54


def res_evaluate(res_dir="./outputs/predict-intent/predictions.json", eval_phase='test'):
    if eval_phase == 'test':
        data_dir="./data/atis/atis_intent/test.tsv"
    elif eval_phase == 'dev':
        data_dir="./data/dev.tsv"

    else:
        assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test'
    
    labels = []
    with open(data_dir, "r") as file:
        first_flag = True
        for line in file:
            line = line.split("\t")
            label = line[0]
            if label=='label':
                continue
            labels.append(str(label))
    file.close()

    preds = []
    with open(res_dir, "r") as file:
        for line in file.readlines():
            line = json.loads(line)
            pred = line['label']
            preds.append(str(pred))
    file.close()
    assert len(labels) == len(preds), "prediction result doesn't match to labels"
    print('data num: {}'.format(len(labels)))
W
wangxiao1021 已提交
55 56
    p, r, f1 = pre_recall_f1(preds, labels)
    print("accuracy: {:.4f}, precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(accuracy(preds, labels), p, r, f1))
W
wangxiao1021 已提交
57 58

res_evaluate()