提交 feaf71d4 编写于 作者: H Hui Zhang

u2 kaldi mutli process test with batchsize one

上级 aaa87698
...@@ -444,7 +444,7 @@ class U2Tester(U2Trainer): ...@@ -444,7 +444,7 @@ class U2Tester(U2Trainer):
start_time = time.time() start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode( result_transcripts, result_tokenids = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature=text_feature, text_feature=text_feature,
...@@ -462,14 +462,19 @@ class U2Tester(U2Trainer): ...@@ -462,14 +462,19 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts, for utt, target, result, rec_tids in zip(
result_transcripts): utts, target_transcripts, result_transcripts, result_tokenids):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref": target, "hyp": result}) fout.write({
"utt": utt,
"refs": [target],
"hyps": [result],
"hyps_tokenid": [rec_tids],
})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")
......
...@@ -390,6 +390,10 @@ class U2Tester(U2Trainer): ...@@ -390,6 +390,10 @@ class U2Tester(U2Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
def id2token(self, texts, texts_len, text_feature): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
...@@ -413,15 +417,11 @@ class U2Tester(U2Trainer): ...@@ -413,15 +417,11 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time() start_time = time.time()
text_feature = TextFeaturizer( target_transcripts = self.id2token(texts, texts_len, self.text_feature)
unit_type=self.config.collator.unit_type, result_transcripts, result_tokenids = self.model.decode(
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
target_transcripts = self.id2token(texts, texts_len, text_feature)
result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature=text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha, beam_alpha=cfg.alpha,
...@@ -436,14 +436,19 @@ class U2Tester(U2Trainer): ...@@ -436,14 +436,19 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts, for i, (utt, target, result, rec_tids) in enumerate(zip(
result_transcripts): utts, target_transcripts, result_transcripts, result_tokenids)):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref": target, "hyp": result}) fout.write({
"utt": utt,
"refs": [target],
"hyps": [result],
"hyps_tokenid": [rec_tids],
})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")
......
...@@ -32,7 +32,7 @@ __all__ = ["SpeechCollator", "TripletSpeechCollator"] ...@@ -32,7 +32,7 @@ __all__ = ["SpeechCollator", "TripletSpeechCollator"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def tokenids(text, keep_transcription_text): def _tokenids(text, keep_transcription_text):
# for training text is token ids # for training text is token ids
tokens = text # token ids tokens = text # token ids
...@@ -93,6 +93,8 @@ class SpeechCollatorBase(): ...@@ -93,6 +93,8 @@ class SpeechCollatorBase():
a user-defined shape) within one batch. a user-defined shape) within one batch.
""" """
self.keep_transcription_text = keep_transcription_text self.keep_transcription_text = keep_transcription_text
self.train_mode = not keep_transcription_text
self.stride_ms = stride_ms self.stride_ms = stride_ms
self.window_ms = window_ms self.window_ms = window_ms
self.feat_dim = feat_dim self.feat_dim = feat_dim
...@@ -192,6 +194,7 @@ class SpeechCollatorBase(): ...@@ -192,6 +194,7 @@ class SpeechCollatorBase():
texts = [] texts = []
text_lens = [] text_lens = []
utts = [] utts = []
tids = [] # tokenids
for idx, item in enumerate(batch): for idx, item in enumerate(batch):
utts.append(item['utt']) utts.append(item['utt'])
...@@ -203,7 +206,7 @@ class SpeechCollatorBase(): ...@@ -203,7 +206,7 @@ class SpeechCollatorBase():
audios.append(audio) # [T, D] audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0]) audio_lens.append(audio.shape[0])
tokens = tokenids(text, self.keep_transcription_text) tokens = _tokenids(text, self.keep_transcription_text)
texts.append(tokens) texts.append(tokens)
text_lens.append(tokens.shape[0]) text_lens.append(tokens.shape[0])
......
...@@ -142,6 +142,15 @@ class BatchDataLoader(): ...@@ -142,6 +142,15 @@ class BatchDataLoader():
collate_fn=batch_collate, collate_fn=batch_collate,
num_workers=self.n_iter_processes, ) num_workers=self.n_iter_processes, )
def __len__(self):
return len(self.dataloader)
def __iter__(self):
return self.dataloader.__iter__()
def __call__(self):
return self.__iter__()
def __repr__(self): def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> " echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
echo += f"train_mode: {self.train_mode}, " echo += f"train_mode: {self.train_mode}, "
...@@ -159,12 +168,3 @@ class BatchDataLoader(): ...@@ -159,12 +168,3 @@ class BatchDataLoader():
echo += f"num_workers: {self.n_iter_processes}, " echo += f"num_workers: {self.n_iter_processes}, "
echo += f"file: {self.json_file}" echo += f"file: {self.json_file}"
return echo return echo
def __len__(self):
return len(self.dataloader)
def __iter__(self):
return self.dataloader.__iter__()
def __call__(self):
return self.__iter__()
...@@ -809,7 +809,8 @@ class U2BaseModel(nn.Layer): ...@@ -809,7 +809,8 @@ class U2BaseModel(nn.Layer):
raise ValueError(f"Not support decoding method: {decoding_method}") raise ValueError(f"Not support decoding method: {decoding_method}")
res = [text_feature.defeaturize(hyp) for hyp in hyps] res = [text_feature.defeaturize(hyp) for hyp in hyps]
return res res_tokenids = [hyp for hyp in hyps]
return res, res_tokenids
class U2Model(U2BaseModel): class U2Model(U2BaseModel):
......
[ [
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{ {
"type": "speed", "type": "speed",
"params": { "params": {
...@@ -16,6 +8,14 @@ ...@@ -16,6 +8,14 @@
}, },
"prob": 0.0 "prob": 0.0
}, },
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{ {
"type": "specaug", "type": "specaug",
"params": { "params": {
......
# LibriSpeech # LibriSpeech
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 0.83s ~ 29.735s |
| data/manifest.dev | 1.065 ~ 35.155s |
| data/manifest.test-clean | 1.285s ~ 34.955s |
## Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | - | - |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | |
## Chunk Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 16, -1 | | - |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 16, -1 | | - |
## Transformer ## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | | Model | Params | Config | Augmentation| Test Set | Decode Method | Loss | WER % |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | | | transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.395054340362549 | 4.2 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.395054340362549 | 5.0 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.395054340362549 | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescore | 6.395054340362549 | |
...@@ -5,9 +5,9 @@ data: ...@@ -5,9 +5,9 @@ data:
test_manifest: data/manifest.test-clean test_manifest: data/manifest.test-clean
collator: collator:
vocab_filepath: data/train_960_unigram5000_units.txt vocab_filepath: data/bpe_unigram_5000_units.txt
unit_type: 'spm' unit_type: spm
spm_model_prefix: 'data/train_960_unigram5000' spm_model_prefix: data/bpe_unigram_5000
feat_dim: 83 feat_dim: 83
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
......
...@@ -46,15 +46,17 @@ pids=() # initialize pids ...@@ -46,15 +46,17 @@ pids=() # initialize pids
for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
( (
echo "${dmethd} decoding"
for rtask in ${recog_set}; do for rtask in ${recog_set}; do
( (
decode_dir=decode_${rtask}_${dmethd}_$(basename ${config_path%.*})_${lmtag} echo "${rtask} dataset"
decode_dir=decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}
feat_recog_dir=${datadir} feat_recog_dir=${datadir}
mkdir -p ${expdir}/${decode_dir} mkdir -p ${expdir}/${decode_dir}
mkdir -p ${feat_recog_dir} mkdir -p ${feat_recog_dir}
# split data # split data
split_json.sh ${feat_recog_dir}/manifest.${rtask} ${nj} split_json.sh manifest.${rtask} ${nj}
#### use CPU for decoding #### use CPU for decoding
ngpu=0 ngpu=0
...@@ -74,17 +76,16 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco ...@@ -74,17 +76,16 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco
--opts decoding.batch_size ${batch_size} \ --opts decoding.batch_size ${batch_size} \
--opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} --opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true ${expdir}/${decode_dir} ${dict} score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${expdir}/${decode_dir} ${dict}
) & ) &
pids+=($!) # store background pids pids+=($!) # store background pids
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
done done
) & )
pids+=($!) # store background pids
done done
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
echo "Finished" echo "Finished"
exit 0 exit 0
...@@ -32,7 +32,7 @@ fi ...@@ -32,7 +32,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
......
...@@ -6,7 +6,7 @@ CC ?= gcc # used for sph2pipe ...@@ -6,7 +6,7 @@ CC ?= gcc # used for sph2pipe
# CXX = clang++ # Uncomment these lines... # CXX = clang++ # Uncomment these lines...
# CC = clang # ...to build with Clang. # CC = clang # ...to build with Clang.
WGET ?= wget WGET ?= wget --no-check-certificate
.PHONY: all clean .PHONY: all clean
......
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import json
import logging
import sys
import jsonlines
from utility import get_commandline_args
def get_parser():
parser = argparse.ArgumentParser(
description="convert a json to a transcription file with a token dictionary",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument("json", type=str, help="jsonlines files")
parser.add_argument("dict", type=str, help="dict, not used.")
parser.add_argument(
"--num-spkrs", type=int, default=1, help="number of speakers")
parser.add_argument(
"--refs", type=str, nargs="+", help="ref for all speakers")
parser.add_argument(
"--hyps", type=str, nargs="+", help="hyp for all outputs")
return parser
def main(args):
args = get_parser().parse_args(args)
convert(args.json, args.dict, args.refs, args.hyps, args.num_spkrs)
def convert(jsonf, dic, refs, hyps, num_spkrs=1):
n_ref = len(refs)
n_hyp = len(hyps)
assert n_ref == n_hyp
assert n_ref == num_spkrs
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info(get_commandline_args())
logging.info("reading %s", jsonf)
with jsonlines.open(jsonf, "r") as f:
j = [item for item in f]
logging.info("reading %s", dic)
with open(dic, "r") as f:
dictionary = f.readlines()
char_list = [entry.split(" ")[0] for entry in dictionary]
char_list.insert(0, "<blank>")
char_list.append("<eos>")
for ns in range(num_spkrs):
hyp_file = open(hyps[ns], "w")
ref_file = open(refs[ns], "w")
for x in j:
# recognition hypothesis
if num_spkrs == 1:
#seq = [char_list[int(i)] for i in x['hyps_tokenid'][0]]
seq = x['hyps'][0]
else:
seq = [char_list[int(i)] for i in x['hyps_tokenid'][ns]]
# In the recognition hypothesis,
# the <eos> symbol is usually attached in the last part of the sentence
# and it is removed below.
#hyp_file.write(" ".join(seq).replace("<eos>", ""))
hyp_file.write(seq.replace("<eos>", ""))
# spk-uttid
hyp_file.write(" (" + x["utt"] + ")\n")
# reference
if num_spkrs == 1:
seq = x["refs"][0]
else:
seq = x['refs'][ns]
# Unlike the recognition hypothesis,
# the reference is directly generated from a token without dictionary
# to avoid to include <unk> symbols in the reference to make scoring normal.
# The detailed discussion can be found at
# https://github.com/espnet/espnet/issues/993
# ref_file.write(
# seq + " (" + j["utts"][x]["utt2spk"].replace("-", "_") + "-" + x + ")\n"
# )
ref_file.write(seq + " (" + x['utt'] + ")\n")
hyp_file.close()
ref_file.close()
if __name__ == "__main__":
main(sys.argv[1:])
#!/usr/bin/env bash #!/usr/bin/env bash
set -e
# Copyright 2017 Johns Hopkins University (Shinji Watanabe) # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import hashlib import hashlib
import json import json
import os import os
import sys
import tarfile import tarfile
import zipfile import zipfile
from typing import Text from typing import Text
...@@ -21,7 +22,7 @@ from typing import Text ...@@ -21,7 +22,7 @@ from typing import Text
__all__ = [ __all__ = [
"check_md5sum", "getfile_insensitive", "download_multi", "download", "check_md5sum", "getfile_insensitive", "download_multi", "download",
"unpack", "unzip", "md5file", "print_arguments", "add_arguments", "unpack", "unzip", "md5file", "print_arguments", "add_arguments",
"read_manifest" "read_manifest", "get_commandline_args"
] ]
...@@ -46,6 +47,40 @@ def read_manifest(manifest_path): ...@@ -46,6 +47,40 @@ def read_manifest(manifest_path):
return manifest return manifest
def get_commandline_args():
extra_chars = [
" ",
";",
"&",
"(",
")",
"|",
"^",
"<",
">",
"?",
"*",
"[",
"]",
"$",
"`",
'"',
"\\",
"!",
"{",
"}",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''") if all(char not in arg
for char in extra_chars) else
"'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv
]
return sys.executable + " " + " ".join(argv)
def print_arguments(args, info=None): def print_arguments(args, info=None):
"""Print argparse's arguments. """Print argparse's arguments.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册