From 58d761f056b8958fe0c06ed11794ed95533880de Mon Sep 17 00:00:00 2001 From: Junkun Date: Tue, 20 Jul 2021 23:31:11 -0700 Subject: [PATCH] add related scripts of TIMIT --- .../timit/timit_kaldi_standard_split.py | 110 ++++++++++++++++++ examples/timit/README.md | 3 + examples/timit/conf/augmentation.json | 34 ++++++ examples/timit/conf/dev_spk.list | 50 ++++++++ examples/timit/conf/phones.60-48-39.map | 61 ++++++++++ examples/timit/conf/test_spk.list | 24 ++++ examples/timit/conf/transformer.yaml | 110 ++++++++++++++++++ examples/timit/local/align.sh | 37 ++++++ examples/timit/local/data.sh | 87 ++++++++++++++ examples/timit/local/export.sh | 34 ++++++ examples/timit/local/test.sh | 71 +++++++++++ examples/timit/local/timit_data_prep.sh | 90 ++++++++++++++ examples/timit/local/timit_norm_trans.pl | 91 +++++++++++++++ examples/timit/local/train.sh | 33 ++++++ examples/timit/path.sh | 13 +++ examples/timit/run.sh | 45 +++++++ 16 files changed, 893 insertions(+) create mode 100644 examples/dataset/timit/timit_kaldi_standard_split.py create mode 100644 examples/timit/README.md create mode 100644 examples/timit/conf/augmentation.json create mode 100644 examples/timit/conf/dev_spk.list create mode 100644 examples/timit/conf/phones.60-48-39.map create mode 100644 examples/timit/conf/test_spk.list create mode 100644 examples/timit/conf/transformer.yaml create mode 100755 examples/timit/local/align.sh create mode 100755 examples/timit/local/data.sh create mode 100755 examples/timit/local/export.sh create mode 100755 examples/timit/local/test.sh create mode 100644 examples/timit/local/timit_data_prep.sh create mode 100644 examples/timit/local/timit_norm_trans.pl create mode 100755 examples/timit/local/train.sh create mode 100644 examples/timit/path.sh create mode 100755 examples/timit/run.sh diff --git a/examples/dataset/timit/timit_kaldi_standard_split.py b/examples/dataset/timit/timit_kaldi_standard_split.py new file mode 100644 index 00000000..beb5a63e --- /dev/null +++ b/examples/dataset/timit/timit_kaldi_standard_split.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare TIMIT dataset (Standard split from Kaldi) + +Create manifest files from splited datased. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os + +import soundfile + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--src_dir", + default="", + type=str, + help="Directory to kaldi splited data. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + + data_types = ['train', 'dev', 'test'] + for dtype in data_types: + del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + + phn_path = os.path.join(data_dir, dtype+'.text') + phn_dict = {} + for line in codecs.open(phn_path, 'r', 'utf-8'): + line = line.strip() + if line == '': + continue + audio_id, text = line.split(' ', 1) + phn_dict[audio_id] = text + + audio_dir = os.path.join(data_dir, dtype+'_sph.scp') + for line in codecs.open(audio_dir, 'r', 'utf-8'): + audio_id, audio_path = line.strip().split() + # if no transcription for audio then raise error + assert audio_id in phn_dict + audio_data, samplerate = soundfile.read(audio_path) + duration = float(len(audio_data) / samplerate) + text = phn_dict[audio_id] + json_lines.append( + json.dumps( + { + 'utt': audio_id, + 'feat': audio_path, + 'feat_shape': (duration, ), # second + 'text': text + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(text) + total_num += 1 + + manifest_path = manifest_path_prefix + '.' + dtype + '.raw' + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + +def prepare_dataset(src_dir, manifest_path=None): + """create manifest file.""" + if os.path.isdir(manifest_path): + manifest_path = os.path.join(manifest_path, 'manifest') + if manifest_path: + create_manifest(src_dir, manifest_path) + + +def main(): + if args.src_dir.startswith('~'): + args.src_dir = os.path.expanduser(args.src_dir) + + prepare_dataset( + src_dir=args.src_dir, + manifest_path=args.manifest_prefix) + + print("manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/examples/timit/README.md b/examples/timit/README.md new file mode 100644 index 00000000..2dd8a719 --- /dev/null +++ b/examples/timit/README.md @@ -0,0 +1,3 @@ +# TIMIT + +Results will be organized and updated soon. \ No newline at end of file diff --git a/examples/timit/conf/augmentation.json b/examples/timit/conf/augmentation.json new file mode 100644 index 00000000..c1078393 --- /dev/null +++ b/examples/timit/conf/augmentation.json @@ -0,0 +1,34 @@ +[ + { + "type": "shift", + "params": { + "min_shift_ms": -5, + "max_shift_ms": 5 + }, + "prob": 1.0 + }, + { + "type": "speed", + "params": { + "min_speed_rate": 0.9, + "max_speed_rate": 1.1, + "num_rates": 3 + }, + "prob": 0.0 + }, + { + "type": "specaug", + "params": { + "F": 10, + "T": 50, + "n_freq_masks": 2, + "n_time_masks": 2, + "p": 1.0, + "W": 80, + "adaptive_number_ratio": 0, + "adaptive_size_ratio": 0, + "max_n_time_masks": 20 + }, + "prob": 1.0 + } +] diff --git a/examples/timit/conf/dev_spk.list b/examples/timit/conf/dev_spk.list new file mode 100644 index 00000000..edcb3ef7 --- /dev/null +++ b/examples/timit/conf/dev_spk.list @@ -0,0 +1,50 @@ +faks0 +fdac1 +fjem0 +mgwt0 +mjar0 +mmdb1 +mmdm2 +mpdf0 +fcmh0 +fkms0 +mbdg0 +mbwm0 +mcsh0 +fadg0 +fdms0 +fedw0 +mgjf0 +mglb0 +mrtk0 +mtaa0 +mtdt0 +mthc0 +mwjg0 +fnmr0 +frew0 +fsem0 +mbns0 +mmjr0 +mdls0 +mdlf0 +mdvc0 +mers0 +fmah0 +fdrw0 +mrcs0 +mrjm4 +fcal1 +mmwh0 +fjsj0 +majc0 +mjsw0 +mreb0 +fgjd0 +fjmg0 +mroa0 +mteb0 +mjfc0 +mrjr0 +fmml0 +mrws1 \ No newline at end of file diff --git a/examples/timit/conf/phones.60-48-39.map b/examples/timit/conf/phones.60-48-39.map new file mode 100644 index 00000000..946f3bef --- /dev/null +++ b/examples/timit/conf/phones.60-48-39.map @@ -0,0 +1,61 @@ +aa aa aa +ae ae ae +ah ah ah +ao ao aa +aw aw aw +ax ax ah +ax-h ax ah +axr er er +ay ay ay +b b b +bcl vcl sil +ch ch ch +d d d +dcl vcl sil +dh dh dh +dx dx dx +eh eh eh +el el l +em m m +en en n +eng ng ng +epi epi sil +er er er +ey ey ey +f f f +g g g +gcl vcl sil +h# sil sil +hh hh hh +hv hh hh +ih ih ih +ix ix ih +iy iy iy +jh jh jh +k k k +kcl cl sil +l l l +m m m +n n n +ng ng ng +nx n n +ow ow ow +oy oy oy +p p p +pau sil sil +pcl cl sil +q +r r r +s s s +sh sh sh +t t t +tcl cl sil +th th th +uh uh uh +uw uw uw +ux uw uw +v v v +w w w +y y y +z z z +zh zh sh \ No newline at end of file diff --git a/examples/timit/conf/test_spk.list b/examples/timit/conf/test_spk.list new file mode 100644 index 00000000..3cfa8f5d --- /dev/null +++ b/examples/timit/conf/test_spk.list @@ -0,0 +1,24 @@ +mdab0 +mwbt0 +felc0 +mtas1 +mwew0 +fpas0 +mjmp0 +mlnt0 +fpkt0 +mlll0 +mtls0 +fjlm0 +mbpm0 +mklt0 +fnlp0 +mcmj0 +mjdh0 +fmgd0 +mgrt0 +mnjm0 +fdhc0 +mjln0 +mpam0 +fmld0 \ No newline at end of file diff --git a/examples/timit/conf/transformer.yaml b/examples/timit/conf/transformer.yaml new file mode 100644 index 00000000..131173ce --- /dev/null +++ b/examples/timit/conf/transformer.yaml @@ -0,0 +1,110 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.5 # second + max_input_len: 30.0 # second + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.05 + max_output_input_ratio: 100.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: "word" + mean_std_filepath: "" + augmentation_config: "" + batch_size: 64 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + +# network architecture +model: + cmvn_file: "data/mean_std.json" + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +training: + n_epoch: 120 + accum_grad: 2 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 0.004 + weight_decay: 1e-06 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + + +decoding: + batch_size: 64 + error_rate_type: wer + decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. + + diff --git a/examples/timit/local/align.sh b/examples/timit/local/align.sh new file mode 100755 index 00000000..ad6c84bc --- /dev/null +++ b/examples/timit/local/align.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/timit/local/data.sh b/examples/timit/local/data.sh new file mode 100755 index 00000000..1d16f454 --- /dev/null +++ b/examples/timit/local/data.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +unit_type=word +TIMIT_path= + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/timit/timit_kaldi_standard_split.py \ + --manifest_prefix="data/manifest" \ + --src="data/local" \ + + if [ $? -ne 0 ]; then + echo "Prepare TIMIT failed. Terminated." + exit 1 + fi +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # build vocabulary + python3 ${MAIN_ROOT}/utils/build_vocab.py \ + --unit_type ${unit_type} \ + --count_threshold=0 \ + --vocab_path="data/vocab.txt" \ + --manifest_paths="data/manifest.train.raw" + + if [ $? -ne 0 ]; then + echo "Build vocabulary failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=-1 \ + --specgram_type="fbank" \ + --feat_dim=80 \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=25.0 \ + --use_dB_normalization=False \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type ${unit_type} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest.${set} failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "TIMIT Data preparation done." +exit 0 diff --git a/examples/timit/local/export.sh b/examples/timit/local/export.sh new file mode 100755 index 00000000..f99a15ba --- /dev/null +++ b/examples/timit/local/export.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_path_prefix=$2 +jit_model_export_path=$3 + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi + +python3 -u ${BIN_DIR}/export.py \ +--device ${device} \ +--nproc ${ngpu} \ +--config ${config_path} \ +--checkpoint_path ${ckpt_path_prefix} \ +--export_path ${jit_model_export_path} + + +if [ $? -ne 0 ]; then + echo "Failed in export!" + exit 1 +fi + + +exit 0 diff --git a/examples/timit/local/test.sh b/examples/timit/local/test.sh new file mode 100755 index 00000000..fe01d700 --- /dev/null +++ b/examples/timit/local/test.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi + +config_path=$1 +ckpt_prefix=$2 + +chunk_mode=false +if [[ ${config_path} =~ ^chunk_ ]];then + chunk_mode=true +fi + + +# download language model +#bash local/download_lm_en.sh +#if [ $? -ne 0 ]; then +# exit 1 +#fi + +for type in attention ctc_greedy_search; do + echo "decoding ${type}" + if [ ${chunk_mode} == true ];then + # stream decoding only support batchsize=1 + batch_size=1 + else + batch_size=64 + fi + python3 -u ${BIN_DIR}/test.py \ + --device ${device} \ + --nproc 1 \ + --config ${config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done + +for type in ctc_prefix_beam_search attention_rescoring; do + echo "decoding ${type}" + batch_size=1 + python3 -u ${BIN_DIR}/test.py \ + --device ${device} \ + --nproc 1 \ + --config ${config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done + + +exit 0 diff --git a/examples/timit/local/timit_data_prep.sh b/examples/timit/local/timit_data_prep.sh new file mode 100644 index 00000000..22e6f343 --- /dev/null +++ b/examples/timit/local/timit_data_prep.sh @@ -0,0 +1,90 @@ +#!/usr/bin/env bash + +# Copyright 2013 (Authors: Bagher BabaAli, Daniel Povey, Arnab Ghoshal) +# 2014 Brno University of Technology (Author: Karel Vesely) +# Apache 2.0. + +if [ $# -ne 1 ]; then + echo "Argument should be the Timit directory, see ../run.sh for example." + exit 1; +fi + +dir=`pwd`/data/local +mkdir -p $dir +local=`pwd`/local +utils=`pwd`/utils +conf=`pwd`/conf + +[ -f $conf/test_spk.list ] || error_exit "$PROG: Eval-set speaker list not found."; +[ -f $conf/dev_spk.list ] || error_exit "$PROG: dev-set speaker list not found."; + +# First check if the train & test directories exist (these can either be upper- +# or lower-cased +if [ ! -d $*/TRAIN -o ! -d $*/TEST ] && [ ! -d $*/train -o ! -d $*/test ]; then + echo "timit_data_prep.sh: Spot check of command line argument failed" + echo "Command line argument must be absolute pathname to TIMIT directory" + echo "with name like /export/corpora5/LDC/LDC93S1/timit/TIMIT" + exit 1; +fi + +# Now check what case the directory structure is +uppercased=false +train_dir=train +test_dir=test +if [ -d $*/TRAIN ]; then + uppercased=true + train_dir=TRAIN + test_dir=TEST +fi + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT + +# Get the list of speakers. The list of speakers in the 24-speaker core test +# set and the 50-speaker development set must be supplied to the script. All +# speakers in the 'train' directory are used for training. +if $uppercased; then + tr '[:lower:]' '[:upper:]' < $conf/dev_spk.list > $tmpdir/dev_spk + tr '[:lower:]' '[:upper:]' < $conf/test_spk.list > $tmpdir/test_spk + ls -d "$*"/TRAIN/DR*/* | sed -e "s:^.*/::" > $tmpdir/train_spk +else + tr '[:upper:]' '[:lower:]' < $conf/dev_spk.list > $tmpdir/dev_spk + tr '[:upper:]' '[:lower:]' < $conf/test_spk.list > $tmpdir/test_spk + ls -d "$*"/train/dr*/* | sed -e "s:^.*/::" > $tmpdir/train_spk +fi + +cd $dir +for x in train dev test; do + # First, find the list of audio files (use only si & sx utterances). + # Note: train & test sets are under different directories, but doing find on + # both and grepping for the speakers will work correctly. + find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.WAV' \ + | grep -f $tmpdir/${x}_spk > ${x}_sph.flist + + sed -e 's:.*/\(.*\)/\(.*\).\(WAV\|wav\)$:\1_\2:' ${x}_sph.flist \ + > $tmpdir/${x}_sph.uttids + paste $tmpdir/${x}_sph.uttids ${x}_sph.flist \ + | sort -k1,1 > ${x}_sph.scp + + cat ${x}_sph.scp | awk '{print $1}' > ${x}.uttids + + # Now, Convert the transcripts into our format (no normalization yet) + # Get the transcripts: each line of the output contains an utterance + # ID followed by the transcript. + find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.PHN' \ + | grep -f $tmpdir/${x}_spk > $tmpdir/${x}_phn.flist + sed -e 's:.*/\(.*\)/\(.*\).\(PHN\|phn\)$:\1_\2:' $tmpdir/${x}_phn.flist \ + > $tmpdir/${x}_phn.uttids + while read line; do + [ -f $line ] || error_exit "Cannot find transcription file '$line'"; + cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;' + done < $tmpdir/${x}_phn.flist > $tmpdir/${x}_phn.trans + paste $tmpdir/${x}_phn.uttids $tmpdir/${x}_phn.trans \ + | sort -k1,1 > ${x}.trans + + # Do normalization steps. + cat ${x}.trans | $local/timit_norm_trans.pl -i - -m $conf/phones.60-48-39.map -to 39 | sort > $x.text || exit 1; + +done + +echo "Data preparation succeeded" \ No newline at end of file diff --git a/examples/timit/local/timit_norm_trans.pl b/examples/timit/local/timit_norm_trans.pl new file mode 100644 index 00000000..702d9b15 --- /dev/null +++ b/examples/timit/local/timit_norm_trans.pl @@ -0,0 +1,91 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter + +# Copyright 2012 Arnab Ghoshal + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script normalizes the TIMIT phonetic transcripts that have been +# extracted in a format where each line contains an utterance ID followed by +# the transcript, e.g.: +# fcke0_si1111 h# hh ah dx ux w iy dcl d ix f ay n ih q h# + +my $usage = "Usage: timit_norm_trans.pl -i transcript -m phone_map -from [60|48] -to [48|39] > normalized\n +Normalizes phonetic transcriptions for TIMIT, by mapping the phones to a +smaller set defined by the -m option. This script assumes that the mapping is +done in the \"standard\" fashion, i.e. to 48 or 39 phones. The input is +assumed to have 60 phones (+1 for glottal stop, which is deleted), but that can +be changed using the -from option. The input format is assumed to be utterance +ID followed by transcript on the same line.\n"; + +use strict; +use Getopt::Long; +die "$usage" unless(@ARGV >= 1); +my ($in_trans, $phone_map, $num_phones_out); +my $num_phones_in = 60; +GetOptions ("i=s" => \$in_trans, # Input transcription + "m=s" => \$phone_map, # File containing phone mappings + "from=i" => \$num_phones_in, # Input #phones: must be 60 or 48 + "to=i" => \$num_phones_out ); # Output #phones: must be 48 or 39 + +die $usage unless(defined($in_trans) && defined($phone_map) && + defined($num_phones_out)); +if ($num_phones_in != 60 && $num_phones_in != 48) { + die "Can only used 60 or 48 for -from (used $num_phones_in)." +} +if ($num_phones_out != 48 && $num_phones_out != 39) { + die "Can only used 48 or 39 for -to (used $num_phones_out)." +} +unless ($num_phones_out < $num_phones_in) { + die "Argument to -from ($num_phones_in) must be greater than that to -to ($num_phones_out)." +} + + +open(M, "<$phone_map") or die "Cannot open mappings file '$phone_map': $!"; +my (%phonemap, %seen_phones); +my $num_seen_phones = 0; +while () { + chomp; + next if ($_ =~ /^q\s*.*$/); # Ignore glottal stops. + m:^(\S+)\s+(\S+)\s+(\S+)$: or die "Bad line: $_"; + my $mapped_from = ($num_phones_in == 60)? $1 : $2; + my $mapped_to = ($num_phones_out == 48)? $2 : $3; + if (!defined($seen_phones{$mapped_to})) { + $seen_phones{$mapped_to} = 1; + $num_seen_phones += 1; + } + $phonemap{$mapped_from} = $mapped_to; +} +if ($num_seen_phones != $num_phones_out) { + die "Trying to map to $num_phones_out phones, but seen only $num_seen_phones"; +} + +open(T, "<$in_trans") or die "Cannot open transcription file '$in_trans': $!"; +while () { + chomp; + $_ =~ m:^(\S+)\s+(.+): or die "Bad line: $_"; + my $utt_id = $1; + my $trans = $2; + + $trans =~ s/q//g; # Remove glottal stops. + $trans =~ s/^\s*//; $trans =~ s/\s*$//; # Normalize spaces + + print $utt_id; + for my $phone (split(/\s+/, $trans)) { + if(exists $phonemap{$phone}) { print " $phonemap{$phone}"; } + if(not exists $phonemap{$phone}) { print " $phone"; } + } + print "\n"; +} \ No newline at end of file diff --git a/examples/timit/local/train.sh b/examples/timit/local/train.sh new file mode 100755 index 00000000..f3eb98da --- /dev/null +++ b/examples/timit/local/train.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_name=$2 + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +echo "using ${device}..." + +mkdir -p exp + +python3 -u ${BIN_DIR}/train.py \ +--device ${device} \ +--nproc ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/examples/timit/path.sh b/examples/timit/path.sh new file mode 100644 index 00000000..95427277 --- /dev/null +++ b/examples/timit/path.sh @@ -0,0 +1,13 @@ +export MAIN_ROOT=${PWD}/../../ +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + + +MODEL=u2 +export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin diff --git a/examples/timit/run.sh b/examples/timit/run.sh new file mode 100755 index 00000000..d2b7f362 --- /dev/null +++ b/examples/timit/run.sh @@ -0,0 +1,45 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=50 +conf_path=conf/transformer.yaml +avg_num=10 +TIMIT_path= +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + bash ./local/timit_data_prep.sh ${TIMIT_path} + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # export ckpt avg_n + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit +fi -- GitLab