提交 a7244593 编写于 作者: H Hui Zhang

refactor data

上级 553aa359
此差异已折叠。
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "choice-lender",
"metadata": {},
"outputs": [],
"source": [
"eng=\"one minute a voice said and the time buzzer sounded\"\n",
"chn=\"可控是病毒武器最基本的要求\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ruled-kuwait",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"o\n",
"n\n",
"e\n",
" \n",
"m\n",
"i\n",
"n\n",
"u\n",
"t\n",
"e\n",
" \n",
"a\n",
" \n",
"v\n",
"o\n",
"i\n",
"c\n",
"e\n",
" \n",
"s\n",
"a\n",
"i\n",
"d\n",
" \n",
"a\n",
"n\n",
"d\n",
" \n",
"t\n",
"h\n",
"e\n",
" \n",
"t\n",
"i\n",
"m\n",
"e\n",
" \n",
"b\n",
"u\n",
"z\n",
"z\n",
"e\n",
"r\n",
" \n",
"s\n",
"o\n",
"u\n",
"n\n",
"d\n",
"e\n",
"d\n"
]
}
],
"source": [
"for char in eng:\n",
" print(char)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "passive-petite",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"可\n",
"控\n",
"是\n",
"病\n",
"毒\n",
"武\n",
"器\n",
"最\n",
"基\n",
"本\n",
"的\n",
"要\n",
"求\n"
]
}
],
"source": [
"for char in chn:\n",
" print(char)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "olympic-realtor",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"one\n",
"minute\n",
"a\n",
"voice\n",
"said\n",
"and\n",
"the\n",
"time\n",
"buzzer\n",
"sounded\n"
]
}
],
"source": [
"for word in eng.split():\n",
" print(word)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "induced-enhancement",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"可控是病毒武器最基本的要求\n"
]
}
],
"source": [
"for word in chn.split():\n",
" print(word)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "lovely-bottle",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'StringIO'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-3e4825b8299f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mStringIO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'StringIO'"
]
}
],
"source": [
"import StringIO"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "interested-cardiff",
"metadata": {},
"outputs": [],
"source": [
"from io import StringIO"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "portable-ivory",
"metadata": {},
"outputs": [],
"source": [
"inputs = StringIO()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "compatible-destination",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"64"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "federal-margin",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n",
"\n"
]
}
],
"source": [
"print(inputs.getvalue())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "consecutive-entity",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"64"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "desirable-anxiety",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n",
"nor is mister quilter's manner less interesting than his matter\n",
"\n"
]
}
],
"source": [
"print(inputs.getvalue())"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "employed-schedule",
"metadata": {},
"outputs": [],
"source": [
"import tempfile"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "unlikely-honduras",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['__class__', '__del__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__ne__', '__new__', '__next__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_checkClosed', '_checkReadable', '_checkSeekable', '_checkWritable', '_dealloc_warn', '_finalizing', 'close', 'closed', 'detach', 'fileno', 'flush', 'isatty', 'mode', 'name', 'peek', 'raw', 'read', 'read1', 'readable', 'readinto', 'readinto1', 'readline', 'readlines', 'seek', 'seekable', 'tell', 'truncate', 'writable', 'write', 'writelines']\n",
"57\n"
]
}
],
"source": [
"with tempfile.TemporaryFile() as fp:\n",
" print(dir(fp))\n",
" print(fp.name)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "needed-trail",
"metadata": {},
"outputs": [],
"source": [
"a = tempfile.mkstemp(suffix=None, prefix='test', dir=None, text=False)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "hazardous-choir",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['__add__', '__class__', '__contains__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getnewargs__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'count', 'index']\n"
]
}
],
"source": [
"print(dir(a))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "front-sauce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(57, '/tmp/test27smzbzc')\n"
]
}
],
"source": [
"print(a)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "shared-wages",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<built-in method index of tuple object at 0x7f999b525648>\n"
]
}
],
"source": [
"print(a.index)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "charged-carnival",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_closer', 'close', 'delete', 'file', 'name']\n",
"/tmp/tmpfjn7mygy\n"
]
}
],
"source": [
"fp= tempfile.NamedTemporaryFile(mode='w', delete=False)\n",
"print(dir(fp))\n",
"print(fp.name)\n",
"fp.close()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "religious-terror",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/tmp/tmpfjn7mygy\n"
]
}
],
"source": [
"import os\n",
"os.path.exists(fp.name)\n",
"print(fp.name)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "communist-gospel",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function BufferedRandom.write(buffer, /)>"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fp.write"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "simplified-clarity",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'example'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"s='/home/ubuntu/python/example.py'\n",
"os.path.splitext(os.path.basename(s))[0]"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "popular-genius",
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "studied-burner",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_items([('hello', 1), ('world', 1)])\n"
]
}
],
"source": [
"counter = Counter()\n",
"counter.update([\"hello\"])\n",
"counter.update([\"world\"])\n",
"print(counter.items())"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "mineral-ceremony",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_items([('h', 1), ('e', 1), ('l', 3), ('o', 2), ('w', 1), ('r', 1), ('d', 1)])\n"
]
}
],
"source": [
"counter = Counter()\n",
"counter.update(\"hello\")\n",
"counter.update(\"world\")\n",
"print(counter.items())"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "nonprofit-freedom",
"metadata": {},
"outputs": [],
"source": [
"counter.update(list(\"hello\"))"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "extended-methodology",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_items([('h', 2), ('e', 2), ('l', 5), ('o', 3), ('w', 1), ('r', 1), ('d', 1)])\n"
]
}
],
"source": [
"print(counter.items())"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "grand-benjamin",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['h', 'e', 'l', 'l', 'o']"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(\"hello\")"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "marine-fundamentals",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{}\n"
]
}
],
"source": [
"from io import StringIO\n",
"a = StringIO(initial_value='{}', newline='')\n",
"print(a.read())"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "suitable-charlotte",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "expected str, bytes or os.PathLike object, not _io.StringIO",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-56-4323a912120d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: expected str, bytes or os.PathLike object, not _io.StringIO"
]
}
],
"source": [
"with io.open(a) as f:\n",
" print(f.read())"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "institutional-configuration",
"metadata": {},
"outputs": [],
"source": [
"io.open?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pregnant-modem",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -83,27 +83,11 @@ def inference(config, args):
def start_server(config, args):
"""Start the ASR server"""
dataset = ManifestDataset(
config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = io.StringIO(
initial_value='{}', newline='')
config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config,
args.checkpoint_path)
......
......@@ -35,27 +35,12 @@ from deepspeech.io.dataset import ManifestDataset
def start_server(config, args):
"""Start the ASR server"""
dataset = ManifestDataset(
config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = io.StringIO(
initial_value='{}', newline='')
config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config,
args.checkpoint_path)
model.eval()
......
......@@ -41,34 +41,18 @@ def tune(config, args):
if not args.num_betas >= 0:
raise ValueError("num_betas must be non-negative!")
dev_dataset = ManifestDataset(
config.data.dev_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
config.data.manfiest = config.data.dev_manifest
config.data.augmentation_config = io.StringIO(
initial_value='{}', newline='')
config.data.keep_transcription_text = True
dev_dataset = ManifestDataset.from_config(config)
valid_loader = DataLoader(
dev_dataset,
batch_size=config.data.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(is_training=False))
collate_fn=SpeechCollator(keep_transcription_text=True))
model = DeepSpeech2Model.from_pretrained(dev_dataset, config,
args.checkpoint_path)
......
......@@ -33,8 +33,8 @@ _C.data = CN(
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delat_delta=False, # 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delat_delta=False, # 'mfcc', 'fbank'
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
......
......@@ -145,52 +145,15 @@ class DeepSpeech2Trainer(Trainer):
def setup_dataloader(self):
config = self.config
config.data.keep_transcription_text = False
train_dataset = ManifestDataset(
config.data.train_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=io.open(
config.data.augmentation_config, mode='r',
encoding='utf8').read(),
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=False)
dev_dataset = ManifestDataset(
config.data.dev_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=False)
config.data.manfiest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manfiest = config.data.dev_manifest
config.data.augmentation_config = io.StringIO(
initial_value='{}', newline='')
dev_dataset = ManifestDataset.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
......@@ -211,7 +174,7 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(is_training=True)
collate_fn = SpeechCollator(keep_transcription_text=False)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
......@@ -367,27 +330,12 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_dataloader(self):
config = self.config
# return raw text
test_dataset = ManifestDataset(
config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
config.data.manfiest = config.data.test_manifest
config.data.augmentation_config = io.StringIO(
initial_value='{}', newline='')
config.data.keep_transcription_text = True
test_dataset = ManifestDataset.from_config(config)
# return text ord id
self.test_loader = DataLoader(
......@@ -395,7 +343,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(is_training=False))
collate_fn=SpeechCollator(keep_transcription_text=True))
self.logger.info("Setup test Dataloader!")
def setup_output_dir(self):
......
......@@ -330,7 +330,7 @@ class AudioFeaturizer(object):
nfft=512,
lowfreq=0,
highfreq=max_freq,
preemph=0.97,)
preemph=0.97, )
fbank_feat = np.transpose(fbank_feat)
if delta_delta:
fbank_feat = self._concat_delta_delta(fbank_feat)
......
......@@ -48,7 +48,7 @@ class TextFeaturizer(object):
tokens = self.char_tokenize(text)
elif self.unit_type == 'word':
tokens = self.word_tokenize(text)
else: # spm
else: # spm
tokens = self.spm_tokenize(text)
return tokens
......
......@@ -42,6 +42,7 @@ class SpeechSegment(AudioSegment):
"""
AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript
# must init `tokens` with `token_ids` at the same time
self._tokens = tokens
self._token_ids = token_ids
......@@ -183,7 +184,7 @@ class SpeechSegment(AudioSegment):
@property
def has_token(self):
if self._tokens or self._token_ids:
if self._tokens and self._token_ids:
return True
return False
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import numpy as np
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
......@@ -26,12 +28,18 @@ def create_dataloader(manifest_path,
mean_std_filepath,
spm_model_prefix,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
max_input_len=float('inf'),
min_input_len=0.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type='linear',
feat_dim=None,
delta_delta=False,
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
......@@ -43,20 +51,24 @@ def create_dataloader(manifest_path,
dist=False):
dataset = ManifestDataset(
manifest_path,
unit_type,
vocab_filepath,
mean_std_filepath,
manifest_path=manifest_path,
unit_type=unit_type,
vocab_filepath=vocab_filepath,
mean_std_filepath=mean_std_filepath,
spm_model_prefix=spm_model_prefix,
augmentation_config=augmentation_config,
max_duration=max_duration,
min_duration=min_duration,
max_input_len=max_input_len,
min_input_len=min_input_len,
max_output_len=max_output_len,
min_output_len=min_output_len,
max_output_input_ratio=max_output_input_ratio,
min_output_input_ratio=min_output_input_ratio,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
specgram_type=specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
feat_dim=feat_dim,
delta_delta=delta_delta,
use_dB_normalization=use_dB_normalization,
random_seed=random_seed,
keep_transcription_text=keep_transcription_text)
......@@ -80,7 +92,10 @@ def create_dataloader(manifest_path,
sortagrad=is_training,
shuffle_method=shuffle_method)
def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
def padding_batch(batch,
padding_to=-1,
flatten=False,
keep_transcription_text=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
......@@ -113,10 +128,10 @@ def create_dataloader(manifest_path,
audio_lens.append(audio.shape[1])
padded_text = np.zeros([max_text_length])
if is_training:
padded_text[:len(text)] = text #ids
else:
if keep_transcription_text:
padded_text[:len(text)] = [ord(t) for t in text] # string
else:
padded_text[:len(text)] = text #ids
texts.append(padded_text)
text_lens.append(len(text))
......@@ -124,11 +139,13 @@ def create_dataloader(manifest_path,
audio_lens = np.array(audio_lens).astype('int64')
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
return padded_audios, audio_lens, texts, text_lens
#collate_fn=functools.partial(padding_batch, keep_transcription_text=keep_transcription_text),
collate_fn = SpeechCollator(keep_transcription_text=keep_transcription_text)
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=partial(padding_batch, is_training=is_training),
collate_fn=collate_fn,
num_workers=num_workers)
return loader
......@@ -25,14 +25,14 @@ __all__ = ["SpeechCollator"]
class SpeechCollator():
def __init__(self, is_training=True):
def __init__(self, keep_transcription_text=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
if ``is_training`` is True, text is token ids else is raw string.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
"""
self._is_training = is_training
self._keep_transcription_text = keep_transcription_text
def __call__(self, batch):
"""batch examples
......@@ -61,15 +61,15 @@ class SpeechCollator():
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = []
if self._is_training:
tokens = text # token ids
else:
assert isinstance(text, str)
if self._keep_transcription_text:
assert isinstance(text, str), type(text)
tokens = [ord(t) for t in text]
else:
tokens = text # token ids
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
tokens, dtype=np.int64)
texts.append(tokens)
text_lens.append(len(text))
text_lens.append(tokens.shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import math
import random
import tarfile
......@@ -43,8 +44,12 @@ class ManifestDataset(Dataset):
mean_std_filepath,
spm_model_prefix=None,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
max_input_len=float('inf'),
min_input_len=0.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
......@@ -66,8 +71,12 @@ class ManifestDataset(Dataset):
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_duration (float, optional): audio length in seconds must less than this. Defaults to float('inf').
min_duration (float, optional): audio length is seconds must greater than this. Defaults to 0.0.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
......@@ -82,9 +91,13 @@ class ManifestDataset(Dataset):
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
"""
super().__init__()
self._max_input_len = max_input_len,
self._min_input_len = min_input_len,
self._max_output_len = max_output_len,
self._min_output_len = min_output_len,
self._max_output_input_ratio = max_output_input_ratio,
self._min_output_input_ratio = min_output_input_ratio,
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
self._audio_augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
......@@ -102,6 +115,7 @@ class ManifestDataset(Dataset):
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
# for caching tar files info
......@@ -112,9 +126,58 @@ class ManifestDataset(Dataset):
# read manifest
self._manifest = read_manifest(
manifest_path=manifest_path,
max_duration=self._max_duration,
min_duration=self._min_duration)
self._manifest.sort(key=lambda x: x["duration"])
max_input_len=max_input_len,
min_input_len=min_input_len,
max_output_len=max_output_len,
min_output_len=min_output_len,
max_output_input_ratio=max_output_input_ratio,
min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0])
@classmethod
def from_config(cls, config):
"""Build a ManifestDataset object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
ManifestDataset: dataet object.
"""
assert manifest in config.data
assert keep_transcription_text in config.data
if isinstance(config.data.augmentation_config, (str, bytes)):
aug_file = io.open(
config.data.augmentation_config, mode='r', encoding='utf8')
else:
aug_file = config.data.augmentation_config
assert isinstance(aug_file, io.StringIO)
dataset = cls(
manifest_path=config.data.manifest,
unit_type=config.data.unit_type,
vocab_filepath=config.data.vocab_filepath,
mean_std_filepath=config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=aug_file.read(),
max_input_len=config.data.max_input_len,
min_input_len=config.data.min_input_len,
max_output_len=config.data.max_output_len,
min_output_len=config.data.min_output_len,
max_output_input_ratio=config.data.max_output_input_ratio,
min_output_input_ratio=config.data.min_output_input_ratio,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=config.data.keep_transcription_text)
return dataset
@property
def manifest(self):
......
......@@ -35,7 +35,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('count_threshold', int, 0,
add_arg('count_threshold', int, 0,
"Truncation threshold for char/word counts.Default 0, no truncate.")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
......@@ -61,7 +61,7 @@ def count_manifest(counter, text_feature, manifest_path):
for line_json in manifest_jsons:
line = text_feature.tokenize(line_json['text'])
counter.update(line)
def dump_text_manifest(fileobj, manifest_path):
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
......@@ -97,7 +97,7 @@ def main():
# encode
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
counter = Counter()
for manifest_path in args.manifest_paths:
count_manifest(counter, text_feature, manifest_path)
......
#!/bin/bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0.
# 2014 David Snyder
# This script combines the data from multiple source directories into
# a single destination directory.
# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information
# about what these directories contain.
# Begin configuration section.
extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..."
skip_fix=false # skip the fix_data_dir.sh in the end
# End configuration section.
echo "$0 $@" # Print the command line for logging
if [ -f path.sh ]; then . ./path.sh; fi
if [ -f parse_options.sh ]; then . parse_options.sh || exit 1; fi
if [ $# -lt 2 ]; then
echo "Usage: combine_data.sh [--extra-files 'file1 file2'] <dest-data-dir> <src-data-dir1> <src-data-dir2> ..."
echo "Note, files that don't appear in all source dirs will not be combined,"
echo "with the exception of utt2uniq and segments, which are created where necessary."
exit 1
fi
dest=$1;
shift;
first_src=$1;
rm -r $dest 2>/dev/null
mkdir -p $dest;
export LC_ALL=C
for dir in $*; do
if [ ! -f $dir/utt2spk ]; then
echo "$0: no such file $dir/utt2spk"
exit 1;
fi
done
# Check that frame_shift are compatible, where present together with features.
dir_with_frame_shift=
for dir in $*; do
if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then
if [[ $dir_with_frame_shift ]] &&
! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then
echo "$0:error: different frame_shift in directories $dir and " \
"$dir_with_frame_shift. Cannot combine features."
exit 1;
fi
dir_with_frame_shift=$dir
fi
done
# W.r.t. utt2uniq file the script has different behavior compared to other files
# it is not compulsary for it to exist in src directories, but if it exists in
# even one it should exist in all. We will create the files where necessary
has_utt2uniq=false
for in_dir in $*; do
if [ -f $in_dir/utt2uniq ]; then
has_utt2uniq=true
break
fi
done
if $has_utt2uniq; then
# we are going to create an utt2uniq file in the destdir
for in_dir in $*; do
if [ ! -f $in_dir/utt2uniq ]; then
# we assume that utt2uniq is a one to one mapping
cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}'
else
cat $in_dir/utt2uniq
fi
done | sort -k1 > $dest/utt2uniq
echo "$0: combined utt2uniq"
else
echo "$0 [info]: not combining utt2uniq as it does not exist"
fi
# some of the old scripts might provide utt2uniq as an extrafile, so just remove it
extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g")
# segments are treated similarly to utt2uniq. If it exists in some, but not all
# src directories, then we generate segments where necessary.
has_segments=false
for in_dir in $*; do
if [ -f $in_dir/segments ]; then
has_segments=true
break
fi
done
if $has_segments; then
for in_dir in $*; do
if [ ! -f $in_dir/segments ]; then
echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
utils/data/get_segments_for_data.sh $in_dir
else
cat $in_dir/segments
fi
done | sort -k1 > $dest/segments
echo "$0: combined segments"
else
echo "$0 [info]: not combining segments as it does not exist"
fi
for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do
exists_somewhere=false
absent_somewhere=false
for d in $*; do
if [ -f $d/$file ]; then
exists_somewhere=true
else
absent_somewhere=true
fi
done
if ! $absent_somewhere; then
set -o pipefail
( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1;
set +o pipefail
echo "$0: combined $file"
else
if ! $exists_somewhere; then
echo "$0 [info]: not combining $file as it does not exist"
else
echo "$0 [info]: **not combining $file as it does not exist everywhere**"
fi
fi
done
tools/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
if [[ $dir_with_frame_shift ]]; then
cp $dir_with_frame_shift/frame_shift $dest
fi
if ! $skip_fix ; then
tools/fix_data_dir.sh $dest || exit 1;
fi
exit 0
\ No newline at end of file
......@@ -85,7 +85,7 @@ def main():
raise NotImplemented('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n')
count += 1
print(f"Examples number: {count}")
fout.close()
......
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
$ignore_oov = 0;
for($x = 0; $x < 2; $x++) {
if ($ARGV[0] eq "--map-oov") {
shift @ARGV;
$map_oov = shift @ARGV;
if ($map_oov eq "-f" || $map_oov =~ m/words\.txt$/ || $map_oov eq "") {
# disallow '-f', the empty string and anything ending in words.txt as the
# OOV symbol because these are likely command-line errors.
die "the --map-oov option requires an argument";
}
}
if ($ARGV[0] eq "-f") {
shift @ARGV;
$field_spec = shift @ARGV;
if ($field_spec =~ m/^\d+$/) {
$field_begin = $field_spec - 1; $field_end = $field_spec - 1;
}
if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
if ($1 ne "") {
$field_begin = $1 - 1; # Change to zero-based indexing.
}
if ($2 ne "") {
$field_end = $2 - 1; # Change to zero-based indexing.
}
}
if (!defined $field_begin && !defined $field_end) {
die "Bad argument to -f option: $field_spec";
}
}
}
$symtab = shift @ARGV;
if (!defined $symtab) {
print STDERR "Usage: sym2int.pl [options] symtab [input transcriptions] > output transcriptions\n" .
"options: [--map-oov <oov-symbol> ] [-f <field-range> ]\n" .
"note: <field-range> can look like 4-5, or 4-, or 5-, or 1.\n";
}
open(F, "<$symtab") || die "Error opening symbol table file $symtab";
while(<F>) {
@A = split(" ", $_);
@A == 2 || die "bad line in symbol table file: $_";
$sym2int{$A[0]} = $A[1] + 0;
}
if (defined $map_oov && $map_oov !~ m/^\d+$/) { # not numeric-> look it up
if (!defined $sym2int{$map_oov}) { die "OOV symbol $map_oov not defined."; }
$map_oov = $sym2int{$map_oov};
}
$num_warning = 0;
$max_warning = 20;
while (<>) {
@A = split(" ", $_);
@B = ();
for ($n = 0; $n < @A; $n++) {
$a = $A[$n];
if ( (!defined $field_begin || $n >= $field_begin)
&& (!defined $field_end || $n <= $field_end)) {
$i = $sym2int{$a};
if (!defined ($i)) {
if (defined $map_oov) {
if ($num_warning++ < $max_warning) {
print STDERR "sym2int.pl: replacing $a with $map_oov\n";
if ($num_warning == $max_warning) {
print STDERR "sym2int.pl: not warning for OOVs any more times\n";
}
}
$i = $map_oov;
} else {
$pos = $n+1;
die "sym2int.pl: undefined symbol $a (in position $pos)\n";
}
}
$a = $i;
}
push @B, $a;
}
print join(" ", @B);
print "\n";
}
if ($num_warning > 0) {
print STDERR "** Replaced $num_warning instances of OOVs with $map_oov\n";
}
exit(0);
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册