提交 8e73d184 编写于 作者: H Hui Zhang

tiny/s0/s1 can run all

上级 6f83be1a
......@@ -18,8 +18,10 @@ import numpy as np
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
......@@ -78,26 +80,31 @@ def inference(config, args):
def start_server(config, args):
"""Start the ASR server"""
config.defrost()
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
config.data.manifest = config.data.test_manifest
dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config,
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
config.collator.batch_size = 1
config.collator.num_workers = 0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path)
model.eval()
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
feature = test_loader.collate_fn.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, T, D]
audio_len = feature[0].shape[0]
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
vocab_list=test_loader.collate_fn.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
......@@ -138,7 +145,7 @@ if __name__ == "__main__":
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('host_port', int, 8089, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
......
......@@ -16,8 +16,10 @@ import functools
import numpy as np
import paddle
from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
......@@ -31,26 +33,35 @@ from deepspeech.utils.utility import print_arguments
def start_server(config, args):
"""Start the ASR server"""
config.defrost()
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
config.data.manifest = config.data.test_manifest
dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config,
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
config.collator.batch_size = 1
config.collator.num_workers = 0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path)
model.eval()
# prepare ASR inference handler
def file_to_transcript(filename):
feature = dataset.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
audio_len = feature[0].shape[1]
feature = test_loader.collate_fn.process_utterance(filename, "")
audio = np.array([feature[0]]).astype('float32') #[1, T, D]
# audio = audio.swapaxes(1,2)
print('---file_to_transcript feature----')
print(audio.shape)
audio_len = feature[0].shape[0]
print(audio_len)
audio_len = np.array([audio_len]).astype('int64') # [1]
result_transcript = model.decode(
paddle.to_tensor(audio),
paddle.to_tensor(audio_len),
vocab_list=dataset.vocab_list,
vocab_list=test_loader.collate_fn.vocab_list,
decoding_method=config.decoding.decoding_method,
lang_model_path=config.decoding.lang_model_path,
beam_alpha=config.decoding.alpha,
......@@ -91,7 +102,7 @@ if __name__ == "__main__":
add_arg('host_ip', str,
'localhost',
"Server's IP address.")
add_arg('host_port', int, 8086, "Server's IP port.")
add_arg('host_port', int, 8088, "Server's IP port.")
add_arg('speech_save_dir', str,
'demo_cache',
"Directory to save demo audios.")
......
......@@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -30,6 +30,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -34,6 +34,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
......
......@@ -22,6 +22,8 @@ from deepspeech.exps.u2.model import U2Trainer as Trainer
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
# from deepspeech.exps.u2.trainer import U2Trainer as Trainer
def main_sp(config, args):
exp = Trainer(config, args)
......@@ -30,7 +32,7 @@ def main_sp(config, args):
def main(config, args):
if args.device == "gpu" and args.nprocs > 1:
if args.nprocs > 0:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
......
......@@ -73,11 +73,11 @@ class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
def train_batch(self, batch_index, batch, msg):
train_conf = self.config.training
start = time.time()
loss, attention_loss, ctc_loss = self.model(*batch_data)
loss, attention_loss, ctc_loss = self.model(*batch)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
......@@ -219,7 +219,7 @@ class U2Trainer(Trainer):
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False)
collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
......@@ -269,7 +269,7 @@ class U2Trainer(Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True))
logger.info("Setup train/valid/test Dataloader!")
def setup_model(self):
......@@ -345,7 +345,7 @@ class U2Tester(U2Trainer):
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
simulate_streaming=False, # simulate streaming inference. Defaults to False.
))
......@@ -428,7 +428,7 @@ class U2Tester(U2Trainer):
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
metrics = self.compute_metrics(*batch[:-1], fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
errors_sum += metrics['errors_sum']
......@@ -476,12 +476,12 @@ class U2Tester(U2Trainer):
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
# def run_test(self):
# self.resume_or_scratch()
# try:
# self.test()
# except KeyboardInterrupt:
# sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
......@@ -512,36 +512,36 @@ class U2Tester(U2Trainer):
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
# def run_export(self):
# try:
# self.export()
# except KeyboardInterrupt:
# sys.exit(-1)
# def setup(self):
# """Setup the experiment.
# """
# paddle.set_device(self.args.device)
# self.setup_output_dir()
# self.setup_checkpointer()
# self.setup_dataloader()
# self.setup_model()
# self.iteration = 0
# self.epoch = 0
# def setup_output_dir(self):
# """Create a directory used for output.
# """
# # output dir
# if self.args.output:
# output_dir = Path(self.args.output).expanduser()
# output_dir.mkdir(parents=True, exist_ok=True)
# else:
# output_dir = Path(
# self.args.checkpoint_path).expanduser().parent.parent
# output_dir.mkdir(parents=True, exist_ok=True)
# self.output_dir = output_dir
......@@ -14,12 +14,27 @@
"""Contains the text featurizer class."""
import sentencepiece as spm
from deepspeech.frontend.utility import EOS
from deepspeech.frontend.utility import UNK
from ..utility import EOS
from ..utility import SPACE
from ..utility import UNK
from ..utility import SOS
from ..utility import BLANK
from ..utility import MASKCTC
from ..utility import load_dict
from deepspeech.utils.log import Log
class TextFeaturizer(object):
def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None):
logger = Log(__name__).getlog()
__all__ = ["TextFeaturizer"]
class TextFeaturizer():
def __init__(self,
unit_type,
vocab_filepath,
spm_model_prefix=None,
maskctc=False):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
......@@ -34,20 +49,21 @@ class TextFeaturizer(object):
assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type
self.unk = UNK
self.maskctc = maskctc
if vocab_filepath:
self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
self.unk_id = self._vocab_list.index(self.unk)
self.eos_id = self._vocab_list.index(EOS)
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file(
vocab_filepath, maskctc)
self.vocab_size = len(self.vocab_list)
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(spm_model)
def tokenize(self, text):
def tokenize(self, text, replace_space=True):
if self.unit_type == 'char':
tokens = self.char_tokenize(text)
tokens = self.char_tokenize(text, replace_space)
elif self.unit_type == 'word':
tokens = self.word_tokenize(text)
else: # spm
......@@ -67,27 +83,27 @@ class TextFeaturizer(object):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
text (str): Text.
Returns:
List[int]: List of token indices.
"""
tokens = self.tokenize(text)
ids = []
for token in tokens:
token = token if token in self._vocab_dict else self.unk
ids.append(self._vocab_dict[token])
token = token if token in self.vocab_dict else self.unk
ids.append(self.vocab_dict[token])
return ids
def defeaturize(self, idxs):
"""Convert a list of token indices to text string,
ignore index after eos_id.
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text to process.
str: Text.
"""
tokens = []
for idx in idxs:
......@@ -97,43 +113,22 @@ class TextFeaturizer(object):
text = self.detokenize(tokens)
return text
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]: tokens.
"""
return self._vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]: token str -> int
"""
return self._vocab_dict
def char_tokenize(self, text):
def char_tokenize(self, text, replace_space=True):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
return list(text.strip())
text = text.strip()
if replace_space:
text_list = [SPACE if item == " " else item for item in list(text)]
else:
text_list = list(text)
return text_list
def char_detokenize(self, tokens):
"""Character detokenizer.
......@@ -144,6 +139,7 @@ class TextFeaturizer(object):
Returns:
str: text string.
"""
tokens = tokens.replace(SPACE, " ")
return "".join(tokens)
def word_tokenize(self, text):
......@@ -206,14 +202,28 @@ class TextFeaturizer(object):
return decode(tokens)
def _load_vocabulary_from_file(self, vocab_filepath):
def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r', encoding='utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None
logger.info(f"Vocab: {vocab_list}")
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)])
return token2id, id2token, vocab_list
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}")
logger.info(f"BLANK id: {blank_id}")
logger.info(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id
......@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains data helper functions."""
import codecs
import json
import math
import tarfile
from collections import namedtuple
from typing import List
from typing import Optional
from typing import Text
import jsonlines
import numpy as np
from deepspeech.utils.log import Log
......@@ -23,16 +28,40 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs",
"mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK",
"BLANK"
"load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
"max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
"EOS", "UNK", "BLANK", "MASKCTC", "SPACE"
]
IGNORE_ID = -1
SOS = "<sos/eos>"
# `sos` and `eos` using same token
SOS = "<eos>"
EOS = SOS
UNK = "<unk>"
BLANK = "<blank>"
MASKCTC = "<mask>"
SPACE = "<space>"
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
if dict_path is None:
return None
with open(dict_path, "r") as f:
dictionary = f.readlines()
# first token is `<blank>`
# multi line: `<blank> 0\n`
# one line: `<blank>`
# space is relpace with <space>
char_list = [entry[:-1].split(" ")[0] for entry in dictionary]
if BLANK not in char_list:
char_list.insert(0, BLANK)
if EOS not in char_list:
char_list.append(EOS)
# for non-autoregressive maskctc model
if maskctc and MASKCTC not in char_list:
char_list.append(MASKCTC)
return char_list
def read_manifest(
......@@ -47,12 +76,20 @@ def read_manifest(
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
......@@ -62,29 +99,70 @@ def read_manifest(
"""
manifest = []
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
feat_len = json_data["feat_shape"][
0] if 'feat_shape' in json_data else 1.0
token_len = json_data["token_shape"][
0] if 'token_shape' in json_data else 1.0
conditions = [
feat_len >= min_input_len,
feat_len <= max_input_len,
token_len >= min_output_len,
token_len <= max_output_len,
token_len / feat_len >= min_output_input_ratio,
token_len / feat_len <= max_output_input_ratio,
]
if all(conditions):
manifest.append(json_data)
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
feat_len = json_data["feat_shape"][
0] if 'feat_shape' in json_data else 1.0
token_len = json_data["token_shape"][
0] if 'token_shape' in json_data else 1.0
conditions = [
feat_len >= min_input_len,
feat_len <= max_input_len,
token_len >= min_output_len,
token_len <= max_output_len,
token_len / feat_len >= min_output_input_ratio,
token_len / feat_len <= max_output_input_ratio,
]
if all(conditions):
manifest.append(json_data)
return manifest
# Tar File read
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
def parse_tar(file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def subfile_from_tar(file, local_data=None):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if local_data is None:
local_data = TarLocalData(tar2info={}, tar2object={})
assert isinstance(local_data, TarLocalData)
if 'tar2info' not in local_data.__dict__:
local_data.tar2info = {}
if 'tar2object' not in local_data.__dict__:
local_data.tar2object = {}
if tarpath not in local_data.tar2info:
fobj, infos = parse_tar(tarpath)
local_data.tar2info[tarpath] = infos
local_data.tar2object[tarpath] = fobj
else:
fobj = local_data.tar2object[tarpath]
infos = local_data.tar2info[tarpath]
return fobj.extractfile(infos[filename])
def rms_to_db(rms: float):
"""Root Mean Square to dB.
......@@ -101,7 +179,7 @@ def rms_to_dbfs(rms: float):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
......@@ -116,26 +194,26 @@ def rms_to_dbfs(rms: float):
def max_dbfs(sample_data: np.ndarray):
"""Peak dBFS based on the maximum energy sample.
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
def mean_dbfs(sample_data):
"""Peak dBFS based on the RMS energy.
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
float: dBFS
"""
return rms_to_dbfs(
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
......@@ -155,7 +233,7 @@ def gain_db_to_ratio(gain_db: float):
def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
"""Nomalize audio to dBFS.
Args:
sample_data (np.ndarray): input wave samples, [-1, 1].
dbfs (float, optional): target dBFS. Defaults to -3.0103.
......@@ -254,6 +332,13 @@ def load_cmvn(cmvn_file: str, filetype: str):
cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file)
elif filetype == "npz":
eps = 1e-14
npzfile = np.load(cmvn_file)
mean = np.squeeze(npzfile["mean"])
std = np.squeeze(npzfile["std"])
istd = 1 / (std + eps)
cmvn = [mean, istd]
else:
raise ValueError(f"cmvn file type no support: {filetype}")
return cmvn[0], cmvn[1]
......@@ -23,7 +23,7 @@ logger = Log(__name__).getlog()
class SpeechCollator():
def __init__(self, keep_transcription_text=True):
def __init__(self, keep_transcription_text=True, return_utts=False):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
......@@ -31,6 +31,7 @@ class SpeechCollator():
if ``keep_transcription_text`` is False, text is token ids else is raw string.
"""
self._keep_transcription_text = keep_transcription_text
self.return_utts = return_utts
def __call__(self, batch):
"""batch examples
......@@ -51,7 +52,9 @@ class SpeechCollator():
audio_lens = []
texts = []
text_lens = []
for audio, text in batch:
utts = []
for utt, audio, text in batch:
utts.append(utt)
# audio
audios.append(audio.T) # [T, D]
audio_lens.append(audio.shape[1])
......@@ -75,4 +78,7 @@ class SpeechCollator():
padded_texts = pad_sequence(
texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64)
return padded_audios, audio_lens, padded_texts, text_lens
if self.return_utts:
return padded_audios, audio_lens, padded_texts, text_lens, utts
else:
return padded_audios, audio_lens, padded_texts, text_lens
\ No newline at end of file
......@@ -347,4 +347,5 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx):
instance = self._manifest[idx]
return self.process_utterance(instance["feat"], instance["text"])
feat, text = self.process_utterance(instance["feat"], instance["text"])
return instance["utt"], feat, text
......@@ -26,9 +26,9 @@ __all__ = ['ConvStack', "conv_output_size"]
def conv_output_size(I, F, P, S):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
......@@ -45,7 +45,7 @@ def conv_output_size(I, F, P, S):
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer):
......@@ -58,8 +58,8 @@ class ConvBn(nn.Layer):
:type num_channels_in: int
:param num_channels_out: Number of output channels.
:type num_channels_out: int
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
......@@ -114,7 +114,7 @@ class ConvBn(nn.Layer):
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
......
......@@ -219,15 +219,17 @@ class DeepSpeech2Model(nn.Layer):
The model built from pretrained result.
"""
model = cls(
feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
#feat_size=dataloader.collate_fn.feature_size,
feat_size=dataloader.dataset.feature_size,
#dict_size=dataloader.collate_fn.vocab_size,
dict_size=dataloader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights,
blank_id=config.model.blank_id,
ctc_grad_norm_type=config.ctc_grad_norm_type, )
ctc_grad_norm_type=config.model.ctc_grad_norm_type, )
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
......@@ -260,24 +262,8 @@ class DeepSpeech2Model(nn.Layer):
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, audio, audio_len):
"""export model function
......
......@@ -29,13 +29,13 @@ __all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
......@@ -92,7 +92,7 @@ class RNNCell(nn.RNNCellBase):
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
......@@ -101,8 +101,8 @@ class GRUCell(nn.RNNCellBase):
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
......@@ -309,6 +309,6 @@ class RNNStack(nn.Layer):
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
......@@ -255,22 +255,24 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
blank_id=0, # index of blank in vocob.txt
))
ctc_grad_norm_type='instance', ))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0):
def __init__(
self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0,
ctc_grad_norm_type='instance', ):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
......@@ -290,7 +292,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type='instance')
grad_norm_type=ctc_grad_norm_type)
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
......@@ -348,16 +350,18 @@ class DeepSpeech2ModelOnline(nn.Layer):
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru,
blank_id=config.model.blank_id)
model = cls(
feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru,
blank_id=config.model.blank_id,
ctc_grad_norm_type=config.model.ctc_grad_norm_type, )
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
......@@ -376,42 +380,24 @@ class DeepSpeech2ModelOnline(nn.Layer):
DeepSpeech2ModelOnline
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru,
blank_id=config.blank_id)
model = cls(
feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru,
blank_id=config.blank_id,
ctc_grad_norm_type=config.ctc_grad_norm_type, )
return model
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru,
blank_id=blank_id)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box):
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from contextlib import contextmanager
from pathlib import Path
import paddle
......@@ -29,37 +30,37 @@ logger = Log(__name__).getlog()
class Trainer():
"""
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
visualizer and logger) in a standard way without enforcing any
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
customize all the text and visual logs).
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
non-standard behaviors if needed.
We have some conventions to follow.
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
``valid_loader``, ``config`` and ``args`` attributes.
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
experiment.
3. There are four methods, namely ``train_batch``, ``valid``,
3. There are four methods, namely ``train_batch``, ``valid``,
``setup_model`` and ``setup_dataloader`` that should be implemented.
Feel free to add/overwrite other methods and standalone functions if you
Feel free to add/overwrite other methods and standalone functions if you
need.
Parameters
----------
config: yacs.config.CfgNode
The configuration used for the experiment.
args: argparse.Namespace
The parsed command line arguments.
Examples
......@@ -68,17 +69,17 @@ class Trainer():
>>> exp = Trainer(config, args)
>>> exp.setup()
>>> exp.run()
>>>
>>>
>>> config = get_cfg_defaults()
>>> parser = default_argument_parser()
>>> args = parser.parse_args()
>>> if args.config:
>>> if args.config:
>>> config.merge_from_file(args.config)
>>> if args.opts:
>>> config.merge_from_list(args.opts)
>>> config.freeze()
>>>
>>> if args.nprocs > 1 and args.device == "gpu":
>>>
>>> if args.nprocs > 0:
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
>>> else:
>>> main_sp(config, args)
......@@ -93,18 +94,24 @@ class Trainer():
self.checkpoint_dir = None
self.iteration = 0
self.epoch = 0
self._train = True
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
if self.parallel:
self.init_parallel()
@contextmanager
def eval(self):
self._train = False
yield
self._train = True
def setup(self):
"""Setup the experiment.
"""
self.setup_output_dir()
self.dump_config()
self.setup_visualizer()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
......@@ -114,10 +121,10 @@ class Trainer():
@property
def parallel(self):
"""A flag indicating whether the experiment should run with
"""A flag indicating whether the experiment should run with
multiprocessing.
"""
return self.args.device == "gpu" and self.args.nprocs > 1
return self.args.nprocs > 1
def init_parallel(self):
"""Init environment for multiprocess training.
......@@ -144,9 +151,9 @@ class Trainer():
self.optimizer, infos)
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
......@@ -158,8 +165,8 @@ class Trainer():
checkpoint_path=self.args.checkpoint_path)
if infos:
# restore from ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.iteration = infos["step"] + 1
self.epoch = infos["epoch"] + 1
scratch = False
else:
self.iteration = 0
......@@ -237,31 +244,61 @@ class Trainer():
try:
self.train()
except KeyboardInterrupt:
self.save()
exit(-1)
finally:
self.destory()
logger.info("Training Done.")
logger.info("Train Done.")
def run_test(self):
"""Do Test/Decode"""
with self.eval():
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
logger.info("Test/Decode Done.")
def run_export(self):
"""Do Model Export"""
with self.eval():
try:
self.export()
except KeyboardInterrupt:
exit(-1)
logger.info("Export Done.")
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
if self.args.output:
output_dir = Path(self.args.output).expanduser()
elif self.args.checkpoint_path:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_dir = self.output_dir / "checkpoints"
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.log_dir = output_dir / "log"
self.log_dir.mkdir(parents=True, exist_ok=True)
self.test_dir = output_dir / "test"
self.test_dir.mkdir(parents=True, exist_ok=True)
self.decode_dir = output_dir / "decode"
self.decode_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir = checkpoint_dir
self.export_dir = output_dir / "export"
self.export_dir.mkdir(parents=True, exist_ok=True)
self.visual_dir = output_dir / "visual"
self.visual_dir.mkdir(parents=True, exist_ok=True)
self.config_dir = output_dir / "conf"
self.config_dir.mkdir(parents=True, exist_ok=True)
@mp_tools.rank_zero_only
def destory(self):
......@@ -273,27 +310,34 @@ class Trainer():
@mp_tools.rank_zero_only
def setup_visualizer(self):
"""Initialize a visualizer to log the experiment.
The visual log is saved in the output directory.
Notes
------
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
unexpected behaviors.
"""
# visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir))
visualizer = SummaryWriter(logdir=str(self.visual_dir))
self.visualizer = visualizer
@mp_tools.rank_zero_only
def dump_config(self):
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment.
"""
with open(self.output_dir / "config.yaml", 'wt') as f:
config_file = self.config_dir / "config.yaml"
if self._train and config_file.exists():
time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime())
target_path = self.config_dir / ".".join(
[time_stamp, "config.yaml"])
config_file.rename(target_path)
with open(config_file, 'wt') as f:
print(self.config, file=f)
def train_batch(self):
......@@ -307,14 +351,26 @@ class Trainer():
"""
raise NotImplementedError("valid should be implemented.")
@paddle.no_grad()
def test(self):
"""The test. A subclass should implement this method in Tester.
"""
raise NotImplementedError("test should be implemented.")
@paddle.no_grad()
def export(self):
"""The test. A subclass should implement this method in Tester.
"""
raise NotImplementedError("export should be implemented.")
def setup_model(self):
"""Setup model, criterion and optimizer, etc. A subclass should
"""Setup model, criterion and optimizer, etc. A subclass should
implement this method.
"""
raise NotImplementedError("setup_model should be implemented.")
def setup_dataloader(self):
"""Setup training dataloader and validation dataloader. A subclass
"""Setup training dataloader and validation dataloader. A subclass
should implement this method.
"""
raise NotImplementedError("setup_dataloader should be implemented.")
......@@ -120,14 +120,15 @@ class Autolog:
model_precision="fp32"):
import auto_log
pid = os.getpid()
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
if os.environ.get('CUDA_VISIBLE_DEVICES', None):
gpu_id = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0])
infer_config = inference.Config()
infer_config.enable_use_gpu(100, gpu_id)
else:
gpu_id = None
infer_config = inference.Config()
autolog = auto_log.AutoLogger(
self.autolog = auto_log.AutoLogger(
model_name=model_name,
model_precision=model_precision,
batch_size=batch_size,
......@@ -139,7 +140,6 @@ class Autolog:
gpu_ids=gpu_id,
time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
warmup=0)
self.autolog = autolog
def getlog(self):
return self.autolog
......@@ -2,3 +2,4 @@ dev-clean/
manifest.dev-clean
manifest.train-clean
train-clean/
*.meta
......@@ -58,6 +58,10 @@ def create_manifest(data_dir, manifest_path):
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
total_sec = 0.0
total_text = 0.0
total_num = 0
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
......@@ -80,10 +84,27 @@ def create_manifest(data_dir, manifest_path):
'text':
text
}))
total_sec += duration
total_text += len(text)
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
subset = os.path.splitext(manifest_path)[1][1:]
manifest_dir = os.path.dirname(manifest_path)
data_dir_name = os.path.split(data_dir)[-1]
meta_path = os.path.join(manifest_dir, data_dir_name) + '.meta'
with open(meta_path, 'w') as f:
print(f"{subset}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
#! /usr/bin/env bash
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
......@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
set -e
expdir=exp
datadir=data
nj=32
lmtag=
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpemodel=${bpeprefix}.model
if [ $# != 3 ];then
echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
dict=$2
ckpt_prefix=$3
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
echo "chunk mode ${chunk_mode}"
# download language model
#bash local/download_lm_en.sh
......@@ -21,39 +42,46 @@ ckpt_prefix=$2
# exit 1
#fi
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
batch_size=64
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
pids=() # initialize pids
for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
(
for rtask in ${recog_set}; do
(
decode_dir=decode_${rtask}_${dmethd}_$(basename ${config_path%.*})_${lmtag}
feat_recog_dir=${datadir}
mkdir -p ${expdir}/${decode_dir}
mkdir -p ${feat_recog_dir}
# split data
split_json.sh ${feat_recog_dir}/manifest.${rtask} ${nj}
#### use CPU for decoding
ngpu=0
# set batchsize 0 to disable batch decoding
batch_size=1
${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${expdir}/${decode_dir}/data.JOB.json \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${dmethd} \
--opts decoding.batch_size ${batch_size} \
--opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true ${expdir}/${decode_dir} ${dict}
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
) &
pids+=($!) # store background pids
done
) &
pids+=($!) # store background pids
done
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
echo "Finished"
exit 0
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
......@@ -11,19 +11,28 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
device=gpu
if [ ngpu == 0 ];then
device=cpu
mkdir -p exp
# seed may break model convergence
seed=0
if [ ${seed} != 0 ]; then
#export FLAGS_cudnn_deterministic=True
echo "None"
fi
echo "using ${device}..."
mkdir -p exp
# export FLAGS_cudnn_exhaustive_search=true
# export FLAGS_conv_workspace_size_limit=4000
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} != 0 ]; then
#unset FLAGS_cudnn_deterministic
echo "None"
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
......@@ -4,6 +4,7 @@ data:
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 4
......@@ -35,6 +36,8 @@ model:
rnn_layer_size: 2048
use_gru: False
share_rnn_weights: True
blank_id: 0
ctc_grad_norm_type: instance
training:
n_epoch: 20
......
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
min_input_len: 0.0
max_input_len: 30.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
spectrum_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
batch_size: 4
model:
num_conv_layers: 2
num_rnn_layers: 4
rnn_layer_size: 2048
rnn_direction: forward
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: True
blank_id: 0
ctc_grad_norm_type: instance
training:
n_epoch: 10
accum_grad: 1
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 1
checkpoint:
kbest_n: 3
latest_n: 2
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#! /usr/bin/env bash
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
......@@ -9,6 +9,11 @@ URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
if [ -e $TARGET ];then
echo "$TARGET exists."
exit 0
fi
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1
fi
......@@ -11,19 +11,14 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
model_type=$4
python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in export!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
......@@ -22,11 +19,11 @@ if [ $? -ne 0 ]; then
fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
profiler_options=
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
device=gpu
if [ ngpu == 0 ];then
device=cpu
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
config_path=$1
ckpt_name=$2
model_type=$3
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--model_type ${model_type} \
--profiler-options "${profiler_options}" \
--seed ${seed}
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
export MAIN_ROOT=${PWD}/../../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
......@@ -7,11 +7,12 @@ stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ###ckpt = deepspeech2
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
......@@ -21,20 +22,20 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
......@@ -65,6 +65,8 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
......@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
......
#! /usr/bin/env bash
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
......@@ -8,30 +8,57 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0
#! /usr/bin/env bash
#!/bin/bash
profiler_options=
benchmark_batch_size=0
benchmark_max_step=0
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--seed ${seed} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
......@@ -20,20 +20,26 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES= ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES= ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册