提交 e55b5baf 编写于 作者: H Hui Zhang

fix data

上级 e5347c48
...@@ -168,7 +168,7 @@ if __name__ == "__main__": ...@@ -168,7 +168,7 @@ if __name__ == "__main__":
default=False, default=False,
help="Whether use gpu.") help="Whether use gpu.")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
...@@ -180,7 +180,7 @@ if __name__ == "__main__": ...@@ -180,7 +180,7 @@ if __name__ == "__main__":
print(config) print(config)
args.warmup_manifest = config.data.test_manifest args.warmup_manifest = config.data.test_manifest
print_arguments(args) print_arguments(args, globals())
if args.dump_config: if args.dump_config:
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
......
...@@ -98,7 +98,7 @@ if __name__ == "__main__": ...@@ -98,7 +98,7 @@ if __name__ == "__main__":
"Directory to save demo audios.") "Directory to save demo audios.")
add_arg('warmup_manifest', str, None, "Filepath of manifest to warm up.") add_arg('warmup_manifest', str, None, "Filepath of manifest to warm up.")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
...@@ -110,7 +110,7 @@ if __name__ == "__main__": ...@@ -110,7 +110,7 @@ if __name__ == "__main__":
print(config) print(config)
args.warmup_manifest = config.data.test_manifest args.warmup_manifest = config.data.test_manifest
print_arguments(args) print_arguments(args, globals())
if args.dump_config: if args.dump_config:
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
......
...@@ -33,7 +33,7 @@ def main(config, args): ...@@ -33,7 +33,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
......
...@@ -32,7 +32,7 @@ def main(config, args): ...@@ -32,7 +32,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
......
...@@ -37,7 +37,7 @@ def main(config, args): ...@@ -37,7 +37,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
......
...@@ -168,7 +168,7 @@ if __name__ == "__main__": ...@@ -168,7 +168,7 @@ if __name__ == "__main__":
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
......
...@@ -33,7 +33,7 @@ def main(config, args): ...@@ -33,7 +33,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
......
...@@ -34,7 +34,7 @@ def main(config, args): ...@@ -34,7 +34,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
......
...@@ -38,7 +38,7 @@ def main(config, args): ...@@ -38,7 +38,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
......
...@@ -34,11 +34,12 @@ class TextFeaturizer(object): ...@@ -34,11 +34,12 @@ class TextFeaturizer(object):
""" """
assert unit_type in ('char', 'spm', 'word') assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type self.unit_type = unit_type
self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
self.unk = UNK self.unk = UNK
self.unk_id = self._vocab_list.index(self.unk) if vocab_filepath:
self.eos_id = self._vocab_list.index(EOS) self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
self.unk_id = self._vocab_list.index(self.unk)
self.eos_id = self._vocab_list.index(EOS)
if unit_type == 'spm': if unit_type == 'spm':
spm_model = spm_model_prefix + '.model' spm_model = spm_model_prefix + '.model'
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains common utility functions.""" """Contains common utility functions."""
import os
import math import math
import distutils.util import distutils.util
from typing import List from typing import List
...@@ -20,7 +21,7 @@ from typing import List ...@@ -20,7 +21,7 @@ from typing import List
__all__ = ['print_arguments', 'add_arguments', "log_add"] __all__ = ['print_arguments', 'add_arguments', "log_add"]
def print_arguments(args): def print_arguments(args, info=None):
"""Print argparse's arguments. """Print argparse's arguments.
Usage: Usage:
...@@ -35,10 +36,14 @@ def print_arguments(args): ...@@ -35,10 +36,14 @@ def print_arguments(args):
:param args: Input argparse.Namespace for printing. :param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace :type args: argparse.Namespace
""" """
print("----------- Configuration Arguments -----------") filename = ""
if info:
filename = info["__file__"]
filename = os.path.basename(filename)
print(f"----------- {filename} Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()): for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value)) print("%s: %s" % (arg, value))
print("------------------------------------------------") print("-----------------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs): def add_arguments(argname, type, default, help, argparser, **kwargs):
......
* s0 for deepspeech2 * s0 for deepspeech2
* s1 for u2
#! /usr/bin/env bash #! /usr/bin/env bash
mkdir -p data mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR} mkdir -p ${TARGET_DIR}
# download data, generate manifests # download data, generate manifests
PYTHONPATH=.:$PYTHONPATH python3 ${TARGET_DIR}/aishell/aishell.py \ python3 ${TARGET_DIR}/aishell/aishell.py \
--manifest_prefix="data/manifest" \ --manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/aishell" --target_dir="${TARGET_DIR}/aishell"
...@@ -16,11 +15,17 @@ if [ $? -ne 0 ]; then ...@@ -16,11 +15,17 @@ if [ $? -ne 0 ]; then
fi fi
for dataset in train dev test; do
mv data/manifest.${dataset} data/manifest.${dataset}.raw
done
# build vocabulary # build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \ python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type="char" \
--count_threshold=0 \ --count_threshold=0 \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_paths "data/manifest.train" "data/manifest.dev" --manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated." echo "Build vocabulary failed. Terminated."
...@@ -30,9 +35,11 @@ fi ...@@ -30,9 +35,11 @@ fi
# compute mean and stddev for normalizer # compute mean and stddev for normalizer
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train" \ --manifest_path="data/manifest.train.raw" \
--num_samples=2000 \ --num_samples=2000 \
--specgram_type="linear" \ --specgram_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--output_path="data/mean_std.npz" --output_path="data/mean_std.npz"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
...@@ -41,5 +48,21 @@ if [ $? -ne 0 ]; then ...@@ -41,5 +48,21 @@ if [ $? -ne 0 ]; then
fi fi
# format manifest with tokenids, vocab size
for dataset in train dev test; do
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}"
done
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
echo "Aishell data preparation done." echo "Aishell data preparation done."
exit 0 exit 0
#! /usr/bin/env bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/pretrain
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_fluid.tar.gz'
MD5=2bf0cc8b6d5da2a2a787b5cc36a496b5
TARGET=${DIR}/aishell_model_fluid.tar.gz
echo "Download Aishell model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download Aishell model!"
exit 1
fi
tar -zxvf $TARGET -C ${DIR}
exit 0
[
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
}
]
../../s0/local/data.sh
\ No newline at end of file
../../s0/local/download_lm_ch.sh
\ No newline at end of file
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: export ckpt_path jit_model_path"
exit -1
fi
python3 -u ${BIN_DIR}/export.py \
--config conf/deepspeech2.yaml \
--checkpoint_path ${1} \
--export_path ${2}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
#! /usr/bin/env bash
# download language model
bash local/download_lm_ch.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--output ckpt
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
#! /usr/bin/env bash
# train model
# if you wish to resume from an exists model, uncomment --init_from_pretrained_model
export FLAGS_sync_nccl_allreduce=0
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..."
python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \
--nproc ${ngpu} \
--config conf/deepspeech2.yaml \
--output ckpt-${1}
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0
export MAIN_ROOT=${PWD}/../../../
export PATH=${MAIN_ROOT}:${PWD}/tools:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
#!/bin/bash
source path.sh
# only demos
# prepare data
bash ./local/data.sh
# train model
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh
# test model
CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh
# infer model
CUDA_VISIBLE_DEVICES=0 bash ./local/infer.sh ckpt/checkpoints/step-3284
# export model
bash ./local/export.sh ckpt/checkpoints/step-3284 jit.model
\ No newline at end of file
...@@ -5,7 +5,7 @@ TARGET_DIR=${MAIN_ROOT}/examples/dataset ...@@ -5,7 +5,7 @@ TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR} mkdir -p ${TARGET_DIR}
# download data, generate manifests # download data, generate manifests
PYTHONPATH=.:$PYTHONPATH python3 ${TARGET_DIR}/librispeech/librispeech.py \ python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \ --manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \ --target_dir="${TARGET_DIR}/librispeech" \
--full_download="False" --full_download="False"
...@@ -24,7 +24,7 @@ bpeprefix="data/bpe_${bpemode}_${nbpe}" ...@@ -24,7 +24,7 @@ bpeprefix="data/bpe_${bpemode}_${nbpe}"
# build vocabulary # build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \ python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type "spm" \ --unit_type "spm" \
--vocab_size=${nbpe} \ --spm_vocab_size=${nbpe} \
--spm_mode ${bpemode} \ --spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \ --spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -45,9 +45,9 @@ add_arg('manifest_paths', str, ...@@ -45,9 +45,9 @@ add_arg('manifest_paths', str,
nargs='+', nargs='+',
required=True) required=True)
# bpe # bpe
add_arg('vocab_size', int, 0, "Vocab size for spm.") add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)", "spm model prefix, only need when `unit_type` is spm") add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
# yapf: disable # yapf: disable
args = parser.parse_args() args = parser.parse_args()
...@@ -64,7 +64,7 @@ def dump_text_manifest(fileobj, manifest_path): ...@@ -64,7 +64,7 @@ def dump_text_manifest(fileobj, manifest_path):
fileobj.write(line_json['text'] + "\n") fileobj.write(line_json['text'] + "\n")
def main(): def main():
print_arguments(args) print_arguments(args, globals())
fout = open(args.vocab_path, 'w', encoding='utf-8') fout = open(args.vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
...@@ -91,7 +91,7 @@ def main(): ...@@ -91,7 +91,7 @@ def main():
os.unlink(fp.name) os.unlink(fp.name)
# encode # encode
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) text_feature = TextFeaturizer(args.unit_type, "", args.spm_model_prefix)
counter = Counter() counter = Counter()
for manifest_path in args.manifest_paths: for manifest_path in args.manifest_paths:
......
...@@ -46,7 +46,7 @@ args = parser.parse_args() ...@@ -46,7 +46,7 @@ args = parser.parse_args()
def main(): def main():
print_arguments(args) print_arguments(args, globals())
augmentation_pipeline = AugmentationPipeline('{}') augmentation_pipeline = AugmentationPipeline('{}')
audio_featurizer = AudioFeaturizer( audio_featurizer = AudioFeaturizer(
......
...@@ -48,12 +48,12 @@ args = parser.parse_args() ...@@ -48,12 +48,12 @@ args = parser.parse_args()
def main(): def main():
print_arguments(args) print_arguments(args, globals())
fout = open(args.output_path, 'w', encoding='utf-8') fout = open(args.output_path, 'w', encoding='utf-8')
# get feat dim # get feat dim
mean, std = load_cmvn(args.cmvn_path, filetype='npz') mean, std = load_cmvn(args.cmvn_path, filetype='npz')
feat_dim = mean.shape[0] feat_dim = mean.shape[1] #(1, D)
print(f"Feature dim: {feat_dim}") print(f"Feature dim: {feat_dim}")
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册