From 18db3cf77324db5203fcba9af208b0915d46a652 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 4 Jul 2018 18:50:08 -0700 Subject: [PATCH] Handle the special tokens in scoring cer --- fluid/DeepASR/score_error_rate.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/fluid/DeepASR/score_error_rate.py b/fluid/DeepASR/score_error_rate.py index 5ecbca08..dde5a244 100644 --- a/fluid/DeepASR/score_error_rate.py +++ b/fluid/DeepASR/score_error_rate.py @@ -16,10 +16,18 @@ def parse_args(): default='cer', choices=['cer', 'wer'], help="Error rate type. (default: %(default)s)") + parser.add_argument( + '--special_tokens', + type=str, + default='', + help="Special tokens in scoring CER, seperated by space. " + "They shouldn't be splitted and should be treated as one special " + "character. Example: ' ' " + "(default: %(default)s)") parser.add_argument( '--ref', type=str, required=True, help="The ground truth text.") parser.add_argument( - '--hyp', type=str, required=True, help="The decoding result.") + '--hyp', type=str, required=True, help="The decoding result text.") args = parser.parse_args() return args @@ -31,6 +39,8 @@ if __name__ == '__main__': sum_errors, sum_ref_len = 0.0, 0 sent_cnt, not_in_ref_cnt = 0, 0 + special_tokens = args.special_tokens.split(" ") + with open(args.ref, "r") as ref_txt: line = ref_txt.readline() while line: @@ -51,6 +61,8 @@ if __name__ == '__main__': continue if args.error_rate_type == 'cer': + for sp_tok in special_tokens: + sent = sent.replace(sp_tok, '\0') errors, ref_len = char_errors( ref_dict[key].decode("utf8"), sent.decode("utf8"), -- GitLab