error_rate.py 4.6 KB
Newer Older
Y
yangyaming 已提交
1
# -*- coding: utf-8 -*-
Y
yangyaming 已提交
2 3
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
Y
yangyaming 已提交
4
"""
Y
yangyaming 已提交
5 6 7
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Y
yangyaming 已提交
8

Y
yangyaming 已提交
9 10 11
import numpy as np


Y
yangyaming 已提交
12 13 14 15 16 17 18
def _levenshtein_distance(ref, hyp):
    """Levenshtein distance is a string metric for measuring the difference between
    two sequences. Informally, the levenshtein disctance is defined as the minimum
    number of single-character edits (substitutions, insertions or deletions) 
    required to change one word into the other. We can naturally extend the edits to 
    word level when calculate levenshtein disctance for two sentences.
    """
Y
yangyaming 已提交
19 20 21 22 23 24 25 26 27 28 29
    ref_len = len(ref)
    hyp_len = len(hyp)

    # special case
    if ref == hyp:
        return 0
    if ref_len == 0:
        return hyp_len
    if hyp_len == 0:
        return ref_len

Y
yangyaming 已提交
30
    distance = np.zeros((ref_len + 1, hyp_len + 1), dtype=np.int32)
Y
yangyaming 已提交
31

Y
yangyaming 已提交
32
    # initialize distance matrix
Y
yangyaming 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    for j in xrange(hyp_len + 1):
        distance[0][j] = j
    for i in xrange(ref_len + 1):
        distance[i][0] = i

    # calculate levenshtein distance
    for i in xrange(1, ref_len + 1):
        for j in xrange(1, hyp_len + 1):
            if ref[i - 1] == hyp[j - 1]:
                distance[i][j] = distance[i - 1][j - 1]
            else:
                s_num = distance[i - 1][j - 1] + 1
                i_num = distance[i][j - 1] + 1
                d_num = distance[i - 1][j] + 1
                distance[i][j] = min(s_num, i_num, d_num)

    return distance[ref_len][hyp_len]


Y
yangyaming 已提交
52
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
Y
yangyaming 已提交
53
    """Calculate word error rate (WER). WER compares reference text and 
Y
yangyaming 已提交
54
    hypothesis text in word-level. WER is defined as:
Y
yangyaming 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67

    .. math::
        WER = (Sw + Dw + Iw) / Nw

    where

    .. code-block:: text

        Sw is the number of words subsituted,
        Dw is the number of words deleted,
        Iw is the number of words inserted,
        Nw is the number of words in the reference

Y
yangyaming 已提交
68 69
    We can use levenshtein distance to calculate WER. Please draw an attention that 
    empty items will be removed when splitting sentences by delimiter.
Y
yangyaming 已提交
70 71

    :param reference: The reference sentence.
Y
yangyaming 已提交
72 73 74 75 76
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
Y
yangyaming 已提交
77 78
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
Y
yangyaming 已提交
79
    :return: Word error rate.
Y
yangyaming 已提交
80
    :rtype: float
Y
yangyaming 已提交
81
    :raises ValueError: If the reference length is zero.
Y
yangyaming 已提交
82
    """
Y
yangyaming 已提交
83 84 85
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()
Y
yangyaming 已提交
86

Y
yangyaming 已提交
87 88
    ref_words = filter(None, reference.split(delimiter))
    hyp_words = filter(None, hypothesis.split(delimiter))
Y
yangyaming 已提交
89

Y
yangyaming 已提交
90 91
    if len(ref_words) == 0:
        raise ValueError("Reference's word number should be greater than 0.")
Y
yangyaming 已提交
92

Y
yangyaming 已提交
93
    edit_distance = _levenshtein_distance(ref_words, hyp_words)
Y
yangyaming 已提交
94 95 96 97
    wer = float(edit_distance) / len(ref_words)
    return wer


Y
yangyaming 已提交
98
def cer(reference, hypothesis, ignore_case=False):
Y
yangyaming 已提交
99
    """Calculate charactor error rate (CER). CER compares reference text and
Y
yangyaming 已提交
100
    hypothesis text in char-level. CER is defined as:
Y
yangyaming 已提交
101 102 103 104 105 106 107 108

    .. math::
        CER = (Sc + Dc + Ic) / Nc

    where

    .. code-block:: text

Y
yangyaming 已提交
109 110 111
        Sc is the number of characters substituted,
        Dc is the number of characters deleted,
        Ic is the number of characters inserted
Y
yangyaming 已提交
112 113 114
        Nc is the number of characters in the reference

    We can use levenshtein distance to calculate CER. Chinese input should be 
Y
yangyaming 已提交
115 116 117
    encoded to unicode. Please draw an attention that the leading and tailing 
    white space characters will be truncated and multiple consecutive white 
    space characters in a sentence will be replaced by one white space character.
Y
yangyaming 已提交
118 119

    :param reference: The reference sentence.
Y
yangyaming 已提交
120 121 122
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
Y
yangyaming 已提交
123
    :param ignore_case: Whether case-sensitive or not.
Y
yangyaming 已提交
124
    :type ignore_case: bool
Y
yangyaming 已提交
125
    :return: Character error rate.
Y
yangyaming 已提交
126
    :rtype: float
Y
yangyaming 已提交
127
    :raises ValueError: If the reference length is zero.
Y
yangyaming 已提交
128 129 130
    """
    if ignore_case == True:
        reference = reference.lower()
Y
yangyaming 已提交
131 132 133 134
        hypothesis = hypothesis.lower()

    reference = ' '.join(filter(None, reference.split(' ')))
    hypothesis = ' '.join(filter(None, hypothesis.split(' ')))
Y
yangyaming 已提交
135 136 137

    if len(reference) == 0:
        raise ValueError("Length of reference should be greater than 0.")
Y
yangyaming 已提交
138

Y
yangyaming 已提交
139
    edit_distance = _levenshtein_distance(reference, hypothesis)
Y
yangyaming 已提交
140 141
    cer = float(edit_distance) / len(reference)
    return cer