diff --git a/fluid/DeepASR/infer_by_ckpt.py b/fluid/DeepASR/infer_by_ckpt.py index bf6093acb8e14ec926d1aefb759207905e468f8d..4a4073c02279bfd74b8ce31d0877a5338400d93b 100644 --- a/fluid/DeepASR/infer_by_ckpt.py +++ b/fluid/DeepASR/infer_by_ckpt.py @@ -17,6 +17,7 @@ from decoder.post_decode_faster import Decoder from data_utils.util import lodtensor_to_ndarray from model_utils.model import stacked_lstmp_model from data_utils.util import split_infer_result +from tools.error_rate import char_errors def parse_args(): @@ -86,6 +87,11 @@ def parse_args(): type=str, default='data/infer_label.lst', help='The label list path for inference. (default: %(default)s)') + parser.add_argument( + '--ref_txt', + type=str, + default='data/text.test', + help='The reference text for decoding. (default: %(default)s)') parser.add_argument( '--checkpoint', type=str, @@ -111,6 +117,11 @@ def parse_args(): type=float, default=0.2, help="Scaling factor for acoustic likelihoods. (default: %(default)f)") + parser.add_argument( + '--target_trans', + type=str, + default="./decoder/target_trans.txt", + help="The path to target transcription. (default: %(default)s)") args = parser.parse_args() return args @@ -122,6 +133,18 @@ def print_arguments(args): print('------------------------------------------------') +def get_trg_trans(args): + trans_dict = {} + with open(args.target_trans) as trg_trans: + line = trg_trans.readline() + while line: + items = line.strip().split() + key = items[0] + trans_dict[key] = ''.join(items[1:]) + line = trg_trans.readline() + return trans_dict + + def infer_from_ckpt(args): """Inference by using checkpoint.""" @@ -145,6 +168,7 @@ def infer_from_ckpt(args): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + trg_trans = get_trg_trans(args) # load checkpoint. fluid.io.load_persistables(exe, args.checkpoint) @@ -166,11 +190,12 @@ def infer_from_ckpt(args): args.infer_label_lst) infer_data_reader.set_transformers(ltrans) infer_costs, infer_accs = [], [] + total_edit_dist, total_ref_len = 0.0, 0 for batch_id, batch_data in enumerate( infer_data_reader.batch_iterator(args.batch_size, args.minimum_batch_size)): # load_data - (features, labels, lod) = batch_data + (features, labels, lod, name_lst) = batch_data feature_t.set(features, place) feature_t.set_lod([lod]) label_t.set(labels, place) @@ -186,11 +211,19 @@ def infer_from_ckpt(args): probs, lod = lodtensor_to_ndarray(results[0]) infer_batch = split_infer_result(probs, lod) - for index, sample in enumerate(infer_batch): - key = "utter#%d" % (batch_id * args.batch_size + index) - print(key, ": ", decoder.decode(key, sample).encode("utf8"), "\n") - print(np.mean(infer_costs), np.mean(infer_accs)) + for index, sample in enumerate(infer_batch): + key = name_lst[index] + ref = trg_trans[key] + hyp = decoder.decode(key, sample) + edit_dist, ref_len = char_errors(ref.decode("utf8"), hyp) + total_edit_dist += edit_dist + total_ref_len += ref_len + print(key + "|Ref:", ref) + print(key + "|Hyp:", hyp.encode("utf8")) + print("Instance CER: ", edit_dist / ref_len) + + print("Total CER = %f" % (total_edit_dist / total_ref_len)) if __name__ == '__main__': diff --git a/fluid/DeepASR/tools/error_rate.py b/fluid/DeepASR/tools/error_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..215ad39d24a551879d0fd8d4c8892161a0708370 --- /dev/null +++ b/fluid/DeepASR/tools/error_rate.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +"""This module provides functions to calculate error rate in different level. +e.g. wer for word-level, cer for char-level. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def _levenshtein_distance(ref, hyp): + """Levenshtein distance is a string metric for measuring the difference + between two sequences. Informally, the levenshtein disctance is defined as + the minimum number of single-character edits (substitutions, insertions or + deletions) required to change one word into the other. We can naturally + extend the edits to word level when calculate levenshtein disctance for + two sentences. + """ + m = len(ref) + n = len(hyp) + + # special case + if ref == hyp: + return 0 + if m == 0: + return n + if n == 0: + return m + + if m < n: + ref, hyp = hyp, ref + m, n = n, m + + # use O(min(m, n)) space + distance = np.zeros((2, n + 1), dtype=np.int32) + + # initialize distance matrix + for j in xrange(n + 1): + distance[0][j] = j + + # calculate levenshtein distance + for i in xrange(1, m + 1): + prev_row_idx = (i - 1) % 2 + cur_row_idx = i % 2 + distance[cur_row_idx][0] = i + for j in xrange(1, n + 1): + if ref[i - 1] == hyp[j - 1]: + distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] + else: + s_num = distance[prev_row_idx][j - 1] + 1 + i_num = distance[cur_row_idx][j - 1] + 1 + d_num = distance[prev_row_idx][j] + 1 + distance[cur_row_idx][j] = min(s_num, i_num, d_num) + + return distance[m % 2][n] + + +def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in word-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Levenshtein distance and word number of reference sentence. + :rtype: list + """ + if ignore_case == True: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = filter(None, reference.split(delimiter)) + hyp_words = filter(None, hypothesis.split(delimiter)) + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + return float(edit_distance), len(ref_words) + + +def char_errors(reference, hypothesis, ignore_case=False, remove_space=False): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in char-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Levenshtein distance and length of reference sentence. + :rtype: list + """ + if ignore_case == True: + reference = reference.lower() + hypothesis = hypothesis.lower() + + join_char = ' ' + if remove_space == True: + join_char = '' + + reference = join_char.join(filter(None, reference.split(' '))) + hypothesis = join_char.join(filter(None, hypothesis.split(' '))) + + edit_distance = _levenshtein_distance(reference, hypothesis) + return float(edit_distance), len(reference) + + +def wer(reference, hypothesis, ignore_case=False, delimiter=' '): + """Calculate word error rate (WER). WER compares reference text and + hypothesis text in word-level. WER is defined as: + .. math:: + WER = (Sw + Dw + Iw) / Nw + where + .. code-block:: text + Sw is the number of words subsituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + We can use levenshtein distance to calculate WER. Please draw an attention + that empty items will be removed when splitting sentences by delimiter. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Word error rate. + :rtype: float + :raises ValueError: If word number of reference is zero. + """ + edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, + delimiter) + + if ref_len == 0: + raise ValueError("Reference's word number should be greater than 0.") + + wer = float(edit_distance) / ref_len + return wer + + +def cer(reference, hypothesis, ignore_case=False, remove_space=False): + """Calculate charactor error rate (CER). CER compares reference text and + hypothesis text in char-level. CER is defined as: + .. math:: + CER = (Sc + Dc + Ic) / Nc + where + .. code-block:: text + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted + Nc is the number of characters in the reference + We can use levenshtein distance to calculate CER. Chinese input should be + encoded to unicode. Please draw an attention that the leading and tailing + space characters will be truncated and multiple consecutive space + characters in a sentence will be replaced by one space character. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Character error rate. + :rtype: float + :raises ValueError: If the reference length is zero. + """ + edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, + remove_space) + + if ref_len == 0: + raise ValueError("Length of reference should be greater than 0.") + + cer = float(edit_distance) / ref_len + return cer