error_rate.py 6.9 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yangyaming 已提交
14 15
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
Y
yangyaming 已提交
16
"""
Y
yangyaming 已提交
17

Y
yangyaming 已提交
18 19 20
import numpy as np


Y
yangyaming 已提交
21
def _levenshtein_distance(ref, hyp):
22 23 24 25 26 27
    """Levenshtein distance is a string metric for measuring the difference
    between two sequences. Informally, the levenshtein disctance is defined as
    the minimum number of single-character edits (substitutions, insertions or
    deletions) required to change one word into the other. We can naturally
    extend the edits to word level when calculate levenshtein disctance for
    two sentences.
Y
yangyaming 已提交
28
    """
29 30
    m = len(ref)
    n = len(hyp)
Y
yangyaming 已提交
31 32 33 34

    # special case
    if ref == hyp:
        return 0
35 36 37 38
    if m == 0:
        return n
    if n == 0:
        return m
Y
yangyaming 已提交
39

40 41 42 43 44 45
    if m < n:
        ref, hyp = hyp, ref
        m, n = n, m

    # use O(min(m, n)) space
    distance = np.zeros((2, n + 1), dtype=np.int32)
Y
yangyaming 已提交
46

Y
yangyaming 已提交
47
    # initialize distance matrix
L
lfchener 已提交
48
    for j in range(n + 1):
Y
yangyaming 已提交
49 50 51
        distance[0][j] = j

    # calculate levenshtein distance
L
lfchener 已提交
52
    for i in range(1, m + 1):
53 54 55
        prev_row_idx = (i - 1) % 2
        cur_row_idx = i % 2
        distance[cur_row_idx][0] = i
L
lfchener 已提交
56
        for j in range(1, n + 1):
Y
yangyaming 已提交
57
            if ref[i - 1] == hyp[j - 1]:
58
                distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
Y
yangyaming 已提交
59
            else:
60 61 62 63
                s_num = distance[prev_row_idx][j - 1] + 1
                i_num = distance[cur_row_idx][j - 1] + 1
                d_num = distance[prev_row_idx][j] + 1
                distance[cur_row_idx][j] = min(s_num, i_num, d_num)
Y
yangyaming 已提交
64

65
    return distance[m % 2][n]
Y
yangyaming 已提交
66 67


68 69 70 71 72
def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
    """Compute the levenshtein distance between reference sequence and
    hypothesis sequence in word-level.

    :param reference: The reference sentence.
H
Hui Zhang 已提交
73
    :type reference: str
74
    :param hypothesis: The hypothesis sentence.
H
Hui Zhang 已提交
75
    :type hypothesis: str
76 77 78 79 80 81 82 83 84 85 86
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
    :return: Levenshtein distance and word number of reference sentence.
    :rtype: list
    """
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()

H
Hui Zhang 已提交
87 88
    ref_words = list(filter(None, reference.split(delimiter)))
    hyp_words = list(filter(None, hypothesis.split(delimiter)))
89 90 91 92 93 94 95 96 97 98

    edit_distance = _levenshtein_distance(ref_words, hyp_words)
    return float(edit_distance), len(ref_words)


def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
    """Compute the levenshtein distance between reference sequence and
    hypothesis sequence in char-level.

    :param reference: The reference sentence.
H
Hui Zhang 已提交
99
    :type reference: str
100
    :param hypothesis: The hypothesis sentence.
H
Hui Zhang 已提交
101
    :type hypothesis: str
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param remove_space: Whether remove internal space characters
    :type remove_space: bool
    :return: Levenshtein distance and length of reference sentence.
    :rtype: list
    """
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()

    join_char = ' '
    if remove_space == True:
        join_char = ''

H
Hui Zhang 已提交
117 118
    reference = join_char.join(list(filter(None, reference.split(' '))))
    hypothesis = join_char.join(list(filter(None, hypothesis.split(' '))))
119 120 121 122 123

    edit_distance = _levenshtein_distance(reference, hypothesis)
    return float(edit_distance), len(reference)


Y
yangyaming 已提交
124
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
125
    """Calculate word error rate (WER). WER compares reference text and
Y
yangyaming 已提交
126
    hypothesis text in word-level. WER is defined as:
Y
yangyaming 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139

    .. math::
        WER = (Sw + Dw + Iw) / Nw

    where

    .. code-block:: text

        Sw is the number of words subsituted,
        Dw is the number of words deleted,
        Iw is the number of words inserted,
        Nw is the number of words in the reference

140 141
    We can use levenshtein distance to calculate WER. Please draw an attention
    that empty items will be removed when splitting sentences by delimiter.
Y
yangyaming 已提交
142 143

    :param reference: The reference sentence.
H
Hui Zhang 已提交
144
    :type reference: str
Y
yangyaming 已提交
145
    :param hypothesis: The hypothesis sentence.
H
Hui Zhang 已提交
146
    :type hypothesis: str
Y
yangyaming 已提交
147 148
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
Y
yangyaming 已提交
149 150
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
Y
yangyaming 已提交
151
    :return: Word error rate.
Y
yangyaming 已提交
152
    :rtype: float
153
    :raises ValueError: If word number of reference is zero.
Y
yangyaming 已提交
154
    """
155 156
    edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
                                         delimiter)
157 158 159 160

    if ref_len == 0:
        raise ValueError("Reference's word number should be greater than 0.")

161
    wer = float(edit_distance) / ref_len
Y
yangyaming 已提交
162 163 164
    return wer


165
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
Y
yangyaming 已提交
166
    """Calculate charactor error rate (CER). CER compares reference text and
Y
yangyaming 已提交
167
    hypothesis text in char-level. CER is defined as:
Y
yangyaming 已提交
168 169 170 171 172 173 174 175

    .. math::
        CER = (Sc + Dc + Ic) / Nc

    where

    .. code-block:: text

Y
yangyaming 已提交
176 177 178
        Sc is the number of characters substituted,
        Dc is the number of characters deleted,
        Ic is the number of characters inserted
Y
yangyaming 已提交
179 180
        Nc is the number of characters in the reference

181 182 183 184
    We can use levenshtein distance to calculate CER. Chinese input should be
    encoded to unicode. Please draw an attention that the leading and tailing
    space characters will be truncated and multiple consecutive space
    characters in a sentence will be replaced by one space character.
Y
yangyaming 已提交
185 186

    :param reference: The reference sentence.
H
Hui Zhang 已提交
187
    :type reference: str
Y
yangyaming 已提交
188
    :param hypothesis: The hypothesis sentence.
H
Hui Zhang 已提交
189
    :type hypothesis: str
Y
yangyaming 已提交
190
    :param ignore_case: Whether case-sensitive or not.
Y
yangyaming 已提交
191
    :type ignore_case: bool
192 193
    :param remove_space: Whether remove internal space characters
    :type remove_space: bool
Y
yangyaming 已提交
194
    :return: Character error rate.
Y
yangyaming 已提交
195
    :rtype: float
Y
yangyaming 已提交
196
    :raises ValueError: If the reference length is zero.
Y
yangyaming 已提交
197
    """
198 199
    edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
                                         remove_space)
200 201 202 203 204

    if ref_len == 0:
        raise ValueError("Length of reference should be greater than 0.")

    cer = float(edit_distance) / ref_len
Y
yangyaming 已提交
205
    return cer