error_rate.py 4.0 KB
Newer Older
Y
yangyaming 已提交
1 2 3 4 5 6
# -*- coding: utf-8 -*-
"""
    This module provides functions to calculate error rate in different level.
    e.g. wer for word-level, cer for char-level.
"""

Y
yangyaming 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
import numpy as np


def levenshtein_distance(ref, hyp):
    ref_len = len(ref)
    hyp_len = len(hyp)

    # special case
    if ref == hyp:
        return 0
    if ref_len == 0:
        return hyp_len
    if hyp_len == 0:
        return ref_len

Y
yangyaming 已提交
22
    distance = np.zeros((ref_len + 1, hyp_len + 1), dtype=np.int32)
Y
yangyaming 已提交
23

Y
yangyaming 已提交
24
    # initialize distance matrix
Y
yangyaming 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    for j in xrange(hyp_len + 1):
        distance[0][j] = j
    for i in xrange(ref_len + 1):
        distance[i][0] = i

    # calculate levenshtein distance
    for i in xrange(1, ref_len + 1):
        for j in xrange(1, hyp_len + 1):
            if ref[i - 1] == hyp[j - 1]:
                distance[i][j] = distance[i - 1][j - 1]
            else:
                s_num = distance[i - 1][j - 1] + 1
                i_num = distance[i][j - 1] + 1
                d_num = distance[i - 1][j] + 1
                distance[i][j] = min(s_num, i_num, d_num)

    return distance[ref_len][hyp_len]


Y
yangyaming 已提交
44
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
Y
yangyaming 已提交
45
    """
Y
yangyaming 已提交
46 47
    Calculate word error rate (WER). WER compares reference text and 
    hypothesis text in word-level. WER is defined as:
Y
yangyaming 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60

    .. math::
        WER = (Sw + Dw + Iw) / Nw

    where

    .. code-block:: text

        Sw is the number of words subsituted,
        Dw is the number of words deleted,
        Iw is the number of words inserted,
        Nw is the number of words in the reference

Y
yangyaming 已提交
61 62
    We can use levenshtein distance to calculate WER. Please draw an attention that 
    empty items will be removed when splitting sentences by delimiter.
Y
yangyaming 已提交
63 64

    :param reference: The reference sentence.
Y
yangyaming 已提交
65 66 67 68 69
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
Y
yangyaming 已提交
70 71
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
Y
yangyaming 已提交
72
    :return: Word error rate.
Y
yangyaming 已提交
73 74
    :rtype: float
    """
Y
yangyaming 已提交
75 76 77
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()
Y
yangyaming 已提交
78

Y
yangyaming 已提交
79 80
    ref_words = filter(None, reference.split(delimiter))
    hyp_words = filter(None, hypothesis.split(delimiter))
Y
yangyaming 已提交
81

Y
yangyaming 已提交
82 83
    if len(ref_words) == 0:
        raise ValueError("Reference's word number should be greater than 0.")
Y
yangyaming 已提交
84 85 86 87 88 89

    edit_distance = levenshtein_distance(ref_words, hyp_words)
    wer = float(edit_distance) / len(ref_words)
    return wer


Y
yangyaming 已提交
90
def cer(reference, hypothesis, ignore_case=False):
Y
yangyaming 已提交
91
    """
Y
yangyaming 已提交
92 93
    Calculate charactor error rate (CER). CER compares reference text and
    hypothesis text in char-level. CER is defined as:
Y
yangyaming 已提交
94 95 96 97 98 99 100 101

    .. math::
        CER = (Sc + Dc + Ic) / Nc

    where

    .. code-block:: text

Y
yangyaming 已提交
102 103 104
        Sc is the number of characters substituted,
        Dc is the number of characters deleted,
        Ic is the number of characters inserted
Y
yangyaming 已提交
105 106 107
        Nc is the number of characters in the reference

    We can use levenshtein distance to calculate CER. Chinese input should be 
Y
yangyaming 已提交
108 109 110
    encoded to unicode. Please draw an attention that the leading and tailing 
    white space characters will be truncated and multiple consecutive white 
    space characters in a sentence will be replaced by one white space character.
Y
yangyaming 已提交
111 112

    :param reference: The reference sentence.
Y
yangyaming 已提交
113 114 115
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
Y
yangyaming 已提交
116
    :param ignore_case: Whether case-sensitive or not.
Y
yangyaming 已提交
117
    :type ignore_case: bool
Y
yangyaming 已提交
118
    :return: Character error rate.
Y
yangyaming 已提交
119 120 121 122
    :rtype: float
    """
    if ignore_case == True:
        reference = reference.lower()
Y
yangyaming 已提交
123 124 125 126
        hypothesis = hypothesis.lower()

    reference = ' '.join(filter(None, reference.split(' ')))
    hypothesis = ' '.join(filter(None, hypothesis.split(' ')))
Y
yangyaming 已提交
127 128 129

    if len(reference) == 0:
        raise ValueError("Length of reference should be greater than 0.")
Y
yangyaming 已提交
130 131

    edit_distance = levenshtein_distance(reference, hypothesis)
Y
yangyaming 已提交
132 133
    cer = float(edit_distance) / len(reference)
    return cer