提交 561d5cf0 编写于 作者: H Hui Zhang

refactor feature, dict and argument for new config format

上级 27daa92a
......@@ -42,6 +42,10 @@ ignore =
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
per-file-ignores =
*/__init__.py: F401
# Specify the list of error codes you wish Flake8 to report.
select =
E,
......
......@@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument("--model_type")
args = parser.parse_args()
if args.model_type is None:
......
......@@ -31,6 +31,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
......
......@@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -34,6 +34,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Evaluation for U2 model."""
import cProfile
from yacs.config import CfgNode
from deepspeech.training.cli import default_argument_parser
......@@ -54,6 +55,14 @@ if __name__ == "__main__":
type=str,
default='test',
help='run mode, e.g. test, align, export')
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
# save asr result to
parser.add_argument(
"--result-file", type=str, help="path of save the asr result")
# save jit model to
parser.add_argument(
"--export-path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -25,6 +25,8 @@ import paddle
from paddle import distributed as dist
from yacs.config import CfgNode
from deepspeech.frontend.featurizer import TextFeaturizer
from deepspeech.frontend.utility import load_dict
from deepspeech.io.dataloader import BatchDataLoader
from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory
......@@ -80,8 +82,8 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
......@@ -124,6 +126,7 @@ class U2Trainer(Trainer):
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
......@@ -305,10 +308,8 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.train_loader.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
......@@ -379,13 +380,13 @@ class U2Tester(U2Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def ordid2token(self, texts, texts_len):
def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
trans.append(text_feature.defeaturize(ids.numpy().tolist()))
return trans
def compute_metrics(self,
......@@ -401,8 +402,11 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
target_transcripts = self.id2token(texts, texts_len, text_feature)
result_transcripts = self.model.decode(
audio,
audio_len,
......@@ -450,7 +454,7 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.collate_fn.stride_ms
stride_ms = self.config.collator.stride_ms
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
......@@ -525,8 +529,9 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.config.collate.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
stride_ms = self.config.collater.stride_ms
token_dict = self.args.char_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
......@@ -613,6 +618,11 @@ class U2Tester(U2Trainer):
except KeyboardInterrupt:
sys.exit(-1)
def setup_dict(self):
# load dictionary for debug log
self.args.char_list = load_dict(self.args.dict_path,
"maskctc" in self.args.model_name)
def setup(self):
"""Setup the experiment.
"""
......@@ -624,6 +634,8 @@ class U2Tester(U2Trainer):
self.setup_dataloader()
self.setup_model()
self.setup_dict()
self.iteration = 0
self.epoch = 0
......
......@@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -34,6 +34,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .audio_featurizer import AudioFeaturizer #noqa: F401
from .speech_featurizer import SpeechFeaturizer
from .text_featurizer import TextFeaturizer
......@@ -18,7 +18,7 @@ from python_speech_features import logfbank
from python_speech_features import mfcc
class AudioFeaturizer(object):
class AudioFeaturizer():
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
......
......@@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object):
class SpeechFeaturizer():
"""Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment.
......
......@@ -14,12 +14,19 @@
"""Contains the text featurizer class."""
import sentencepiece as spm
from deepspeech.frontend.utility import EOS
from deepspeech.frontend.utility import UNK
from ..utility import EOS
from ..utility import load_dict
from ..utility import UNK
__all__ = ["TextFeaturizer"]
class TextFeaturizer(object):
def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None):
class TextFeaturizer():
def __init__(self,
unit_type,
vocab_filepath,
spm_model_prefix=None,
maskctc=False):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
......@@ -34,11 +41,12 @@ class TextFeaturizer(object):
assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type
self.unk = UNK
self.maskctc = maskctc
if vocab_filepath:
self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
self.unk_id = self._vocab_list.index(self.unk)
self.eos_id = self._vocab_list.index(EOS)
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file(
vocab_filepath, maskctc)
self.vocab_size = len(self.vocab_list)
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
......@@ -67,7 +75,7 @@ class TextFeaturizer(object):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
text (str): Text.
Returns:
List[int]: List of token indices.
......@@ -75,8 +83,8 @@ class TextFeaturizer(object):
tokens = self.tokenize(text)
ids = []
for token in tokens:
token = token if token in self._vocab_dict else self.unk
ids.append(self._vocab_dict[token])
token = token if token in self.vocab_dict else self.unk
ids.append(self.vocab_dict[token])
return ids
def defeaturize(self, idxs):
......@@ -87,7 +95,7 @@ class TextFeaturizer(object):
idxs (List[int]): List of token indices.
Returns:
str: Text to process.
str: Text.
"""
tokens = []
for idx in idxs:
......@@ -97,33 +105,6 @@ class TextFeaturizer(object):
text = self.detokenize(tokens)
return text
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]: tokens.
"""
return self._vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]: token str -> int
"""
return self._vocab_dict
def char_tokenize(self, text):
"""Character tokenizer.
......@@ -206,14 +187,16 @@ class TextFeaturizer(object):
return decode(tokens)
def _load_vocabulary_from_file(self, vocab_filepath):
def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r', encoding='utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)])
return token2id, id2token, vocab_list
unk_id = vocab_list.index(UNK)
eos_id = vocab_list.index(EOS)
return token2id, id2token, vocab_list, unk_id, eos_id
......@@ -15,6 +15,9 @@
import codecs
import json
import math
from typing import List
from typing import Optional
from typing import Text
import numpy as np
......@@ -23,16 +26,35 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs",
"mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK",
"BLANK"
"load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
"max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
"EOS", "UNK", "BLANK", "MASKCTC"
]
IGNORE_ID = -1
SOS = "<sos/eos>"
# `sos` and `eos` using same token
SOS = "<eos>"
EOS = SOS
UNK = "<unk>"
BLANK = "<blank>"
MASKCTC = "<mask>"
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
if dict_path is None:
return None
with open(dict_path, "r") as f:
dictionary = f.readlines()
char_list = [entry.split(" ")[0] for entry in dictionary]
if BLANK not in char_list:
char_list.insert(0, BLANK)
if EOS not in char_list:
char_list.append(EOS)
# for non-autoregressive maskctc model
if maskctc and MASKCTC not in char_list:
char_list.append(MASKCTC)
return char_list
def read_manifest(
......@@ -47,12 +69,20 @@ def read_manifest(
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
......
......@@ -47,18 +47,11 @@ def default_argument_parser():
# data and output
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.")
# parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.")
# load from saved checkpoint
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
# save jit model to
parser.add_argument("--export_path", type=str, help="path of the jit model to save")
# save asr result to
parser.add_argument("--result_file", type=str, help="path of save the asr result")
# running
parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"],
help="device type to use, cpu and gpu are supported.")
......
......@@ -33,4 +33,4 @@
},
"prob": 1.0
}
]
\ No newline at end of file
]
......@@ -3,17 +3,11 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
vocab_filepath: data/train_960_unigram5000_units.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
spm_model_prefix: 'data/train_960_unigram5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1
fi
......@@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
dict_path=$2
ckpt_prefix=$3
batch_size=1
output_dir=${ckpt_prefix}
......@@ -22,11 +23,13 @@ mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/test.py \
--run_mode 'align' \
--model-name 'u2_kaldi' \
--run-mode 'align' \
--dict-path ${dict_path} \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--result-file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
......
......@@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then
fi
python3 -u ${BIN_DIR}/test.py \
--run_mode 'export' \
--model-name 'u2_kaldi' \
--run-mode 'export' \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1
fi
......@@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then
fi
config_path=$1
ckpt_prefix=$2
dict_path=$2
ckpt_prefix=$3
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
......@@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--run_mode test \
--model-name u2_kaldi \
--run-mode test \
--dict-path ${dict_path} \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--result-file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
......@@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--run_mode test \
--model-name u2_kaldi \
--run-mode test \
--dict-path ${dict_path} \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--result-file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
......
......@@ -5,6 +5,7 @@ source path.sh
stage=0
stop_stage=100
conf_path=conf/transformer.yaml
dict_path=data/train_960_unigram5000_units.txt
avg_num=5
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......@@ -29,12 +30,12 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
......
......@@ -29,8 +29,7 @@
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true,
"warp_mode": "PIL"
"replace_with_zero": true
},
"prob": 1.0
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册