提交 feaf71d4 编写于 作者: H Hui Zhang

u2 kaldi mutli process test with batchsize one

上级 aaa87698
......@@ -444,7 +444,7 @@ class U2Tester(U2Trainer):
start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode(
result_transcripts, result_tokenids = self.model.decode(
audio,
audio_len,
text_feature=text_feature,
......@@ -462,14 +462,19 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
for utt, target, result, rec_tids in zip(
utts, target_transcripts, result_transcripts, result_tokenids):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref": target, "hyp": result})
fout.write({
"utt": utt,
"refs": [target],
"hyps": [result],
"hyps_tokenid": [rec_tids],
})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
......
......@@ -390,6 +390,10 @@ class U2Tester(U2Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
......@@ -413,15 +417,11 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
target_transcripts = self.id2token(texts, texts_len, text_feature)
result_transcripts = self.model.decode(
target_transcripts = self.id2token(texts, texts_len, self.text_feature)
result_transcripts, result_tokenids = self.model.decode(
audio,
audio_len,
text_feature=text_feature,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
......@@ -436,14 +436,19 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
for i, (utt, target, result, rec_tids) in enumerate(zip(
utts, target_transcripts, result_transcripts, result_tokenids)):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref": target, "hyp": result})
fout.write({
"utt": utt,
"refs": [target],
"hyps": [result],
"hyps_tokenid": [rec_tids],
})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
......
......@@ -32,7 +32,7 @@ __all__ = ["SpeechCollator", "TripletSpeechCollator"]
logger = Log(__name__).getlog()
def tokenids(text, keep_transcription_text):
def _tokenids(text, keep_transcription_text):
# for training text is token ids
tokens = text # token ids
......@@ -93,6 +93,8 @@ class SpeechCollatorBase():
a user-defined shape) within one batch.
"""
self.keep_transcription_text = keep_transcription_text
self.train_mode = not keep_transcription_text
self.stride_ms = stride_ms
self.window_ms = window_ms
self.feat_dim = feat_dim
......@@ -192,6 +194,7 @@ class SpeechCollatorBase():
texts = []
text_lens = []
utts = []
tids = [] # tokenids
for idx, item in enumerate(batch):
utts.append(item['utt'])
......@@ -203,7 +206,7 @@ class SpeechCollatorBase():
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
tokens = tokenids(text, self.keep_transcription_text)
tokens = _tokenids(text, self.keep_transcription_text)
texts.append(tokens)
text_lens.append(tokens.shape[0])
......
......@@ -142,6 +142,15 @@ class BatchDataLoader():
collate_fn=batch_collate,
num_workers=self.n_iter_processes, )
def __len__(self):
return len(self.dataloader)
def __iter__(self):
return self.dataloader.__iter__()
def __call__(self):
return self.__iter__()
def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
echo += f"train_mode: {self.train_mode}, "
......@@ -159,12 +168,3 @@ class BatchDataLoader():
echo += f"num_workers: {self.n_iter_processes}, "
echo += f"file: {self.json_file}"
return echo
def __len__(self):
return len(self.dataloader)
def __iter__(self):
return self.dataloader.__iter__()
def __call__(self):
return self.__iter__()
......@@ -809,7 +809,8 @@ class U2BaseModel(nn.Layer):
raise ValueError(f"Not support decoding method: {decoding_method}")
res = [text_feature.defeaturize(hyp) for hyp in hyps]
return res
res_tokenids = [hyp for hyp in hyps]
return res, res_tokenids
class U2Model(U2BaseModel):
......
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "speed",
"params": {
......@@ -16,6 +8,14 @@
},
"prob": 0.0
},
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "specaug",
"params": {
......
# LibriSpeech
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 0.83s ~ 29.735s |
| data/manifest.dev | 1.065 ~ 35.155s |
| data/manifest.test-clean | 1.285s ~ 34.955s |
## Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | - | - |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | |
## Chunk Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 16, -1 | | - |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 16, -1 | | - |
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| Model | Params | Config | Augmentation| Test Set | Decode Method | Loss | WER % |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.395054340362549 | 4.2 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.395054340362549 | 5.0 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.395054340362549 | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescore | 6.395054340362549 | |
......@@ -5,9 +5,9 @@ data:
test_manifest: data/manifest.test-clean
collator:
vocab_filepath: data/train_960_unigram5000_units.txt
unit_type: 'spm'
spm_model_prefix: 'data/train_960_unigram5000'
vocab_filepath: data/bpe_unigram_5000_units.txt
unit_type: spm
spm_model_prefix: data/bpe_unigram_5000
feat_dim: 83
stride_ms: 10.0
window_ms: 25.0
......
......@@ -46,15 +46,17 @@ pids=() # initialize pids
for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
(
echo "${dmethd} decoding"
for rtask in ${recog_set}; do
(
decode_dir=decode_${rtask}_${dmethd}_$(basename ${config_path%.*})_${lmtag}
echo "${rtask} dataset"
decode_dir=decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}
feat_recog_dir=${datadir}
mkdir -p ${expdir}/${decode_dir}
mkdir -p ${feat_recog_dir}
# split data
split_json.sh ${feat_recog_dir}/manifest.${rtask} ${nj}
split_json.sh manifest.${rtask} ${nj}
#### use CPU for decoding
ngpu=0
......@@ -74,17 +76,16 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco
--opts decoding.batch_size ${batch_size} \
--opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true ${expdir}/${decode_dir} ${dict}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${expdir}/${decode_dir} ${dict}
) &
pids+=($!) # store background pids
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
done
) &
pids+=($!) # store background pids
)
done
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
echo "Finished"
exit 0
......@@ -32,7 +32,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
......
......@@ -6,7 +6,7 @@ CC ?= gcc # used for sph2pipe
# CXX = clang++ # Uncomment these lines...
# CC = clang # ...to build with Clang.
WGET ?= wget
WGET ?= wget --no-check-certificate
.PHONY: all clean
......
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import json
import logging
import sys
import jsonlines
from utility import get_commandline_args
def get_parser():
parser = argparse.ArgumentParser(
description="convert a json to a transcription file with a token dictionary",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument("json", type=str, help="jsonlines files")
parser.add_argument("dict", type=str, help="dict, not used.")
parser.add_argument(
"--num-spkrs", type=int, default=1, help="number of speakers")
parser.add_argument(
"--refs", type=str, nargs="+", help="ref for all speakers")
parser.add_argument(
"--hyps", type=str, nargs="+", help="hyp for all outputs")
return parser
def main(args):
args = get_parser().parse_args(args)
convert(args.json, args.dict, args.refs, args.hyps, args.num_spkrs)
def convert(jsonf, dic, refs, hyps, num_spkrs=1):
n_ref = len(refs)
n_hyp = len(hyps)
assert n_ref == n_hyp
assert n_ref == num_spkrs
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info(get_commandline_args())
logging.info("reading %s", jsonf)
with jsonlines.open(jsonf, "r") as f:
j = [item for item in f]
logging.info("reading %s", dic)
with open(dic, "r") as f:
dictionary = f.readlines()
char_list = [entry.split(" ")[0] for entry in dictionary]
char_list.insert(0, "<blank>")
char_list.append("<eos>")
for ns in range(num_spkrs):
hyp_file = open(hyps[ns], "w")
ref_file = open(refs[ns], "w")
for x in j:
# recognition hypothesis
if num_spkrs == 1:
#seq = [char_list[int(i)] for i in x['hyps_tokenid'][0]]
seq = x['hyps'][0]
else:
seq = [char_list[int(i)] for i in x['hyps_tokenid'][ns]]
# In the recognition hypothesis,
# the <eos> symbol is usually attached in the last part of the sentence
# and it is removed below.
#hyp_file.write(" ".join(seq).replace("<eos>", ""))
hyp_file.write(seq.replace("<eos>", ""))
# spk-uttid
hyp_file.write(" (" + x["utt"] + ")\n")
# reference
if num_spkrs == 1:
seq = x["refs"][0]
else:
seq = x['refs'][ns]
# Unlike the recognition hypothesis,
# the reference is directly generated from a token without dictionary
# to avoid to include <unk> symbols in the reference to make scoring normal.
# The detailed discussion can be found at
# https://github.com/espnet/espnet/issues/993
# ref_file.write(
# seq + " (" + j["utts"][x]["utt2spk"].replace("-", "_") + "-" + x + ")\n"
# )
ref_file.write(seq + " (" + x['utt'] + ")\n")
hyp_file.close()
ref_file.close()
if __name__ == "__main__":
main(sys.argv[1:])
#!/usr/bin/env bash
set -e
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
......
......@@ -14,6 +14,7 @@
import hashlib
import json
import os
import sys
import tarfile
import zipfile
from typing import Text
......@@ -21,7 +22,7 @@ from typing import Text
__all__ = [
"check_md5sum", "getfile_insensitive", "download_multi", "download",
"unpack", "unzip", "md5file", "print_arguments", "add_arguments",
"read_manifest"
"read_manifest", "get_commandline_args"
]
......@@ -46,6 +47,40 @@ def read_manifest(manifest_path):
return manifest
def get_commandline_args():
extra_chars = [
" ",
";",
"&",
"(",
")",
"|",
"^",
"<",
">",
"?",
"*",
"[",
"]",
"$",
"`",
'"',
"\\",
"!",
"{",
"}",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''") if all(char not in arg
for char in extra_chars) else
"'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv
]
return sys.executable + " " + " ".join(argv)
def print_arguments(args, info=None):
"""Print argparse's arguments.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册