error_rate.py 4.2 KB
Newer Older
Y
yangyaming 已提交
1
# -*- coding: utf-8 -*-
Y
yangyaming 已提交
2 3
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
Y
yangyaming 已提交
4 5
"""

Y
yangyaming 已提交
6 7 8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Y
yangyaming 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
import numpy as np


def levenshtein_distance(ref, hyp):
    ref_len = len(ref)
    hyp_len = len(hyp)

    # special case
    if ref == hyp:
        return 0
    if ref_len == 0:
        return hyp_len
    if hyp_len == 0:
        return ref_len

Y
yangyaming 已提交
24
    distance = np.zeros((ref_len + 1, hyp_len + 1), dtype=np.int32)
Y
yangyaming 已提交
25

Y
yangyaming 已提交
26
    # initialize distance matrix
Y
yangyaming 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    for j in xrange(hyp_len + 1):
        distance[0][j] = j
    for i in xrange(ref_len + 1):
        distance[i][0] = i

    # calculate levenshtein distance
    for i in xrange(1, ref_len + 1):
        for j in xrange(1, hyp_len + 1):
            if ref[i - 1] == hyp[j - 1]:
                distance[i][j] = distance[i - 1][j - 1]
            else:
                s_num = distance[i - 1][j - 1] + 1
                i_num = distance[i][j - 1] + 1
                d_num = distance[i - 1][j] + 1
                distance[i][j] = min(s_num, i_num, d_num)

    return distance[ref_len][hyp_len]


Y
yangyaming 已提交
46
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
Y
yangyaming 已提交
47
    """Calculate word error rate (WER). WER compares reference text and 
Y
yangyaming 已提交
48
    hypothesis text in word-level. WER is defined as:
Y
yangyaming 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61

    .. math::
        WER = (Sw + Dw + Iw) / Nw

    where

    .. code-block:: text

        Sw is the number of words subsituted,
        Dw is the number of words deleted,
        Iw is the number of words inserted,
        Nw is the number of words in the reference

Y
yangyaming 已提交
62 63
    We can use levenshtein distance to calculate WER. Please draw an attention that 
    empty items will be removed when splitting sentences by delimiter.
Y
yangyaming 已提交
64 65

    :param reference: The reference sentence.
Y
yangyaming 已提交
66 67 68 69 70
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
Y
yangyaming 已提交
71 72
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
Y
yangyaming 已提交
73
    :return: Word error rate.
Y
yangyaming 已提交
74
    :rtype: float
Y
yangyaming 已提交
75
    :raises ValueError: If reference length is zero.
Y
yangyaming 已提交
76
    """
Y
yangyaming 已提交
77 78 79
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()
Y
yangyaming 已提交
80

Y
yangyaming 已提交
81 82
    ref_words = filter(None, reference.split(delimiter))
    hyp_words = filter(None, hypothesis.split(delimiter))
Y
yangyaming 已提交
83

Y
yangyaming 已提交
84 85
    if len(ref_words) == 0:
        raise ValueError("Reference's word number should be greater than 0.")
Y
yangyaming 已提交
86 87 88 89 90 91

    edit_distance = levenshtein_distance(ref_words, hyp_words)
    wer = float(edit_distance) / len(ref_words)
    return wer


Y
yangyaming 已提交
92
def cer(reference, hypothesis, ignore_case=False):
Y
yangyaming 已提交
93
    """Calculate charactor error rate (CER). CER compares reference text and
Y
yangyaming 已提交
94
    hypothesis text in char-level. CER is defined as:
Y
yangyaming 已提交
95 96 97 98 99 100 101 102

    .. math::
        CER = (Sc + Dc + Ic) / Nc

    where

    .. code-block:: text

Y
yangyaming 已提交
103 104 105
        Sc is the number of characters substituted,
        Dc is the number of characters deleted,
        Ic is the number of characters inserted
Y
yangyaming 已提交
106 107 108
        Nc is the number of characters in the reference

    We can use levenshtein distance to calculate CER. Chinese input should be 
Y
yangyaming 已提交
109 110 111
    encoded to unicode. Please draw an attention that the leading and tailing 
    white space characters will be truncated and multiple consecutive white 
    space characters in a sentence will be replaced by one white space character.
Y
yangyaming 已提交
112 113

    :param reference: The reference sentence.
Y
yangyaming 已提交
114 115 116
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
Y
yangyaming 已提交
117
    :param ignore_case: Whether case-sensitive or not.
Y
yangyaming 已提交
118
    :type ignore_case: bool
Y
yangyaming 已提交
119
    :return: Character error rate.
Y
yangyaming 已提交
120
    :rtype: float
Y
yangyaming 已提交
121
    :raises ValueError: If reference length is zero.
Y
yangyaming 已提交
122 123 124
    """
    if ignore_case == True:
        reference = reference.lower()
Y
yangyaming 已提交
125 126 127 128
        hypothesis = hypothesis.lower()

    reference = ' '.join(filter(None, reference.split(' ')))
    hypothesis = ' '.join(filter(None, hypothesis.split(' ')))
Y
yangyaming 已提交
129 130 131

    if len(reference) == 0:
        raise ValueError("Length of reference should be greater than 0.")
Y
yangyaming 已提交
132 133

    edit_distance = levenshtein_distance(reference, hypothesis)
Y
yangyaming 已提交
134 135
    cer = float(edit_distance) / len(reference)
    return cer