提交 7310baa8 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #193 from pkuyym/error_rate_optional

Make type of error rate optional.
......@@ -9,7 +9,7 @@ import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
from error_rate import wer
from error_rate import wer, cer
import utils
parser = argparse.ArgumentParser(description=__doc__)
......@@ -111,6 +111,14 @@ parser.add_argument(
default='datasets/vocab/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--error_rate_type",
default='wer',
choices=['wer', 'cer'],
type=str,
help="Error rate type for evaluation. 'wer' for word error rate and 'cer' "
"for character error rate. "
"(default: %(default)s)")
args = parser.parse_args()
......@@ -136,7 +144,8 @@ def evaluate():
rnn_layer_size=args.rnn_layer_size,
pretrained_model_path=args.model_filepath)
wer_sum, num_ins = 0.0, 0
error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_sum, num_ins = 0.0, 0
for infer_data in batch_reader():
result_transcripts = ds2_model.infer_batch(
infer_data=infer_data,
......@@ -153,10 +162,12 @@ def evaluate():
for _, transcript in infer_data
]
for target, result in zip(target_transcripts, result_transcripts):
wer_sum += wer(target, result)
error_sum += error_rate_func(target, result)
num_ins += 1
print("WER (%d/?) = %f" % (num_ins, wer_sum / num_ins))
print("Final WER (%d/%d) = %f" % (num_ins, num_ins, wer_sum / num_ins))
print("Error rate [%s] (%d/?) = %f" %
(args.error_rate_type, num_ins, error_sum / num_ins))
print("Final error rate [%s] (%d/%d) = %f" %
(args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
def main():
......
......@@ -9,7 +9,7 @@ import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
from error_rate import wer
from error_rate import wer, cer
import utils
parser = argparse.ArgumentParser(description=__doc__)
......@@ -111,6 +111,14 @@ parser.add_argument(
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
parser.add_argument(
"--error_rate_type",
default='wer',
choices=['wer', 'cer'],
type=str,
help="Error rate type for evaluation. 'wer' for word error rate and 'cer' "
"for character error rate. "
"(default: %(default)s)")
args = parser.parse_args()
......@@ -147,6 +155,7 @@ def infer():
language_model_path=args.language_model_path,
num_processes=args.num_processes_beam_search)
error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
......@@ -154,7 +163,8 @@ def infer():
for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
print("Current wer = %f" % wer(target, result))
print("Current error rate [%s] = %f" %
(args.error_rate_type, error_rate_func(target, result)))
def main():
......
......@@ -185,7 +185,7 @@ class DeepSpeech2Model(object):
# best path decode
for i, probs in enumerate(probs_split):
output_transcription = ctc_best_path_decoder(
probs_seq=probs, vocabulary=data_generator.vocab_list)
probs_seq=probs, vocabulary=vocab_list)
results.append(output_transcription)
elif decode_method == "beam_search":
# initialize external scorer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册