提交 ada40967 编写于 作者: Y yangyaming

Follow comments.

上级 0322d752
...@@ -2,14 +2,20 @@ ...@@ -2,14 +2,20 @@
"""This module provides functions to calculate error rate in different level. """This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
def levenshtein_distance(ref, hyp): def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference between
two sequences. Informally, the levenshtein disctance is defined as the minimum
number of single-character edits (substitutions, insertions or deletions)
required to change one word into the other. We can naturally extend the edits to
word level when calculate levenshtein disctance for two sentences.
"""
ref_len = len(ref) ref_len = len(ref)
hyp_len = len(hyp) hyp_len = len(hyp)
...@@ -72,7 +78,7 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '): ...@@ -72,7 +78,7 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
:type delimiter: char :type delimiter: char
:return: Word error rate. :return: Word error rate.
:rtype: float :rtype: float
:raises ValueError: If reference length is zero. :raises ValueError: If the reference length is zero.
""" """
if ignore_case == True: if ignore_case == True:
reference = reference.lower() reference = reference.lower()
...@@ -84,7 +90,7 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '): ...@@ -84,7 +90,7 @@ def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
if len(ref_words) == 0: if len(ref_words) == 0:
raise ValueError("Reference's word number should be greater than 0.") raise ValueError("Reference's word number should be greater than 0.")
edit_distance = levenshtein_distance(ref_words, hyp_words) edit_distance = _levenshtein_distance(ref_words, hyp_words)
wer = float(edit_distance) / len(ref_words) wer = float(edit_distance) / len(ref_words)
return wer return wer
...@@ -118,7 +124,7 @@ def cer(reference, hypothesis, ignore_case=False): ...@@ -118,7 +124,7 @@ def cer(reference, hypothesis, ignore_case=False):
:type ignore_case: bool :type ignore_case: bool
:return: Character error rate. :return: Character error rate.
:rtype: float :rtype: float
:raises ValueError: If reference length is zero. :raises ValueError: If the reference length is zero.
""" """
if ignore_case == True: if ignore_case == True:
reference = reference.lower() reference = reference.lower()
...@@ -130,6 +136,6 @@ def cer(reference, hypothesis, ignore_case=False): ...@@ -130,6 +136,6 @@ def cer(reference, hypothesis, ignore_case=False):
if len(reference) == 0: if len(reference) == 0:
raise ValueError("Length of reference should be greater than 0.") raise ValueError("Length of reference should be greater than 0.")
edit_distance = levenshtein_distance(reference, hypothesis) edit_distance = _levenshtein_distance(reference, hypothesis)
cer = float(edit_distance) / len(reference) cer = float(edit_distance) / len(reference)
return cer return cer
...@@ -23,10 +23,8 @@ class TestParse(unittest.TestCase): ...@@ -23,10 +23,8 @@ class TestParse(unittest.TestCase):
def test_wer_3(self): def test_wer_3(self):
ref = ' ' ref = ' '
hyp = 'Hypothesis sentence' hyp = 'Hypothesis sentence'
try: with self.assertRaises(ValueError):
word_error_rate = error_rate.wer(ref, hyp) word_error_rate = error_rate.wer(ref, hyp)
except Exception as e:
self.assertTrue(isinstance(e, ValueError))
def test_cer_1(self): def test_cer_1(self):
ref = 'werewolf' ref = 'werewolf'
...@@ -53,10 +51,8 @@ class TestParse(unittest.TestCase): ...@@ -53,10 +51,8 @@ class TestParse(unittest.TestCase):
def test_cer_5(self): def test_cer_5(self):
ref = '' ref = ''
hyp = 'Hypothesis' hyp = 'Hypothesis'
try: with self.assertRaises(ValueError):
char_error_rate = error_rate.cer(ref, hyp) char_error_rate = error_rate.cer(ref, hyp)
except Exception as e:
self.assertTrue(isinstance(e, ValueError))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册