evaluate.py 7.1 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
#coding=utf-8
"""
evaluate wordseg for LAC and other open-source wordseg tools
"""
from __future__ import print_function
from __future__ import division

import sys
import os


def to_unicode(string):
    """ string compatibility for python2 & python3 """
    if sys.version_info.major == 2 and isinstance(string, str):
        return string.decode("utf-8")
    else:
        return string


def to_set(words):
    """ cut list to set of (string, off) """
    off = 0
    s= set()
    for w in words:
        if w:
            s.add((off, w))
        off += len(w)
    return s


def cal_fscore(standard, result, split_delim=" "):
    """ caculate fscore for wordseg
    Param: standard, list of str, ground-truth labels , e.g. ["a b c", "d ef g"]
    Param: result, list of str, predicted result, e.g. ["ab c", "d e fg"]
    """
    assert len(standard) == len(result)
    std, rst, cor = 0, 0, 0
    for s, r in zip(standard, result):
        s = to_set(s.rstrip().split(split_delim))
        r = to_set(r.rstrip().split(split_delim))
        std += len(s)
        rst += len(r)
        cor += len(s & r)
    p = 1.0 * cor / rst
    r = 1.0 * cor / std
    f = 2 * p * r / (p + r)

    print("std, rst, cor = %d, %d, %d" % (std, rst, cor))
    print("precision = %.5f, recall = %.5f, f1 = %.5f" % (p, r, f))
    #print("| | %.5f | %.5f | %.5f |" % (p, r, f))
    print("")

    return p, r, f


def load_testdata(datapath="./data/test_data/test_part"):
    """none"""
    sentences = []
    sent_seg_list = []
    for line in open(datapath):
        sent, label = line.strip().split("\t")
        sentences.append(sent)

        sent = to_unicode(sent)
        label = label.split(" ")
        assert len(sent) == len(label)

        # parse segment
        words = []
        current_word = ""
        for w, l in zip(sent, label):
            if l.endswith("-B"):
                if current_word != "":
                    words.append(current_word)
                current_word = w
            elif l.endswith("-I"):
                current_word += w
            elif l.endswith("-O"):
                if current_word != "":
                    words.append(current_word)
                words.append(w)
                current_word = ""
            else:
                raise ValueError("wrong label: " + l)
        if current_word != "":
            words.append(current_word)
        sent_seg = " ".join(words)
        sent_seg_list.append(sent_seg)
    print("got %d lines" % (len(sent_seg_list)))
    return sent_seg_list, sentences


def get_lac_result():
    """
    get LAC predicted result by:
        `sh run.sh | tail -n 100 > result.txt`
    """
    sent_seg_list = []
    for line in open("./result.txt"):
        line = line.strip().split(" ")
        words = [pair.split("/")[0] for pair in line]
        labels = [pair.split("/")[1] for pair in line]
        sent_seg = " ".join(words)
        sent_seg = to_unicode(sent_seg)
        sent_seg_list.append(sent_seg)
    return sent_seg_list


def get_jieba_result(sentences):
    """
    Ref to: https://github.com/fxsjy/jieba
    Install by `pip install jieba`
    """
    import jieba
    preds = []
    for sentence in sentences:
        sent_seg = " ".join(jieba.lcut(sentence))
        sent_seg = to_unicode(sent_seg)
        preds.append(sent_seg)
    return preds


def get_thulac_result(sentences):
    """
    Ref to: http://thulac.thunlp.org/
    Install by: `pip install thulac`
    """
    import thulac
    preds = []
    lac = thulac.thulac(seg_only=True)
    for sentence in sentences:
        sent_seg = lac.cut(sentence, text=True)
        sent_seg = to_unicode(sent_seg)
        preds.append(sent_seg)
    return preds


def get_pkuseg_result(sentences):
    """
    Ref to: https://github.com/lancopku/pkuseg-python
    Install by: `pip3 install pkuseg`
    You should noticed that pkuseg-python only support python3
    """
    import pkuseg
    seg = pkuseg.pkuseg()
    preds = []
    for sentence in sentences:
        sent_seg  = " ".join(seg.cut(sentence))
        sent_seg = to_unicode(sent_seg)
        preds.append(sent_seg)
    return preds


def get_hanlp_result(sentences):
    """
    Ref to: https://github.com/hankcs/pyhanlp
    Install by: pip install pyhanlp
        (Before using pyhanlp, you need to download the model manully.)
    """
    from pyhanlp import HanLP
    preds = []
    for sentence in sentences:
        arraylist = HanLP.segment(sentence)
        sent_seg = " ".join([term.toString().split("/")[0] for term in arraylist])
        sent_seg = to_unicode(sent_seg)
        preds.append(sent_seg)
    return preds


def get_nlpir_result(sentences):
    """
    Ref to: https://github.com/tsroten/pynlpir
    Install by `pip install pynlpir`
    Run `pynlpir update` to update License
    """
    import pynlpir
    pynlpir.open()
    preds = []
    for sentence in sentences:
        sent_seg = " ".join(pynlpir.segment(sentence, pos_tagging=False))
        sent_seg = to_unicode(sent_seg)
        preds.append(sent_seg)
    return preds


def get_ltp_result(sentences):
    """
    Ref to: https://github.com/HIT-SCIR/pyltp
        1. Install by `pip install pyltp`
        2. Download models from http://ltp.ai/download.html
    """
    from pyltp import Segmentor
    segmentor = Segmentor()
    model_path = "./ltp_data_v3.4.0/cws.model"
    if not os.path.exists(model_path):
        raise IOError("LTP Model do not exist! Download it first!")
    segmentor.load(model_path)
    preds = []
    for sentence in sentences:
        sent_seg = " ".join(segmentor.segment(sentence))
        sent_seg = to_unicode(sent_seg)
        preds.append(sent_seg)
    segmentor.release()

    return preds


def print_array(array):
    """print some case"""
    for i in [1, 10, 20, 30, 40]:
        print("case " + str(i) + ": \t" + array[i])


def evaluate_all():
    """none"""
    standard, sentences = load_testdata()
    print_array(standard)

    # evaluate lac
    preds = get_lac_result()
    print("lac result:")
    print_array(preds)
    cal_fscore(standard=standard, result=preds)

    # evaluate jieba
    preds = get_jieba_result(sentences)
    print("jieba result")
    print_array(preds)
    cal_fscore(standard=standard, result=preds)

    # evaluate thulac
    preds = get_thulac_result(sentences)
    print("thulac result")
    print_array(preds)
    cal_fscore(standard=standard, result=preds)

    # evaluate pkuseg, but pyuseg only support python3
    if sys.version_info.major == 3:
        preds = get_pkuseg_result(sentences)
        print("pkuseg result")
        print_array(preds)
        cal_fscore(standard=standard, result=preds)

    # evaluate HanLP
    preds = get_hanlp_result(sentences)
    print("HanLP result")
    print_array(preds)
    cal_fscore(standard=standard, result=preds)

    # evaluate NLPIR
    preds = get_nlpir_result(sentences)
    print("NLPIR result")
    print_array(preds)
    cal_fscore(standard=standard, result=preds)

    # evaluate LTP
    preds = get_ltp_result(sentences)
    print("LTP result")
    print_array(preds)
    cal_fscore(standard=standard, result=preds)


if __name__ == "__main__":
    import ipdb
    #ipdb.set_trace()
    evaluate_all()