未验证 提交 2b8c08e3 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1673 from Jackwaterveg/CER

[asr] Add new cer tools
...@@ -5,6 +5,8 @@ if [ $# != 4 ];then ...@@ -5,6 +5,8 @@ if [ $# != 4 ];then
exit -1 exit -1
fi fi
stage=0
stop_stage=100
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
...@@ -19,18 +21,45 @@ if [ $? -ne 0 ]; then ...@@ -19,18 +21,45 @@ if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
python3 -u ${BIN_DIR}/test.py \ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--ngpu ${ngpu} \ # format the reference test file
--config ${config_path} \ python utils/format_rsl.py \
--decode_cfg ${decode_config_path} \ --origin_ref data/manifest.test.raw \
--result_file ${ckpt_prefix}.rsl \ --trans_ref data/manifest.test.text
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"
exit 1 exit 1
fi
# format the hyp file
python utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.rsl \
--trans_hyp ${ckpt_prefix}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error
fi fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref_sclite data/manifest.test.text.sclite
python utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.rsl \
--trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite
mkdir -p ${ckpt_prefix}_sclite
sclite -i wsj -r data/manifest.test.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII
fi
exit 0 exit 0
...@@ -5,6 +5,8 @@ if [ $# != 3 ];then ...@@ -5,6 +5,8 @@ if [ $# != 3 ];then
exit -1 exit -1
fi fi
stage=0
stop_stage=100
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
...@@ -24,7 +26,13 @@ fi ...@@ -24,7 +26,13 @@ fi
#fi #fi
for type in attention ctc_greedy_search; do if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# format the reference test file
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref data/manifest.test.text
for type in attention ctc_greedy_search; do
echo "decoding ${type}" echo "decoding ${type}"
if [ ${chunk_mode} == true ];then if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1 # stream decoding only support batchsize=1
...@@ -46,10 +54,18 @@ for type in attention ctc_greedy_search; do ...@@ -46,10 +54,18 @@ for type in attention ctc_greedy_search; do
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"
exit 1 exit 1
fi fi
done # format the hyp file
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl \
--trans_hyp ${output_dir}/${type}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${output_dir}/${type}.rsl.text > ${output_dir}/${type}.error
done
for type in ctc_prefix_beam_search attention_rescoring; do for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}" echo "decoding ${type}"
batch_size=1 batch_size=1
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
...@@ -67,6 +83,29 @@ for type in ctc_prefix_beam_search attention_rescoring; do ...@@ -67,6 +83,29 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "Failed in evaluation!" echo "Failed in evaluation!"
exit 1 exit 1
fi fi
done python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl
--trans_hyp ${output_dir}/${type}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${output_dir}/${type}.rsl.text > ${output_dir}/${type}.error
done
fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
# format the reference test file for sclite
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref_sclite data/manifest.test.text.sclite
output_dir=${ckpt_prefix}
for type in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl
--trans_hyp_sclite ${output_dir}/${type}.rsl.text.sclite
mkdir -p ${output_dir}/${type}_sclite
sclite -i wsj -r data/manifest.test.text.sclite -h ${output_dir}/${type}.rsl.text.sclite -e utf-8 -o all -O ${output_dir}/${type}_sclite -c NOASCII
done
fi
exit 0 exit 0
...@@ -7,7 +7,7 @@ stage=0 ...@@ -7,7 +7,7 @@ stage=0
stop_stage=50 stop_stage=50
conf_path=conf/conformer.yaml conf_path=conf/conformer.yaml
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=20 avg_num=30
audio_file=data/demo_01_03.wav audio_file=data/demo_01_03.wav
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......
...@@ -278,7 +278,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -278,7 +278,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref": target, "hyp": result}) fout.write({"utt": utt, "refs": [target], "hyps": [result]})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")
......
此差异已折叠。
import os
import argparse
import jsonlines
def trans_hyp(origin_hyp,
trans_hyp = None,
trans_hyp_sclite = None):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin_hyp, "r+", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["hyps"][0]
if trans_hyp is not None:
with open(trans_hyp, "w+", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_hyp_sclite is not None:
with open(trans_hyp_sclite, "w+") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n"
f.write(line)
def trans_ref(origin_ref,
trans_ref = None,
trans_ref_sclite = None):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin_ref, "r", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["text"]
if trans_ref is not None:
with open(trans_ref, "w", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_ref_sclite is not None:
with open(trans_ref_sclite, "w") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n"
f.write(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True)
parser.add_argument(
'--origin_hyp',
type=str,
default = None,
help='origin hyp file')
parser.add_argument(
'--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER')
parser.add_argument(
'--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite')
parser.add_argument(
'--origin_ref',
type=str,
default = None,
help='origin ref file')
parser.add_argument(
'--trans_ref', type=str, default = None, help='ref file for caculating CER/WER')
parser.add_argument(
'--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args()
if parser_args.origin_hyp is not None:
trans_hyp(
origin_hyp = parser_args.origin_hyp,
trans_hyp = parser_args.trans_hyp,
trans_hyp_sclite = parser_args.trans_hyp_sclite, )
if parser_args.origin_ref is not None:
trans_ref(
origin_ref = parser_args.origin_ref,
trans_ref = parser_args.trans_ref,
trans_ref_sclite = parser_args.trans_ref_sclite, )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册