未验证 提交 0a3a840b 编写于 作者: H Hui Zhang 提交者: GitHub

more decoding method (#618)

* more decoding method

* all decode method test scripts; result readme

* exp libri confi

* parallel data scripts; more mask test; need pybind11 repo

* speed perturb config

* libri conf test set
上级 295f8bda
.DS_Store .DS_Store
*.pyc *.pyc
.vscode .vscode
*.log *log
*.pdmodel *.pdmodel
*.pdiparams* *.pdiparams*
*.zip *.zip
......
...@@ -168,7 +168,7 @@ class DeepSpeech2Trainer(Trainer): ...@@ -168,7 +168,7 @@ class DeepSpeech2Trainer(Trainer):
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn, collate_fn=collate_fn,
num_workers=config.data.num_workers, ) num_workers=config.data.num_workers)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.data.batch_size, batch_size=config.data.batch_size,
......
...@@ -450,7 +450,7 @@ class U2Tester(U2Trainer): ...@@ -450,7 +450,7 @@ class U2Tester(U2Trainer):
logger.info(msg) logger.info(msg)
# test meta results # test meta results
err_meta_path = os.path.splitext(self.args.checkpoint_path)[0] + '.err' err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
err_type_str = "{}".format(error_rate_type) err_type_str = "{}".format(error_rate_type)
with open(err_meta_path, 'w') as f: with open(err_meta_path, 'w') as f:
data = json.dumps({ data = json.dumps({
...@@ -471,6 +471,8 @@ class U2Tester(U2Trainer): ...@@ -471,6 +471,8 @@ class U2Tester(U2Trainer):
errors_sum, errors_sum,
"ref_len": "ref_len":
len_refs, len_refs,
"decode_method":
self.config.decoding.decoding_method,
}) })
f.write(data + '\n') f.write(data + '\n')
......
...@@ -66,19 +66,22 @@ fi ...@@ -66,19 +66,22 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size # format manifest with tokenids, vocab size
for dataset in train dev test; do for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \ --feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type "char" \ --unit_type "char" \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \ --manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}" --output_path="data/manifest.${dataset}"
done
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated." echo "Formt mnaifest failed. Terminated."
exit 1 exit 1
fi fi
} &
done
wait
fi fi
echo "Aishell data preparation done." echo "Aishell data preparation done."
......
# Aishell
## Conformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- |
| conformer | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 |
| conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 |
| conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 |
| conformer | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 |
## Transformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- |
| transformer | conf/transformer.yaml | spec_aug + shift | test | attention | - | - |
...@@ -14,7 +14,7 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then ...@@ -14,7 +14,7 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
python3 ${TARGET_DIR}/aishell/aishell.py \ python3 ${TARGET_DIR}/aishell/aishell.py \
--manifest_prefix="data/manifest" \ --manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/aishell" --target_dir="${TARGET_DIR}/aishell"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Prepare Aishell failed. Terminated." echo "Prepare Aishell failed. Terminated."
exit 1 exit 1
...@@ -33,7 +33,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -33,7 +33,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--count_threshold=0 \ --count_threshold=0 \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_paths "data/manifest.train.raw" --manifest_paths "data/manifest.train.raw"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated." echo "Build vocabulary failed. Terminated."
exit 1 exit 1
...@@ -56,7 +56,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -56,7 +56,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--num_samples=-1 \ --num_samples=-1 \
--num_workers=${num_workers} \ --num_workers=${num_workers} \
--output_path="data/mean_std.json" --output_path="data/mean_std.json"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated." echo "Compute mean and stddev failed. Terminated."
exit 1 exit 1
...@@ -67,19 +67,22 @@ fi ...@@ -67,19 +67,22 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size # format manifest with tokenids, vocab size
for dataset in train dev test; do for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \ --feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type "char" \ --unit_type "char" \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \ --manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}" --output_path="data/manifest.${dataset}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
} &
done done
wait
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
fi fi
echo "Aishell data preparation done." echo "Aishell data preparation done."
......
...@@ -21,17 +21,39 @@ ckpt_prefix=$2 ...@@ -21,17 +21,39 @@ ckpt_prefix=$2
# exit 1 # exit 1
#fi #fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
batch_size=64
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0 exit 0
# ASR # ASR
* s0 is for deepspeech2 * s0 is for deepspeech2
* s1 is for U2 * s1 is for transformer/conformer/U2
# LibriSpeech
## Conformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- |
| conformer | conf/conformer.yaml | spec_aug + shift | test-all | attention | test-all 6.35 | 0.057117 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.35 | 0.030162 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | test-all 6.35 | 0.037910 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | test-all 6.35 | 0.037761 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | test-all 6.35 | 0.032115 |
## Transformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- |
| transformer | conf/transformer.yaml | spec_aug + shift | test-all | attention | test-all 6.98 | 0.066500 |
| transformer | conf/transformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.98 | 0.036 |
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
data: data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test-clean
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000' spm_model_prefix: 'data/bpe_unigram_5000'
...@@ -14,7 +14,7 @@ data: ...@@ -14,7 +14,7 @@ data:
min_output_len: 0.0 # tokens min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
data: data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test-clean
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000' spm_model_prefix: 'data/bpe_unigram_5000'
......
...@@ -21,17 +21,39 @@ ckpt_prefix=$2 ...@@ -21,17 +21,39 @@ ckpt_prefix=$2
# exit 1 # exit 1
#fi #fi
python3 -u ${BIN_DIR}/test.py \ for type in attention ctc_greedy_search; do
--device ${device} \ echo "decoding ${type}"
--nproc 1 \ batch_size=64
--config ${config_path} \ python3 -u ${BIN_DIR}/test.py \
--result_file ${ckpt_prefix}.rsl \ --device ${device} \
--checkpoint_path ${ckpt_prefix} --nproc 1 \
--config ${config_path} \
if [ $? -ne 0 ]; then --result_file ${ckpt_prefix}.${type}.rsl \
echo "Failed in evaluation!" --checkpoint_path ${ckpt_prefix} \
exit 1 --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
fi
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0 exit 0
文件模式从 100644 更改为 100755
...@@ -8,4 +8,5 @@ SoundFile==0.9.0.post1 ...@@ -8,4 +8,5 @@ SoundFile==0.9.0.post1
sox sox
tensorboardX tensorboardX
typeguard typeguard
yacs yacs
\ No newline at end of file pybind11
...@@ -50,7 +50,9 @@ class TestU2Model(unittest.TestCase): ...@@ -50,7 +50,9 @@ class TestU2Model(unittest.TestCase):
def test_make_pad_mask(self): def test_make_pad_mask(self):
res = make_pad_mask(self.lengths) res = make_pad_mask(self.lengths)
res1 = make_non_pad_mask(self.lengths).logical_not()
self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res1.tolist())
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册