提交 f4e59293 编写于 作者: H huangyuxin

Merge branch 'develop' of https://github.com/PaddlePaddle/DeepSpeech into fix_bug

...@@ -18,5 +18,7 @@ tools/sox-14.4.2 ...@@ -18,5 +18,7 @@ tools/sox-14.4.2
tools/soxbindings tools/soxbindings
tools/montreal-forced-aligner/ tools/montreal-forced-aligner/
tools/Montreal-Forced-Aligner/ tools/Montreal-Forced-Aligner/
tools/sctk
tools/sctk-20159b5/
*output/ *output/
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "academic-surname",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from paddle import nn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fundamental-treasure",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"L = nn.Linear(256, 2048)\n",
"L2 = nn.Linear(2048, 256)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "consolidated-elephant",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "moderate-noise",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float64\n",
"Tensor(shape=[2, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[[-1.54171216, -2.61531472, -1.79881978, ..., -0.31395876, 0.56513089, -0.44516513],\n",
" [-0.79492962, 1.91157901, 0.66567147, ..., 0.54825783, -1.01471853, -0.84924090],\n",
" [-1.22556651, -0.36225814, 0.65063190, ..., 0.65726501, 0.05563191, 0.09009409],\n",
" ...,\n",
" [ 0.38615900, -0.77905393, 0.99732304, ..., -1.38463700, -3.32365036, -1.31089687],\n",
" [ 0.05579993, 0.06885809, -1.66662002, ..., -0.23346378, -3.29372883, 1.30561364],\n",
" [ 1.90676069, 1.95093191, -0.28849599, ..., -0.06860496, 0.95347673, 1.00475824]],\n",
"\n",
" [[-0.91453546, 0.55298805, -1.06146812, ..., -0.86378336, 1.00454640, 1.26062179],\n",
" [ 0.10223761, 0.81301165, 2.36865163, ..., 0.16821407, 0.29240361, 1.05408621],\n",
" [-1.33196676, 1.94433689, 0.01934209, ..., 0.48036841, 0.51585966, 1.22893548],\n",
" ...,\n",
" [-0.19558455, -0.47075930, 0.90796155, ..., -1.28598249, -0.24321797, 0.17734711],\n",
" [ 0.89819717, -1.39516675, 0.17138045, ..., 2.39761519, 1.76364994, -0.52177650],\n",
" [ 0.94122332, -0.18581429, 1.36099780, ..., 0.67647684, -0.04699665, 1.51205540]]])\n",
"tensor([[[-1.5417, -2.6153, -1.7988, ..., -0.3140, 0.5651, -0.4452],\n",
" [-0.7949, 1.9116, 0.6657, ..., 0.5483, -1.0147, -0.8492],\n",
" [-1.2256, -0.3623, 0.6506, ..., 0.6573, 0.0556, 0.0901],\n",
" ...,\n",
" [ 0.3862, -0.7791, 0.9973, ..., -1.3846, -3.3237, -1.3109],\n",
" [ 0.0558, 0.0689, -1.6666, ..., -0.2335, -3.2937, 1.3056],\n",
" [ 1.9068, 1.9509, -0.2885, ..., -0.0686, 0.9535, 1.0048]],\n",
"\n",
" [[-0.9145, 0.5530, -1.0615, ..., -0.8638, 1.0045, 1.2606],\n",
" [ 0.1022, 0.8130, 2.3687, ..., 0.1682, 0.2924, 1.0541],\n",
" [-1.3320, 1.9443, 0.0193, ..., 0.4804, 0.5159, 1.2289],\n",
" ...,\n",
" [-0.1956, -0.4708, 0.9080, ..., -1.2860, -0.2432, 0.1773],\n",
" [ 0.8982, -1.3952, 0.1714, ..., 2.3976, 1.7636, -0.5218],\n",
" [ 0.9412, -0.1858, 1.3610, ..., 0.6765, -0.0470, 1.5121]]])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"x = np.random.randn(2, 51, 256)\n",
"print(x.dtype)\n",
"px = paddle.to_tensor(x, dtype='float32')\n",
"tx = torch.tensor(x, dtype=torch.float32)\n",
"print(px)\n",
"print(tx)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cooked-progressive",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "mechanical-prisoner",
"metadata": {},
"outputs": [],
"source": [
"data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n",
"t_norm_ff = data['norm_ff']\n",
"t_ff_out = data['ff_out']\n",
"t_ff_l_x = data['ff_l_x']\n",
"t_ff_l_a_x = data['ff_l_a_x']\n",
"t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
"t_ps = data['ps']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "indie-marriage",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "assured-zambia",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"L.set_state_dict({'weight': t_ps[0].T, 'bias': t_ps[1]})\n",
"L2.set_state_dict({'weight': t_ps[2].T, 'bias': t_ps[3]})\n",
"\n",
"ps = []\n",
"for n, p in L.named_parameters():\n",
" ps.append(p)\n",
"\n",
"for n, p in L2.state_dict().items():\n",
" ps.append(p)\n",
" \n",
"for p, tp in zip(ps, t_ps):\n",
" print(np.allclose(p.numpy(), tp.T))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "committed-jacob",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "extreme-traffic",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "optimum-milwaukee",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "viral-indian",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"# data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n",
"# t_norm_ff = data['norm_ff']\n",
"# t_ff_out = data['ff_out']\n",
"# t_ff_l_x = data['ff_l_x']\n",
"# t_ff_l_a_x = data['ff_l_a_x']\n",
"# t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
"# t_ps = data['ps']\n",
"TL = torch.nn.Linear(256, 2048)\n",
"TL2 = torch.nn.Linear(2048, 256)\n",
"TL.load_state_dict({'weight': torch.tensor(t_ps[0]), 'bias': torch.tensor(t_ps[1])})\n",
"TL2.load_state_dict({'weight': torch.tensor(t_ps[2]), 'bias': torch.tensor(t_ps[3])})\n",
"\n",
"# for n, p in TL.named_parameters():\n",
"# print(n, p)\n",
"# for n, p in TL2.named_parameters():\n",
"# print(n, p)\n",
"\n",
"ps = []\n",
"for n, p in TL.state_dict().items():\n",
" ps.append(p.data.numpy())\n",
" \n",
"for n, p in TL2.state_dict().items():\n",
" ps.append(p.data.numpy())\n",
" \n",
"for p, tp in zip(ps, t_ps):\n",
" print(np.allclose(p, tp))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "skilled-vietnamese",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[ 0.67277956 0.08313607 -0.62761104 ... -0.17480263 0.42718208\n",
" -0.5787626 ]\n",
" [ 0.91516656 0.5393416 1.7159258 ... 0.06144593 0.06486575\n",
" -0.03350811]\n",
" [ 0.438351 0.6227843 0.24096036 ... 1.0912522 -0.90929437\n",
" -1.012989 ]\n",
" ...\n",
" [ 0.68631977 0.14240924 0.10763275 ... -0.11513516 0.48065388\n",
" 0.04070369]\n",
" [-0.9525228 0.23197874 0.31264272 ... 0.5312439 0.18773697\n",
" -0.8450228 ]\n",
" [ 0.42024016 -0.04561988 0.54541194 ... -0.41933843 -0.00436018\n",
" -0.06663495]]\n",
"\n",
" [[-0.11638781 -0.33566502 -0.20887226 ... 0.17423287 -0.9195841\n",
" -0.8161046 ]\n",
" [-0.3469874 0.88269687 -0.11887559 ... -0.15566081 0.16357468\n",
" -0.20766167]\n",
" [-0.3847657 0.3984318 -0.06963477 ... -0.00360622 1.2360432\n",
" -0.26811332]\n",
" ...\n",
" [ 0.08230796 -0.46158582 0.54582864 ... 0.15747628 -0.44790155\n",
" 0.06020184]\n",
" [-0.8095085 0.43163058 -0.42837143 ... 0.8627463 0.90656304\n",
" 0.15847842]\n",
" [-1.485811 -0.18216592 -0.8882585 ... 0.32596245 0.7822631\n",
" -0.6460344 ]]]\n",
"[[[ 0.67278004 0.08313602 -0.6276114 ... -0.17480245 0.42718196\n",
" -0.5787625 ]\n",
" [ 0.91516703 0.5393413 1.7159253 ... 0.06144581 0.06486579\n",
" -0.03350812]\n",
" [ 0.43835106 0.62278455 0.24096027 ... 1.0912521 -0.9092943\n",
" -1.0129892 ]\n",
" ...\n",
" [ 0.6863195 0.14240888 0.10763284 ... -0.11513527 0.48065376\n",
" 0.04070365]\n",
" [-0.9525231 0.23197863 0.31264275 ... 0.53124386 0.18773702\n",
" -0.84502304]\n",
" [ 0.42024007 -0.04561983 0.545412 ... -0.41933888 -0.00436005\n",
" -0.066635 ]]\n",
"\n",
" [[-0.11638767 -0.33566508 -0.20887226 ... 0.17423296 -0.9195838\n",
" -0.8161046 ]\n",
" [-0.34698725 0.88269705 -0.11887549 ... -0.15566081 0.16357464\n",
" -0.20766166]\n",
" [-0.3847657 0.3984319 -0.06963488 ... -0.00360619 1.2360426\n",
" -0.26811326]\n",
" ...\n",
" [ 0.08230786 -0.4615857 0.5458287 ... 0.15747619 -0.44790167\n",
" 0.06020182]\n",
" [-0.8095083 0.4316307 -0.42837155 ... 0.862746 0.9065631\n",
" 0.15847899]\n",
" [-1.485811 -0.18216613 -0.8882584 ... 0.32596254 0.7822631\n",
" -0.6460344 ]]]\n",
"True\n",
"False\n"
]
}
],
"source": [
"y = L(px)\n",
"print(y.numpy())\n",
"\n",
"ty = TL(tx)\n",
"print(ty.data.numpy())\n",
"print(np.allclose(px.numpy(), tx.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "incorrect-allah",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "prostate-cameroon",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"id": "governmental-surge",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.04476918 0.554463 -0.3027508 ... -0.49600336 0.3751858\n",
" 0.8254095 ]\n",
" [ 0.95594174 -0.29528382 -1.2899452 ... 0.43718258 0.05584608\n",
" -0.06974669]]\n",
"[[ 0.04476918 0.5544631 -0.3027507 ... -0.49600336 0.37518573\n",
" 0.8254096 ]\n",
" [ 0.95594174 -0.29528376 -1.2899454 ... 0.4371827 0.05584623\n",
" -0.0697467 ]]\n",
"True\n",
"False\n",
"True\n"
]
}
],
"source": [
"x = np.random.randn(2, 256)\n",
"px = paddle.to_tensor(x, dtype='float32')\n",
"tx = torch.tensor(x, dtype=torch.float32)\n",
"y = L(px)\n",
"print(y.numpy())\n",
"ty = TL(tx)\n",
"print(ty.data.numpy())\n",
"print(np.allclose(px.numpy(), tx.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy(), atol=1e-5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "confidential-jacket",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
"id": "improved-civilization",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5e7e7c9fde8350084abf1898cf52651cfc84b17a\n"
]
}
],
"source": [
"print(paddle.version.commit)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d1e2d3b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" 'commit',\n",
" 'full_version',\n",
" 'istaged',\n",
" 'major',\n",
" 'minor',\n",
" 'mkl',\n",
" 'patch',\n",
" 'rc',\n",
" 'show',\n",
" 'with_mkl']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(paddle.version)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c880c719",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.1.0\n"
]
}
],
"source": [
"print(paddle.version.full_version)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f26977bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"commit: 5e7e7c9fde8350084abf1898cf52651cfc84b17a\n",
"None\n"
]
}
],
"source": [
"print(paddle.version.show())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "04ad47f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.6.0\n"
]
}
],
"source": [
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "e1e03830",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" '__version__',\n",
" 'cuda',\n",
" 'debug',\n",
" 'git_version',\n",
" 'hip']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(torch.version)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "4ad0389b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'b31f58de6fa8bbda5353b3c77d9be4914399724d'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.version.git_version"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7870ea10",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'10.2'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.version.cuda"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db8ee5a7",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6321ec2a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d6a0e098",
"metadata": {},
"outputs": [],
"source": [
"from typing import Union\n",
"\n",
"import torch\n",
"from torch.optim.lr_scheduler import _LRScheduler\n",
"\n",
"from typeguard import check_argument_types\n",
"\n",
"\n",
"class WarmupLR(_LRScheduler):\n",
" \"\"\"The WarmupLR scheduler\n",
" This scheduler is almost same as NoamLR Scheduler except for following\n",
" difference:\n",
" NoamLR:\n",
" lr = optimizer.lr * model_size ** -0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" WarmupLR:\n",
" lr = optimizer.lr * warmup_step ** 0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" Note that the maximum lr equals to optimizer.lr in this scheduler.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" optimizer: torch.optim.Optimizer,\n",
" warmup_steps: Union[int, float] = 25000,\n",
" last_epoch: int = -1,\n",
" ):\n",
" assert check_argument_types()\n",
" self.warmup_steps = warmup_steps\n",
"\n",
" # __init__() must be invoked before setting field\n",
" # because step() is also invoked in __init__()\n",
" super().__init__(optimizer, last_epoch)\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n",
"\n",
" def get_lr(self):\n",
" step_num = self.last_epoch + 1\n",
" return [\n",
" lr\n",
" * self.warmup_steps ** 0.5\n",
" * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n",
" for lr in self.base_lrs\n",
" ]\n",
"\n",
" def set_step(self, step: int):\n",
" self.last_epoch = step"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0d496677",
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"model = torch.nn.Linear(10, 200)\n",
"optimizer = optim.Adam(model.parameters())\n",
"scheduler = WarmupLR(optimizer, warmup_steps=25000)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e3e3f3dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.0 -1\n"
]
}
],
"source": [
"infos = {}\n",
"start_epoch = infos.get('epoch', -1) + 1\n",
"cv_loss = infos.get('cv_loss', 0.0)\n",
"step = infos.get('step', -1)\n",
"print(start_epoch, cv_loss, step)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "dc3d550c",
"metadata": {},
"outputs": [],
"source": [
"scheduler.set_step(step)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e527634e",
"metadata": {},
"outputs": [],
"source": [
"lrs=[]\n",
"for i in range(100000):\n",
" scheduler.step()\n",
" lrs.append(scheduler.get_lr())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f1452db9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting matplotlib\n",
" Downloading matplotlib-3.4.1-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)\n",
"\u001b[K |████████████████████████████████| 10.3 MB 575 kB/s eta 0:00:01\n",
"\u001b[?25hCollecting kiwisolver>=1.0.1\n",
" Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)\n",
"\u001b[K |████████████████████████████████| 1.2 MB 465 kB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (8.1.2)\n",
"Requirement already satisfied: numpy>=1.16 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (1.20.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.8.1)\n",
"Collecting cycler>=0.10\n",
" Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n",
"Requirement already satisfied: pyparsing>=2.2.1 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n",
"Requirement already satisfied: six in /workspace/wenet/venv/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n",
"Installing collected packages: kiwisolver, cycler, matplotlib\n",
"Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n"
]
}
],
"source": [
"!pip install matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0f36d04f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f0c39aa82e0>]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqc0lEQVR4nO3deXxV1b338c8vCUkYkkAghJAEAhLQIJMEHHFCBa2KVkG0T7Wt1qet9ra1w9Xn3ufe1ld7b21tvVq1alut+mhJQK3Yqjig1SpCDgIyBiLTSZhCAglTyLSeP86GxjTDQZKc6ft+vXh5zjrrrLM2O+bL3mvv3zHnHCIiIu2JC/UEREQkvCkoRESkQwoKERHpkIJCREQ6pKAQEZEOJYR6Al1h0KBBLi8vL9TTEBGJKMuXL9/rnMvorF9UBEVeXh4+ny/U0xARiShmti2Yfjr1JCIiHVJQiIhIhxQUIiLSIQWFiIh0SEEhIiIdCioozGymmZWaWZmZ3d3G60lmVuS9vtTM8lq8do/XXmpmM1q0P2lme8xsTaux0s3sTTPb5P13wElsn4iInKROg8LM4oFHgMuBAuBGMyto1e1WYJ9zbhTwAHCf994CYC4wFpgJPOqNB/BHr621u4G3nXP5wNvecxERCZFgjiimAmXOuc3OuXpgHjCrVZ9ZwNPe4wXAdDMzr32ec+6oc24LUOaNh3PuPaC6jc9rOdbTwDXBb450p82VB3m3dE+opyEiPSyYoMgG/C2el3ttbfZxzjUCNcDAIN/bWqZzbqf3eBeQ2VYnM7vdzHxm5qusrAxiM+Rk3fS7pXzlqRLeWrc71FMRkR4U1ovZLvCtSm1+s5Jz7gnnXKFzrjAjo9M70OUkle05wK7aOgC+V7SSTysPhnhGItJTggmKCiC3xfMcr63NPmaWAKQBVUG+t7XdZpbljZUF6FxHGCj2lZMQZ7xy53kkJsRx+zM+DtQ1hHpaItIDggmKEiDfzEaYWSKBxemFrfosBG7xHl8PLPaOBhYCc72rokYA+cCyTj6v5Vi3AC8HMUfpRg1Nzbz4cTnTTxvMuJw0Hr7pDLZWHeZ7RatobtZX6YpEu06DwltzuBNYBKwHip1za83sXjO72uv2B2CgmZUBd+FdqeScWwsUA+uA14E7nHNNAGb2J2AJMMbMys3sVm+snwOXmtkm4BLvuYTQ4g172HuwnjmFgYPDs08ZyL9/4TTeWr+bB9/eFOLZiUh3s8A//CNbYWGhU/XY7nPb0yV8Ul7Dh3dfTEJ84N8Wzjl+uOATFiwv58G5E5k1sbNrFEQk3JjZcudcYWf9wnoxW0JvT20d75RWct3knOMhAWBm/Oza0zlzRDo/nP8JJVvbutJZRKKBgkI6tODjcpqa3fHTTi0lJcTz+Jcnk5Pem68/42PL3kMhmKGIdDcFhbTLOcd8XzlT89IZMahvm33690nkqa9MIc6Mrz61jOpD9T08SxHpbgoKaVfJ1n1s2XuIOVP++WiipeED+/K7myezo6aOrz/j40h9Uw/NUER6goJC2lXs89MvKYErxg3ptO/k4ek8eMNEVmzfxzefW059Y3MPzFBEeoKCQtp0oK6Bv36yk6smZNEnMbivVr98XBY/u3Yc75ZW8oP5usdCJFoE9xtAYs5fP9nJkYamNhexO3Lj1GHsP9zAfa9vIK13L+6dNZZAfUgRiVQKCmlTkc9P/uB+TMztf8Lv/eaFp7D/cD2Pv7eZ/n168f3LxnT9BEWkxygo5J9s2n2AFdv38+9fOO1zHw3cffmp1Bxp4DeLy0iMj+Pb0/O7eJYi0lMUFPJPin1+EuKMayZ9/rutAzfkjaO+sZlfvbkRM7jzYoWFSCRSUMhnBAoAVnDJaZkM6pd0UmPFxxm/nD0BgPvf2AgoLEQikYJCPuPt9XuoOlTPnCk5XTJe67AwM+64aFSXjC0iPUNBIZ8x3+cnMzWJ8/O77sugjoWFA365qJT6xma+e0m+roYSiRAKCjlud20d75Tu4RsXnPKZAoBdIT7OuH/2BBLijAff3kTNkQb+48oC4uIUFiLhTkEhxy1YXk6z44TvnQhWfJxx33XjSUnuxZMfbOFAXSP3XTeuy0NJRLqWgkKAYwUA/UwdkU5eOwUAu0JcnPF/rzyNtN69eOCtjRyoa+A3N00iKSG+2z5TRE6O/iknACzbUs3WqsPc0E1HEy2ZGd+5JJ//vKqAN9bt5qtPlVCr798WCVsKCgGg2FfuFQDM6rHP/Oq5I/j1nAks21LN9b/9kIr9R3rss0UkeAoK4UBdA6+u3slVE4bSO7FnTwF98Ywcnv7aVHbur+PaRz5gTUVNj36+iHROQSH8xSsAeEMn3zvRXc4dNYgF3zyHXvFxzHl8Ce9s2BOSeYhI2xQUQlGJn9GZ/ZiQkxayOYwZksJL3zqHkRl9ufXpEp5ZshXnVKZcJBwoKGLcxt0HWOnfz5zC3JDfADc4NZmi28/mojGD+Y+X13LPi6s52qhvyxMJNQVFjCsu8dMr3rj2JAoAdqW+SQk8cXMhd1x0CvNK/Nz0u6Xsqa0L9bREYpqCIobVNzbz0opAAcCBJ1kAsCvFxxk/nHEqj9x0But21HLVw39npX9/qKclErMUFDFs8YbdgQKAPXDvxOfxhfFZvPDNc0iICyxyF5f4Qz0lkZikoIhhxb5yhqQmc/7orisA2NUKhqbyyrfPo3D4AH70wid8v3gVh+sbQz0tkZiioIhRu2rqeLd0D9dNziY+zAvzpfdN5Nlbz+Rfpufz4opyZj38AWV7DoR6WiIxQ0ERo174OFAAcPbk8Dzt1Fp8nHHXpaN55mtTqT5Uz1W/+YCXVpSHeloiMUFBEYOccxT7/JzZzQUAu8O0/Axe/c40xuWk8b2iVfxowSoOHdWpKJHupKCIQUu3VLOt6nDI7sQ+WZmpyTx/25ncedEo5i8v54qH3mf5tn2hnpZI1FJQxKBin5+UpAQuP73nCgB2tYT4OH4wYwxFt59NY5Nj9mMf8us3N9LQ1BzqqYlEnaCCwsxmmlmpmZWZ2d1tvJ5kZkXe60vNLK/Fa/d47aVmNqOzMc1supl9bGYrzezvZqYvWO5CtccKAE7s+QKA3WHqiHRe++40rpmUzUNvb2L2Y0vYsvdQqKclElU6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5m+BLznnJgLPA/9+Ulson/GXVTupa2juke+d6Cmpyb349ZyJPHzTJLbsPcQVD77PUx9soblZtaJEukIwRxRTgTLn3GbnXD0wD5jVqs8s4Gnv8QJgugUKB80C5jnnjjrntgBl3ngdjemAVO9xGrDj822atKXI52dMZgrjQ1gAsLtcOX4or393GmeOTOcnr6xjzuNL+LTyYKinJRLxggmKbKDlLbHlXlubfZxzjUANMLCD93Y05m3Aq2ZWDnwZ+HlbkzKz283MZ2a+ysrKIDZDSncdYJV/P3OmhL4AYHfJSuvNU1+Zwq9mT2DTnoNc/uD7/PbdT2nU2oXI5xaOi9nfA65wzuUATwG/bquTc+4J51yhc64wIyN87ywOJ8W+8CoA2F3MjOsm5/DmXedz0ZgM7nt9A9c++iHrd9aGemoiESmYoKgAWp7QzvHa2uxjZgkEThlVdfDeNtvNLAOY4Jxb6rUXAecEtSXSoWMFAC8tyCS9b2Kop9MjBqck89j/mswjN53Bjv1HuPI3f+e/Xl2v+y5ETlAwQVEC5JvZCDNLJLA4vbBVn4XALd7j64HFLvCtMwuBud5VUSOAfGBZB2PuA9LMbLQ31qXA+s+/eXLM2+t3U32ontlRtIgdDDPjC+OzeOuuC5hTmMMT721m+q/+xmurd+qLkUSClNBZB+dco5ndCSwC4oEnnXNrzexewOecWwj8AXjWzMqAagK/+PH6FQPrgEbgDudcE0BbY3rtXwdeMLNmAsHxtS7d4hhV5PMHCgDmx+ZpugF9E/nvL45ndmEu//bSGr753MdcMDqDn1w9NuLuThfpaRYN/6oqLCx0Pp8v1NMIWztrjnDuzxfzrQtH8YMZY0I9nZBrbGrm2Y+28as3NlLf1Mw3LjiFb1wwkj6Jnf67SSSqmNly51xhZ/3CcTFbutgLy70CgIU5oZ5KWEiIj+Or547g7e9fwMyxQ3jo7U1cfP/fePHjct17IdIGBUWUa252FPvKOWtkOsMH6hRLS5mpyTx04yQWfONsMlOTuKt4Fdc8+gG+rdWhnppIWFFQRLmlW6rZXh25BQB7QmFeOi9961weuGECe2qPcv1jS7jj+Y/xVx8O9dREwoJOyka5+T4/KcmRXQCwJ8TFGddOymHG2CE8/rfNPP7ep7y5djdfOmsYd1w0ikFh9J3iIj1NRxRRrLaugVfX7OTqCUNJ7hX5BQB7Qp/EBL536Wje+cGFXDspm6c/3MoFv3iHX7+5kQN1DaGenkhIKCii2CurdgQKAOq00wnLSuvNfdeP543vXcAFYzJ46O1NnP+Ld/j9+5upa2gK9fREepSCIooVl/g5dUgK47KjrwBgTxk1uB+PfmkyC+88l9Oz0/jpX9dz0f3v8vzS7dQ3qn6UxAYFRZTasKuWVeU1zCmM3gKAPWl8Tn+evfVMnr/tTDJTk/k/L63mwl++w7NLtnK0UUcYEt0UFFGquKScXvHGNVFeALCnnTNqEC996xye+dpUsvr35v++vJYLfvEuf/xgi05JSdRSUEShQAHAci4rGBIzBQB7kplx/ugMFnzjbJ677UyGpffhx6+sY5q3hnG4XkUHJbro8tgo9Nb63ew73KA7sbuZmXHuqEGcO2oQH22u4qG3N/HTv67n4XfKuPms4dx8Tp4uq5WooKCIQkUlfrLSkpkWowUAQ+GskQM5a+RAlm+r5rG/beahxWU8/t5mrp+cw9enjVThQYloCooos2P/Ed7bVMmdF40iPk6L2D1t8vB0fndzOmV7DvL79zcz31fO88u2M3PsEG4/fySThg0I9RRFTpiCIsq8sLwc52D2ZN07EUqjBvfj59eN565LR/PHD7fy/z7axmtrdjE1L51bzsnjsrGZ9IrXEqFEBpUZjyLNzY4L73+XnAG9ef7rZ4V6OtLCwaONzFu2naeXbMVffYQhqcl8+ezhzJ2Sy0CtY0iIqMx4DPpoSxXbqw8zJ8a+xS4S9EtK4LZpI3n3Bxfxu5sLGTW4H79cVMrZP1/M94tXsbq8JtRTFGmXTj1Fkfm+clKSE5h5+pBQT0XaER9nXFqQyaUFmZTtOcDTH27jhY/LeeHjcs4Y1p8vnz2cy0/PUm0uCSs69RQlao40MPVnbzG7MIefXjMu1NORE1Bb18ACXznPLNnK1qrDpPXuxbWTsrlx6jDGDEkJ9fQkigV76klHFFHilVU7ONrYzA2Fw0I9FTlBqcm9+Np5I/jKOXl8tLmK55dt57ml2/jjh1s5Y1h/bpw6jCvHD6V3oo4yJDR0RBElrn7479Q3NvPad6aptlMUqDp4lBc/ruBPJdvZXHmIlOQErpmYzQ1Tchk7NFX7WLqEjihiyPqdtXxSXsN/XlWgXyBRYmC/JL5+/khumzaCZVuq+dOy7RT5/Dz70TbGZKZw3eRsZk3MJjM1OdRTlRigoIgCxT4/ifFxXDNRBQCjjZlx5siBnDlyID8+XM8rn+zkxY/L+a9XN/Dz1zZwXn4G152RzWUFQ3RqSrqNgiLCHW1s4s8rKrh0bCYDVAAwqvXvk8iXzxrOl88azqeVB3np4wpeWlHBd+atpF9SAl8Yl8UXz8hmSl46cborX7qQgiLCvbVuD/sON+jeiRhzSkY/fjBjDHddOpqPtlTx4scV/OWTHRT5AnW+rhyfxZXjhzI+J02nI+WkaTE7wt385DLKdh/g/X+9WLWdYtzh+kbeWLubv3yyg79trKShyTEsvQ9XTcjiqglDGZOZotCQz9BidgzYsf8I72+q5NsqAChAn8QErpmUzTWTsqk53MCitbt45ZMdPPa3zTzyzqeMGtyPq8YP5coJWZyS0S/U05UIoqCIYAuOFQDUaSdpJa1PL+ZMyWXOlFz2HjzKa2t28ZdVO/iftzfywFsbOXVICpeNHcKMsZkUZOlyW+mYTj1FqOZmxwX3v8Ow9D48d5sKAEpwdtXU8erqnby+dhe+rdU0O8hN782MgiHMOH0IZwwboKPTGKJTT1Huo81V+KuP8IPLxoR6KhJBhqQl87XzRvC180aw9+BR3lq3m0Vrd/HMkm38/u9bGNQviUsLMpl5+hDOHjmQxATVDRUFRcQq9vlJTU5gxlgVAJTPZ1C/JOZOHcbcqcM4UNfAO6WVLFq7i5dXVvCnZdtJSUrg/NEZXHTqYC4ck6GvdY1hQQWFmc0EHgTigd87537e6vUk4BlgMlAF3OCc2+q9dg9wK9AE/ItzblFHY1rgZOlPgdnee37rnHvo5DYzutQcaeC1NbuYU5irKqPSJVKSe3H1hKFcPWEodQ1NfFC2lzfW7uad0j38dfVOzGBCTn+mnzqYi04drDIiMabToDCzeOAR4FKgHCgxs4XOuXUtut0K7HPOjTKzucB9wA1mVgDMBcYCQ4G3zGy09572xvwKkAuc6pxrNrPBXbGh0WThsQKAU7SILV0vuVc800/LZPppmTQ3O9btrOXt9XtYXLqHX725kV+9uZEhqclcdGoGF5+aybmjBtInUScnolkwe3cqUOac2wxgZvOAWUDLoJgF/Nh7vAB42DsymAXMc84dBbaYWZk3Hh2M+U3gJudcM4Bzbs/n37zoVFzi57SsVMYOTQ31VCTKxcUZp2encXp2Gt+5JJ89B+p4t7SSdzbsYeHKHfxpmZ/EhDim5qUzLX8Q54/O4NQhul8j2gQTFNmAv8XzcuDM9vo45xrNrAYY6LV/1Oq9xwoStTfmKQSORq4FKgmcrtrUelJmdjtwO8CwYbFTWnvdjlpWV9TwYxUAlBAYnJLMnMJc5hTmUt/YTMnWahZv2MP7myr579c28N+vbSAjJYlpowYxbfQgzhuVQUaK1jYiXTgeLyYBdc65QjP7IvAkMK11J+fcE8ATELg8tmenGDrHCgDOUgFACbHEhDjOHTWIc0cNAgKX3r6/qZL3N+3l3Y2VvLiiAoDTslI5P38Q0/IzKMwboHW1CBRMUFQQWDM4Jsdra6tPuZklAGkEFrU7em977eXAi97jl4CngphjTDja2MSfV1ZwmQoAShgakpbM7MJcZhfmHl/beG9TJe9v3MuTH2zh8fc2k5QQR2HeAM4eOZCzRg5kfE5/XYIbAYIJihIg38xGEPhlPhe4qVWfhcAtwBLgemCxc86Z2ULgeTP7NYHF7HxgGWAdjPln4CJgC3ABsPFzb12UeXPdbvarAKBEgJZrG9+6cBSHjjaybEs172/ay5LNVdz/RuB/69694gPBccpAzh45kHHZaSTEKzjCTadB4a053AksInAp65POubVmdi/gc84tBP4APOstVlcT+MWP16+YwCJ1I3CHc64JoK0xvY/8OfCcmX0POAjc1nWbG9mKSvxk9+99/FBfJFL0TUrgIu/SWoB9h+pZuqWKJZ9WsWRzFb94vRSAfkkJTDkeHIMoGJqqO8XDgEp4RIiK/Uc4777FfPvifO66dHTnbxCJIHsPHuWjzf8Ijs2VhwBISUrgjOEDmJI3gMK8dCbm9tcaRxdSCY8os8BXDsDsyTkhnolI1xvUL4krxw/lyvFDAdhdW8dHm6tYtqUa39Z9x09V9YoPnNKakpdO4fBAeKRrva7bKSgiQHOzY/5yP+eeMojc9D6hno5It8tMTWbWxOzjV/ftP1zP8m37KNm6D9/Wav74wVaeeG8zAKdk9A0Ehxcewwf20aXjXUxBEQGWbK6ifN8RfjhDBQAlNvXvk3j8bnGAuoYmVlfUULI1cMTx6uqdzCsJ3JqV3jeRibn9mZTbn0nDBjA+N43U5F6hnH7EU1BEABUAFPms5F7xTMlLZ0peOhA46t605yC+bdWs3L6fFf79LN4QKOpgBqMy+gXCY9gAJub2Z3RmP11ddQIUFGGu5nCgAODcKSoAKNKeuDhjzJAUxgxJ4UtnDgcCxTM/Kd/Piu37Wenfz1vrdzN/eWCtr09iPONz0piYGwiOCblpDElN1imrdigowtzCVRXUNzbr3gmRE5TWuxfT8jOYlp8BgHOObVWHWenfz4rt+1jp38/v399MY3Pgys9B/RI5PTuN8d79H+NyFB7HKCjCXJHPT0FWKqdnp4V6KiIRzczIG9SXvEF9uWZSYJG8rqGJtTtqWVNRw+qKGlaX1/Dexkq87GBQvyTGZacyLqc/47LTGJ+TRmZqcgi3IjQUFGFs7Y4a1lTU8pOrx4Z6KiJRKblXPJOHD2Dy8AHH247UN7FuZyA0VlfUsrpiP39rER4ZKUmM8446CrwqzjkDekf1kYeCIozN95WTmBDHrIlDQz0VkZjROzGeycPTmTw8/Xjb4fpG1u+s5ZPywJHHmooa3i3dczw8UpISOC0rldOyUjgtK5WCoamMzkyJmnVFBUWYqmto4qUVFcwYO4T+fXRDkUgo9UlM+KfwOFLfROnuA6zbUcv6nbWs21nLguXlHKpvAiDO4JSMfseD41iQDE6JvFNXCoow9ea63dQcaWBOoe7EFglHvRPjmZjbn4m5/Y+3NTc7tlcfZv3Of4TH8m37WLhqx/E+g/olcVpWCmMyUxg9JIXRmSnkD+5H36Tw/XUcvjOLccU+rwDgKSoAKBIp4uL+sWB++bis4+37D9ezfueB4+Gxfmctz360jaONzcf75Kb3ZkxmCvmZXohkpnDK4L4kJYT+9JWCIgyV7zvM38v28p3p+cSpcqZIxOvfJzFQEfeUgcfbmpod/urDlO4+wMZdByjdfYBNuw/ybmnl8Ut24+OMvIF9GO0Fx5ghKYzO7EfewL49esOggiIMLfBuCrpeBQBFolZ8i6OPllUX6hub2Vp1iI0tAmTDrgMsWrvr+OJ5YnwcIwb1ZdTgftx9+andXgNOQRFmmpsd833lnDdqEDkDVABQJNYkJsQdP4Jg/D/a6xqaKNtzMBAguw9Stucga3fU9Mg3BCoowsyHn1ZRsf8I/3r5qaGeioiEkeRe8ce/NbCnqSpWmCn2+Unr3YvLCjJDPRUREUBBEVZqDjfw+tpdXDNxaNTcqCMikU9BEUZePlYAcIoKAIpI+FBQhJGiEj9jh6YydqgKAIpI+FBQhIk1FTWs3VHLDTqaEJEwo6AIE/N9/kABwAnZoZ6KiMhnKCjCQF1DE39euYOZY4eQ1kff7Ssi4UVBEQbeOF4AUKedRCT8KCjCQHGJn5wBvTmnRR0YEZFwoaAIMX/1YT74dC+zJ+eqAKCIhCUFRYgdLwCo750QkTCloAih5mbHguWBAoDZ/XuHejoiIm1SUITQB5/upWL/ES1ii0hYU1CEULGvnP59enHZWBUAFJHwpaAIkf2H61m0dhfXTMwOi686FBFpT1BBYWYzzazUzMrM7O42Xk8ysyLv9aVmltfitXu89lIzm3ECYz5kZgc/53aFvZdX7ggUANRpJxEJc50GhZnFA48AlwMFwI1mVtCq263APufcKOAB4D7vvQXAXGAsMBN41MziOxvTzAqBASe5bWGtqMTP6dmpFAxNDfVUREQ6FMwRxVSgzDm32TlXD8wDZrXqMwt42nu8AJhuZua1z3POHXXObQHKvPHaHdMLkV8CPzq5TQtfaypqWLezlht0NCEiESCYoMgG/C2el3ttbfZxzjUCNcDADt7b0Zh3Agudczs7mpSZ3W5mPjPzVVZWBrEZ4aPYKwB4tQoAikgECKvFbDMbCswGftNZX+fcE865QudcYUZGRvdProvUNTTx5xUVXH66CgCKSGQIJigqgJbnSHK8tjb7mFkCkAZUdfDe9tonAaOAMjPbCvQxs7IgtyUiLFq7i9q6Ri1ii0jECCYoSoB8MxthZokEFqcXtuqzELjFe3w9sNg557z2ud5VUSOAfGBZe2M65/7qnBvinMtzzuUBh70F8qhR7POTm96bs0eqAKCIRIaEzjo45xrN7E5gERAPPOmcW2tm9wI+59xC4A/As96//qsJ/OLH61cMrAMagTucc00AbY3Z9ZsXXvzVh/mgrIq7Lh2tAoAiEjE6DQoA59yrwKut2v6jxeM6AmsLbb33Z8DPghmzjT79gplfpJi/vBwzuG6yCgCKSOQIq8XsaNbU7Fjg8zMtP0MFAEUkoigoesgHZXvZUVPHHJUTF5EIo6DoIcU+P/379OLSAhUAFJHIoqDoAfsO1fPG2t0qACgiEUlB0QNeXllBfZMKAIpIZFJQdDPnHEW+csZlp6kAoIhEJAVFN1tTUcv6nbXMmaKjCRGJTAqKblbs85OUEMfVE4aGeioiIp+LgqIb1TU08eeVXgHA3ioAKCKRSUHRjRat3cWBukaddhKRiKag6EZFJYECgGeNUAFAEYlcCopu4q8+zIefVjFncq4KAIpIRFNQdJP5Pr8KAIpIVFBQdIOmZseC5eWcn5/BUBUAFJEIp6DoBn8/XgBQi9giEvkUFN2g2OdnQJ9eXFIwONRTERE5aQqKLrbvUD1vrt3NNZNUAFBEooOCoou9tCJQAPAG3TshIlFCQdGFnHMU+/yMz0nj1CEqACgi0UFB0YVWV9SwYdcBLWKLSFRRUHShYwUAr1IBQBGJIgqKLlLX0MTLK3dwxbgsFQAUkaiioOgir6/xCgDqtJOIRBkFRRcpKvEzLL0PZ45ID/VURES6lIKiC2yvOsySzVXMKcxRAUARiToKii4wf7mfOBUAFJEopaA4SccLAI7OICtNBQBFJPooKE7S+5sq2akCgCISxRQUJ2m+r5z0volcclpmqKciItItFBQnofpQPW+s28U1E7NJTNBfpYhEp6B+u5nZTDMrNbMyM7u7jdeTzKzIe32pmeW1eO0er73UzGZ0NqaZPee1rzGzJ80sbO9ee2lFBQ1NTgUARSSqdRoUZhYPPAJcDhQAN5pZQatutwL7nHOjgAeA+7z3FgBzgbHATOBRM4vvZMzngFOBcUBv4LaT2sJu4pxjvs/PhJw0xgxJCfV0RES6TTBHFFOBMufcZudcPTAPmNWqzyzgae/xAmC6mZnXPs85d9Q5twUo88Zrd0zn3KvOAywDwvKa00/KvQKAOpoQkSgXTFBkA/4Wz8u9tjb7OOcagRpgYAfv7XRM75TTl4HX25qUmd1uZj4z81VWVgaxGV2r2OcnuZcKAIpI9AvnFdhHgfecc++39aJz7gnnXKFzrjAjI6NHJ3akvomFK3dwxelZpCaH7RKKiEiXSAiiTwXQ8vxKjtfWVp9yM0sA0oCqTt7b7phm9p9ABvC/g5hfj3t97U4OHG3UaScRiQnBHFGUAPlmNsLMEgksTi9s1WchcIv3+HpgsbfGsBCY610VNQLIJ7Du0O6YZnYbMAO40TnXfHKb1z2KSvwMH6gCgCISGzo9onDONZrZncAiIB540jm31szuBXzOuYXAH4BnzawMqCbwix+vXzGwDmgE7nDONQG0Nab3kY8B24AlgfVwXnTO3dtlW3yStlUd4qPN1fxwxhi8+YmIRLVgTj3hnHsVeLVV23+0eFwHzG7nvT8DfhbMmF57UHMKlfm+8kABwDPC8mIsEZEuF86L2WHnWAHAC0ZnMCQtOdTTERHpEQqKE/Depkp21aoAoIjEFgXFCZjv85PeN5HpKgAoIjFEQRGkqoNHeXPdbq6dpAKAIhJb9BsvSMcKAOq0k4jEGgVFEJxzFPv8TMjtrwKAIhJzFBRBWFVew8bdB7lBRxMiEoMUFEH4RwHArFBPRUSkxykoOnGkvolXVu7ginFZpKgAoIjEIAVFJ15bEygAqNNOIhKrFBSdKCrxkzewD1NVAFBEYpSCogNb9x5i6ZZqZhfmqgCgiMQsBUUH5i/3qwCgiMQ8BUU7jhUAvHDMYBUAFJGYpqBox3sbK9lde5Q5hTqaEJHYpqBoR1GJn4F9E7n4VBUAFJHYpqBoQ9XBo7y1XgUARURAQdGml1ZU0NjsmDNF906IiCgoWnHOUVTiZ2Juf0ZnqgCgiIiCopWV/v1s2nOQG3Q0ISICKCj+SbGvnN694rlyvAoAioiAguIzDtc38soqFQAUEWlJQdHCa6t3cfBoo047iYi0oKBoocjnZ8SgvkzJGxDqqYiIhA0FhWfL3kMs21LN7MIcFQAUEWlBQeGZ71MBQBGRtigogMamZl74uJyLxgwmM1UFAEVEWlJQAO9tChQAnK1vsRMR+ScKCgIFAAf1S2T6aYNDPRURkbAT80Gx9+BR3l6/h2snZdMrPub/OkRE/knM/2Z86eNAAUDdOyEi0raggsLMZppZqZmVmdndbbyeZGZF3utLzSyvxWv3eO2lZjajszHNbIQ3Rpk3ZuJJbmO7nHMU+/ycMaw/owarAKCISFs6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5n3AA95Y+7yxu8UKrwDgHC1ii4i0K5gjiqlAmXNus3OuHpgHzGrVZxbwtPd4ATDdAnetzQLmOeeOOue2AGXeeG2O6b3nYm8MvDGv+dxb14n5Pn+gAOCEod31ESIiES+YoMgG/C2el3ttbfZxzjUCNcDADt7bXvtAYL83RnufBYCZ3W5mPjPzVVZWBrEZ/2xYel++cm4e/ZISPtf7RURiQcT+hnTOPQE8AVBYWOg+zxjfvPCULp2TiEg0CuaIogJoeRI/x2trs4+ZJQBpQFUH722vvQro743R3meJiEgPCiYoSoB872qkRAKL0wtb9VkI3OI9vh5Y7JxzXvtc76qoEUA+sKy9Mb33vOONgTfmy59/80RE5GR1eurJOddoZncCi4B44Enn3FozuxfwOecWAn8AnjWzMqCawC9+vH7FwDqgEbjDOdcE0NaY3kf+KzDPzH4KrPDGFhGRELHAP+IjW2FhofP5fKGehohIRDGz5c65ws76xfyd2SIi0jEFhYiIdEhBISIiHVJQiIhIh6JiMdvMKoFtn/Ptg4C9XTidSKBtjg3a5uh3sts73DmX0VmnqAiKk2FmvmBW/aOJtjk2aJujX09tr049iYhIhxQUIiLSIQWFV1gwxmibY4O2Ofr1yPbG/BqFiIh0TEcUIiLSIQWFiIh0KKaDwsxmmlmpmZWZ2d2hns+JMLNcM3vHzNaZ2Voz+47Xnm5mb5rZJu+/A7x2M7OHvG39xMzOaDHWLV7/TWZ2S4v2yWa22nvPQ95X1Yac973rK8zsL97zEWa21JtnkVe6Hq+8fZHXvtTM8lqMcY/XXmpmM1q0h93PhJn1N7MFZrbBzNab2dnRvp/N7Hvez/UaM/uTmSVH2342syfNbI+ZrWnR1u37tb3P6JBzLib/EChv/ikwEkgEVgEFoZ7XCcw/CzjDe5wCbAQKgF8Ad3vtdwP3eY+vAF4DDDgLWOq1pwObvf8O8B4P8F5b5vU1772Xh3q7vXndBTwP/MV7XgzM9R4/BnzTe/wt4DHv8VygyHtc4O3vJGCE93MQH64/EwS+O/4273Ei0D+a9zOBrz/eAvRusX+/Em37GTgfOANY06Kt2/dre5/R4VxD/T9BCH8YzwYWtXh+D3BPqOd1EtvzMnApUApkeW1ZQKn3+HHgxhb9S73XbwQeb9H+uNeWBWxo0f6ZfiHczhzgbeBi4C/e/wR7gYTW+5XA952c7T1O8PpZ6319rF84/kwQ+LbILXgXnrTef9G4nwkEhd/75Zfg7ecZ0bifgTw+GxTdvl/b+4yO/sTyqadjP4zHlHttEcc71J4ELAUynXM7vZd2AZne4/a2t6P28jbaQ+1/gB8Bzd7zgcB+51yj97zlPI9vm/d6jdf/RP8uQmkEUAk85Z1u+72Z9SWK97NzrgK4H9gO7CSw35YT3fv5mJ7Yr+19RrtiOSiigpn1A14Avuucq235mgv8kyFqrn82syuBPc655aGeSw9KIHB64rfOuUnAIQKnC46Lwv08AJhFICSHAn2BmSGdVAj0xH4N9jNiOSgqgNwWz3O8tohhZr0IhMRzzrkXvebdZpblvZ4F7PHa29vejtpz2mgPpXOBq81sKzCPwOmnB4H+Znbsa31bzvP4tnmvpwFVnPjfRSiVA+XOuaXe8wUEgiOa9/MlwBbnXKVzrgF4kcC+j+b9fExP7Nf2PqNdsRwUJUC+dyVFIoFFsIUhnlPQvCsY/gCsd879usVLC4FjVz7cQmDt4lj7zd7VE2cBNd7h5yLgMjMb4P1L7jIC5293ArVmdpb3WTe3GCsknHP3OOdynHN5BPbXYufcl4B3gOu9bq23+djfxfVef+e1z/WulhkB5BNY+Au7nwnn3C7Ab2ZjvKbpBL6DPmr3M4FTTmeZWR9vTse2OWr3cws9sV/b+4z2hXLRKtR/CFxJsJHAFRD/Fur5nODczyNwyPgJsNL7cwWBc7NvA5uAt4B0r78Bj3jbuhoobDHW14Ay789XW7QXAmu89zxMqwXVEG//hfzjqqeRBH4BlAHzgSSvPdl7Xua9PrLF+//N265SWlzlE44/E8BEwOft6z8TuLolqvcz8BNggzevZwlcuRRV+xn4E4E1mAYCR4639sR+be8zOvqjEh4iItKhWD71JCIiQVBQiIhIhxQUIiLSIQWFiIh0SEEhIiIdUlCIiEiHFBQiItKh/w/uhegfvR+Q7QAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xs = list(range(100000))\n",
"plt.plot(xs, lrs)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4f4e282c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/wenet/venv/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"from typing import Union\n",
"\n",
"from paddle.optimizer.lr import LRScheduler\n",
"from typeguard import check_argument_types\n",
"\n",
"class WarmupLR(LRScheduler):\n",
" \"\"\"The WarmupLR scheduler\n",
" This scheduler is almost same as NoamLR Scheduler except for following\n",
" difference:\n",
" NoamLR:\n",
" lr = optimizer.lr * model_size ** -0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" WarmupLR:\n",
" lr = optimizer.lr * warmup_step ** 0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" Note that the maximum lr equals to optimizer.lr in this scheduler.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" warmup_steps: Union[int, float]=25000,\n",
" learning_rate=1.0,\n",
" last_epoch=-1,\n",
" verbose=False):\n",
" assert check_argument_types()\n",
" self.warmup_steps = warmup_steps\n",
" super().__init__(learning_rate, last_epoch, verbose)\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n",
"\n",
" def get_lr(self):\n",
" step_num = self.last_epoch + 1\n",
" return self.base_lr * self.warmup_steps**0.5 * min(\n",
" step_num**-0.5, step_num * self.warmup_steps**-1.5)\n",
"\n",
" def set_step(self, step: int):\n",
" self.step(step)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8c40b202",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1\n"
]
}
],
"source": [
"sc = WarmupLR(warmup_steps=25000, learning_rate=0.001)\n",
"print(step)\n",
"#sc.set_step(step)\n",
"sc.set_step(0)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ecbc7e37",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f0ba6dd9c40>]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqaUlEQVR4nO3de3xU9Z3/8dcnCUm4JIGEEAIBEiCAQW4SEG94F7QqagGhu9Varb9a3W51267+tr/dtrvdVevW1VardrVaa4WAN7QqKqJ4QchwvwYiAZMQICQQ7uT2/f0xB4xpLoMkmcnM+/l48GDmO99z5ns4Yd4553vOZ8w5h4iISHOigj0AEREJbQoKERFpkYJCRERapKAQEZEWKShERKRFMcEeQFvo3bu3y8zMDPYwREQ6lRUrVux1zqW21i8sgiIzMxOfzxfsYYiIdCpmtiOQfjr1JCIiLVJQiIhIixQUIiLSIgWFiIi0SEEhIiItCigozGyqmRWYWaGZ3dvE63FmNtd7fZmZZTZ47T6vvcDMpjRof8bM9pjZ+kbrSjazd81sq/d3r9PYPhEROU2tBoWZRQOPAVcCOcBsM8tp1O1WYJ9zbijwMPCAt2wOMAsYCUwFHvfWB/Cs19bYvcAi51w2sMh7LiIiQRLIEcVEoNA5t805Vw3MAaY16jMNeM57PB+41MzMa5/jnDvunCsCCr314ZxbAlQ28X4N1/UccF3gmyPtaVv5IT4o2BPsYYhIBwskKPoDxQ2el3htTfZxztUCVUBKgMs2luacK/Me7wLSmupkZrebmc/MfOXl5QFshpyuWU99xnf+mM+iTbuDPRQR6UAhPZnt/N+q1OQ3KznnnnLO5TrnclNTW70DXU7T1t0H2XPwOAA/mrOaz8sPBXlEItJRAgmKUmBAg+cZXluTfcwsBkgCKgJctrHdZpburSsd0LmOEJDnKyYmynj9rvPpEhPF7X/ycfBYTbCHJSIdIJCgyAeyzSzLzGLxT04vaNRnAXCz93g68L53NLAAmOVdFZUFZAPLW3m/huu6GXgtgDFKO6qpq+fllaVcdkYaozKSeOxbZ7G94gj35K2hvl5fpSsS7loNCm/O4S5gIbAJyHPObTCzX5rZtV63p4EUMysE7sG7Usk5twHIAzYCbwN3OufqAMzsRWApMNzMSszsVm9d9wOXm9lW4DLvuQTRok17qDhczcwJGQCcMySFf7nqDN7duJtH398a5NGJSHsz/y/+nVtubq5T9dj2c+uz+azfWcUn/3wJMdH+3y2cc/x43lpeWlnCI7PGMm1sa9coiEioMbMVzrnc1vqF9GS2BN/uA8dYXLCHb56VcTIkAMyM/7zhTCZmJfOTeWvJ397Ulc4iEg4UFNKil1aWUO9gZu6Av3ktLiaap749noxeXbn9Tz627z0chBGKSHtTUEiznHPM85UwMSuZzN7dm+zTs1ssf7xlAmbGLc/ms+9wdQePUkTam4JCmrW8qJKivYe5sYmjiYYGpXTnDzeNp3T/Ub73Jx9Hq+s6aIQi0hEUFNKsPF8JPeJiuHJU31b7jh+UzP/cOJYVX+zjBy+soKauvgNGKCIdQUEhTTp4rIY315VxzZh+dIsN7KvVrxqVzn9eP4rFBeX8eJ7usRAJF4F9AkjEeWNtGUdr6rhxQsunnRqbPXEg+45U8+DbBSR17cIvrh2Jvz6kiHRWCgpp0tz8Yoal9WBMRtIpL3vHhUPYf6SGp5Zso2fXLtxzxfB2GKGIdBQFhfyNLbsPsrp4Pz/7xhlf62jAzLjvyhFUHanh0fcLiY2J4q5LstthpCLSERQU8jfy8ovpEm1cP+7r323tvyFvFDV19Tz0zhbMjDsvHtqGoxSRjqKgkK+orq3nlVX+AoApPeJOa13RUcavZ4zBAb9eWACgsBDphBQU8hXvb97tLwDYyr0TgYqOMh6aMQbwh4UZ/OAihYVIZ6KgkK/I85XQNzGeycPa7sugToSFc44H3y6gptbxw0uH6mookU5CQSEn7ao6xgcFe7jjoiFER7Xth3h0lPHfM8cSEx3Fw+9toepoDT/7xhlEtfH7iEjbU1DISScKAM4Y3zannRqLjjIe/OZoEuJjeOaTIg4cq+H+G0Z9pSqtiIQeBYUAJwoAFnN2CwUA20JUlPGvV+eQ1LUL//PeVg4dq+WR2WOJi4lut/cUkdOjX+UEgGVFlWyvOHLKd2J/HWbGjy4bxr9encPbG3bx3WfzOaDv3xYJWQoKASDPV0xCXAxXnpneYe/53fOz+O8ZY1i2rZIZv1/Kzv1HO+y9RSRwCgrhwIkCgGP70TW2Y08BfXN8Bs/eMpGd+49y3WOfsL60qkPfX0Rap6AQ3lhTxrGa+ja7d+JUnZ/dm3l3nENMlHHjk0tZXLAnKOMQkaYpKIS5vmKGpyV8rQKAbWVE30ReufM8BqV057bnfDy/dDvOqUy5SChQUES4gl0HWVO8n5kTBgT9Bri0xHjyvn8OFw5L5f+9toH/+8o6qmv1BUgiwaagiHB5vtMvANiWesTF8IebcvnBRUN4cXkxs//wGXsOHgv2sEQimoIigp0oAHh5ThrJ3WODPZyToqOMn04dwe++NY6NOw9w7W8/YW3J/mAPSyRiKSgi2KJNu6k8XM2MIE1it+bq0f2Yf8c5REcZ059YyjxfcbCHJBKRFBQRLM9X7C8AmN12BQDb2sh+SSy46zxyB/XiJ/PX8uN5azhaXRfsYYlEFAVFhNpVdYwPt5QzfXxGmxcAbGspPeJ4/taz+eElQ3lpZQnTHvuYwj0Hgz0skYihoIhQJwsA5mYEeygBiY4y7rliOM/dMpGKQ9Vc+7tPeGVVSbCHJRIRFBQRqL7ekecrZtLgZAaltF8BwPYweVgqf/3hBZzZL4m7567hp/PXcPh4bbCHJRLWFBQRaPn2SnZ0UAHA9tA3KZ6/fO9sfnDREOatKOGqRz9i5Rf7gj0skbCloIhAefn+AoBTR3ZcAcC2FhMdxU+njmDO9yZRW+eY8cRSHn53C7V1ukFPpK0FFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTWlunmV1qZivNbLWZfWxm+oLlNnTgWA1vri/j2iAUAGwPZw9O4a0fXcC0Mf14ZNFWpj+xlKK9h4M9LJGw0mpQmFk08BhwJZADzDaznEbdbgX2OeeGAg8DD3jL5gCzgJHAVOBxM4tuZZ2/B/7OOTcW+Avws9PaQvmK19fsDGoBwPaQGN+F39w4lt/OHse28kNc9chHPPtJEfX1qhUl0hYCOaKYCBQ657Y556qBOcC0Rn2mAc95j+cDl5q/cNA0YI5z7rhzrggo9NbX0jodkOg9TgJ2fr1Nk6bk5Rczom8Co4NYALC9XDOmHwvvnszErGR+/vpGZj65lG3lh4I9LJFOL5Cg6A80vCW2xGtrso9zrhaoAlJaWLaldd4GvGlmJcC3gfubGpSZ3W5mPjPzlZeXB7AZsnnXAdaUVDEzN/gFANtLelJXnr1lAg/NGMOW3QeZ+shHPPHh55q7EDkNoTiZfTdwlXMuA/gj8JumOjnnnnLO5TrnclNTQ/fO4lCSl19Cl2jjuhApANhezIzp4zN4754LuXh4Kve/tZkbfv8pm3cdCPbQRDqlQIKiFGh4QjvDa2uyj5nF4D9lVNHCsk22m1kqMMY5t8xrnwucG9CWSIv8BQBLuCKnb0gVAGxPfRLjeeLvx/O7b42jdN9RvvHox/znm5t034XIKQokKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzv/N86swCY5V0VlQVkA8tbWOc+IMnMhnnruhzY9PU3T054b9Nu9h2p6TR3YrcVM+Pq0f14754LmTE+g6eWbOOy33zIW+vK9MVIIgGKaa2Dc67WzO4CFgLRwDPOuQ1m9kvA55xbADwNPG9mhUAl/g9+vH55wEagFrjTOVcH0NQ6vfbvAS+ZWT3+4Phum25xhMrzFZOeFM8FIVwAsD316h7L/d8czYzcAfzs1fXc8cJKLhyWyi+uHUlm7851d7pIR7Nw+K0qNzfX+Xy+YA8jZJVVHeW8+9/nzouH8k9XDA/2cIKutq6ePy3dwW/e3UJ1XT3fv3AI379wMN1iW/29SSSsmNkK51xua/1CcTJb2thLK7wCgOPD596J0xETHcV3z89i0T9dyJSRfXl00VYueehDXl5ZonsvRJqgoAhz/gKAJZwzOIWBKd2CPZyQkpYYz29nj2Pe98+hT2Ic9+St4brHP8G3vTLYQxMJKQqKMLesqJIvKjtvAcCOMCEzmVd/cB6/mTmG3QeOMf2Jpdz5l5UUVx4J9tBEQoJOyoa5PF8xCfExTD2zb7CHEtKioowbzspg6pl9efLDbTy55HPe3bCbv580iDsvHkJKj7hgD1EkaHREEcaqjtbw5roypo3tR3yXzl8AsCN0i43h7suHsfjHF3H9uP48+2kRkx9czMPvbuHgsZpgD08kKBQUYez1NTs5XhteBQA7SnpSVx6YPpp37p7M5GGpPLJoKxf++gOe/riIYzX6zm6JLAqKMJbn8xcAHNU//AoAdpShfRL4/d+P57U7zyMnPZF/f2Mjlzz0AS8u/4LqWtWPksigoAhTm8oOsDbMCwB2pDEDevLn287mhdvOJjUxnvteXsfFD33Anz/bwfFaHWFIeFNQhKk8XzGx0VFcH+YFADvaeUN78+oPzuXZWybQJzGOn726ngsf/IDnPt2uU1ISthQUYeh4bR2vrirl8pFp9IqQAoAdycy4aHgfXr7jXP5869kMSO7Kvy3YwOQHF/P0x0UcrVZgSHjR5bFh6L2Ne9h3pEaT2O3MzDg/uzfnDU1h6bYKHl20lX9/YyO/e38rN52TyU3nDNJltRIWFBRhKM9XTL+keM4f2jvYQ4kIZsa5Q3pz7pDe5G+v5MkPP+eRRVt5csnnzBg/gNsuyGJQigoPSueloAgzO/cfZcnWcv7h4qFER2kSu6NNyExmQmYyhXsO8tSSbczNL+aFZTu48sx0bp88mDEDegZ7iCKnTEERZl5aUYJzMEOnnYJqaJ8EHpw+hh9fMZw/frqdP3+2g7+uK2NiVjLfOTeTK3LSiInWFKF0DiozHkbq6x0XPrSYAb268ZfvTQr2cKSBQ8drmbP8C579dDsl+47SLymev5s0iNkTB0bMNw5K6FGZ8Qj0WVEFxZVHVQAwBPWIi+G2Cwbz4U8u5g835ZKV2p1fLyxg0n8t4sfz1rC+tCrYQxRplk49hZG8fH8BwCkjVQAwVEVHGZfnpHF5Thpbdx/kuaXbeXllKfNXlDB+UC9uOmcQU0b2VW0uCSkKijBRdbSGt9bvYmbuAH3IdBLZaQn8x3Wj+MmUEcxfUcLzS7fzj3NW07NbF24Yl8HsiQPITksI9jBFFBThYoEKAHZaSV27cOv5WdxybiZLt1Xw4vIveP6z7TzzSRG5g3oxe+JArhqVTtdY/QIgwaHJ7DBxzW8/prbe8eYPz1dtpzBQceg4L68s5cXlX7Bt72ES4mO4YVx/Zk4YwMh+KvIobSPQyWwdUYSBjTsPsK60in+7JkchESZSesTxvcmDue2CLJYXVfLi8i94Mb+Y55buYETfBL55VgbTxvajT2J8sIcqEUBBEQZOFAC8bqwKAIYbM+PswSmcPTiFnx+p5vW1Zby0ooRfvbmJ/3prE5OHpXLDWRlckZOmuSlpNwqKTu54bR2vrlYBwEjQs1ss3540iG9PGsTn5Yd4eWUJr6ws5YcvriIhLoZvjE7nhrMyyB3UiyjdlS9tSEHRyb27cTf7j9RwoyaxI8qQ1B78ZMoI/uny4XxWVMFLK0pZsGYnc/L9db6uHtOPq0enM6p/kk5HymnTZHYnd9Mzy/l8zyGW/PRi1XaKcEeqa3lnw25eX7OTJVvLqalzDErpxjWj+3H1mHSGpyUoNOQrNJkdAUr3H+WjreX8wyXZCgmhW2wM143rz3Xj+lN1pIaFG3bx+tqdPP5BIb9bXEh2nx5cPbof14xJZ3Bqj2APVzoRBUUndrIA4PiMYA9FQkxSty7MnDCAmRMGsPfQcd5aV8bra8t4+L0tPPzeFkb0TWDKyL5MGdmXM9J1pCEt06mnTqq+3jH514sZlNKNF25TAUAJTFnVUd5ct4uF63eRv6MS52BgcjemjExjysi+nDVQE+GRRKeewtxn2yoo2XeUn0wZHuyhSCeSntSVW8/P4tbzsyg/eJz3Nu1m4YZdPPvpdv7wURGpCXFcnpPG1JF9mTQ4hdgY1Q0VBUWnNddXTKIKAMppSE2IY/bEgcyeOJADx2pYvHkPCzfs4tVVpfxl2RckxMcweVgql47ow0XD+6gcegQLKCjMbCrwCBAN/K9z7v5Gr8cBfwLGAxXAjc657d5r9wG3AnXAD51zC1tap/lPlv4HMMNb5vfOuUdPbzPDS9URfwHAWRNUAFDaRmJ8F6aN7c+0sf05VlPHx1v38s7GXSwuKOeva8swg3EDenLJiD5cMiJN8xoRptWgMLNo4DHgcqAEyDezBc65jQ263Qrsc84NNbNZwAPAjWaWA8wCRgL9gPfMbJi3THPr/A4wABjhnKs3sz5tsaHhZMGaUqpVAFDaSXyXaC7LSeOynDTq6x3rd1bx/uY9vL95Dw+9s4WH3tlCelI8F4/owyXD+3De0N4qWBjmAjmimAgUOue2AZjZHGAa0DAopgE/9x7PB37nHRlMA+Y4544DRWZW6K2PFtZ5B/At51w9gHNuz9ffvPA011dMTnoiZ/ZXcThpX1FRxuiMnozO6MmPLhvGngPH+KCgnEWbd/Oad4oqLiaKiVnJTM5O5YJhvXW/RhgKJCj6A8UNnpcAZzfXxzlXa2ZVQIrX/lmjZU8UJGpunUPwH41cD5TjP121tfGgzOx24HaAgQMHBrAZ4WHDzirWlx7g59fkBHsoEoH6JMafvOz2eG0dy4sqWby5nI+2lvOrNzfBm/65jwuyezM5O5XzhvYmNSEu2MOW0xSKk9lxwDHnXK6Z3QA8A1zQuJNz7ingKfBfHtuxQwyeeb4SfwHAcSoAKMEVFxPNBdmpXJCdCvgvvf1o614+2rqXxZv38PLKUgBy0hO5YJg/OHIzexEXo9NUnU0gQVGKf87ghAyvrak+JWYWAyThn9Ruadnm2kuAl73HrwB/DGCMEeFYTR2vrCrlipFp9OymK1AktKQndWVm7gBm5g6gvt6xYecBlmwtZ8mWcp75uIgnP9xGfJcocgclc86QFCYNTmF0RhJdonUJbqgLJCjygWwzy8L/YT4L+FajPguAm4GlwHTgfeecM7MFwF/M7Df4J7OzgeWAtbDOV4GLgSLgQmDL1966MPPuxt1UHa3hxgmaxJbQFhVljMpIYlRGEndePJTDx2tZVlTBki17+WxbBb9eWABAt9hoJmQmM2lwCucMSeHMfonEKDhCTqtB4c053AUsxH8p6zPOuQ1m9kvA55xbADwNPO9NVlfi/+DH65eHf5K6FrjTOVcH0NQ6vbe8H3jBzO4GDgG3td3mdm55vmL69+zKeUN6B3soIqeke1wMl4xI45IRaYD/G/yWFVXy2bYKln5ewQNvbwYgIS6GCVnJnOMFxxnpiapjFgJUwqOTKNl3hAseXMwPL8nm7suHtb6ASCdSfvC4PzS2VfDZ5xVs23sYgIT4GMYP6sWEzGRyB/VizICeuneoDamER5h5aYV/CmdGrgoASvhJTYjjmjH9uGZMPwB2VR3js20VLN9eSX5RJR8U+E9VdYk2RvVP8geHFx76wq72pyOKTuBEAcDMlO78+bbGVyaLhL99h6tZsWMf+Tsq8W3fx9qS/dTU+T+7hvbpwYTMXuQOSmZCZjIDkrvqPo4A6YgijCz1CgD+dOqIYA9FJCh6dY89ebc4+K8AXFtSRf72SnzbK3ljbRkvLvffmpXSPZaxA3oybmBPxg3sxeiMJBLiuwRz+J2egqITmJtfTFLXLlzh/ScRiXTxXaKZmJXMxKxkwH/UvWXPQfK372P1F/tZXbyPRZv9RR3MILtPDy88ejF2QE+GpSVokvwUKChCXNWRGt7esIvZKgAo0qyoKGNE30RG9E3k25MGAf7/O2tK9rPKC453Nu4mz1cCQPfYaEZlJJ0MjtEZSfRNjNcpq2YoKELca14BwBkqAChySpK6dWHysFQmD/PfOe6cY0fFEVYV+486VhXv5w9LtlFb75/r6N0jjlH9ExnVP4lRGT0Z1T+JtMQ4hQcKipCX5ytmZD8VABQ5XWZGZu/uZPbuzvXj/FcPHqupY8POA6wvrWJtSRXrS6v4cEs5XnbQu0ccozOSOLN/EqP7+28gTEuMD+JWBIeCIoSdKAD4i2tHBnsoImEpvks04wf1YvygXifbjlTXsqnsAOtKqlhb6g+PDwr2nAyP1IQ4Rvf3h0dOv0Ry0hPJ6BXeV1opKEJYXn4xsTFRTBvbL9hDEYkY3WJjGD8omfGDkk+2HamuZePOA6wrrWJdSRXrSqtY3CA8EuJjOKNvIjn9EjkjPYEz0hMZlpYQNvOKCooQdaymjldX72TKyL4qACgSZN1iY/w3+GV+NTwKdh1kU9lBNpZVsansIHm+Yo5U1wEQHWUM7t3dCw//n5z0xE5Zdl1BEaLeOVEAUJPYIiGpW2wM4wb2YtzAL09b1dc7vqg8wsayA2wqO8DGnQfIL6rktdU7T/bp3SOOM9ITGJ6WwLC+/r+z03rQLTZ0P45Dd2QRbp5XAPDcISnBHoqIBCgq6ssJ86tGpZ9s33e4mk27/MGxqewgm8oO8KeiHVTX1p/sMyC5qz880hIY3tf/9+DU7iHx/R0KihBUsu8IHxfu5R8vzSZKNwWJdHq9usdy7pDenNug8nNdvWNHxWG27D7Ilt2HKNh9kC27DvJBQfnJS3ajo4zMlG4ng+PEn8yUbh1ajl1BEYLmr/DfFDR9vAoAioSr6ChjcGoPBqf2YOqZX7ZX19ZTtPfwyeDYsvsgG3ce4K31uzhRmi82Ooqs3t0ZmtaDe6eOYEByt3Ydq4IixNTXO+b5Sjh/aG8yerXvzheR0BMbE8Xwvv7TT4z5sv1odR2Few55RyAHKdxziHUlVcTGtP+RhYIixHz6eQWl+49y75UqACgiX+rqlR0ZldHxN9/qOwdDzFyfvwDg5SoAKCIhQkERQvYfqWbhhl1cP65/2NyoIyKdn4IihLy2eqdXAFCT2CISOhQUISTPV8yZ/RMZ2U8FAEUkdCgoQsT60io27DzATN2JLSIhRkERIvJ8XgHAMf2DPRQRka9QUISAYzV1vLqqlKkj+5LUTd/tKyKhRUERAhZu2MWBY7XcOEGnnUQk9CgoQsA8XwkZvbpyzmAVABSR0KOgCLLiSn8BwBnjB6gAoIiEJAVFkM1fUYIZTNe9EyISohQUQVRX75i/wl8AsH/PrsEejohIkxQUQfTp53sp3X9Uk9giEtIUFEE0N7+Ynt1UAFBEQpuCIkj2H6nmnQ27uW5s/5D4qkMRkeYEFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTTmGdj5rZoa+5XSHv1VWlVNfVq2SHiIS8VoPCzKKBx4ArgRxgtpnlNOp2K7DPOTcUeBh4wFs2B5gFjASmAo+bWXRr6zSzXKDXaW5bSMvzlTCqfxI5/RKDPRQRkRYFckQxESh0zm1zzlUDc4BpjfpMA57zHs8HLjUz89rnOOeOO+eKgEJvfc2u0wuRXwM/Pb1NC13rS6vYWHaAmbokVkQ6gUCCoj9Q3OB5idfWZB/nXC1QBaS0sGxL67wLWOCcK2tpUGZ2u5n5zMxXXl4ewGaEjjxfMXExUVw7VgUARST0hdRktpn1A2YAv22tr3PuKedcrnMuNzU1tf0H10ZOFgA8sy9JXVUAUERCXyBBUQo0nHHN8Nqa7GNmMUASUNHCss21jwOGAoVmth3oZmaFAW5Lp3CyAKAmsUWkkwgkKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzvnHNe+yzvqqgsIBtY3tw6nXN/dc71dc5lOucygSPeBHnYyPMVMyC5K5NUAFBEOomY1jo452rN7C5gIRANPOOc22BmvwR8zrkFwNPA895v/5X4P/jx+uUBG4Fa4E7nXB1AU+ts+80LLcWVR/iksIJ7Lh+mAoAi0mm0GhQAzrk3gTcbtf1rg8fH8M8tNLXsr4BfBbLOJvr0CGR8ncU8rwDgN8fraicR6TxCajI7nNXVO+b7irkgO1UFAEWkU1FQdJBPCveys+qYJrFFpNNRUHSQub5ienXrwmU5fYI9FBGRU6Kg6AD7Dlfz7obdXDdOBQBFpPNRUHSAV1erAKCIdF4KinbmnGNufjGjM5I4I10FAEWk81FQtLP1pQfYvOsgM3Q0ISKdlIKinZ0sADimX7CHIiLytSgo2tGxmjpeXV3KlSoAKCKdmIKiHb29fhcHj9Uyc4JOO4lI56WgaEcnCwBmqQCgiHReCop28kXFET79vIKZ4weoAKCIdGoKinYyf0WxCgCKSFhQULSDunrHvBUlTM5OpZ8KAIpIJ6egaAcfF+6lrOoYN2oSW0TCgIKiHeTl+wsAXnqGCgCKSOenoGhjlYereWfjLq4fl6ECgCISFhQUbezVVaXU1DlmTtAktoiEBwVFG3LOkecrZkxGEiP6qgCgiIQHBUUbWldapQKAIhJ2FBRt6GQBwLEqACgi4UNB0UaO1dTx2uqdXDUqncR4FQAUkfChoGgjJwsA6rSTiIQZBUUbmZtfzMDkbpydlRzsoYiItCkFRRvYUXGYpdsqmJmboQKAIhJ2FBRtYP6KEqJUAFBEwpSC4jTV1Tvmryhh8rBU0pNUAFBEwo+C4jR9tLWcsqpjmsQWkbCloDhNeb5ikrvHctkZacEeiohIu1BQnIbKw9W8u3E314/rT2yM/ilFJDwF9OlmZlPNrMDMCs3s3iZejzOzud7ry8wss8Fr93ntBWY2pbV1mtkLXvt6M3vGzEL27rVXThQA1GknEQljrQaFmUUDjwFXAjnAbDPLadTtVmCfc24o8DDwgLdsDjALGAlMBR43s+hW1vkCMAIYBXQFbjutLWwnzjnm+YoZM6Anw/smBHs4IiLtJpAjiolAoXNum3OuGpgDTGvUZxrwnPd4PnCpmZnXPsc5d9w5VwQUeutrdp3OuTedB1gOhOQ1p2tL/AUAZ+aG5PBERNpMIEHRHyhu8LzEa2uyj3OuFqgCUlpYttV1eqecvg283dSgzOx2M/OZma+8vDyAzWhbeb5i4rtEcc0YFQAUkfAWyjOwjwNLnHMfNfWic+4p51yucy43NTW1Qwd2tLqOBat3ctWZKgAoIuEvJoA+pUDD2doMr62pPiVmFgMkARWtLNvsOs3s34BU4P8EML4O9/aGMg4er2XmBE1ii0j4C+SIIh/INrMsM4vFPzm9oFGfBcDN3uPpwPveHMMCYJZ3VVQWkI1/3qHZdZrZbcAUYLZzrv70Nq99zM0vZlCKCgCKSGRo9YjCOVdrZncBC4Fo4Bnn3AYz+yXgc84tAJ4GnjezQqAS/wc/Xr88YCNQC9zpnKsDaGqd3ls+AewAlvrnw3nZOffLNtvi07Sj4jCfbavkJ1OG441PRCSsBXLqCefcm8Cbjdr+tcHjY8CMZpb9FfCrQNbptQc0pmCZ5/MKAJ6lq51EJDKE8mR2yDlRAPDCYan0TYoP9nBERDqEguIULNlazq4DKgAoIpFFQXEK8vKLSekey6UqACgiEURBEaCKQ8d5b5MKAIpI5NEnXoBOFgDUvRMiEmEUFAFwzpHnK2bsgJ4MS1MBQBGJLAqKAKwpqWLL7kOaxBaRiKSgCMCXBQDTgz0UEZEOp6BoxdHqOl5fvZOrRqWToAKAIhKBFBSteGu9vwDgjTrtJCIRSkHRirn5xWSmdGOiCgCKSIRSULRg+97DLCuqZEbuABUAFJGIpaBowbwVxSoAKCIRT0HRjNq6euavKOGi4X1UAFBEIpqCohkfbd3L7gPHmZmrowkRiWwKimbM9QoAXjJCBQBFJLIpKJqgAoAiIl/Sp2ATXllVSm2940YVABQRUVA05pxjbn4x4wb2JFsFAEVEFBSNrS7ez9Y9KgAoInKCgqKRPF8JXbtEc/VoFQAUEQEFxVccqa7l9TUqACgi0pCCooG31u3i0PFaTWKLiDSgoGhgrq+YrN7dmZDZK9hDEREJGQoKT9HewywvqmRGboYKAIqINKCg8MzzqQCgiEhTFBT4CwC+tLKEi4f3IS1RBQBFRBpSUABLtpaz+8BxZujeCRGRv6GgwF8AsHePWC49o0+whyIiEnIiPij2HjrOok17uH5cf7pER/w/h4jI34j4T8ZXVqoAoIhISwIKCjObamYFZlZoZvc28Xqcmc31Xl9mZpkNXrvPay8wsymtrdPMsrx1FHrrjD3NbWyWc448XzFnDezJ0D4qACgi0pRWg8LMooHHgCuBHGC2meU06nYrsM85NxR4GHjAWzYHmAWMBKYCj5tZdCvrfAB42FvXPm/d7WKVCgCKiLQqkCOKiUChc26bc64amANMa9RnGvCc93g+cKn571qbBsxxzh13zhUBhd76mlynt8wl3jrw1nnd1966VszzFfsLAI7p115vISLS6QUSFP2B4gbPS7y2Jvs452qBKiClhWWba08B9nvraO69ADCz283MZ2a+8vLyADbjbw1M7s53zsukR1zM11peRCQSdNpPSOfcU8BTALm5ue7rrOOOi4a06ZhERMJRIEcUpUDDk/gZXluTfcwsBkgCKlpYtrn2CqCnt47m3ktERDpQIEGRD2R7VyPF4p+cXtCozwLgZu/xdOB955zz2md5V0VlAdnA8ubW6S2z2FsH3jpf+/qbJyIip6vVU0/OuVozuwtYCEQDzzjnNpjZLwGfc24B8DTwvJkVApX4P/jx+uUBG4Fa4E7nXB1AU+v03vKfgTlm9h/AKm/dIiISJOb/Jb5zy83NdT6fL9jDEBHpVMxshXMut7V+EX9ntoiItExBISIiLVJQiIhIixQUIiLSorCYzDazcmDH11y8N7C3DYfTGWibI4O2Ofyd7vYOcs6lttYpLILidJiZL5BZ/3CibY4M2ubw11Hbq1NPIiLSIgWFiIi0SEHhFRaMMNrmyKBtDn8dsr0RP0chIiIt0xGFiIi0SEEhIiItiuigMLOpZlZgZoVmdm+wx3MqzGyAmS02s41mtsHM/tFrTzazd81sq/d3L6/dzOxRb1vXmtlZDdZ1s9d/q5nd3KB9vJmt85Z51Puq2qDzvnd9lZm94T3PMrNl3jjneqXr8crbz/Xal5lZZoN13Oe1F5jZlAbtIfczYWY9zWy+mW02s01mdk6472czu9v7uV5vZi+aWXy47Wcze8bM9pjZ+gZt7b5fm3uPFjnnIvIP/vLmnwODgVhgDZAT7HGdwvjTgbO8xwnAFiAHeBC412u/F3jAe3wV8BZgwCRgmdeeDGzz/u7lPe7lvbbc62veslcGe7u9cd0D/AV4w3ueB8zyHj8B3OE9/gHwhPd4FjDXe5zj7e84IMv7OYgO1Z8J/N8df5v3OBboGc77Gf/XHxcBXRvs3++E234GJgNnAesbtLX7fm3uPVoca7D/EwTxh/EcYGGD5/cB9wV7XKexPa8BlwMFQLrXlg4UeI+fBGY36F/gvT4beLJB+5NeWzqwuUH7V/oFcTszgEXAJcAb3n+CvUBM4/2K//tOzvEex3j9rPG+PtEvFH8m8H9bZBHehSeN91847mf8QVHsffjFePt5SjjuZyCTrwZFu+/X5t6jpT+RfOrpxA/jCSVeW6fjHWqPA5YBac65Mu+lXUCa97i57W2pvaSJ9mD7H+CnQL33PAXY75yr9Z43HOfJbfNer/L6n+q/RTBlAeXAH73Tbf9rZt0J4/3snCsFHgK+AMrw77cVhPd+PqEj9mtz79GsSA6KsGBmPYCXgB855w40fM35f2UIm+ufzexqYI9zbkWwx9KBYvCfnvi9c24ccBj/6YKTwnA/9wKm4Q/JfkB3YGpQBxUEHbFfA32PSA6KUmBAg+cZXlunYWZd8IfEC865l73m3WaW7r2eDuzx2pvb3pbaM5poD6bzgGvNbDswB//pp0eAnmZ24mt9G47z5LZ5rycBFZz6v0UwlQAlzrll3vP5+IMjnPfzZUCRc67cOVcDvIx/34fzfj6hI/Zrc+/RrEgOinwg27uSIhb/JNiCII8pYN4VDE8Dm5xzv2nw0gLgxJUPN+OfuzjRfpN39cQkoMo7/FwIXGFmvbzf5K7Af/62DDhgZpO897qpwbqCwjl3n3MuwzmXiX9/ve+c+ztgMTDd69Z4m0/8W0z3+juvfZZ3tUwWkI1/4i/kfiacc7uAYjMb7jVdiv876MN2P+M/5TTJzLp5YzqxzWG7nxvoiP3a3Hs0L5iTVsH+g/9Kgi34r4D4l2CP5xTHfj7+Q8a1wGrvz1X4z80uArYC7wHJXn8DHvO2dR2Q22Bd3wUKvT+3NGjPBdZ7y/yORhOqQd7+i/jyqqfB+D8ACoF5QJzXHu89L/ReH9xg+X/xtquABlf5hOLPBDAW8Hn7+lX8V7eE9X4GfgFs9sb1PP4rl8JqPwMv4p+DqcF/5HhrR+zX5t6jpT8q4SEiIi2K5FNPIiISAAWFiIi0SEEhIiItUlCIiEiLFBQiItIiBYWIiLRIQSEiIi36/zob5nVzA95IAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"lrs=[]\n",
"for i in range(100000):\n",
" sc.step()\n",
" lrs.append(sc.get_lr())\n",
"xs = list(range(100000))\n",
"plt.plot(xs, lrs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e613fe16",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0fd9f40",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 94,
"id": "matched-camera",
"metadata": {},
"outputs": [],
"source": [
"from nnAudio import Spectrogram\n",
"from scipy.io import wavfile\n",
"import torch\n",
"import soundfile as sf\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "quarterly-solution",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[43 75 69 ... 7 6 3]\n",
"[43 75 69 ... 7 6 3]\n",
"[43 75 69 ... 7 6 3]\n"
]
}
],
"source": [
"import scipy.io.wavfile as wav\n",
"\n",
"rate,sig = wav.read('./BAC009S0764W0124.wav')\n",
"sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n",
"sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sig)\n",
"print(song)\n",
"print(sample)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "middle-salem",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"[43 75 69 ... 7 6 3]\n",
"(83792,)\n",
"int16\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.2733 seconds\n",
"tensor([[[[-4.0940e+03, 1.2600e+04],\n",
" [ 8.5108e+03, -5.4930e+03],\n",
" [-3.3631e+03, -1.7904e+03],\n",
" ...,\n",
" [ 8.2279e+03, -9.3340e+03],\n",
" [-3.1990e+03, 2.0969e+03],\n",
" [-1.2669e+03, 4.4488e+03]],\n",
"\n",
" [[ 3.4886e+03, -9.9620e+03],\n",
" [-4.5364e+03, 4.1907e+02],\n",
" [ 2.5074e+03, 7.1339e+03],\n",
" ...,\n",
" [-5.4819e+03, 3.9258e+01],\n",
" [ 4.7221e+03, 6.5887e+01],\n",
" [ 9.6492e+02, -3.4386e+03]],\n",
"\n",
" [[-3.4947e+03, 9.2981e+03],\n",
" [-7.5164e+03, 8.1856e+02],\n",
" [-5.3766e+03, -9.0889e+03],\n",
" ...,\n",
" [ 1.4317e+03, 5.7447e+03],\n",
" [-3.1178e+03, 3.0740e+03],\n",
" [-3.4351e+03, 5.6900e+02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.7112e+01, -4.5737e+00],\n",
" [-9.6295e+00, 3.5554e+01],\n",
" [ 1.8527e+00, -1.0491e+01],\n",
" ...,\n",
" [-1.1157e+01, 3.4423e+00],\n",
" [ 3.1193e+00, -4.4388e+00],\n",
" [-8.8242e+00, 8.0324e+00]],\n",
"\n",
" [[-6.5080e+01, 2.9543e+00],\n",
" [ 3.9992e+01, -1.3836e+01],\n",
" [-9.2803e+00, 1.0318e+01],\n",
" ...,\n",
" [ 4.2928e+00, 9.2397e+00],\n",
" [ 3.6642e+00, 9.4680e+00],\n",
" [ 4.8932e+00, -2.5199e+01]],\n",
"\n",
" [[ 4.7264e+01, -1.0721e+00],\n",
" [-6.0516e+00, -1.4589e+01],\n",
" [ 1.3127e+01, 1.4995e+00],\n",
" ...,\n",
" [ 1.7333e+01, -1.4380e+01],\n",
" [-3.6046e+00, -6.1019e+00],\n",
" [ 1.3321e+01, 2.3184e+01]]]])\n"
]
}
],
"source": [
"sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n",
"print(sr)\n",
"print(song)\n",
"print(song.shape)\n",
"print(song.dtype)\n",
"x = song\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n",
" window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n",
" fmin=50,fmax=8000, sr=sr) # Initializing the model\n",
"\n",
"spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"print(spec)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "finished-sterling",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"[43 75 69 ... 7 6 3]\n",
"(83792,)\n",
"int16\n",
"True\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.2001 seconds\n",
"torch.Size([1, 1025, 164, 2])\n",
"tensor([[[[-4.0940e+03, 1.2600e+04],\n",
" [ 8.5108e+03, -5.4930e+03],\n",
" [-3.3631e+03, -1.7904e+03],\n",
" ...,\n",
" [ 8.2279e+03, -9.3340e+03],\n",
" [-3.1990e+03, 2.0969e+03],\n",
" [-1.2669e+03, 4.4488e+03]],\n",
"\n",
" [[ 3.4886e+03, -9.9620e+03],\n",
" [-4.5364e+03, 4.1907e+02],\n",
" [ 2.5074e+03, 7.1339e+03],\n",
" ...,\n",
" [-5.4819e+03, 3.9258e+01],\n",
" [ 4.7221e+03, 6.5887e+01],\n",
" [ 9.6492e+02, -3.4386e+03]],\n",
"\n",
" [[-3.4947e+03, 9.2981e+03],\n",
" [-7.5164e+03, 8.1856e+02],\n",
" [-5.3766e+03, -9.0889e+03],\n",
" ...,\n",
" [ 1.4317e+03, 5.7447e+03],\n",
" [-3.1178e+03, 3.0740e+03],\n",
" [-3.4351e+03, 5.6900e+02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.7112e+01, -4.5737e+00],\n",
" [-9.6295e+00, 3.5554e+01],\n",
" [ 1.8527e+00, -1.0491e+01],\n",
" ...,\n",
" [-1.1157e+01, 3.4423e+00],\n",
" [ 3.1193e+00, -4.4388e+00],\n",
" [-8.8242e+00, 8.0324e+00]],\n",
"\n",
" [[-6.5080e+01, 2.9543e+00],\n",
" [ 3.9992e+01, -1.3836e+01],\n",
" [-9.2803e+00, 1.0318e+01],\n",
" ...,\n",
" [ 4.2928e+00, 9.2397e+00],\n",
" [ 3.6642e+00, 9.4680e+00],\n",
" [ 4.8932e+00, -2.5199e+01]],\n",
"\n",
" [[ 4.7264e+01, -1.0721e+00],\n",
" [-6.0516e+00, -1.4589e+01],\n",
" [ 1.3127e+01, 1.4995e+00],\n",
" ...,\n",
" [ 1.7333e+01, -1.4380e+01],\n",
" [-3.6046e+00, -6.1019e+00],\n",
" [ 1.3321e+01, 2.3184e+01]]]])\n",
"True\n"
]
}
],
"source": [
"wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sr)\n",
"print(wav)\n",
"print(wav.shape)\n",
"print(wav.dtype)\n",
"print(np.allclose(wav, song))\n",
"\n",
"x = wav\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n",
" window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n",
" fmin=50,fmax=8000, sr=sr) # Initializing the model\n",
"\n",
"wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"print(wav_spec.shape)\n",
"print(wav_spec)\n",
"print(np.allclose(wav_spec, spec))"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "running-technology",
"metadata": {},
"outputs": [],
"source": [
"import decimal\n",
"\n",
"import numpy\n",
"import math\n",
"import logging\n",
"def round_half_up(number):\n",
" return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n",
"\n",
"\n",
"def rolling_window(a, window, step=1):\n",
" # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n",
" shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n",
" strides = a.strides + (a.strides[-1],)\n",
" return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n",
"\n",
"\n",
"def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n",
" \"\"\"Frame a signal into overlapping frames.\n",
"\n",
" :param sig: the audio signal to frame.\n",
" :param frame_len: length of each frame measured in samples.\n",
" :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n",
" :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n",
" :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n",
" :returns: an array of frames. Size is NUMFRAMES by frame_len.\n",
" \"\"\"\n",
" slen = len(sig)\n",
" frame_len = int(round_half_up(frame_len))\n",
" frame_step = int(round_half_up(frame_step))\n",
" if slen <= frame_len:\n",
" numframes = 1\n",
" else:\n",
" numframes = 1 + (( slen - frame_len) // frame_step)\n",
"\n",
" # check kaldi/src/feat/feature-window.h\n",
" padsignal = sig[:(numframes-1)*frame_step+frame_len]\n",
" if wintype is 'povey':\n",
" win = numpy.empty(frame_len)\n",
" for i in range(frame_len):\n",
" win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n",
" else: # the hamming window\n",
" win = numpy.hamming(frame_len)\n",
" \n",
" if stride_trick:\n",
" frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n",
" else:\n",
" indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n",
" numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n",
" indices = numpy.array(indices, dtype=numpy.int32)\n",
" frames = padsignal[indices]\n",
" win = numpy.tile(win, (numframes, 1))\n",
" \n",
" frames = frames.astype(numpy.float32)\n",
" raw_frames = numpy.zeros(frames.shape)\n",
" for frm in range(frames.shape[0]):\n",
" raw_frames[frm,:] = frames[frm,:]\n",
" frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n",
" frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n",
" # raw_frames[frm,:] = frames[frm,:]\n",
" frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n",
"\n",
" return frames * win, raw_frames\n",
"\n",
"\n",
"def magspec(frames, NFFT):\n",
" \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n",
"\n",
" :param frames: the array of frames. Each row is a frame.\n",
" :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n",
" :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n",
" \"\"\"\n",
" if numpy.shape(frames)[1] > NFFT:\n",
" logging.warn(\n",
" 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n",
" numpy.shape(frames)[1], NFFT)\n",
" complex_spec = numpy.fft.rfft(frames, NFFT)\n",
" return numpy.absolute(complex_spec)\n",
"\n",
"\n",
"def powspec(frames, NFFT):\n",
" \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n",
"\n",
" :param frames: the array of frames. Each row is a frame.\n",
" :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n",
" :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n",
" \"\"\"\n",
" return numpy.square(magspec(frames, NFFT))\n",
"\n",
"\n",
"def do_dither(signal, dither_value=1.0):\n",
" signal += numpy.random.normal(size=signal.shape) * dither_value\n",
" return signal\n",
" \n",
"def do_remove_dc_offset(signal):\n",
" signal -= numpy.mean(signal)\n",
" return signal\n",
"\n",
"def do_preemphasis(signal, coeff=0.97):\n",
" \"\"\"perform preemphasis on the input signal.\n",
"\n",
" :param signal: The signal to filter.\n",
" :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n",
" :returns: the filtered signal.\n",
" \"\"\"\n",
" return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "ignored-retreat",
"metadata": {},
"outputs": [],
"source": [
"def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n",
" wintype='hamming'):\n",
" highfreq= highfreq or samplerate/2\n",
" frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n",
" spec = magspec(frames, nfft) # nearly the same until this part\n",
" rspec = magspec(raw_frames, nfft)\n",
" return spec, rspec\n",
"\n",
"\n",
"\n",
"def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n",
" wintype='hamming'):\n",
" highfreq= highfreq or samplerate/2\n",
" frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n",
" return raw_frames"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "federal-teacher",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"import scipy # used only in CFP\n",
"\n",
"import numpy as np\n",
"from time import time\n",
"\n",
"def pad_center(data, size, axis=-1, **kwargs):\n",
"\n",
" kwargs.setdefault('mode', 'constant')\n",
"\n",
" n = data.shape[axis]\n",
"\n",
" lpad = int((size - n) // 2)\n",
"\n",
" lengths = [(0, 0)] * data.ndim\n",
" lengths[axis] = (lpad, int(size - n - lpad))\n",
"\n",
" if lpad < 0:\n",
" raise ParameterError(('Target size ({:d}) must be '\n",
" 'at least input size ({:d})').format(size, n))\n",
"\n",
" return np.pad(data, lengths, **kwargs)\n",
"\n",
"\n",
"\n",
"sz_float = 4 # size of a float\n",
"epsilon = 10e-8 # fudge factor for normalization\n",
"\n",
"def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n",
" freq_scale='linear', window='hann', verbose=True):\n",
"\n",
" if freq_bins==None: freq_bins = n_fft//2+1\n",
" if win_length==None: win_length = n_fft\n",
"\n",
" s = np.arange(0, n_fft, 1.)\n",
" wsin = np.empty((freq_bins,1,n_fft))\n",
" wcos = np.empty((freq_bins,1,n_fft))\n",
" start_freq = fmin\n",
" end_freq = fmax\n",
" bins2freq = []\n",
" binslist = []\n",
"\n",
" # num_cycles = start_freq*d/44000.\n",
" # scaling_ind = np.log(end_freq/start_freq)/k\n",
"\n",
" # Choosing window shape\n",
"\n",
" #window_mask = get_window(window, int(win_length), fftbins=True)\n",
" window_mask = np.hamming(int(win_length))\n",
" window_mask = pad_center(window_mask, n_fft)\n",
"\n",
" if freq_scale == 'linear':\n",
" if verbose==True:\n",
" print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n",
" f\"get a valid freq range\")\n",
" \n",
" start_bin = start_freq*n_fft/sr\n",
" scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n",
"\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n",
" bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n",
" binslist.append((k*scaling_ind+start_bin))\n",
" wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n",
"\n",
" elif freq_scale == 'log':\n",
" if verbose==True:\n",
" print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n",
" f\"get a valid freq range\")\n",
" start_bin = start_freq*n_fft/sr\n",
" scaling_ind = np.log(end_freq/start_freq)/freq_bins\n",
"\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n",
" bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n",
" binslist.append((np.exp(k*scaling_ind)*start_bin))\n",
" wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n",
"\n",
" elif freq_scale == 'no':\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" bins2freq.append(k*sr/n_fft)\n",
" binslist.append(k)\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n",
" else:\n",
" print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n",
" return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n",
"\n",
"\n",
"\n",
"def broadcast_dim(x):\n",
" \"\"\"\n",
" Auto broadcast input so that it can fits into a Conv1d\n",
" \"\"\"\n",
"\n",
" if x.dim() == 2:\n",
" x = x[:, None, :]\n",
" elif x.dim() == 1:\n",
" # If nn.DataParallel is used, this broadcast doesn't work\n",
" x = x[None, None, :]\n",
" elif x.dim() == 3:\n",
" pass\n",
" else:\n",
" raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n",
" return x\n",
"\n",
"\n",
"\n",
"### --------------------------- Spectrogram Classes ---------------------------###\n",
"class STFT(torch.nn.Module):\n",
"\n",
" def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n",
" freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n",
" fmin=50, fmax=6000, sr=22050, trainable=False,\n",
" output_format=\"Complex\", verbose=True):\n",
"\n",
" super().__init__()\n",
"\n",
" # Trying to make the default setting same as librosa\n",
" if win_length==None: win_length = n_fft\n",
" if hop_length==None: hop_length = int(win_length // 4)\n",
"\n",
" self.output_format = output_format\n",
" self.trainable = trainable\n",
" self.stride = hop_length\n",
" self.center = center\n",
" self.pad_mode = pad_mode\n",
" self.n_fft = n_fft\n",
" self.freq_bins = freq_bins\n",
" self.trainable = trainable\n",
" self.pad_amount = self.n_fft // 2\n",
" self.window = window\n",
" self.win_length = win_length\n",
" self.iSTFT = iSTFT\n",
" self.trainable = trainable\n",
" start = time()\n",
"\n",
"\n",
"\n",
" # Create filter windows for stft\n",
" kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n",
" win_length=win_length,\n",
" freq_bins=freq_bins,\n",
" window=window,\n",
" freq_scale=freq_scale,\n",
" fmin=fmin,\n",
" fmax=fmax,\n",
" sr=sr,\n",
" verbose=verbose)\n",
"\n",
"\n",
" kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n",
" kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n",
" \n",
" # In this way, the inverse kernel and the forward kernel do not share the same memory...\n",
" kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n",
" kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n",
" \n",
" if iSTFT:\n",
" self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n",
" self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n",
"\n",
" # Applying window functions to the Fourier kernels\n",
" if window:\n",
" window_mask = torch.tensor(window_mask)\n",
" wsin = kernel_sin * window_mask\n",
" wcos = kernel_cos * window_mask\n",
" else:\n",
" wsin = kernel_sin\n",
" wcos = kernel_cos\n",
" \n",
" if self.trainable==False:\n",
" self.register_buffer('wsin', wsin)\n",
" self.register_buffer('wcos', wcos) \n",
" \n",
" if self.trainable==True:\n",
" wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n",
" wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n",
" self.register_parameter('wsin', wsin)\n",
" self.register_parameter('wcos', wcos) \n",
" \n",
" # Prepare the shape of window mask so that it can be used later in inverse\n",
" # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n",
" \n",
" if verbose==True:\n",
" print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n",
" else:\n",
" pass\n",
"\n",
" def forward(self, x, output_format=None):\n",
" \"\"\"\n",
" Convert a batch of waveforms to spectrograms.\n",
" \n",
" Parameters\n",
" ----------\n",
" x : torch tensor\n",
" Input signal should be in either of the following shapes.\\n\n",
" 1. ``(len_audio)``\\n\n",
" 2. ``(num_audio, len_audio)``\\n\n",
" 3. ``(num_audio, 1, len_audio)``\n",
" It will be automatically broadcast to the right shape\n",
" \n",
" output_format : str\n",
" Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n",
" Default value is ``Complex``. \n",
" \n",
" \"\"\"\n",
" output_format = output_format or self.output_format\n",
" self.num_samples = x.shape[-1]\n",
" \n",
" x = broadcast_dim(x)\n",
" if self.center:\n",
" if self.pad_mode == 'constant':\n",
" padding = nn.ConstantPad1d(self.pad_amount, 0)\n",
"\n",
" elif self.pad_mode == 'reflect':\n",
" if self.num_samples < self.pad_amount:\n",
" raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n",
" padding = nn.ReflectionPad1d(self.pad_amount)\n",
"\n",
" x = padding(x)\n",
" spec_imag = conv1d(x, self.wsin, stride=self.stride)\n",
" spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n",
"\n",
" # remove redundant parts\n",
" spec_real = spec_real[:, :self.freq_bins, :]\n",
" spec_imag = spec_imag[:, :self.freq_bins, :]\n",
"\n",
" if output_format=='Magnitude':\n",
" spec = spec_real.pow(2) + spec_imag.pow(2)\n",
" if self.trainable==True:\n",
" return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n",
" else:\n",
" return torch.sqrt(spec)\n",
"\n",
" elif output_format=='Complex':\n",
" return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n",
"\n",
" elif output_format=='Phase':\n",
" return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n",
"\n",
" def inverse(self, X, onesided=True, length=None, refresh_win=True):\n",
" \"\"\"\n",
" This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n",
" which is to convert spectrograms back to waveforms. \n",
" It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n",
" please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n",
" \n",
" Parameters\n",
" ----------\n",
" onesided : bool\n",
" If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n",
" else use ``onesided=False``\n",
"\n",
" length : int\n",
" To make sure the inverse STFT has the same output length of the original waveform, please\n",
" set `length` as your intended waveform length. By default, ``length=None``,\n",
" which will remove ``n_fft//2`` samples from the start and the end of the output.\n",
" \n",
" refresh_win : bool\n",
" Recalculating the window sum square. If you have an input with fixed number of timesteps,\n",
" you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n",
" \n",
" \n",
" \"\"\"\n",
" if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n",
" raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n",
" \n",
" assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n",
" \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n",
" \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n",
" if onesided:\n",
" X = extend_fbins(X) # extend freq\n",
"\n",
" \n",
" X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n",
"\n",
" # broadcast dimensions to support 2D convolution\n",
" X_real_bc = X_real.unsqueeze(1)\n",
" X_imag_bc = X_imag.unsqueeze(1)\n",
" a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n",
" b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n",
" \n",
" # compute real and imag part. signal lies in the real part\n",
" real = a1 - b2\n",
" real = real.squeeze(-2)*self.window_mask\n",
"\n",
" # Normalize the amplitude with n_fft\n",
" real /= (self.n_fft)\n",
"\n",
" # Overlap and Add algorithm to connect all the frames\n",
" real = overlap_add(real, self.stride)\n",
" \n",
" # Prepare the window sumsqure for division\n",
" # Only need to create this window once to save time\n",
" # Unless the input spectrograms have different time steps\n",
" if hasattr(self, 'w_sum')==False or refresh_win==True:\n",
" self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n",
" self.nonzero_indices = (self.w_sum>1e-10) \n",
" else:\n",
" pass\n",
" real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n",
" # Remove padding\n",
" if length is None: \n",
" if self.center:\n",
" real = real[:, self.pad_amount:-self.pad_amount]\n",
"\n",
" else:\n",
" if self.center:\n",
" real = real[:, self.pad_amount:self.pad_amount + length] \n",
" else:\n",
" real = real[:, :length] \n",
" \n",
" return real\n",
" \n",
" def extra_repr(self) -> str:\n",
" return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n",
" self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n",
" ) "
]
},
{
"cell_type": "code",
"execution_count": 128,
"id": "unusual-baker",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"(83792,)\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.0153 seconds\n",
"torch.Size([521, 257])\n",
"(522, 257)\n",
"[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n",
" 2.32206573e+01 1.90274487e+01]\n",
" [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n",
" 1.80534729e+02 1.84179596e+02]\n",
" [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n",
" 1.21321175e+02 1.08345497e+02]\n",
" ...\n",
" [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n",
" 1.25670158e+02 1.20691467e+02]\n",
" [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n",
" 9.01831894e+01 9.84055099e+01]\n",
" [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n",
" 8.94473553e+00 9.61348629e+00]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]\n",
" [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n",
" 6.36197377e+01 4.40000000e+01]]\n"
]
}
],
"source": [
"wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sr)\n",
"print(wav.shape)\n",
"\n",
"x = wav\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n",
" window='', freq_scale='linear', center=False, pad_mode='constant',\n",
" fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n",
"wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"wav_spec = wav_spec[0].T\n",
"print(wav_spec.shape)\n",
"\n",
"\n",
"spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"print(spec.shape)\n",
"\n",
"print(wav_spec.numpy())\n",
"print(rspec)\n",
"# print(spec)\n",
"\n",
"# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n",
"# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
"# dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
"# wintype='hamming')\n",
"# print(rspec)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "white-istanbul",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 129,
"id": "modern-rescue",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n",
" 0.75 0.41317591 0.11697778 0. ]\n"
]
},
{
"data": {
"text/plain": [
"array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n",
" 0.9045085, 0.6545085, 0.3454915, 0.0954915])"
]
},
"execution_count": 129,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(np.hanning(10))\n",
"from scipy.signal import get_window\n",
"get_window('hann', 10, fftbins=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "professional-journalism",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 153,
"id": "involved-motion",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(522, 400)\n",
"[[ 43. 75. 69. ... 46. 46. 45.]\n",
" [ 210. 215. 216. ... -86. -89. -91.]\n",
" [ 128. 128. 128. ... -154. -151. -151.]\n",
" ...\n",
" [ -60. -61. -61. ... 112. 109. 110.]\n",
" [ 20. 22. 24. ... 91. 87. 87.]\n",
" [ 111. 107. 108. ... -6. -4. -8.]]\n",
"torch.Size([1, 1, 83792])\n",
"torch.Size([400, 1, 512])\n",
"torch.Size([1, 400, 521])\n",
"conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n",
" [ 210., 215., 216., ..., -86., -89., -91.],\n",
" [ 128., 128., 128., ..., -154., -151., -151.],\n",
" ...,\n",
" [-143., -141., -142., ..., 96., 101., 101.],\n",
" [ -60., -61., -61., ..., 112., 109., 110.],\n",
" [ 20., 22., 24., ..., 91., 87., 87.]])\n",
"xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n",
" 2.4064583e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n",
" 9.5781029e+01 1.4200000e+02]\n",
" ...\n",
" [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n",
" 9.1511757e+01 1.1500000e+02]\n",
" [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n",
" 7.8405365e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n",
" 5.1310158e+01 3.5000000e+01]]\n",
"torch.Size([521, 257])\n",
"yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n",
" 9.15117270e+01 1.15000000e+02]\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]]\n",
"yy (522, 257)\n",
"[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n",
" 2.4064583e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n",
" 9.5781029e+01 1.4200000e+02]\n",
" ...\n",
" [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n",
" 9.1511757e+01 1.1500000e+02]\n",
" [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n",
" 7.8405365e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n",
" 5.1310158e+01 3.5000000e+01]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n",
" 9.15117270e+01 1.15000000e+02]\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]]\n",
"False\n"
]
}
],
"source": [
"f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"print(f.shape)\n",
"print(f)\n",
"\n",
"n_fft=512\n",
"freq_bins = n_fft//2+1\n",
"s = np.arange(0, n_fft, 1.)\n",
"wsin = np.empty((freq_bins,1,n_fft))\n",
"wcos = np.empty((freq_bins,1,n_fft))\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n",
"\n",
"\n",
"wsin = np.empty((n_fft,1,n_fft))\n",
"wcos = np.empty((n_fft,1,n_fft))\n",
"for k in range(n_fft): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n",
" wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n",
" \n",
" \n",
"wsin = np.empty((400,1,n_fft))\n",
"wcos = np.empty((400,1,n_fft))\n",
"for k in range(400): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.eye(400, n_fft)[k]\n",
" wcos[k,0,:] = np.eye(400, n_fft)[k]\n",
" \n",
"\n",
" \n",
"x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n",
"x = x[None, None, :]\n",
"print(x.size())\n",
"kernel_sin = torch.tensor(wsin, dtype=torch.float)\n",
"kernel_cos = torch.tensor(wcos, dtype=torch.float)\n",
"print(kernel_sin.size())\n",
"\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"spec_imag = conv1d(x, kernel_sin, stride=160)\n",
"spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n",
"\n",
"print(spec_imag.size())\n",
"print(\"conv frame\", spec_imag[0].T)\n",
"# print(spec_imag[0].T[:, :400])\n",
"\n",
"# remove redundant parts\n",
"# spec_real = spec_real[:, :freq_bins, :]\n",
"# spec_imag = spec_imag[:, :freq_bins, :]\n",
"# spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"# spec = torch.sqrt(spec)\n",
"# print(spec)\n",
"\n",
"\n",
"\n",
"s = np.arange(0, 512, 1.)\n",
"# s = s[::-1]\n",
"wsin = np.empty((freq_bins, 400))\n",
"wcos = np.empty((freq_bins, 400))\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n",
" wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n",
"\n",
"spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n",
"spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n",
"\n",
"\n",
"# remove redundant parts\n",
"spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"spec = torch.sqrt(spec)\n",
"\n",
"print('xx', spec.numpy())\n",
"print(spec.size())\n",
"print('yy', rspec[:521, :])\n",
"print('yy', rspec.shape)\n",
"\n",
"\n",
"x = spec.numpy()\n",
"y = rspec[:-1, :]\n",
"print(x)\n",
"print(y)\n",
"print(np.allclose(x, y))"
]
},
{
"cell_type": "code",
"execution_count": 160,
"id": "mathematical-traffic",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([257, 1, 400])\n",
"tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n",
" 3.8693e+04, 3.1020e+04],\n",
" [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n",
" 1.3598e+04, 1.5920e+04],\n",
" [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n",
" 6.7707e+03, 4.3020e+03],\n",
" ...,\n",
" [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n",
" 6.1071e+01, 5.3685e+01],\n",
" [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n",
" 5.1310e+01, 6.3620e+01],\n",
" [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n",
" 3.5000e+01, 4.4000e+01]]])\n",
"[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n",
" 2.4064537e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n",
" 9.5781006e+01 1.4200000e+02]\n",
" ...\n",
" [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n",
" 7.8405350e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n",
" 5.1310150e+01 3.5000000e+01]\n",
" [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n",
" 6.3619797e+01 4.4000000e+01]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]\n",
" [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n",
" 6.36197377e+01 4.40000000e+01]]\n",
"False\n"
]
}
],
"source": [
"f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"\n",
"n_fft=512\n",
"freq_bins = n_fft//2+1\n",
"s = np.arange(0, n_fft, 1.)\n",
"wsin = np.empty((freq_bins,1,400))\n",
"wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n",
"\n",
" \n",
"x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n",
"x = x[None, None, :] #[B, C, T]\n",
"\n",
"kernel_sin = torch.tensor(wsin, dtype=torch.float)\n",
"kernel_cos = torch.tensor(wcos, dtype=torch.float)\n",
"print(kernel_sin.size())\n",
"\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n",
"spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n",
"\n",
"# remove redundant parts\n",
"spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"spec = torch.sqrt(spec)\n",
"print(spec)\n",
"\n",
"x = spec[0].T.numpy()\n",
"y = rspec[:, :]\n",
"print(x)\n",
"print(y)\n",
"print(np.allclose(x, y))"
]
},
{
"cell_type": "code",
"execution_count": 162,
"id": "olive-nicaragua",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
},
{
"data": {
"text/plain": [
"27241"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.argmax(np.abs(x -y) / np.abs(y))"
]
},
{
"cell_type": "code",
"execution_count": 165,
"id": "ultimate-assault",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[np.unravel_index(27241, y.shape)]"
]
},
{
"cell_type": "code",
"execution_count": 166,
"id": "institutional-stock",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4.2412265e-10"
]
},
"execution_count": 166,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x[np.unravel_index(27241, y.shape)]"
]
},
{
"cell_type": "code",
"execution_count": 167,
"id": "integrated-courage",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 167,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.allclose(y, x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "different-operation",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "purple-consequence",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "defensive-mason",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "patient-convention",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Namespace(delta_delta=False, feat_dim=80, manifest_path='examples/aishell/s1/data/manifest.train.raw', num_samples=-1, num_workers=16, output_path='data/librispeech/mean_std.npz', sample_rate=16000, specgram_type='fbank', stride_ms=10.0, window_ms=25.0)\n"
]
}
],
"source": [
"import argparse\n",
"import functools\n",
"\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"from deepspeech.utils.utility import add_arguments\n",
"from deepspeech.utils.utility import print_arguments\n",
"\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, -1, \"# of samples to for statistics.\")\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc, fbank.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n",
"add_arg('delta_delta', bool,\n",
" False,\n",
" \"Audio feature with delta delta.\")\n",
"add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n",
"add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n",
"add_arg('sample_rate', int, 16000, \"target sample rate.\")\n",
"add_arg('manifest_path', str,\n",
" 'examples/aishell/s1/data/manifest.train.raw',\n",
" \"Filepath of manifest to compute normalizer's mean and stddev.\")\n",
"add_arg('num_workers',\n",
" default=16,\n",
" type=int,\n",
" help='num of subprocess workers for processing')\n",
"add_arg('output_path', str,\n",
" 'data/librispeech/mean_std.npz',\n",
" \"Filepath of write mean and stddev to (.npz).\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(args)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "enormous-currency",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import paddle\n",
"from paddle.io import DataLoader\n",
"from paddle.io import Dataset\n",
"\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"class CollateFunc(object):\n",
" ''' Collate function for AudioDataset\n",
" '''\n",
" def __init__(self):\n",
" pass\n",
" \n",
" def __call__(self, batch):\n",
" mean_stat = None\n",
" var_stat = None\n",
" number = 0\n",
" for feat in batch:\n",
" sums = np.sum(feat, axis=1)\n",
" if mean_stat is None:\n",
" mean_stat = sums\n",
" else:\n",
" mean_stat += sums\n",
"\n",
" square_sums = np.sum(np.square(feat), axis=1)\n",
" if var_stat is None:\n",
" var_stat = square_sums\n",
" else:\n",
" var_stat += square_sums\n",
"\n",
" number += feat.shape[1]\n",
" #return paddle.to_tensor(number), paddle.to_tensor(mean_stat), paddle.to_tensor(var_stat)\n",
" return number, mean_stat, var_stat\n",
"\n",
"\n",
"class AudioDataset(Dataset):\n",
" def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None):\n",
" self.feature_func = feature_func\n",
" self._rng = rng\n",
" manifest = read_manifest(manifest_path)\n",
" if num_samples == -1:\n",
" sampled_manifest = manifest\n",
" else:\n",
" sampled_manifest = self._rng.sample(manifest, num_samples)\n",
" self.items = sampled_manifest\n",
"\n",
" def __len__(self):\n",
" return len(self.items)\n",
"\n",
" def __getitem__(self, idx):\n",
" key = self.items[idx]['feat']\n",
" audioseg = AudioSegment.from_file(key)\n",
" feat = self.feature_func(audioseg) #(D, T)\n",
" return feat"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "armed-semester",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"process 1000 wavs,450739 frames\n",
"process 2000 wavs,887447 frames\n",
"process 3000 wavs,1354148 frames\n",
"process 4000 wavs,1816494 frames\n",
"process 5000 wavs,2359211 frames\n",
"process 6000 wavs,2828455 frames\n",
"process 7000 wavs,3276186 frames\n",
"process 8000 wavs,3692234 frames\n",
"process 9000 wavs,4139360 frames\n",
"process 10000 wavs,4591528 frames\n",
"process 11000 wavs,5020114 frames\n",
"process 12000 wavs,5459523 frames\n",
"process 13000 wavs,5899534 frames\n",
"process 14000 wavs,6323242 frames\n",
"process 15000 wavs,6736597 frames\n",
"process 16000 wavs,7207686 frames\n",
"process 17000 wavs,7637800 frames\n",
"process 18000 wavs,8093004 frames\n",
"process 19000 wavs,8529518 frames\n",
"process 20000 wavs,8906022 frames\n",
"process 21000 wavs,9352652 frames\n",
"process 22000 wavs,9807495 frames\n",
"process 23000 wavs,10247938 frames\n",
"process 24000 wavs,10700011 frames\n",
"process 25000 wavs,11126134 frames\n",
"process 26000 wavs,11558061 frames\n",
"process 27000 wavs,12010359 frames\n",
"process 28000 wavs,12470938 frames\n",
"process 29000 wavs,12916013 frames\n",
"process 30000 wavs,13345816 frames\n",
"process 31000 wavs,13752365 frames\n",
"process 32000 wavs,14174801 frames\n",
"process 33000 wavs,14642170 frames\n",
"process 34000 wavs,15053557 frames\n",
"process 35000 wavs,15531890 frames\n",
"process 36000 wavs,16022711 frames\n",
"process 37000 wavs,16437688 frames\n",
"process 38000 wavs,16859517 frames\n",
"process 39000 wavs,17307676 frames\n",
"process 40000 wavs,17796629 frames\n",
"process 41000 wavs,18264151 frames\n",
"process 42000 wavs,18711898 frames\n",
"process 43000 wavs,19159890 frames\n",
"process 44000 wavs,19576435 frames\n",
"process 45000 wavs,19992793 frames\n",
"process 46000 wavs,20464449 frames\n",
"process 47000 wavs,20886021 frames\n",
"process 48000 wavs,21317318 frames\n",
"process 49000 wavs,21738034 frames\n",
"process 50000 wavs,22171890 frames\n",
"process 51000 wavs,22622238 frames\n",
"process 52000 wavs,23100734 frames\n",
"process 53000 wavs,23526901 frames\n",
"process 54000 wavs,23969746 frames\n",
"process 55000 wavs,24418691 frames\n",
"process 56000 wavs,24862546 frames\n",
"process 57000 wavs,25336448 frames\n",
"process 58000 wavs,25778435 frames\n",
"process 59000 wavs,26216199 frames\n",
"process 60000 wavs,26694692 frames\n",
"process 61000 wavs,27148978 frames\n",
"process 62000 wavs,27617088 frames\n",
"process 63000 wavs,28064946 frames\n",
"process 64000 wavs,28519843 frames\n",
"process 65000 wavs,28989722 frames\n",
"process 66000 wavs,29470156 frames\n",
"process 67000 wavs,29952931 frames\n",
"process 68000 wavs,30360555 frames\n",
"process 69000 wavs,30797929 frames\n",
"process 70000 wavs,31218227 frames\n",
"process 71000 wavs,31663934 frames\n",
"process 72000 wavs,32107468 frames\n",
"process 73000 wavs,32541943 frames\n",
"process 74000 wavs,33010702 frames\n",
"process 75000 wavs,33448082 frames\n",
"process 76000 wavs,33886812 frames\n",
"process 77000 wavs,34338108 frames\n",
"process 78000 wavs,34761495 frames\n",
"process 79000 wavs,35199730 frames\n",
"process 80000 wavs,35669630 frames\n",
"process 81000 wavs,36122402 frames\n",
"process 82000 wavs,36604561 frames\n",
"process 83000 wavs,37085552 frames\n",
"process 84000 wavs,37517500 frames\n",
"process 85000 wavs,37987196 frames\n",
"process 86000 wavs,38415721 frames\n",
"process 87000 wavs,38889467 frames\n",
"process 88000 wavs,39337809 frames\n",
"process 89000 wavs,39792342 frames\n",
"process 90000 wavs,40287946 frames\n",
"process 91000 wavs,40719461 frames\n",
"process 92000 wavs,41178919 frames\n",
"process 93000 wavs,41659635 frames\n",
"process 94000 wavs,42132985 frames\n",
"process 95000 wavs,42584564 frames\n",
"process 96000 wavs,43018598 frames\n",
"process 97000 wavs,43480662 frames\n",
"process 98000 wavs,43973670 frames\n",
"process 99000 wavs,44448190 frames\n",
"process 100000 wavs,44935034 frames\n",
"process 101000 wavs,45379812 frames\n",
"process 102000 wavs,45821207 frames\n",
"process 103000 wavs,46258420 frames\n",
"process 104000 wavs,46743733 frames\n",
"process 105000 wavs,47206922 frames\n",
"process 106000 wavs,47683041 frames\n",
"process 107000 wavs,48122809 frames\n",
"process 108000 wavs,48594623 frames\n",
"process 109000 wavs,49086358 frames\n",
"process 110000 wavs,49525568 frames\n",
"process 111000 wavs,49985820 frames\n",
"process 112000 wavs,50428262 frames\n",
"process 113000 wavs,50897957 frames\n",
"process 114000 wavs,51344589 frames\n",
"process 115000 wavs,51774621 frames\n",
"process 116000 wavs,52243372 frames\n",
"process 117000 wavs,52726025 frames\n",
"process 118000 wavs,53170026 frames\n",
"process 119000 wavs,53614141 frames\n",
"process 120000 wavs,54071271 frames\n"
]
}
],
"source": [
"\n",
"augmentation_pipeline = AugmentationPipeline('{}')\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20)\n",
"\n",
"def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
"\n",
"collate_func = CollateFunc()\n",
"\n",
"dataset = AudioDataset(\n",
" args.manifest_path,\n",
" augment_and_featurize, \n",
" args.num_samples)\n",
"\n",
"batch_size = 20\n",
"data_loader = DataLoader(\n",
" dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=args.num_workers,\n",
" collate_fn=collate_func)\n",
"\n",
"with paddle.no_grad():\n",
" all_mean_stat = None\n",
" all_var_stat = None\n",
" all_number = 0\n",
" wav_number = 0\n",
" for i, batch in enumerate(data_loader()):\n",
" #for batch in data_loader():\n",
" number, mean_stat, var_stat = batch\n",
" if i == 0:\n",
" all_mean_stat = mean_stat\n",
" all_var_stat = var_stat\n",
" else:\n",
" all_mean_stat += mean_stat\n",
" all_var_stat += var_stat\n",
" all_number += number\n",
" wav_number += batch_size\n",
"\n",
" if wav_number % 1000 == 0:\n",
" print('process {} wavs,{} frames'.format(wav_number,\n",
" all_number))\n",
"\n",
"cmvn_info = {\n",
" 'mean_stat': list(all_mean_stat.tolist()),\n",
" 'var_stat': list(all_var_stat.tolist()),\n",
" 'frame_num': all_number\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "danish-executive",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'mean_stat': [-813852467.7953382, -769025957.9140725, -809499593.411409, -774700574.014532, -750961217.5896736, -760564397.2864963, -805662399.3771614, -843490965.4231446, -850242081.9416809, -857678651.504435, -879067453.9826999, -908602072.3856701, -936850957.7187386, -957242686.489041, -968425442.0916103, -972687545.5953809, -980383731.7683417, -991533337.6343704, -1001966818.1164789, -1010334169.7486078, -1016855066.9099333, -1022176245.7021623, -1025700476.4788507, -1030678878.3195274, -1037075963.124199, -1042705719.0195516, -1047422212.6492896, -1049003537.271861, -1050314833.7453628, -1050772191.0204058, -1050010034.9948177, -1050436065.1336465, -1053327181.7978873, -1058710548.2036785, -1065950852.4966162, -1071709705.0060445, -1077682778.259181, -1083371045.272074, -1089708906.2657735, -1096312217.7865202, -1101089858.8364556, -1104965332.4332569, -1107791702.5223634, -1109431075.2374773, -1110066333.0280604, -1110382732.0722318, -1110480306.3793216, -1110203297.7110727, -1109972534.3583376, -1109378081.8792782, -1108212059.413654, -1107235713.2041805, -1106973581.9280007, -1107352339.7860134, -1108730029.862537, -1110425202.83704, -1113220669.4552443, -1115887535.4870913, -1118105356.3628063, -1120001376.8503075, -1121135822.320366, -1122265971.8751016, -1123990217.401155, -1125786729.6230593, -1127784957.2745507, -1129180108.9033566, -1132000461.6688302, -1134675829.8190608, -1137652487.5164194, -1141755948.0463965, -1145340901.5468378, -1148637682.593287, -1151755522.470022, -1154981643.2268832, -1157417488.840151, -1161240429.0989249, -1165411128.671642, -1170521097.1034513, -1176307165.5109766, -1183456865.0039694, -1190535938.6591117, -1197946309.0472982, -1203596565.037139, -1207563038.1241052, -1209707561.5829782, -1211407066.2452552, -1211884576.9201162, -1212778872.005509, -1214041413.8080075, -1215367953.1745043, -1216850831.482193, -1217678325.5351057, -1218854289.54188, -1219325064.8610544, -1219080344.7580786, -1218541313.657531, -1217889833.2067819, -1216552930.1654336, -1216423777.4113154, -1216575252.225508, -1217075384.9826024, -1217391577.901724, -1217838974.57273, -1218131805.6054134, -1218294889.7465532, -1218566666.1755593, -1218790537.5519717, -1218748668.9956846, -1218603191.4941735, -1218004566.4348054, -1217312410.127734, -1217207493.9522285, -1217284002.3834674, -1217644312.51745, -1218039821.6444128, -1218721811.6269798, -1219121088.9265897, -1219014460.8090584, -1218530127.6776083, -1217952335.451711, -1217316073.8666434, -1217035380.1151958, -1216636431.2964456, -1216257015.2945514, -1215658496.1208403, -1215097272.0976632, -1214669859.2064147, -1214593853.4809475, -1214599475.7838447, -1214575440.823035, -1214158828.8008435, -1213482920.2673717, -1212476577.5897374, -1211251374.2198513, -1210284855.590475, -1209302456.065669, -1209106252.6625297, -1209373211.5146718, -1209689421.7984035, -1210021342.495856, -1210650609.3592312, -1211428521.3900626, -1212616111.4257205, -1213820075.2948189, -1215320588.7144456, -1217175082.2739282, -1219703351.4585004, -1222007827.120464, -1224637375.5900724, -1228367798.912171, -1234853879.862459, -1247222219.867692, -1268562808.1616178, -1302034822.9569275, -1347823631.0776038, -1402753916.9445229, -1458826717.3262982, -1505843092.0970414, -1534278782.249077, -1543955545.8994718, -1600409154.893352], 'var_stat': [12665413908.91729, 11145088801.244318, 12567119446.035736, 11758392758.06822, 11200687982.736668, 11551903443.711124, 12880777868.435602, 14084854368.236998, 14394011058.866192, 14678818621.277662, 15346278722.626339, 16268053979.757076, 17191705347.854794, 17877540386.548733, 18251857849.077663, 18392628178.710472, 18645534548.4045, 19018598212.22902, 19366711357.782673, 19655730286.72857, 19890681996.786858, 20094163350.461906, 20227774955.225887, 20423525628.66887, 20669928826.76939, 20882313568.247944, 21062392676.270527, 21126648821.879055, 21185210734.751118, 21209014745.520447, 21182293842.91236, 21197433134.875977, 21302147790.662144, 21504666657.651955, 21781818550.89697, 21996170165.145462, 22217169779.096275, 22431161762.176693, 22672708668.38104, 22922683961.072956, 23101137011.201683, 23249680793.556847, 23358894817.24979, 23422895267.919228, 23449479198.303394, 23464433357.671055, 23469197140.124596, 23459013479.866177, 23447935341.542686, 23422585038.052387, 23375601301.949135, 23338397991.497776, 23329682884.21905, 23348002892.39853, 23406274659.89975, 23478242518.92228, 23592891371.876236, 23703885161.772205, 23797158601.65954, 23875230355.66992, 23918333664.3946, 23968582109.371258, 24040547318.081936, 24112364295.110058, 24189973697.612144, 24242165205.640236, 24364255205.82311, 24472408850.760197, 24590211203.05312, 24763026764.005527, 24909192634.69144, 25043438176.23281, 25167141466.500504, 25297108031.48665, 25395377064.0999, 25550930772.86505, 25721404827.10336, 25931101211.156487, 26168988710.098465, 26465528802.762875, 26760033029.443783, 27075408488.605213, 27316626931.655052, 27487275073.52796, 27579518448.2332, 27652308513.875782, 27673412508.45838, 27711509210.702576, 27767312240.641487, 27827464683.295334, 27894794590.957966, 27935988489.16511, 27992337099.891083, 28019655483.58796, 28014286886.252903, 27996189233.857716, 27973078840.875465, 27920045013.68706, 27917103211.22359, 27927566165.64652, 27953525818.61368, 27973386070.140022, 27999317832.502476, 28019494120.641834, 28033010746.452637, 28051086123.896503, 28066195174.191753, 28068570977.318798, 28064890246.85437, 28042424375.860577, 28015849655.869568, 28014812222.566605, 28021039053.959835, 28039270607.169422, 28058271295.10199, 28088976520.10178, 28107824988.74732, 28105633030.784756, 28087681357.818607, 28065484299.963837, 28039555887.004284, 28028214431.52875, 28011714871.929447, 27995603790.480755, 27970125897.561134, 27946436130.511288, 27929044772.5522, 27926612443.390316, 27926256324.387302, 27924771848.71099, 27905526922.390133, 27876268519.168198, 27832532606.552593, 27779497699.976765, 27737034351.907337, 27692129825.179924, 27684252911.371475, 27698882622.878677, 27712387157.27985, 27726474638.933037, 27752647691.051613, 27786197932.382797, 27836378752.662235, 27887415700.334576, 27949784230.702114, 28028117657.84245, 28136313097.200474, 28234098926.207996, 28345845477.25874, 28507222800.146496, 28793832339.90449, 29350765483.070816, 30328262350.231213, 31894930713.76519, 34093669067.422382, 36801959396.22739, 39638995447.49344, 42088579425.44825, 43616108982.85117, 44152063315.31461, 47464832889.5967], 'frame_num': 54129649}\n"
]
}
],
"source": [
"print(cmvn_info)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "accurate-terminal",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dominant-abuse",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 1000 wavs,450240 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 2000 wavs,886411 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 3000 wavs,1352580 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 4000 wavs,1814397 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 5000 wavs,2356587 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 6000 wavs,2825310 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 7000 wavs,3272506 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 8000 wavs,3688045 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 9000 wavs,4134669 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 10000 wavs,4586357 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 11000 wavs,5014429 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 12000 wavs,5453334 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 13000 wavs,5892888 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 14000 wavs,6316059 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 15000 wavs,6728870 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 16000 wavs,7199442 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 17000 wavs,7629055 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 18000 wavs,8083729 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 19000 wavs,8519732 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 20000 wavs,8895694 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 21000 wavs,9341778 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 22000 wavs,9796126 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 23000 wavs,10236057 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 24000 wavs,10687461 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 25000 wavs,11113082 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 26000 wavs,11544482 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 27000 wavs,11996273 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 28000 wavs,12456350 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 29000 wavs,12900895 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 30000 wavs,13330353 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 31000 wavs,13736568 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 32000 wavs,14158472 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 33000 wavs,14625316 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 34000 wavs,15036206 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 35000 wavs,15514001 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 36000 wavs,16004323 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 37000 wavs,16418799 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 38000 wavs,16840100 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 39000 wavs,17287752 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 40000 wavs,17776206 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 41000 wavs,18243209 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 42000 wavs,18690449 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 43000 wavs,19137940 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 44000 wavs,19553966 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 45000 wavs,19969813 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 46000 wavs,20440963 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 47000 wavs,20862022 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 48000 wavs,21292801 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 49000 wavs,21713004 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 50000 wavs,22146346 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 51000 wavs,22596172 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 52000 wavs,23074160 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 53000 wavs,23499823 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 54000 wavs,23942151 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 55000 wavs,24390566 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 56000 wavs,24833905 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 57000 wavs,25307270 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 58000 wavs,25748720 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 59000 wavs,26185964 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 60000 wavs,26663953 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 61000 wavs,27117720 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 62000 wavs,27585349 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 63000 wavs,28032693 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 64000 wavs,28487074 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 65000 wavs,28956462 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 66000 wavs,29436358 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 67000 wavs,29918569 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 68000 wavs,30325682 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 69000 wavs,30762528 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 70000 wavs,31182319 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 71000 wavs,31627526 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 72000 wavs,32070556 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 73000 wavs,32504534 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 74000 wavs,32972775 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 75000 wavs,33409637 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 76000 wavs,33847861 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 77000 wavs,34298647 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 78000 wavs,34721536 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 79000 wavs,35159236 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 80000 wavs,35628628 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 81000 wavs,36080909 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 82000 wavs,36562496 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 83000 wavs,37042976 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 84000 wavs,37474403 frames\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 85000 wavs,37943596 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 86000 wavs,38371620 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 87000 wavs,38844874 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 88000 wavs,39292686 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 89000 wavs,39746715 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 90000 wavs,40241800 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 91000 wavs,40672817 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 92000 wavs,41131773 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 93000 wavs,41612001 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 94000 wavs,42084822 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 95000 wavs,42535878 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 96000 wavs,42969365 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 97000 wavs,43430890 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 98000 wavs,43923378 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 99000 wavs,44397370 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 100000 wavs,44883695 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 101000 wavs,45327968 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 102000 wavs,45768860 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 103000 wavs,46205602 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 104000 wavs,46690407 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 105000 wavs,47153089 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 106000 wavs,47628699 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 107000 wavs,48067945 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 108000 wavs,48539256 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 109000 wavs,49030485 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 110000 wavs,49469189 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 111000 wavs,49928968 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 112000 wavs,50370921 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 113000 wavs,50840090 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 114000 wavs,51286249 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 115000 wavs,51715786 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 116000 wavs,52184017 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 117000 wavs,52666156 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 118000 wavs,53109645 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 119000 wavs,53553253 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 120000 wavs,54009877 frames\n",
"{'mean_stat': [700612678.1184504, 704246512.9321843, 720430663.1822729, 754033269.0474415, 798737761.616614, 829467218.4204571, 851246702.9426627, 862261185.2661449, 859339943.6923889, 846303730.8696194, 832995109.605447, 823196536.6029147, 832626008.2569772, 845571326.1936859, 848801373.0562981, 846503549.328017, 836774344.5500796, 823481091.0445303, 820728368.2518216, 804571348.4957463, 795306095.0083207, 811729024.2415155, 805734803.5703195, 813076782.1959459, 806620199.406499, 809655573.8886961, 804371708.9347517, 809272248.6085774, 810322689.7490631, 814294131.1973915, 816262716.0476038, 816213124.2411841, 817158473.4380915, 821414211.5629157, 827408091.5728914, 834353896.0519086, 840094990.3467333, 842613218.6554606, 842070761.1727513, 834970952.5260613, 837020570.8200948, 829592602.7833654, 830116543.8893851, 829482316.3881509, 833397219.4597517, 839251633.3120549, 845475010.4718693, 852378426.7183967, 859563981.8633184, 866063840.5523493, 867790921.9978689, 868215100.5962687, 869683066.032885, 872467375.6674014, 873097681.1780069, 873025823.0543871, 869897292.7201596, 866386426.3869117, 863166726.7256871, 854653071.2244718, 842402803.9000899, 830838253.4144138, 830143002.3536818, 831492285.0310817, 833304371.8781006, 838896092.8621838, 843866088.9578133, 847316792.1429776, 851038022.3643295, 855931698.0149751, 859320543.9795249, 863031001.3470656, 868325062.1832993, 873626971.0115026, 878726636.924209, 884861725.972504, 886920281.5192285, 883056006.5094173, 863719240.7255149, 773378975.9476194], 'var_stat': [9237018652.657722, 9417257721.82426, 10105084297.159702, 11071318522.587782, 12422783727.426847, 13400306419.784964, 14148498843.406874, 14576436982.89939, 14529009036.494726, 14105645932.596651, 13682988821.478252, 13413013425.088106, 13764134927.293928, 14233704806.737064, 14361631309.367067, 14281358385.45644, 13939662689.213865, 13496884231.929493, 13382566162.783987, 12871350930.6626, 12576198160.876635, 13051463889.56708, 12859205935.513906, 13053861416.098743, 12830323588.550724, 12886405923.897238, 12708529922.84171, 12847306110.231739, 12880398489.53404, 13002566299.565536, 13066708060.463543, 13064231286.858614, 13088983337.353497, 13221393824.891022, 13412425607.755072, 13631485149.777075, 13807797519.156103, 13877277485.033077, 13848613909.96762, 13609176326.2529, 13649815250.130072, 13397698404.696907, 13388964704.359968, 13354326914.968012, 13469861474.898457, 13652539440.283333, 13846837321.329163, 14062143714.601675, 14292571198.61228, 14504626563.299246, 14563864749.132776, 14579720287.991764, 14626700787.353922, 14716185568.128899, 14728532777.28015, 14719101187.113443, 14607945896.239174, 14478517828.531614, 14355110561.681187, 14057430280.249746, 13634284490.879377, 13248236002.494394, 13217602306.335958, 13257856701.946049, 13323688441.072674, 13515395318.023148, 13685827169.67645, 13811622609.426846, 13947347160.615082, 14115883822.884943, 14231204526.433033, 14356066668.651815, 14533604268.238445, 14708971788.69237, 14875667326.732443, 15079098318.79331, 15144888989.667963, 15002658970.504765, 14349232841.34513, 11544480117.013124], 'frame_num': 54068199}\n"
]
}
],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import paddle\n",
"from paddle.io import DataLoader\n",
"from paddle.io import Dataset\n",
"\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"# https://github.com/PaddlePaddle/Paddle/pull/31481\n",
"class CollateFunc(object):\n",
" ''' Collate function for AudioDataset\n",
" '''\n",
" def __init__(self, feature_func):\n",
" self.feature_func = feature_func\n",
" \n",
" def __call__(self, batch):\n",
" mean_stat = None\n",
" var_stat = None\n",
" number = 0\n",
" for item in batch:\n",
" audioseg = AudioSegment.from_file(item['feat'])\n",
" feat = self.feature_func(audioseg) #(D, T)\n",
"\n",
" sums = np.sum(feat, axis=1)\n",
" if mean_stat is None:\n",
" mean_stat = sums\n",
" else:\n",
" mean_stat += sums\n",
"\n",
" square_sums = np.sum(np.square(feat), axis=1)\n",
" if var_stat is None:\n",
" var_stat = square_sums\n",
" else:\n",
" var_stat += square_sums\n",
"\n",
" number += feat.shape[1]\n",
" return number, mean_stat, var_stat\n",
"\n",
"\n",
"class AudioDataset(Dataset):\n",
" def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):\n",
" self._rng = rng if rng else np.random.RandomState(random_seed)\n",
" manifest = read_manifest(manifest_path)\n",
" if num_samples == -1:\n",
" sampled_manifest = manifest\n",
" else:\n",
" sampled_manifest = self._rng.choice(manifest, num_samples, replace=False)\n",
" self.items = sampled_manifest\n",
"\n",
" def __len__(self):\n",
" return len(self.items)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.items[idx]\n",
" \n",
" \n",
"augmentation_pipeline = AugmentationPipeline('{}')\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20)\n",
"\n",
"def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
"\n",
"collate_func = CollateFunc(augment_and_featurize)\n",
"\n",
"dataset = AudioDataset(\n",
" args.manifest_path,\n",
" args.num_samples)\n",
"\n",
"batch_size = 20\n",
"data_loader = DataLoader(\n",
" dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=args.num_workers,\n",
" collate_fn=collate_func)\n",
"\n",
"with paddle.no_grad():\n",
" all_mean_stat = None\n",
" all_var_stat = None\n",
" all_number = 0\n",
" wav_number = 0\n",
" for i, batch in enumerate(data_loader):\n",
" number, mean_stat, var_stat = batch\n",
" if i == 0:\n",
" all_mean_stat = mean_stat\n",
" all_var_stat = var_stat\n",
" else:\n",
" all_mean_stat += mean_stat\n",
" all_var_stat += var_stat\n",
" all_number += number\n",
" wav_number += batch_size\n",
"\n",
" if wav_number % 1000 == 0:\n",
" print('process {} wavs,{} frames'.format(wav_number,\n",
" all_number))\n",
"\n",
"cmvn_info = {\n",
" 'mean_stat': list(all_mean_stat.tolist()),\n",
" 'var_stat': list(all_var_stat.tolist()),\n",
" 'frame_num': all_number\n",
"}\n",
"print(cmvn_info)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unlike-search",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "emerging-meter",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"import math\n",
"import random\n",
"import tarfile\n",
"import logging\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"from functools import partial\n",
"\n",
"import paddle\n",
"from paddle.io import Dataset\n",
"from paddle.io import DataLoader\n",
"from paddle.io import BatchSampler\n",
"from paddle.io import DistributedBatchSampler\n",
"from paddle import distributed as dist\n",
"\n",
"from data_utils.utility import read_manifest\n",
"from data_utils.augmentor.augmentation import AugmentationPipeline\n",
"from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n",
"from data_utils.speech import SpeechSegment\n",
"from data_utils.normalizer import FeatureNormalizer\n",
"\n",
"\n",
"from data_utils.dataset import (\n",
" DeepSpeech2Dataset,\n",
" DeepSpeech2DistributedBatchSampler,\n",
" DeepSpeech2BatchSampler,\n",
" SpeechCollator,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "excessive-american",
"metadata": {},
"outputs": [],
"source": [
"def create_dataloader(manifest_path,\t\n",
" vocab_filepath,\t\n",
" mean_std_filepath,\t\n",
" augmentation_config='{}',\t\n",
" max_duration=float('inf'),\t\n",
" min_duration=0.0,\t\n",
" stride_ms=10.0,\t\n",
" window_ms=20.0,\t\n",
" max_freq=None,\t\n",
" specgram_type='linear',\t\n",
" use_dB_normalization=True,\t\n",
" random_seed=0,\t\n",
" keep_transcription_text=False,\t\n",
" is_training=False,\t\n",
" batch_size=1,\t\n",
" num_workers=0,\t\n",
" sortagrad=False,\t\n",
" shuffle_method=None,\t\n",
" dist=False):\t\n",
"\n",
" dataset = DeepSpeech2Dataset(\t\n",
" manifest_path,\t\n",
" vocab_filepath,\t\n",
" mean_std_filepath,\t\n",
" augmentation_config=augmentation_config,\t\n",
" max_duration=max_duration,\t\n",
" min_duration=min_duration,\t\n",
" stride_ms=stride_ms,\t\n",
" window_ms=window_ms,\t\n",
" max_freq=max_freq,\t\n",
" specgram_type=specgram_type,\t\n",
" use_dB_normalization=use_dB_normalization,\t\n",
" random_seed=random_seed,\t\n",
" keep_transcription_text=keep_transcription_text)\t\n",
"\n",
" if dist:\t\n",
" batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n",
" dataset,\t\n",
" batch_size,\t\n",
" num_replicas=None,\t\n",
" rank=None,\t\n",
" shuffle=is_training,\t\n",
" drop_last=is_training,\t\n",
" sortagrad=is_training,\t\n",
" shuffle_method=shuffle_method)\t\n",
" else:\t\n",
" batch_sampler = DeepSpeech2BatchSampler(\t\n",
" dataset,\t\n",
" shuffle=is_training,\t\n",
" batch_size=batch_size,\t\n",
" drop_last=is_training,\t\n",
" sortagrad=is_training,\t\n",
" shuffle_method=shuffle_method)\t\n",
"\n",
" def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n",
" \"\"\"\t\n",
" Padding audio features with zeros to make them have the same shape (or\t\n",
" a user-defined shape) within one bach.\t\n",
"\n",
" If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n",
" as the target shape for padding. Otherwise, `padding_to` will be the\t\n",
" target shape (only refers to the second axis).\t\n",
"\n",
" If `flatten` is True, features will be flatten to 1darray.\t\n",
" \"\"\"\t\n",
" new_batch = []\t\n",
" # get target shape\t\n",
" max_length = max([audio.shape[1] for audio, text in batch])\t\n",
" if padding_to != -1:\t\n",
" if padding_to < max_length:\t\n",
" raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n",
" \"than any instance's shape in the batch\")\t\n",
" max_length = padding_to\t\n",
" max_text_length = max([len(text) for audio, text in batch])\t\n",
" # padding\t\n",
" padded_audios = []\t\n",
" audio_lens = []\t\n",
" texts, text_lens = [], []\t\n",
" for audio, text in batch:\t\n",
" padded_audio = np.zeros([audio.shape[0], max_length])\t\n",
" padded_audio[:, :audio.shape[1]] = audio\t\n",
" if flatten:\t\n",
" padded_audio = padded_audio.flatten()\t\n",
" padded_audios.append(padded_audio)\t\n",
" audio_lens.append(audio.shape[1])\t\n",
"\n",
" padded_text = np.zeros([max_text_length])\n",
" if is_training:\n",
" padded_text[:len(text)] = text\t# ids\n",
" else:\n",
" padded_text[:len(text)] = [ord(t) for t in text] # string\n",
" \n",
" texts.append(padded_text)\t\n",
" text_lens.append(len(text))\t\n",
"\n",
" padded_audios = np.array(padded_audios).astype('float32')\t\n",
" audio_lens = np.array(audio_lens).astype('int64')\t\n",
" texts = np.array(texts).astype('int32')\t\n",
" text_lens = np.array(text_lens).astype('int64')\t\n",
" return padded_audios, texts, audio_lens, text_lens\t\n",
"\n",
" loader = DataLoader(\t\n",
" dataset,\t\n",
" batch_sampler=batch_sampler,\t\n",
" collate_fn=partial(padding_batch, is_training=is_training),\t\n",
" num_workers=num_workers)\t\n",
" return loader"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "naval-brave",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/data/manifest.dev',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/aishell/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'linear',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "bearing-physics",
"metadata": {},
"outputs": [],
"source": [
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=True,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "classified-melissa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[22823, 26102, 20195, 37324, 0 , 0 ],\n",
" [22238, 26469, 23601, 22909, 0 , 0 ],\n",
" [20108, 26376, 22235, 26085, 0 , 0 ],\n",
" [36824, 35201, 20445, 25345, 32654, 24863],\n",
" [29042, 27748, 21463, 23456, 0 , 0 ]])\n",
"test raw 大时代里\n",
"test raw 煲汤受宠\n",
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [4, 4, 4, 6, 4])\n",
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
" ...,\n",
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
"\n",
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
" ...,\n",
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
]
}
],
"source": [
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len', audio_len)\n",
" print('test len', text_len)\n",
" print('audio', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unexpected-skating",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "minus-modern",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "medieval-monday",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/DeepSpeech-2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "emerging-meter",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n"
]
}
],
"source": [
"import math\n",
"import random\n",
"import tarfile\n",
"import logging\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"from functools import partial\n",
"\n",
"import paddle\n",
"from paddle.io import Dataset\n",
"from paddle.io import DataLoader\n",
"from paddle.io import BatchSampler\n",
"from paddle.io import DistributedBatchSampler\n",
"from paddle import distributed as dist\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "excessive-american",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "naval-brave",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:93] register user softmax to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:97] register user log_softmax to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:101] register user sigmoid to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:105] register user log_sigmoid to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:109] register user relu to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:119] override cat of paddle if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:133] override item of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:144] override long of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:164] override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:179] override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:185] override eq of paddle if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:195] override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:212] override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:223] register user view to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:233] register user view_as to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:259] register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:277] register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:288] register user fill_ to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:298] register user repeat to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:303] register user softmax to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:308] register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:312] register user relu to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:322] register user type_as to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:337] register user to to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:346] register user float to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:356] register user tolist to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:371] register user glu to paddle.nn.functional, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:422] override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:428] register user Module to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:434] register user ModuleList to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:450] register user GLU to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:483] register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:489] register user export to paddle.jit, remove this when fixed!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/tiny/s1/data/spm_bpe', 'infer_manifest': 'examples/tiny/s1/data/manifest.tiny', 'mean_std_path': 'examples/tiny/s1/data/mean_std.npz', 'vocab_path': 'examples/tiny/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/tiny/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('unit_type', str,\n",
" 'char',\n",
" \"Options: char, word, spm.\",\n",
" choices=['char', 'word', 'spm'])\n",
"add_arg('spm_model_prefix', str,\n",
" 'examples/tiny/s1/data/spm_bpe',\n",
" \"spm model prefix.\",)\n",
"add_arg('infer_manifest', str,\n",
" 'examples/tiny/s1/data/manifest.tiny',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/tiny/s1/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/tiny/s1/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/tiny/s1/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n",
"add_arg('delta_delta', bool, False, \"delta delta\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "wired-principal",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/aishell/s1/data/spm_bpe', 'infer_manifest': 'examples/aishell/s1/data/manifest.test', 'mean_std_path': '', 'vocab_path': 'examples/aishell/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('unit_type', str,\n",
" 'char',\n",
" \"Options: char, word, spm.\",\n",
" choices=['char', 'word', 'spm'])\n",
"add_arg('spm_model_prefix', str,\n",
" 'examples/aishell/s1/data/spm_bpe',\n",
" \"spm model prefix.\",)\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/s1/data/manifest.test',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" '',\n",
" \"examples/aishell/s1/data/mean_std.npz, Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/s1/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/s1/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n",
"add_arg('delta_delta', bool, False, \"delta delta\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bearing-physics",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n",
"from deepspeech.frontend.speech import SpeechSegment\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"\n",
"\n",
"from deepspeech.io.collator import SpeechCollator\n",
"from deepspeech.io.dataset import ManifestDataset\n",
"from deepspeech.io.sampler import (\n",
" SortagradDistributedBatchSampler,\n",
" SortagradBatchSampler,\n",
")\n",
"from deepspeech.io import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" unit_type=args.unit_type,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" spm_model_prefix=args.spm_model_prefix,\n",
" augmentation_config='{}',\n",
" max_input_len=27.0,\n",
" min_input_len=0.0,\n",
" max_output_len=float('inf'),\n",
" min_output_len=0.0,\n",
" max_output_input_ratio=float('inf'),\n",
" min_output_input_ratio=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=True,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" num_workers=0,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "classified-melissa",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[232 387 331 ... 249 249 262] int16\n",
"fbank\n",
"[-138 -219 -192 ... 338 324 351] int16\n",
"fbank\n",
"[ 694 1175 1022 ... 553 514 627] int16\n",
"fbank\n",
"[-39 -79 -53 ... 139 172 99] int16\n",
"fbank\n",
"[-277 -480 -425 ... 758 767 739] int16\n",
"fbank\n",
"[ 399 693 609 ... 1291 1270 1291] int16\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if arr.dtype == np.object:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[ -750 -1254 -1107 ... 2276 1889 2067] int16\n",
"fbank\n",
"[ -127 -199 -149 ... -5243 -5065 -5398] int16\n",
"fbank\n",
"[ 465 783 677 ... 980 903 1008] int16\n",
"fbank\n",
"[ 90 160 157 ... -2 -16 -21] int16\n",
"fbank\n",
"[ 213 345 295 ... 2483 2246 2501] int16\n",
"fbank\n",
"[ -86 -159 -131 ... 270 258 290] int16\n",
"fbank\n",
"[-1023 -1714 -1505 ... 1532 1596 1575] int16\n",
"fbank\n",
"[-366 -602 -527 ... 374 370 379] int16\n",
"fbank\n",
"[ 761 1275 1127 ... 369 413 295] int16\n",
"fbank\n",
"[382 621 550 ... 161 161 174] int16\n",
"fbank\n",
"[ -28 -91 -120 ... 28 34 11] int16\n",
"fbank\n",
"[ -5 -5 -5 ... 268 294 341] int16\n",
"fbank\n",
"[240 417 684 ... 267 262 219] int16\n",
"fbank\n",
"[131 206 194 ... 383 320 343] int16\n",
"test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[31069, 21487, 29233, 30340, 20320, -1 , -1 ],\n",
" [20540, 24471, 19968, 25552, 30340, 26159, -1 ],\n",
" [36825, 20010, 31243, 24230, 26159, 32654, 30340],\n",
" [20108, 21040, 20108, -1 , -1 , -1 , -1 ],\n",
" [21435, 34892, 25919, 21270, -1 , -1 , -1 ]])\n",
"fbank\n",
"[1155 1890 1577 ... 1092 989 1130] int16\n",
"fbank\n",
"[296 358 296 ... 140 140 168] int16\n",
"fbank\n",
"[-50 -91 -63 ... 104 104 86] int16\n",
"fbank\n",
"[-37 -66 -50 ... -31 -45 -52] int16\n",
"fbank\n",
"[-401 -652 -547 ... -339 -307 -344] int16\n",
"fbank\n",
"[-21 -47 -51 ... 94 81 107] int16\n",
"fbank\n",
"[ 533 887 755 ... 3074 2853 3254] int16\n",
"fbank\n",
"[ 44 71 66 ... -628 -733 -601] int16\n",
"fbank\n",
"[ 50 86 79 ... 129 116 138] int16\n",
"fbank\n",
"[ 92 146 126 ... -208 -193 -179] int16\n",
"test raw: 祝可爱的你\n",
"test raw: 去行政化\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [184, 194, 196, 204, 207])\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [5, 6, 7, 3, 4])\n",
"audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[12.25633812, 12.61639309, 10.36936474, ..., 13.02949619, 11.51365757, 10.59789085],\n",
" [13.32148266, 13.41071606, 11.43800735, ..., 13.69783783, 12.83939362, 11.51259613],\n",
" [12.62640572, 12.53621101, 10.97212505, ..., 13.33757591, 12.32293034, 10.75493717],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[10.99619484, 11.35202599, 9.56922054 , ..., 9.94971657 , 9.88354111 , 9.55315971 ],\n",
" [10.44461155, 9.81688595 , 5.62538481 , ..., 10.60468388, 10.94417381, 9.42646980 ],\n",
" [10.23835754, 10.23407459, 7.99464273 , ..., 10.68097591, 9.91640091 , 10.04131031],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[14.10299397, 14.50298119, 12.87738323, ..., 12.62796497, 12.69949627, 11.43171215],\n",
" [13.85035992, 13.15289116, 10.66541386, ..., 13.34364223, 13.46972179, 11.02160740],\n",
" [13.19866467, 13.23537827, 11.65760899, ..., 12.72559357, 12.42716217, 11.74562359],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.85668373, 12.82431412, 11.68144703, ..., 14.10119247, 15.12791920, 13.68221378],\n",
" [13.19507027, 13.40244961, 11.43618393, ..., 13.32919979, 13.68267441, 12.73429012],\n",
" [13.02173328, 12.92082500, 11.44303989, ..., 12.77793121, 13.10915661, 11.77327728],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.90771198, 13.40234852, 13.01435471, ..., 13.80359459, 14.08088684, 13.17883396],\n",
" [14.06678009, 14.06943512, 12.52837276, ..., 13.66423225, 13.66300583, 13.60142994],\n",
" [12.58743191, 12.94520760, 11.75190544, ..., 14.28828907, 14.08229160, 13.02433395],\n",
" ...,\n",
" [16.20896912, 16.42283821, 14.94358730, ..., 12.91146755, 12.66766262, 11.76361752],\n",
" [13.49324894, 14.14653301, 13.16490936, ..., 13.23435783, 13.45378494, 12.60386276],\n",
" [15.56288910, 15.92445087, 14.90794277, ..., 13.43840790, 13.41075516, 12.55605984]]])\n"
]
}
],
"source": [
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len:', audio_len)\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unexpected-skating",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "minus-modern",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[232 387 331 ... 249 249 262] int16\n",
"fbank\n",
"[-138 -219 -192 ... 338 324 351] int16\n",
"fbank\n",
"[ 694 1175 1022 ... 553 514 627] int16\n",
"fbank\n",
"[-39 -79 -53 ... 139 172 99] int16\n",
"fbank\n",
"[-277 -480 -425 ... 758 767 739] int16\n",
"fbank\n",
"test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[2695, 505, 2332, 2553, 169, -1 , -1 ],\n",
" [ 230, 1237, 2 , 1556, 2553, 1694, -1 ],\n",
" [3703, 28 , 2739, 1172, 1694, 2966, 2553],\n",
" [ 70 , 355, 70 , -1 , -1 , -1 , -1 ],\n",
" [ 477, 3363, 1621, 412, -1 , -1 , -1 ]])\n",
"[ 399 693 609 ... 1291 1270 1291] int16\n",
"test raw: ઇǹज৹©\n",
"test raw: ǝണٕƜ\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [5, 6, 7, 3, 4])\n",
"audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[12.25794601, 12.61855793, 10.37306023, ..., 13.12571049, 11.53678799, 10.32210350],\n",
" [13.32333183, 13.41336918, 11.44248962, ..., 13.65861225, 12.79308128, 11.31168747],\n",
" [12.62584686, 12.53506088, 10.96861362, ..., 13.32526493, 12.41560936, 10.71458912],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[11.00003052, 11.35529137, 9.56384087 , ..., 10.06063652, 10.16322994, 9.43149185 ],\n",
" [10.44556236, 9.81155300 , 5.49400425 , ..., 10.84116268, 11.02734756, 9.42253590 ],\n",
" [10.23620510, 10.23321152, 7.99466419 , ..., 10.93381882, 10.28395081, 10.00841141],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[14.10379314, 14.50375748, 12.87825108, ..., 12.68065739, 12.62359715, 11.53773308],\n",
" [13.84964657, 13.15079498, 10.67198086, ..., 13.24875164, 13.45796680, 10.97363472],\n",
" [13.19808197, 13.23482990, 11.65900230, ..., 12.70375061, 12.41395664, 11.88668156],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.85676289, 12.82410812, 11.67961884, ..., 14.12018299, 15.14850044, 13.80065727],\n",
" [13.19532776, 13.40243340, 11.43492508, ..., 13.29144669, 13.70278549, 12.67841339],\n",
" [13.02196407, 12.92111111, 11.43998623, ..., 12.71165752, 13.16518497, 11.92028046],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.90661621, 13.40162563, 13.01394463, ..., 13.84056377, 14.11240959, 13.21227264],\n",
" [14.06642914, 14.06922340, 12.52955723, ..., 13.55829811, 13.60157204, 13.50268650],\n",
" [12.58881378, 12.94780254, 11.75758171, ..., 14.29055786, 14.12165928, 13.02695847],\n",
" ...,\n",
" [16.20891571, 16.42290306, 14.94398117, ..., 12.86083794, 12.63515949, 11.67581463],\n",
" [13.49345875, 14.14656067, 13.16498375, ..., 13.28024578, 13.40956783, 12.70357513],\n",
" [15.56265163, 15.92387581, 14.90643024, ..., 13.45694065, 13.44703197, 12.81099033]]])\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [184, 194, 196, 204, 207])\n"
]
}
],
"source": [
"keep_transcription_text=False\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" unit_type=args.unit_type,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" spm_model_prefix=args.spm_model_prefix,\n",
" augmentation_config='{}',\n",
" max_input_len=27.0,\n",
" min_input_len=0.0,\n",
" max_output_len=float('inf'),\n",
" min_output_len=0.0,\n",
" max_output_input_ratio=float('inf'),\n",
" min_output_input_ratio=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=keep_transcription_text,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" num_workers=0,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)\n",
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" print('audio len:', audio_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "competitive-mounting",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"id": "knowing-military",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 1, 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False, 'stride_ms': 10.0, 'window_ms': 25.0, 'sample_rate': 16000, 'manifest_path': 'examples/aishell/s1/data/manifest.train', 'output_path': 'examples/aishell/s1/data/mean_std.npz'}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"\n",
"add_arg('num_samples', int, 1, \"# of samples to for statistics.\")\n",
"add_arg('specgram_type', str, 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc, fbank.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n",
"add_arg('delta_delta', bool, False,\"Audio feature with delta delta.\")\n",
"add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n",
"add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n",
"add_arg('sample_rate', int, 16000, \"target sample rate.\")\n",
"add_arg('manifest_path', str,\n",
" 'examples/aishell/s1/data/manifest.train',\n",
" \"Filepath of manifest to compute normalizer's mean and stddev.\")\n",
"add_arg('output_path', str,\n",
" 'examples/aishell/s1/data/mean_std.npz',\n",
" \"Filepath of write mean and stddev to (.npz).\")\n",
"args = parser.parse_args([])\n",
"print(vars(args))\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "unnecessary-province",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"\n",
"\n",
"def mean(args):\n",
" augmentation_pipeline = AugmentationPipeline('{}')\n",
" audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20,\n",
" dither=0.0)\n",
"\n",
" def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
" normalizer = FeatureNormalizer(\n",
" mean_std_filepath=None,\n",
" manifest_path=args.manifest_path,\n",
" featurize_func=augment_and_featurize,\n",
" num_samples=args.num_samples)\n",
" normalizer.write_to_file(args.output_path)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "interested-camping",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"[54. 90. 77. ... 58. 58. 61.]\n",
"29746\n",
"fbank\n",
"[54 90 77 ... 58 58 61] int16\n",
"(184, 80) float64\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n",
"(184, 80) float64\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n"
]
}
],
"source": [
"wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/test/S0916/BAC009S0916W0426.wav'\n",
"test='祝可爱的你'\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=False,\n",
" target_dB=-20,\n",
" dither=0.0)\n",
"samples = AudioSegment.from_file(wav)\n",
"print(samples._samples)\n",
"print(samples._samples * 2**15)\n",
"print(len(samples._samples))\n",
"feat = audio_featurizer.featurize(samples, False, False)\n",
"feat = feat.T\n",
"print(feat.shape, feat.dtype)\n",
"print(feat)\n",
"\n",
"from python_speech_features import logfbank\n",
"max_freq = args.sample_rate / 2\n",
"fbank_feat = logfbank(\n",
" signal=samples.to('int16'),\n",
" samplerate=args.sample_rate,\n",
" winlen=0.001 * args.window_ms,\n",
" winstep=0.001 * args.stride_ms,\n",
" nfilt=args.feat_dim,\n",
" nfft=512,\n",
" lowfreq=20,\n",
" highfreq=max_freq,\n",
" preemph=0.97,\n",
" dither=0.0,\n",
" wintype='povey')\n",
"print(fbank_feat.shape, fbank_feat.dtype)\n",
"print(fbank_feat)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "numeric-analyst",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(184, 160)\n",
"[ 8.59522397 8.43148278 8.36414052 8.45487173 8.31761643 8.04843683\n",
" 8.01683696 7.6574614 7.95521932 8.22945157 10.20138275 9.0447775\n",
" 9.14763398 9.18184349 9.03801065 9.04852307 8.67706728 8.71894271\n",
" 9.54553655 9.19535135 8.76413076 8.47828946 8.52586143 8.49469288\n",
" 8.72461247 8.28562879 8.11581393 7.99922156 7.91023364 8.04142296\n",
" 7.89762773 7.76257636 8.32043745 8.01592886 8.34109665 8.90115454\n",
" 8.48246945 7.98658664 8.05745122 8.11384088 8.18864479 8.8091827\n",
" 11.8067711 13.25258218 14.44311795 13.90515283 14.00120623 13.99801252\n",
" 13.81595394 13.6379904 13.3574897 13.14933334 12.96518543 13.02601156\n",
" 12.70246737 12.54410834 12.15615068 11.86574681 11.67497882 10.79645481\n",
" 10.48150035 10.03758575 10.05637027 9.92891308 10.06923218 12.43382431\n",
" 12.71428321 14.33135052 13.94470959 14.29188291 14.11483993 14.03496606\n",
" 13.78167331 13.66701466 14.40308625 14.73934137 15.09569382 14.89565815\n",
" 15.10519995 14.94383582 15.03275563 15.42194679 15.29219967 15.41602274\n",
" 15.39242545 15.76836177 16.259222 16.47777231 17.03366795 17.46165793\n",
" 17.52596217 17.78844031 17.99878075 18.11446843 17.95761578 17.99900337\n",
" 17.86282737 17.7290163 17.47686504 17.43425516 17.07750485 16.64395242\n",
" 15.68217043 14.90058399 14.45645737 14.0405463 14.89549542 16.00405781\n",
" 16.27301689 16.37572895 16.31219037 16.31765447 16.44819716 16.36281089\n",
" 16.24932823 15.79302555 14.76361963 13.95761882 13.48917053 13.45543501\n",
" 13.00091327 13.13854248 13.74596395 13.86340629 14.00656109 13.77432101\n",
" 13.64267001 13.35742634 13.23042234 12.97916104 12.80694468 12.70005006\n",
" 13.2802483 13.22644525 13.14579624 13.02536594 13.36511022 11.37167205\n",
" 12.11598045 12.47619798 12.83885973 11.63880287 11.42083924 11.08747705\n",
" 11.04093403 11.11263149 10.74353319 10.58734669 10.46180738 10.34157335\n",
" 9.63131146 9.70582692 9.29059204 8.94583657 8.66065094 8.46799095\n",
" 8.25064103 8.30239167 8.19463371 8.12104567 8.02731234 8.06412715\n",
" 7.84889951 7.73090283 7.74119562 7.85444657 7.80717312 7.7129933\n",
" 7.84087442 7.77907788 7.60660865 7.55051479 7.458385 7.496416\n",
" 7.69519793 7.49086759 7.32199493 8.01617458 7.58525375 7.06661122\n",
" 6.94653756 7.19874283 7.28515661 7.17574078]\n",
"(184,)\n",
"(184,)\n",
"[1.48370471 1.52174523 1.46984238 1.67010478 1.88757689 1.68825992\n",
" 1.74270259 1.55497318 1.29200818 1.68446481 1.88133219 1.97138928\n",
" 2.15910096 2.3149476 1.9820247 2.07694378 1.93498835 2.01493974\n",
" 2.39156824 2.02396518 1.69586449 1.63808752 1.64020228 1.43573473\n",
" 1.93092656 1.37466294 1.34704929 1.59600739 1.03960441 1.45276496\n",
" 1.59360131 1.57466343 1.89491479 1.79333746 1.32701974 1.49441767\n",
" 1.51466756 1.63497989 1.42858074 1.51135396 1.61077201 1.81066387\n",
" 1.83367783 2.3507094 2.87885378 3.26231227 2.1313117 1.98557548\n",
" 1.99105426 2.26150533 2.34298751 2.44621608 2.39201042 2.41226503\n",
" 2.5142992 3.03777565 2.81592295 2.75117863 2.78324175 2.68819666\n",
" 2.8945782 2.84464168 2.680973 2.78397395 2.47996808 1.71829563\n",
" 1.60636949 1.65992483 1.38122631 1.74831825 2.16006884 1.68076185\n",
" 1.69329487 1.44929837 1.63763312 1.80101076 2.01166253 2.03254244\n",
" 1.9583913 2.04542255 2.00859694 2.16600883 2.16095629 1.97541122\n",
" 2.13807632 2.06386436 2.2154187 2.84205688 2.54862449 2.64321545\n",
" 2.6805773 2.52300146 2.53209001 2.54682059 2.4521937 2.43155532\n",
" 2.42571275 2.23421289 2.23164529 2.23597192 2.14215121 2.10406703\n",
" 2.07962874 1.88506161 1.80092372 1.61156092 1.77426835 1.98765563\n",
" 2.0356793 1.87964187 1.779513 1.87187681 1.76463632 1.70978684\n",
" 1.76471778 1.75604749 1.62792552 1.73929352 1.6887024 1.8677704\n",
" 2.17342368 2.08166072 2.14567453 2.15936953 2.18351006 2.41010388\n",
" 2.26101752 2.25468001 2.23739715 2.15395133 2.04547813 1.92038843\n",
" 1.85491264 1.91905927 2.16709365 1.99924152 2.1850471 2.55461622\n",
" 2.72476673 1.69682926 1.73249614 2.06992695 2.1210591 1.66854454\n",
" 1.63907505 1.32203822 1.38992558 1.2436937 1.17932877 1.02963653\n",
" 1.26085036 1.16997132 1.09339504 1.14188689 1.18675772 1.31859788\n",
" 1.21746591 1.3872131 1.26095274 1.34885761 1.46633543 1.64506975\n",
" 1.36013821 1.45574721 1.43766588 1.65119054 1.57163772 1.55082968\n",
" 1.29413316 1.38351736 1.64234673 1.57186432 1.45381083 1.71204761\n",
" 1.51828607 1.30639985 1.32928395 1.49004237 1.6057589 1.81815735\n",
" 1.67784678 1.72180861 1.60703743 1.64850255]\n"
]
}
],
"source": [
"a = np.hstack([feat, feat])\n",
"print(a.shape)\n",
"m = np.mean(a, axis=1)\n",
"print(m)\n",
"print(m.shape)\n",
"std = np.std(a, axis=1)\n",
"print(std.shape)\n",
"print(std)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "nonprofit-potato",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 18,
"id": "hispanic-ethics",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchaudio\n",
"import torchaudio.compliance.kaldi as kaldi\n",
"import torchaudio.sox_effects as sox_effects\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"torchaudio.set_audio_backend(\"sox\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "changing-calvin",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 29746])\n",
"tensor([[54., 90., 77., ..., 58., 58., 61.]])\n",
"(184, 80)\n",
"[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n",
" [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n",
" [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n",
" ...\n",
" [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n",
" [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n",
" [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n",
"-----------\n",
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"(184, 80)\n",
"[[-10.177039 -10.717326 -15.46954 ... -10.546229 -11.897424 -12.987689]\n",
" [ -9.750411 -10.476343 -14.485752 ... -9.557108 -10.436023 -11.955799]\n",
" [-10.525113 -10.798049 -13.46475 ... -10.343097 -11.101464 -12.832712]\n",
" ...\n",
" [-10.649446 -10.907673 -14.056403 ... -10.578607 -11.790988 -12.038239]\n",
" [-10.816959 -11.114918 -12.88781 ... -10.570049 -11.199847 -13.101528]\n",
" [-14.320845 -13.03106 -13.036756 ... -10.829194 -11.171779 -12.634331]]\n",
"**************\n",
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"[54. 90. 77. ... 58. 58. 61.] float32\n",
"(184, 80)\n",
"[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n",
" [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n",
" [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n",
" ...\n",
" [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n",
" [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n",
" [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: torchaudio.backend.sox_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use \"torchaudio.load\".\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
}
],
"source": [
"waveform, sample_rate = torchaudio.load_wav(wav)\n",
"print(waveform.shape)\n",
"print(waveform)\n",
"mat = kaldi.fbank(\n",
" waveform,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)\n",
"\n",
"print('-----------')\n",
"print(samples._samples)\n",
"aud = torch.tensor(samples._samples).view(1, -1)\n",
"mat = kaldi.fbank(\n",
" aud,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)\n",
"\n",
"print('**************')\n",
"print(samples._samples)\n",
"tmp = samples.to('int16').astype('float32')\n",
"print(tmp, tmp.dtype)\n",
"aud = torch.tensor(tmp).view(1, -1)\n",
"mat = kaldi.fbank(\n",
" aud,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "buried-dependence",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "silver-printing",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 20,
"id": "outer-space",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(29746,)\n",
"[54 90 77 ... 58 58 61]\n",
"(184, 80)\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n",
"(184, 13)\n",
"[[ 14.73775998 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n",
" 8.86862748]\n",
" [ 15.31274834 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n",
" 6.78353115]\n",
" [ 13.82218765 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n",
" -0.05919222]\n",
" ...\n",
" [ 13.5837844 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n",
" 5.59045842]\n",
" [ 13.75757034 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n",
" 12.0785146 ]\n",
" [ 11.92813809 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n",
" 0.88337965]]\n"
]
}
],
"source": [
"from python_speech_features import mfcc\n",
"from python_speech_features import delta\n",
"from python_speech_features import logfbank\n",
"import scipy.io.wavfile as iowav\n",
"\n",
"(rate,sig) = iowav.read(wav)\n",
"print(sig.shape)\n",
"print(sig)\n",
"\n",
"# note that generally nfilt=40 is used for speech recognition\n",
"fbank_feat = logfbank(sig,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n",
"print(fbank_feat.shape)\n",
"print(fbank_feat)\n",
"\n",
"# the computed fbank coefficents of english.wav with dimension [110,23]\n",
"# [ 12.2865\t12.6906\t13.1765\t15.714\t16.064\t15.7553\t16.5746\t16.9205\t16.6472\t16.1302\t16.4576\t16.7326\t16.8864\t17.7215\t18.88\t19.1377\t19.1495\t18.6683\t18.3886\t20.3506\t20.2772\t18.8248\t18.1899\n",
"# 11.9198\t13.146\t14.7215\t15.8642\t17.4288\t16.394\t16.8238\t16.1095\t16.4297\t16.6331\t16.3163\t16.5093\t17.4981\t18.3429\t19.6555\t19.6263\t19.8435\t19.0534\t19.001\t20.0287\t19.7707\t19.5852\t19.1112\n",
"# ...\n",
"# ...\n",
"# the same with that using kaldi commands: compute-fbank-feats --dither=0.0\n",
"\n",
"mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')\n",
"print(mfcc_feat.shape)\n",
"print(mfcc_feat)\n",
"\n",
"# the computed mfcc coefficents of english.wav with dimension [110,13]\n",
"# [ 17.1337\t-23.3651\t-7.41751\t-7.73686\t-21.3682\t-8.93884\t-3.70843\t4.68346\t-16.0676\t12.782\t-7.24054\t8.25089\t10.7292\n",
"# 17.1692\t-23.3028\t-5.61872\t-4.0075\t-23.287\t-20.6101\t-5.51584\t-6.15273\t-14.4333\t8.13052\t-0.0345329\t2.06274\t-0.564298\n",
"# ...\n",
"# ...\n",
"# the same with that using kaldi commands: compute-mfcc-feats --dither=0.0"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "sporting-school",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(184, 80)\n",
"[[-10.17703627 -10.71732606 -15.46954014 ... -10.54623152 -11.89742148\n",
" -12.98770428]\n",
" [ -9.75040771 -10.47634331 -14.48575413 ... -9.55710616 -10.43602673\n",
" -11.95581463]\n",
" [-10.52510987 -10.79804975 -13.46475161 ... -10.34309947 -11.10146239\n",
" -12.83273051]\n",
" ...\n",
" [-10.64944197 -10.90767335 -14.05640404 ... -10.57860915 -11.7909807\n",
" -12.03825021]\n",
" [-10.8169558 -11.11491806 -12.88781116 ... -10.57004889 -11.19985048\n",
" -13.10154358]\n",
" [-14.32084168 -13.03106051 -13.03675699 ... -10.82919465 -11.17177892\n",
" -12.63434434]]\n",
"(184, 13)\n",
"[[ -6.05665544 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n",
" 8.86862748]\n",
" [ -5.48166707 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n",
" 6.78353115]\n",
" [ -6.97222776 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n",
" -0.05919222]\n",
" ...\n",
" [ -7.21063102 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n",
" 5.59045842]\n",
" [ -7.03684508 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n",
" 12.0785146 ]\n",
" [ -8.86627732 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n",
" 0.88337965]]\n"
]
}
],
"source": [
"fbank_feat = logfbank(samples._samples,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n",
"print(fbank_feat.shape)\n",
"print(fbank_feat)\n",
"\n",
"mfcc_feat = mfcc(samples._samples,dither=0,useEnergy=True,wintype='povey')\n",
"print(mfcc_feat.shape)\n",
"print(mfcc_feat)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "restricted-license",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "specialized-threat",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 147,
"id": "extensive-venice",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/\n"
]
},
{
"data": {
"text/plain": [
"'/'"
]
},
"execution_count": 147,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 148,
"id": "correct-window",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"manifest.dev\t manifest.test-clean\t manifest.train\r\n",
"manifest.dev.raw manifest.test-clean.raw manifest.train.raw\r\n"
]
}
],
"source": [
"!ls /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/"
]
},
{
"cell_type": "code",
"execution_count": 149,
"id": "exceptional-cheese",
"metadata": {},
"outputs": [],
"source": [
"dev_data='/workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev'"
]
},
{
"cell_type": "code",
"execution_count": 150,
"id": "extraordinary-orleans",
"metadata": {},
"outputs": [],
"source": [
"from deepspeech.frontend.utility import read_manifest"
]
},
{
"cell_type": "code",
"execution_count": 151,
"id": "returning-lighter",
"metadata": {},
"outputs": [],
"source": [
"dev_json = read_manifest(dev_data)"
]
},
{
"cell_type": "code",
"execution_count": 152,
"id": "western-founder",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev/deltafalse/feats.1.ark:16',\n",
" 'name': 'input1',\n",
" 'shape': [1063, 83]}],\n",
" 'output': [{'name': 'target1',\n",
" 'shape': [41, 5002],\n",
" 'text': 'AS I APPROACHED THE CITY I HEARD BELLS RINGING AND A '\n",
" 'LITTLE LATER I FOUND THE STREETS ASTIR WITH THRONGS OF '\n",
" 'WELL DRESSED PEOPLE IN FAMILY GROUPS WENDING THEIR WAY '\n",
" 'HITHER AND THITHER',\n",
" 'token': '▁AS ▁I ▁APPROACHED ▁THE ▁CITY ▁I ▁HEARD ▁BELL S ▁RING '\n",
" 'ING ▁AND ▁A ▁LITTLE ▁LATER ▁I ▁FOUND ▁THE ▁STREETS ▁AS '\n",
" 'T IR ▁WITH ▁THRONG S ▁OF ▁WELL ▁DRESSED ▁PEOPLE ▁IN '\n",
" '▁FAMILY ▁GROUP S ▁WE ND ING ▁THEIR ▁WAY ▁HITHER ▁AND '\n",
" '▁THITHER',\n",
" 'tokenid': '713 2458 676 4502 1155 2458 2351 849 389 3831 206 627 '\n",
" '482 2812 2728 2458 2104 4502 4316 713 404 212 4925 '\n",
" '4549 389 3204 4861 1677 3339 2495 1950 2279 389 4845 '\n",
" '302 206 4504 4843 2394 627 4526'}],\n",
" 'utt': '116-288045-0000',\n",
" 'utt2spk': '116-288045'}\n",
"5542\n",
"<class 'list'>\n"
]
}
],
"source": [
"from pprint import pprint\n",
"pprint(dev_json[0])\n",
"print(len(dev_json))\n",
"print(type(dev_json))"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "motivated-receptor",
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"import itertools\n",
"\n",
"import numpy as np\n",
"\n",
"from deepspeech.utils.log import Log\n",
"\n",
"__all__ = [\"make_batchset\"]\n",
"\n",
"logger = Log(__name__).getlog()\n",
"\n",
"\n",
"def batchfy_by_seq(\n",
" sorted_data,\n",
" batch_size,\n",
" max_length_in,\n",
" max_length_out,\n",
" min_batch_size=1,\n",
" shortest_first=False,\n",
" ikey=\"input\",\n",
" iaxis=0,\n",
" okey=\"output\",\n",
" oaxis=0, ):\n",
" \"\"\"Make batch set from json dictionary\n",
"\n",
" :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
" :param int batch_size: batch size\n",
" :param int max_length_in: maximum length of input to decide adaptive batch size\n",
" :param int max_length_out: maximum length of output to decide adaptive batch size\n",
" :param int min_batch_size: mininum batch size (for multi-gpu)\n",
" :param bool shortest_first: Sort from batch with shortest samples\n",
" to longest if true, otherwise reverse\n",
" :param str ikey: key to access input\n",
" (for ASR ikey=\"input\", for TTS, MT ikey=\"output\".)\n",
" :param int iaxis: dimension to access input\n",
" (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
" :param str okey: key to access output\n",
" (for ASR, MT okey=\"output\". for TTS okey=\"input\".)\n",
" :param int oaxis: dimension to access output\n",
" (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)\n",
" :return: List[List[Tuple[str, dict]]] list of batches\n",
" \"\"\"\n",
" if batch_size <= 0:\n",
" raise ValueError(f\"Invalid batch_size={batch_size}\")\n",
"\n",
" # check #utts is more than min_batch_size\n",
" if len(sorted_data) < min_batch_size:\n",
" raise ValueError(\n",
" f\"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).\"\n",
" )\n",
"\n",
" # make list of minibatches\n",
" minibatches = []\n",
" start = 0\n",
" while True:\n",
" _, info = sorted_data[start]\n",
" ilen = int(info[ikey][iaxis][\"shape\"][0])\n",
" olen = (int(info[okey][oaxis][\"shape\"][0]) if oaxis >= 0 else\n",
" max(map(lambda x: int(x[\"shape\"][0]), info[okey])))\n",
" factor = max(int(ilen / max_length_in), int(olen / max_length_out))\n",
" # change batchsize depending on the input and output length\n",
" # if ilen = 1000 and max_length_in = 800\n",
" # then b = batchsize / 2\n",
" # and max(min_batches, .) avoids batchsize = 0\n",
" bs = max(min_batch_size, int(batch_size / (1 + factor)))\n",
" end = min(len(sorted_data), start + bs)\n",
" minibatch = sorted_data[start:end]\n",
" if shortest_first:\n",
" minibatch.reverse()\n",
"\n",
" # check each batch is more than minimum batchsize\n",
" if len(minibatch) < min_batch_size:\n",
" mod = min_batch_size - len(minibatch) % min_batch_size\n",
" additional_minibatch = [\n",
" sorted_data[i] for i in np.random.randint(0, start, mod)\n",
" ]\n",
" if shortest_first:\n",
" additional_minibatch.reverse()\n",
" minibatch.extend(additional_minibatch)\n",
" minibatches.append(minibatch)\n",
"\n",
" if end == len(sorted_data):\n",
" break\n",
" start = end\n",
"\n",
" # batch: List[List[Tuple[str, dict]]]\n",
" return minibatches\n",
"\n",
"\n",
"def batchfy_by_bin(\n",
" sorted_data,\n",
" batch_bins,\n",
" num_batches=0,\n",
" min_batch_size=1,\n",
" shortest_first=False,\n",
" ikey=\"input\",\n",
" okey=\"output\", ):\n",
" \"\"\"Make variably sized batch set, which maximizes\n",
"\n",
" the number of bins up to `batch_bins`.\n",
"\n",
" :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
" :param int batch_bins: Maximum frames of a batch\n",
" :param int num_batches: # number of batches to use (for debug)\n",
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
" :param int test: Return only every `test` batches\n",
" :param bool shortest_first: Sort from batch with shortest samples\n",
" to longest if true, otherwise reverse\n",
"\n",
" :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
" :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
"\n",
" :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
" \"\"\"\n",
" if batch_bins <= 0:\n",
" raise ValueError(f\"invalid batch_bins={batch_bins}\")\n",
" length = len(sorted_data)\n",
" idim = int(sorted_data[0][1][ikey][0][\"shape\"][1])\n",
" odim = int(sorted_data[0][1][okey][0][\"shape\"][1])\n",
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
" minibatches = []\n",
" start = 0\n",
" n = 0\n",
" while True:\n",
" # Dynamic batch size depending on size of samples\n",
" b = 0\n",
" next_size = 0\n",
" max_olen = 0\n",
" while next_size < batch_bins and (start + b) < length:\n",
" ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0]) * idim\n",
" olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0]) * odim\n",
" if olen > max_olen:\n",
" max_olen = olen\n",
" next_size = (max_olen + ilen) * (b + 1)\n",
" if next_size <= batch_bins:\n",
" b += 1\n",
" elif next_size == 0:\n",
" raise ValueError(\n",
" f\"Can't fit one sample in batch_bins ({batch_bins}): \"\n",
" f\"Please increase the value\")\n",
" end = min(length, start + max(min_batch_size, b))\n",
" batch = sorted_data[start:end]\n",
" if shortest_first:\n",
" batch.reverse()\n",
" minibatches.append(batch)\n",
" # Check for min_batch_size and fixes the batches if needed\n",
" i = -1\n",
" while len(minibatches[i]) < min_batch_size:\n",
" missing = min_batch_size - len(minibatches[i])\n",
" if -i == len(minibatches):\n",
" minibatches[i + 1].extend(minibatches[i])\n",
" minibatches = minibatches[1:]\n",
" break\n",
" else:\n",
" minibatches[i].extend(minibatches[i - 1][:missing])\n",
" minibatches[i - 1] = minibatches[i - 1][missing:]\n",
" i -= 1\n",
" if end == length:\n",
" break\n",
" start = end\n",
" n += 1\n",
" if num_batches > 0:\n",
" minibatches = minibatches[:num_batches]\n",
" lengths = [len(x) for x in minibatches]\n",
" logger.info(\n",
" str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
" + \" to \" + str(max(lengths)) + \" samples \" + \"(avg \" + str(\n",
" int(np.mean(lengths))) + \" samples).\")\n",
" return minibatches\n",
"\n",
"\n",
"def batchfy_by_frame(\n",
" sorted_data,\n",
" max_frames_in,\n",
" max_frames_out,\n",
" max_frames_inout,\n",
" num_batches=0,\n",
" min_batch_size=1,\n",
" shortest_first=False,\n",
" ikey=\"input\",\n",
" okey=\"output\", ):\n",
" \"\"\"Make variable batch set, which maximizes the number of frames to max_batch_frame.\n",
"\n",
" :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json\n",
" :param int max_frames_in: Maximum input frames of a batch\n",
" :param int max_frames_out: Maximum output frames of a batch\n",
" :param int max_frames_inout: Maximum input+output frames of a batch\n",
" :param int num_batches: # number of batches to use (for debug)\n",
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
" :param int test: Return only every `test` batches\n",
" :param bool shortest_first: Sort from batch with shortest samples\n",
" to longest if true, otherwise reverse\n",
"\n",
" :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
" :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
"\n",
" :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
" \"\"\"\n",
" if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:\n",
" raise ValueError(\n",
" \"At least, one of `--batch-frames-in`, `--batch-frames-out` or \"\n",
" \"`--batch-frames-inout` should be > 0\")\n",
" length = len(sorted_data)\n",
" minibatches = []\n",
" start = 0\n",
" end = 0\n",
" while end != length:\n",
" # Dynamic batch size depending on size of samples\n",
" b = 0\n",
" max_olen = 0\n",
" max_ilen = 0\n",
" while (start + b) < length:\n",
" ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0])\n",
" if ilen > max_frames_in and max_frames_in != 0:\n",
" raise ValueError(\n",
" f\"Can't fit one sample in --batch-frames-in ({max_frames_in}): \"\n",
" f\"Please increase the value\")\n",
" olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0])\n",
" if olen > max_frames_out and max_frames_out != 0:\n",
" raise ValueError(\n",
" f\"Can't fit one sample in --batch-frames-out ({max_frames_out}): \"\n",
" f\"Please increase the value\")\n",
" if ilen + olen > max_frames_inout and max_frames_inout != 0:\n",
" raise ValueError(\n",
" f\"Can't fit one sample in --batch-frames-out ({max_frames_inout}): \"\n",
" f\"Please increase the value\")\n",
" max_olen = max(max_olen, olen)\n",
" max_ilen = max(max_ilen, ilen)\n",
" in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0\n",
" out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0\n",
" inout_ok = (max_ilen + max_olen) * (\n",
" b + 1) <= max_frames_inout or max_frames_inout == 0\n",
" if in_ok and out_ok and inout_ok:\n",
" # add more seq in the minibatch\n",
" b += 1\n",
" else:\n",
" # no more seq in the minibatch\n",
" break\n",
" end = min(length, start + b)\n",
" batch = sorted_data[start:end]\n",
" if shortest_first:\n",
" batch.reverse()\n",
" minibatches.append(batch)\n",
" # Check for min_batch_size and fixes the batches if needed\n",
" i = -1\n",
" while len(minibatches[i]) < min_batch_size:\n",
" missing = min_batch_size - len(minibatches[i])\n",
" if -i == len(minibatches):\n",
" minibatches[i + 1].extend(minibatches[i])\n",
" minibatches = minibatches[1:]\n",
" break\n",
" else:\n",
" minibatches[i].extend(minibatches[i - 1][:missing])\n",
" minibatches[i - 1] = minibatches[i - 1][missing:]\n",
" i -= 1\n",
" start = end\n",
" if num_batches > 0:\n",
" minibatches = minibatches[:num_batches]\n",
" lengths = [len(x) for x in minibatches]\n",
" logger.info(\n",
" str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
" + \" to \" + str(max(lengths)) + \" samples\" + \"(avg \" + str(\n",
" int(np.mean(lengths))) + \" samples).\")\n",
"\n",
" return minibatches\n",
"\n",
"\n",
"def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,\n",
" shortest_first):\n",
" import random\n",
"\n",
" logger.info(\"use shuffled batch.\")\n",
" sorted_data = random.sample(data.items(), len(data.items()))\n",
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
" # make list of minibatches\n",
" minibatches = []\n",
" start = 0\n",
" while True:\n",
" end = min(len(sorted_data), start + batch_size)\n",
" # check each batch is more than minimum batchsize\n",
" minibatch = sorted_data[start:end]\n",
" if shortest_first:\n",
" minibatch.reverse()\n",
" if len(minibatch) < min_batch_size:\n",
" mod = min_batch_size - len(minibatch) % min_batch_size\n",
" additional_minibatch = [\n",
" sorted_data[i] for i in np.random.randint(0, start, mod)\n",
" ]\n",
" if shortest_first:\n",
" additional_minibatch.reverse()\n",
" minibatch.extend(additional_minibatch)\n",
" minibatches.append(minibatch)\n",
" if end == len(sorted_data):\n",
" break\n",
" start = end\n",
"\n",
" # for debugging\n",
" if num_batches > 0:\n",
" minibatches = minibatches[:num_batches]\n",
" logger.info(\"# minibatches: \" + str(len(minibatches)))\n",
" return minibatches\n",
"\n",
"\n",
"BATCH_COUNT_CHOICES = [\"auto\", \"seq\", \"bin\", \"frame\"]\n",
"BATCH_SORT_KEY_CHOICES = [\"input\", \"output\", \"shuffle\"]\n",
"\n",
"\n",
"def make_batchset(\n",
" data,\n",
" batch_size=0,\n",
" max_length_in=float(\"inf\"),\n",
" max_length_out=float(\"inf\"),\n",
" num_batches=0,\n",
" min_batch_size=1,\n",
" shortest_first=False,\n",
" batch_sort_key=\"input\",\n",
" count=\"auto\",\n",
" batch_bins=0,\n",
" batch_frames_in=0,\n",
" batch_frames_out=0,\n",
" batch_frames_inout=0,\n",
" iaxis=0,\n",
" oaxis=0, ):\n",
" \"\"\"Make batch set from json dictionary\n",
"\n",
" if utts have \"category\" value,\n",
"\n",
" >>> data = {'utt1': {'category': 'A', 'input': ...},\n",
" ... 'utt2': {'category': 'B', 'input': ...},\n",
" ... 'utt3': {'category': 'B', 'input': ...},\n",
" ... 'utt4': {'category': 'A', 'input': ...}}\n",
" >>> make_batchset(data, batchsize=2, ...)\n",
" [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]\n",
"\n",
" Note that if any utts doesn't have \"category\",\n",
" perform as same as batchfy_by_{count}\n",
"\n",
" :param List[Dict[str, Any]] data: dictionary loaded from data.json\n",
" :param int batch_size: maximum number of sequences in a minibatch.\n",
" :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.\n",
" :param int batch_frames_in: maximum number of input frames in a minibatch.\n",
" :param int batch_frames_out: maximum number of output frames in a minibatch.\n",
" :param int batch_frames_out: maximum number of input+output frames in a minibatch.\n",
" :param str count: strategy to count maximum size of batch.\n",
" For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES\n",
"\n",
" :param int max_length_in: maximum length of input to decide adaptive batch size\n",
" :param int max_length_out: maximum length of output to decide adaptive batch size\n",
" :param int num_batches: # number of batches to use (for debug)\n",
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
" :param bool shortest_first: Sort from batch with shortest samples\n",
" to longest if true, otherwise reverse\n",
" :param str batch_sort_key: how to sort data before creating minibatches\n",
" [\"input\", \"output\", \"shuffle\"]\n",
" :param bool swap_io: if True, use \"input\" as output and \"output\"\n",
" as input in `data` dict\n",
" :param bool mt: if True, use 0-axis of \"output\" as output and 1-axis of \"output\"\n",
" as input in `data` dict\n",
" :param int iaxis: dimension to access input\n",
" (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
" :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,\n",
" reserved for future research, -1 means all axis.)\n",
" :return: List[List[Tuple[str, dict]]] list of batches\n",
" \"\"\"\n",
"\n",
" # check args\n",
" if count not in BATCH_COUNT_CHOICES:\n",
" raise ValueError(\n",
" f\"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}\")\n",
" if batch_sort_key not in BATCH_SORT_KEY_CHOICES:\n",
" raise ValueError(f\"arg 'batch_sort_key' ({batch_sort_key}) should be \"\n",
" f\"one of {BATCH_SORT_KEY_CHOICES}\")\n",
"\n",
" ikey = \"input\"\n",
" okey = \"output\"\n",
" batch_sort_axis = 0 # index of list \n",
"\n",
" if count == \"auto\":\n",
" if batch_size != 0:\n",
" count = \"seq\"\n",
" elif batch_bins != 0:\n",
" count = \"bin\"\n",
" elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:\n",
" count = \"frame\"\n",
" else:\n",
" raise ValueError(\n",
" f\"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}\"\n",
" )\n",
" logger.info(f\"count is auto detected as {count}\")\n",
"\n",
" if count != \"seq\" and batch_sort_key == \"shuffle\":\n",
" raise ValueError(\n",
" \"batch_sort_key=shuffle is only available if batch_count=seq\")\n",
"\n",
" category2data = {} # Dict[str, dict]\n",
" for v in data:\n",
" k = v['utt']\n",
" category2data.setdefault(v.get(\"category\"), {})[k] = v\n",
"\n",
" batches_list = [] # List[List[List[Tuple[str, dict]]]]\n",
" for d in category2data.values():\n",
" if batch_sort_key == \"shuffle\":\n",
" batches = batchfy_shuffle(d, batch_size, min_batch_size,\n",
" num_batches, shortest_first)\n",
" batches_list.append(batches)\n",
" continue\n",
"\n",
" # sort it by input lengths (long to short)\n",
" sorted_data = sorted(\n",
" d.items(),\n",
" key=lambda data: int(data[1][batch_sort_key][batch_sort_axis][\"shape\"][0]),\n",
" reverse=not shortest_first, )\n",
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
" \n",
" if count == \"seq\":\n",
" batches = batchfy_by_seq(\n",
" sorted_data,\n",
" batch_size=batch_size,\n",
" max_length_in=max_length_in,\n",
" max_length_out=max_length_out,\n",
" min_batch_size=min_batch_size,\n",
" shortest_first=shortest_first,\n",
" ikey=ikey,\n",
" iaxis=iaxis,\n",
" okey=okey,\n",
" oaxis=oaxis, )\n",
" if count == \"bin\":\n",
" batches = batchfy_by_bin(\n",
" sorted_data,\n",
" batch_bins=batch_bins,\n",
" min_batch_size=min_batch_size,\n",
" shortest_first=shortest_first,\n",
" ikey=ikey,\n",
" okey=okey, )\n",
" if count == \"frame\":\n",
" batches = batchfy_by_frame(\n",
" sorted_data,\n",
" max_frames_in=batch_frames_in,\n",
" max_frames_out=batch_frames_out,\n",
" max_frames_inout=batch_frames_inout,\n",
" min_batch_size=min_batch_size,\n",
" shortest_first=shortest_first,\n",
" ikey=ikey,\n",
" okey=okey, )\n",
" batches_list.append(batches)\n",
"\n",
" if len(batches_list) == 1:\n",
" batches = batches_list[0]\n",
" else:\n",
" # Concat list. This way is faster than \"sum(batch_list, [])\"\n",
" batches = list(itertools.chain(*batches_list))\n",
"\n",
" # for debugging\n",
" if num_batches > 0:\n",
" batches = batches[:num_batches]\n",
" logger.info(\"# minibatches: \" + str(len(batches)))\n",
"\n",
" # batch: List[List[Tuple[str, dict]]]\n",
" return batches\n"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "acquired-hurricane",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO 2021/08/18 06:57:10 1445365138.py:284] use shuffled batch.\n",
"[INFO 2021/08/18 06:57:10 1445365138.py:286] # utts: 5542\n",
"[INFO 2021/08/18 06:57:10 1445365138.py:468] # minibatches: 555\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"555\n"
]
}
],
"source": [
"batch_size=10\n",
"maxlen_in=300\n",
"maxlen_out=400\n",
"minibatches=0 # for debug\n",
"min_batch_size=2\n",
"use_sortagrad=True\n",
"batch_count='seq'\n",
"batch_bins=0\n",
"batch_frames_in=3000\n",
"batch_frames_out=0\n",
"batch_frames_inout=0\n",
" \n",
"dev_data = make_batchset(\n",
" dev_json,\n",
" batch_size,\n",
" maxlen_in,\n",
" maxlen_out,\n",
" minibatches, # for debug\n",
" min_batch_size=min_batch_size,\n",
" shortest_first=use_sortagrad,\n",
" batch_sort_key=\"shuffle\",\n",
" count=batch_count,\n",
" batch_bins=batch_bins,\n",
" batch_frames_in=batch_frames_in,\n",
" batch_frames_out=batch_frames_out,\n",
" batch_frames_inout=batch_frames_inout,\n",
" iaxis=0,\n",
" oaxis=0, )\n",
"print(len(dev_data))\n",
"# for i in range(len(dev_data)):\n",
"# print(len(dev_data[i]))\n"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "warming-malpractice",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: kaldiio in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (2.17.2)\n",
"Requirement already satisfied: numpy in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numpy-1.21.2-py3.7-linux-x86_64.egg (from kaldiio) (1.21.2)\n",
"\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.2.4 is available.\n",
"You should consider upgrading via the '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"source": [
"!pip install kaldiio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "equipped-subject",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 100,
"id": "superb-methodology",
"metadata": {},
"outputs": [],
"source": [
"from collections import OrderedDict\n",
"import kaldiio\n",
"\n",
"class LoadInputsAndTargets():\n",
" \"\"\"Create a mini-batch from a list of dicts\n",
"\n",
" >>> batch = [('utt1',\n",
" ... dict(input=[dict(feat='some.ark:123',\n",
" ... filetype='mat',\n",
" ... name='input1',\n",
" ... shape=[100, 80])],\n",
" ... output=[dict(tokenid='1 2 3 4',\n",
" ... name='target1',\n",
" ... shape=[4, 31])]]))\n",
" >>> l = LoadInputsAndTargets()\n",
" >>> feat, target = l(batch)\n",
"\n",
" :param: str mode: Specify the task mode, \"asr\" or \"tts\"\n",
" :param: str preprocess_conf: The path of a json file for pre-processing\n",
" :param: bool load_input: If False, not to load the input data\n",
" :param: bool load_output: If False, not to load the output data\n",
" :param: bool sort_in_input_length: Sort the mini-batch in descending order\n",
" of the input length\n",
" :param: bool use_speaker_embedding: Used for tts mode only\n",
" :param: bool use_second_target: Used for tts mode only\n",
" :param: dict preprocess_args: Set some optional arguments for preprocessing\n",
" :param: Optional[dict] preprocess_args: Used for tts mode only\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" mode=\"asr\",\n",
" preprocess_conf=None,\n",
" load_input=True,\n",
" load_output=True,\n",
" sort_in_input_length=True,\n",
" preprocess_args=None,\n",
" keep_all_data_on_mem=False, ):\n",
" self._loaders = {}\n",
"\n",
" if mode not in [\"asr\"]:\n",
" raise ValueError(\"Only asr are allowed: mode={}\".format(mode))\n",
"\n",
" if preprocess_conf is not None:\n",
" self.preprocessing = AugmentationPipeline(preprocess_conf)\n",
" logging.warning(\n",
" \"[Experimental feature] Some preprocessing will be done \"\n",
" \"for the mini-batch creation using {}\".format(\n",
" self.preprocessing))\n",
" else:\n",
" # If conf doesn't exist, this function don't touch anything.\n",
" self.preprocessing = None\n",
"\n",
" self.mode = mode\n",
" self.load_output = load_output\n",
" self.load_input = load_input\n",
" self.sort_in_input_length = sort_in_input_length\n",
" if preprocess_args is None:\n",
" self.preprocess_args = {}\n",
" else:\n",
" assert isinstance(preprocess_args, dict), type(preprocess_args)\n",
" self.preprocess_args = dict(preprocess_args)\n",
"\n",
" self.keep_all_data_on_mem = keep_all_data_on_mem\n",
"\n",
" def __call__(self, batch, return_uttid=False):\n",
" \"\"\"Function to load inputs and targets from list of dicts\n",
"\n",
" :param List[Tuple[str, dict]] batch: list of dict which is subset of\n",
" loaded data.json\n",
" :param bool return_uttid: return utterance ID information for visualization\n",
" :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]\n",
" :return: list of input feature sequences\n",
" [(T_1, D), (T_2, D), ..., (T_B, D)]\n",
" :rtype: list of float ndarray\n",
" :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]\n",
" :rtype: list of int ndarray\n",
"\n",
" \"\"\"\n",
" x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n",
" y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n",
" uttid_list = [] # List[str]\n",
"\n",
" for uttid, info in batch:\n",
" uttid_list.append(uttid)\n",
"\n",
" if self.load_input:\n",
" # Note(kamo): This for-loop is for multiple inputs\n",
" for idx, inp in enumerate(info[\"input\"]):\n",
" # {\"input\":\n",
" # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"hdf5\",\n",
" # \"name\": \"input1\", ...}], ...}\n",
" x = self._get_from_loader(\n",
" filepath=inp[\"feat\"],\n",
" filetype=inp.get(\"filetype\", \"mat\"))\n",
" x_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
"\n",
" if self.load_output:\n",
" for idx, inp in enumerate(info[\"output\"]):\n",
" if \"tokenid\" in inp:\n",
" # ======= Legacy format for output =======\n",
" # {\"output\": [{\"tokenid\": \"1 2 3 4\"}])\n",
" x = np.fromiter(\n",
" map(int, inp[\"tokenid\"].split()), dtype=np.int64)\n",
" else:\n",
" # ======= New format =======\n",
" # {\"input\":\n",
" # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"hdf5\",\n",
" # \"name\": \"target1\", ...}], ...}\n",
" x = self._get_from_loader(\n",
" filepath=inp[\"feat\"],\n",
" filetype=inp.get(\"filetype\", \"mat\"))\n",
"\n",
" y_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
"\n",
" if self.mode == \"asr\":\n",
" return_batch, uttid_list = self._create_batch_asr(\n",
" x_feats_dict, y_feats_dict, uttid_list)\n",
" else:\n",
" raise NotImplementedError(self.mode)\n",
"\n",
" if self.preprocessing is not None:\n",
" # Apply pre-processing all input features\n",
" for x_name in return_batch.keys():\n",
" if x_name.startswith(\"input\"):\n",
" return_batch[x_name] = self.preprocessing(\n",
" return_batch[x_name], uttid_list,\n",
" **self.preprocess_args)\n",
"\n",
" if return_uttid:\n",
" return tuple(return_batch.values()), uttid_list\n",
"\n",
" # Doesn't return the names now.\n",
" return tuple(return_batch.values())\n",
"\n",
" def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):\n",
" \"\"\"Create a OrderedDict for the mini-batch\n",
"\n",
" :param OrderedDict x_feats_dict:\n",
" e.g. {\"input1\": [ndarray, ndarray, ...],\n",
" \"input2\": [ndarray, ndarray, ...]}\n",
" :param OrderedDict y_feats_dict:\n",
" e.g. {\"target1\": [ndarray, ndarray, ...],\n",
" \"target2\": [ndarray, ndarray, ...]}\n",
" :param: List[str] uttid_list:\n",
" Give uttid_list to sort in the same order as the mini-batch\n",
" :return: batch, uttid_list\n",
" :rtype: Tuple[OrderedDict, List[str]]\n",
" \"\"\"\n",
" # handle single-input and multi-input (paralell) asr mode\n",
" xs = list(x_feats_dict.values())\n",
"\n",
" if self.load_output:\n",
" ys = list(y_feats_dict.values())\n",
" assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))\n",
"\n",
" # get index of non-zero length samples\n",
" nonzero_idx = list(\n",
" filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))\n",
" for n in range(1, len(y_feats_dict)):\n",
" nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)\n",
" else:\n",
" # Note(kamo): Be careful not to make nonzero_idx to a generator\n",
" nonzero_idx = list(range(len(xs[0])))\n",
"\n",
" if self.sort_in_input_length:\n",
" # sort in input lengths based on the first input\n",
" nonzero_sorted_idx = sorted(\n",
" nonzero_idx, key=lambda i: -len(xs[0][i]))\n",
" else:\n",
" nonzero_sorted_idx = nonzero_idx\n",
"\n",
" if len(nonzero_sorted_idx) != len(xs[0]):\n",
" logging.warning(\n",
" \"Target sequences include empty tokenid (batch {} -> {}).\".\n",
" format(len(xs[0]), len(nonzero_sorted_idx)))\n",
"\n",
" # remove zero-length samples\n",
" xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]\n",
" uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]\n",
"\n",
" x_names = list(x_feats_dict.keys())\n",
" if self.load_output:\n",
" ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]\n",
" y_names = list(y_feats_dict.keys())\n",
"\n",
" # Keeping x_name and y_name, e.g. input1, for future extension\n",
" return_batch = OrderedDict([\n",
" * [(x_name, x) for x_name, x in zip(x_names, xs)],\n",
" * [(y_name, y) for y_name, y in zip(y_names, ys)],\n",
" ])\n",
" else:\n",
" return_batch = OrderedDict(\n",
" [(x_name, x) for x_name, x in zip(x_names, xs)])\n",
" return return_batch, uttid_list\n",
"\n",
" def _get_from_loader(self, filepath, filetype):\n",
" \"\"\"Return ndarray\n",
"\n",
" In order to make the fds to be opened only at the first referring,\n",
" the loader are stored in self._loaders\n",
"\n",
" >>> ndarray = loader.get_from_loader(\n",
" ... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')\n",
"\n",
" :param: str filepath:\n",
" :param: str filetype:\n",
" :return:\n",
" :rtype: np.ndarray\n",
" \"\"\"\n",
" if filetype == \"hdf5\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"hdf5\",\n",
" # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
" filepath, key = filepath.split(\":\", 1)\n",
"\n",
" loader = self._loaders.get(filepath)\n",
" if loader is None:\n",
" # To avoid disk access, create loader only for the first time\n",
" loader = h5py.File(filepath, \"r\")\n",
" self._loaders[filepath] = loader\n",
" return loader[key][()]\n",
" elif filetype == \"sound.hdf5\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"sound.hdf5\",\n",
" # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
" filepath, key = filepath.split(\":\", 1)\n",
"\n",
" loader = self._loaders.get(filepath)\n",
" if loader is None:\n",
" # To avoid disk access, create loader only for the first time\n",
" loader = SoundHDF5File(filepath, \"r\", dtype=\"int16\")\n",
" self._loaders[filepath] = loader\n",
" array, rate = loader[key]\n",
" return array\n",
" elif filetype == \"sound\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.wav\",\n",
" # \"filetype\": \"sound\"},\n",
" # Assume PCM16\n",
" if not self.keep_all_data_on_mem:\n",
" array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
" return array\n",
" if filepath not in self._loaders:\n",
" array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
" self._loaders[filepath] = array\n",
" return self._loaders[filepath]\n",
" elif filetype == \"npz\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.npz:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"npz\",\n",
" filepath, key = filepath.split(\":\", 1)\n",
"\n",
" loader = self._loaders.get(filepath)\n",
" if loader is None:\n",
" # To avoid disk access, create loader only for the first time\n",
" loader = np.load(filepath)\n",
" self._loaders[filepath] = loader\n",
" return loader[key]\n",
" elif filetype == \"npy\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.npy\",\n",
" # \"filetype\": \"npy\"},\n",
" if not self.keep_all_data_on_mem:\n",
" return np.load(filepath)\n",
" if filepath not in self._loaders:\n",
" self._loaders[filepath] = np.load(filepath)\n",
" return self._loaders[filepath]\n",
" elif filetype in [\"mat\", \"vec\"]:\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.ark:123\",\n",
" # \"filetype\": \"mat\"}]},\n",
" # In this case, \"123\" indicates the starting points of the matrix\n",
" # load_mat can load both matrix and vector\n",
" if not self.keep_all_data_on_mem:\n",
" return kaldiio.load_mat(filepath)\n",
" if filepath not in self._loaders:\n",
" self._loaders[filepath] = kaldiio.load_mat(filepath)\n",
" return self._loaders[filepath]\n",
" elif filetype == \"scp\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.scp:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"scp\",\n",
" filepath, key = filepath.split(\":\", 1)\n",
" loader = self._loaders.get(filepath)\n",
" if loader is None:\n",
" # To avoid disk access, create loader only for the first time\n",
" loader = kaldiio.load_scp(filepath)\n",
" self._loaders[filepath] = loader\n",
" return loader[key]\n",
" else:\n",
" raise NotImplementedError(\n",
" \"Not supported: loader_type={}\".format(filetype))\n"
]
},
{
"cell_type": "code",
"execution_count": 101,
"id": "monthly-muscle",
"metadata": {},
"outputs": [],
"source": [
"preprocess_conf=None\n",
"train_mode=True\n",
"load = LoadInputsAndTargets(\n",
" mode=\"asr\",\n",
" load_output=True,\n",
" preprocess_conf=preprocess_conf,\n",
" preprocess_args={\"train\":\n",
" train_mode}, # Switch the mode of preprocessing\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "periodic-senegal",
"metadata": {},
"outputs": [],
"source": [
"res = load(dev_data[0])"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "502d3f4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'tuple'>\n",
"2\n",
"10\n",
"10\n",
"(1174, 83) float32\n",
"(29,) int64\n"
]
}
],
"source": [
"print(type(res))\n",
"print(len(res))\n",
"print(len(res[0]))\n",
"print(len(res[1]))\n",
"print(res[0][0].shape, res[0][0].dtype)\n",
"print(res[1][0].shape, res[1][0].dtype)\n",
"# Tuple[Tuple[np.ndarry], Tuple[np.ndarry]]\n",
"# 2[10, 10]\n",
"# feats, labels"
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "humanitarian-container",
"metadata": {},
"outputs": [],
"source": [
"(inputs, outputs), utts = load(dev_data[0], return_uttid=True)"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "heard-prize",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027'] 10\n",
"10\n"
]
}
],
"source": [
"print(utts, len(utts))\n",
"print(len(inputs))"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "convinced-animation",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from deepspeech.io.utility import pad_list\n",
"class CustomConverter():\n",
" \"\"\"Custom batch converter.\n",
"\n",
" Args:\n",
" subsampling_factor (int): The subsampling factor.\n",
" dtype (paddle.dtype): Data type to convert.\n",
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, subsampling_factor=1, dtype=np.float32):\n",
" \"\"\"Construct a CustomConverter object.\"\"\"\n",
" self.subsampling_factor = subsampling_factor\n",
" self.ignore_id = -1\n",
" self.dtype = dtype\n",
"\n",
" def __call__(self, batch):\n",
" \"\"\"Transform a batch and send it to a device.\n",
"\n",
" Args:\n",
" batch (list): The batch to transform.\n",
"\n",
" Returns:\n",
" tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)\n",
"\n",
" \"\"\"\n",
" # batch should be located in list\n",
" assert len(batch) == 1\n",
" (xs, ys), utts = batch[0]\n",
"\n",
" # perform subsampling\n",
" if self.subsampling_factor > 1:\n",
" xs = [x[::self.subsampling_factor, :] for x in xs]\n",
"\n",
" # get batch of lengths of input sequences\n",
" ilens = np.array([x.shape[0] for x in xs])\n",
"\n",
" # perform padding and convert to tensor\n",
" # currently only support real number\n",
" if xs[0].dtype.kind == \"c\":\n",
" xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)\n",
" xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)\n",
" # Note(kamo):\n",
" # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.\n",
" # Don't create ComplexTensor and give it E2E here\n",
" # because torch.nn.DataParellel can't handle it.\n",
" xs_pad = {\"real\": xs_pad_real, \"imag\": xs_pad_imag}\n",
" else:\n",
" xs_pad = pad_list(xs, 0).astype(self.dtype)\n",
"\n",
" # NOTE: this is for multi-output (e.g., speech translation)\n",
" ys_pad = pad_list(\n",
" [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],\n",
" self.ignore_id)\n",
"\n",
" olens = np.array([y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])\n",
" return utts, xs_pad, ilens, ys_pad, olens"
]
},
{
"cell_type": "code",
"execution_count": 107,
"id": "0b92ade5",
"metadata": {},
"outputs": [],
"source": [
"convert = CustomConverter()"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "8dbd847c",
"metadata": {},
"outputs": [],
"source": [
"utts, xs, ilen, ys, olen = convert([load(dev_data[0], return_uttid=True)])"
]
},
{
"cell_type": "code",
"execution_count": 109,
"id": "31c085f4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027']\n",
"(10, 1174, 83)\n",
"(10,)\n",
"[1174 821 716 628 597 473 463 441 419 358]\n",
"(10, 32)\n",
"[[4502 2404 4223 3204 4502 587 1018 3861 2932 713 2458 2916 253 4508\n",
" 627 1395 713 4504 957 2761 209 2967 3173 3918 2598 4100 3 2816\n",
" 4990 -1 -1 -1]\n",
" [1005 451 210 278 3411 206 482 2307 573 4502 3848 4577 4273 2388\n",
" 4444 89 4919 278 1264 4501 2371 3 139 113 2603 4962 3158 3325\n",
" 4577 814 4587 1422]\n",
" [2345 4144 2291 200 713 2345 532 999 2458 3076 545 2458 4832 3038\n",
" 4499 482 2812 1260 3080 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]\n",
" [2345 832 4577 4920 4501 2345 2298 1236 381 288 389 101 2495 4172\n",
" 4843 3233 3245 4501 2345 2298 3987 4502 3023 3353 2345 1361 1635 2603\n",
" 4723 2371 -1 -1]\n",
" [4502 4207 432 3204 4502 2396 125 935 433 2598 483 18 327 2\n",
" 389 627 4512 2340 713 482 1981 4525 4031 269 2030 1340 101 2495\n",
" 4013 4844 -1 -1]\n",
" [4502 4892 3204 1892 3780 389 482 2774 3013 89 192 2495 4502 3475\n",
" 389 66 370 343 404 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]\n",
" [2458 2314 4577 2340 2863 1254 303 269 2 389 932 2079 4577 299\n",
" 195 3233 4508 2 89 814 3144 1091 3204 3250 2193 3414 -1 -1\n",
" -1 -1 -1 -1]\n",
" [2391 1785 443 78 39 4962 2340 829 599 4593 278 4681 202 407\n",
" 269 194 182 4577 482 4308 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]\n",
" [ 627 4873 2175 363 202 404 1018 4577 4502 3412 4875 2286 107 122\n",
" 4832 2345 3896 89 2368 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]\n",
" [ 481 174 474 599 1881 3252 2842 742 4502 2545 107 88 3204 4525\n",
" 4517 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]]\n",
"[29 32 19 30 30 19 26 20 19 15]\n",
"float32\n",
"int64\n",
"int64\n",
"int64\n"
]
}
],
"source": [
"print(utts)\n",
"print(xs.shape)\n",
"print(ilen.shape)\n",
"print(ilen)\n",
"print(ys.shape)\n",
"print(ys)\n",
"print(olen)\n",
"print(xs.dtype)\n",
"print(ilen.dtype)\n",
"print(ys.dtype)\n",
"print(olen.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 110,
"id": "72e9ba60",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 230,
"id": "64593e5f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from paddle.io import DataLoader\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.io.batchfy import make_batchset\n",
"from deepspeech.io.converter import CustomConverter\n",
"from deepspeech.io.dataset import TransformDataset\n",
"from deepspeech.io.reader import LoadInputsAndTargets\n",
"from deepspeech.utils.log import Log\n",
"\n",
"\n",
"logger = Log(__name__).getlog()\n",
"\n",
"\n",
"class BatchDataLoader():\n",
" def __init__(self,\n",
" json_file: str,\n",
" train_mode: bool,\n",
" sortagrad: bool=False,\n",
" batch_size: int=0,\n",
" maxlen_in: float=float('inf'),\n",
" maxlen_out: float=float('inf'),\n",
" minibatches: int=0,\n",
" mini_batch_size: int=1,\n",
" batch_count: str='auto',\n",
" batch_bins: int=0,\n",
" batch_frames_in: int=0,\n",
" batch_frames_out: int=0,\n",
" batch_frames_inout: int=0,\n",
" preprocess_conf=None,\n",
" n_iter_processes: int=1,\n",
" subsampling_factor: int=1,\n",
" num_encs: int=1):\n",
" self.json_file = json_file\n",
" self.train_mode = train_mode\n",
" self.use_sortagrad = sortagrad == -1 or sortagrad > 0\n",
" self.batch_size = batch_size\n",
" self.maxlen_in = maxlen_in\n",
" self.maxlen_out = maxlen_out\n",
" self.batch_count = batch_count\n",
" self.batch_bins = batch_bins\n",
" self.batch_frames_in = batch_frames_in\n",
" self.batch_frames_out = batch_frames_out\n",
" self.batch_frames_inout = batch_frames_inout\n",
" self.subsampling_factor = subsampling_factor\n",
" self.num_encs = num_encs\n",
" self.preprocess_conf = preprocess_conf\n",
" self.n_iter_processes = n_iter_processes\n",
"\n",
" \n",
" # read json data\n",
" self.data_json = read_manifest(json_file)\n",
"\n",
" # make minibatch list (variable length)\n",
" self.minibaches = make_batchset(\n",
" self.data_json,\n",
" batch_size,\n",
" maxlen_in,\n",
" maxlen_out,\n",
" minibatches, # for debug\n",
" min_batch_size=mini_batch_size,\n",
" shortest_first=self.use_sortagrad,\n",
" count=batch_count,\n",
" batch_bins=batch_bins,\n",
" batch_frames_in=batch_frames_in,\n",
" batch_frames_out=batch_frames_out,\n",
" batch_frames_inout=batch_frames_inout,\n",
" iaxis=0,\n",
" oaxis=0, )\n",
"\n",
" # data reader\n",
" self.reader = LoadInputsAndTargets(\n",
" mode=\"asr\",\n",
" load_output=True,\n",
" preprocess_conf=preprocess_conf,\n",
" preprocess_args={\"train\":\n",
" train_mode}, # Switch the mode of preprocessing\n",
" )\n",
"\n",
" # Setup a converter\n",
" if num_encs == 1:\n",
" self.converter = CustomConverter(\n",
" subsampling_factor=subsampling_factor, dtype=np.float32)\n",
" else:\n",
" assert NotImplementedError(\"not impl CustomConverterMulEnc.\")\n",
"\n",
" # hack to make batchsize argument as 1\n",
" # actual bathsize is included in a list\n",
" # default collate function converts numpy array to pytorch tensor\n",
" # we used an empty collate function instead which returns list\n",
" self.dataset = TransformDataset(self.minibaches, \n",
" lambda data: self.converter([self.reader(data, return_uttid=True)]))\n",
" self.dataloader = DataLoader(\n",
" dataset=self.dataset,\n",
" batch_size=1,\n",
" shuffle=not use_sortagrad if train_mode else False,\n",
" collate_fn=lambda x: x[0],\n",
" num_workers=n_iter_processes, )\n",
"\n",
" def __repr__(self):\n",
" echo = f\"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> \"\n",
" echo += f\"train_mode: {self.train_mode}, \"\n",
" echo += f\"sortagrad: {self.use_sortagrad}, \"\n",
" echo += f\"batch_size: {self.batch_size}, \"\n",
" echo += f\"maxlen_in: {self.maxlen_in}, \"\n",
" echo += f\"maxlen_out: {self.maxlen_out}, \"\n",
" echo += f\"batch_count: {self.batch_count}, \"\n",
" echo += f\"batch_bins: {self.batch_bins}, \"\n",
" echo += f\"batch_frames_in: {self.batch_frames_in}, \"\n",
" echo += f\"batch_frames_out: {self.batch_frames_out}, \"\n",
" echo += f\"batch_frames_inout: {self.batch_frames_inout}, \"\n",
" echo += f\"subsampling_factor: {self.subsampling_factor}, \"\n",
" echo += f\"num_encs: {self.num_encs}, \"\n",
" echo += f\"num_workers: {self.n_iter_processes}, \"\n",
" echo += f\"file: {self.json_file}\"\n",
" return echo\n",
" \n",
" def __len__(self):\n",
" return len(self.dataloader)\n",
" \n",
" def __iter__(self):\n",
" return self.dataloader.__iter__()\n",
" \n",
" def __call__(self):\n",
" return self.__iter__()\n"
]
},
{
"cell_type": "code",
"execution_count": 231,
"id": "fcea3fd0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO 2021/08/18 07:42:23 batchfy.py:399] count is auto detected as seq\n",
"[INFO 2021/08/18 07:42:23 batchfy.py:423] # utts: 5542\n",
"[INFO 2021/08/18 07:42:23 batchfy.py:466] # minibatches: 278\n"
]
}
],
"source": [
"train = BatchDataLoader(dev_data, True, batch_size=20)"
]
},
{
"cell_type": "code",
"execution_count": 232,
"id": "e2a2c9a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"278\n",
"['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'auto_collate_batch', 'batch_sampler', 'batch_size', 'collate_fn', 'dataset', 'dataset_kind', 'feed_list', 'from_dataset', 'from_generator', 'num_workers', 'pin_memory', 'places', 'return_list', 'timeout', 'use_buffer_reader', 'use_shared_memory', 'worker_init_fn']\n",
"<__main__.BatchDataLoader object at 0x7fdddba35470> train_mode: True, sortagrad: False, batch_size: 20, maxlen_in: inf, maxlen_out: inf, batch_count: auto, batch_bins: 0, batch_frames_in: 0, batch_frames_out: 0, batch_frames_inout: 0, subsampling_factor: 1, num_encs: 1, num_workers: 1, file: /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev\n",
"278\n"
]
}
],
"source": [
"print(len(train.dataloader))\n",
"print(dir(train.dataloader))\n",
"print(train)\n",
"print(len(train))"
]
},
{
"cell_type": "code",
"execution_count": 220,
"id": "a5ba7d6e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['7601-101619-0003', '1255-138279-0000', '1272-128104-0004', '6123-59150-0027', '2078-142845-0025', '7850-73752-0018', '4570-24733-0004', '2506-169427-0002', '7601-101619-0004', '3170-137482-0000', '6267-53049-0019', '4570-14911-0009', '174-168635-0018', '7601-291468-0004', '3576-138058-0022', '1919-142785-0007', '6467-62797-0007', '4153-61735-0005', '1686-142278-0003', '2506-169427-0000']\n",
"Tensor(shape=[20, 2961, 83], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[-1.99415934, -1.80315673, -1.88801885, ..., 0.86933994, -0.59853148, 0.02596200],\n",
" [-1.95346808, -1.84891188, -2.17492867, ..., 0.83640492, -0.59853148, -0.11333394],\n",
" [-2.27899861, -2.21495342, -2.58480024, ..., 0.91874266, -0.59853148, -0.31453922],\n",
" ...,\n",
" [-2.64522028, -2.35221887, -2.91269732, ..., 1.48994756, -0.16100442, 0.36646330],\n",
" [-2.40107250, -2.21495342, -2.37986445, ..., 1.44072104, -0.13220564, 0.12656468],\n",
" [-2.15692472, -1.89466715, -2.25690317, ..., 1.31273174, -0.09620714, -0.15202725]],\n",
"\n",
" [[-0.28859532, -0.29033494, -0.86576819, ..., 1.37753224, -0.30570769, 0.25806731],\n",
" [-0.20149794, -0.17814466, -0.59891301, ..., 1.35188794, -0.30570769, -0.02964944],\n",
" [-0.34947991, -0.33597648, -0.96877253, ..., 1.38394332, -0.30570769, -0.38376236],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.44914246, -0.33902276, -0.78237975, ..., 1.38218808, 0.29214793, -0.16815147],\n",
" [-0.55490732, -0.41596055, -0.84425378, ..., 1.34530187, 0.25002354, -0.04004869],\n",
" [-0.83694696, -0.62112784, -1.07112527, ..., 1.19160914, 0.20789915, 0.37984371],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" ...,\n",
"\n",
" [[-1.24343657, -0.94188881, -1.41092563, ..., 0.96716309, 0.60345763, 0.15360183],\n",
" [-1.19466043, -0.80585432, -0.49723154, ..., 1.06735480, 0.60345763, 0.14511746],\n",
" [-0.94079566, -0.59330046, -0.40948665, ..., 0.82244170, 0.55614340, 0.28086722],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.21757117, 0.11361472, -0.33262897, ..., 0.76338506, -0.10711290, -0.57754958],\n",
" [-1.00205481, -0.61152041, -0.47124696, ..., 1.11897349, -0.10711290, 0.24931324],\n",
" [-1.03929281, -1.20336759, -1.16433656, ..., 0.88888687, -0.10711290, -0.04115745],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-1.25289667, -1.05046368, -0.82881606, ..., 1.23991334, 0.61702502, 0.05275881],\n",
" [-1.19659519, -0.78677225, -0.80407262, ..., 1.27644968, 0.61702502, -0.35079369],\n",
" [-1.49687004, -1.01750231, -0.82881606, ..., 1.29106426, 0.65006059, 0.17958963],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]]])\n",
"Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [2961, 2948, 2938, 2907, 2904, 2838, 2832, 2819, 2815, 2797, 2775, 2710, 2709, 2696, 2688, 2661, 2616, 2595, 2589, 2576])\n",
"Tensor(shape=[20, 133], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[3098, 1595, 389, ..., -1 , -1 , -1 ],\n",
" [2603, 4832, 482, ..., -1 , -1 , -1 ],\n",
" [2796, 303, 269, ..., -1 , -1 , -1 ],\n",
" ...,\n",
" [3218, 3673, 206, ..., -1 , -1 , -1 ],\n",
" [2371, 4832, 4031, ..., -1 , -1 , -1 ],\n",
" [2570, 2433, 4285, ..., -1 , -1 , -1 ]])\n",
"Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [80 , 83 , 102, 133, 82 , 102, 71 , 91 , 68 , 81 , 86 , 67 , 71 , 95 , 65 , 88 , 97 , 98 , 89 , 72 ])\n"
]
}
],
"source": [
"for batch in train:\n",
" utts, xs, ilens, ys, olens = batch\n",
" print(utts)\n",
" print(xs)\n",
" print(ilens)\n",
" print(ys)\n",
" print(olens)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c974a1e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "breeding-haven",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "appropriate-theta",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LICENSE deepspeech examples\t\t requirements.txt tools\r\n",
"README.md docs\t libsndfile-1.0.28\t setup.sh\t utils\r\n",
"README_cn.md env.sh\t libsndfile-1.0.28.tar.gz tests\r\n"
]
}
],
"source": [
"!ls"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "entire-bloom",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"WARNING:root:override cat of paddle.Tensor if exists or register, remove this when fixed!\n",
"WARNING:root:register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user repeat to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user glu to paddle.nn.functional, remove this when fixed!\n",
"WARNING:root:register user GLU to paddle.nn, remove this when fixed!\n",
"WARNING:root:register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"WARNING:root:override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n"
]
}
],
"source": [
"from deepspeech.modules import loss"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "governmental-aircraft",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"import paddle"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "proprietary-disaster",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function deepspeech.modules.repeat(xs: paddle.VarBase, *size: Any) -> paddle.VarBase>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.repeat"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "first-diagram",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<property at 0x7fb515eeeb88>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.size"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "intelligent-david",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function paddle.tensor.manipulation.concat(x, axis=0, name=None)>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.cat"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bronze-tenant",
"metadata": {},
"outputs": [],
"source": [
"a = paddle.to_tensor([12,32, 10, 12, 123,32 ,4])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "balanced-bearing",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "extreme-republic",
"metadata": {},
"outputs": [],
"source": [
"def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:\n",
" nargs = len(args)\n",
" assert (nargs <= 1)\n",
" s = paddle.shape(xs)\n",
" if nargs == 1:\n",
" return s[args[0]]\n",
" else:\n",
" return s\n",
"\n",
"# logger.warn(\n",
"# \"override size of paddle.Tensor if exists or register, remove this when fixed!\"\n",
"# )\n",
"paddle.Tensor.size = size"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "gross-addiction",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [7])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size(0)\n",
"a.size()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "adverse-dining",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [7])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "popular-potato",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-26 02:55:23,873 - WARNING - register user softmax to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,875 - WARNING - register user sigmoid to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,875 - WARNING - register user relu to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,876 - WARNING - override cat of paddle if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,876 - WARNING - override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,877 - WARNING - override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,877 - WARNING - override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"2021-03-26 02:55:23,878 - WARNING - register user view to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,878 - WARNING - register user view_as to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,879 - WARNING - register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,880 - WARNING - register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,880 - WARNING - register user fill_ to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,881 - WARNING - register user repeat to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,881 - WARNING - register user softmax to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,882 - WARNING - register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,882 - WARNING - register user relu to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,883 - WARNING - register user glu to paddle.nn.functional, remove this when fixed!\n",
"2021-03-26 02:55:23,883 - WARNING - override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"2021-03-26 02:55:23,884 - WARNING - register user GLU to paddle.nn, remove this when fixed!\n",
"2021-03-26 02:55:23,884 - WARNING - register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n"
]
}
],
"source": [
"import os\n",
"import time\n",
"import argparse\n",
"import functools\n",
"import paddle\n",
"import numpy as np\n",
"\n",
"from deepspeech.utils.socket_server import warm_up_test\n",
"from deepspeech.utils.socket_server import AsrTCPServer\n",
"from deepspeech.utils.socket_server import AsrRequestHandler\n",
"\n",
"from deepspeech.training.cli import default_argument_parser\n",
"from deepspeech.exps.deepspeech2.config import get_cfg_defaults\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"\n",
"from deepspeech.models.ds2 import DeepSpeech2Model\n",
"from deepspeech.models.ds2 import DeepSpeech2InferModel\n",
"from deepspeech.io.dataset import ManifestDataset\n",
"\n",
"\n",
"\n",
"from deepspeech.frontend.utility import read_manifest"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0.0\n",
"e7f28d6c0db54eb9c9a810612300b526687e56a6\n",
"OFF\n",
"OFF\n",
"commit: e7f28d6c0db54eb9c9a810612300b526687e56a6\n",
"None\n",
"0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" 'commit',\n",
" 'full_version',\n",
" 'istaged',\n",
" 'major',\n",
" 'minor',\n",
" 'mkl',\n",
" 'patch',\n",
" 'rc',\n",
" 'show',\n",
" 'with_mkl']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(paddle.__version__)\n",
"print(paddle.version.commit)\n",
"print(paddle.version.with_mkl)\n",
"print(paddle.version.mkl())\n",
"print(paddle.version.show())\n",
"print(paddle.version.patch)\n",
"dir(paddle.version)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data:\n",
" augmentation_config: conf/augmentation.config\n",
" batch_size: 64\n",
" dev_manifest: data/manifest.dev\n",
" keep_transcription_text: False\n",
" max_duration: 27.0\n",
" max_freq: None\n",
" mean_std_filepath: examples/aishell/data/mean_std.npz\n",
" min_duration: 0.0\n",
" n_fft: None\n",
" num_workers: 0\n",
" random_seed: 0\n",
" shuffle_method: batch_shuffle\n",
" sortagrad: True\n",
" specgram_type: linear\n",
" stride_ms: 10.0\n",
" target_dB: -20\n",
" target_sample_rate: 16000\n",
" test_manifest: examples/aishell/data/manifest.test\n",
" train_manifest: data/manifest.train\n",
" use_dB_normalization: True\n",
" vocab_filepath: examples/aishell/data/vocab.txt\n",
" window_ms: 20.0\n",
"decoding:\n",
" alpha: 2.6\n",
" batch_size: 128\n",
" beam_size: 300\n",
" beta: 5.0\n",
" cutoff_prob: 0.99\n",
" cutoff_top_n: 40\n",
" decoding_method: ctc_beam_search\n",
" error_rate_type: cer\n",
" lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm\n",
" num_proc_bsearch: 10\n",
"model:\n",
" num_conv_layers: 2\n",
" num_rnn_layers: 3\n",
" rnn_layer_size: 1024\n",
" share_rnn_weights: False\n",
" use_gru: True\n",
"training:\n",
" global_grad_clip: 5.0\n",
" lr: 0.0005\n",
" lr_decay: 0.83\n",
" n_epoch: 30\n",
" weight_decay: 1e-06\n",
"----------- Configuration Arguments -----------\n",
"checkpoint_path: examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725\n",
"config: examples/aishell/conf/deepspeech2.yaml\n",
"device: gpu\n",
"dump_config: None\n",
"export_path: None\n",
"host_ip: localhost\n",
"host_port: 8086\n",
"model_dir: None\n",
"model_file: examples/aishell/jit.model.pdmodel\n",
"nprocs: 1\n",
"opts: ['data.test_manifest', 'examples/aishell/data/manifest.test', 'data.mean_std_filepath', 'examples/aishell/data/mean_std.npz', 'data.vocab_filepath', 'examples/aishell/data/vocab.txt']\n",
"output: None\n",
"params_file: examples/aishell/jit.model.pdiparams\n",
"speech_save_dir: demo_cache\n",
"use_gpu: False\n",
"warmup_manifest: examples/aishell/data/manifest.test\n",
"------------------------------------------------\n"
]
}
],
"source": [
"parser = default_argument_parser()\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"add_arg('host_ip', str,\n",
" 'localhost',\n",
" \"Server's IP address.\")\n",
"add_arg('host_port', int, 8086, \"Server's IP port.\")\n",
"add_arg('speech_save_dir', str,\n",
" 'demo_cache',\n",
" \"Directory to save demo audios.\")\n",
"add_arg('warmup_manifest', \n",
" str, \n",
" \"examples/aishell/data/manifest.test\", \n",
" \"Filepath of manifest to warm up.\")\n",
"add_arg(\n",
" \"--model_file\",\n",
" type=str,\n",
" default=\"examples/aishell/jit.model.pdmodel\",\n",
" help=\"Model filename, Specify this when your model is a combined model.\"\n",
")\n",
"add_arg(\n",
" \"--params_file\",\n",
" type=str,\n",
" default=\"examples/aishell/jit.model.pdiparams\",\n",
" help=\n",
" \"Parameter filename, Specify this when your model is a combined model.\"\n",
")\n",
"add_arg(\n",
" \"--model_dir\",\n",
" type=str,\n",
" default=None,\n",
" help=\n",
" \"Model dir, If you load a non-combined model, specify the directory of the model.\"\n",
")\n",
"add_arg(\"--use_gpu\",type=bool,default=False, help=\"Whether use gpu.\")\n",
"\n",
"\n",
"args = parser.parse_args(\n",
" \"--checkpoint_path examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725 --config examples/aishell/conf/deepspeech2.yaml --opts data.test_manifest examples/aishell/data/manifest.test data.mean_std_filepath examples/aishell/data/mean_std.npz data.vocab_filepath examples/aishell/data/vocab.txt\".split()\n",
")\n",
"\n",
"\n",
"config = get_cfg_defaults()\n",
"if args.config:\n",
" config.merge_from_file(args.config)\n",
"if args.opts:\n",
" config.merge_from_list(args.opts)\n",
"config.freeze()\n",
"print(config)\n",
"\n",
"args.warmup_manifest = config.data.test_manifest\n",
"\n",
"print_arguments(args)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dataset = ManifestDataset(\n",
" config.data.test_manifest,\n",
" config.data.unit_type,\n",
" config.data.vocab_filepath,\n",
" config.data.mean_std_filepath,\n",
" augmentation_config=\"{}\",\n",
" max_duration=config.data.max_duration,\n",
" min_duration=config.data.min_duration,\n",
" stride_ms=config.data.stride_ms,\n",
" window_ms=config.data.window_ms,\n",
" n_fft=config.data.n_fft,\n",
" max_freq=config.data.max_freq,\n",
" target_sample_rate=config.data.target_sample_rate,\n",
" specgram_type=config.data.specgram_type,\n",
" feat_dim=config.data.feat_dim,\n",
" delta_delta=config.data.delat_delta,\n",
" use_dB_normalization=config.data.use_dB_normalization,\n",
" target_dB=config.data.target_dB,\n",
" random_seed=config.data.random_seed,\n",
" keep_transcription_text=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-26 02:55:57,930 - INFO - [checkpoint] Rank 0: loaded model from examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725.pdparams\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"layer summary:\n",
"encoder.conv.conv_in.conv.weight|[32, 1, 41, 11]|14432\n",
"encoder.conv.conv_in.bn.weight|[32]|32\n",
"encoder.conv.conv_in.bn.bias|[32]|32\n",
"encoder.conv.conv_in.bn._mean|[32]|32\n",
"encoder.conv.conv_in.bn._variance|[32]|32\n",
"encoder.conv.conv_stack.0.conv.weight|[32, 32, 21, 11]|236544\n",
"encoder.conv.conv_stack.0.bn.weight|[32]|32\n",
"encoder.conv.conv_stack.0.bn.bias|[32]|32\n",
"encoder.conv.conv_stack.0.bn._mean|[32]|32\n",
"encoder.conv.conv_stack.0.bn._variance|[32]|32\n",
"encoder.rnn.rnn_stacks.0.fw_fc.weight|[1312, 3072]|4030464\n",
"encoder.rnn.rnn_stacks.0.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_fc.weight|[1312, 3072]|4030464\n",
"encoder.rnn.rnn_stacks.0.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.bw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.1.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.1.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.bw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.2.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.2.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.bw_rnn.cell.bias_hh|[3072]|3072\n",
"decoder.ctc_lo.weight|[2048, 4300]|8806400\n",
"decoder.ctc_lo.bias|[4300]|4300\n",
"layer has 66 parameters, 80148012 elements.\n"
]
}
],
"source": [
"model = DeepSpeech2InferModel.from_pretrained(dataset, config,\n",
" args.checkpoint_path)\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"examples/aishell/jit.model.pdmodel\n",
"examples/aishell/jit.model.pdiparams\n",
"0\n",
"False\n"
]
}
],
"source": [
"\n",
"from paddle.inference import Config\n",
"from paddle.inference import PrecisionType\n",
"from paddle.inference import create_predictor\n",
"\n",
"args.use_gpu=False\n",
"paddle.set_device('cpu')\n",
"\n",
"def init_predictor(args):\n",
" if args.model_dir is not None:\n",
" config = Config(args.model_dir)\n",
" else:\n",
" config = Config(args.model_file, args.params_file)\n",
"\n",
" if args.use_gpu:\n",
" config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)\n",
"# config.enable_tensorrt_engine(precision_mode=PrecisionType.Float32,\n",
"# use_calib_mode=True) # 开启TensorRT预测,精度为fp32,开启int8离线量化\n",
" else:\n",
" # If not specific mkldnn, you can set the blas thread.\n",
" # The thread num should not be greater than the number of cores in the CPU.\n",
" config.set_cpu_math_library_num_threads(1)\n",
" config.enable_mkldnn()\n",
" \n",
" config.enable_memory_optim()\n",
" config.switch_ir_optim(True)\n",
" \n",
" print(config.model_dir())\n",
" print(config.prog_file())\n",
" print(config.params_file())\n",
" print(config.gpu_device_id())\n",
" print(args.use_gpu)\n",
" predictor = create_predictor(config)\n",
" return predictor\n",
"\n",
"def run(predictor, audio, audio_len):\n",
" # copy img data to input tensor\n",
" input_names = predictor.get_input_names()\n",
" for i, name in enumerate(input_names):\n",
" print(\"input:\", i, name)\n",
" \n",
" audio_tensor = predictor.get_input_handle('audio')\n",
" audio_tensor.reshape(audio.shape)\n",
" audio_tensor.copy_from_cpu(audio.copy())\n",
" \n",
" audiolen_tensor = predictor.get_input_handle('audio_len')\n",
" audiolen_tensor.reshape(audio_len.shape)\n",
" audiolen_tensor.copy_from_cpu(audio_len.copy())\n",
"\n",
" output_names = predictor.get_output_names()\n",
" for i, name in enumerate(output_names):\n",
" print(\"output:\", i, name)\n",
"\n",
" # do the inference\n",
" predictor.run()\n",
"\n",
" results = []\n",
" # get out data from output tensor\n",
" output_names = predictor.get_output_names()\n",
" for i, name in enumerate(output_names):\n",
" output_tensor = predictor.get_output_handle(name)\n",
" output_data = output_tensor.copy_to_cpu()\n",
" results.append(output_data)\n",
"\n",
" return results\n",
"\n",
"\n",
"predictor = init_predictor(args)\n",
"\n",
"def file_to_transcript(filename):\n",
" print(filename)\n",
" feature = dataset.process_utterance(filename, \"\")\n",
" audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n",
" audio_len = feature[0].shape[1]\n",
" audio_len = np.array([audio_len]).astype('int64') # [1]\n",
" \n",
" \n",
" i_probs = run(predictor, audio, audio_len)\n",
" print('jit:', i_probs[0], type(i_probs[0]))\n",
" \n",
" audio = paddle.to_tensor(audio)\n",
" audio_len = paddle.to_tensor(audio_len)\n",
" print(audio.shape)\n",
" print(audio_len.shape)\n",
" \n",
" #eouts, eouts_len = model.encoder(audio, audio_len)\n",
" #probs = model.decoder.softmax(eouts)\n",
" probs = model.forward(audio, audio_len)\n",
" print('paddle:', probs.numpy())\n",
" \n",
" flag = np.allclose(i_probs[0], probs.numpy())\n",
" print(flag)\n",
" \n",
" return probs\n",
"\n",
"# result_transcript = model.decode(\n",
"# audio,\n",
"# audio_len,\n",
"# vocab_list=dataset.vocab_list,\n",
"# decoding_method=config.decoding.decoding_method,\n",
"# lang_model_path=config.decoding.lang_model_path,\n",
"# beam_alpha=config.decoding.alpha,\n",
"# beam_beta=config.decoding.beta,\n",
"# beam_size=config.decoding.beam_size,\n",
"# cutoff_prob=config.decoding.cutoff_prob,\n",
"# cutoff_top_n=config.decoding.cutoff_top_n,\n",
"# num_processes=config.decoding.num_proc_bsearch)\n",
"# return result_transcript[0]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warm-up Test Case %d: %s 0 /home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n",
"input: 0 audio\n",
"input: 1 audio_len\n",
"output: 0 tmp_75\n",
"jit: [[[8.91786298e-12 4.45648032e-12 3.67572750e-09 ... 8.91767563e-12\n",
" 8.91573707e-12 4.64317296e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638127e-17 7.61802427e-16 2.93265812e-14 ... 1.24633371e-17\n",
" 1.24587264e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676260e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89334696e-13 1.66754856e-11 1.42900388e-11 ... 3.89329492e-13\n",
" 3.89252270e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]] <class 'numpy.ndarray'>\n",
"[1, 161, 522]\n",
"[1]\n",
"paddle: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n",
" 8.91577090e-12 4.64319072e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n",
" 1.24587735e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n",
" 3.89253761e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]]\n",
"False\n"
]
}
],
"source": [
"manifest = read_manifest(args.warmup_manifest)\n",
"\n",
"for idx, sample in enumerate(manifest[:1]):\n",
" print(\"Warm-up Test Case %d: %s\", idx, sample['audio_filepath'])\n",
" start_time = time.time()\n",
" transcript = file_to_transcript(sample['audio_filepath'])\n",
" finish_time = time.time()\n",
"# print(\"Response Time: %f, Transcript: %s\" %\n",
"# (finish_time - start_time, transcript))\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 161, 522) (1,)\n",
"input: 0 audio\n",
"input: 1 audio_len\n",
"output: 0 tmp_75\n",
"jit: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n",
" 8.91577090e-12 4.64319072e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n",
" 1.24587735e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n",
" 3.89253761e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]]\n"
]
}
],
"source": [
"def test(filename):\n",
" feature = dataset.process_utterance(filename, \"\")\n",
" audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n",
" audio_len = feature[0].shape[1]\n",
" audio_len = np.array([audio_len]).astype('int64') # [1]\n",
" \n",
" print(audio.shape, audio_len.shape)\n",
"\n",
" i_probs = run(predictor, audio, audio_len)\n",
" print('jit:', i_probs[0])\n",
" return i_probs\n",
" \n",
"probs = test(sample['audio_filepath'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 32,
"id": "academic-surname",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from paddle import nn"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "fundamental-treasure",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])\n",
"Parameter containing:\n",
"Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n"
]
}
],
"source": [
"L = nn.LayerNorm(256, epsilon=1e-12)\n",
"for p in L.parameters():\n",
" print(p)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "consolidated-elephant",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "moderate-noise",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float64\n"
]
}
],
"source": [
"x = np.random.randn(2, 51, 256)\n",
"print(x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "cooked-progressive",
"metadata": {},
"outputs": [],
"source": [
"y = L(paddle.to_tensor(x, dtype='float32'))"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "optimum-milwaukee",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "viral-indian",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1.], requires_grad=True)\n",
"Parameter containing:\n",
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" requires_grad=True)\n"
]
}
],
"source": [
"TL = torch.nn.LayerNorm(256, eps=1e-12)\n",
"for p in TL.parameters():\n",
" print(p)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "skilled-vietnamese",
"metadata": {},
"outputs": [],
"source": [
"ty = TL(torch.tensor(x, dtype=torch.float32))"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "incorrect-allah",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.allclose(y.numpy(), ty.detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "prostate-cameroon",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 52,
"id": "governmental-surge",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.random.randn(2, 256)\n",
"y = L(paddle.to_tensor(x, dtype='float32'))\n",
"ty = TL(torch.tensor(x, dtype=torch.float32))\n",
"np.allclose(y.numpy(), ty.detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "confidential-jacket",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "primary-organic",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "stopped-semester",
"metadata": {},
"outputs": [],
"source": [
"def mask_finished_scores(score: torch.Tensor,\n",
" flag: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" If a sequence is finished, we only allow one alive branch. This function\n",
" aims to give one branch a zero score and the rest -inf score.\n",
" Args:\n",
" score (torch.Tensor): A real value array with shape\n",
" (batch_size * beam_size, beam_size).\n",
" flag (torch.Tensor): A bool array with shape\n",
" (batch_size * beam_size, 1).\n",
" Returns:\n",
" torch.Tensor: (batch_size * beam_size, beam_size).\n",
" \"\"\"\n",
" beam_size = score.size(-1)\n",
" zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n",
" if beam_size > 1:\n",
" unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n",
" dim=1)\n",
" finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n",
" dim=1)\n",
" else:\n",
" unfinished = zero_mask\n",
" finished = flag\n",
" print(unfinished)\n",
" print(finished)\n",
" score.masked_fill_(unfinished, -float('inf'))\n",
" score.masked_fill_(finished, 0)\n",
" return score"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "agreed-portuguese",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ True],\n",
" [False]])\n",
"tensor([[-0.8841, 0.7381, -0.9986],\n",
" [ 0.2675, -0.7971, 0.3798]])\n",
"tensor([[ True, True],\n",
" [False, False]])\n"
]
}
],
"source": [
"score = torch.randn((2, 3))\n",
"flag = torch.ones((2, 1), dtype=torch.bool)\n",
"flag[1] = False\n",
"print(flag)\n",
"print(score)\n",
"print(flag.repeat([1, 2]))"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "clean-aspect",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[False, True, True],\n",
" [False, False, False]])\n",
"tensor([[ True, False, False],\n",
" [False, False, False]])\n",
"tensor([[ 0.0000, -inf, -inf],\n",
" [ 0.2675, -0.7971, 0.3798]])\n",
"tensor([[ 0.0000, -inf, -inf],\n",
" [ 0.2675, -0.7971, 0.3798]])\n"
]
}
],
"source": [
"r = mask_finished_scores(score, flag)\n",
"print(r)\n",
"print(score)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "thrown-airline",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True ],\n",
" [False]])\n",
"Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"import paddle\n",
"\n",
"score = paddle.randn((2, 3))\n",
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag)\n",
"print(score)\n",
"print(flag.tile([1, 2]))"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "internal-patent",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[False, True , True ],\n",
" [False, False, False]])\n",
"Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , False, False],\n",
" [False, False, False]])\n",
"x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n"
]
}
],
"source": [
"paddle.bool = 'bool'\n",
"\n",
"def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
" print(xs)\n",
" trues = paddle.ones_like(xs) * value\n",
" assert xs.shape == mask.shape\n",
" xs = paddle.where(mask, trues, xs)\n",
" return xs\n",
"\n",
"def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
" print('x', xs)\n",
" trues = paddle.ones_like(xs) * value\n",
" assert xs.shape == mask.shape\n",
" ret = paddle.where(mask, trues, xs)\n",
" print('2', xs)\n",
" paddle.assign(ret, output=xs)\n",
" print('3', xs)\n",
"\n",
"paddle.Tensor.masked_fill = masked_fill\n",
"paddle.Tensor.masked_fill_ = masked_fill_\n",
"\n",
"def mask_finished_scores_pd(score: paddle.Tensor,\n",
" flag: paddle.Tensor) -> paddle.Tensor:\n",
" \"\"\"\n",
" If a sequence is finished, we only allow one alive branch. This function\n",
" aims to give one branch a zero score and the rest -inf score.\n",
" Args:\n",
" score (torch.Tensor): A real value array with shape\n",
" (batch_size * beam_size, beam_size).\n",
" flag (torch.Tensor): A bool array with shape\n",
" (batch_size * beam_size, 1).\n",
" Returns:\n",
" torch.Tensor: (batch_size * beam_size, beam_size).\n",
" \"\"\"\n",
" beam_size = score.shape[-1]\n",
" zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n",
" if beam_size > 1:\n",
" unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n",
" axis=1)\n",
" finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n",
" axis=1)\n",
" else:\n",
" unfinished = zero_mask\n",
" finished = flag\n",
" print(unfinished)\n",
" print(finished)\n",
" \n",
" #score.masked_fill_(unfinished, -float('inf'))\n",
" #score.masked_fill_(finished, 0)\n",
"# infs = paddle.ones_like(score) * -float('inf')\n",
"# score = paddle.where(unfinished, infs, score)\n",
"# score = paddle.where(finished, paddle.zeros_like(score), score)\n",
"\n",
"# score = score.masked_fill(unfinished, -float('inf'))\n",
"# score = score.masked_fill(finished, 0)\n",
" score.masked_fill_(unfinished, -float('inf'))\n",
" score.masked_fill_(finished, 0)\n",
" return score\n",
"\n",
"r = mask_finished_scores_pd(score, flag)\n",
"print(r)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "vocal-prime",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<bound method PyCapsule.value of Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])>"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"score.value"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "bacterial-adolescent",
"metadata": {},
"outputs": [],
"source": [
"from typing import Union, Any"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "absent-fiber",
"metadata": {},
"outputs": [],
"source": [
"def repeat(xs : paddle.Tensor, *size: Any):\n",
" print(size)\n",
" return paddle.tile(xs, size)\n",
"paddle.Tensor.repeat = repeat"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "material-harbor",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 2)\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag.repeat(1, 2))"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "acute-brighton",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [1]), 2)\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag.repeat(paddle.to_tensor(1), 2))"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "european-rugby",
"metadata": {},
"outputs": [],
"source": [
"def size(xs, *args: int):\n",
" nargs = len(args)\n",
" s = paddle.shape(xs)\n",
" assert(nargs <= 1)\n",
" if nargs == 1:\n",
" return s[args[0]]\n",
" else:\n",
" return s\n",
"paddle.Tensor.size = size"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "moral-special",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [2, 1])"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size()"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "ahead-coach",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [1])"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size(1)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "incomplete-fitness",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [2])"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "upset-connectivity",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "designing-borough",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n",
"True\n",
"True\n"
]
}
],
"source": [
"import torch\n",
"import math\n",
"import numpy as np\n",
"\n",
"max_len=100\n",
"d_model=256\n",
"\n",
"pe = torch.zeros(max_len, d_model)\n",
"position = torch.arange(0, max_len,\n",
" dtype=torch.float32).unsqueeze(1)\n",
"toruch_position = position\n",
"div_term = torch.exp(\n",
" torch.arange(0, d_model, 2, dtype=torch.float32) *\n",
" -(math.log(10000.0) / d_model))\n",
"tourch_div_term = div_term.cpu().detach().numpy()\n",
"\n",
"\n",
"\n",
"torhc_sin = torch.sin(position * div_term)\n",
"torhc_cos = torch.cos(position * div_term)\n",
"print(torhc_sin.cpu().detach().numpy())\n",
"np_sin = np.sin((position * div_term).cpu().detach().numpy())\n",
"np_cos = np.cos((position * div_term).cpu().detach().numpy())\n",
"print(np.allclose(np_sin, torhc_sin.cpu().detach().numpy()))\n",
"print(np.allclose(np_cos, torhc_cos.cpu().detach().numpy()))\n",
"pe[:, 0::2] = torhc_sin\n",
"pe[:, 1::2] = torhc_cos\n",
"tourch_pe = pe.cpu().detach().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "swiss-referral",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"False\n",
"False\n",
"False\n",
"False\n",
"[[ 1. 1. 1. ... 1. 1.\n",
" 1. ]\n",
" [ 0.5403023 0.59737533 0.6479059 ... 1. 1.\n",
" 1. ]\n",
" [-0.41614684 -0.28628543 -0.1604359 ... 0.99999994 1.\n",
" 1. ]\n",
" ...\n",
" [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.99993724\n",
" 0.9999457 ]\n",
" [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n",
" 0.99994457]\n",
" [ 0.03982088 -0.52298605 -0.6157435 ... 0.99992454 0.9999347\n",
" 0.99994344]]\n",
"----\n",
"[[ 1. 1. 1. ... 1. 1.\n",
" 1. ]\n",
" [ 0.54030234 0.59737533 0.6479059 ... 1. 1.\n",
" 1. ]\n",
" [-0.41614684 -0.28628543 -0.1604359 ... 1. 1.\n",
" 1. ]\n",
" ...\n",
" [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.9999373\n",
" 0.9999457 ]\n",
" [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n",
" 0.99994457]\n",
" [ 0.03982088 -0.5229861 -0.6157435 ... 0.99992454 0.9999347\n",
" 0.99994344]]\n",
")))))))\n",
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n",
"----\n",
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n"
]
}
],
"source": [
"import paddle\n",
"paddle.set_device('cpu')\n",
"ppe = paddle.zeros((max_len, d_model), dtype='float32')\n",
"position = paddle.arange(0, max_len,\n",
" dtype='float32').unsqueeze(1)\n",
"print(np.allclose(position.numpy(), toruch_position))\n",
"div_term = paddle.exp(\n",
" paddle.arange(0, d_model, 2, dtype='float32') *\n",
" -(math.log(10000.0) / d_model))\n",
"print(np.allclose(div_term.numpy(), tourch_div_term))\n",
"\n",
"\n",
"\n",
"p_sin = paddle.sin(position * div_term)\n",
"p_cos = paddle.cos(position * div_term)\n",
"print(np.allclose(np_sin, p_sin.numpy(), rtol=1.e-6, atol=0))\n",
"print(np.allclose(np_cos, p_cos.numpy(), rtol=1.e-6, atol=0))\n",
"ppe[:, 0::2] = p_sin\n",
"ppe[:, 1::2] = p_cos\n",
"print(np.allclose(p_sin.numpy(), torhc_sin.cpu().detach().numpy()))\n",
"print(np.allclose(p_cos.numpy(), torhc_cos.cpu().detach().numpy()))\n",
"print(p_cos.numpy())\n",
"print(\"----\")\n",
"print(torhc_cos.cpu().detach().numpy())\n",
"print(\")))))))\")\n",
"print(p_sin.numpy())\n",
"print(\"----\")\n",
"print(torhc_sin.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "integrated-boards",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
"source": [
"print(np.allclose(ppe.numpy(), pe.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "flying-reserve",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "revised-divide",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
因为 它太大了无法显示 source diff 。你可以改为 查看blob
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cloudy-glass",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISISBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "grand-stephen",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0.0\n"
]
}
],
"source": [
"import paddle\n",
"print(paddle.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "isolated-prize",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "romance-samuel",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/data/manifest.dev',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/aishell/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'linear',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "timely-bikini",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"from data_utils.dataset import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=False,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "organized-warrior",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if arr.dtype == np.object:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[14 , 34 , 322 , 233 , 0 , 0 ],\n",
" [238 , 38 , 122 , 164 , 0 , 0 ],\n",
" [8 , 52 , 49 , 42 , 0 , 0 ],\n",
" [109 , 47 , 146 , 193 , 210 , 479 ],\n",
" [3330, 1751, 208 , 1923, 0 , 0 ]])\n",
"test raw 大时代里的的\n",
"test raw 煲汤受宠的的\n",
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"test len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [4, 4, 4, 6, 4])\n",
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
" ...,\n",
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
"\n",
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
" ...,\n",
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
]
}
],
"source": [
" for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n",
" print('audio len', audio_len)\n",
" print('test len', text_len)\n",
" print('audio', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "confidential-radius",
"metadata": {},
"outputs": [],
"source": [
"# reader = batch_reader()\n",
"# audio, test , audio_len, text_len = reader.next()\n",
"# print('test', text)\n",
"# print('t len', text_len) #[B, T]\n",
"# print('audio len', audio_len)\n",
"# print(audio)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "future-vermont",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"煲汤受宠\n"
]
}
],
"source": [
"print(u'\\u7172\\u6c64\\u53d7\\u5ba0')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dental-sweden",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "sunrise-contact",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "hispanic-asthma",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "hearing-leadership",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "skilled-friday",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "copyrighted-measure",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"id": "employed-lightweight",
"metadata": {},
"outputs": [],
"source": [
"from model_utils.network import DeepSpeech2, DeepSpeech2Loss\n",
"\n",
"from data_utils.dataset import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=False,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)\n",
"\n",
"\n",
"import paddle\n",
"from paddle import nn\n",
"from paddle.nn import functional as F\n",
"from paddle.nn import initializer as I\n",
"\n",
"import math\n",
"\n",
"def brelu(x, t_min=0.0, t_max=24.0, name=None):\n",
" t_min = paddle.to_tensor(t_min)\n",
" t_max = paddle.to_tensor(t_max)\n",
" return x.maximum(t_min).minimum(t_max)\n",
"\n",
"def sequence_mask(x_len, max_len=None, dtype='float32'):\n",
" max_len = max_len or x_len.max()\n",
" x_len = paddle.unsqueeze(x_len, -1)\n",
" row_vector = paddle.arange(max_len)\n",
" mask = row_vector > x_len # maybe a bug\n",
" mask = paddle.cast(mask, dtype)\n",
" print(f'seq mask: {mask}')\n",
" return mask\n",
"\n",
"\n",
"class ConvBn(nn.Layer):\n",
" def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,\n",
" padding, act):\n",
"\n",
" super().__init__()\n",
" self.kernel_size = kernel_size\n",
" self.stride = stride\n",
" self.padding = padding\n",
"\n",
" self.conv = nn.Conv2D(\n",
" num_channels_in,\n",
" num_channels_out,\n",
" kernel_size=kernel_size,\n",
" stride=stride,\n",
" padding=padding,\n",
" weight_attr=None,\n",
" bias_attr=None,\n",
" data_format='NCHW')\n",
"\n",
" self.bn = nn.BatchNorm2D(\n",
" num_channels_out,\n",
" weight_attr=None,\n",
" bias_attr=None,\n",
" data_format='NCHW')\n",
" self.act = F.relu if act == 'relu' else brelu\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x(Tensor): audio, shape [B, C, D, T]\n",
" \"\"\"\n",
" x = self.conv(x)\n",
" x = self.bn(x)\n",
" x = self.act(x)\n",
"\n",
" x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]\n",
" ) // self.stride[1] + 1\n",
"\n",
" # reset padding part to 0\n",
" masks = sequence_mask(x_len) #[B, T]\n",
" masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]\n",
" x = x.multiply(masks)\n",
"\n",
" return x, x_len\n",
"\n",
"\n",
"class ConvStack(nn.Layer):\n",
" def __init__(self, feat_size, num_stacks):\n",
" super().__init__()\n",
" self.feat_size = feat_size # D\n",
" self.num_stacks = num_stacks\n",
"\n",
" self.conv_in = ConvBn(\n",
" num_channels_in=1,\n",
" num_channels_out=32,\n",
" kernel_size=(41, 11), #[D, T]\n",
" stride=(2, 3),\n",
" padding=(20, 5),\n",
" act='brelu')\n",
"\n",
" out_channel = 32\n",
" self.conv_stack = nn.Sequential([\n",
" ConvBn(\n",
" num_channels_in=32,\n",
" num_channels_out=out_channel,\n",
" kernel_size=(21, 11),\n",
" stride=(2, 1),\n",
" padding=(10, 5),\n",
" act='brelu') for i in range(num_stacks - 1)\n",
" ])\n",
"\n",
" # conv output feat_dim\n",
" output_height = (feat_size - 1) // 2 + 1\n",
" for i in range(self.num_stacks - 1):\n",
" output_height = (output_height - 1) // 2 + 1\n",
" self.output_height = out_channel * output_height\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x: shape [B, C, D, T]\n",
" x_len : shape [B]\n",
" \"\"\"\n",
" print(f\"conv in: {x_len}\")\n",
" x, x_len = self.conv_in(x, x_len)\n",
" for i, conv in enumerate(self.conv_stack):\n",
" print(f\"conv in: {x_len}\")\n",
" x, x_len = conv(x, x_len)\n",
" print(f\"conv out: {x_len}\")\n",
" return x, x_len\n",
" \n",
" \n",
"\n",
"class RNNCell(nn.RNNCellBase):\n",
" r\"\"\"\n",
" Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it \n",
" computes the outputs and updates states.\n",
" The formula used is as follows:\n",
" .. math::\n",
" h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})\n",
" y_{t} & = h_{t}\n",
" \n",
" where :math:`act` is for :attr:`activation`.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" hidden_size,\n",
" activation=\"tanh\",\n",
" weight_ih_attr=None,\n",
" weight_hh_attr=None,\n",
" bias_ih_attr=None,\n",
" bias_hh_attr=None,\n",
" name=None):\n",
" super().__init__()\n",
" std = 1.0 / math.sqrt(hidden_size)\n",
" self.weight_hh = self.create_parameter(\n",
" (hidden_size, hidden_size),\n",
" weight_hh_attr,\n",
" default_initializer=I.Uniform(-std, std))\n",
" # self.bias_ih = self.create_parameter(\n",
" # (hidden_size, ),\n",
" # bias_ih_attr,\n",
" # is_bias=True,\n",
" # default_initializer=I.Uniform(-std, std))\n",
" self.bias_ih = None\n",
" self.bias_hh = self.create_parameter(\n",
" (hidden_size, ),\n",
" bias_hh_attr,\n",
" is_bias=True,\n",
" default_initializer=I.Uniform(-std, std))\n",
"\n",
" self.hidden_size = hidden_size\n",
" if activation not in [\"tanh\", \"relu\", \"brelu\"]:\n",
" raise ValueError(\n",
" \"activation for SimpleRNNCell should be tanh or relu, \"\n",
" \"but get {}\".format(activation))\n",
" self.activation = activation\n",
" self._activation_fn = paddle.tanh \\\n",
" if activation == \"tanh\" \\\n",
" else F.relu\n",
" if activation == 'brelu':\n",
" self._activation_fn = brelu\n",
"\n",
" def forward(self, inputs, states=None):\n",
" if states is None:\n",
" states = self.get_initial_states(inputs, self.state_shape)\n",
" pre_h = states\n",
" i2h = inputs\n",
" if self.bias_ih is not None:\n",
" i2h += self.bias_ih\n",
" h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)\n",
" if self.bias_hh is not None:\n",
" h2h += self.bias_hh\n",
" h = self._activation_fn(i2h + h2h)\n",
" return h, h\n",
"\n",
" @property\n",
" def state_shape(self):\n",
" return (self.hidden_size, )\n",
"\n",
"\n",
"class GRUCellShare(nn.RNNCellBase):\n",
" r\"\"\"\n",
" Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, \n",
" it computes the outputs and updates states.\n",
" The formula for GRU used is as follows:\n",
" .. math::\n",
" r_{t} & = \\sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})\n",
" z_{t} & = \\sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})\n",
" \\widetilde{h}_{t} & = \\tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))\n",
" h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \\widetilde{h}_{t}\n",
" y_{t} & = h_{t}\n",
" \n",
" where :math:`\\sigma` is the sigmoid fucntion, and * is the elemetwise \n",
" multiplication operator.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" input_size,\n",
" hidden_size,\n",
" weight_ih_attr=None,\n",
" weight_hh_attr=None,\n",
" bias_ih_attr=None,\n",
" bias_hh_attr=None,\n",
" name=None):\n",
" super().__init__()\n",
" std = 1.0 / math.sqrt(hidden_size)\n",
" self.weight_hh = self.create_parameter(\n",
" (3 * hidden_size, hidden_size),\n",
" weight_hh_attr,\n",
" default_initializer=I.Uniform(-std, std))\n",
" # self.bias_ih = self.create_parameter(\n",
" # (3 * hidden_size, ),\n",
" # bias_ih_attr,\n",
" # is_bias=True,\n",
" # default_initializer=I.Uniform(-std, std))\n",
" self.bias_ih = None\n",
" self.bias_hh = self.create_parameter(\n",
" (3 * hidden_size, ),\n",
" bias_hh_attr,\n",
" is_bias=True,\n",
" default_initializer=I.Uniform(-std, std))\n",
"\n",
" self.hidden_size = hidden_size\n",
" self.input_size = input_size\n",
" self._gate_activation = F.sigmoid\n",
" #self._activation = paddle.tanh\n",
" self._activation = F.relu\n",
"\n",
" def forward(self, inputs, states=None):\n",
" if states is None:\n",
" states = self.get_initial_states(inputs, self.state_shape)\n",
"\n",
" pre_hidden = states\n",
" x_gates = inputs\n",
" if self.bias_ih is not None:\n",
" x_gates = x_gates + self.bias_ih\n",
" h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)\n",
" if self.bias_hh is not None:\n",
" h_gates = h_gates + self.bias_hh\n",
"\n",
" x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)\n",
" h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)\n",
"\n",
" r = self._gate_activation(x_r + h_r)\n",
" z = self._gate_activation(x_z + h_z)\n",
" c = self._activation(x_c + r * h_c) # apply reset gate after mm\n",
" h = (pre_hidden - c) * z + c\n",
" # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru\n",
" #h = (1-z) * pre_hidden + z * c\n",
"\n",
" return h, h\n",
"\n",
" @property\n",
" def state_shape(self):\n",
" r\"\"\"\n",
" The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch\n",
" size would be automatically inserted into shape). The shape corresponds\n",
" to the shape of :math:`h_{t-1}`.\n",
" \"\"\"\n",
" return (self.hidden_size, )\n",
"\n",
"\n",
"class BiRNNWithBN(nn.Layer):\n",
" \"\"\"Bidirectonal simple rnn layer with sequence-wise batch normalization.\n",
" The batch normalization is only performed on input-state weights.\n",
"\n",
" :param name: Name of the layer parameters.\n",
" :type name: string\n",
" :param size: Dimension of RNN cells.\n",
" :type size: int\n",
" :param share_weights: Whether to share input-hidden weights between\n",
" forward and backward directional RNNs.\n",
" :type share_weights: bool\n",
" :return: Bidirectional simple rnn layer.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, share_weights):\n",
" super().__init__()\n",
" self.share_weights = share_weights\n",
" if self.share_weights:\n",
" #input-hidden weights shared between bi-directional rnn.\n",
" self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" # batch norm is only performed on input-state projection\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = self.fw_fc\n",
" self.bw_bn = self.fw_bn\n",
" else:\n",
" self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" self.bw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
"\n",
" self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n",
" self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n",
" self.fw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n",
" self.bw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n",
"\n",
" def forward(self, x, x_len):\n",
" # x, shape [B, T, D]\n",
" fw_x = self.fw_bn(self.fw_fc(x))\n",
" bw_x = self.bw_bn(self.bw_fc(x))\n",
" fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n",
" bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n",
" x = paddle.concat([fw_x, bw_x], axis=-1)\n",
" return x, x_len\n",
"\n",
"\n",
"class BiGRUWithBN(nn.Layer):\n",
" \"\"\"Bidirectonal gru layer with sequence-wise batch normalization.\n",
" The batch normalization is only performed on input-state weights.\n",
"\n",
" :param name: Name of the layer.\n",
" :type name: string\n",
" :param input: Input layer.\n",
" :type input: Variable\n",
" :param size: Dimension of GRU cells.\n",
" :type size: int\n",
" :param act: Activation type.\n",
" :type act: string\n",
" :return: Bidirectional GRU layer.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, act):\n",
" super().__init__()\n",
" hidden_size = h_size * 3\n",
" self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" hidden_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n",
" self.bw_bn = nn.BatchNorm1D(\n",
" hidden_size, bias_attr=None, data_format='NLC')\n",
"\n",
" self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n",
" self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n",
" self.fw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n",
" self.bw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n",
"\n",
" def forward(self, x, x_len):\n",
" # x, shape [B, T, D]\n",
" fw_x = self.fw_bn(self.fw_fc(x))\n",
" bw_x = self.bw_bn(self.bw_fc(x))\n",
" fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n",
" bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n",
" x = paddle.concat([fw_x, bw_x], axis=-1)\n",
" return x, x_len\n",
"\n",
"\n",
"class RNNStack(nn.Layer):\n",
" \"\"\"RNN group with stacked bidirectional simple RNN or GRU layers.\n",
"\n",
" :param input: Input layer.\n",
" :type input: Variable\n",
" :param size: Dimension of RNN cells in each layer.\n",
" :type size: int\n",
" :param num_stacks: Number of stacked rnn layers.\n",
" :type num_stacks: int\n",
" :param use_gru: Use gru if set True. Use simple rnn if set False.\n",
" :type use_gru: bool\n",
" :param share_rnn_weights: Whether to share input-hidden weights between\n",
" forward and backward directional RNNs.\n",
" It is only available when use_gru=False.\n",
" :type share_weights: bool\n",
" :return: Output layer of the RNN group.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):\n",
" super().__init__()\n",
" self.rnn_stacks = nn.LayerList()\n",
" for i in range(num_stacks):\n",
" if use_gru:\n",
" #default:GRU using tanh\n",
" self.rnn_stacks.append(\n",
" BiGRUWithBN(i_size=i_size, h_size=h_size, act=\"relu\"))\n",
" else:\n",
" self.rnn_stacks.append(\n",
" BiRNNWithBN(\n",
" i_size=i_size,\n",
" h_size=h_size,\n",
" share_weights=share_rnn_weights))\n",
" i_size = h_size * 2\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x: shape [B, T, D]\n",
" x_len: shpae [B]\n",
" \"\"\"\n",
" for i, rnn in enumerate(self.rnn_stacks):\n",
" x, x_len = rnn(x, x_len)\n",
" masks = sequence_mask(x_len) #[B, T]\n",
" masks = masks.unsqueeze(-1) # [B, T, 1]\n",
" x = x.multiply(masks)\n",
" return x, x_len\n",
"\n",
" \n",
"class DeepSpeech2Test(DeepSpeech2):\n",
" def __init__(self,\n",
" feat_size,\n",
" dict_size,\n",
" num_conv_layers=2,\n",
" num_rnn_layers=3,\n",
" rnn_size=256,\n",
" use_gru=False,\n",
" share_rnn_weights=True):\n",
" super().__init__(feat_size,\n",
" dict_size,\n",
" num_conv_layers=2,\n",
" num_rnn_layers=3,\n",
" rnn_size=256,\n",
" use_gru=False,\n",
" share_rnn_weights=True)\n",
" self.feat_size = feat_size # 161 for linear\n",
" self.dict_size = dict_size\n",
"\n",
" self.conv = ConvStack(feat_size, num_conv_layers)\n",
" \n",
"# self.fc = nn.Linear(1312, dict_size + 1)\n",
"\n",
" i_size = self.conv.output_height # H after conv stack\n",
" self.rnn = RNNStack(\n",
" i_size=i_size,\n",
" h_size=rnn_size,\n",
" num_stacks=num_rnn_layers,\n",
" use_gru=use_gru,\n",
" share_rnn_weights=share_rnn_weights)\n",
" \n",
" self.fc = nn.Linear(rnn_size * 2, dict_size + 1)\n",
" \n",
" def infer(self, audio, audio_len):\n",
" # [B, D, T] -> [B, C=1, D, T]\n",
" audio = audio.unsqueeze(1)\n",
"\n",
" # convolution group\n",
" x, audio_len = self.conv(audio, audio_len)\n",
" print('conv out', x.shape)\n",
"\n",
" # convert data from convolution feature map to sequence of vectors\n",
" B, C, D, T = paddle.shape(x)\n",
" x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]\n",
" x = x.reshape([B, T, C * D]) #[B, T, C*D]\n",
" print('rnn input', x.shape)\n",
"\n",
" # remove padding part\n",
" x, audio_len = self.rnn(x, audio_len) #[B, T, D]\n",
" print('rnn output', x.shape)\n",
"\n",
" logits = self.fc(x) #[B, T, V + 1]\n",
"\n",
" #ctcdecoder need probs, not log_probs\n",
" probs = F.softmax(logits)\n",
"\n",
" return logits, probs, audio_len\n",
"\n",
" def forward(self, audio, audio_len, text, text_len):\n",
" \"\"\"\n",
" audio: shape [B, D, T]\n",
" text: shape [B, T]\n",
" audio_len: shape [B]\n",
" text_len: shape [B]\n",
" \"\"\"\n",
" return self.infer(audio, audio_len)\n",
" \n",
"\n",
"feat_dim=161\n",
"\n",
"model = DeepSpeech2Test(\n",
" feat_size=feat_dim,\n",
" dict_size=batch_reader.dataset.vocab_size,\n",
" num_conv_layers=args.num_conv_layers,\n",
" num_rnn_layers=args.num_rnn_layers,\n",
" rnn_size=1024,\n",
" use_gru=args.use_gru,\n",
" share_rnn_weights=args.share_rnn_weights,\n",
" )\n",
"dp_model = model\n",
"#dp_model = paddle.DataParallel(model)\n",
"\n",
"loss_fn = DeepSpeech2Loss(batch_reader.dataset.vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "divided-incentive",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 22,
"id": "discrete-conjunction",
"metadata": {},
"outputs": [],
"source": [
"audio, audio_len, text, text_len = None, None, None, None\n",
"\n",
"for idx, inputs in enumerate(batch_reader):\n",
" audio, audio_len, text, text_len = inputs\n",
"# print(idx)\n",
"# print('a', audio.shape, audio.place)\n",
"# print('t', text)\n",
"# print('al', audio_len)\n",
"# print('tl', text_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "protected-announcement",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"conv in: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"conv in: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"conv out: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"conv out [5, 32, 41, 62]\n",
"rnn input [5, 62, 1312]\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
" return (isinstance(seq, collections.Sequence) and\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"rnn output [5, 62, 2048]\n",
"logits len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"loss Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [2316.82153320])\n"
]
}
],
"source": [
"outputs = dp_model(audio, audio_len, text, text_len)\n",
"logits, _, logits_len = outputs\n",
"print('logits len', logits_len)\n",
"loss = loss_fn.forward(logits, text, logits_len, text_len)\n",
"print('loss', loss)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "universal-myrtle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: None\n",
"param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: None\n",
"param grad: fc.bias: shape: [4299] stop_grad: False grad: None\n"
]
}
],
"source": [
"for n, p in dp_model.named_parameters():\n",
" print(\n",
" f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "referenced-double",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: [[[[ 2.1243238 1.696022 3.770659 ... 5.234652 5.4865217\n",
" 4.757795 ]\n",
" [ 2.651376 2.3109848 4.428488 ... 5.353201 8.703288\n",
" 5.1787405 ]\n",
" [ 2.7511077 1.8823049 2.1875212 ... 3.4821286 6.386543\n",
" 3.5026932 ]\n",
" ...\n",
" [ 1.9173846 1.8623551 0.5601456 ... 2.8375719 3.8496673\n",
" 2.359191 ]\n",
" [ 2.3827765 2.497965 1.5914664 ... 2.220721 3.4617734\n",
" 4.829253 ]\n",
" [ 1.6855702 1.5040786 1.8793598 ... 4.0773935 3.176893\n",
" 3.7477999 ]]]\n",
"\n",
"\n",
" [[[ 1.8451455 2.0091445 1.5225713 ... 1.524528 0.17764974\n",
" 1.0245132 ]\n",
" [ 1.9388857 1.3873467 2.044691 ... 0.92544 -0.9746763\n",
" -0.41603735]\n",
" [ 2.6814485 2.6096234 1.6802506 ... 1.902397 1.6837387\n",
" -0.96788657]\n",
" ...\n",
" [ 4.3675485 1.9822174 1.1695029 ... 1.4672399 3.2029557\n",
" 2.6364415 ]\n",
" [ 3.2536 1.1792442 -0.5618002 ... 2.101127 1.904225\n",
" 3.3839993 ]\n",
" [ 1.9118482 1.0651072 0.5409893 ... 2.6783593 1.6871439\n",
" 4.1078367 ]]]\n",
"\n",
"\n",
" [[[-4.412424 -1.7111907 -1.7722387 ... -4.3383503 -6.2393785\n",
" -6.139402 ]\n",
" [-2.260428 -1.0250616 -2.0550888 ... -5.353946 -4.29947\n",
" -6.158736 ]\n",
" [-1.4927872 0.7552787 -0.0702923 ... -4.485656 -4.0794134\n",
" -5.416684 ]\n",
" ...\n",
" [ 2.9100134 4.156195 4.357041 ... -3.569804 -1.8634341\n",
" -0.8772557 ]\n",
" [ 1.6895763 3.4314504 4.1192107 ... -1.380024 -2.3234155\n",
" -3.6650617 ]\n",
" [ 2.4190075 1.007498 3.1173465 ... -0.96318084 -3.6175003\n",
" -2.5240796 ]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[-0.6865506 -0.60106415 -1.5555015 ... 2.0853553 1.900961\n",
" 2.101063 ]\n",
" [-0.31686288 -1.4362946 -1.4929098 ... 0.15085456 1.4540495\n",
" 1.4128599 ]\n",
" [-0.57852304 -0.8204216 -2.3264258 ... 1.4970423 0.54599845\n",
" 1.6222539 ]\n",
" ...\n",
" [ 0.32624918 0.96004546 -0.7476514 ... 2.2786083 2.1000178\n",
" 2.7494807 ]\n",
" [-1.6967826 -0.78979015 -1.8424999 ... 1.0620685 2.0544293\n",
" 2.2483966 ]\n",
" [ 0.8192332 2.601636 -2.6636481 ... 0.26625186 1.7610842\n",
" 1.7467536 ]]]\n",
"\n",
"\n",
" [[[ 0.9140297 0.42424175 1.4352363 ... -2.3022954 -3.001058\n",
" -2.6987422 ]\n",
" [ 0.4491998 -0.10698095 1.5089144 ... -3.2831016 -3.6055021\n",
" -3.6595795 ]\n",
" [ 2.6818252 -1.5750014 -0.34812498 ... -4.4137015 -4.250422\n",
" -3.481941 ]\n",
" ...\n",
" [ 1.4232106 2.9689102 3.9547806 ... -0.481165 0.28190404\n",
" -1.2167063 ]\n",
" [ 2.2297084 4.8198485 4.2857304 ... 0.57483846 1.4093391\n",
" 0.0715822 ]\n",
" [ 1.679745 4.768068 5.416195 ... 0.17254728 0.4623217\n",
" 1.4772662 ]]]\n",
"\n",
"\n",
" [[[-2.0860114 -2.9508173 -1.4945896 ... -4.067145 -2.5652342\n",
" -3.5771027 ]\n",
" [-2.697845 -1.9273603 -2.3885014 ... -2.196533 -2.8573706\n",
" -2.0113711 ]\n",
" [-2.413383 -2.7204053 -1.0502659 ... -3.001385 -3.36447\n",
" -4.3225455 ]\n",
" ...\n",
" [ 1.2754489 0.9560999 1.5239805 ... -0.0105865 -1.00876\n",
" 2.6247358 ]\n",
" [ 1.1965859 1.0378222 1.1025598 ... -0.5394704 0.49838027\n",
" -0.9618193 ]\n",
" [ 1.1361816 1.3232857 0.687318 ... -0.23925456 -0.43679112\n",
" -0.79297894]]]]\n",
"param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: [ 5.9604645e-07 -3.9339066e-06 -1.0728836e-06 -1.6689301e-06\n",
" 1.1920929e-06 -2.5033951e-06 -2.3841858e-07 4.7683716e-07\n",
" 4.2915344e-06 -1.9073486e-06 -1.9073486e-06 3.0994415e-06\n",
" -2.6822090e-06 3.3378601e-06 -4.2915344e-06 5.2452087e-06\n",
" 3.8146973e-06 2.3841858e-07 7.1525574e-07 -3.6954880e-06\n",
" 2.0563602e-06 -2.6226044e-06 3.0994415e-06 -3.5762787e-07\n",
" -4.7683716e-06 1.2218952e-06 3.3378601e-06 -2.5629997e-06\n",
" 2.3841858e-07 -1.7881393e-06 4.7683716e-07 -2.7418137e-06]\n",
"param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: [ 2.363316 3.286464 1.9607866 -1.6367784 -1.6325372 -1.7729434\n",
" -0.9261875 2.0950415 0.1155543 -0.8857083 0.70079553 0.33920464\n",
" 2.6953902 -0.64524114 0.8845749 -1.2271115 0.6578167 -2.939814\n",
" 5.5728893 -1.0917969 0.01470797 1.395206 4.8009634 -0.744532\n",
" 0.944651 -1.092311 1.4877632 -3.042566 0.51686054 -5.4768667\n",
" -5.628145 -1.0894046 ]\n",
"param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: [ 1.5193373 1.8838218 3.7722278 0.28052303 0.5386534 -0.44620085\n",
" -1.6977876 3.115642 0.03312349 -2.9121587 3.8925257 0.2288351\n",
" -2.273387 -1.3597974 4.3708124 -0.23374033 0.116272 -0.7064927\n",
" 6.5267463 -1.5318865 1.0288429 0.7928574 -0.24655592 -2.1116853\n",
" 2.922772 -3.3462617 1.7016437 -3.5471547 0.29777628 -3.2820854\n",
" -4.116946 -0.9909375 ]\n",
"param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: [[[[ 6.20494843e-01 5.95983505e-01 -1.48909020e+00 ... -6.86620831e-01\n",
" 6.71104014e-01 -1.95339048e+00]\n",
" [-3.91837955e-03 1.27062631e+00 -1.63248098e+00 ... 1.07290137e+00\n",
" -9.42245364e-01 -3.34277248e+00]\n",
" [ 2.41821265e+00 2.36212373e-01 -1.84433365e+00 ... 1.23182368e+00\n",
" 1.36039746e+00 -2.94621849e+00]\n",
" ...\n",
" [ 1.55153418e+00 7.25861669e-01 2.08785534e+00 ... -6.40172660e-01\n",
" -3.23889256e-02 -2.30832791e+00]\n",
" [ 3.69824195e+00 1.27163112e-01 4.09263194e-01 ... -8.60729575e-01\n",
" -3.51897454e+00 -2.10093403e+00]\n",
" [-4.94779050e-01 -3.74262631e-01 -1.19801068e+00 ... -2.05930543e+00\n",
" -7.38576293e-01 -9.44581270e-01]]\n",
"\n",
" [[-2.04341412e+00 -3.70606273e-01 -1.40429378e+00 ... -1.71711946e+00\n",
" -4.09437418e-01 -1.74107194e+00]\n",
" [-8.72247815e-01 -1.06301677e+00 -9.19306517e-01 ... -2.98976970e+00\n",
" -3.03250861e+00 -2.37099743e+00]\n",
" [-5.00457406e-01 -1.11882675e+00 -5.91526508e-01 ... 4.23921436e-01\n",
" -2.08650708e+00 -1.82109618e+00]\n",
" ...\n",
" [ 2.07773042e+00 1.40735030e-01 -2.60543615e-01 ... -1.55956164e-01\n",
" -1.31862307e+00 -2.07174897e+00]\n",
" [ 7.95007765e-01 1.14988625e-01 -1.43308258e+00 ... 8.29253554e-01\n",
" -9.57888126e-01 -3.82121086e-01]\n",
" [ 8.34397674e-02 1.38636863e+00 -1.21593380e+00 ... -2.65783578e-01\n",
" 1.78124309e-02 -3.40287232e+00]]\n",
"\n",
" [[ 6.27344131e-01 5.71699142e-02 -3.58010936e+00 ... -4.53077674e-01\n",
" 1.65331578e+00 2.58466601e-02]\n",
" [ 2.66681361e+00 2.02069378e+00 -1.52052927e+00 ... 2.94914508e+00\n",
" 1.94632411e+00 -1.06698799e+00]\n",
" [ 1.57839453e+00 -1.03649735e-01 -4.22528505e+00 ... 2.28863955e+00\n",
" 4.27859402e+00 3.66381669e+00]\n",
" ...\n",
" [-2.44603205e+00 -2.09621000e+00 -2.57623529e+00 ... 9.00211930e-01\n",
" 4.30536079e+00 -2.49779320e+00]\n",
" [-2.52187514e+00 -3.36546659e+00 -1.26748765e+00 ... 8.11533451e-01\n",
" 2.55930424e-01 4.50821817e-02]\n",
" [-3.40082574e+00 -3.26924801e+00 -5.86932135e+00 ... -1.18203712e+00\n",
" 1.09565187e+00 -4.96661961e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 8.20469666e+00 6.96195841e+00 2.73753977e+00 ... 8.34498823e-01\n",
" 2.56748104e+00 1.67592216e+00]\n",
" [ 9.85801792e+00 8.81465149e+00 6.09280396e+00 ... 1.42389655e+00\n",
" 2.92086434e+00 2.08308399e-01]\n",
" [ 8.00702763e+00 7.97301006e+00 4.64527416e+00 ... 8.61916900e-01\n",
" 3.55370259e+00 4.75085378e-01]\n",
" ...\n",
" [ 5.61662769e+00 -4.72857296e-01 -1.04519971e-01 ... -4.03000236e-01\n",
" -1.66419971e+00 -1.70375630e-01]\n",
" [ 4.52409792e+00 -3.70670676e-01 4.54190969e-02 ... -8.20453286e-01\n",
" 9.49141383e-02 8.88008535e-01]\n",
" [ 3.27219462e+00 8.93201411e-01 1.94810414e+00 ... -2.86915004e-02\n",
" 1.93200278e+00 8.19505215e-01]]\n",
"\n",
" [[ 5.84066296e+00 6.72855520e+00 5.21399307e+00 ... 4.55058670e+00\n",
" 3.19132543e+00 3.17435169e+00]\n",
" [ 6.04594421e+00 6.88997173e+00 5.00542831e+00 ... 2.23561144e+00\n",
" 2.76059532e+00 4.83479440e-01]\n",
" [ 5.36118126e+00 4.13896275e+00 3.68701124e+00 ... 3.64462805e+00\n",
" 2.80596399e+00 1.52781498e+00]\n",
" ...\n",
" [ 2.87856674e+00 5.84320784e-01 1.74297714e+00 ... 2.83938944e-01\n",
" -2.26546407e-01 -1.18434143e+00]\n",
" [ 2.08510804e+00 1.74915957e+00 1.58637917e+00 ... 6.41967297e-01\n",
" -1.31319761e-01 -3.85830402e-01]\n",
" [ 4.41666174e+00 2.58244562e+00 2.97712159e+00 ... 1.42317235e-01\n",
" 1.68037796e+00 -6.50003672e-01]]\n",
"\n",
" [[ 1.05511594e+00 6.74880028e-01 -7.64639139e-01 ... -2.15282440e-01\n",
" 2.07197094e+00 4.48752761e-01]\n",
" [ 2.12095881e+00 3.44118834e+00 1.61375272e+00 ... -1.18487728e+00\n",
" 1.88659012e+00 1.48252523e+00]\n",
" [ 8.33427787e-01 4.35035896e+00 -3.59877385e-02 ... 8.70242774e-01\n",
" 3.75945044e+00 -3.09408635e-01]\n",
" ...\n",
" [ 5.08510351e+00 4.73114061e+00 1.97346115e+00 ... -2.25924397e+00\n",
" -1.26373076e+00 -1.37826729e+00]\n",
" [ 6.17275095e+00 4.16016817e+00 3.15675950e+00 ... -2.02416754e+00\n",
" 1.50002241e-02 1.84633851e+00]\n",
" [ 7.32995272e+00 5.34601831e+00 4.58857203e+00 ... -1.88874304e+00\n",
" 1.53240371e+00 7.47349262e-02]]]\n",
"\n",
"\n",
" [[[-1.80918843e-01 -2.52616453e+00 -2.78145695e+00 ... 1.44283652e+00\n",
" -1.08945215e+00 4.19084758e-01]\n",
" [-9.66833949e-01 -2.41106153e+00 -3.48886085e+00 ... -1.87193304e-01\n",
" 8.21905077e-01 1.89097953e+00]\n",
" [-1.59118319e+00 -2.56997013e+00 -3.10426521e+00 ... 2.05900550e+00\n",
" -2.78253704e-01 6.96343541e-01]\n",
" ...\n",
" [ 6.66302443e-02 -2.00887346e+00 -3.17550874e+00 ... 7.97579706e-01\n",
" -9.71581042e-02 1.71877682e+00]\n",
" [-8.01679730e-01 -2.02678037e+00 -3.21915555e+00 ... 8.35528374e-01\n",
" -1.15296638e+00 4.35728967e-01]\n",
" [ 1.45292446e-01 -2.15479851e+00 -1.51839817e+00 ... -3.07936192e-01\n",
" -5.39051890e-01 1.13107657e+00]]\n",
"\n",
" [[-2.43341160e+00 -3.35346818e+00 -9.87014294e-01 ... 1.34049034e+00\n",
" 2.95773447e-02 1.27177119e+00]\n",
" [-2.61602497e+00 -9.76761580e-01 -2.52060473e-01 ... -1.38134825e+00\n",
" 3.85564029e-01 4.57195908e-01]\n",
" [-2.23676014e+00 -4.00404739e+00 -2.23409963e+00 ... -1.41846514e+00\n",
" -6.58698231e-02 -3.61778140e-01]\n",
" ...\n",
" [-1.13604403e+00 -6.03917837e-02 -4.95491922e-01 ... 2.14673686e+00\n",
" 1.21484184e+00 2.22764325e+00]\n",
" [-1.05162430e+00 -1.59828448e+00 3.15489501e-01 ... 2.28046751e+00\n",
" 2.39702511e+00 2.43942714e+00]\n",
" [-1.27370405e+00 -2.05736399e-01 -1.12124372e+00 ... 2.21597219e+00\n",
" 2.50086927e+00 1.91134131e+00]]\n",
"\n",
" [[-4.53170598e-01 -1.59644139e+00 -3.63470483e+00 ... -4.35066032e+00\n",
" -3.79540777e+00 -1.09796596e+00]\n",
" [-2.21036464e-01 -2.53353834e+00 -1.28269875e+00 ... -3.38615727e+00\n",
" -2.59143281e+00 7.74220943e-01]\n",
" [-6.89323783e-01 -1.44375205e+00 6.66438341e-02 ... -1.30736077e+00\n",
" -1.23293114e+00 1.58148706e+00]\n",
" ...\n",
" [ 1.63751483e+00 -4.08427984e-01 -8.15176964e-01 ... 3.70807743e+00\n",
" 2.04232907e+00 1.97716308e+00]\n",
" [ 2.13261342e+00 1.85947633e+00 -8.06532025e-01 ... 1.98311245e+00\n",
" 2.27003932e+00 -1.11734614e-01]\n",
" [ 1.28702402e+00 3.98628891e-01 -1.63712263e+00 ... 8.00528765e-01\n",
" 5.78273535e-01 -2.59924948e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 3.96233416e+00 4.66794682e+00 1.39437711e+00 ... 7.52061129e-01\n",
" -1.53534544e+00 -6.67162359e-01]\n",
" [ 2.33841681e+00 3.35811281e+00 9.80114818e-01 ... 1.48806703e+00\n",
" 2.68609226e-01 -1.35124445e+00]\n",
" [ 2.08177710e+00 4.28519583e+00 1.52450514e+00 ... 7.45321214e-01\n",
" -5.04359961e-01 -1.81241560e+00]\n",
" ...\n",
" [ 2.95398951e-01 4.30877179e-01 -2.03731894e+00 ... -4.20221925e-01\n",
" 3.29260826e-01 5.83679557e-01]\n",
" [ 1.30742240e+00 -6.32183790e-01 -3.13741422e+00 ... 9.63868052e-02\n",
" 2.91730791e-01 1.33400351e-01]\n",
" [ 5.43292165e-01 -2.83665359e-01 -1.88138187e+00 ... 2.15468198e-01\n",
" 4.90157723e-01 2.40562439e+00]]\n",
"\n",
" [[ 1.57632053e+00 6.27885723e+00 2.87853765e+00 ... 3.07016110e+00\n",
" 1.91490650e+00 1.76274943e+00]\n",
" [ 2.57776356e+00 4.07256317e+00 2.52231169e+00 ... 4.09494352e+00\n",
" 2.53548074e+00 2.44395185e+00]\n",
" [ 2.43037057e+00 4.35728836e+00 1.96233964e+00 ... 2.26702976e+00\n",
" 2.94634581e+00 2.21452284e+00]\n",
" ...\n",
" [-2.72509992e-01 -8.41220498e-01 -1.89133918e+00 ... -1.80079627e+00\n",
" -2.00367713e+00 -7.09145784e-01]\n",
" [ 8.21575999e-01 -1.13323164e+00 -2.62418866e+00 ... -2.38889670e+00\n",
" -7.83945560e-01 -1.01922750e-01]\n",
" [-1.14730227e+00 -1.42182577e+00 -2.00993991e+00 ... -2.11025667e+00\n",
" 1.60286129e-02 -7.26446986e-01]]\n",
"\n",
" [[ 4.20389509e+00 3.75917768e+00 4.97653627e+00 ... 1.23642838e+00\n",
" 8.52760911e-01 1.27920091e-01]\n",
" [ 5.29409122e+00 5.29002380e+00 3.96404648e+00 ... 1.91227329e+00\n",
" 3.97556186e-01 1.69182217e+00]\n",
" [ 4.60112572e+00 4.12772799e+00 2.10280085e+00 ... 3.24303842e+00\n",
" -1.07720590e+00 -3.81854475e-01]\n",
" ...\n",
" [ 1.81884170e-02 -3.11472058e+00 -8.23525012e-01 ... -2.40161085e+00\n",
" -4.48192549e+00 -6.14600539e-01]\n",
" [ 1.16305006e+00 -1.15409636e+00 -3.48765063e+00 ... -1.97504926e+00\n",
" -4.44984436e+00 -2.28429958e-01]\n",
" [ 1.29197860e+00 6.17720246e-01 -5.87171853e-01 ... -1.35258228e-01\n",
" -1.29259872e+00 1.30360842e-01]]]\n",
"\n",
"\n",
" [[[-1.26687372e+00 -2.33633637e+00 -1.49625254e+00 ... 2.52396107e+00\n",
" -6.68072224e-01 -1.13282454e+00]\n",
" [-1.34229445e+00 -2.87080932e+00 -2.57388353e+00 ... -8.75385761e-01\n",
" -1.00205469e+00 -3.58956242e+00]\n",
" [-9.49853599e-01 -5.78684711e+00 -3.52962446e+00 ... 8.88233304e-01\n",
" 2.25133196e-01 -1.02802217e+00]\n",
" ...\n",
" [-7.38113701e-01 -3.47510982e+00 -3.23011065e+00 ... -1.25624001e+00\n",
" -1.63268471e+00 6.00247443e-01]\n",
" [-2.29733467e+00 -5.72547615e-01 -1.98301303e+00 ... -1.90137398e+00\n",
" -1.47013855e+00 -1.45779204e+00]\n",
" [-2.24628520e+00 -3.36337948e+00 -3.91878939e+00 ... -1.53652275e+00\n",
" -1.36285520e+00 -1.68160331e+00]]\n",
"\n",
" [[-8.11348319e-01 -7.17824280e-01 -1.02243233e+00 ... -2.69050407e+00\n",
" -2.32403350e+00 -4.25943947e+00]\n",
" [-2.35056520e+00 -2.35941172e+00 -1.24398732e+00 ... -2.08313870e+00\n",
" -1.16508257e+00 -1.30353463e+00]\n",
" [-2.25146723e+00 -1.94972813e+00 -1.13295293e+00 ... -2.61496377e+00\n",
" -1.91106403e+00 -1.07801402e+00]\n",
" ...\n",
" [-2.67012739e+00 -3.20916414e+00 -2.41768575e+00 ... 2.65138328e-01\n",
" -5.27612507e-01 1.44604075e+00]\n",
" [-3.54237866e+00 -3.62832785e+00 -2.40270257e+00 ... -9.76106226e-02\n",
" 4.67946082e-01 -7.24248111e-01]\n",
" [-2.49844384e+00 -3.42463255e+00 -2.99040008e+00 ... 4.28889185e-01\n",
" -7.51657963e-01 -1.00530767e+00]]\n",
"\n",
" [[-8.42589438e-02 1.42022014e-01 -8.51281703e-01 ... 4.21745628e-01\n",
" -2.35717297e-02 -1.71374834e+00]\n",
" [-1.05496287e+00 3.82416457e-01 -4.40595537e-01 ... 1.03381336e-01\n",
" -1.41204190e+00 -7.58325040e-01]\n",
" [-2.28930283e+00 -2.03857040e+00 -9.16261196e-01 ... -3.94939929e-01\n",
" -1.07798588e+00 -1.48433352e+00]\n",
" ...\n",
" [-3.11473966e-01 -1.40877593e+00 -2.42908645e+00 ... 7.88682699e-01\n",
" 1.24199319e+00 1.89949930e-01]\n",
" [ 5.44084549e-01 -1.02425671e+00 -1.53991556e+00 ... -4.36764538e-01\n",
" -5.78772545e-01 2.62665659e-01]\n",
" [ 1.26812792e+00 -9.89493608e-01 -1.47972977e+00 ... 2.21440494e-02\n",
" 2.79776216e-01 7.63269484e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 6.02095068e-01 5.93243122e-01 -1.06838238e+00 ... 3.56546330e+00\n",
" 1.16390383e+00 -1.47593319e-01]\n",
" [ 1.80458140e+00 1.68401957e+00 4.17516947e-01 ... 3.33444500e+00\n",
" 1.89411759e+00 1.03220642e-01]\n",
" [ 2.74264169e+00 2.92038846e+00 1.00775683e+00 ... 3.53285050e+00\n",
" 2.07282662e+00 -2.56800652e-01]\n",
" ...\n",
" [ 4.88933468e+00 3.72433925e+00 3.58677816e+00 ... 1.98363388e+00\n",
" 1.80851030e+00 8.32634747e-01]\n",
" [ 4.01546288e+00 4.78934765e+00 2.94778132e+00 ... 2.99637699e+00\n",
" 1.30439472e+00 3.61029744e-01]\n",
" [ 3.13628030e+00 2.01894832e+00 2.82585931e+00 ... 2.54264188e+00\n",
" -9.16651785e-02 9.93353873e-02]]\n",
"\n",
" [[ 2.35585642e+00 8.42678428e-01 1.57331872e+00 ... 3.65935063e+00\n",
" 3.94066262e+00 4.89832020e+00]\n",
" [ 1.85791731e+00 1.34373701e+00 1.30812299e+00 ... 2.71434736e+00\n",
" 3.22004294e+00 2.99872303e+00]\n",
" [ 1.67675853e+00 -4.05569375e-02 1.85539150e+00 ... 3.73934364e+00\n",
" 2.98195982e+00 3.37315011e+00]\n",
" ...\n",
" [ 2.14539170e+00 2.86586595e+00 2.20222116e+00 ... 1.20492995e+00\n",
" 2.13971066e+00 1.94932449e+00]\n",
" [ 4.68422651e+00 3.80044746e+00 4.23209000e+00 ... 2.40658951e+00\n",
" 2.29117441e+00 2.52368808e+00]\n",
" [ 3.10694575e+00 2.49402595e+00 4.53786707e+00 ... 9.08902645e-01\n",
" 1.86903965e+00 2.27776885e+00]]\n",
"\n",
" [[ 1.45200038e+00 5.17961740e-01 -1.58403587e+00 ... 5.07019472e+00\n",
" 7.87163258e-01 1.20610237e+00]\n",
" [ 3.39321136e+00 2.21043849e+00 -6.31202877e-01 ... 4.97822762e+00\n",
" 9.66498017e-01 1.18883348e+00]\n",
" [ 1.20627856e+00 1.82759428e+00 5.91053367e-01 ... 4.14318657e+00\n",
" 5.25399208e-01 -1.16850233e+00]\n",
" ...\n",
" [ 1.05183899e+00 5.80030501e-01 1.89724147e+00 ... 2.54626465e+00\n",
" -1.49128008e+00 -1.85064209e+00]\n",
" [ 1.50983357e+00 2.85973406e+00 2.61224055e+00 ... 4.83481932e+00\n",
" 9.67048705e-02 -4.37043965e-01]\n",
" [ 2.57720876e+00 2.09961963e+00 4.11754288e-02 ... 3.80421424e+00\n",
" -7.83308804e-01 -1.64871216e+00]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[-1.16345096e+00 -2.53971386e+00 -8.99101734e-01 ... -4.35583591e-01\n",
" -1.29671764e+00 -1.61429560e+00]\n",
" [ 3.72841507e-01 3.45808208e-01 -1.82167351e+00 ... -2.14515448e+00\n",
" -1.26383066e+00 -2.27464601e-01]\n",
" [ 1.58568513e+00 2.58181524e+00 1.86554670e+00 ... -1.10401320e+00\n",
" -3.68550658e-01 -2.58849680e-01]\n",
" ...\n",
" [-9.15827155e-01 -1.25424683e+00 -4.04716206e+00 ... 2.13138080e+00\n",
" 2.67662477e+00 2.31014514e+00]\n",
" [-3.19453120e-01 -6.71132684e-01 -1.51378751e+00 ... 1.86080432e+00\n",
" 2.77418542e+00 1.22875953e+00]\n",
" [-1.20453942e+00 -3.93669218e-01 -1.51751983e+00 ... 1.17620552e+00\n",
" 1.95602298e+00 7.64306366e-01]]\n",
"\n",
" [[-8.73186827e-01 -2.12537169e+00 -1.91664994e+00 ... -2.90821463e-01\n",
" 1.90896463e+00 8.02283168e-01]\n",
" [-1.06389821e+00 -2.15300727e+00 -1.82113051e+00 ... -4.34280694e-01\n",
" 1.53455496e+00 1.94702053e+00]\n",
" [-2.08403468e+00 -4.72900331e-01 -1.10610819e+00 ... -8.79420400e-01\n",
" 7.79394627e-01 2.02670670e+00]\n",
" ...\n",
" [-4.28208113e-01 -7.90894389e-01 -1.06713009e+00 ... 1.12579381e+00\n",
" 9.61961091e-01 1.40342009e+00]\n",
" [ 4.40416574e-01 -1.65901780e-02 -1.05338669e+00 ... 1.40698349e+00\n",
" 9.43485856e-01 2.34856772e+00]\n",
" [-1.20572495e+00 -2.03134632e+00 4.88817632e-01 ... 2.20770907e+00\n",
" 1.38143206e+00 2.00714707e+00]]\n",
"\n",
" [[ 9.00486887e-01 -9.50459957e-01 -1.42935121e+00 ... -1.30648065e+00\n",
" -2.52133775e+00 -8.87715697e-01]\n",
" [ 3.73431134e+00 1.69571114e+00 5.99429727e-01 ... 6.64332986e-01\n",
" -6.10453069e-01 2.06534386e+00]\n",
" [ 1.59800696e+00 -4.59622175e-01 -6.73136234e-01 ... 2.18770742e-01\n",
" -1.12928271e+00 4.87097502e-02]\n",
" ...\n",
" [ 1.92336845e+00 1.37130380e-01 -3.51048648e-01 ... 5.41638851e-01\n",
" 1.06069386e+00 1.36404145e+00]\n",
" [ 1.29641414e+00 -2.79530913e-01 -2.63607264e-01 ... -8.62445176e-01\n",
" 1.48393130e+00 2.69196725e+00]\n",
" [ 1.14442182e+00 -1.24098969e+00 3.70959163e-01 ... -1.12241995e+00\n",
" 3.67927134e-01 2.55976987e+00]]\n",
"\n",
" ...\n",
"\n",
" [[ 5.32017851e+00 3.64207411e+00 3.84571218e+00 ... 3.60754800e+00\n",
" 2.57500267e+00 -1.38083458e-01]\n",
" [ 5.69058084e+00 3.93056583e+00 2.93337941e+00 ... 3.17091584e+00\n",
" 2.34770632e+00 6.48133337e-01]\n",
" [ 5.98239613e+00 6.16548634e+00 3.04750896e+00 ... 5.51510525e+00\n",
" 4.34810448e+00 1.31588542e+00]\n",
" ...\n",
" [ 5.09930992e+00 3.32360983e+00 2.29228449e+00 ... 3.45123887e-01\n",
" 1.06280947e+00 -5.93325794e-02]\n",
" [ 4.19760656e+00 3.97779059e+00 1.66905916e+00 ... 3.68937254e-01\n",
" 8.06131065e-02 8.08142900e-01]\n",
" [ 4.52498960e+00 3.45109749e+00 1.01074433e+00 ... -2.54036248e-01\n",
" 3.13675582e-01 2.13851762e+00]]\n",
"\n",
" [[ 6.93927193e+00 6.05758238e+00 4.60648441e+00 ... 4.32221603e+00\n",
" 3.17874146e+00 1.47012353e+00]\n",
" [ 7.88523865e+00 6.62228966e+00 4.77496338e+00 ... 4.45868683e+00\n",
" 2.73698759e+00 2.17057824e+00]\n",
" [ 7.12061214e+00 6.01714134e+00 4.52996492e+00 ... 3.97184372e+00\n",
" 3.43153954e+00 1.21802723e+00]\n",
" ...\n",
" [ 2.85720730e+00 1.89639473e+00 1.96340394e+00 ... 1.89643729e+00\n",
" 1.64856291e+00 1.15853786e+00]\n",
" [ 3.88248491e+00 2.16386199e+00 1.53069091e+00 ... 2.71704245e+00\n",
" 2.24890351e+00 2.22156644e+00]\n",
" [ 5.27136230e+00 1.68400204e+00 2.09500480e+00 ... 2.75956345e+00\n",
" 3.71970820e+00 1.69852686e+00]]\n",
"\n",
" [[ 2.55598164e+00 1.64588141e+00 6.70431674e-01 ... 3.24091220e+00\n",
" 1.48759770e+00 -1.72001183e+00]\n",
" [ 4.33942318e+00 8.40826690e-01 -7.40000725e-01 ... 7.24577069e-01\n",
" 1.74327165e-01 -1.83029580e+00]\n",
" [ 4.39864540e+00 2.28395438e+00 -1.90353513e-01 ... 5.58019161e+00\n",
" 1.05627227e+00 -8.02519619e-01]\n",
" ...\n",
" [ 1.97654784e+00 3.26888156e+00 1.52879453e+00 ... 3.15013933e+00\n",
" 4.66731453e+00 4.98701715e+00]\n",
" [ 1.40016854e+00 3.45761251e+00 3.68359756e+00 ... 1.14207900e+00\n",
" 3.32219076e+00 3.83035636e+00]\n",
" [ 1.99269783e+00 2.15428829e+00 3.35396528e-01 ... 2.45916694e-01\n",
" 2.13785577e+00 4.33214951e+00]]]\n",
"\n",
"\n",
" [[[ 1.35320330e+00 5.05850911e-02 1.04915988e+00 ... 1.82023585e-01\n",
" 2.72914767e-01 3.92112255e-01]\n",
" [ 1.04646444e+00 7.60913491e-01 1.93323612e+00 ... 1.19493449e+00\n",
" -1.44200325e-01 4.07531261e-02]\n",
" [-9.88207340e-01 -1.46165287e+00 1.05884135e-01 ... -3.23057353e-01\n",
" -2.28934169e+00 -7.38609374e-01]\n",
" ...\n",
" [ 1.01198792e+00 2.34331083e+00 1.04566610e+00 ... 1.29697472e-01\n",
" -1.23878837e+00 2.21006930e-01]\n",
" [-3.75360101e-01 1.53673506e+00 -1.32206869e+00 ... -2.55255580e-01\n",
" -6.22699618e-01 -1.73162484e+00]\n",
" [ 4.34735864e-01 5.08327007e-01 -3.49233925e-01 ... -1.04749084e+00\n",
" -1.15777385e+00 -1.13671994e+00]]\n",
"\n",
" [[ 1.67839336e+00 -1.80224836e-01 1.02194118e+00 ... 8.44027162e-01\n",
" 8.81283879e-02 -1.37762165e+00]\n",
" [ 8.39694083e-01 1.32322550e+00 4.02442753e-01 ... -4.21785116e-01\n",
" -9.98012185e-01 -1.11348581e+00]\n",
" [ 7.64424682e-01 8.58965695e-01 2.94626594e-01 ... -6.65519595e-01\n",
" -3.65677416e-01 -2.25250268e+00]\n",
" ...\n",
" [-1.10193872e+00 1.18070498e-01 1.04604781e-01 ... -1.44486964e+00\n",
" -2.52748466e+00 -2.16131711e+00]\n",
" [-1.06079710e+00 -1.48379254e+00 3.80138367e-01 ... -1.62288392e+00\n",
" -2.44736362e+00 -8.78590107e-01]\n",
" [ 3.44401300e-02 -2.60935068e+00 -2.35597759e-01 ... -2.41114974e+00\n",
" -2.45255780e+00 -1.82384634e+00]]\n",
"\n",
" [[ 1.37670958e+00 1.58661580e+00 -2.85664916e-01 ... 1.49081087e+00\n",
" 4.13422853e-01 1.12761199e+00]\n",
" [ 1.54148173e+00 6.22704089e-01 1.41886568e+00 ... 1.59678531e+00\n",
" -8.72656107e-01 1.52415514e-01]\n",
" [ 3.30207205e+00 2.89925170e+00 1.91855145e+00 ... 3.18863559e+00\n",
" 1.87347198e+00 9.48901057e-01]\n",
" ...\n",
" [-1.53920484e+00 1.77375078e-02 -1.02018684e-01 ... 1.94011092e+00\n",
" -6.83587790e-01 1.49154460e+00]\n",
" [-2.27719522e+00 1.02481163e+00 -2.11300224e-01 ... -8.18020821e-01\n",
" 1.54248989e+00 -1.46732473e+00]\n",
" [-4.50206220e-01 3.62383485e+00 1.07175660e+00 ... 4.25961137e-01\n",
" 1.12405360e-01 -6.87821358e-02]]\n",
"\n",
" ...\n",
"\n",
" [[-3.40477467e-01 -2.99311423e+00 -2.12096786e+00 ... 2.27393007e+00\n",
" 4.03424358e+00 3.73335361e+00]\n",
" [-6.99971199e-01 -2.97719741e+00 -2.72910309e+00 ... 1.50101089e+00\n",
" 2.29408574e+00 3.14105940e+00]\n",
" [-1.41648722e+00 -1.86292887e+00 -1.84006739e+00 ... 2.78402638e+00\n",
" 3.91481900e+00 5.32456112e+00]\n",
" ...\n",
" [ 5.97958088e-01 1.50512588e+00 6.23718500e-01 ... 2.83813477e+00\n",
" 3.87909842e+00 3.33359623e+00]\n",
" [ 1.65542316e+00 3.56163192e+00 4.01527691e+00 ... 3.38367462e+00\n",
" 1.55827272e+00 2.50741863e+00]\n",
" [ 2.82036042e+00 2.53322673e+00 4.38798475e+00 ... 4.64642382e+00\n",
" 3.28739667e+00 3.02895570e+00]]\n",
"\n",
" [[-3.47941303e+00 -3.49006844e+00 -2.25583363e+00 ... 1.45181656e-01\n",
" 1.52944064e+00 2.08810711e+00]\n",
" [-2.27786446e+00 -4.59218550e+00 -2.74722624e+00 ... -1.73136210e+00\n",
" 7.46028006e-01 1.74789345e+00]\n",
" [-3.35524082e+00 -4.58244705e+00 -2.40820456e+00 ... -5.04051924e-01\n",
" 1.49640536e+00 2.16613841e+00]\n",
" ...\n",
" [ 5.26107132e-01 2.05329061e+00 2.84252572e+00 ... 1.33222675e+00\n",
" 3.87935114e+00 3.69385266e+00]\n",
" [ 4.38092083e-01 2.15028906e+00 3.13363624e+00 ... 3.36048746e+00\n",
" 5.36551809e+00 2.94915986e+00]\n",
" [ 2.75497317e+00 3.25929213e+00 2.33522987e+00 ... 1.69926262e+00\n",
" 3.93462896e+00 3.68200874e+00]]\n",
"\n",
" [[ 1.10951948e+00 5.31419516e-02 -1.58864903e+00 ... 5.24887085e+00\n",
" 1.60273385e+00 4.90113163e+00]\n",
" [-2.94517064e+00 -2.81092644e+00 -4.89631557e+00 ... 3.99868512e+00\n",
" 1.40544355e+00 2.84833241e+00]\n",
" [-3.51893663e-01 -3.53325534e+00 -2.21239805e+00 ... 4.26225853e+00\n",
" 6.87886119e-01 2.58609629e+00]\n",
" ...\n",
" [ 2.92248201e+00 5.40264511e+00 4.65721560e+00 ... 5.24537373e+00\n",
" 2.30406880e+00 1.29892707e+00]\n",
" [ 1.43473256e+00 4.61167526e+00 3.57578802e+00 ... 5.12181854e+00\n",
" 8.59923482e-01 1.38731599e+00]\n",
" [-6.50881350e-01 2.18233657e+00 2.74669623e+00 ... 4.86368895e+00\n",
" 1.44120216e+00 1.79993320e+00]]]\n",
"\n",
"\n",
" [[[ 1.64106202e+00 3.54410499e-01 -3.54172409e-01 ... 2.32646990e+00\n",
" 1.65043330e+00 3.45897645e-01]\n",
" [ 2.16236949e+00 1.28213906e+00 2.26082468e+00 ... 6.10507369e-01\n",
" 9.12241280e-01 1.27429694e-01]\n",
" [ 2.07962990e+00 7.03816175e-01 2.01272345e+00 ... -2.26959705e-01\n",
" 1.00041127e+00 5.87104559e-02]\n",
" ...\n",
" [-1.62972426e+00 -3.04028845e+00 -1.39124167e+00 ... 2.47561097e+00\n",
" 2.35047388e+00 1.61532843e+00]\n",
" [-1.97368932e+00 -5.44541061e-01 -5.92882216e-01 ... 1.39800012e+00\n",
" 2.32770801e+00 9.96662021e-01]\n",
" [-1.15636075e+00 -1.34654212e+00 -8.50648999e-01 ... 1.85655832e+00\n",
" 2.05776072e+00 5.34575820e-01]]\n",
"\n",
" [[-1.02104437e+00 3.08469892e-01 2.81789303e-01 ... -8.24654043e-01\n",
" -9.85817850e-01 -2.05517030e+00]\n",
" [ 9.50192690e-01 3.35105330e-01 5.31637192e-01 ... -1.42974198e-01\n",
" -1.79659498e+00 -1.58266973e+00]\n",
" [-2.51316994e-01 -1.28709340e+00 3.01498562e-01 ... -1.32253516e+00\n",
" -1.55507576e+00 -9.37123299e-01]\n",
" ...\n",
" [ 2.33016998e-01 2.92454743e+00 3.15420461e+00 ... 1.15574491e+00\n",
" 1.27850962e+00 1.35487700e+00]\n",
" [ 3.81013602e-01 1.44239831e+00 6.64825320e-01 ... -3.89374971e-01\n",
" 1.50716826e-01 1.33641326e+00]\n",
" [ 1.71373415e+00 1.67357373e+00 1.76596940e+00 ... 1.57941079e+00\n",
" 1.60940981e+00 1.78091609e+00]]\n",
"\n",
" [[-5.16522598e+00 -1.68099070e+00 -3.24440050e+00 ... -3.46229005e+00\n",
" -2.18273020e+00 -1.98621082e+00]\n",
" [-3.05743694e+00 9.15392339e-01 -1.93508530e+00 ... -1.82306373e+00\n",
" -2.12960863e+00 -3.45255351e+00]\n",
" [-4.32777822e-01 -1.00303245e+00 -1.61397791e+00 ... -2.08376765e+00\n",
" -3.72989595e-01 -1.36516929e+00]\n",
" ...\n",
" [-5.83641946e-01 4.14125490e+00 1.58227599e+00 ... 2.03144050e+00\n",
" 2.13982654e+00 -1.81909311e+00]\n",
" [-1.74230576e+00 2.39347410e+00 2.44080925e+00 ... 5.43732524e-01\n",
" 2.07899213e+00 -3.71748984e-01]\n",
" [ 3.80016506e-01 7.84988403e-01 1.20596504e+00 ... -2.32057095e+00\n",
" -2.81265080e-01 -3.69353056e+00]]\n",
"\n",
" ...\n",
"\n",
" [[-3.48024845e+00 -2.60937548e+00 -3.84952760e+00 ... 6.68736577e-01\n",
" -1.75104141e-02 -3.54720926e+00]\n",
" [-2.59637117e+00 -5.18190145e+00 -2.33887696e+00 ... 9.13373232e-02\n",
" -3.58282638e+00 -2.40778995e+00]\n",
" [-2.50912881e+00 -1.22113395e+00 -2.34372020e+00 ... 1.40071487e+00\n",
" -1.67449510e+00 -1.14655948e+00]\n",
" ...\n",
" [-5.75253534e+00 -6.67348385e+00 -5.05184650e+00 ... -2.73145151e+00\n",
" -1.48933101e+00 -1.36807609e+00]\n",
" [-3.29049587e+00 -3.73956156e+00 -2.85064268e+00 ... -3.92481357e-01\n",
" -8.00529659e-01 -8.39800835e-01]\n",
" [-4.30351114e+00 -4.21471930e+00 -2.41703367e+00 ... -1.27081513e+00\n",
" 1.67839837e+00 8.47821474e-01]]\n",
"\n",
" [[-5.27856112e-01 -1.09752083e+00 3.39107156e-01 ... 2.00062895e+00\n",
" 8.83528054e-01 2.57416844e-01]\n",
" [-1.58655810e+00 -3.36268663e-01 1.16161990e+00 ... 1.54868484e+00\n",
" 2.38878536e+00 1.84097290e+00]\n",
" [ 5.96052647e-01 2.15484858e-01 1.85280466e+00 ... 2.74587560e+00\n",
" 1.61432290e+00 1.13214278e+00]\n",
" ...\n",
" [-4.57659864e+00 -5.42679739e+00 -4.35204458e+00 ... -1.82452416e+00\n",
" -2.18670201e+00 -3.91811800e+00]\n",
" [-1.32477629e+00 -4.19110394e+00 -3.41308069e+00 ... 1.39622003e-01\n",
" -1.59393203e+00 -9.08105671e-01]\n",
" [-3.60161018e+00 -4.05932713e+00 -2.23674798e+00 ... 9.09647286e-01\n",
" 9.73127842e-01 1.19991803e+00]]\n",
"\n",
" [[ 2.04062796e+00 7.95603275e-01 -1.28833270e+00 ... 4.64749050e+00\n",
" 2.25974560e+00 1.02396965e+00]\n",
" [ 1.68882537e+00 2.63353348e+00 2.53597498e-02 ... 4.69063854e+00\n",
" -4.19382691e-01 2.91669458e-01]\n",
" [ 7.71395087e-01 1.20833695e+00 -2.58601785e-01 ... 1.21794045e+00\n",
" -1.51922226e-01 7.44265199e-01]\n",
" ...\n",
" [-6.66095781e+00 -4.81577682e+00 -5.39921665e+00 ... -2.20548606e+00\n",
" 5.72486281e-01 -4.35207397e-01]\n",
" [-7.51608658e+00 -6.67776871e+00 -3.73199415e+00 ... -1.70327055e+00\n",
" 1.01334639e-02 -3.20627165e+00]\n",
" [-5.73050356e+00 -2.74379373e+00 -3.70248461e+00 ... -1.09794116e+00\n",
" -1.73590891e-02 -1.80156028e+00]]]]\n",
"param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: [-1.4305115e-06 0.0000000e+00 -4.0531158e-06 -1.6689301e-06\n",
" 2.3841858e-07 -7.1525574e-07 1.1920929e-06 1.5497208e-06\n",
" -2.3841858e-07 1.6689301e-06 9.5367432e-07 9.5367432e-07\n",
" -2.6226044e-06 1.1920929e-06 1.3113022e-06 1.9669533e-06\n",
" -4.7683716e-07 1.1920929e-06 -1.6689301e-06 -1.5497208e-06\n",
" -2.2649765e-06 4.7683716e-07 2.3841858e-06 -3.5762787e-06\n",
" 2.3841858e-07 2.1457672e-06 -3.5762787e-07 8.3446503e-07\n",
" -3.5762787e-07 -7.1525574e-07 2.6524067e-06 -1.1920929e-06]\n",
"param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: [-3.7669735 1.5226867 1.759756 4.501629 -2.2077336 0.18411277\n",
" 1.3558264 -1.0269645 3.9628277 3.9300344 -2.80754 1.8462183\n",
" -0.03385968 2.1284049 0.46124816 -4.364863 0.78491163 0.25565645\n",
" -5.3538237 3.2606194 0.79100513 -1.4652673 2.769378 1.2283417\n",
" -4.7466464 -1.3404545 -6.9374166 0.710248 2.0944448 0.4334769\n",
" -0.24313992 0.31392363]\n",
"param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: [-0.6251638 2.833331 0.6993131 3.7106915 -2.262496 0.7390424\n",
" 0.5360477 -2.803875 2.1646228 2.117193 -1.9988279 1.5135905\n",
" -2.0181084 2.6450465 0.06302822 -3.0530102 1.4788482 0.5941844\n",
" -3.1690063 1.8753575 -0.0737313 -2.7806277 -0.04483938 0.16129279\n",
" -1.2960215 -0.38020235 -0.55218065 0.10754502 2.065371 -1.4703183\n",
" -0.40964937 -1.4454535 ]\n",
"param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: [[-0.46178514 0.1095643 0.06441769 ... 0.42020613 -0.34181893\n",
" -0.0658682 ]\n",
" [-0.03619978 0.21653323 0.01727325 ... 0.05731536 -0.37822944\n",
" -0.05464617]\n",
" [-0.32397318 0.04158126 -0.08091418 ... 0.0928297 -0.06518176\n",
" -0.40110156]\n",
" ...\n",
" [-0.2702023 0.05126935 0.11825457 ... 0.0069707 -0.36951366\n",
" 0.37071258]\n",
" [-0.11326203 0.19305304 -0.133317 ... -0.13030824 -0.09068564\n",
" 0.32735693]\n",
" [-0.04543798 0.09902512 -0.10745425 ... -0.06685166 -0.3055201\n",
" 0.0752247 ]]\n",
"param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.07338604 0.64991236 0.5465856 ... 0.507725 0.14061031\n",
" 0.3020359 ]\n",
"param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.41395143 -0.28493872 0.36796764 ... 0.2387953 0.06732331\n",
" 0.16263628]\n",
"param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.09370177 -0.12264141 -0.08237482 ... -0.50241685 -0.149155\n",
" -0.25661892]\n",
" [-0.37426725 0.44987115 0.10685667 ... -0.65946174 -0.4499248\n",
" -0.17545304]\n",
" [-0.03753807 0.33422717 0.12750985 ... 0.05405155 -0.17648363\n",
" 0.05315325]\n",
" ...\n",
" [ 0.15721183 0.03064088 -0.00751081 ... 0.27183983 0.3881693\n",
" -0.01544908]\n",
" [ 0.26047793 0.16917065 0.00915196 ... 0.18076143 -0.05080506\n",
" 0.14791614]\n",
" [ 0.19052255 0.03642382 -0.14313167 ... 0.2611448 0.20763844\n",
" 0.26846847]]\n",
"param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.4139514 -0.28493875 0.36796758 ... 0.23879525 0.06732336\n",
" 0.16263627]\n",
"param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[ 0.04214853 -0.1710323 0.17557406 ... 0.11926915 0.21577051\n",
" -0.30598596]\n",
" [-0.02370887 -0.03498494 -0.05991999 ... -0.06049232 -0.14527473\n",
" -0.5335691 ]\n",
" [-0.21417995 -0.10263194 -0.05903128 ... -0.26958284 0.05936668\n",
" 0.25522667]\n",
" ...\n",
" [ 0.31594425 -0.29487017 0.15871571 ... 0.3504135 -0.1418606\n",
" -0.07482046]\n",
" [ 0.22316164 0.7682122 -0.22191924 ... -0.00535548 -0.6497105\n",
" -0.2011079 ]\n",
" [-0.05800886 0.13750821 0.02450509 ... 0.245736 0.07425706\n",
" -0.17761081]]\n",
"param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.45080703 0.19005743 0.077441 ... -0.24504453 0.19666554\n",
" -0.10503208]\n",
"param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237206 0.03389215 ... -0.35602498 0.25528812\n",
" 0.11344345]\n",
"param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.48457903 0.04466334 -0.19785863 ... -0.0254025 -0.10338341\n",
" -0.29202533]\n",
" [-0.15261276 0.00412052 0.22198747 ... 0.22460426 -0.03752084\n",
" 0.05170784]\n",
" [-0.09337254 0.02530848 0.1263681 ... -0.02056236 0.33342454\n",
" -0.08760723]\n",
" ...\n",
" [-0.28645608 -0.19169135 -0.1361257 ... -0.00444204 -0.06552711\n",
" -0.14726155]\n",
" [ 0.21883707 0.2049045 0.23723911 ... 0.4626113 -0.14110637\n",
" 0.02569831]\n",
" [ 0.37554163 -0.19249167 0.14591683 ... 0.25602737 0.40088275\n",
" 0.41056633]]\n",
"param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237211 0.0338921 ... -0.35602498 0.2552881\n",
" 0.11344352]\n",
"param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[-0.28007814 -0.09206 -0.01297755 ... -0.2557205 -0.2693453\n",
" 0.05862035]\n",
" [-0.34194735 -0.01383794 -0.06490533 ... -0.11063005 0.16226721\n",
" -0.3197178 ]\n",
" [-0.3646778 0.15443833 0.02241019 ... -0.15093157 -0.09886418\n",
" -0.44295847]\n",
" ...\n",
" [-0.01041886 -0.57636976 -0.03988511 ... -0.2260822 0.49646813\n",
" -0.15528557]\n",
" [-0.19385241 -0.56451964 -0.05551083 ... -0.5638106 0.43611372\n",
" -0.61484563]\n",
" [ 0.1051331 -0.4762463 0.11194798 ... -0.26766616 -0.30734932\n",
" 0.17856634]]\n",
"param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.02791309 -0.992517 0.63012564 ... -1.1830902 1.4646478\n",
" 1.6333911 ]\n",
"param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.10834587 -1.7079136 0.81259465 ... -1.4478713 1.455745\n",
" 2.069446 ]\n",
"param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.14363798 -0.06933184 0.02901152 ... -0.19233373 -0.03206367\n",
" -0.00845779]\n",
" [-0.44314507 -0.8921327 -1.031872 ... -0.558997 -0.53070104\n",
" -0.855925 ]\n",
" [ 0.15673254 0.28793585 0.13351494 ... 0.38433537 0.5040767\n",
" 0.11303265]\n",
" ...\n",
" [-0.22923109 -0.62508404 -0.6195032 ... -0.6876448 -0.41718128\n",
" -0.74844164]\n",
" [ 0.18024652 0.45618314 0.81391454 ... 0.5780604 0.87566674\n",
" 0.71526295]\n",
" [ 0.3763076 0.54033077 0.9940485 ... 1.087821 0.72288674\n",
" 1.2852117 ]]\n",
"param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.10834593 -1.7079139 0.8125948 ... -1.4478711 1.4557447\n",
" 2.0694466 ]\n",
"param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: [[ 1.4382483e-02 2.0160766e-02 1.2322801e-02 ... 1.0075266e-02\n",
" 7.4421698e-03 -2.3925617e+01]\n",
" [ 3.7887424e-02 5.7105277e-02 2.8803380e-02 ... 2.4820438e-02\n",
" 1.8560058e-02 -5.0687141e+01]\n",
" [ 4.5566272e-02 5.4415584e-02 3.2858539e-02 ... 3.2725763e-02\n",
" 2.1536341e-02 -6.1036335e+01]\n",
" ...\n",
" [ 2.8015019e-02 3.5967816e-02 2.3228688e-02 ... 2.1284629e-02\n",
" 1.3860047e-02 -5.2543671e+01]\n",
" [ 2.8445240e-02 4.2448867e-02 2.7125146e-02 ... 2.2253662e-02\n",
" 1.7470375e-02 -4.3619675e+01]\n",
" [ 4.7438074e-02 5.8287360e-02 3.4546286e-02 ... 3.0827176e-02\n",
" 2.2168703e-02 -6.7901680e+01]]\n",
"param grad: fc.bias: shape: [4299] stop_grad: False grad: [ 8.8967547e-02 1.0697905e-01 6.5251388e-02 ... 6.1503030e-02\n",
" 4.3404289e-02 -1.3512518e+02]\n"
]
}
],
"source": [
"loss.backward(retain_graph=False)\n",
"for n, p in dp_model.named_parameters():\n",
" print(\n",
" f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "selected-crazy",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1.]\n"
]
}
],
"source": [
"print(loss.grad)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bottom-engineer",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "stuffed-yeast",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
\ No newline at end of file
因为 它太大了无法显示 source diff 。你可以改为 查看blob
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "choice-grade",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/DeepSpeech-2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "broke-broad",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"register user softmax to paddle, remove this when fixed!\n",
"register user log_softmax to paddle, remove this when fixed!\n",
"register user sigmoid to paddle, remove this when fixed!\n",
"register user log_sigmoid to paddle, remove this when fixed!\n",
"register user relu to paddle, remove this when fixed!\n",
"override cat of paddle if exists or register, remove this when fixed!\n",
"override item of paddle.Tensor if exists or register, remove this when fixed!\n",
"override long of paddle.Tensor if exists or register, remove this when fixed!\n",
"override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
"override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"override eq of paddle if exists or register, remove this when fixed!\n",
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"register user view to paddle.Tensor, remove this when fixed!\n",
"register user view_as to paddle.Tensor, remove this when fixed!\n",
"register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"register user fill_ to paddle.Tensor, remove this when fixed!\n",
"register user repeat to paddle.Tensor, remove this when fixed!\n",
"register user softmax to paddle.Tensor, remove this when fixed!\n",
"register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"register user relu to paddle.Tensor, remove this when fixed!\n",
"register user type_as to paddle.Tensor, remove this when fixed!\n",
"register user to to paddle.Tensor, remove this when fixed!\n",
"register user float to paddle.Tensor, remove this when fixed!\n",
"register user tolist to paddle.Tensor, remove this when fixed!\n",
"register user glu to paddle.nn.functional, remove this when fixed!\n",
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"register user Module to paddle.nn, remove this when fixed!\n",
"register user ModuleList to paddle.nn, remove this when fixed!\n",
"register user GLU to paddle.nn, remove this when fixed!\n",
"register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"register user export to paddle.jit, remove this when fixed!\n"
]
}
],
"source": [
"import numpy as np\n",
"import paddle\n",
"from yacs.config import CfgNode as CN\n",
"\n",
"from deepspeech.models.u2 import U2Model\n",
"from deepspeech.utils.layer_tools import print_params\n",
"from deepspeech.utils.layer_tools import summary"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "permanent-summary",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"[INFO 2021/05/31 03:23:22 u2.py:839] U2 Encoder type: transformer\n",
"[INFO 2021/05/31 03:23:22 u2.py:840] attention_dropout_rate: 0.0\n",
"attention_heads: 4\n",
"dropout_rate: 0.1\n",
"input_layer: conv2d\n",
"linear_units: 2048\n",
"normalize_before: True\n",
"num_blocks: 12\n",
"output_size: 256\n",
"positional_dropout_rate: 0.1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n",
"encoder.embed.conv.0.bias | [256] | 256 | True\n",
"encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n",
"encoder.embed.conv.2.bias | [256] | 256 | True\n",
"encoder.embed.out.0.weight | [5120, 256] | 1310720 | True\n",
"encoder.embed.out.0.bias | [256] | 256 | True\n",
"encoder.after_norm.weight | [256] | 256 | True\n",
"encoder.after_norm.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n",
"decoder.embed.0.weight | [4233, 256] | 1083648 | True\n",
"decoder.after_norm.weight | [256] | 256 | True\n",
"decoder.after_norm.bias | [256] | 256 | True\n",
"decoder.output_layer.weight | [256, 4233] | 1083648 | True\n",
"decoder.output_layer.bias | [4233] | 4233 | True\n",
"decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n",
"ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n",
"ctc.ctc_lo.bias | [4233] | 4233 | True\n",
"Total parameters: 411.0, 32.01M elements.\n"
]
}
],
"source": [
"conf_str='examples/tiny/s1/conf/transformer.yaml'\n",
"cfg = CN().load_cfg(open(conf_str))\n",
"cfg.model.input_dim = 83\n",
"cfg.model.output_dim = 4233\n",
"cfg.model.cmvn_file = None\n",
"cfg.model.cmvn_file_type = 'json'\n",
"#cfg.model.encoder_conf.concat_after=True\n",
"cfg.freeze()\n",
"model = U2Model(cfg.model)\n",
"\n",
"print_params(model)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "sapphire-agent",
"metadata": {},
"outputs": [],
"source": [
"#summary(model)\n",
"#print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ruled-invitation",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fossil-means",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n",
"encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n"
]
}
],
"source": [
"%ls .notebook/espnet"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "45c2b75f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"state\n",
"odict_keys(['mask_feature', 'encoder.embed.conv.0.weight', 'encoder.embed.conv.0.bias', 'encoder.embed.conv.2.weight', 'encoder.embed.conv.2.bias', 'encoder.embed.out.0.weight', 'encoder.embed.out.0.bias', 'encoder.encoders.0.self_attn.linear_q.weight', 'encoder.encoders.0.self_attn.linear_q.bias', 'encoder.encoders.0.self_attn.linear_k.weight', 'encoder.encoders.0.self_attn.linear_k.bias', 'encoder.encoders.0.self_attn.linear_v.weight', 'encoder.encoders.0.self_attn.linear_v.bias', 'encoder.encoders.0.self_attn.linear_out.weight', 'encoder.encoders.0.self_attn.linear_out.bias', 'encoder.encoders.0.feed_forward.w_1.weight', 'encoder.encoders.0.feed_forward.w_1.bias', 'encoder.encoders.0.feed_forward.w_2.weight', 'encoder.encoders.0.feed_forward.w_2.bias', 'encoder.encoders.0.norm1.weight', 'encoder.encoders.0.norm1.bias', 'encoder.encoders.0.norm2.weight', 'encoder.encoders.0.norm2.bias', 'encoder.encoders.1.self_attn.linear_q.weight', 'encoder.encoders.1.self_attn.linear_q.bias', 'encoder.encoders.1.self_attn.linear_k.weight', 'encoder.encoders.1.self_attn.linear_k.bias', 'encoder.encoders.1.self_attn.linear_v.weight', 'encoder.encoders.1.self_attn.linear_v.bias', 'encoder.encoders.1.self_attn.linear_out.weight', 'encoder.encoders.1.self_attn.linear_out.bias', 'encoder.encoders.1.feed_forward.w_1.weight', 'encoder.encoders.1.feed_forward.w_1.bias', 'encoder.encoders.1.feed_forward.w_2.weight', 'encoder.encoders.1.feed_forward.w_2.bias', 'encoder.encoders.1.norm1.weight', 'encoder.encoders.1.norm1.bias', 'encoder.encoders.1.norm2.weight', 'encoder.encoders.1.norm2.bias', 'encoder.encoders.2.self_attn.linear_q.weight', 'encoder.encoders.2.self_attn.linear_q.bias', 'encoder.encoders.2.self_attn.linear_k.weight', 'encoder.encoders.2.self_attn.linear_k.bias', 'encoder.encoders.2.self_attn.linear_v.weight', 'encoder.encoders.2.self_attn.linear_v.bias', 'encoder.encoders.2.self_attn.linear_out.weight', 'encoder.encoders.2.self_attn.linear_out.bias', 'encoder.encoders.2.feed_forward.w_1.weight', 'encoder.encoders.2.feed_forward.w_1.bias', 'encoder.encoders.2.feed_forward.w_2.weight', 'encoder.encoders.2.feed_forward.w_2.bias', 'encoder.encoders.2.norm1.weight', 'encoder.encoders.2.norm1.bias', 'encoder.encoders.2.norm2.weight', 'encoder.encoders.2.norm2.bias', 'encoder.encoders.3.self_attn.linear_q.weight', 'encoder.encoders.3.self_attn.linear_q.bias', 'encoder.encoders.3.self_attn.linear_k.weight', 'encoder.encoders.3.self_attn.linear_k.bias', 'encoder.encoders.3.self_attn.linear_v.weight', 'encoder.encoders.3.self_attn.linear_v.bias', 'encoder.encoders.3.self_attn.linear_out.weight', 'encoder.encoders.3.self_attn.linear_out.bias', 'encoder.encoders.3.feed_forward.w_1.weight', 'encoder.encoders.3.feed_forward.w_1.bias', 'encoder.encoders.3.feed_forward.w_2.weight', 'encoder.encoders.3.feed_forward.w_2.bias', 'encoder.encoders.3.norm1.weight', 'encoder.encoders.3.norm1.bias', 'encoder.encoders.3.norm2.weight', 'encoder.encoders.3.norm2.bias', 'encoder.encoders.4.self_attn.linear_q.weight', 'encoder.encoders.4.self_attn.linear_q.bias', 'encoder.encoders.4.self_attn.linear_k.weight', 'encoder.encoders.4.self_attn.linear_k.bias', 'encoder.encoders.4.self_attn.linear_v.weight', 'encoder.encoders.4.self_attn.linear_v.bias', 'encoder.encoders.4.self_attn.linear_out.weight', 'encoder.encoders.4.self_attn.linear_out.bias', 'encoder.encoders.4.feed_forward.w_1.weight', 'encoder.encoders.4.feed_forward.w_1.bias', 'encoder.encoders.4.feed_forward.w_2.weight', 'encoder.encoders.4.feed_forward.w_2.bias', 'encoder.encoders.4.norm1.weight', 'encoder.encoders.4.norm1.bias', 'encoder.encoders.4.norm2.weight', 'encoder.encoders.4.norm2.bias', 'encoder.encoders.5.self_attn.linear_q.weight', 'encoder.encoders.5.self_attn.linear_q.bias', 'encoder.encoders.5.self_attn.linear_k.weight', 'encoder.encoders.5.self_attn.linear_k.bias', 'encoder.encoders.5.self_attn.linear_v.weight', 'encoder.encoders.5.self_attn.linear_v.bias', 'encoder.encoders.5.self_attn.linear_out.weight', 'encoder.encoders.5.self_attn.linear_out.bias', 'encoder.encoders.5.feed_forward.w_1.weight', 'encoder.encoders.5.feed_forward.w_1.bias', 'encoder.encoders.5.feed_forward.w_2.weight', 'encoder.encoders.5.feed_forward.w_2.bias', 'encoder.encoders.5.norm1.weight', 'encoder.encoders.5.norm1.bias', 'encoder.encoders.5.norm2.weight', 'encoder.encoders.5.norm2.bias', 'encoder.encoders.6.self_attn.linear_q.weight', 'encoder.encoders.6.self_attn.linear_q.bias', 'encoder.encoders.6.self_attn.linear_k.weight', 'encoder.encoders.6.self_attn.linear_k.bias', 'encoder.encoders.6.self_attn.linear_v.weight', 'encoder.encoders.6.self_attn.linear_v.bias', 'encoder.encoders.6.self_attn.linear_out.weight', 'encoder.encoders.6.self_attn.linear_out.bias', 'encoder.encoders.6.feed_forward.w_1.weight', 'encoder.encoders.6.feed_forward.w_1.bias', 'encoder.encoders.6.feed_forward.w_2.weight', 'encoder.encoders.6.feed_forward.w_2.bias', 'encoder.encoders.6.norm1.weight', 'encoder.encoders.6.norm1.bias', 'encoder.encoders.6.norm2.weight', 'encoder.encoders.6.norm2.bias', 'encoder.encoders.7.self_attn.linear_q.weight', 'encoder.encoders.7.self_attn.linear_q.bias', 'encoder.encoders.7.self_attn.linear_k.weight', 'encoder.encoders.7.self_attn.linear_k.bias', 'encoder.encoders.7.self_attn.linear_v.weight', 'encoder.encoders.7.self_attn.linear_v.bias', 'encoder.encoders.7.self_attn.linear_out.weight', 'encoder.encoders.7.self_attn.linear_out.bias', 'encoder.encoders.7.feed_forward.w_1.weight', 'encoder.encoders.7.feed_forward.w_1.bias', 'encoder.encoders.7.feed_forward.w_2.weight', 'encoder.encoders.7.feed_forward.w_2.bias', 'encoder.encoders.7.norm1.weight', 'encoder.encoders.7.norm1.bias', 'encoder.encoders.7.norm2.weight', 'encoder.encoders.7.norm2.bias', 'encoder.encoders.8.self_attn.linear_q.weight', 'encoder.encoders.8.self_attn.linear_q.bias', 'encoder.encoders.8.self_attn.linear_k.weight', 'encoder.encoders.8.self_attn.linear_k.bias', 'encoder.encoders.8.self_attn.linear_v.weight', 'encoder.encoders.8.self_attn.linear_v.bias', 'encoder.encoders.8.self_attn.linear_out.weight', 'encoder.encoders.8.self_attn.linear_out.bias', 'encoder.encoders.8.feed_forward.w_1.weight', 'encoder.encoders.8.feed_forward.w_1.bias', 'encoder.encoders.8.feed_forward.w_2.weight', 'encoder.encoders.8.feed_forward.w_2.bias', 'encoder.encoders.8.norm1.weight', 'encoder.encoders.8.norm1.bias', 'encoder.encoders.8.norm2.weight', 'encoder.encoders.8.norm2.bias', 'encoder.encoders.9.self_attn.linear_q.weight', 'encoder.encoders.9.self_attn.linear_q.bias', 'encoder.encoders.9.self_attn.linear_k.weight', 'encoder.encoders.9.self_attn.linear_k.bias', 'encoder.encoders.9.self_attn.linear_v.weight', 'encoder.encoders.9.self_attn.linear_v.bias', 'encoder.encoders.9.self_attn.linear_out.weight', 'encoder.encoders.9.self_attn.linear_out.bias', 'encoder.encoders.9.feed_forward.w_1.weight', 'encoder.encoders.9.feed_forward.w_1.bias', 'encoder.encoders.9.feed_forward.w_2.weight', 'encoder.encoders.9.feed_forward.w_2.bias', 'encoder.encoders.9.norm1.weight', 'encoder.encoders.9.norm1.bias', 'encoder.encoders.9.norm2.weight', 'encoder.encoders.9.norm2.bias', 'encoder.encoders.10.self_attn.linear_q.weight', 'encoder.encoders.10.self_attn.linear_q.bias', 'encoder.encoders.10.self_attn.linear_k.weight', 'encoder.encoders.10.self_attn.linear_k.bias', 'encoder.encoders.10.self_attn.linear_v.weight', 'encoder.encoders.10.self_attn.linear_v.bias', 'encoder.encoders.10.self_attn.linear_out.weight', 'encoder.encoders.10.self_attn.linear_out.bias', 'encoder.encoders.10.feed_forward.w_1.weight', 'encoder.encoders.10.feed_forward.w_1.bias', 'encoder.encoders.10.feed_forward.w_2.weight', 'encoder.encoders.10.feed_forward.w_2.bias', 'encoder.encoders.10.norm1.weight', 'encoder.encoders.10.norm1.bias', 'encoder.encoders.10.norm2.weight', 'encoder.encoders.10.norm2.bias', 'encoder.encoders.11.self_attn.linear_q.weight', 'encoder.encoders.11.self_attn.linear_q.bias', 'encoder.encoders.11.self_attn.linear_k.weight', 'encoder.encoders.11.self_attn.linear_k.bias', 'encoder.encoders.11.self_attn.linear_v.weight', 'encoder.encoders.11.self_attn.linear_v.bias', 'encoder.encoders.11.self_attn.linear_out.weight', 'encoder.encoders.11.self_attn.linear_out.bias', 'encoder.encoders.11.feed_forward.w_1.weight', 'encoder.encoders.11.feed_forward.w_1.bias', 'encoder.encoders.11.feed_forward.w_2.weight', 'encoder.encoders.11.feed_forward.w_2.bias', 'encoder.encoders.11.norm1.weight', 'encoder.encoders.11.norm1.bias', 'encoder.encoders.11.norm2.weight', 'encoder.encoders.11.norm2.bias', 'encoder.after_norm.weight', 'encoder.after_norm.bias', 'decoder.embed.0.weight', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'sfc.weight', 'sfc.bias', 'deconv.0.weight', 'deconv.0.bias', 'deconv.1.weight', 'deconv.1.bias', 'xlm_embed.0.weight', 'xlm_pred.weight', 'xlm_pred.bias'])\n"
]
}
],
"source": [
"#!pip install torch\n",
"import torch\n",
"\n",
"e_model = np.load('.notebook/espnet/model.npz',allow_pickle=True)\n",
"for k in e_model.files:\n",
" print(k)\n",
"state_dict = e_model['state']\n",
"state_dict = state_dict.tolist()\n",
"print(state_dict.keys())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f187bb55",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"# embed.conv.0.weight None torch.Size([256, 1, 3, 3]) \tencoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n",
"# embed.conv.0.bias None torch.Size([256]) \tencoder.embed.conv.0.bias | [256] | 256 | True\n",
"# embed.conv.2.weight None torch.Size([256, 256, 3, 3]) \tencoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n",
"# embed.conv.2.bias None torch.Size([256]) \tencoder.embed.conv.2.bias | [256] | 256 | True\n",
"# embed.out.0.weight None torch.Size([256, 5120]) 83 feature\tencoder.embed.out.0.weight | [4864, 256] | 1245184 | True 80 feature\n",
"# embed.out.0.bias None torch.Size([256]) \tencoder.embed.out.0.bias | [256] | 256 | True\n",
"# after_norm.weight None torch.Size([256]) \tencoder.after_norm.weight | [256] | 256 | True\n",
"# after_norm.bias None torch.Size([256]) \tencoder.after_norm.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_q.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_q.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_k.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_k.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_v.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_v.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_out.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_out.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"# encoders.9.feed_forward.w_1.weight None torch.Size([2048, 256]) \tencoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"# encoders.9.feed_forward.w_1.bias None torch.Size([2048]) \tencoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"# encoders.9.feed_forward.w_2.weight None torch.Size([256, 2048]) \tencoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"# encoders.9.feed_forward.w_2.bias None torch.Size([256]) \tencoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"# encoders.9.norm1.weight None torch.Size([256]) \tencoder.encoders.0.norm1.weight | [256] | 256 | True\n",
"# encoders.9.norm1.bias None torch.Size([256]) \tencoder.encoders.0.norm1.bias | [256] | 256 | True\n",
"# encoders.9.norm2.weight None torch.Size([256]) \tencoder.encoders.0.norm2.weight | [256] | 256 | True\n",
"# encoders.9.norm2.bias None torch.Size([256]) \tencoder.encoders.0.norm2.bias | [256] | 256 | True\n",
"# \tencoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n",
"# \tencoder.encoders.0.concat_linear.bias | [256] | 256 | True\n",
"# espnet transformer\tconcat_linear只是保存了,但是未使用\n",
"\t\n",
"# \tpaddle transformer"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2a0428ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-> encoder.embed.conv.0.weight\n",
"-> encoder.embed.conv.0.bias\n",
"-> encoder.embed.conv.2.weight\n",
"-> encoder.embed.conv.2.bias\n",
"-> encoder.embed.out.0.weight\n",
"encoder.embed.out.0.weight: (256, 5120) -> (5120, 256)\n",
"-> encoder.embed.out.0.bias\n",
"-> encoder.encoders.0.self_attn.linear_q.weight\n",
"encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_q.bias\n",
"-> encoder.encoders.0.self_attn.linear_k.weight\n",
"encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_k.bias\n",
"-> encoder.encoders.0.self_attn.linear_v.weight\n",
"encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_v.bias\n",
"-> encoder.encoders.0.self_attn.linear_out.weight\n",
"encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_out.bias\n",
"-> encoder.encoders.0.feed_forward.w_1.weight\n",
"encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.0.feed_forward.w_1.bias\n",
"-> encoder.encoders.0.feed_forward.w_2.weight\n",
"encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.0.feed_forward.w_2.bias\n",
"-> encoder.encoders.0.norm1.weight\n",
"-> encoder.encoders.0.norm1.bias\n",
"-> encoder.encoders.0.norm2.weight\n",
"-> encoder.encoders.0.norm2.bias\n",
"-> encoder.encoders.1.self_attn.linear_q.weight\n",
"encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_q.bias\n",
"-> encoder.encoders.1.self_attn.linear_k.weight\n",
"encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_k.bias\n",
"-> encoder.encoders.1.self_attn.linear_v.weight\n",
"encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_v.bias\n",
"-> encoder.encoders.1.self_attn.linear_out.weight\n",
"encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_out.bias\n",
"-> encoder.encoders.1.feed_forward.w_1.weight\n",
"encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.1.feed_forward.w_1.bias\n",
"-> encoder.encoders.1.feed_forward.w_2.weight\n",
"encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.1.feed_forward.w_2.bias\n",
"-> encoder.encoders.1.norm1.weight\n",
"-> encoder.encoders.1.norm1.bias\n",
"-> encoder.encoders.1.norm2.weight\n",
"-> encoder.encoders.1.norm2.bias\n",
"-> encoder.encoders.2.self_attn.linear_q.weight\n",
"encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_q.bias\n",
"-> encoder.encoders.2.self_attn.linear_k.weight\n",
"encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_k.bias\n",
"-> encoder.encoders.2.self_attn.linear_v.weight\n",
"encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_v.bias\n",
"-> encoder.encoders.2.self_attn.linear_out.weight\n",
"encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_out.bias\n",
"-> encoder.encoders.2.feed_forward.w_1.weight\n",
"encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.2.feed_forward.w_1.bias\n",
"-> encoder.encoders.2.feed_forward.w_2.weight\n",
"encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.2.feed_forward.w_2.bias\n",
"-> encoder.encoders.2.norm1.weight\n",
"-> encoder.encoders.2.norm1.bias\n",
"-> encoder.encoders.2.norm2.weight\n",
"-> encoder.encoders.2.norm2.bias\n",
"-> encoder.encoders.3.self_attn.linear_q.weight\n",
"encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_q.bias\n",
"-> encoder.encoders.3.self_attn.linear_k.weight\n",
"encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_k.bias\n",
"-> encoder.encoders.3.self_attn.linear_v.weight\n",
"encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_v.bias\n",
"-> encoder.encoders.3.self_attn.linear_out.weight\n",
"encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_out.bias\n",
"-> encoder.encoders.3.feed_forward.w_1.weight\n",
"encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.3.feed_forward.w_1.bias\n",
"-> encoder.encoders.3.feed_forward.w_2.weight\n",
"encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.3.feed_forward.w_2.bias\n",
"-> encoder.encoders.3.norm1.weight\n",
"-> encoder.encoders.3.norm1.bias\n",
"-> encoder.encoders.3.norm2.weight\n",
"-> encoder.encoders.3.norm2.bias\n",
"-> encoder.encoders.4.self_attn.linear_q.weight\n",
"encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_q.bias\n",
"-> encoder.encoders.4.self_attn.linear_k.weight\n",
"encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_k.bias\n",
"-> encoder.encoders.4.self_attn.linear_v.weight\n",
"encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_v.bias\n",
"-> encoder.encoders.4.self_attn.linear_out.weight\n",
"encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_out.bias\n",
"-> encoder.encoders.4.feed_forward.w_1.weight\n",
"encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.4.feed_forward.w_1.bias\n",
"-> encoder.encoders.4.feed_forward.w_2.weight\n",
"encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.4.feed_forward.w_2.bias\n",
"-> encoder.encoders.4.norm1.weight\n",
"-> encoder.encoders.4.norm1.bias\n",
"-> encoder.encoders.4.norm2.weight\n",
"-> encoder.encoders.4.norm2.bias\n",
"-> encoder.encoders.5.self_attn.linear_q.weight\n",
"encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_q.bias\n",
"-> encoder.encoders.5.self_attn.linear_k.weight\n",
"encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_k.bias\n",
"-> encoder.encoders.5.self_attn.linear_v.weight\n",
"encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_v.bias\n",
"-> encoder.encoders.5.self_attn.linear_out.weight\n",
"encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_out.bias\n",
"-> encoder.encoders.5.feed_forward.w_1.weight\n",
"encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.5.feed_forward.w_1.bias\n",
"-> encoder.encoders.5.feed_forward.w_2.weight\n",
"encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.5.feed_forward.w_2.bias\n",
"-> encoder.encoders.5.norm1.weight\n",
"-> encoder.encoders.5.norm1.bias\n",
"-> encoder.encoders.5.norm2.weight\n",
"-> encoder.encoders.5.norm2.bias\n",
"-> encoder.encoders.6.self_attn.linear_q.weight\n",
"encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_q.bias\n",
"-> encoder.encoders.6.self_attn.linear_k.weight\n",
"encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_k.bias\n",
"-> encoder.encoders.6.self_attn.linear_v.weight\n",
"encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_v.bias\n",
"-> encoder.encoders.6.self_attn.linear_out.weight\n",
"encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_out.bias\n",
"-> encoder.encoders.6.feed_forward.w_1.weight\n",
"encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.6.feed_forward.w_1.bias\n",
"-> encoder.encoders.6.feed_forward.w_2.weight\n",
"encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.6.feed_forward.w_2.bias\n",
"-> encoder.encoders.6.norm1.weight\n",
"-> encoder.encoders.6.norm1.bias\n",
"-> encoder.encoders.6.norm2.weight\n",
"-> encoder.encoders.6.norm2.bias\n",
"-> encoder.encoders.7.self_attn.linear_q.weight\n",
"encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_q.bias\n",
"-> encoder.encoders.7.self_attn.linear_k.weight\n",
"encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_k.bias\n",
"-> encoder.encoders.7.self_attn.linear_v.weight\n",
"encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_v.bias\n",
"-> encoder.encoders.7.self_attn.linear_out.weight\n",
"encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_out.bias\n",
"-> encoder.encoders.7.feed_forward.w_1.weight\n",
"encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.7.feed_forward.w_1.bias\n",
"-> encoder.encoders.7.feed_forward.w_2.weight\n",
"encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.7.feed_forward.w_2.bias\n",
"-> encoder.encoders.7.norm1.weight\n",
"-> encoder.encoders.7.norm1.bias\n",
"-> encoder.encoders.7.norm2.weight\n",
"-> encoder.encoders.7.norm2.bias\n",
"-> encoder.encoders.8.self_attn.linear_q.weight\n",
"encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_q.bias\n",
"-> encoder.encoders.8.self_attn.linear_k.weight\n",
"encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_k.bias\n",
"-> encoder.encoders.8.self_attn.linear_v.weight\n",
"encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_v.bias\n",
"-> encoder.encoders.8.self_attn.linear_out.weight\n",
"encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_out.bias\n",
"-> encoder.encoders.8.feed_forward.w_1.weight\n",
"encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.8.feed_forward.w_1.bias\n",
"-> encoder.encoders.8.feed_forward.w_2.weight\n",
"encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.8.feed_forward.w_2.bias\n",
"-> encoder.encoders.8.norm1.weight\n",
"-> encoder.encoders.8.norm1.bias\n",
"-> encoder.encoders.8.norm2.weight\n",
"-> encoder.encoders.8.norm2.bias\n",
"-> encoder.encoders.9.self_attn.linear_q.weight\n",
"encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_q.bias\n",
"-> encoder.encoders.9.self_attn.linear_k.weight\n",
"encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_k.bias\n",
"-> encoder.encoders.9.self_attn.linear_v.weight\n",
"encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_v.bias\n",
"-> encoder.encoders.9.self_attn.linear_out.weight\n",
"encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_out.bias\n",
"-> encoder.encoders.9.feed_forward.w_1.weight\n",
"encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.9.feed_forward.w_1.bias\n",
"-> encoder.encoders.9.feed_forward.w_2.weight\n",
"encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.9.feed_forward.w_2.bias\n",
"-> encoder.encoders.9.norm1.weight\n",
"-> encoder.encoders.9.norm1.bias\n",
"-> encoder.encoders.9.norm2.weight\n",
"-> encoder.encoders.9.norm2.bias\n",
"-> encoder.encoders.10.self_attn.linear_q.weight\n",
"encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_q.bias\n",
"-> encoder.encoders.10.self_attn.linear_k.weight\n",
"encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_k.bias\n",
"-> encoder.encoders.10.self_attn.linear_v.weight\n",
"encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_v.bias\n",
"-> encoder.encoders.10.self_attn.linear_out.weight\n",
"encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_out.bias\n",
"-> encoder.encoders.10.feed_forward.w_1.weight\n",
"encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.10.feed_forward.w_1.bias\n",
"-> encoder.encoders.10.feed_forward.w_2.weight\n",
"encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.10.feed_forward.w_2.bias\n",
"-> encoder.encoders.10.norm1.weight\n",
"-> encoder.encoders.10.norm1.bias\n",
"-> encoder.encoders.10.norm2.weight\n",
"-> encoder.encoders.10.norm2.bias\n",
"-> encoder.encoders.11.self_attn.linear_q.weight\n",
"encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_q.bias\n",
"-> encoder.encoders.11.self_attn.linear_k.weight\n",
"encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_k.bias\n",
"-> encoder.encoders.11.self_attn.linear_v.weight\n",
"encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_v.bias\n",
"-> encoder.encoders.11.self_attn.linear_out.weight\n",
"encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_out.bias\n",
"-> encoder.encoders.11.feed_forward.w_1.weight\n",
"encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.11.feed_forward.w_1.bias\n",
"-> encoder.encoders.11.feed_forward.w_2.weight\n",
"encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.11.feed_forward.w_2.bias\n",
"-> encoder.encoders.11.norm1.weight\n",
"-> encoder.encoders.11.norm1.bias\n",
"-> encoder.encoders.11.norm2.weight\n",
"-> encoder.encoders.11.norm2.bias\n",
"-> encoder.after_norm.weight\n",
"-> encoder.after_norm.bias\n"
]
}
],
"source": [
"# dump torch model to paddle\n",
"#state_dict = model.state_dict()\n",
"paddle_state_dict = {}\n",
"\n",
"for n, p in state_dict.items():\n",
" if 'encoder' not in n:\n",
" continue \n",
" print(f'-> {n}')\n",
" \n",
" \n",
" name_change=True\n",
" if 'norm.running_mean' in n:\n",
" new_n = n.replace('norm.running_', 'norm._')\n",
" elif 'norm.running_var' in n:\n",
" new_n = n.replace('norm.running_var', 'norm._variance')\n",
" else:\n",
" name_change=False\n",
" new_n = n\n",
" if name_change:\n",
" print(f\"{n} -> {new_n}\")\n",
" \n",
" \n",
" p = p.cpu().detach().numpy()\n",
" if n.endswith('weight') and p.ndim == 2:\n",
" new_p = p.T\n",
" print(f\"{n}: {p.shape} -> {new_p.shape}\")\n",
" else:\n",
" new_p = p\n",
" \n",
" if 'global_cmvn.mean' in n:\n",
" print(p, p.dtype)\n",
" \n",
" paddle_state_dict[new_n] = new_p\n",
" \n",
"# np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n",
"# state=paddle_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a1d97e9f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.weight. encoder.encoders.0.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.bias. encoder.encoders.0.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.weight. encoder.encoders.1.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.bias. encoder.encoders.1.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.weight. encoder.encoders.2.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.bias. encoder.encoders.2.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.weight. encoder.encoders.3.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.bias. encoder.encoders.3.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.weight. encoder.encoders.4.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.bias. encoder.encoders.4.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.weight. encoder.encoders.5.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.bias. encoder.encoders.5.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.weight. encoder.encoders.6.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.bias. encoder.encoders.6.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.weight. encoder.encoders.7.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.bias. encoder.encoders.7.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.weight. encoder.encoders.8.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.bias. encoder.encoders.8.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.weight. encoder.encoders.9.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.bias. encoder.encoders.9.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.weight. encoder.encoders.10.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.bias. encoder.encoders.10.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.weight. encoder.encoders.11.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.bias. encoder.encoders.11.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.embed.0.weight. decoder.embed.0.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.weight. decoder.after_norm.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.bias. decoder.after_norm.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.weight. decoder.output_layer.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.bias. decoder.output_layer.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.weight. decoder.decoders.0.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.bias. decoder.decoders.0.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.weight. decoder.decoders.0.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.bias. decoder.decoders.0.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.weight. decoder.decoders.0.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.bias. decoder.decoders.0.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.weight. decoder.decoders.0.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.bias. decoder.decoders.0.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.weight. decoder.decoders.0.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.bias. decoder.decoders.0.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.weight. decoder.decoders.0.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.bias. decoder.decoders.0.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.weight. decoder.decoders.0.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.bias. decoder.decoders.0.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.weight. decoder.decoders.0.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.bias. decoder.decoders.0.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.weight. decoder.decoders.0.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.bias. decoder.decoders.0.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.weight. decoder.decoders.0.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.bias. decoder.decoders.0.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.weight. decoder.decoders.0.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.bias. decoder.decoders.0.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.weight. decoder.decoders.0.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.bias. decoder.decoders.0.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.weight. decoder.decoders.0.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.bias. decoder.decoders.0.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.weight. decoder.decoders.0.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.bias. decoder.decoders.0.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.weight. decoder.decoders.0.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.bias. decoder.decoders.0.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.weight. decoder.decoders.1.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.bias. decoder.decoders.1.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.weight. decoder.decoders.1.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.bias. decoder.decoders.1.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.weight. decoder.decoders.1.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.bias. decoder.decoders.1.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.weight. decoder.decoders.1.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.bias. decoder.decoders.1.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.weight. decoder.decoders.1.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.bias. decoder.decoders.1.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.weight. decoder.decoders.1.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.bias. decoder.decoders.1.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.weight. decoder.decoders.1.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.bias. decoder.decoders.1.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.weight. decoder.decoders.1.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.bias. decoder.decoders.1.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.weight. decoder.decoders.1.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.bias. decoder.decoders.1.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.weight. decoder.decoders.1.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.bias. decoder.decoders.1.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.weight. decoder.decoders.1.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.bias. decoder.decoders.1.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.weight. decoder.decoders.1.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.bias. decoder.decoders.1.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.weight. decoder.decoders.1.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.bias. decoder.decoders.1.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.weight. decoder.decoders.1.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.bias. decoder.decoders.1.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.weight. decoder.decoders.1.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.bias. decoder.decoders.1.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.weight. decoder.decoders.2.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.bias. decoder.decoders.2.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.weight. decoder.decoders.2.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.bias. decoder.decoders.2.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.weight. decoder.decoders.2.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.bias. decoder.decoders.2.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.weight. decoder.decoders.2.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.bias. decoder.decoders.2.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.weight. decoder.decoders.2.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.bias. decoder.decoders.2.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.weight. decoder.decoders.2.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.bias. decoder.decoders.2.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.weight. decoder.decoders.2.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.bias. decoder.decoders.2.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.weight. decoder.decoders.2.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.bias. decoder.decoders.2.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.weight. decoder.decoders.2.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.bias. decoder.decoders.2.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.weight. decoder.decoders.2.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.bias. decoder.decoders.2.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.weight. decoder.decoders.2.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.bias. decoder.decoders.2.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.weight. decoder.decoders.2.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.bias. decoder.decoders.2.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.weight. decoder.decoders.2.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.bias. decoder.decoders.2.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.weight. decoder.decoders.2.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.bias. decoder.decoders.2.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.weight. decoder.decoders.2.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.bias. decoder.decoders.2.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.weight. decoder.decoders.3.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.bias. decoder.decoders.3.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.weight. decoder.decoders.3.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.bias. decoder.decoders.3.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.weight. decoder.decoders.3.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.bias. decoder.decoders.3.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.weight. decoder.decoders.3.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.bias. decoder.decoders.3.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.weight. decoder.decoders.3.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.bias. decoder.decoders.3.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.weight. decoder.decoders.3.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.bias. decoder.decoders.3.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.weight. decoder.decoders.3.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.bias. decoder.decoders.3.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.weight. decoder.decoders.3.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.bias. decoder.decoders.3.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.weight. decoder.decoders.3.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.bias. decoder.decoders.3.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.weight. decoder.decoders.3.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.bias. decoder.decoders.3.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.weight. decoder.decoders.3.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.bias. decoder.decoders.3.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.weight. decoder.decoders.3.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.bias. decoder.decoders.3.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.weight. decoder.decoders.3.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.bias. decoder.decoders.3.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.weight. decoder.decoders.3.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.bias. decoder.decoders.3.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.weight. decoder.decoders.3.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.bias. decoder.decoders.3.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.weight. decoder.decoders.4.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.bias. decoder.decoders.4.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.weight. decoder.decoders.4.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.bias. decoder.decoders.4.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.weight. decoder.decoders.4.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.bias. decoder.decoders.4.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.weight. decoder.decoders.4.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.bias. decoder.decoders.4.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.weight. decoder.decoders.4.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.bias. decoder.decoders.4.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.weight. decoder.decoders.4.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.bias. decoder.decoders.4.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.weight. decoder.decoders.4.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.bias. decoder.decoders.4.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.weight. decoder.decoders.4.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.bias. decoder.decoders.4.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.weight. decoder.decoders.4.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.bias. decoder.decoders.4.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.weight. decoder.decoders.4.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.bias. decoder.decoders.4.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.weight. decoder.decoders.4.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.bias. decoder.decoders.4.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.weight. decoder.decoders.4.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.bias. decoder.decoders.4.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.weight. decoder.decoders.4.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.bias. decoder.decoders.4.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.weight. decoder.decoders.4.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.bias. decoder.decoders.4.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.weight. decoder.decoders.4.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.bias. decoder.decoders.4.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.weight. decoder.decoders.5.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.bias. decoder.decoders.5.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.weight. decoder.decoders.5.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.bias. decoder.decoders.5.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.weight. decoder.decoders.5.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.bias. decoder.decoders.5.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.weight. decoder.decoders.5.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.bias. decoder.decoders.5.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.weight. decoder.decoders.5.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.bias. decoder.decoders.5.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.weight. decoder.decoders.5.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.bias. decoder.decoders.5.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.weight. decoder.decoders.5.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.bias. decoder.decoders.5.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.weight. decoder.decoders.5.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.bias. decoder.decoders.5.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.weight. decoder.decoders.5.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.bias. decoder.decoders.5.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.weight. decoder.decoders.5.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.bias. decoder.decoders.5.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.weight. decoder.decoders.5.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.bias. decoder.decoders.5.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.weight. decoder.decoders.5.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.bias. decoder.decoders.5.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.weight. decoder.decoders.5.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.bias. decoder.decoders.5.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.weight. decoder.decoders.5.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.bias. decoder.decoders.5.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.weight. decoder.decoders.5.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.bias. decoder.decoders.5.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.weight. ctc.ctc_lo.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.bias. ctc.ctc_lo.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n"
]
}
],
"source": [
"model.set_state_dict(paddle_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fc7edf1e",
"metadata": {},
"outputs": [],
"source": [
"e_state = model.encoder.state_dict()\n",
"for key, value in e_state.items():\n",
" if 'concat_linear' in key:\n",
" continue\n",
" if not np.allclose(value.numpy(), paddle_state_dict['encoder.' + key]):\n",
" print(key)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "572097d0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "748250b7",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "91e5deee",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 14,
"id": "fleet-despite",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n",
"encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n"
]
}
],
"source": [
"%ls .notebook/espnet"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "abroad-oracle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(8, 57, 83)\n",
"(8, 1, 57)\n",
"[57 50 48 38 32 31 28 25]\n"
]
}
],
"source": [
"data = np.load('.notebook/espnet/feat.npz', allow_pickle=True)\n",
"xs=data['xs']\n",
"masks=data['masks']\n",
"print(xs.shape)\n",
"print(masks.shape)\n",
"xs_lens = masks.sum(axis=-1).squeeze()\n",
"print(xs_lens)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "false-instrument",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[8, 13, 256]\n",
"[8, 1, 13]\n"
]
}
],
"source": [
"# ecnoder\n",
"xs = paddle.to_tensor(xs, dtype='float32')\n",
"x_lens = paddle.to_tensor(xs_lens, dtype='int32')\n",
"model.eval()\n",
"encoder_out, encoder_mask = model.encoder(xs, x_lens)\n",
"print(encoder_out.shape)\n",
"print(encoder_mask.shape)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "arctic-proxy",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(8, 13, 256)\n",
"(8, 1, 13)\n",
"False\n",
"False\n",
"True\n",
"True\n"
]
}
],
"source": [
"data = np.load('.notebook/espnet/encoder.npz', allow_pickle=True)\n",
"xs = data['xs']\n",
"masks = data['masks']\n",
"print(xs.shape)\n",
"print(masks.shape)\n",
"print(np.allclose(xs, encoder_out.numpy()))\n",
"print(np.allclose(xs, encoder_out.numpy(), atol=1e-6))\n",
"print(np.allclose(xs, encoder_out.numpy(), atol=1e-5))\n",
"print(np.allclose(masks, encoder_mask.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "seasonal-switch",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 2.1380312 1.8675405 -1.1873871 ... -0.30456656 0.56382173\n",
" -0.6526459 ]\n",
" [ 2.1926146 2.1373641 -0.6548196 ... -0.897318 0.6044322\n",
" -0.63332295]\n",
" [ 1.6367635 2.3320658 -0.8848577 ... -0.9640939 1.2420733\n",
" -0.05243584]\n",
" ...\n",
" [ 1.8533031 1.8421621 -0.6728406 ... 0.04810616 0.6459763\n",
" -0.18188554]\n",
" [ 2.0894065 1.7813934 -1.1591585 ... -0.09513803 0.8321831\n",
" -0.72916794]\n",
" [ 1.6488649 2.0984242 -1.3490562 ... 0.42678255 0.5903866\n",
" -0.32597935]]\n",
"Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[ 2.13803196, 1.86753929, -1.18738675, ..., -0.30456796, 0.56382364, -0.65264463],\n",
" [ 2.19261336, 2.13736486, -0.65482187, ..., -0.89731705, 0.60443199, -0.63332343],\n",
" [ 1.63676369, 2.33206534, -0.88485885, ..., -0.96409231, 1.24207270, -0.05243752],\n",
" ...,\n",
" [ 1.85330284, 1.84216177, -0.67284071, ..., 0.04810715, 0.64597648, -0.18188696],\n",
" [ 2.08940673, 1.78139246, -1.15916038, ..., -0.09513779, 0.83218288, -0.72916913],\n",
" [ 1.64886570, 2.09842515, -1.34905660, ..., 0.42678308, 0.59038705, -0.32598034]])\n"
]
}
],
"source": [
"print(xs[0])\n",
"print(encoder_out[0])"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "defined-brooks",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 2.209824 1.5208759 0.1417884 ... -0.73617566 1.6538682\n",
" -0.16355833]\n",
" [ 2.1441019 1.4377339 0.3629197 ... -0.91226125 1.3739952\n",
" 0.11874156]\n",
" [ 1.8725398 1.5417286 0.38919652 ... -0.89621615 1.1841662\n",
" 0.27621832]\n",
" ...\n",
" [ 2.4591084 0.7238764 -1.1456345 ... -0.24188249 0.8232168\n",
" -0.9794884 ]\n",
" [ 2.5156236 1.1919155 -0.97032744 ... -0.7360675 1.0647209\n",
" -1.3076135 ]\n",
" [ 2.160009 0.98425585 -1.2231126 ... -0.03393313 1.9141548\n",
" -1.0099151 ]]\n",
"Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[ 2.20982409, 1.52087593, 0.14178854, ..., -0.73617446, 1.65386844, -0.16355731],\n",
" [ 2.14410043, 1.43773460, 0.36291891, ..., -0.91226172, 1.37399518, 0.11874183],\n",
" [ 1.87254059, 1.54172909, 0.38919681, ..., -0.89621687, 1.18416822, 0.27621880],\n",
" ...,\n",
" [ 2.45910931, 0.72387671, -1.14563596, ..., -0.24188218, 0.82321703, -0.97948682],\n",
" [ 2.51562238, 1.19191694, -0.97032893, ..., -0.73606837, 1.06472087, -1.30761123],\n",
" [ 2.16000915, 0.98425680, -1.22311163, ..., -0.03393326, 1.91415381, -1.00991392]])\n"
]
}
],
"source": [
"print(xs[1])\n",
"print(encoder_out[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0504e3f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
因为 它太大了无法显示 source diff 。你可以改为 查看blob
[中文版](README_cn.md)
# PaddlePaddle Speech to Any toolkit # PaddlePaddle Speech to Any toolkit
![License](https://img.shields.io/badge/license-Apache%202-red.svg) ![License](https://img.shields.io/badge/license-Apache%202-red.svg)
...@@ -11,31 +9,29 @@ ...@@ -11,31 +9,29 @@
## Features ## Features
See [feature list](doc/src/feature_list.md) for more information. See [feature list](docs/src/feature_list.md) for more information.
## Setup ## Setup
All tested under: All tested under:
* Ubuntu 16.04 * Ubuntu 16.04
* python>=3.7 * python>=3.7
* paddlepaddle>=2.1.2 * paddlepaddle>=2.2.0rc
Please see [install](doc/src/install.md). Please see [install](docs/src/install.md).
## Getting Started ## Getting Started
Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md). Please see [Getting Started](docs/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md).
## More Information ## More Information
* [Data Prepration](doc/src/data_preparation.md) * [Data Prepration](docs/src/data_preparation.md)
* [Data Augmentation](doc/src/augmentation.md) * [Data Augmentation](docs/src/augmentation.md)
* [Ngram LM](doc/src/ngram_lm.md) * [Ngram LM](docs/src/ngram_lm.md)
* [Server Demo](doc/src/server.md) * [Benchmark](docs/src/benchmark.md)
* [Benchmark](doc/src/benchmark.md) * [Relased Model](docs/src/released_model.md)
* [Relased Model](doc/src/released_model.md)
* [FAQ](doc/src/faq.md)
## Questions and Help ## Questions and Help
...@@ -45,8 +41,8 @@ You are welcome to submit questions in [Github Discussions](https://github.com/P ...@@ -45,8 +41,8 @@ You are welcome to submit questions in [Github Discussions](https://github.com/P
## License ## License
DeepASR is provided under the [Apache-2.0 License](./LICENSE). DeepSpeech is provided under the [Apache-2.0 License](./LICENSE).
## Acknowledgement ## Acknowledgement
We depends on many open source repos. See [References](doc/src/reference.md) for more information. We depends on many open source repos. See [References](docs/src/reference.md) for more information.
[English](README.md)
# PaddlePaddle Speech to Any toolkit
![License](https://img.shields.io/badge/license-Apache%202-red.svg)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux-yellow.svg)
*DeepSpeech*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别引擎的开源项目,
我们的愿景是为语音识别在工业应用和学术研究上,提供易于使用、高效、小型化和可扩展的工具,包括训练,推理,以及 部署。
## 特性
参看 [特性列表](doc/src/feature_list.md)
## 安装
在以下环境测试验证过:
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.1.2
参看 [安装](doc/src/install.md)
## 开始
请查看 [开始](doc/src/getting_started.md)[tiny egs](examples/tiny/s0/README.md)
## 更多信息
* [数据处理](doc/src/data_preparation.md)
* [数据增强](doc/src/augmentation.md)
* [语言模型](doc/src/ngram_lm.md)
* [服务部署](doc/src/server.md)
* [Benchmark](doc/src/benchmark.md)
* [Relased Model](doc/src/released_model.md)
* [FAQ](doc/src/faq.md)
## 问题和帮助
欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。
## License
DeepASR 遵循[Apache-2.0开源协议](./LICENSE)
## 感谢
开发中参考一些优秀的仓库,详情参见 [References](doc/src/reference.md)
...@@ -80,23 +80,23 @@ def convert_dtype_to_string(tensor_dtype): ...@@ -80,23 +80,23 @@ def convert_dtype_to_string(tensor_dtype):
if not hasattr(paddle, 'softmax'): if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!") logger.debug("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax) setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'log_softmax'): if not hasattr(paddle, 'log_softmax'):
logger.warn("register user log_softmax to paddle, remove this when fixed!") logger.debug("register user log_softmax to paddle, remove this when fixed!")
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
if not hasattr(paddle, 'sigmoid'): if not hasattr(paddle, 'sigmoid'):
logger.warn("register user sigmoid to paddle, remove this when fixed!") logger.debug("register user sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle, 'log_sigmoid'): if not hasattr(paddle, 'log_sigmoid'):
logger.warn("register user log_sigmoid to paddle, remove this when fixed!") logger.debug("register user log_sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
if not hasattr(paddle, 'relu'): if not hasattr(paddle, 'relu'):
logger.warn("register user relu to paddle, remove this when fixed!") logger.debug("register user relu to paddle, remove this when fixed!")
setattr(paddle, 'relu', paddle.nn.functional.relu) setattr(paddle, 'relu', paddle.nn.functional.relu)
...@@ -105,7 +105,7 @@ def cat(xs, dim=0): ...@@ -105,7 +105,7 @@ def cat(xs, dim=0):
if not hasattr(paddle, 'cat'): if not hasattr(paddle, 'cat'):
logger.warn( logger.debug(
"override cat of paddle if exists or register, remove this when fixed!") "override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat paddle.cat = cat
...@@ -116,7 +116,7 @@ def item(x: paddle.Tensor): ...@@ -116,7 +116,7 @@ def item(x: paddle.Tensor):
if not hasattr(paddle.Tensor, 'item'): if not hasattr(paddle.Tensor, 'item'):
logger.warn( logger.debug(
"override item of paddle.Tensor if exists or register, remove this when fixed!" "override item of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.item = item paddle.Tensor.item = item
...@@ -127,13 +127,13 @@ def func_long(x: paddle.Tensor): ...@@ -127,13 +127,13 @@ def func_long(x: paddle.Tensor):
if not hasattr(paddle.Tensor, 'long'): if not hasattr(paddle.Tensor, 'long'):
logger.warn( logger.debug(
"override long of paddle.Tensor if exists or register, remove this when fixed!" "override long of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.long = func_long paddle.Tensor.long = func_long
if not hasattr(paddle.Tensor, 'numel'): if not hasattr(paddle.Tensor, 'numel'):
logger.warn( logger.debug(
"override numel of paddle.Tensor if exists or register, remove this when fixed!" "override numel of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.numel = paddle.numel paddle.Tensor.numel = paddle.numel
...@@ -147,7 +147,7 @@ def new_full(x: paddle.Tensor, ...@@ -147,7 +147,7 @@ def new_full(x: paddle.Tensor,
if not hasattr(paddle.Tensor, 'new_full'): if not hasattr(paddle.Tensor, 'new_full'):
logger.warn( logger.debug(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!" "override new_full of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.new_full = new_full paddle.Tensor.new_full = new_full
...@@ -162,13 +162,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: ...@@ -162,13 +162,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'eq'): if not hasattr(paddle.Tensor, 'eq'):
logger.warn( logger.debug(
"override eq of paddle.Tensor if exists or register, remove this when fixed!" "override eq of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.eq = eq paddle.Tensor.eq = eq
if not hasattr(paddle, 'eq'): if not hasattr(paddle, 'eq'):
logger.warn( logger.debug(
"override eq of paddle if exists or register, remove this when fixed!") "override eq of paddle if exists or register, remove this when fixed!")
paddle.eq = eq paddle.eq = eq
...@@ -178,7 +178,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor: ...@@ -178,7 +178,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'contiguous'): if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn( logger.debug(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!" "override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.contiguous = contiguous paddle.Tensor.contiguous = contiguous
...@@ -195,7 +195,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: ...@@ -195,7 +195,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. #`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn( logger.debug(
"override size of paddle.Tensor " "override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
) )
...@@ -207,7 +207,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: ...@@ -207,7 +207,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view'): if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!") logger.debug("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view paddle.Tensor.view = view
...@@ -216,7 +216,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: ...@@ -216,7 +216,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view_as'): if not hasattr(paddle.Tensor, 'view_as'):
logger.warn( logger.debug(
"register user view_as to paddle.Tensor, remove this when fixed!") "register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as paddle.Tensor.view_as = view_as
...@@ -242,7 +242,7 @@ def masked_fill(xs: paddle.Tensor, ...@@ -242,7 +242,7 @@ def masked_fill(xs: paddle.Tensor,
if not hasattr(paddle.Tensor, 'masked_fill'): if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn( logger.debug(
"register user masked_fill to paddle.Tensor, remove this when fixed!") "register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill paddle.Tensor.masked_fill = masked_fill
...@@ -260,7 +260,7 @@ def masked_fill_(xs: paddle.Tensor, ...@@ -260,7 +260,7 @@ def masked_fill_(xs: paddle.Tensor,
if not hasattr(paddle.Tensor, 'masked_fill_'): if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn( logger.debug(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!") "register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_ paddle.Tensor.masked_fill_ = masked_fill_
...@@ -272,7 +272,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: ...@@ -272,7 +272,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'fill_'): if not hasattr(paddle.Tensor, 'fill_'):
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") logger.debug(
"register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_ paddle.Tensor.fill_ = fill_
...@@ -281,22 +282,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: ...@@ -281,22 +282,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'repeat'): if not hasattr(paddle.Tensor, 'repeat'):
logger.warn( logger.debug(
"register user repeat to paddle.Tensor, remove this when fixed!") "register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat paddle.Tensor.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'): if not hasattr(paddle.Tensor, 'softmax'):
logger.warn( logger.debug(
"register user softmax to paddle.Tensor, remove this when fixed!") "register user softmax to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle.Tensor, 'sigmoid'): if not hasattr(paddle.Tensor, 'sigmoid'):
logger.warn( logger.debug(
"register user sigmoid to paddle.Tensor, remove this when fixed!") "register user sigmoid to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle.Tensor, 'relu'): if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!") logger.debug("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
...@@ -305,7 +306,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: ...@@ -305,7 +306,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'type_as'): if not hasattr(paddle.Tensor, 'type_as'):
logger.warn( logger.debug(
"register user type_as to paddle.Tensor, remove this when fixed!") "register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as) setattr(paddle.Tensor, 'type_as', type_as)
...@@ -321,7 +322,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: ...@@ -321,7 +322,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'to'): if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!") logger.debug("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to) setattr(paddle.Tensor, 'to', to)
...@@ -330,7 +331,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor: ...@@ -330,7 +331,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'float'): if not hasattr(paddle.Tensor, 'float'):
logger.warn("register user float to paddle.Tensor, remove this when fixed!") logger.debug(
"register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float) setattr(paddle.Tensor, 'float', func_float)
...@@ -339,7 +341,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor: ...@@ -339,7 +341,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'int'): if not hasattr(paddle.Tensor, 'int'):
logger.warn("register user int to paddle.Tensor, remove this when fixed!") logger.debug("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int) setattr(paddle.Tensor, 'int', func_int)
...@@ -348,23 +350,6 @@ def tolist(x: paddle.Tensor) -> List[Any]: ...@@ -348,23 +350,6 @@ def tolist(x: paddle.Tensor) -> List[Any]:
if not hasattr(paddle.Tensor, 'tolist'): if not hasattr(paddle.Tensor, 'tolist'):
logger.warn( logger.debug(
"register user tolist to paddle.Tensor, remove this when fixed!") "register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn #############
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
...@@ -35,7 +35,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -35,7 +35,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer,
size_t blank_id) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
...@@ -48,7 +49,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -48,7 +49,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// assign blank id // assign blank id
// size_t blank_id = vocabulary.size(); // size_t blank_id = vocabulary.size();
size_t blank_id = 0; // size_t blank_id = 0;
// assign space id // assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
...@@ -57,7 +58,6 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -57,7 +58,6 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if ((size_t)space_id >= vocabulary.size()) { if ((size_t)space_id >= vocabulary.size()) {
space_id = -2; space_id = -2;
} }
// init prefixes' root // init prefixes' root
PathTrie root; PathTrie root;
root.score = root.log_prob_b_prev = 0.0; root.score = root.log_prob_b_prev = 0.0;
...@@ -218,7 +218,8 @@ ctc_beam_search_decoder_batch( ...@@ -218,7 +218,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer,
size_t blank_id) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool // thread pool
ThreadPool pool(num_processes); ThreadPool pool(num_processes);
...@@ -234,7 +235,8 @@ ctc_beam_search_decoder_batch( ...@@ -234,7 +235,8 @@ ctc_beam_search_decoder_batch(
beam_size, beam_size,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
ext_scorer)); ext_scorer,
blank_id));
} }
// get decoding results // get decoding results
......
...@@ -43,7 +43,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -43,7 +43,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
...@@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch( ...@@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_
...@@ -17,17 +17,18 @@ ...@@ -17,17 +17,18 @@
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) { const std::vector<std::string> &vocabulary,
size_t blank_id) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1, vocabulary.size(),
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
"the shape of the vocabulary"); "the shape of the vocabulary");
} }
size_t blank_id = vocabulary.size(); // size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0); std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec; std::vector<size_t> idx_vec;
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
*/ */
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq, const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary); const std::vector<std::string>& vocabulary,
size_t blank_id);
#endif // CTC_GREEDY_DECODER_H #endif // CTC_GREEDY_DECODER_H
...@@ -85,9 +85,8 @@ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') ...@@ -85,9 +85,8 @@ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# yapf: disable # yapf: disable
FILES = [ FILES = [
fn for fn in FILES fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( or fn.endswith('unittest.cc'))
'unittest.cc'))
] ]
# yapf: enable # yapf: enable
......
...@@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer): ...@@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary): def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig. """Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
...@@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): ...@@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
:return: Decoding result string. :return: Decoding result string.
:rtype: str :rtype: str
""" """
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
return result return result
...@@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoder. """Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
...@@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq,
""" """
beam_results = swig_decoders.ctc_beam_search_decoder( beam_results = swig_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func) ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results return beam_results
...@@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decoder. """Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list :param probs_seq: 3-D list with each element as an instance of 2-D list
...@@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split,
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob, probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func) cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results] batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results] for beam_results in batch_beam_results]
return batch_beam_results return batch_beam_results
...@@ -27,7 +27,7 @@ def main_sp(config, args): ...@@ -27,7 +27,7 @@ def main_sp(config, args):
def main(config, args): def main(config, args):
if args.device == "gpu" and args.nprocs > 1: if args.nprocs > 0:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else: else:
main_sp(config, args) main_sp(config, args)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Beam search parameters tuning for DeepSpeech2 model."""
import functools
import sys
import numpy as np
from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils import error_rate
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
def tune(config, args):
"""Tune parameters alpha and beta incrementally."""
if not args.num_alphas >= 0:
raise ValueError("num_alphas must be non-negative!")
if not args.num_betas >= 0:
raise ValueError("num_betas must be non-negative!")
config.defrost()
config.data.manfiest = config.data.dev_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
dev_dataset = ManifestDataset.from_config(config)
valid_loader = DataLoader(
dev_dataset,
batch_size=config.data.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
model = DeepSpeech2Model.from_pretrained(valid_loader, config,
args.checkpoint_path)
model.eval()
# decoders only accept string encoded in utf-8
vocab_list = valid_loader.dataset.vocab_list
errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors
# create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
params_grid = [(alpha, beta) for alpha in cand_alphas
for beta in cand_betas]
err_sum = [0.0 for i in range(len(params_grid))]
err_ave = [0.0 for i in range(len(params_grid))]
num_ins, len_refs, cur_batch = 0, 0, 0
# initialize external scorer
model.decoder.init_decode(args.alpha_from, args.beta_from,
config.decoding.lang_model_path, vocab_list,
config.decoding.decoding_method)
## incremental tuning parameters over multiple batches
print("start tuning ...")
for infer_data in valid_loader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break
def ordid2token(texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
audio, audio_len, text, text_len = infer_data
target_transcripts = ordid2token(text, text_len)
num_ins += audio.shape[0]
# model infer
eouts, eouts_len = model.encoder(audio, audio_len)
probs = model.decoder.softmax(eouts)
# grid search
for index, (alpha, beta) in enumerate(params_grid):
print(f"tuneing: alpha={alpha} beta={beta}")
result_transcripts = model.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list,
config.decoding.decoding_method,
config.decoding.lang_model_path, alpha, beta,
config.decoding.beam_size, config.decoding.cutoff_prob,
config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch)
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
err_sum[index] += errors
# accumulate the length of references of every batchπ
# in the first iteration
if args.alpha_from == alpha and args.beta_from == beta:
len_refs += len_ref
err_ave[index] = err_sum[index] / len_refs
if index % 2 == 0:
sys.stdout.write('.')
sys.stdout.flush()
print("tuneing: one grid done!")
# output on-line tuning result at the end of current batch
err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
" min [%s] = %f" %
(cur_batch, num_ins, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1],
config.decoding.error_rate_type, err_ave_min))
cur_batch += 1
# output WER/CER at every (alpha, beta)
print("\nFinal %s:\n" % config.decoding.error_rate_type)
for index in range(len(params_grid)):
print("(alpha, beta) = (%s, %s), [%s] = %f" %
("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1],
config.decoding.error_rate_type, err_ave[index]))
err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" %
(cur_batch, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1]))
print("finish tuning")
def main(config, args):
tune(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('num_batches', int, -1, "# of batches tuning on. "
"Default -1, on whole dev set.")
add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.")
add_arg('num_betas', int, 8, "# of beta candidates for tuning.")
add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.")
add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.")
add_arg('beta_from', float, 0.1, "Where beta starts tuning from.")
add_arg('beta_to', float, 0.45, "Where beta ends tuning with.")
add_arg('batch_size', int, 256, "# of samples per batch.")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.data.batch_size = args.batch_size
config.decoding.beam_size = args.beam_size
config.decoding.num_proc_bsearch = args.num_proc_bsearch
config.decoding.cutoff_prob = args.cutoff_prob
config.decoding.cutoff_top_n = args.cutoff_top_n
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -34,12 +36,14 @@ from deepspeech.models.ds2 import DeepSpeech2Model ...@@ -34,12 +36,14 @@ from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.reporter import report
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -65,29 +69,52 @@ class DeepSpeech2Trainer(Trainer): ...@@ -65,29 +69,52 @@ class DeepSpeech2Trainer(Trainer):
super().__init__(config, args) super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
batch_size = self.config.collator.batch_size
accum_grad = self.config.training.accum_grad
start = time.time() start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data utt, audio, audio_len, text, text_len = batch_data
loss = self.model(audio, audio_len, text, text_len) loss = self.model(audio, audio_len, text, text_len)
losses_np = {
'train_loss': float(loss),
}
# loss backward
if (batch_index + 1) % accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % accum_grad == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
self.iteration += 1
iteration_time = time.time() - start iteration_time = time.time() - start
losses_np = { for k, v in losses_np.items():
'train_loss': float(loss), report(k, v)
} report("batch_size", batch_size)
msg += "train time: {:>.3f}s, ".format(iteration_time) report("accum", accum_grad)
msg += "batch size: {}, ".format(self.config.collator.batch_size) report("step_cost", iteration_time)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items(): for k, v in losses_np.items():
# `step -1` since we update `step` after optimizer.step().
self.visualizer.add_scalar("train/{}".format(k), v, self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration) self.iteration - 1)
self.iteration += 1
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
...@@ -124,10 +151,9 @@ class DeepSpeech2Trainer(Trainer): ...@@ -124,10 +151,9 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config.clone() config = self.config.clone()
config.defrost() with UpdateConfig(config):
config.model.feat_size = self.train_loader.collate_fn.feature_size config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.freeze()
if self.args.model_type == 'offline': if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model) model = DeepSpeech2Model.from_config(config.model)
...@@ -280,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -280,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("Current error rate [%s] = %f" % logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
...@@ -325,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -325,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg = self.config cfg = self.config
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts, metrics = self.compute_metrics(utts, audio, audio_len, texts,
...@@ -378,7 +405,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -378,7 +405,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
paddle.set_device(self.args.device) paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir() self.setup_output_dir()
self.setup_checkpointer() self.setup_checkpointer()
...@@ -610,7 +637,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): ...@@ -610,7 +637,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
paddle.set_device(self.args.device) paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir() self.setup_output_dir()
......
...@@ -22,6 +22,8 @@ from deepspeech.exps.u2.model import U2Trainer as Trainer ...@@ -22,6 +22,8 @@ from deepspeech.exps.u2.model import U2Trainer as Trainer
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
# from deepspeech.exps.u2.trainer import U2Trainer as Trainer
def main_sp(config, args): def main_sp(config, args):
exp = Trainer(config, args) exp = Trainer(config, args)
...@@ -30,7 +32,7 @@ def main_sp(config, args): ...@@ -30,7 +32,7 @@ def main_sp(config, args):
def main(config, args): def main(config, args):
if args.device == "gpu" and args.nprocs > 1: if args.nprocs > 0:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else: else:
main_sp(config, args) main_sp(config, args)
......
...@@ -17,9 +17,12 @@ import os ...@@ -17,9 +17,12 @@ import os
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -32,7 +35,10 @@ from deepspeech.io.sampler import SortagradBatchSampler ...@@ -32,7 +35,10 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report
from deepspeech.training.scheduler import LRSchedulerFactory from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
...@@ -41,6 +47,7 @@ from deepspeech.utils import mp_tools ...@@ -41,6 +47,7 @@ from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid from deepspeech.utils import text_grid
from deepspeech.utils import utility from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -79,21 +86,36 @@ class U2Trainer(Trainer): ...@@ -79,21 +86,36 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training
start = time.time() start = time.time()
utt, audio, audio_len, text, text_len = batch_data
# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len) text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad} losses_np = {'loss': float(loss) * train_conf.accum_grad}
if attention_loss: if attention_loss:
losses_np['att_loss'] = float(attention_loss) losses_np['att_loss'] = float(attention_loss)
if ctc_loss: if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss) losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
# When using cpu w/o DDP, model does not have `no_sync`
context = self.model.no_sync if self.parallel else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
...@@ -102,14 +124,13 @@ class U2Trainer(Trainer): ...@@ -102,14 +124,13 @@ class U2Trainer(Trainer):
iteration_time = time.time() - start iteration_time = time.time() - start
if (batch_index + 1) % train_conf.log_interval == 0: for k, v in losses_np.items():
msg += "train time: {:>.3f}s, ".format(iteration_time) report(k, v)
msg += "batch size: {}, ".format(self.config.collator.batch_size) report("batch_size", self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad) report("accum", train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v) report("step_cost", iteration_time)
for k, v in losses_np.items())
logger.info(msg)
if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
...@@ -163,35 +184,47 @@ class U2Trainer(Trainer): ...@@ -163,35 +184,47 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader): for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train:"
msg += "epoch: {}, ".format(self.epoch) observation = OrderedDict()
msg += "step: {}, ".format(self.iteration) with ObsScope(observation):
msg += "batch : {}/{}, ".format(batch_index + 1, report("Rank", dist.get_rank())
len(self.train_loader)) report("epoch", self.epoch)
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) report('step', self.iteration)
msg += "data time: {:>.3f}s, ".format(dataload_time) report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
observation['samples'] = observation['batch_size']
observation['ips[sent./sec]'] = observation[
'batch_size'] / observation['batch_cost']
for k, v in observation.items():
msg += f" {k}: "
msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}"
msg += ","
if (batch_index + 1
) % self.config.training.log_interval == 0:
logger.info(msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid() total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts) num_seen_utts = paddle.to_tensor(num_seen_utts)
...@@ -294,10 +327,11 @@ class U2Trainer(Trainer): ...@@ -294,10 +327,11 @@ class U2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config.model model_conf = config.model
model_conf.defrost()
with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
if self.parallel: if self.parallel:
...@@ -433,9 +467,10 @@ class U2Tester(U2Trainer): ...@@ -433,9 +467,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
...@@ -460,7 +495,7 @@ class U2Tester(U2Trainer): ...@@ -460,7 +495,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout) metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames'] num_frames += metrics['num_frames']
...@@ -540,7 +575,7 @@ class U2Tester(U2Trainer): ...@@ -540,7 +575,7 @@ class U2Tester(U2Trainer):
# 1. Encoder # 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder( encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim) feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax( ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size) encoder_out) # (1, maxlen, vocab_size)
...@@ -548,26 +583,25 @@ class U2Tester(U2Trainer): ...@@ -548,26 +583,25 @@ class U2Tester(U2Trainer):
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0) target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target) alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment) logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment)) fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat # 3. gen praat
# segment alignment # segment alignment
align_segs = text_grid.segment_alignment(alignment) align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs) logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"] # IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config) subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat( tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict) align_segs, subsample, token_dict)
# write tier # write tier
align_output_path = os.path.join( align_output_path = Path(self.args.result_file).parent / "align"
os.path.dirname(self.args.result_file), "align") align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = os.path.join(align_output_path, key[0] + ".tier") tier_path = align_output_path / (key[0] + ".tier")
with open(tier_path, 'w') as f: with tier_path.open('w') as f:
f.writelines(tierformat) f.writelines(tierformat)
# write textgrid # write textgrid
textgrid_path = os.path.join(align_output_path, textgrid_path = align_output_path / (key[0] + ".TextGrid")
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. / second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride stride_ms) # 25ms window, 10ms stride
second_per_example = ( second_per_example = (
...@@ -575,7 +609,7 @@ class U2Tester(U2Trainer): ...@@ -575,7 +609,7 @@ class U2Tester(U2Trainer):
text_grid.generate_textgrid( text_grid.generate_textgrid(
maxtime=second_per_example, maxtime=second_per_example,
intervals=tierformat, intervals=tierformat,
output=textgrid_path) output=str(textgrid_path))
def run_align(self): def run_align(self):
self.resume_or_scratch() self.resume_or_scratch()
...@@ -621,7 +655,7 @@ class U2Tester(U2Trainer): ...@@ -621,7 +655,7 @@ class U2Tester(U2Trainer):
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
paddle.set_device(self.args.device) paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir() self.setup_output_dir()
self.setup_checkpointer() self.setup_checkpointer()
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains U2 model."""
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Evaluator
from deepspeech.models.u2 import U2Model
from deepspeech.models.u2 import U2Updater
from deepspeech.training.extensions.snapshot import Snapshot
from deepspeech.training.extensions.visualizer import VisualDL
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
from deepspeech.training.updaters.trainer import Trainer as NewTrainer
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
config.data.min_input_len = 0.0 # second
config.data.max_input_len = float('inf') # second
config.data.min_output_len = 0.0 # tokens
config.data.max_output_len = float('inf') # tokens
config.data.min_output_input_ratio = 0.00
config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config.model
with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
model.train()
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
train_config = config.training
optim_type = train_config.optim
optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
scheduler_args = {
"learning_rate": optim_conf.lr,
"verbose": False,
"warmup_steps": scheduler_conf.warmup_steps,
"gamma": scheduler_conf.lr_decay,
"d_model": model_conf.encoder_conf.output_size,
}
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args)
def optimizer_args(
config,
parameters,
lr_scheduler=None, ):
train_config = config.training
optim_type = train_config.optim
optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
def setup_updater(self):
output_dir = self.output_dir
config = self.config.training
updater = U2Updater(
model=self.model,
optimizer=self.optimizer,
scheduler=self.lr_scheduler,
dataloader=self.train_loader,
output_dir=output_dir,
accum_grad=config.accum_grad)
trainer = NewTrainer(updater, (config.n_epoch, 'epoch'), output_dir)
evaluator = U2Evaluator(self.model, self.valid_loader)
trainer.extend(evaluator, trigger=(1, "epoch"))
if dist.get_rank() == 0:
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
num_snapshots = config.checkpoint.kbest_n
trainer.extend(
Snapshot(
mode='kbest',
max_size=num_snapshots,
indicator='VALID/LOSS',
less_better=True),
trigger=(1, 'epoch'))
# print(trainer.extensions)
# trainer.run()
self.trainer = trainer
def run(self):
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
self.setup_updater()
with Timer("Training Done: {}"):
self.trainer.run()
...@@ -36,7 +36,7 @@ def main_sp(config, args): ...@@ -36,7 +36,7 @@ def main_sp(config, args):
def main(config, args): def main(config, args):
if args.device == "gpu" and args.nprocs > 1: if args.nprocs > 0:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else: else:
main_sp(config, args) main_sp(config, args)
......
...@@ -17,9 +17,11 @@ import os ...@@ -17,9 +17,11 @@ import os
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -31,6 +33,7 @@ from deepspeech.io.dataloader import BatchDataLoader ...@@ -31,6 +33,7 @@ from deepspeech.io.dataloader import BatchDataLoader
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
...@@ -39,6 +42,7 @@ from deepspeech.utils import mp_tools ...@@ -39,6 +42,7 @@ from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid from deepspeech.utils import text_grid
from deepspeech.utils import utility from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -83,20 +87,34 @@ class U2Trainer(Trainer): ...@@ -83,20 +87,34 @@ class U2Trainer(Trainer):
train_conf = self.config.training train_conf = self.config.training
start = time.time() start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len) text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad} losses_np = {'loss': float(loss) * train_conf.accum_grad}
if attention_loss: if attention_loss:
losses_np['att_loss'] = float(attention_loss) losses_np['att_loss'] = float(attention_loss)
if ctc_loss: if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss) losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
...@@ -167,14 +185,11 @@ class U2Trainer(Trainer): ...@@ -167,14 +185,11 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
...@@ -188,11 +203,13 @@ class U2Trainer(Trainer): ...@@ -188,11 +203,13 @@ class U2Trainer(Trainer):
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time) msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch()
data_start_time = time.time() data_start_time = time.time()
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid() total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts) num_seen_utts = paddle.to_tensor(num_seen_utts)
...@@ -300,10 +317,10 @@ class U2Trainer(Trainer): ...@@ -300,10 +317,10 @@ class U2Trainer(Trainer):
# model # model
model_conf = config.model model_conf = config.model
model_conf.defrost() with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.feat_dim model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size model_conf.output_dim = self.train_loader.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -429,9 +446,10 @@ class U2Tester(U2Trainer): ...@@ -429,9 +446,10 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
...@@ -456,7 +474,7 @@ class U2Tester(U2Trainer): ...@@ -456,7 +474,7 @@ class U2Tester(U2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout) metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames'] num_frames += metrics['num_frames']
...@@ -526,9 +544,8 @@ class U2Tester(U2Trainer): ...@@ -526,9 +544,8 @@ class U2Tester(U2Trainer):
self.model.eval() self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.config.collater.stride_ms stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.args.char_list token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
# one example in batch # one example in batch
for i, batch in enumerate(self.align_loader): for i, batch in enumerate(self.align_loader):
...@@ -537,7 +554,7 @@ class U2Tester(U2Trainer): ...@@ -537,7 +554,7 @@ class U2Tester(U2Trainer):
# 1. Encoder # 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder( encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim) feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax( ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size) encoder_out) # (1, maxlen, vocab_size)
...@@ -545,26 +562,25 @@ class U2Tester(U2Trainer): ...@@ -545,26 +562,25 @@ class U2Tester(U2Trainer):
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0) target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target) alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment) logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment)) fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat # 3. gen praat
# segment alignment # segment alignment
align_segs = text_grid.segment_alignment(alignment) align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs) logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"] # IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config) subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat( tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict) align_segs, subsample, token_dict)
# write tier # write tier
align_output_path = os.path.join( align_output_path = Path(self.args.result_file).parent / "align"
os.path.dirname(self.args.result_file), "align") align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = os.path.join(align_output_path, key[0] + ".tier") tier_path = align_output_path / (key[0] + ".tier")
with open(tier_path, 'w') as f: with tier_path.open('w') as f:
f.writelines(tierformat) f.writelines(tierformat)
# write textgrid # write textgrid
textgrid_path = os.path.join(align_output_path, textgrid_path = align_output_path / (key[0] + ".TextGrid")
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. / second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride stride_ms) # 25ms window, 10ms stride
second_per_example = ( second_per_example = (
...@@ -572,7 +588,7 @@ class U2Tester(U2Trainer): ...@@ -572,7 +588,7 @@ class U2Tester(U2Trainer):
text_grid.generate_textgrid( text_grid.generate_textgrid(
maxtime=second_per_example, maxtime=second_per_example,
intervals=tierformat, intervals=tierformat,
output=textgrid_path) output=str(textgrid_path))
def run_align(self): def run_align(self):
self.resume_or_scratch() self.resume_or_scratch()
...@@ -623,7 +639,7 @@ class U2Tester(U2Trainer): ...@@ -623,7 +639,7 @@ class U2Tester(U2Trainer):
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
paddle.set_device(self.args.device) paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir() self.setup_output_dir()
self.setup_checkpointer() self.setup_checkpointer()
......
...@@ -30,7 +30,7 @@ def main_sp(config, args): ...@@ -30,7 +30,7 @@ def main_sp(config, args):
def main(config, args): def main(config, args):
if args.device == "gpu" and args.nprocs > 1: if args.nprocs > 0:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else: else:
main_sp(config, args) main_sp(config, args)
......
...@@ -17,9 +17,11 @@ import os ...@@ -17,9 +17,11 @@ import os
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -37,6 +39,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler ...@@ -37,6 +39,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2_st import U2STModel from deepspeech.models.u2_st import U2STModel
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import bleu_score from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
...@@ -45,6 +48,7 @@ from deepspeech.utils import mp_tools ...@@ -45,6 +48,7 @@ from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid from deepspeech.utils import text_grid
from deepspeech.utils import utility from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -83,6 +87,7 @@ class U2STTrainer(Trainer): ...@@ -83,6 +87,7 @@ class U2STTrainer(Trainer):
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training
start = time.time() start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data utt, audio, audio_len, text, text_len = batch_data
if isinstance(text, list) and isinstance(text_len, list): if isinstance(text, list) and isinstance(text_len, list):
# joint training with ASR. Two decoding texts [translation, transcription] # joint training with ASR. Two decoding texts [translation, transcription]
...@@ -94,18 +99,30 @@ class U2STTrainer(Trainer): ...@@ -94,18 +99,30 @@ class U2STTrainer(Trainer):
else: else:
loss, st_loss, attention_loss, ctc_loss = self.model( loss, st_loss, attention_loss, ctc_loss = self.model(
audio, audio_len, text, text_len) audio, audio_len, text, text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad} losses_np = {'loss': float(loss) * train_conf.accum_grad}
losses_np['st_loss'] = float(st_loss)
if attention_loss: if attention_loss:
losses_np['att_loss'] = float(attention_loss) losses_np['att_loss'] = float(attention_loss)
if ctc_loss: if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss) losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
...@@ -182,17 +199,11 @@ class U2STTrainer(Trainer): ...@@ -182,17 +199,11 @@ class U2STTrainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
...@@ -206,11 +217,13 @@ class U2STTrainer(Trainer): ...@@ -206,11 +217,13 @@ class U2STTrainer(Trainer):
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time) msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch()
data_start_time = time.time() data_start_time = time.time()
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid() total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts) num_seen_utts = paddle.to_tensor(num_seen_utts)
...@@ -327,10 +340,10 @@ class U2STTrainer(Trainer): ...@@ -327,10 +340,10 @@ class U2STTrainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config.model model_conf = config.model
model_conf.defrost() with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
model = U2STModel.from_config(model_conf) model = U2STModel.from_config(model_conf)
if self.parallel: if self.parallel:
...@@ -467,8 +480,10 @@ class U2STTester(U2STTrainer): ...@@ -467,8 +480,10 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split()) len_refs += len(target.split())
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info("\nReference: %s\nHypothesis: %s" % (target, result)) logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example BLEU = %s" % logger.info("One example BLEU = %s" %
(bleu_func([result], [[target]]).prec_str)) (bleu_func([result], [[target]]).prec_str))
...@@ -496,7 +511,7 @@ class U2STTester(U2STTrainer): ...@@ -496,7 +511,7 @@ class U2STTester(U2STTrainer):
len_refs, num_ins = 0, 0 len_refs, num_ins = 0, 0
num_frames = 0.0 num_frames = 0.0
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_translation_metrics( metrics = self.compute_translation_metrics(
*batch, bleu_func=bleu_func, fout=fout) *batch, bleu_func=bleu_func, fout=fout)
...@@ -569,7 +584,7 @@ class U2STTester(U2STTrainer): ...@@ -569,7 +584,7 @@ class U2STTester(U2STTrainer):
# 1. Encoder # 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder( encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim) feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax( ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size) encoder_out) # (1, maxlen, vocab_size)
...@@ -577,26 +592,25 @@ class U2STTester(U2STTrainer): ...@@ -577,26 +592,25 @@ class U2STTester(U2STTrainer):
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0) target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target) alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment) logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment)) fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat # 3. gen praat
# segment alignment # segment alignment
align_segs = text_grid.segment_alignment(alignment) align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs) logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"] # IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config) subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat( tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict) align_segs, subsample, token_dict)
# write tier # write tier
align_output_path = os.path.join( align_output_path = Path(self.args.result_file).parent / "align"
os.path.dirname(self.args.result_file), "align") align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = os.path.join(align_output_path, key[0] + ".tier") tier_path = align_output_path / (key[0] + ".tier")
with open(tier_path, 'w') as f: with tier_path.open('w') as f:
f.writelines(tierformat) f.writelines(tierformat)
# write textgrid # write textgrid
textgrid_path = os.path.join(align_output_path, textgrid_path = align_output_path / (key[0] + ".TextGrid")
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. / second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride stride_ms) # 25ms window, 10ms stride
second_per_example = ( second_per_example = (
...@@ -604,7 +618,7 @@ class U2STTester(U2STTrainer): ...@@ -604,7 +618,7 @@ class U2STTester(U2STTrainer):
text_grid.generate_textgrid( text_grid.generate_textgrid(
maxtime=second_per_example, maxtime=second_per_example,
intervals=tierformat, intervals=tierformat,
output=textgrid_path) output=str(textgrid_path))
def run_align(self): def run_align(self):
self.resume_or_scratch() self.resume_or_scratch()
...@@ -650,7 +664,7 @@ class U2STTester(U2STTrainer): ...@@ -650,7 +664,7 @@ class U2STTester(U2STTrainer):
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
paddle.set_device(self.args.device) paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir() self.setup_output_dir()
self.setup_checkpointer() self.setup_checkpointer()
......
...@@ -196,7 +196,12 @@ class TextFeaturizer(): ...@@ -196,7 +196,12 @@ class TextFeaturizer():
[(idx, token) for (idx, token) in enumerate(vocab_list)]) [(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict( token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)]) [(token, idx) for (idx, token) in enumerate(vocab_list)])
if UNK in vocab_list:
unk_id = vocab_list.index(UNK) unk_id = vocab_list.index(UNK)
else:
unk_id = -1
if EOS in vocab_list:
eos_id = vocab_list.index(EOS) eos_id = vocab_list.index(EOS)
else:
eos_id = -1
return token2id, id2token, vocab_list, unk_id, eos_id return token2id, id2token, vocab_list, unk_id, eos_id
...@@ -130,7 +130,8 @@ class FeatureNormalizer(object): ...@@ -130,7 +130,8 @@ class FeatureNormalizer(object):
def _read_mean_std_from_file(self, filepath, eps=1e-20): def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file.""" """Load mean and std from file."""
mean, istd = load_cmvn(filepath, filetype='json') filetype = filepath.split(".")[-1]
mean, istd = load_cmvn(filepath, filetype=filetype)
self._mean = np.expand_dims(mean, axis=0) self._mean = np.expand_dims(mean, axis=0)
self._istd = np.expand_dims(istd, axis=0) self._istd = np.expand_dims(istd, axis=0)
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains data helper functions.""" """Contains data helper functions."""
import codecs
import json import json
import math import math
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Text from typing import Text
import jsonlines
import numpy as np import numpy as np
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
...@@ -92,12 +92,8 @@ def read_manifest( ...@@ -92,12 +92,8 @@ def read_manifest(
""" """
manifest = [] manifest = []
for json_line in codecs.open(manifest_path, 'r', 'utf-8'): with jsonlines.open(manifest_path, 'r') as reader:
try: for json_data in reader:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
feat_len = json_data["feat_shape"][ feat_len = json_data["feat_shape"][
0] if 'feat_shape' in json_data else 1.0 0] if 'feat_shape' in json_data else 1.0
token_len = json_data["token_shape"][ token_len = json_data["token_shape"][
...@@ -284,6 +280,13 @@ def load_cmvn(cmvn_file: str, filetype: str): ...@@ -284,6 +280,13 @@ def load_cmvn(cmvn_file: str, filetype: str):
cmvn = _load_json_cmvn(cmvn_file) cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi": elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file) cmvn = _load_kaldi_cmvn(cmvn_file)
elif filetype == "npz":
eps = 1e-14
npzfile = np.load(cmvn_file)
mean = np.squeeze(npzfile["mean"])
std = np.squeeze(npzfile["std"])
istd = 1 / (std + eps)
cmvn = [mean, istd]
else: else:
raise ValueError(f"cmvn file type no support: {filetype}") raise ValueError(f"cmvn file type no support: {filetype}")
return cmvn[0], cmvn[1] return cmvn[0], cmvn[1]
...@@ -292,10 +292,6 @@ class SpeechCollator(): ...@@ -292,10 +292,6 @@ class SpeechCollator():
olens = np.array(text_lens).astype(np.int64) olens = np.array(text_lens).astype(np.int64)
return utts, xs_pad, ilens, ys_pad, olens return utts, xs_pad, ilens, ys_pad, olens
@property
def manifest(self):
return self._manifest
@property @property
def vocab_size(self): def vocab_size(self):
return self._speech_featurizer.vocab_size return self._speech_featurizer.vocab_size
......
...@@ -44,7 +44,7 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]], ...@@ -44,7 +44,7 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
def batch_collate(x): def batch_collate(x):
"""de-tuple. """de-minibatch, since user compose batch.
Args: Args:
x (List[Tuple]): [(utts, xs, ilens, ys, olens)] x (List[Tuple]): [(utts, xs, ilens, ys, olens)]
......
...@@ -147,3 +147,131 @@ class TransformDataset(Dataset): ...@@ -147,3 +147,131 @@ class TransformDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
"""[] operator.""" """[] operator."""
return self.converter([self.reader(self.data[idx], return_uttid=True)]) return self.converter([self.reader(self.data[idx], return_uttid=True)])
class AudioDataset(Dataset):
def __init__(self,
data_file,
max_length=10240,
min_length=0,
token_max_length=200,
token_min_length=1,
batch_type='static',
batch_size=1,
max_frames_in_batch=0,
sort=True,
raw_wav=True,
stride_ms=10):
"""Dataset for loading audio data.
Attributes::
data_file: input data file
Plain text data file, each line contains following 7 fields,
which is split by '\t':
utt:utt1
feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30
feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames)
text:i love you
token: i <space> l o v e <space> y o u
tokenid: int id of this token
token_shape: M,N # M is the number of token, N is vocab size
max_length: drop utterance which is greater than max_length(10ms), unit 10ms.
min_length: drop utterance which is less than min_length(10ms), unit 10ms.
token_max_length: drop utterance which is greater than token_max_length,
especially when use char unit for english modeling
token_min_length: drop utterance which is less than token_max_length
batch_type: static or dynamic, see max_frames_in_batch(dynamic)
batch_size: number of utterances in a batch,
it's for static batch size.
max_frames_in_batch: max feature frames in a batch,
when batch_type is dynamic, it's for dynamic batch size.
Then batch_size is ignored, we will keep filling the
batch until the total frames in batch up to max_frames_in_batch.
sort: whether to sort all data, so the utterance with the same
length could be filled in a same batch.
raw_wav: use raw wave or extracted featute.
if raw wave is used, dynamic waveform-level augmentation could be used
and the feature is extracted by torchaudio.
if extracted featute(e.g. by kaldi) is used, only feature-level
augmentation such as specaug could be used.
"""
assert batch_type in ['static', 'dynamic']
# read manifest
data = read_manifest(data_file)
if sort:
data = sorted(data, key=lambda x: x["feat_shape"][0])
if raw_wav:
assert data[0]['feat'].split(':')[0].splitext()[-1] not in ('.ark',
'.scp')
data = map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms))
self.input_dim = data[0]['feat_shape'][1]
self.output_dim = data[0]['token_shape'][1]
# with open(data_file, 'r') as f:
# for line in f:
# arr = line.strip().split('\t')
# if len(arr) != 7:
# continue
# key = arr[0].split(':')[1]
# tokenid = arr[5].split(':')[1]
# output_dim = int(arr[6].split(':')[1].split(',')[1])
# if raw_wav:
# wav_path = ':'.join(arr[1].split(':')[1:])
# duration = int(float(arr[2].split(':')[1]) * 1000 / 10)
# data.append((key, wav_path, duration, tokenid))
# else:
# feat_ark = ':'.join(arr[1].split(':')[1:])
# feat_info = arr[2].split(':')[1].split(',')
# feat_dim = int(feat_info[1].strip())
# num_frames = int(feat_info[0].strip())
# data.append((key, feat_ark, num_frames, tokenid))
# self.input_dim = feat_dim
# self.output_dim = output_dim
valid_data = []
for i in range(len(data)):
length = data[i]['feat_shape'][0]
token_length = data[i]['token_shape'][0]
# remove too lang or too short utt for both input and output
# to prevent from out of memory
if length > max_length or length < min_length:
# logging.warn('ignore utterance {} feature {}'.format(
# data[i][0], length))
pass
elif token_length > token_max_length or token_length < token_min_length:
pass
else:
valid_data.append(data[i])
data = valid_data
self.minibatch = []
num_data = len(data)
# Dynamic batch size
if batch_type == 'dynamic':
assert (max_frames_in_batch > 0)
self.minibatch.append([])
num_frames_in_batch = 0
for i in range(num_data):
length = data[i]['feat_shape'][0]
num_frames_in_batch += length
if num_frames_in_batch > max_frames_in_batch:
self.minibatch.append([])
num_frames_in_batch = length
self.minibatch[-1].append(data[i])
# Static batch size
else:
cur = 0
while cur < num_data:
end = min(cur + batch_size, num_data)
item = []
for i in range(cur, end):
item.append(data[i])
self.minibatch.append(item)
cur = end
def __len__(self):
return len(self.minibatch)
def __getitem__(self, idx):
instance = self.minibatch[idx]
return instance["utt"], instance["feat"], instance["text"]
...@@ -106,11 +106,9 @@ class ConvBn(nn.Layer): ...@@ -106,11 +106,9 @@ class ConvBn(nn.Layer):
# reset padding part to 0 # reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T] masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply # https://github.com/PaddlePaddle/Paddle/pull/29265
# masks = masks.type_as(x) # rhs will type promote to lhs
masks = masks.astype(x.dtype) x = x * masks
x = x.multiply(masks)
return x, x_len return x, x_len
......
...@@ -128,8 +128,8 @@ class DeepSpeech2Model(nn.Layer): ...@@ -128,8 +128,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=3, #Number of stacking RNN layers. num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells). rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False. use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
)) ctc_grad_norm_type='instance', ))
if config is not None: if config is not None:
config.merge_from_other_cfg(default) config.merge_from_other_cfg(default)
return default return default
...@@ -141,7 +141,9 @@ class DeepSpeech2Model(nn.Layer): ...@@ -141,7 +141,9 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
blank_id=0,
ctc_grad_norm_type='instance'):
super().__init__() super().__init__()
self.encoder = CRNNEncoder( self.encoder = CRNNEncoder(
feat_size=feat_size, feat_size=feat_size,
...@@ -156,10 +158,11 @@ class DeepSpeech2Model(nn.Layer): ...@@ -156,10 +158,11 @@ class DeepSpeech2Model(nn.Layer):
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size, enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank> blank_id=blank_id,
dropout_rate=0.0, dropout_rate=0.0,
reduction=True, # sum reduction=True, # sum
batch_average=True) # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=ctc_grad_norm_type)
def forward(self, audio, audio_len, text, text_len): def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss """Compute Model loss
...@@ -221,7 +224,8 @@ class DeepSpeech2Model(nn.Layer): ...@@ -221,7 +224,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights,
blank_id=config.model.blank_id)
infos = Checkpoint().load_parameters( infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path) model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}") logger.info(f"checkpoint info: {infos}")
...@@ -246,7 +250,8 @@ class DeepSpeech2Model(nn.Layer): ...@@ -246,7 +250,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=config.num_rnn_layers, num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size, rnn_size=config.rnn_layer_size,
use_gru=config.use_gru, use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights) share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id)
return model return model
...@@ -258,7 +263,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): ...@@ -258,7 +263,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
blank_id=0):
super().__init__( super().__init__(
feat_size=feat_size, feat_size=feat_size,
dict_size=dict_size, dict_size=dict_size,
...@@ -266,7 +272,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): ...@@ -266,7 +272,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=num_rnn_layers, num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size, rnn_size=rnn_size,
use_gru=use_gru, use_gru=use_gru,
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def forward(self, audio, audio_len): def forward(self, audio, audio_len):
"""export model function """export model function
......
...@@ -308,7 +308,8 @@ class RNNStack(nn.Layer): ...@@ -308,7 +308,8 @@ class RNNStack(nn.Layer):
x, x_len = rnn(x, x_len) x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T] masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1] masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply # https://github.com/PaddlePaddle/Paddle/pull/29265
masks = masks.astype(x.dtype) # rhs will type promote to lhs
x = x.multiply(masks) x = x * masks
return x, x_len return x, x_len
...@@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer): ...@@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_fc_layers=2, num_fc_layers=2,
fc_layers_size_list=[512, 256], fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False. use_gru=True, #Use gru if set True. Use simple rnn if set False.
blank_id=0, # index of blank in vocob.txt
)) ))
if config is not None: if config is not None:
config.merge_from_other_cfg(default) config.merge_from_other_cfg(default)
...@@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer): ...@@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction='forward', rnn_direction='forward',
num_fc_layers=2, num_fc_layers=2,
fc_layers_size_list=[512, 256], fc_layers_size_list=[512, 256],
use_gru=False): use_gru=False,
blank_id=0):
super().__init__() super().__init__()
self.encoder = CRNNEncoder( self.encoder = CRNNEncoder(
feat_size=feat_size, feat_size=feat_size,
...@@ -284,10 +286,11 @@ class DeepSpeech2ModelOnline(nn.Layer): ...@@ -284,10 +286,11 @@ class DeepSpeech2ModelOnline(nn.Layer):
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size, enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank> blank_id=blank_id,
dropout_rate=0.0, dropout_rate=0.0,
reduction=True, # sum reduction=True, # sum
batch_average=True) # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type='instance')
def forward(self, audio, audio_len, text, text_len): def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss """Compute Model loss
...@@ -353,7 +356,8 @@ class DeepSpeech2ModelOnline(nn.Layer): ...@@ -353,7 +356,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction=config.model.rnn_direction, rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers, num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list, fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru) use_gru=config.model.use_gru,
blank_id=config.model.blank_id)
infos = Checkpoint().load_parameters( infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path) model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}") logger.info(f"checkpoint info: {infos}")
...@@ -380,7 +384,8 @@ class DeepSpeech2ModelOnline(nn.Layer): ...@@ -380,7 +384,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction=config.rnn_direction, rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers, num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list, fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru) use_gru=config.use_gru,
blank_id=config.blank_id)
return model return model
...@@ -394,7 +399,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): ...@@ -394,7 +399,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
rnn_direction='forward', rnn_direction='forward',
num_fc_layers=2, num_fc_layers=2,
fc_layers_size_list=[512, 256], fc_layers_size_list=[512, 256],
use_gru=False): use_gru=False,
blank_id=0):
super().__init__( super().__init__(
feat_size=feat_size, feat_size=feat_size,
dict_size=dict_size, dict_size=dict_size,
...@@ -404,7 +410,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): ...@@ -404,7 +410,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
rnn_direction=rnn_direction, rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers, num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list, fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru) use_gru=use_gru,
blank_id=blank_id)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box): chunk_state_c_box):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2 import U2InferModel
from .u2 import U2Model
from .updater import U2Evaluator
from .updater import U2Updater
__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"]
...@@ -48,6 +48,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos ...@@ -48,6 +48,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import log_add from deepspeech.utils.utility import log_add
from deepspeech.utils.utility import UpdateConfig
__all__ = ["U2Model", "U2InferModel"] __all__ = ["U2Model", "U2InferModel"]
...@@ -115,7 +116,8 @@ class U2BaseModel(nn.Layer): ...@@ -115,7 +116,8 @@ class U2BaseModel(nn.Layer):
ctc_weight: float=0.5, ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID, ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0, lsm_weight: float=0.0,
length_normalized_loss: bool=False): length_normalized_loss: bool=False,
**kwargs):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__() super().__init__()
...@@ -162,10 +164,7 @@ class U2BaseModel(nn.Layer): ...@@ -162,10 +164,7 @@ class U2BaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}") #logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. Attention-decoder branch # 2a. Attention-decoder branch
loss_att = None loss_att = None
...@@ -299,8 +298,8 @@ class U2BaseModel(nn.Layer): ...@@ -299,8 +298,8 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim) simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.size(2) encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
...@@ -320,8 +319,7 @@ class U2BaseModel(nn.Layer): ...@@ -320,8 +319,7 @@ class U2BaseModel(nn.Layer):
# 2. Decoder forward step by step # 2. Decoder forward step by step
for i in range(1, maxlen + 1): for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos # Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
...@@ -406,10 +404,8 @@ class U2BaseModel(nn.Layer): ...@@ -406,10 +404,8 @@ class U2BaseModel(nn.Layer):
encoder_out, encoder_mask = self._forward_encoder( encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming) num_decoding_left_chunks, simulate_streaming)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
# (TODO Hui Zhang): bool no support reduce_sum encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
...@@ -459,7 +455,7 @@ class U2BaseModel(nn.Layer): ...@@ -459,7 +455,7 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim) simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
...@@ -587,7 +583,7 @@ class U2BaseModel(nn.Layer): ...@@ -587,7 +583,7 @@ class U2BaseModel(nn.Layer):
encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool) (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder( decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size) hyps_lens) # (beam_size, max_hyps_len, vocab_size)
...@@ -667,9 +663,7 @@ class U2BaseModel(nn.Layer): ...@@ -667,9 +663,7 @@ class U2BaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache, xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache) elayers_output_cache, conformer_cnn_cache)
# @jit.to_static([ # @jit.to_static
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log """ Export interface for c++ call, apply linear transform and log
softmax before ctc softmax before ctc
...@@ -696,13 +690,13 @@ class U2BaseModel(nn.Layer): ...@@ -696,13 +690,13 @@ class U2BaseModel(nn.Layer):
Returns: Returns:
paddle.Tensor: decoder output, (B, L) paddle.Tensor: decoder output, (B, L)
""" """
assert encoder_out.size(0) == 1 assert encoder_out.shape[0] == 1
num_hyps = hyps.size(0) num_hyps = hyps.shape[0]
assert hyps_lens.size(0) == num_hyps assert hyps_lens.shape[0] == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1) encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T) # (B, 1, T)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size) # (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens) hyps_lens)
...@@ -757,7 +751,7 @@ class U2BaseModel(nn.Layer): ...@@ -757,7 +751,7 @@ class U2BaseModel(nn.Layer):
Returns: Returns:
List[List[int]]: transcripts. List[List[int]]: transcripts.
""" """
batch_size = feats.size(0) batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search', if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1: 'attention_rescoring'] and batch_size > 1:
logger.fatal( logger.fatal(
...@@ -785,7 +779,7 @@ class U2BaseModel(nn.Layer): ...@@ -785,7 +779,7 @@ class U2BaseModel(nn.Layer):
# result in List[int], change it to List[List[int]] for compatible # result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode # with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search': elif decoding_method == 'ctc_prefix_beam_search':
assert feats.size(0) == 1 assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search( hyp = self.ctc_prefix_beam_search(
feats, feats,
feats_lengths, feats_lengths,
...@@ -795,7 +789,7 @@ class U2BaseModel(nn.Layer): ...@@ -795,7 +789,7 @@ class U2BaseModel(nn.Layer):
simulate_streaming=simulate_streaming) simulate_streaming=simulate_streaming)
hyps = [hyp] hyps = [hyp]
elif decoding_method == 'attention_rescoring': elif decoding_method == 'attention_rescoring':
assert feats.size(0) == 1 assert feats.shape[0] == 1
hyp = self.attention_rescoring( hyp = self.attention_rescoring(
feats, feats,
feats_lengths, feats_lengths,
...@@ -836,6 +830,7 @@ class U2Model(U2BaseModel): ...@@ -836,6 +830,7 @@ class U2Model(U2BaseModel):
Returns: Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
""" """
# cmvn
if configs['cmvn_file'] is not None: if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type']) configs['cmvn_file_type'])
...@@ -845,11 +840,13 @@ class U2Model(U2BaseModel): ...@@ -845,11 +840,13 @@ class U2Model(U2BaseModel):
else: else:
global_cmvn = None global_cmvn = None
# input & output dim
input_dim = configs['input_dim'] input_dim = configs['input_dim']
vocab_size = configs['output_dim'] vocab_size = configs['output_dim']
assert input_dim != 0, input_dim assert input_dim != 0, input_dim
assert vocab_size != 0, vocab_size assert vocab_size != 0, vocab_size
# encoder
encoder_type = configs.get('encoder', 'transformer') encoder_type = configs.get('encoder', 'transformer')
logger.info(f"U2 Encoder type: {encoder_type}") logger.info(f"U2 Encoder type: {encoder_type}")
if encoder_type == 'transformer': if encoder_type == 'transformer':
...@@ -861,16 +858,21 @@ class U2Model(U2BaseModel): ...@@ -861,16 +858,21 @@ class U2Model(U2BaseModel):
else: else:
raise ValueError(f"not support encoder type:{encoder_type}") raise ValueError(f"not support encoder type:{encoder_type}")
# decoder
decoder = TransformerDecoder(vocab_size, decoder = TransformerDecoder(vocab_size,
encoder.output_size(), encoder.output_size(),
**configs['decoder_conf']) **configs['decoder_conf'])
# ctc decoder and ctc loss
model_conf = configs['model_conf']
ctc = CTCDecoder( ctc = CTCDecoder(
odim=vocab_size, odim=vocab_size,
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,
dropout_rate=0.0, dropout_rate=model_conf['ctc_dropoutrate'],
reduction=True, # sum reduction=True, # sum
batch_average=True) # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=model_conf['ctc_grad_norm_type'])
return vocab_size, encoder, decoder, ctc return vocab_size, encoder, decoder, ctc
...@@ -902,10 +904,10 @@ class U2Model(U2BaseModel): ...@@ -902,10 +904,10 @@ class U2Model(U2BaseModel):
Returns: Returns:
DeepSpeech2Model: The model built from pretrained result. DeepSpeech2Model: The model built from pretrained result.
""" """
config.defrost() with UpdateConfig(config):
config.input_dim = dataloader.collate_fn.feature_size config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size config.output_dim = dataloader.collate_fn.vocab_size
config.freeze()
model = cls.from_config(config) model = cls.from_config(config)
if checkpoint_path: if checkpoint_path:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import paddle
from paddle import distributed as dist
from deepspeech.training.extensions.evaluator import StandardEvaluator
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.training.updaters.standard_updater import StandardUpdater
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class U2Evaluator(StandardEvaluator):
def __init__(self, model, dataloader):
super().__init__(model, dataloader)
self.msg = ""
self.num_seen_utts = 0
self.total_loss = 0.0
def evaluate_core(self, batch):
self.msg = "Valid: Rank: {}, ".format(dist.get_rank())
losses_dict = {}
loss, attention_loss, ctc_loss = self.model(*batch[1:])
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
self.num_seen_utts += num_utts
self.total_loss += float(loss) * num_utts
losses_dict['loss'] = float(loss)
if attention_loss:
losses_dict['att_loss'] = float(attention_loss)
if ctc_loss:
losses_dict['ctc_loss'] = float(ctc_loss)
for k, v in losses_dict.items():
report("eval/" + k, v)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
logger.info(self.msg)
return self.total_loss, self.num_seen_utts
class U2Updater(StandardUpdater):
def __init__(self,
model,
optimizer,
scheduler,
dataloader,
init_state=None,
accum_grad=1,
**kwargs):
super().__init__(
model, optimizer, scheduler, dataloader, init_state=init_state)
self.accum_grad = accum_grad
self.forward_count = 0
self.msg = ""
def update_core(self, batch):
"""One Step
Args:
batch (List[Object]): utts, xs, xlens, ys, ylens
"""
losses_dict = {}
self.msg = "Rank: {}, ".format(dist.get_rank())
# forward
batch_size = batch[1].shape[0]
loss, attention_loss, ctc_loss = self.model(*batch[1:])
# loss div by `batch_size * accum_grad`
loss /= self.accum_grad
# loss backward
if (self.forward_count + 1) != self.accum_grad:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# loss info
losses_dict['loss'] = float(loss) * self.accum_grad
if attention_loss:
losses_dict['att_loss'] = float(attention_loss)
if ctc_loss:
losses_dict['ctc_loss'] = float(ctc_loss)
# report loss
for k, v in losses_dict.items():
report("train/" + k, v)
# loss msg
self.msg += "batch size: {}, ".format(batch_size)
self.msg += "accum: {}, ".format(self.accum_grad)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
# Truncate the graph
loss.detach()
# update parameters
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
self.optimizer.step()
self.optimizer.clear_grad()
self.scheduler.step()
def update(self):
# model is default in train mode
# training for a step is implemented here
with Timer("data time cost:{}"):
batch = self.read_batch()
with Timer("step time cost:{}"):
self.update_core(batch)
# #iterations with accum_grad > 1
# Ref.: https://github.com/espnet/espnet/issues/777
if self.forward_count == 0:
self.state.iteration += 1
if self.updates_per_epoch is not None:
if self.state.iteration % self.updates_per_epoch == 0:
self.state.epoch += 1
...@@ -42,6 +42,7 @@ from deepspeech.utils import layer_tools ...@@ -42,6 +42,7 @@ from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"] __all__ = ["U2STModel", "U2STInferModel"]
...@@ -163,10 +164,7 @@ class U2STBaseModel(nn.Layer): ...@@ -163,10 +164,7 @@ class U2STBaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}") #logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. ST-decoder branch # 2a. ST-decoder branch
start = time.time() start = time.time()
...@@ -342,8 +340,8 @@ class U2STBaseModel(nn.Layer): ...@@ -342,8 +340,8 @@ class U2STBaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim) simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.size(2) encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
...@@ -363,8 +361,7 @@ class U2STBaseModel(nn.Layer): ...@@ -363,8 +361,7 @@ class U2STBaseModel(nn.Layer):
# 2. Decoder forward step by step # 2. Decoder forward step by step
for i in range(1, maxlen + 1): for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos # Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
...@@ -417,26 +414,26 @@ class U2STBaseModel(nn.Layer): ...@@ -417,26 +414,26 @@ class U2STBaseModel(nn.Layer):
best_hyps = best_hyps[:, 1:] best_hyps = best_hyps[:, 1:]
return best_hyps return best_hyps
@jit.to_static # @jit.to_static
def subsampling_rate(self) -> int: def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the """ Export interface for c++ call, return subsampling_rate of the
model model
""" """
return self.encoder.embed.subsampling_rate return self.encoder.embed.subsampling_rate
@jit.to_static # @jit.to_static
def right_context(self) -> int: def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model """ Export interface for c++ call, return right_context of the model
""" """
return self.encoder.embed.right_context return self.encoder.embed.right_context
@jit.to_static # @jit.to_static
def sos_symbol(self) -> int: def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model """ Export interface for c++ call, return sos symbol id of the model
""" """
return self.sos return self.sos
@jit.to_static # @jit.to_static
def eos_symbol(self) -> int: def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model """ Export interface for c++ call, return eos symbol id of the model
""" """
...@@ -472,7 +469,7 @@ class U2STBaseModel(nn.Layer): ...@@ -472,7 +469,7 @@ class U2STBaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache, xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache) elayers_output_cache, conformer_cnn_cache)
@jit.to_static # @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log """ Export interface for c++ call, apply linear transform and log
softmax before ctc softmax before ctc
...@@ -499,13 +496,13 @@ class U2STBaseModel(nn.Layer): ...@@ -499,13 +496,13 @@ class U2STBaseModel(nn.Layer):
Returns: Returns:
paddle.Tensor: decoder output, (B, L) paddle.Tensor: decoder output, (B, L)
""" """
assert encoder_out.size(0) == 1 assert encoder_out.shape[0] == 1
num_hyps = hyps.size(0) num_hyps = hyps.shape[0]
assert hyps_lens.size(0) == num_hyps assert hyps_lens.shape[0] == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1) encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T) # (B, 1, T)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size) # (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens) hyps_lens)
...@@ -560,7 +557,7 @@ class U2STBaseModel(nn.Layer): ...@@ -560,7 +557,7 @@ class U2STBaseModel(nn.Layer):
Returns: Returns:
List[List[int]]: transcripts. List[List[int]]: transcripts.
""" """
batch_size = feats.size(0) batch_size = feats.shape[0]
if decoding_method == 'fullsentence': if decoding_method == 'fullsentence':
hyps = self.translate( hyps = self.translate(
...@@ -647,13 +644,16 @@ class U2STModel(U2STBaseModel): ...@@ -647,13 +644,16 @@ class U2STModel(U2STBaseModel):
decoder = TransformerDecoder(vocab_size, decoder = TransformerDecoder(vocab_size,
encoder.output_size(), encoder.output_size(),
**configs['decoder_conf']) **configs['decoder_conf'])
# ctc decoder and ctc loss
model_conf = configs['model_conf']
ctc = CTCDecoder( ctc = CTCDecoder(
odim=vocab_size, odim=vocab_size,
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,
dropout_rate=0.0, dropout_rate=model_conf['ctc_dropout_rate'],
reduction=True, # sum reduction=True, # sum
batch_average=True) # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=model_conf['ctc_grad_norm_type'])
return vocab_size, encoder, (st_decoder, decoder, ctc) return vocab_size, encoder, (st_decoder, decoder, ctc)
else: else:
...@@ -687,10 +687,10 @@ class U2STModel(U2STBaseModel): ...@@ -687,10 +687,10 @@ class U2STModel(U2STBaseModel):
Returns: Returns:
DeepSpeech2Model: The model built from pretrained result. DeepSpeech2Model: The model built from pretrained result.
""" """
config.defrost() with UpdateConfig(config):
config.input_dim = dataloader.collate_fn.feature_size config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size config.output_dim = dataloader.collate_fn.vocab_size
config.freeze()
model = cls.from_config(config) model = cls.from_config(config)
if checkpoint_path: if checkpoint_path:
......
...@@ -15,12 +15,13 @@ from collections import OrderedDict ...@@ -15,12 +15,13 @@ from collections import OrderedDict
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] __all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"]
def brelu(x, t_min=0.0, t_max=24.0, name=None): def brelu(x, t_min=0.0, t_max=24.0, name=None):
...@@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): ...@@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return x.maximum(t_min).minimum(t_max) return x.maximum(t_min).minimum(t_max)
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
class LinearGLUBlock(nn.Layer): class LinearGLUBlock(nn.Layer):
"""A linear Gated Linear Units (GLU) block.""" """A linear Gated Linear Units (GLU) block."""
...@@ -133,13 +145,18 @@ def get_activation(act): ...@@ -133,13 +145,18 @@ def get_activation(act):
"""Return activation function.""" """Return activation function."""
# Lazy load to avoid unused import # Lazy load to avoid unused import
activation_funcs = { activation_funcs = {
"hardshrink": paddle.nn.Hardshrink,
"hardswish": paddle.nn.Hardswish,
"hardtanh": paddle.nn.Hardtanh, "hardtanh": paddle.nn.Hardtanh,
"tanh": paddle.nn.Tanh, "tanh": paddle.nn.Tanh,
"relu": paddle.nn.ReLU, "relu": paddle.nn.ReLU,
"relu6": paddle.nn.ReLU6,
"leakyrelu": paddle.nn.LeakyReLU,
"selu": paddle.nn.SELU, "selu": paddle.nn.SELU,
"swish": paddle.nn.Swish, "swish": paddle.nn.Swish,
"gelu": paddle.nn.GELU, "gelu": paddle.nn.GELU,
"brelu": brelu, "glu": GLU,
"elu": paddle.nn.ELU,
} }
return activation_funcs[act]() return activation_funcs[act]()
...@@ -70,7 +70,7 @@ class MultiHeadedAttention(nn.Layer): ...@@ -70,7 +70,7 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value tensor, size paddle.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k). (#batch, n_head, time2, d_k).
""" """
n_batch = query.size(0) n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
...@@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer): ...@@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value weighted paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model). by the attention score, (#batch, time1, d_model).
""" """
n_batch = value.size(0) n_batch = value.shape[0]
if mask is not None: if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf')) scores = scores.masked_fill(mask, -float('inf'))
...@@ -109,8 +109,8 @@ class MultiHeadedAttention(nn.Layer): ...@@ -109,8 +109,8 @@ class MultiHeadedAttention(nn.Layer):
p_attn = self.dropout(attn) p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).contiguous().view( x = x.transpose([0, 2, 1, 3]).view(n_batch, -1, self.h *
n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) self.d_k) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)
...@@ -172,15 +172,16 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -172,15 +172,16 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
paddle.Tensor: Output tensor. (batch, head, time1, time1) paddle.Tensor: Output tensor. (batch, head, time1, time1)
""" """
zero_pad = paddle.zeros( zero_pad = paddle.zeros(
(x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype) (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1) x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu: if zero_triu:
ones = paddle.ones((x.size(2), x.size(3))) ones = paddle.ones((x.shape[2], x.shape[3]))
x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :] x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :]
return x return x
...@@ -205,7 +206,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -205,7 +206,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0) n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
......
...@@ -113,11 +113,9 @@ class ConvBn(nn.Layer): ...@@ -113,11 +113,9 @@ class ConvBn(nn.Layer):
# reset padding part to 0 # reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T] masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply # https://github.com/PaddlePaddle/Paddle/pull/29265
# masks = masks.type_as(x) # rhs will type promote to lhs
masks = masks.astype(x.dtype) x = x * masks
x = x.multiply(masks)
return x, x_len return x, x_len
......
...@@ -16,15 +16,19 @@ from paddle import nn ...@@ -16,15 +16,19 @@ from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from typeguard import check_argument_types from typeguard import check_argument_types
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder
from deepspeech.decoders.swig_wrapper import Scorer
from deepspeech.modules.loss import CTCLoss from deepspeech.modules.loss import CTCLoss
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
try:
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
except Exception as e:
logger.info("ctcdecoder not installed!")
__all__ = ['CTCDecoder'] __all__ = ['CTCDecoder']
...@@ -35,7 +39,8 @@ class CTCDecoder(nn.Layer): ...@@ -35,7 +39,8 @@ class CTCDecoder(nn.Layer):
blank_id=0, blank_id=0,
dropout_rate: float=0.0, dropout_rate: float=0.0,
reduction: bool=True, reduction: bool=True,
batch_average: bool=True): batch_average: bool=True,
grad_norm_type: str="instance"):
"""CTC decoder """CTC decoder
Args: Args:
...@@ -44,6 +49,7 @@ class CTCDecoder(nn.Layer): ...@@ -44,6 +49,7 @@ class CTCDecoder(nn.Layer):
dropout_rate (float): dropout rate (0.0 ~ 1.0) dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average. batch_average (bool): do batch dim wise average.
grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None.
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
...@@ -56,7 +62,8 @@ class CTCDecoder(nn.Layer): ...@@ -56,7 +62,8 @@ class CTCDecoder(nn.Layer):
self.criterion = CTCLoss( self.criterion = CTCLoss(
blank=self.blank_id, blank=self.blank_id,
reduction=reduction_type, reduction=reduction_type,
batch_average=batch_average) batch_average=batch_average,
grad_norm_type=grad_norm_type)
# CTCDecoder LM Score handle # CTCDecoder LM Score handle
self._ext_scorer = None self._ext_scorer = None
...@@ -132,7 +139,7 @@ class CTCDecoder(nn.Layer): ...@@ -132,7 +139,7 @@ class CTCDecoder(nn.Layer):
results = [] results = []
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder( output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list) probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id)
results.append(output_transcription) results.append(output_transcription)
return results return results
...@@ -212,13 +219,15 @@ class CTCDecoder(nn.Layer): ...@@ -212,13 +219,15 @@ class CTCDecoder(nn.Layer):
num_processes=num_processes, num_processes=num_processes,
ext_scoring_func=self._ext_scorer, ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob, cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n) cutoff_top_n=cutoff_top_n,
blank_id=self.blank_id)
results = [result[0][1] for result in beam_search_results] results = [result[0][1] for result in beam_search_results]
return results return results
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
decoding_method): decoding_method):
if decoding_method == "ctc_beam_search": if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list) vocab_list)
......
...@@ -122,11 +122,9 @@ class TransformerDecoder(nn.Layer): ...@@ -122,11 +122,9 @@ class TransformerDecoder(nn.Layer):
# tgt_mask: (B, 1, L) # tgt_mask: (B, 1, L)
tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1)) tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1))
# m: (1, L, L) # m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0) m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L) # tgt_mask: (B, L, L)
# TODO(Hui Zhang): not support & for tensor tgt_mask = tgt_mask & m
# tgt_mask = tgt_mask & m
tgt_mask = tgt_mask.logical_and(m)
x, _ = self.embed(tgt) x, _ = self.embed(tgt)
for layer in self.decoders: for layer in self.decoders:
...@@ -137,9 +135,7 @@ class TransformerDecoder(nn.Layer): ...@@ -137,9 +135,7 @@ class TransformerDecoder(nn.Layer):
if self.use_output_layer: if self.use_output_layer:
x = self.output_layer(x) x = self.output_layer(x)
# TODO(Hui Zhang): reduce_sum not support bool type olens = tgt_mask.sum(1)
# olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens return x, olens
def forward_one_step( def forward_one_step(
......
...@@ -68,7 +68,7 @@ class PositionalEncoding(nn.Layer): ...@@ -68,7 +68,7 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
""" """
T = x.shape[1] T = x.shape[1]
assert offset + x.size(1) < self.max_len assert offset + x.shape[1] < self.max_len
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T] pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb x = x * self.xscale + pos_emb
...@@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding): ...@@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`).
""" """
assert offset + x.size(1) < self.max_len assert offset + x.shape[1] < self.max_len
x = x * self.xscale x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]] pos_emb = self.pe[:, offset:offset + x.shape[1]]
......
...@@ -159,11 +159,10 @@ class BaseEncoder(nn.Layer): ...@@ -159,11 +159,10 @@ class BaseEncoder(nn.Layer):
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0) xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool) masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks mask_pad = ~masks
mask_pad = masks.logical_not()
chunk_masks = add_optional_chunk_mask( chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,
...@@ -207,11 +206,11 @@ class BaseEncoder(nn.Layer): ...@@ -207,11 +206,11 @@ class BaseEncoder(nn.Layer):
chunk computation chunk computation
List[paddle.Tensor]: conformer cnn cache List[paddle.Tensor]: conformer cnn cache
""" """
assert xs.size(0) == 1 # batch size must be one assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility # tmp_masks is just for interface compatibility
# TODO(Hui Zhang): stride_slice not support bool tensor # TODO(Hui Zhang): stride_slice not support bool tensor
# tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) # tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.int32) tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T] tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
if self.global_cmvn is not None: if self.global_cmvn is not None:
...@@ -221,25 +220,25 @@ class BaseEncoder(nn.Layer): ...@@ -221,25 +220,25 @@ class BaseEncoder(nn.Layer):
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
if subsampling_cache is not None: if subsampling_cache is not None:
cache_size = subsampling_cache.size(1) #T cache_size = subsampling_cache.shape[1] #T
xs = paddle.cat((subsampling_cache, xs), dim=1) xs = paddle.cat((subsampling_cache, xs), dim=1)
else: else:
cache_size = 0 cache_size = 0
# only used when using `RelPositionMultiHeadedAttention` # only used when using `RelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding( pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.size(1)) offset=offset - cache_size, size=xs.shape[1])
if required_cache_size < 0: if required_cache_size < 0:
next_cache_start = 0 next_cache_start = 0
elif required_cache_size == 0: elif required_cache_size == 0:
next_cache_start = xs.size(1) next_cache_start = xs.shape[1]
else: else:
next_cache_start = xs.size(1) - required_cache_size next_cache_start = xs.shape[1] - required_cache_size
r_subsampling_cache = xs[:, next_cache_start:, :] r_subsampling_cache = xs[:, next_cache_start:, :]
# Real mask for transformer/conformer layers # Real mask for transformer/conformer layers
masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, L'=1, T] masks = masks.unsqueeze(1) #[B=1, L'=1, T]
r_elayers_output_cache = [] r_elayers_output_cache = []
r_conformer_cnn_cache = [] r_conformer_cnn_cache = []
...@@ -303,7 +302,7 @@ class BaseEncoder(nn.Layer): ...@@ -303,7 +302,7 @@ class BaseEncoder(nn.Layer):
stride = subsampling * decoding_chunk_size stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1) num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks
subsampling_cache: Optional[paddle.Tensor] = None subsampling_cache: Optional[paddle.Tensor] = None
elayers_output_cache: Optional[List[paddle.Tensor]] = None elayers_output_cache: Optional[List[paddle.Tensor]] = None
...@@ -319,10 +318,10 @@ class BaseEncoder(nn.Layer): ...@@ -319,10 +318,10 @@ class BaseEncoder(nn.Layer):
chunk_xs, offset, required_cache_size, subsampling_cache, chunk_xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache) elayers_output_cache, conformer_cnn_cache)
outputs.append(y) outputs.append(y)
offset += y.size(1) offset += y.shape[1]
ys = paddle.cat(outputs, 1) ys = paddle.cat(outputs, 1)
# fake mask, just for jit script and compatibility with `forward` api # fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.size(1)], dtype=paddle.bool) masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) masks = masks.unsqueeze(1)
return ys, masks return ys, masks
......
...@@ -23,11 +23,32 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] ...@@ -23,11 +23,32 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"]
class CTCLoss(nn.Layer): class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum', batch_average=False): def __init__(self,
blank=0,
reduction='sum',
batch_average=False,
grad_norm_type=None):
super().__init__() super().__init__()
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average self.batch_average = batch_average
logger.info(
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
# instance for norm_by_times
# batch for norm_by_batchsize
# frame for norm_by_total_logits_len
assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False
self.norm_by_batchsize = False
self.norm_by_total_logits_len = False
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance':
self.norm_by_times = True
if grad_norm_type == 'batch':
self.norm_by_batchsize = True
if grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
def forward(self, logits, ys_pad, hlens, ys_lens): def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss. """Compute CTC loss.
...@@ -46,10 +67,15 @@ class CTCLoss(nn.Layer): ...@@ -46,10 +67,15 @@ class CTCLoss(nn.Layer):
# warp-ctc need activation with shape [T, B, V + 1] # warp-ctc need activation with shape [T, B, V + 1]
# logits: (B, L, D) -> (L, B, D) # logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2]) logits = logits.transpose([1, 0, 2])
# (TODO:Hui Zhang) ctc loss does not support int64 labels
ys_pad = ys_pad.astype(paddle.int32) ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss( loss = self.loss(
logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average) logits,
ys_pad,
hlens,
ys_lens,
norm_by_times=self.norm_by_times,
norm_by_batchsize=self.norm_by_batchsize,
norm_by_total_logits_len=self.norm_by_total_logits_len)
if self.batch_average: if self.batch_average:
# Batch-size average # Batch-size average
loss = loss / B loss = loss / B
...@@ -124,9 +150,9 @@ class LabelSmoothingLoss(nn.Layer): ...@@ -124,9 +150,9 @@ class LabelSmoothingLoss(nn.Layer):
# use zeros_like instead of torch.no_grad() for true_dist, # use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT # since no_grad() can not be exported by JIT
true_dist = paddle.full_like(x, self.smoothing / (self.size - 1)) true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,) ignore = (target == self.padding_idx) # (B,)
# target = target * (1 - ignore) # avoid -1 index #TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index
target = target.masked_fill(ignore, 0) # avoid -1 index target = target.masked_fill(ignore, 0) # avoid -1 index
# true_dist.scatter_(1, target.unsqueeze(1), self.confidence) # true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
target_mask = F.one_hot(target, self.size) target_mask = F.one_hot(target, self.size)
...@@ -135,10 +161,8 @@ class LabelSmoothingLoss(nn.Layer): ...@@ -135,10 +161,8 @@ class LabelSmoothingLoss(nn.Layer):
kl = self.criterion(F.log_softmax(x, axis=1), true_dist) kl = self.criterion(F.log_softmax(x, axis=1), true_dist)
#TODO(Hui Zhang): sum not support bool type total = len(target) - int(ignore.sum())
#total = len(target) - int(ignore.sum())
total = len(target) - int(ignore.type_as(target).sum())
denom = total if self.normalize_length else B denom = total if self.normalize_length else B
#numer = (kl * (1 - ignore)).sum() #TODO(Hui Zhang): numer = (kl * (1 - ignore)).sum()
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum() numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
return numer / denom return numer / denom
...@@ -69,8 +69,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: ...@@ -69,8 +69,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0], [1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]] [1, 1, 0, 0, 0]]
""" """
#TODO(Hui Zhang): return ~make_pad_mask(lengths), not support ~ return ~make_pad_mask(lengths)
return make_pad_mask(lengths).logical_not()
def subsequent_mask(size: int) -> paddle.Tensor: def subsequent_mask(size: int) -> paddle.Tensor:
...@@ -92,12 +91,7 @@ def subsequent_mask(size: int) -> paddle.Tensor: ...@@ -92,12 +91,7 @@ def subsequent_mask(size: int) -> paddle.Tensor:
[1, 1, 1]] [1, 1, 1]]
""" """
ret = paddle.ones([size, size], dtype=paddle.bool) ret = paddle.ones([size, size], dtype=paddle.bool)
#TODO(Hui Zhang): tril not support bool return paddle.tril(ret)
#return paddle.tril(ret)
ret = ret.astype(paddle.float)
ret = paddle.tril(ret)
ret = ret.astype(paddle.bool)
return ret
def subsequent_chunk_mask( def subsequent_chunk_mask(
...@@ -186,15 +180,13 @@ def add_optional_chunk_mask(xs: paddle.Tensor, ...@@ -186,15 +180,13 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size, chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
num_left_chunks) # (L, L) num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L) chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
elif static_chunk_size > 0: elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size, chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
num_left_chunks) # (L, L) num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L) chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
else: else:
chunk_masks = masks chunk_masks = masks
return chunk_masks return chunk_masks
......
...@@ -308,7 +308,7 @@ class RNNStack(nn.Layer): ...@@ -308,7 +308,7 @@ class RNNStack(nn.Layer):
x, x_len = rnn(x, x_len) x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T] masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1] masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply # https://github.com/PaddlePaddle/Paddle/pull/29265
masks = masks.astype(x.dtype) # rhs will type promote to lhs
x = x.multiply(masks) x = x * masks
return x, x_len return x, x_len
...@@ -14,6 +14,20 @@ ...@@ -14,6 +14,20 @@
import argparse import argparse
class ExtendAction(argparse.Action):
"""
[Since Python 3.8, the "extend" is available directly in stdlib]
(https://docs.python.org/3.8/library/argparse.html#action).
If you only have to support 3.8+ then defining it yourself is no longer required.
Usage of stdlib "extend" action is exactly the same way as this answer originally described:
"""
def __call__(self, parser, namespace, values, option_string=None):
items = getattr(namespace, self.dest) or []
items.extend(values)
setattr(namespace, self.dest, items)
def default_argument_parser(): def default_argument_parser():
r"""A simple yet genral argument parser for experiments with parakeet. r"""A simple yet genral argument parser for experiments with parakeet.
...@@ -30,7 +44,7 @@ def default_argument_parser(): ...@@ -30,7 +44,7 @@ def default_argument_parser():
The ``--checkpoint_path`` specifies the checkpoint to load from. The ``--checkpoint_path`` specifies the checkpoint to load from.
The ``--device`` and ``--nprocs`` specifies how to run the training. The ``--nprocs`` specifies how to run the training.
See Also See Also
...@@ -42,29 +56,53 @@ def default_argument_parser(): ...@@ -42,29 +56,53 @@ def default_argument_parser():
the parser the parser
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.register('action', 'extend', ExtendAction)
# yapf: disable train_group = parser.add_argument_group(
# data and output title='Train Options', description=None)
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") train_group.add_argument(
parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") "--seed",
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") type=int,
default=None,
# load from saved checkpoint help="seed to use for paddle, np and random. None or 0 for random, else set seed."
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") )
train_group.add_argument(
# running "--nprocs",
parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], type=int,
help="device type to use, cpu and gpu are supported.") default=1,
parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") help="number of parallel processes. 0 for cpu.")
train_group.add_argument(
# overwrite extra config and default config "--config", metavar="CONFIG_FILE", help="config file.")
# parser.add_argument("--opts", nargs=argparse.REMAINDER, train_group.add_argument(
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") "--output", metavar="CKPT_DIR", help="path to save checkpoint.")
parser.add_argument("--opts", type=str, default=[], nargs='+', train_group.add_argument(
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") "--checkpoint_path", type=str, help="path to load checkpoint")
train_group.add_argument(
"--opts",
action='extend',
nargs=2,
metavar=('key', 'val'),
help="overwrite --config field, passing (KEY VALUE) pairs")
train_group.add_argument(
"--dump-config", metavar="FILE", help="dump config to `this` file.")
parser.add_argument("--seed", type=int, default=None, profile_group = parser.add_argument_group(
help="seed to use for paddle, np and random. None or 0 for random, else set seed.") title='Benchmark Options', description=None)
# yapd: enable profile_group.add_argument(
'--profiler-options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
profile_group.add_argument(
'--benchmark-batch-size',
type=int,
default=None,
help='batch size for benchmark.')
profile_group.add_argument(
'--benchmark-max-step',
type=int,
default=None,
help='max iteration for benchmark.')
return parser return parser
...@@ -13,14 +13,18 @@ ...@@ -13,14 +13,18 @@
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Dict
import extension
import paddle import paddle
from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
from . import extension
from ..reporter import DictSummary from ..reporter import DictSummary
from ..reporter import ObsScope
from ..reporter import report from ..reporter import report
from ..reporter import scope from ..timer import Timer
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class StandardEvaluator(extension.Extension): class StandardEvaluator(extension.Extension):
...@@ -43,6 +47,27 @@ class StandardEvaluator(extension.Extension): ...@@ -43,6 +47,27 @@ class StandardEvaluator(extension.Extension):
def evaluate_core(self, batch): def evaluate_core(self, batch):
# compute # compute
self.model(batch) # you may report here self.model(batch) # you may report here
return
def evaluate_sync(self, data):
# dist sync `evaluate_core` outputs
if data is None:
return
numerator, denominator = data
if dist.get_world_size() > 1:
numerator = paddle.to_tensor(numerator)
denominator = paddle.to_tensor(denominator)
# the default operator in all_reduce function is sum.
dist.all_reduce(numerator)
dist.all_reduce(denominator)
value = numerator / denominator
value = float(value)
else:
value = numerator / denominator
# used for `snapshort` to do kbest save.
report("VALID/LOSS", value)
logger.info(f"Valid: all-reduce loss {value}")
def evaluate(self): def evaluate(self):
# switch to eval mode # switch to eval mode
...@@ -53,12 +78,16 @@ class StandardEvaluator(extension.Extension): ...@@ -53,12 +78,16 @@ class StandardEvaluator(extension.Extension):
summary = DictSummary() summary = DictSummary()
for batch in self.dataloader: for batch in self.dataloader:
observation = {} observation = {}
with scope(observation): with ObsScope(observation):
# main evaluation computation here. # main evaluation computation here.
with paddle.no_grad(): with paddle.no_grad():
self.evaluate_core(batch) self.evaluate_sync(self.evaluate_core(batch))
summary.add(observation) summary.add(observation)
summary = summary.compute_mean() summary = summary.compute_mean()
# switch to train mode
for model in self.models.values():
model.train()
return summary return summary
def __call__(self, trainer=None): def __call__(self, trainer=None):
...@@ -66,6 +95,7 @@ class StandardEvaluator(extension.Extension): ...@@ -66,6 +95,7 @@ class StandardEvaluator(extension.Extension):
# if it is used to extend a trainer, the metrics is reported to # if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer # to observation of the trainer
# or otherwise, you can use your own observation # or otherwise, you can use your own observation
with Timer("Eval Time Cost: {}"):
summary = self.evaluate() summary = self.evaluate()
for k, v in summary.items(): for k, v in summary.items():
report(k, v) report(k, v)
...@@ -20,8 +20,9 @@ from typing import List ...@@ -20,8 +20,9 @@ from typing import List
import jsonlines import jsonlines
from deepspeech.training.extensions import extension from . import extension
from deepspeech.training.updaters.trainer import Trainer from ..reporter import get_observations
from ..updaters.trainer import Trainer
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only from deepspeech.utils.mp_tools import rank_zero_only
...@@ -52,8 +53,19 @@ class Snapshot(extension.Extension): ...@@ -52,8 +53,19 @@ class Snapshot(extension.Extension):
priority = -100 priority = -100
default_name = "snapshot" default_name = "snapshot"
def __init__(self, max_size: int=5, snapshot_on_error: bool=False): def __init__(self,
mode='latest',
max_size: int=5,
indicator=None,
less_better=True,
snapshot_on_error: bool=False):
self.records: List[Dict[str, Any]] = [] self.records: List[Dict[str, Any]] = []
assert mode in ('latest', 'kbest'), mode
if mode == 'kbest':
assert indicator is not None
self.mode = mode
self.indicator = indicator
self.less_is_better = less_better
self.max_size = max_size self.max_size = max_size
self._snapshot_on_error = snapshot_on_error self._snapshot_on_error = snapshot_on_error
self._save_all = (max_size == -1) self._save_all = (max_size == -1)
...@@ -66,16 +78,17 @@ class Snapshot(extension.Extension): ...@@ -66,16 +78,17 @@ class Snapshot(extension.Extension):
# load existing records # load existing records
record_path: Path = self.checkpoint_dir / "records.jsonl" record_path: Path = self.checkpoint_dir / "records.jsonl"
if record_path.exists(): if record_path.exists():
logger.debug("Loading from an existing checkpoint dir")
self.records = load_records(record_path) self.records = load_records(record_path)
trainer.updater.load(self.records[-1]['path']) ckpt_path = self.records[-1]['path']
logger.info(f"Loading from an existing checkpoint {ckpt_path}")
trainer.updater.load(ckpt_path)
def on_error(self, trainer, exc, tb): def on_error(self, trainer, exc, tb):
if self._snapshot_on_error: if self._snapshot_on_error:
self.save_checkpoint_and_update(trainer) self.save_checkpoint_and_update(trainer, 'latest')
def __call__(self, trainer: Trainer): def __call__(self, trainer: Trainer):
self.save_checkpoint_and_update(trainer) self.save_checkpoint_and_update(trainer, self.mode)
def full(self): def full(self):
"""Whether the number of snapshots it keeps track of is greater """Whether the number of snapshots it keeps track of is greater
...@@ -83,12 +96,12 @@ class Snapshot(extension.Extension): ...@@ -83,12 +96,12 @@ class Snapshot(extension.Extension):
return (not self._save_all) and len(self.records) > self.max_size return (not self._save_all) and len(self.records) > self.max_size
@rank_zero_only @rank_zero_only
def save_checkpoint_and_update(self, trainer: Trainer): def save_checkpoint_and_update(self, trainer: Trainer, mode: str):
"""Saving new snapshot and remove the oldest snapshot if needed.""" """Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] == 'epoch' else iteration num = epoch if self.trigger[1] == 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.pdz" path = self.checkpoint_dir / f"{num}.np"
# add the new one # add the new one
trainer.updater.save(path) trainer.updater.save(path)
...@@ -97,11 +110,17 @@ class Snapshot(extension.Extension): ...@@ -97,11 +110,17 @@ class Snapshot(extension.Extension):
'path': str(path.resolve()), # use absolute path 'path': str(path.resolve()), # use absolute path
'iteration': iteration, 'iteration': iteration,
'epoch': epoch, 'epoch': epoch,
'indicator': get_observations()[self.indicator]
} }
self.records.append(record) self.records.append(record)
# remove the earist # remove the earist
if self.full(): if self.full():
if mode == 'kbest':
self.records = sorted(
self.records,
key=lambda record: record['indicator'],
reverse=not self.less_is_better)
eariest_record = self.records[0] eariest_record = self.records[0]
os.remove(eariest_record["path"]) os.remove(eariest_record["path"])
self.records.pop(0) self.records.pop(0)
......
...@@ -11,8 +11,10 @@ ...@@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from deepspeech.training.extensions import extension from visualdl import LogWriter
from deepspeech.training.updaters.trainer import Trainer
from . import extension
from ..updaters.trainer import Trainer
class VisualDL(extension.Extension): class VisualDL(extension.Extension):
...@@ -26,8 +28,8 @@ class VisualDL(extension.Extension): ...@@ -26,8 +28,8 @@ class VisualDL(extension.Extension):
default_name = 'visualdl' default_name = 'visualdl'
priority = extension.PRIORITY_READER priority = extension.PRIORITY_READER
def __init__(self, writer): def __init__(self, output_dir):
self.writer = writer self.writer = LogWriter(str(output_dir))
def __call__(self, trainer: Trainer): def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items(): for k, v in trainer.observation.items():
......
...@@ -47,7 +47,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): ...@@ -47,7 +47,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
sum_square = layers.reduce_sum(square) sum_square = layers.reduce_sum(square)
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
# debug log # debug log, not dump all since slow down train process
if i < 10: if i < 10:
logger.debug( logger.debug(
f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }")
...@@ -76,7 +76,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): ...@@ -76,7 +76,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
new_grad = layers.elementwise_mul(x=g, y=clip_var) new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
# debug log # debug log, not dump all since slow down train process
if i < 10: if i < 10:
logger.debug( logger.debug(
f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}"
......
...@@ -19,7 +19,7 @@ OBSERVATIONS = None ...@@ -19,7 +19,7 @@ OBSERVATIONS = None
@contextlib.contextmanager @contextlib.contextmanager
def scope(observations): def ObsScope(observations):
# make `observation` the target to report to. # make `observation` the target to report to.
# it is basically a dictionary that stores temporary observations # it is basically a dictionary that stores temporary observations
global OBSERVATIONS global OBSERVATIONS
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import time
from deepspeech.utils.log import Log
__all__ = ["Timer"]
logger = Log(__name__).getlog()
class Timer():
"""To be used like this:
with Timer("Message") as value:
do some thing
"""
def __init__(self, message=None):
self.message = message
def duration(self) -> str:
elapsed_time = time.time() - self.start
time_str = str(datetime.timedelta(seconds=elapsed_time))
return time_str
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, type, value, traceback):
if self.message:
logger.info(self.message.format(self.duration()))
def __call__(self) -> float:
return time.time() - self.start
def __str__(self):
return self.duration()
...@@ -11,17 +11,24 @@ ...@@ -11,17 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import time import time
from collections import OrderedDict
from pathlib import Path from pathlib import Path
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils import profiler
from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import seed_all from deepspeech.utils.utility import seed_all
from deepspeech.utils.utility import UpdateConfig
__all__ = ["Trainer"] __all__ = ["Trainer"]
...@@ -79,7 +86,7 @@ class Trainer(): ...@@ -79,7 +86,7 @@ class Trainer():
>>> config.merge_from_list(args.opts) >>> config.merge_from_list(args.opts)
>>> config.freeze() >>> config.freeze()
>>> >>>
>>> if args.nprocs > 1 and args.device == "gpu": >>> if args.nprocs > 0:
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
>>> else: >>> else:
>>> main_sp(config, args) >>> main_sp(config, args)
...@@ -94,15 +101,25 @@ class Trainer(): ...@@ -94,15 +101,25 @@ class Trainer():
self.checkpoint_dir = None self.checkpoint_dir = None
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
self.rank = dist.get_rank()
logger.info(f"Rank: {self.rank}/{dist.get_world_size()}")
if args.seed: if args.seed:
seed_all(args.seed) seed_all(args.seed)
logger.info(f"Set seed {args.seed}") logger.info(f"Set seed {args.seed}")
if self.args.benchmark_batch_size:
with UpdateConfig(self.config):
self.config.collator.batch_size = self.args.benchmark_batch_size
self.config.training.log_interval = 1
logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
paddle.set_device(self.args.device) paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
if self.parallel: if self.parallel:
self.init_parallel() self.init_parallel()
...@@ -122,7 +139,7 @@ class Trainer(): ...@@ -122,7 +139,7 @@ class Trainer():
"""A flag indicating whether the experiment should run with """A flag indicating whether the experiment should run with
multiprocessing. multiprocessing.
""" """
return self.args.device == "gpu" and self.args.nprocs > 1 return self.args.nprocs > 0
def init_parallel(self): def init_parallel(self):
"""Init environment for multiprocess training. """Init environment for multiprocess training.
...@@ -162,56 +179,97 @@ class Trainer(): ...@@ -162,56 +179,97 @@ class Trainer():
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path) checkpoint_path=self.args.checkpoint_path)
if infos: if infos:
# restore from ckpt # just restore ckpt
# lr will resotre from optimizer ckpt
self.iteration = infos["step"] self.iteration = infos["step"]
self.epoch = infos["epoch"] self.epoch = infos["epoch"]
scratch = False scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
else: else:
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
scratch = True scratch = True
logger.info("Init from scratch!")
return scratch return scratch
def new_epoch(self): def maybe_batch_sampler_step(self):
"""Reset the train loader seed and increment `epoch`. """ batch_sampler seed by epoch """
""" if hasattr(self.train_loader, "batch_sampler"):
self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch) batch_sampler.set_epoch(self.epoch)
def train(self): def before_train(self):
"""The training process control by epoch."""
from_scratch = self.resume_or_scratch() from_scratch = self.resume_or_scratch()
if from_scratch: if from_scratch:
# save init model, i.e. 0 epoch # scratch: save init model, i.e. 0 epoch
self.save(tag='init', infos=None) self.save(tag='init', infos=None)
self.lr_scheduler.step(self.epoch) else:
if self.parallel and hasattr(self.train_loader, "batch_sampler"): # resume: train next_epoch and next_iteration
self.train_loader.batch_sampler.set_epoch(self.epoch) self.epoch += 1
self.iteration += 1
logger.info(
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
self.maybe_batch_sampler_step()
def new_epoch(self):
"""Reset the train loader seed and increment `epoch`.
"""
# `iteration` increased by train step
self.epoch += 1
self.maybe_batch_sampler_step()
def after_train_batch(self):
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
profiler.add_profiler_step(self.args.profiler_options)
logger.info(
f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
sys.exit(
f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
def train(self):
"""The training process control by epoch."""
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader): for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train:"
msg += "epoch: {}, ".format(self.epoch) observation = OrderedDict()
msg += "step: {}, ".format(self.iteration) with ObsScope(observation):
msg += "batch : {}/{}, ".format(batch_index + 1, report("Rank", dist.get_rank())
len(self.train_loader)) report("epoch", self.epoch)
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) report('step', self.iteration)
msg += "data time: {:>.3f}s, ".format(dataload_time) report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
observation['samples'] = observation['batch_size']
observation['ips[sent./sec]'] = observation[
'batch_size'] / observation['batch_cost']
for k, v in observation.items():
msg += f" {k}: "
msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}"
msg += ","
logger.info(msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid() total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts) num_seen_utts = paddle.to_tensor(num_seen_utts)
...@@ -231,6 +289,7 @@ class Trainer(): ...@@ -231,6 +289,7 @@ class Trainer():
'epoch', {'cv_loss': cv_loss, 'epoch', {'cv_loss': cv_loss,
'lr': self.lr_scheduler()}, self.epoch) 'lr': self.lr_scheduler()}, self.epoch)
# after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
# step lr every epoch # step lr every epoch
self.lr_scheduler.step() self.lr_scheduler.step()
...@@ -240,14 +299,13 @@ class Trainer(): ...@@ -240,14 +299,13 @@ class Trainer():
"""The routine of the experiment after setup. This method is intended """The routine of the experiment after setup. This method is intended
to be used by the user. to be used by the user.
""" """
with Timer("Training Done: {}"):
try: try:
self.train() self.train()
except KeyboardInterrupt: except KeyboardInterrupt:
self.save()
exit(-1) exit(-1)
finally: finally:
self.destory() self.destory()
logger.info("Training Done.")
def setup_output_dir(self): def setup_output_dir(self):
"""Create a directory used for output. """Create a directory used for output.
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from paddle import Tensor import paddle
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddle.nn import Layer from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from timer import timer from paddle.optimizer.lr import LRScheduler
from deepspeech.training.reporter import report from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase from deepspeech.training.updaters.updater import UpdaterBase
...@@ -39,8 +39,10 @@ class StandardUpdater(UpdaterBase): ...@@ -39,8 +39,10 @@ class StandardUpdater(UpdaterBase):
def __init__(self, def __init__(self,
model: Layer, model: Layer,
optimizer: Optimizer, optimizer: Optimizer,
scheduler: LRScheduler,
dataloader: DataLoader, dataloader: DataLoader,
init_state: Optional[UpdaterState]=None): init_state: Optional[UpdaterState]=None):
super().__init__(init_state)
# it is designed to hold multiple models # it is designed to hold multiple models
models = {"main": model} models = {"main": model}
self.models: Dict[str, Layer] = models self.models: Dict[str, Layer] = models
...@@ -51,15 +53,14 @@ class StandardUpdater(UpdaterBase): ...@@ -51,15 +53,14 @@ class StandardUpdater(UpdaterBase):
self.optimizer = optimizer self.optimizer = optimizer
self.optimizers: Dict[str, Optimizer] = optimizers self.optimizers: Dict[str, Optimizer] = optimizers
# it is designed to hold multiple scheduler
schedulers = {"main": scheduler}
self.scheduler = scheduler
self.schedulers: Dict[str, LRScheduler] = schedulers
# dataloaders # dataloaders
self.dataloader = dataloader self.dataloader = dataloader
# init state
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
self.train_iterator = iter(dataloader) self.train_iterator = iter(dataloader)
def update(self): def update(self):
...@@ -103,7 +104,9 @@ class StandardUpdater(UpdaterBase): ...@@ -103,7 +104,9 @@ class StandardUpdater(UpdaterBase):
model.train() model.train()
# training for a step is implemented here # training for a step is implemented here
with Timier("data time cost:{}"):
batch = self.read_batch() batch = self.read_batch()
with Timier("step time cost:{}"):
self.update_core(batch) self.update_core(batch)
self.state.iteration += 1 self.state.iteration += 1
...@@ -115,13 +118,14 @@ class StandardUpdater(UpdaterBase): ...@@ -115,13 +118,14 @@ class StandardUpdater(UpdaterBase):
"""A simple case for a training step. Basic assumptions are: """A simple case for a training step. Basic assumptions are:
Single model; Single model;
Single optimizer; Single optimizer;
Single scheduler, and update learning rate each step;
A batch from the dataloader is just the input of the model; A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses. The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation. Parameters updates at every batch, no gradient accumulation.
""" """
loss = self.model(*batch) loss = self.model(*batch)
if isinstance(loss, Tensor): if isinstance(loss, paddle.Tensor):
loss_dict = {"main": loss} loss_dict = {"main": loss}
else: else:
# Dict[str, Tensor] # Dict[str, Tensor]
...@@ -135,14 +139,15 @@ class StandardUpdater(UpdaterBase): ...@@ -135,14 +139,15 @@ class StandardUpdater(UpdaterBase):
for name, loss_item in loss_dict.items(): for name, loss_item in loss_dict.items():
report(name, float(loss_item)) report(name, float(loss_item))
self.optimizer.clear_gradient() self.optimizer.clear_grad()
loss_dict["main"].backward() loss_dict["main"].backward()
self.optimizer.update() self.optimizer.step()
self.scheduler.step()
@property @property
def updates_per_epoch(self): def updates_per_epoch(self):
"""Number of updater per epoch, determined by the length of the """Number of steps per epoch,
dataloader.""" determined by the length of the dataloader."""
length_of_dataloader = None length_of_dataloader = None
try: try:
length_of_dataloader = len(self.dataloader) length_of_dataloader = len(self.dataloader)
...@@ -163,18 +168,16 @@ class StandardUpdater(UpdaterBase): ...@@ -163,18 +168,16 @@ class StandardUpdater(UpdaterBase):
def read_batch(self): def read_batch(self):
"""Read a batch from the data loader, auto renew when data is exhausted.""" """Read a batch from the data loader, auto renew when data is exhausted."""
with timer() as t:
try: try:
batch = next(self.train_iterator) batch = next(self.train_iterator)
except StopIteration: except StopIteration:
self.new_epoch() self.new_epoch()
batch = next(self.train_iterator) batch = next(self.train_iterator)
logger.debug(
f"Read a batch takes {t.elapse}s.") # replace it with logger
return batch return batch
def state_dict(self): def state_dict(self):
"""State dict of a Updater, model, optimizer and updater state are included.""" """State dict of a Updater, model, optimizers/schedulers
and updater state are included."""
state_dict = super().state_dict() state_dict = super().state_dict()
for name, model in self.models.items(): for name, model in self.models.items():
state_dict[f"{name}_params"] = model.state_dict() state_dict[f"{name}_params"] = model.state_dict()
...@@ -184,7 +187,7 @@ class StandardUpdater(UpdaterBase): ...@@ -184,7 +187,7 @@ class StandardUpdater(UpdaterBase):
def set_state_dict(self, state_dict): def set_state_dict(self, state_dict):
"""Set state dict for a Updater. Parameters of models, states for """Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored.""" optimizers/schedulers and UpdaterState are restored."""
for name, model in self.models.items(): for name, model in self.models.items():
model.set_state_dict(state_dict[f"{name}_params"]) model.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items(): for name, optim in self.optimizers.items():
......
...@@ -24,7 +24,7 @@ import tqdm ...@@ -24,7 +24,7 @@ import tqdm
from deepspeech.training.extensions.extension import Extension from deepspeech.training.extensions.extension import Extension
from deepspeech.training.extensions.extension import PRIORITY_READER from deepspeech.training.extensions.extension import PRIORITY_READER
from deepspeech.training.reporter import scope from deepspeech.training.reporter import ObsScope
from deepspeech.training.triggers import get_trigger from deepspeech.training.triggers import get_trigger
from deepspeech.training.triggers.limit_trigger import LimitTrigger from deepspeech.training.triggers.limit_trigger import LimitTrigger
from deepspeech.training.updaters.updater import UpdaterBase from deepspeech.training.updaters.updater import UpdaterBase
...@@ -140,11 +140,11 @@ class Trainer(): ...@@ -140,11 +140,11 @@ class Trainer():
try: try:
while not stop_trigger(self): while not stop_trigger(self):
self.observation = {} self.observation = {}
# set observation as the report target # set observation as the `report` target
# you can use report freely in Updater.update() # you can use `report` freely in Updater.update()
# updating parameters and state # updating parameters and state
with scope(self.observation): with ObsScope(self.observation):
update() update()
p.update() p.update()
......
...@@ -52,6 +52,7 @@ class UpdaterBase(): ...@@ -52,6 +52,7 @@ class UpdaterBase():
""" """
def __init__(self, init_state=None): def __init__(self, init_state=None):
# init state
if init_state is None: if init_state is None:
self.state = UpdaterState() self.state = UpdaterState()
else: else:
......
...@@ -114,13 +114,13 @@ class Checkpoint(): ...@@ -114,13 +114,13 @@ class Checkpoint():
params_path = checkpoint_path + ".pdparams" params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path) model_dict = paddle.load(params_path)
model.set_state_dict(model_dict) model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path)) logger.info("Rank {}: Restore model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt" optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path): if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path) optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict) optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format( logger.info("Rank {}: Restore optimizer state from {}".format(
rank, optimizer_path)) rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path) info_path = re.sub('.pdparams$', '.json', params_path)
......
...@@ -84,19 +84,19 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, ...@@ -84,19 +84,19 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y_insert_blank = insert_blank(y, blank_id) #(2L+1) y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros( log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) (ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
# self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
state_path = (paddle.zeros( state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1 (ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1)) ) # state path, Tuple((T, 2L+1))
# init start state # init start state
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)): # T for t in range(1, ctc_probs.shape[0]): # T
for s in range(len(y_insert_blank)): # 2L+1 for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]: s] == y_insert_blank[s - 2]:
...@@ -110,13 +110,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, ...@@ -110,13 +110,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2], log_alpha[t - 1, s - 2],
]) ])
prev_state = [s, s - 1, s - 2] prev_state = [s, s - 1, s - 2]
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int( y_insert_blank[s]]
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)] state_path[t, s] = prev_state[paddle.argmax(candidates)]
# self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
# TODO(Hui Zhang): zeros not support paddle.int16 state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32)
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
candidates = paddle.to_tensor([ candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb log_alpha[-1, len(y_insert_blank) - 1], # Sb
...@@ -124,11 +122,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, ...@@ -124,11 +122,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
]) ])
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = prev_state[paddle.argmax(candidates)] state_seq[-1] = prev_state[paddle.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1): for t in range(ctc_probs.shape[0] - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_alignment = [] output_alignment = []
for t in range(0, ctc_probs.size(0)): for t in range(0, ctc_probs.shape[0]):
output_alignment.append(y_insert_blank[state_seq[t, 0]]) output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment return output_alignment
...@@ -12,19 +12,13 @@ ...@@ -12,19 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import getpass import getpass
import logging
import os import os
import socket import socket
import sys import sys
from loguru import logger
from paddle import inference from paddle import inference
FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
DATE_FMT_STR = '%Y/%m/%d %H:%M:%S'
logging.basicConfig(
level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR)
def find_log_dir(log_dir=None): def find_log_dir(log_dir=None):
"""Returns the most suitable directory to put log files into. """Returns the most suitable directory to put log files into.
...@@ -98,59 +92,28 @@ def find_log_dir_and_names(program_name=None, log_dir=None): ...@@ -98,59 +92,28 @@ def find_log_dir_and_names(program_name=None, log_dir=None):
class Log(): class Log():
"""Default Logger for all."""
log_name = None logger.remove()
logger.add(
def __init__(self, logger=None): sys.stdout,
self.logger = logging.getLogger(logger) level='INFO',
self.logger.setLevel(logging.DEBUG) enqueue=True,
filter=lambda record: record['level'].no >= 20)
file_dir = os.getcwd() + '/log' _, file_prefix, _ = find_log_dir_and_names()
if not os.path.exists(file_dir): sink_prefix = os.path.join("exp/log", file_prefix)
os.mkdir(file_dir) sink_path = sink_prefix[:-3] + "{time}.log"
self.log_dir = file_dir logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB")
actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( def __init__(self, name=None):
program_name=None, log_dir=self.log_dir)
basename = '%s.DEBUG.%d' % (file_prefix, os.getpid())
filename = os.path.join(actual_log_dir, basename)
if Log.log_name is None:
Log.log_name = filename
# Create a symlink to the log file with a canonical name.
symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG')
try:
if os.path.islink(symlink):
os.unlink(symlink)
os.symlink(os.path.basename(Log.log_name), symlink)
except EnvironmentError:
# If it fails, we're sad but it's no error. Commonly, this
# fails because the symlink was created by another user and so
# we can't modify it
pass pass
if not self.logger.hasHandlers():
formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR)
fh = logging.FileHandler(Log.log_name)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
self.logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
self.logger.addHandler(ch)
# stop propagate for propagating may print
# log multiple times
self.logger.propagate = False
def getlog(self): def getlog(self):
return self.logger return logger
class Autolog: class Autolog:
"""Just used by fullchain project"""
def __init__(self, def __init__(self,
batch_size, batch_size,
model_name="DeepSpeech", model_name="DeepSpeech",
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
# A global variable to record the number of calling times for profiler
# functions. It is used to specify the tracing range of training steps.
_profiler_step_id = 0
# A global variable to avoid parsing from string every time.
_profiler_options = None
class ProfilerOptions(object):
'''
Use a string to initialize a ProfilerOptions.
The string should be in the format: "key1=value1;key2=value;key3=value3".
For example:
"profile_path=model.profile"
"batch_range=[50, 60]; profile_path=model.profile"
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
ProfilerOptions supports following key-value pair:
batch_range - a integer list, e.g. [100, 110].
state - a string, the optional values are 'CPU', 'GPU' or 'All'.
sorted_key - a string, the optional values are 'calls', 'total',
'max', 'min' or 'ave.
tracer_option - a string, the optional values are 'Default', 'OpDetail',
'AllOpDetail'.
profile_path - a string, the path to save the serialized profile data,
which can be used to generate a timeline.
exit_on_finished - a boolean.
'''
def __init__(self, options_str):
assert isinstance(options_str, str)
self._options = {
'batch_range': [10, 20],
'state': 'All',
'sorted_key': 'total',
'tracer_option': 'Default',
'profile_path': '/tmp/profile',
'exit_on_finished': True
}
self._parse_from_string(options_str)
def _parse_from_string(self, options_str):
if not options_str:
return
for kv in options_str.replace(' ', '').split(';'):
key, value = kv.split('=')
if key == 'batch_range':
value_list = value.replace('[', '').replace(']', '').split(',')
value_list = list(map(int, value_list))
if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
1] > value_list[0]:
self._options[key] = value_list
elif key == 'exit_on_finished':
self._options[key] = value.lower() in ("yes", "true", "t", "1")
elif key in [
'state', 'sorted_key', 'tracer_option', 'profile_path'
]:
self._options[key] = value
def __getitem__(self, name):
if self._options.get(name, None) is None:
raise ValueError(
"ProfilerOptions does not have an option named %s." % name)
return self._options[name]
def add_profiler_step(options_str=None):
'''
Enable the operator-level timing using PaddlePaddle's profiler.
The profiler uses a independent variable to count the profiler steps.
One call of this function is treated as a profiler step.
Args:
profiler_options - a string to initialize the ProfilerOptions.
Default is None, and the profiler is disabled.
'''
if options_str is None:
return
global _profiler_step_id
global _profiler_options
if _profiler_options is None:
_profiler_options = ProfilerOptions(options_str)
logger.info(f"Profiler: {options_str}")
logger.info(f"Profiler: {_profiler_options._options}")
if _profiler_step_id == _profiler_options['batch_range'][0]:
paddle.utils.profiler.start_profiler(_profiler_options['state'],
_profiler_options['tracer_option'])
elif _profiler_step_id == _profiler_options['batch_range'][1]:
paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
_profiler_options['profile_path'])
if _profiler_options['exit_on_finished']:
sys.exit(0)
_profiler_step_id += 1
...@@ -83,7 +83,7 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -83,7 +83,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
# (TODO Hui Zhang): slice not supprot `end==start` # (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:] # trailing_dims = max_size[1:]
trailing_dims = max_size[1:] if max_size.ndim >= 2 else () trailing_dims = max_size[1:] if max_size.ndim >= 2 else ()
max_len = max([s.size(0) for s in sequences]) max_len = max([s.shape[0] for s in sequences])
if batch_first: if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims out_dims = (len(sequences), max_len) + trailing_dims
else: else:
...@@ -91,12 +91,22 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -91,12 +91,22 @@ def pad_sequence(sequences: List[paddle.Tensor],
out_tensor = sequences[0].new_full(out_dims, padding_value) out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences): for i, tensor in enumerate(sequences):
length = tensor.size(0) length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor # use index notation to prevent duplicate references to the tensor
if batch_first: if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[i, :length, ...] = tensor
if length != 0:
out_tensor[i, :length, ...] = tensor out_tensor[i, :length, ...] = tensor
else: else:
out_tensor[i, length, ...] = tensor
else:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if length != 0:
out_tensor[:length, i, ...] = tensor out_tensor[:length, i, ...] = tensor
else:
out_tensor[length, i, ...] = tensor
return out_tensor return out_tensor
...@@ -139,7 +149,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, ...@@ -139,7 +149,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.size(0) B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1) ys_in = paddle.cat([_sos, ys_pad], dim=1)
...@@ -165,16 +175,10 @@ def th_accuracy(pad_outputs: paddle.Tensor, ...@@ -165,16 +175,10 @@ def th_accuracy(pad_outputs: paddle.Tensor,
Returns: Returns:
float: Accuracy value (0.0 - 1.0). float: Accuracy value (0.0 - 1.0).
""" """
pad_pred = pad_outputs.view( pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type numerator = paddle.sum(
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = paddle.sum(numerator.type_as(pad_targets)) denominator = paddle.sum(mask)
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator) return float(numerator) / float(denominator)
...@@ -16,15 +16,27 @@ import distutils.util ...@@ -16,15 +16,27 @@ import distutils.util
import math import math
import os import os
import random import random
from contextlib import contextmanager
from typing import List from typing import List
import numpy as np import numpy as np
import paddle import paddle
__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"] __all__ = [
"UpdateConfig", "seed_all", 'print_arguments', 'add_arguments', "log_add"
]
@contextmanager
def UpdateConfig(config):
"""Update yacs config"""
config.defrost()
yield
config.freeze()
def seed_all(seed: int=210329): def seed_all(seed: int=210329):
"""freeze random generator seed."""
np.random.seed(seed) np.random.seed(seed)
random.seed(seed) random.seed(seed)
paddle.seed(seed) paddle.seed(seed)
......
# Benchmarks
## Acceleration with Multi-GPUs
We compare the training time with 1, 2, 4, 8 Tesla V100 GPUs (with a subset of LibriSpeech samples whose audio durations are between 6.0 and 7.0 seconds). And it shows that a **near-linear** acceleration with multiple GPUs has been achieved. In the following figure, the time (in seconds) cost for training is printed on the blue bars.
<img src="../images/multi_gpu_speedup.png" width=450>
| # of GPU | Acceleration Rate |
| -------- | --------------: |
| 1 | 1.00 X |
| 2 | 1.98 X |
| 4 | 3.73 X |
| 8 | 6.95 X |
`utils/profile.sh` provides such a demo profiling tool, you can change it as need.
# FAQ
1. 音频变速快慢到达什么晨读会影响识别率?
变速会提升识别效果,一般用0.9, 1.0, 1.1 的变速。
2. 音量大小到什么程度会影响识别率?
一般训练会固定音量到一个范围内,波动过大会影响训练,估计在10dB ~ 20dB吧。
3. 语音模型训练数据的最小时长要求时多少?
Aishell-1大约178h的数据,数据越多越好。
4. 那些噪声或背景生会影响识别率?
主要是人生干扰和低信噪比会影响识别率。
5. 单条语音数据的长度限制是多少?
一般训练的语音长度会限制在1s~6s之间,和训练配置有关。
6. 背景声在识别前是否需要分离出来,或做降噪处理?
需要分离的,需要结合具体场景考虑。
7. 模型是否带有VAD人生激活识别能力?
VAD是单独的模型或模块,模型不包含此能力。
8. 是否支持长语音识别?
一般过VAD后识别。
9. Mandarin LM Large语言模型需要的硬件配置时怎样的?
内存能放得下LM即可。
# Reference
* [wenet](https://github.com/mobvoi/wenet)
# Released Models
## Language Model Released
Language Model | Training Data | Token-based | Size | Descriptions
:-------------:| :------------:| :-----: | -----: | :-----------------
[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1; <br/> About 1.85 billion n-grams; <br/> 'trie' binary with '-a 22 -q 8 -b 8'
[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4; <br/> About 0.13 billion n-grams; <br/> 'probing' binary with default settings
[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning; <br/> About 3.7 billion n-grams; <br/> 'probing' binary with default settings
# Trying Live Demo with Your Own Voice
Until now, an ASR model is trained and tested qualitatively (`infer`) and quantitatively (`test`) with existing audio files. But it is not yet tested with your own speech. We build up a real-time demo ASR engine with the trained model, enabling you to test and play around with the demo, with your own voice.
First, change your directory to `examples/aishell` and `source path.sh`.
To start the demo's server, please run this in one console:
```bash
CUDA_VISIBLE_DEVICES=0 bash local/server.sh
```
For the machine (might not be the same machine) to run the demo's client, please do the following installation before moving on.
For example, on MAC OS X:
```bash
brew install portaudio
pip install pyaudio
pip install keyboard
```
Then to start the client, please run this in another console:
```bash
CUDA_VISIBLE_DEVICES=0 bash local/client.sh
```
Now, in the client console, press the `whitespace` key, hold, and start speaking. Until finishing your utterance, release the key to let the speech-to-text results shown in the console. To quit the client, just press `ESC` key.
Notice that `deepspeech/exps/deepspeech2/deploy/client.py` must be run on a machine with a microphone device, while `deepspeech/exps/deepspeech2/deploy/server.py` could be run on one without any audio recording hardware, e.g. any remote server machine. Just be careful to set the `host_ip` and `host_port` argument with the actual accessible IP address and port, if the server and client are running with two separate machines. Nothing should be done if they are running on one single machine.
Please also refer to `examples/aishell/local/server.sh`, which will first download a pre-trained Chinese model (trained with AISHELL1) and then start the demo server with the model. With running `examples/aishell/local/client.sh`, you can speak Chinese to test it. If you would like to try some other models, just update `--checkpoint_path` argument in the script.  
# Deepspeech2
## Streaming
The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes.
The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers.
To illustrate the model implementation clearly, 3 parts are described in detail.
- Data Preparation
- Encoder
- Decoder
In addition, the training process and the testing process are also introduced.
The arcitecture of the model is shown in Fig.1.
<p align="center">
<img src="../images/ds2onlineModel.png" width=800>
<br/>Fig.1 The Arcitecture of deepspeech2 online model
</p>
### Data Preparation
#### Vocabulary
For English data, the vocabulary dictionary is composed of 26 English characters with " ' ", space, \<blank\> and \<eos\>. The \<blank\> represents the blank label in CTC, the \<unk\> represents the unknown character and the \<eos\> represents the start and the end characters. For mandarin, the vocabulary dictionary is composed of chinese characters statisticed from the training set and three additional characters are added. The added characters are \<blank\>, \<unk\> and \<eos\>. For both English and mandarin data, we set the default indexs that \<blank\>=0, \<unk\>=1 and \<eos\>= last index.
```
# The code to build vocabulary
cd examples/aishell/s0
python3 ../../../utils/build_vocab.py \
--unit_type="char" \
--count_threshold=0 \
--vocab_path="data/vocab.txt" \
--manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw"
# vocabulary for aishell dataset (Mandarin)
vi examples/aishell/s0/data/vocab.txt
# vocabulary for librispeech dataset (English)
vi examples/librispeech/s0/data/vocab.txt
```
#### CMVN
For CMVN, a subset or the full of traininig set is chosed and be used to compute the feature mean and std.
```
# The code to compute the feature mean and std
cd examples/aishell/s0
python3 ../../../utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--specgram_type="linear" \
--delta_delta=false \
--stride_ms=10.0 \
--window_ms=20.0 \
--sample_rate=16000 \
--use_dB_normalization=True \
--num_samples=2000 \
--num_workers=10 \
--output_path="data/mean_std.json"
```
#### Feature Extraction
For feature extraction, three methods are implemented, which are linear (FFT without using filter bank), fbank and mfcc.
Currently, the released deepspeech2 online model use the linear feature extraction method.
```
The code for feature extraction
vi deepspeech/frontend/featurizer/audio_featurizer.py
```
### Encoder
The encoder is composed of two 2D convolution subsampling layers and a number of stacked single direction rnn layers. The 2D convolution subsampling layers extract feature representation from the raw audio feature and reduce the length of audio feature at the same time. After passing through the convolution subsampling layers, then the feature representation are input into the stacked rnn layers. For the stacked rnn layers, LSTM cell and GRU cell are provided to use. Adding one fully connected (fc) layer after the stacked rnn layers is optional. If the number of stacked rnn layers is less than 5, adding one fc layer after stacked rnn layers is recommand.
The code of Encoder is in:
```
vi deepspeech/models/ds2_online/deepspeech2.py
```
### Decoder
To got the character possibilities of each frame, the feature representation of each frame output from the encoder are input into a projection layer which is implemented as a dense layer to do feature projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to transform the frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results.
The code of the decoder is in:
```
# The code of constructing the decoder in model
vi deepspeech/models/ds2_online/deepspeech2.py
# The code of CTC Decoder
vi deepspeech/modules/ctc.py
```
## Training Process
Using the command below, you can train the deepspeech2 online model.
```
cd examples/aishell/s0
bash run.sh --stage 0 --stop_stage 2 --model_type online --conf_path conf/deepspeech2_online.yaml
```
The detail commands are:
```
# The code for training in run.sh
set -e
source path.sh
gpus=2,3,5,7
stage=0
stop_stage=5
conf_path=conf/deepspeech2_online.yaml # conf/deepspeech2.yaml | conf/deepspeech2_online.yaml
avg_num=1
model_type=online # online | offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
```
By using the command above, the training process can be started. There are 5 stages in "run.sh", and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss.
## Testing Process
Using the command below, you can test the deepspeech2 online model.
```
bash run.sh --stage 3 --stop_stage 5 --model_type online --conf_path conf/deepspeech2_online.yaml
```
The detail commands are:
```
conf_path=conf/deepspeech2_online.yaml
avg_num=1
model_type=online
avg_ckpt=avg_${avg_num}
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=2 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=5 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# test export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1
fi
```
After the training process, we use stage 3,4,5 for testing process. The stage 3 is for testing the model generated in the stage 2 and provided the CER index of the test set. The stage 4 is for transforming the model from dynamic graph to static graph by using "paddle.jit" library. The stage 5 is for testing the model in static graph.
## Non-Streaming
The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the stacked bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used. For the stacked bi-directional rnn layers in the offline model, the rnn cell and gru cell are provided to use.
The arcitecture of the model is shown in Fig.2.
<p align="center">
<img src="../images/ds2offlineModel.png" width=800>
<br/>Fig.2 The Arcitecture of deepspeech2 offline model
</p>
For data preparation and decoder, the deepspeech2 offline model is same with the deepspeech2 online model.
The code of encoder and decoder for deepspeech2 offline model is in:
```
vi deepspeech/models/ds2/deepspeech2.py
```
The training process and testing process of deepspeech2 offline model is very similary to deepspeech2 online model.
Only some changes should be noticed.
For training and testing, the "model_type" and the "conf_path" must be set.
```
# Training offline
cd examples/aishell/s0
bash run.sh --stage 0 --stop_stage 2 --model_type offline --conf_path conf/deepspeech2.yaml
```
```
# Testing offline
cd examples/aishell/s0
bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml
```
# Features # Features
### Dataset
* Aishell
* Librispeech
* THCHS30
* TIMIT
### Speech Recognition ### Speech Recognition
* Offline * Non-Streaming
* [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf) * [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf)
* [Transformer](https://arxiv.org/abs/1706.03762) * [Transformer](https://arxiv.org/abs/1706.03762)
* [Conformer](https://arxiv.org/abs/2005.08100) * [Conformer](https://arxiv.org/abs/2005.08100)
* Online * Streaming
* [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf)
* [U2](https://arxiv.org/pdf/2012.05481.pdf) * [U2](https://arxiv.org/pdf/2012.05481.pdf)
### Language Model ### Language Model
...@@ -22,6 +29,15 @@ ...@@ -22,6 +29,15 @@
* beam search * beam search
* attention rescore * attention rescore
### Deployment
* Paddle Inference
### Aligment
* MFA
* CTC Aligment
### Speech Frontend ### Speech Frontend
* Audio * Audio
......
...@@ -4,15 +4,16 @@ To avoid the trouble of environment setup, [running in Docker container](#runnin ...@@ -4,15 +4,16 @@ To avoid the trouble of environment setup, [running in Docker container](#runnin
## Prerequisites ## Prerequisites
- Python >= 3.7 - Python >= 3.7
- PaddlePaddle 2.0.0 or later (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html)) - PaddlePaddle latest version (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html))
## Setup ## Setup (Important)
- Make sure these libraries or tools installed: `pkg-config`, `flac`, `ogg`, `vorbis`, `boost`, `sox, and `swig`, e.g. installing them via `apt-get`: - Make sure these libraries or tools installed: `pkg-config`, `flac`, `ogg`, `vorbis`, `boost`, `sox, and `swig`, e.g. installing them via `apt-get`:
```bash ```bash
sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
``` ```
The version of `swig` should >= 3.0
or, installing them via `yum`: or, installing them via `yum`:
......
...@@ -35,52 +35,3 @@ Different from the English language model, Mandarin language model is character- ...@@ -35,52 +35,3 @@ Different from the English language model, Mandarin language model is character-
* A whitespace character between two tokens is inserted. * A whitespace character between two tokens is inserted.
Please notice that the released language models only contain Chinese simplified characters. After preprocessing done we can begin to train the language model. The key training arguments for small LM is '-o 5 --prune 0 1 2 4 4' and '-o 5' for large LM. Please refer above section for the meaning of each argument. We also convert the arpa file to binary file using default settings. Please notice that the released language models only contain Chinese simplified characters. After preprocessing done we can begin to train the language model. The key training arguments for small LM is '-o 5 --prune 0 1 2 4 4' and '-o 5' for large LM. Please refer above section for the meaning of each argument. We also convert the arpa file to binary file using default settings.
## [KenLM](http://kheafield.com/code/kenlm/)
统计语言模型工具有比较多的选择,目前使用比较好的有srilm及kenlm,其中kenlm比srilm晚出来,训练速度也更快,而且支持单机大数据的训练。现在介绍一下kenlm的使用方法。
1. 工具包的下载地址:http://kheafield.com/code/kenlm.tar.gz
2. 使用。该工具在linux环境下使用方便。 先确保linux环境已经按照1.36.0的Boost和zlib
```
boost:
yum install boost
yum install boost-devel
zlib:
yum install zlib
yum install zlib-devel
```
然后gcc版本需要是4.8.2及以上。
```
wget -O - https://kheafield.com/code/kenlm.tar.gz |tar xz
mkdir kenlm/build
cd kenlm/build
cmake ..
make -j2
```
3. 训练。使用如下命令进行训练:
```
build/bin/lmplz -o 3 --verbose_header --text people2014corpus_words.txt --arpa result/people2014corpus_words.arps
```
其中,
1)people2014corpus_words.txt文件必须是分词以后的文件。
训练语料<人民日报2014版熟语料>,包括: 1)标准人工切词及词性数据people2014.tar.gz, 2)未切词文本数据people2014_words.txt, 3)kenlm训练字粒度语言模型文件及其二进制文件people2014corpus_chars.arps/klm, 4)kenlm词粒度语言模型文件及其二进制文件people2014corpus_words.arps/klm。
2)-o后面的5表示的是5-gram,一般取到3即可,但可以结合自己实际情况判断。
4. 压缩。压缩模型为二进制,方便模型快速加载:
```
build/bin/build_binary ./result/people2014corpus_words.arps ./result/people2014corpus_words.klm
```
# Reference
We refer these repos to build `model` and `engine`:
* [delta](https://github.com/Delta-ML/delta.git)
* [espnet](https://github.com/espnet/espnet.git)
* [kaldi](https://github.com/kaldi-asr/kaldi.git)
* [wenet](https://github.com/mobvoi/wenet)
# Released Models
## Acoustic Model Released in paddle 2.X
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER or WER | Hours of speech
:-------------:| :------------:| :-----: | -----: | :----------------- | :---------- | :---------
[Ds2 Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds_online.5rnn.debug.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.0824 | 151 h
[Ds2 Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds2.offline.cer6p65.release.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.065 | 151 h
[Conformer Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz) | Aishell Dataset | Char-based | 283 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention + CTC | 0.0594 | 151 h
[Conformer Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.release.tar.gz) | Aishell Dataset | Char-based | 284 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0547 | 151 h
[Conformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/conformer.release.tar.gz) | Librispeech Dataset | Word-based | 287 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0325 | 960 h
[Transformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/transformer.release.tar.gz) | Librispeech Dataset | Word-based | 195 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0544 | 960 h
## Acoustic Model Transformed from paddle 1.8
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER or WER | Hours of speech
:-------------:| :------------:| :-----: | -----: | :----------------- | :---------- | :---------
[Ds2 Offline Aishell model](https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz)|Aishell Dataset| Char-based| 234 MB| 2 Conv + 3 bidirectional GRU layers| 0.0804 | 151 h|
[Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz)|Librispeech Dataset| Word-based| 307 MB| 2 Conv + 3 bidirectional sharing weight RNN layers | 0.0685| 960 h|
[Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz)|Baidu Internal English Dataset| Word-based| 273 MB| 2 Conv + 3 bidirectional GRU layers | 0.0541 | 8628 h|
## Language Model Released
Language Model | Training Data | Token-based | Size | Descriptions
:-------------:| :------------:| :-----: | -----: | :-----------------
[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1; <br/> About 1.85 billion n-grams; <br/> 'trie' binary with '-a 22 -q 8 -b 8'
[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4; <br/> About 0.13 billion n-grams; <br/> 'probing' binary with default settings
[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning; <br/> About 3.7 billion n-grams; <br/> 'probing' binary with default settings
# 1xt2x
Convert Deepspeech 1.8 released model to 2.x.
## Model
* Deepspeech2x
## Exp
* baidu_en8k
* aishell
* librispeech
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
use_gru: True
share_rnn_weights: False
blank_id: 4333
training:
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 2.6
beta: 5.0
beam_size: 300
cutoff_prob: 0.99
cutoff_top_n: 40
num_proc_bsearch: 8
#!/bin/bash
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh
if [ $? -ne 0 ]; then
exit 1
fi
tar xzvf aishell_model_v1.8_to_v2.x.tar.gz
mv aishell_v1.8.pdparams exp/deepspeech2/checkpoints/
mv README.md exp/deepspeech2/
mv mean_std.npz data/
mv vocab.txt data/
rm aishell_model_v1.8_to_v2.x.tar.gz -f
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/aishell/aishell.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/aishell"
if [ $? -ne 0 ]; then
echo "Prepare Aishell failed. Terminated."
exit 1
fi
for dataset in train dev test; do
mv data/manifest.${dataset} data/manifest.${dataset}.raw
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--specgram_type="linear" \
--delta_delta=false \
--stride_ms=10.0 \
--window_ms=20.0 \
--sample_rate=16000 \
--use_dB_normalization=True \
--num_samples=2000 \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
} &
done
wait
fi
echo "Aishell data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
MD5="29e02312deb2e59b3c8686c7966d4fe3"
TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
#! /usr/bin/env bash
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz'
MD5=4ade113c69ea291b8ce5ec6a03296659
TARGET=./aishell_model_v1.8_to_v2.x.tar.gz
echo "Download Aishell model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download Aishell model!"
exit 1
fi
tar -zxvf $TARGET
exit 0
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin
echo "BIN_DIR "${BIN_DIR}
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=aishell_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: .inf # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
use_gru: True
share_rnn_weights: False
blank_id: 28
training:
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 1.4
beta: 0.35
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#!/bin/bash
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh
if [ $? -ne 0 ]; then
exit 1
fi
tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz
mv baidu_en8k_v1.8.pdparams exp/deepspeech2/checkpoints/
mv README.md exp/deepspeech2/
mv mean_std.npz data/
mv vocab.txt data/
rm baidu_en8k_v1.8_to_v2.x.tar.gz -f
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="True"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mv data/manifest.${set} data/manifest.${set}.raw
done
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
for set in train-clean-100 train-clean-360 train-other-500; do
cat data/manifest.${set}.raw >> data/manifest.train.raw
done
for set in dev-clean dev-other; do
cat data/manifest.${set}.raw >> data/manifest.dev.raw
done
for set in test-clean test-other; do
cat data/manifest.${set}.raw >> data/manifest.test.raw
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--num_samples=2000 \
--specgram_type="linear" \
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=20.0 \
--use_dB_normalization=True \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest.${set} failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "LibriSpeech Data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
#! /usr/bin/env bash
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz'
MD5=fdabeb6c96963ac85d9188f0275c6a1b
TARGET=./baidu_en8k_v1.8_to_v2.x.tar.gz
echo "Download BaiduEn8k model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download BaiduEn8k model!"
exit 1
fi
tar -zxvf $TARGET
exit 0
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin
echo "BIN_DIR "${BIN_DIR}
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=baidu_en8k_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import List
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddle.fluid import core
from paddle.nn import functional as F
from deepspeech.utils.log import Log
#TODO(Hui Zhang): remove fluid import
logger = Log(__name__).getlog()
########### hcak logging #############
logger.warn = logger.warning
########### hcak paddle #############
paddle.half = 'float16'
paddle.float = 'float32'
paddle.double = 'float64'
paddle.short = 'int16'
paddle.int = 'int32'
paddle.long = 'int64'
paddle.uint16 = 'uint16'
paddle.cdouble = 'complex128'
def convert_dtype_to_string(tensor_dtype):
"""
Convert the data type in numpy to the data type in Paddle
Args:
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
Returns:
core.VarDesc.VarType: the data type in Paddle.
"""
dtype = tensor_dtype
if dtype == core.VarDesc.VarType.FP32:
return paddle.float32
elif dtype == core.VarDesc.VarType.FP64:
return paddle.float64
elif dtype == core.VarDesc.VarType.FP16:
return paddle.float16
elif dtype == core.VarDesc.VarType.INT32:
return paddle.int32
elif dtype == core.VarDesc.VarType.INT16:
return paddle.int16
elif dtype == core.VarDesc.VarType.INT64:
return paddle.int64
elif dtype == core.VarDesc.VarType.BOOL:
return paddle.bool
elif dtype == core.VarDesc.VarType.BF16:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return paddle.uint16
elif dtype == core.VarDesc.VarType.UINT8:
return paddle.uint8
elif dtype == core.VarDesc.VarType.INT8:
return paddle.int8
elif dtype == core.VarDesc.VarType.COMPLEX64:
return paddle.complex64
elif dtype == core.VarDesc.VarType.COMPLEX128:
return paddle.complex128
else:
raise ValueError("Not supported tensor dtype %s" % dtype)
if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'log_softmax'):
logger.warn("register user log_softmax to paddle, remove this when fixed!")
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
if not hasattr(paddle, 'sigmoid'):
logger.warn("register user sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle, 'log_sigmoid'):
logger.warn("register user log_sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
if not hasattr(paddle, 'relu'):
logger.warn("register user relu to paddle, remove this when fixed!")
setattr(paddle, 'relu', paddle.nn.functional.relu)
def cat(xs, dim=0):
return paddle.concat(xs, axis=dim)
if not hasattr(paddle, 'cat'):
logger.warn(
"override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat
########### hcak paddle.Tensor #############
def item(x: paddle.Tensor):
return x.numpy().item()
if not hasattr(paddle.Tensor, 'item'):
logger.warn(
"override item of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.item = item
def func_long(x: paddle.Tensor):
return paddle.cast(x, paddle.long)
if not hasattr(paddle.Tensor, 'long'):
logger.warn(
"override long of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.long = func_long
if not hasattr(paddle.Tensor, 'numel'):
logger.warn(
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.numel = paddle.numel
def new_full(x: paddle.Tensor,
size: Union[List[int], Tuple[int], paddle.Tensor],
fill_value: Union[float, int, bool, paddle.Tensor],
dtype=None):
return paddle.full(size, fill_value, dtype=x.dtype)
if not hasattr(paddle.Tensor, 'new_full'):
logger.warn(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.new_full = new_full
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if convert_dtype_to_string(xs.dtype) == paddle.bool:
xs = xs.astype(paddle.int)
return xs.equal(
paddle.to_tensor(
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
if not hasattr(paddle.Tensor, 'eq'):
logger.warn(
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.eq = eq
if not hasattr(paddle, 'eq'):
logger.warn(
"override eq of paddle if exists or register, remove this when fixed!")
paddle.eq = eq
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs
if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.contiguous = contiguous
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
nargs = len(args)
assert (nargs <= 1)
s = paddle.shape(xs)
if nargs == 1:
return s[args[0]]
else:
return s
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn(
"override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
paddle.Tensor.size = size
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return xs.reshape(args)
if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
return xs.reshape(ys.size())
if not hasattr(paddle.Tensor, 'view_as'):
logger.warn(
"register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn(
"register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill
def masked_fill_(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]) -> paddle.Tensor:
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
ret = paddle.where(mask, trues, xs)
paddle.assign(ret.detach(), output=xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_
def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
val = paddle.full_like(xs, value)
paddle.assign(val.detach(), output=xs)
return xs
if not hasattr(paddle.Tensor, 'fill_'):
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
return paddle.tile(xs, size)
if not hasattr(paddle.Tensor, 'repeat'):
logger.warn(
"register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'):
logger.warn(
"register user softmax to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle.Tensor, 'sigmoid'):
logger.warn(
"register user sigmoid to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
return x.astype(other.dtype)
if not hasattr(paddle.Tensor, 'type_as'):
logger.warn(
"register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as)
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
assert len(args) == 1
if isinstance(args[0], str): # dtype
return x.astype(args[0])
elif isinstance(args[0], paddle.Tensor): #Tensor
return x.astype(args[0].dtype)
else: # Device
return x
if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to)
def func_float(x: paddle.Tensor) -> paddle.Tensor:
return x.astype(paddle.float)
if not hasattr(paddle.Tensor, 'float'):
logger.warn("register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float)
def func_int(x: paddle.Tensor) -> paddle.Tensor:
return x.astype(paddle.int)
if not hasattr(paddle.Tensor, 'int'):
logger.warn("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int)
def tolist(x: paddle.Tensor) -> List[Any]:
return x.numpy().tolist()
if not hasattr(paddle.Tensor, 'tolist'):
logger.warn(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn #############
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from deepspeech2x.model import DeepSpeech2Tester as Tester
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
from deepspeech2x.models.ds2 import DeepSpeech2InferModel
from deepspeech2x.models.ds2 import DeepSpeech2Model
from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log
#from deepspeech.utils.log import Autolog
logger = Log(__name__).getlog()
class DeepSpeech2Trainer(Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# training config
default = CfgNode(
dict(
lr=5e-4, # learning rate
lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
loss = self.model(audio, audio_len, text, text_len)
losses_np = {
'train_loss': float(loss),
}
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.iteration += 1
iteration_time = time.time() - start
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
# `step -1` since we update `step` after optimizer.step().
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration - 1)
@paddle.no_grad()
def valid(self):
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
self.model.eval()
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
loss = self.model(audio, audio_len, text, text_len)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
if (i + 1) % self.config.training.log_interval == 0:
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
valid_dump['val_history_loss'] = total_loss / num_seen_utts
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
logger.info('Rank {} Val info val_loss {}'.format(
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def setup_model(self):
config = self.config.clone()
config.defrost()
config.model.feat_size = self.train_loader.collate_fn.feature_size
#config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.model.dict_size = len(self.train_loader.collate_fn.vocab_list)
config.freeze()
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config.model)
else:
raise Exception("wrong model type")
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
grad_clip = ClipGradByGlobalNormWithLog(
config.training.global_grad_clip)
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=config.training.lr,
gamma=config.training.lr_decay,
verbose=True)
optimizer = paddle.optimizer.Adam(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test)
if "<eos>" in self.test_loader.collate_fn.vocab_list:
self.test_loader.collate_fn.vocab_list.remove("<eos>")
if "<eos>" in self.valid_loader.collate_fn.vocab_list:
self.valid_loader.collate_fn.vocab_list.remove("<eos>")
if "<eos>" in self.train_loader.collate_fn.vocab_list:
self.train_loader.collate_fn.vocab_list.remove("<eos>")
logger.info("Setup train/valid/test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# testing config
default = CfgNode(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list
if "" in vocab_list:
space_id = vocab_list.index("")
vocab_list[space_id] = " "
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
return result_transcripts
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval()
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
# self.autolog.report()
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def export(self):
if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
elif self.args.model_type == 'online':
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
raise Exception("wrong model type")
infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size
static_model = infer_model.export()
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModel
from .deepspeech2 import DeepSpeech2Model
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
from typing import Optional
import paddle
from deepspeech2x.models.ds2.rnn import RNNStack
from paddle import nn
from yacs.config import CfgNode
from deepspeech.models.ds2.conv import ConvStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack
self.rnn = RNNStack(
i_size=i_size,
h_size=rnn_size,
num_stacks=num_rnn_layers,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
@property
def output_size(self):
return self.rnn_size * 2
def forward(self, audio, audio_len):
"""Compute Encoder outputs
Args:
audio (Tensor): [B, Tmax, D]
text (Tensor): [B, Umax]
audio_len (Tensor): [B]
text_len (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
"""
# [B, T, D] -> [B, D, T]
audio = audio.transpose([0, 2, 1])
# [B, D, T] -> [B, C=1, D, T]
x = audio.unsqueeze(1)
x_lens = audio_len
# convolution group
x, x_lens = self.conv(x, x_lens)
x_val = x.numpy()
# convert data from convolution feature map to sequence of vectors
#B, C, D, T = paddle.shape(x) # not work under jit
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part
x, x_lens = self.rnn(x, x_lens) #[B, T, D]
return x, x_lens
class DeepSpeech2Model(nn.Layer):
"""The DeepSpeech2 network structure.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward direction RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
assert (self.encoder.output_size == rnn_size * 2)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=blank_id, # first token is <blank>
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len = self.encoder(audio, audio_len)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
print("probs.shape", probs.shape)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
def decode_probs_split(self, probs_split, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
return self.decoder.decode_probs_split(
probs_split, vocab_list, decoding_method, lang_model_path,
beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes)
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=len(dataloader.collate_fn.vocab_list),
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id)
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def forward(self, audio, audio_len):
"""export model function
Args:
audio (Tensor): [B, T, D]
audio_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs, eouts_len
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, self.encoder.feat_size],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
return static_model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size: int,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size: int,
hidden_size: int,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.relu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states # shape [batch_size, hidden_size]
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
bias_u, bias_r, bias_c = paddle.split(
self.bias_hh, num_or_sections=3, axis=0)
weight_hh = paddle.transpose(
self.weight_hh,
perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size]
w_u_r_c = paddle.flatten(weight_hh)
size_u_r = self.hidden_size * 2 * self.hidden_size
w_u_r = paddle.reshape(w_u_r_c[:size_u_r],
(self.hidden_size, self.hidden_size * 2))
w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1)
w_c = paddle.reshape(w_u_r_c[size_u_r:],
(self.hidden_size, self.hidden_size))
h_u = paddle.matmul(
pre_hidden, w_u,
transpose_y=False) + bias_u #shape [batch_size, hidden_size]
h_r = paddle.matmul(
pre_hidden, w_r,
transpose_y=False) + bias_r #shape [batch_size, hidden_size]
x_u, x_r, x_c = paddle.split(
x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size]
u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size]
r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size]
c = self._activation(
x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) +
bias_c) # [batch_size, hidden_size]
h = (1 - u) * pre_hidden + u * c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int, share_weights: bool):
super().__init__()
self.share_weights = share_weights
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self,
i_size: int,
h_size: int,
num_stacks: int,
use_gru: bool,
share_rnn_weights: bool):
super().__init__()
rnn_stacks = []
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
else:
rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.LayerList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: 1000.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 2048
use_gru: False
share_rnn_weights: True
blank_id: 28
training:
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#!/bin/bash
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh
if [ $? -ne 0 ]; then
exit 1
fi
tar xzvf librispeech_v1.8_to_v2.x.tar.gz
mv librispeech_v1.8.pdparams exp/deepspeech2/checkpoints/
mv README.md exp/deepspeech2/
mv mean_std.npz data/
mv vocab.txt data/
rm librispeech_v1.8_to_v2.x.tar.gz -f
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="True"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mv data/manifest.${set} data/manifest.${set}.raw
done
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
for set in train-clean-100 train-clean-360 train-other-500; do
cat data/manifest.${set}.raw >> data/manifest.train.raw
done
for set in dev-clean dev-other; do
cat data/manifest.${set}.raw >> data/manifest.dev.raw
done
for set in test-clean test-other; do
cat data/manifest.${set}.raw >> data/manifest.test.raw
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--num_samples=2000 \
--specgram_type="linear" \
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=20.0 \
--use_dB_normalization=True \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest.${set} failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "LibriSpeech Data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
#! /usr/bin/env bash
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz'
MD5=7b0f582fe2f5a840b840e7ee52246bc5
TARGET=./librispeech_v1.8_to_v2.x.tar.gz
echo "Download LibriSpeech model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download LibriSpeech model!"
exit 1
fi
tar -zxvf $TARGET
exit 0
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin
echo "BIN_DIR "${BIN_DIR}
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=librispeech_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
...@@ -10,8 +10,11 @@ ...@@ -10,8 +10,11 @@
| Model | Params | Release | Config | Test set | Loss | CER | | Model | Params | Release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 | | DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 6.016139030456543 | 0.066549 |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 7181e427 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | | DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 | | DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 |
...@@ -40,9 +40,12 @@ model: ...@@ -40,9 +40,12 @@ model:
rnn_layer_size: 1024 rnn_layer_size: 1024
use_gru: True use_gru: True
share_rnn_weights: False share_rnn_weights: False
blank_id: 0
ctc_grad_norm_type: instance
training: training:
n_epoch: 80 n_epoch: 80
accum_grad: 1
lr: 2e-3 lr: 2e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -36,17 +36,20 @@ collator: ...@@ -36,17 +36,20 @@ collator:
model: model:
num_conv_layers: 2 num_conv_layers: 2
num_rnn_layers: 3 num_rnn_layers: 5
rnn_layer_size: 1024 rnn_layer_size: 1024
rnn_direction: forward # [forward, bidirect] rnn_direction: forward # [forward, bidirect]
num_fc_layers: 1 num_fc_layers: 0
fc_layers_size_list: 512, fc_layers_size_list: -1,
use_gru: False use_gru: False
blank_id: 0
ctc_grad_norm_type: instance
training: training:
n_epoch: 50 n_epoch: 50
accum_grad: 1
lr: 2e-3 lr: 2e-3
lr_decay: 0.91 # 0.83 lr_decay: 0.9 # 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 3.0 global_grad_clip: 3.0
log_interval: 100 log_interval: 100
...@@ -59,7 +62,7 @@ decoding: ...@@ -59,7 +62,7 @@ decoding:
error_rate_type: cer error_rate_type: cer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 1.9 alpha: 2.2 #1.9
beta: 5.0 beta: 5.0
beam_size: 300 beam_size: 300
cutoff_prob: 0.99 cutoff_prob: 0.99
......
#!/bin/bash
source path.sh
# run on MacOS
# brew install portaudio
# pip install pyaudio
# pip install keyboard
# start demo client
python3 -u ${BIN_DIR}/deploy/client.py \
--host_ip="localhost" \
--host_port=8086 \
if [ $? -ne 0 ]; then
echo "Failed in starting demo client!"
exit 1
fi
exit 0
...@@ -13,13 +13,7 @@ ckpt_path_prefix=$2 ...@@ -13,13 +13,7 @@ ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
model_type=$4 model_type=$4
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
#!/bin/bash
# TODO: replace the model with a mandarin model
if [[ $# != 1 ]];then
echo "usage: $1 checkpoint_path"
exit -1
fi
source path.sh
# download language model
bash local/download_lm_ch.sh
if [ $? -ne 0 ]; then
exit 1
fi
# download well-trained model
#bash local/download_model.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
# start demo server
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/deploy/server.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--host_ip="localhost" \
--host_port=8086 \
--speech_save_dir="demo_cache" \
--checkpoint_path ${1}
if [ $? -ne 0 ]; then
echo "Failed in starting demo server!"
exit 1
fi
exit 0
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
model_type=$3 model_type=$3
...@@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then ...@@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.rsl \ --result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
jit_model_export_path=$2 jit_model_export_path=$2
model_type=$3 model_type=$3
...@@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then ...@@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then
fi fi
python3 -u ${BIN_DIR}/test_export.py \ python3 -u ${BIN_DIR}/test_export.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${jit_model_export_path}.rsl \ --result_file ${jit_model_export_path}.rsl \
--export_path ${jit_model_export_path} \ --export_path ${jit_model_export_path} \
......
...@@ -12,27 +12,22 @@ config_path=$1 ...@@ -12,27 +12,22 @@ config_path=$1
ckpt_name=$2 ckpt_name=$2
model_type=$3 model_type=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
mkdir -p exp mkdir -p exp
# seed may break model convergence
seed=10086 seed=10086
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--model_type ${model_type} \ --model_type ${model_type} \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
#!/bin/bash
# grid-search for hyper-parameters in language model
python3 -u ${BIN_DIR}/tune.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--num_batches=10 \
--batch_size=128 \
--beam_size=300 \
--num_proc_bsearch=8 \
--num_alphas=10 \
--num_betas=10 \
--alpha_from=0.0 \
--alpha_to=5.0 \
--beta_from=-6 \
--beta_to=6 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--checkpoint_path ${1}
if [ $? -ne 0 ]; then
echo "Failed in tuning!"
exit 1
fi
exit 0
...@@ -27,7 +27,7 @@ fi ...@@ -27,7 +27,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
...@@ -76,6 +76,8 @@ model: ...@@ -76,6 +76,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -71,6 +71,8 @@ model: ...@@ -71,6 +71,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -22,8 +18,7 @@ mkdir -p ${output_dir} ...@@ -22,8 +18,7 @@ mkdir -p ${output_dir}
# align dump in `result_file` # align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file` # .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \ python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.align \ --result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -12,13 +12,7 @@ config_path=$1 ...@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
...@@ -8,11 +8,6 @@ fi ...@@ -8,11 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -39,8 +34,7 @@ for type in attention ctc_greedy_search; do ...@@ -39,8 +34,7 @@ for type in attention ctc_greedy_search; do
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
mkdir -p ${output_dir} mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.rsl \ --result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
...@@ -58,8 +52,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do ...@@ -58,8 +52,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
mkdir -p ${output_dir} mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.rsl \ --result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
#!/bin/bash #!/bin/bash
profiler_options=
benchmark_batch_size=0
benchmark_max_step=0
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# != 2 ];then if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1 exit -1
fi fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \ --seed ${seed} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
...@@ -25,7 +25,7 @@ fi ...@@ -25,7 +25,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -20,7 +16,6 @@ ckpt_name=$(basename ${ckpt_prefxi}) ...@@ -20,7 +16,6 @@ ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp mkdir -p exp
batch_size=1 batch_size=1
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
mkdir -p ${output_dir} mkdir -p ${output_dir}
...@@ -28,8 +23,7 @@ mkdir -p ${output_dir} ...@@ -28,8 +23,7 @@ mkdir -p ${output_dir}
# align dump in `result_file` # align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file` # .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \ python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.align \ --result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -12,13 +12,7 @@ config_path=$1 ...@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -32,8 +28,7 @@ for type in attention ctc_greedy_search; do ...@@ -32,8 +28,7 @@ for type in attention ctc_greedy_search; do
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
mkdir -p ${output_dir} mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.rsl \ --result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
...@@ -51,8 +46,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do ...@@ -51,8 +46,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
mkdir -p ${output_dir} mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.rsl \ --result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -11,27 +11,23 @@ echo "using $ngpu gpus..." ...@@ -11,27 +11,23 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..." echo "using ${device}..."
mkdir -p exp mkdir -p exp
seed=1024 # seed may break model convergence
if [ ${seed} ]; then seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
...@@ -25,7 +25,7 @@ fi ...@@ -25,7 +25,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
# [CC-CEDICT](https://cc-cedict.org/wiki/)
What is CC-CEDICT?
CC-CEDICT is a continuation of the CEDICT project.
The objective of the CEDICT project was to create an online, downloadable (as opposed to searchable-only) public-domain Chinese-English dictionary.
CEDICT was started by Paul Andrew Denisowski in October 1997.
For the most part, the project is modeled on Jim Breen's highly successful EDICT (Japanese-English dictionary) project and is intended to be a collaborative effort,
with users providing entries and corrections to the main file.
## Parse CC-CEDICT to Json format
1. Parse to Json
```
run.sh
```
2. Result
```
exp/
|-- cedict
`-- cedict.json
0 directories, 2 files
```
```
4c4bffc84e24467fe1b2ea9ba37ed6b6 exp/cedict
3adf504dacd13886f88cc9fe3b37c75d exp/cedict.json
```
```
==> exp/cedict <==
# CC-CEDICT
# Community maintained free Chinese-English dictionary.
#
# Published by MDBG
#
# License:
# Creative Commons Attribution-ShareAlike 4.0 International License
# https://creativecommons.org/licenses/by-sa/4.0/
#
# Referenced works:
==> exp/cedict.json <==
{"traditional": "2019\u51a0\u72c0\u75c5\u6bd2\u75c5", "simplified": "2019\u51a0\u72b6\u75c5\u6bd2\u75c5", "pinyin": "er4 ling2 yi1 jiu3 guan1 zhuang4 bing4 du2 bing4", "english": "COVID-19, the coronavirus disease identified in 2019"}
{"traditional": "21\u4e09\u9ad4\u7d9c\u5408\u75c7", "simplified": "21\u4e09\u4f53\u7efc\u5408\u75c7", "pinyin": "er4 shi2 yi1 san1 ti3 zong1 he2 zheng4", "english": "trisomy"}
{"traditional": "3C", "simplified": "3C", "pinyin": "san1 C", "english": "abbr. for computers, communications, and consumer electronics"}
{"traditional": "3P", "simplified": "3P", "pinyin": "san1 P", "english": "(slang) threesome"}
{"traditional": "3Q", "simplified": "3Q", "pinyin": "san1 Q", "english": "(Internet slang) thank you (loanword)"}
{"traditional": "421", "simplified": "421", "pinyin": "si4 er4 yi1", "english": "four grandparents, two parents and an only child"}
{"traditional": "502\u81a0", "simplified": "502\u80f6", "pinyin": "wu3 ling2 er4 jiao1", "english": "cyanoacrylate glue"}
{"traditional": "88", "simplified": "88", "pinyin": "ba1 ba1", "english": "(Internet slang) bye-bye (alternative for \u62dc\u62dc[bai2 bai2])"}
{"traditional": "996", "simplified": "996", "pinyin": "jiu3 jiu3 liu4", "english": "9am-9pm, six days a week (work schedule)"}
{"traditional": "A", "simplified": "A", "pinyin": "A", "english": "(slang) (Tw) to steal"}
```
# Download Baker dataset
Baker dataset has to be downloaded mannually and moved to 'data/', because you will have to pass the CATTCHA from a browswe to download the dataset.
Download URL https://test.data-baker.com/#/data/index/source.
# G2P
* zh - Chinese G2P
# G2P
* WS
jieba
* G2P
pypinyin
* Tone sandhi
simple
We recommend using [Paraket](https://github.com/PaddlePaddle/Parakeet] [TextFrontEnd](https://github.com/PaddlePaddle/Parakeet/blob/develop/parakeet/frontend/__init__.py) to do G2P.
The phoneme set should be changed, you can reference `examples/thchs30/a0/data/dict/syllable.lexicon`.
## Download Baker dataset
[Baker](https://test.data-baker.com/#/data/index/source) dataset has to be downloaded mannually and moved to './data',
because you will have to pass the `CATTCHA` from a browswe to download the dataset.
## RUN
```
. path.sh
./run.sh
```
## Result
```
exp/
|-- 000001-010000.txt
|-- ref.pinyin
|-- trans.jieba.pinyin
`-- trans.pinyin
0 directories, 4 files
```
```
4f5a368441eb16aaf43dc1972f8b63dd exp/000001-010000.txt
01707896391c2de9b6fc4a39654be942 exp/ref.pinyin
43380ef160f65a23a3a0544700aa49b8 exp/trans.jieba.pinyin
8e6ff1fc22d8e8584082e804e8bcdeb7 exp/trans.pinyin
```
```
==> exp/000001-010000.txt <==
000001 卡尔普#2陪外孙#1玩滑梯#4。
ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 假语村言#2别再#1拥抱我#4。
jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 宝马#1配挂#1跛骡鞍#3,貂蝉#1怨枕#2董翁榻#4。
bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 邓小平#2与#1撒切尔#2会晤#4。
deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4
000005 老虎#1幼崽#2与#1宠物犬#1玩耍#4。
lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3
==> exp/ref.pinyin <==
000001 ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4
000005 lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu2 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan2 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi2 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
==> exp/trans.jieba.pinyin <==
000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4
000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
==> exp/trans.pinyin <==
000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4
000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
```
export MAIN_ROOT=`realpath ${PWD}/../../` export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C export LC_ALL=C
......
...@@ -6,16 +6,19 @@ stage=-1 ...@@ -6,16 +6,19 @@ stage=-1
stop_stage=100 stop_stage=100
exp_dir=exp exp_dir=exp
data_dir=data data=data
source ${MAIN_ROOT}/utils/parse_options.sh || exit -1 source ${MAIN_ROOT}/utils/parse_options.sh || exit -1
mkdir -p ${exp_dir} mkdir -p ${exp_dir}
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ];then
test -e ${data}/BZNSYP.rar || { echo "Please download BZNSYP.rar and put it in ${data}; exit -1; }
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then
echo "stage 0: Extracting Prosody Labeling" echo "stage 0: Extracting Prosody Labeling"
bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data_dir} bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data}
fi fi
# convert transcription in chinese into pinyin with pypinyin or jieba+pypinyin # convert transcription in chinese into pinyin with pypinyin or jieba+pypinyin
......
# LibriSpeech # LibriSpeech
## Data
| Data Subset | Duration in Seconds |
| --- | --- |
| data/manifest.train | 0.83s ~ 29.735s |
| data/manifest.dev | 1.065 ~ 35.155s |
| data/manifest.test-clean | 1.285s ~ 34.955s |
## Deepspeech2 ## Deepspeech2
| Model | Params | release | Config | Test set | Loss | WER | | Model | Params | release | Config | Test set | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | 14.49190807 | test-clean | 0.067283 | | DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | test-clean | 14.49190807 | 0.067283 |
| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | | DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | test-clean | 15.184467315673828 | 0.072154 |
| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | | DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | test-clean | - | 0.073973 |
| DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 | | DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 |
...@@ -4,7 +4,7 @@ data: ...@@ -4,7 +4,7 @@ data:
dev_manifest: data/manifest.dev-clean dev_manifest: data/manifest.dev-clean
test_manifest: data/manifest.test-clean test_manifest: data/manifest.test-clean
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 # second max_input_len: 30.0 # second
min_output_len: 0.0 min_output_len: 0.0
max_output_len: .inf max_output_len: .inf
min_output_input_ratio: 0.00 min_output_input_ratio: 0.00
...@@ -40,9 +40,12 @@ model: ...@@ -40,9 +40,12 @@ model:
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
blank_id: 0
ctc_grad_norm_type: instance
training: training:
n_epoch: 50 n_epoch: 50
accum_grad: 1
lr: 1e-3 lr: 1e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -4,14 +4,14 @@ data: ...@@ -4,14 +4,14 @@ data:
dev_manifest: data/manifest.dev-clean dev_manifest: data/manifest.dev-clean
test_manifest: data/manifest.test-clean test_manifest: data/manifest.test-clean
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 # second max_input_len: 30.0 # second
min_output_len: 0.0 min_output_len: 0.0
max_output_len: .inf max_output_len: .inf
min_output_input_ratio: 0.00 min_output_input_ratio: 0.00
max_output_input_ratio: .inf max_output_input_ratio: .inf
collator: collator:
batch_size: 20 batch_size: 15
mean_std_filepath: data/mean_std.json mean_std_filepath: data/mean_std.json
unit_type: char unit_type: char
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
...@@ -42,9 +42,12 @@ model: ...@@ -42,9 +42,12 @@ model:
num_fc_layers: 2 num_fc_layers: 2
fc_layers_size_list: 512, 256 fc_layers_size_list: 512, 256
use_gru: False use_gru: False
blank_id: 0
ctc_grad_norm_type: instance
training: training:
n_epoch: 50 n_epoch: 50
accum_grad: 4
lr: 1e-3 lr: 1e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -13,13 +13,7 @@ ckpt_path_prefix=$2 ...@@ -13,13 +13,7 @@ ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
model_type=$4 model_type=$4
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
model_type=$3 model_type=$3
...@@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then ...@@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.rsl \ --result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -12,28 +12,22 @@ config_path=$1 ...@@ -12,28 +12,22 @@ config_path=$1
ckpt_name=$2 ckpt_name=$2
model_type=$3 model_type=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp mkdir -p exp
seed=1024 # seed may break model convergence
if [ ${seed} ]; then seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--model_type ${model_type} \ --model_type ${model_type} \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
#!/bin/bash
if [ $# != 1 ];then
echo "usage: tune ckpt_path"
exit 1
fi
# grid-search for hyper-parameters in language model
python3 -u ${BIN_DIR}/tune.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--num_batches=-1 \
--batch_size=128 \
--beam_size=500 \
--num_proc_bsearch=12 \
--num_alphas=45 \
--num_betas=8 \
--alpha_from=1.0 \
--alpha_to=3.2 \
--beta_from=0.1 \
--beta_to=0.45 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--checkpoint_path ${1}
if [ $? -ne 0 ]; then
echo "Failed in tuning!"
exit 1
fi
exit 0
...@@ -25,7 +25,7 @@ fi ...@@ -25,7 +25,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ======
# Usage: <cmd>.pl [options] JOB=1:<nj> <log> <command...>
# e.g.
# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB
#
# Options:
# --time <time>: Limit the maximum time to execute.
# --mem <mem>: Limit the maximum memory usage.
# -–max-jobs-run <njob>: Limit the number parallel jobs. This is ignored for non-array jobs.
# --num-threads <ngpu>: Specify the number of CPU core.
# --gpu <ngpu>: Specify the number of GPU devices.
# --config: Change the configuration file from default.
#
# "JOB=1:10" is used for "array jobs" and it can control the number of parallel jobs.
# The left string of "=", i.e. "JOB", is replaced by <N>(Nth job) in the command and the log file name,
# e.g. "echo JOB" is changed to "echo 3" for the 3rd job and "echo 8" for 8th job respectively.
# Note that the number must start with a positive number, so you can't use "JOB=0:10" for example.
#
# run.pl, queue.pl, slurm.pl, and ssh.pl have unified interface, not depending on its backend.
# These options are mapping to specific options for each backend and
# it is configured by "conf/queue.conf" and "conf/slurm.conf" by default.
# If jobs failed, your configuration might be wrong for your environment.
#
#
# The official documentation for run.pl, queue.pl, slurm.pl, and ssh.pl:
# "Parallelization in Kaldi": http://kaldi-asr.org/doc/queue.html
# =========================================================~
# Select the backend used by run.sh from "local", "sge", "slurm", or "ssh"
cmd_backend='local'
# Local machine, without any Job scheduling system
if [ "${cmd_backend}" = local ]; then
# The other usage
export train_cmd="run.pl"
# Used for "*_train.py": "--gpu" is appended optionally by run.sh
export cuda_cmd="run.pl"
# Used for "*_recog.py"
export decode_cmd="run.pl"
# "qsub" (SGE, Torque, PBS, etc.)
elif [ "${cmd_backend}" = sge ]; then
# The default setting is written in conf/queue.conf.
# You must change "-q g.q" for the "queue" for your environment.
# To know the "queue" names, type "qhost -q"
# Note that to use "--gpu *", you have to setup "complex_value" for the system scheduler.
export train_cmd="queue.pl"
export cuda_cmd="queue.pl"
export decode_cmd="queue.pl"
# "sbatch" (Slurm)
elif [ "${cmd_backend}" = slurm ]; then
# The default setting is written in conf/slurm.conf.
# You must change "-p cpu" and "-p gpu" for the "partion" for your environment.
# To know the "partion" names, type "sinfo".
# You can use "--gpu * " by default for slurm and it is interpreted as "--gres gpu:*"
# The devices are allocated exclusively using "${CUDA_VISIBLE_DEVICES}".
export train_cmd="slurm.pl"
export cuda_cmd="slurm.pl"
export decode_cmd="slurm.pl"
elif [ "${cmd_backend}" = ssh ]; then
# You have to create ".queue/machines" to specify the host to execute jobs.
# e.g. .queue/machines
# host1
# host2
# host3
# Assuming you can login them without any password, i.e. You have to set ssh keys.
export train_cmd="ssh.pl"
export cuda_cmd="ssh.pl"
export decode_cmd="ssh.pl"
# This is an example of specifying several unique options in the JHU CLSP cluster setup.
# Users can modify/add their own command options according to their cluster environments.
elif [ "${cmd_backend}" = jhu ]; then
export train_cmd="queue.pl --mem 2G"
export cuda_cmd="queue-freegpu.pl --mem 2G --gpu 1 --config conf/gpu.conf"
export decode_cmd="queue.pl --mem 4G"
else
echo "$0: Error: Unknown cmd_backend=${cmd_backend}" 1>&2
return 1
fi
...@@ -19,17 +19,17 @@ ...@@ -19,17 +19,17 @@
{ {
"type": "specaug", "type": "specaug",
"params": { "params": {
"W": 0,
"warp_mode": "PIL",
"F": 10, "F": 10,
"T": 50,
"n_freq_masks": 2, "n_freq_masks": 2,
"T": 50,
"n_time_masks": 2, "n_time_masks": 2,
"p": 1.0, "p": 1.0,
"W": 80,
"adaptive_number_ratio": 0, "adaptive_number_ratio": 0,
"adaptive_size_ratio": 0, "adaptive_size_ratio": 0,
"max_n_time_masks": 20, "max_n_time_masks": 20,
"replace_with_zero": true, "replace_with_zero": true
"warp_mode": "PIL"
}, },
"prob": 1.0 "prob": 1.0
} }
......
...@@ -76,6 +76,8 @@ model: ...@@ -76,6 +76,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -69,6 +69,8 @@ model: ...@@ -69,6 +69,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -72,6 +72,8 @@ model: ...@@ -72,6 +72,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -67,6 +67,8 @@ model: ...@@ -67,6 +67,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -22,8 +18,7 @@ mkdir -p ${output_dir} ...@@ -22,8 +18,7 @@ mkdir -p ${output_dir}
# align dump in `result_file` # align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file` # .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \ python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.align \ --result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -12,13 +12,7 @@ config_path=$1 ...@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then set -e
echo "usage: ${0} config_path ckpt_path_prefix"
expdir=exp
datadir=data
nj=32
lmtag=
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpemodel=${bpeprefix}.model
if [ $# != 3 ];then
echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1 exit -1
fi fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 dict=$2
ckpt_prefix=$3
chunk_mode=false chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
...@@ -29,44 +42,46 @@ echo "chunk mode ${chunk_mode}" ...@@ -29,44 +42,46 @@ echo "chunk mode ${chunk_mode}"
# exit 1 # exit 1
#fi #fi
for type in attention ctc_greedy_search; do pids=() # initialize pids
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
echo "Failed in evaluation!" (
exit 1 for rtask in ${recog_set}; do
fi (
done decode_dir=decode_${rtask}_${dmethd}_$(basename ${config_path%.*})_${lmtag}
feat_recog_dir=${datadir}
mkdir -p ${expdir}/${decode_dir}
mkdir -p ${feat_recog_dir}
for type in ctc_prefix_beam_search attention_rescoring; do # split data
echo "decoding ${type}" split_json.sh ${feat_recog_dir}/manifest.${rtask} ${nj}
#### use CPU for decoding
ngpu=0
# set batchsize 0 to disable batch decoding
batch_size=1 batch_size=1
${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result_file ${expdir}/${decode_dir}/data.JOB.json \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} --opts decoding.decoding_method ${dmethd} \
--opts decoding.batch_size ${batch_size} \
--opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true ${expdir}/${decode_dir} ${dict}
if [ $? -ne 0 ]; then ) &
echo "Failed in evaluation!" pids+=($!) # store background pids
exit 1 done
fi ) &
pids+=($!) # store background pids
done done
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
echo "Finished"
exit 0 exit 0
...@@ -11,27 +11,24 @@ echo "using $ngpu gpus..." ...@@ -11,27 +11,24 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp mkdir -p exp
seed=1024 # seed may break model convergence
if [ ${seed} ]; then seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
# export FLAGS_cudnn_exhaustive_search=true
# export FLAGS_conv_workspace_size_limit=4000
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
export MAIN_ROOT=`realpath ${PWD}/../../../` export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${PWD}/utils:${PATH} export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
export LC_ALL=C export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
......
#!/bin/bash #!/bin/bash
set -e set -e
source path.sh
. ./path.sh || exit 1;
. ./cmd.sh || exit 1;
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
avg_num=5 avg_num=5
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
...@@ -24,7 +27,7 @@ fi ...@@ -24,7 +27,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ======
# Usage: <cmd>.pl [options] JOB=1:<nj> <log> <command...>
# e.g.
# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB
#
# Options:
# --time <time>: Limit the maximum time to execute.
# --mem <mem>: Limit the maximum memory usage.
# -–max-jobs-run <njob>: Limit the number parallel jobs. This is ignored for non-array jobs.
# --num-threads <ngpu>: Specify the number of CPU core.
# --gpu <ngpu>: Specify the number of GPU devices.
# --config: Change the configuration file from default.
#
# "JOB=1:10" is used for "array jobs" and it can control the number of parallel jobs.
# The left string of "=", i.e. "JOB", is replaced by <N>(Nth job) in the command and the log file name,
# e.g. "echo JOB" is changed to "echo 3" for the 3rd job and "echo 8" for 8th job respectively.
# Note that the number must start with a positive number, so you can't use "JOB=0:10" for example.
#
# run.pl, queue.pl, slurm.pl, and ssh.pl have unified interface, not depending on its backend.
# These options are mapping to specific options for each backend and
# it is configured by "conf/queue.conf" and "conf/slurm.conf" by default.
# If jobs failed, your configuration might be wrong for your environment.
#
#
# The official documentation for run.pl, queue.pl, slurm.pl, and ssh.pl:
# "Parallelization in Kaldi": http://kaldi-asr.org/doc/queue.html
# =========================================================~
# Select the backend used by run.sh from "local", "sge", "slurm", or "ssh"
cmd_backend='local'
# Local machine, without any Job scheduling system
if [ "${cmd_backend}" = local ]; then
# The other usage
export train_cmd="run.pl"
# Used for "*_train.py": "--gpu" is appended optionally by run.sh
export cuda_cmd="run.pl"
# Used for "*_recog.py"
export decode_cmd="run.pl"
# "qsub" (SGE, Torque, PBS, etc.)
elif [ "${cmd_backend}" = sge ]; then
# The default setting is written in conf/queue.conf.
# You must change "-q g.q" for the "queue" for your environment.
# To know the "queue" names, type "qhost -q"
# Note that to use "--gpu *", you have to setup "complex_value" for the system scheduler.
export train_cmd="queue.pl"
export cuda_cmd="queue.pl"
export decode_cmd="queue.pl"
# "sbatch" (Slurm)
elif [ "${cmd_backend}" = slurm ]; then
# The default setting is written in conf/slurm.conf.
# You must change "-p cpu" and "-p gpu" for the "partion" for your environment.
# To know the "partion" names, type "sinfo".
# You can use "--gpu * " by default for slurm and it is interpreted as "--gres gpu:*"
# The devices are allocated exclusively using "${CUDA_VISIBLE_DEVICES}".
export train_cmd="slurm.pl"
export cuda_cmd="slurm.pl"
export decode_cmd="slurm.pl"
elif [ "${cmd_backend}" = ssh ]; then
# You have to create ".queue/machines" to specify the host to execute jobs.
# e.g. .queue/machines
# host1
# host2
# host3
# Assuming you can login them without any password, i.e. You have to set ssh keys.
export train_cmd="ssh.pl"
export cuda_cmd="ssh.pl"
export decode_cmd="ssh.pl"
# This is an example of specifying several unique options in the JHU CLSP cluster setup.
# Users can modify/add their own command options according to their cluster environments.
elif [ "${cmd_backend}" = jhu ]; then
export train_cmd="queue.pl --mem 2G"
export cuda_cmd="queue-freegpu.pl --mem 2G --gpu 1 --config conf/gpu.conf"
export decode_cmd="queue.pl --mem 4G"
else
echo "$0: Error: Unknown cmd_backend=${cmd_backend}" 1>&2
return 1
fi
...@@ -76,6 +76,8 @@ model: ...@@ -76,6 +76,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -69,6 +69,8 @@ model: ...@@ -69,6 +69,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -72,6 +72,8 @@ model: ...@@ -72,6 +72,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -12,7 +12,7 @@ collator: ...@@ -12,7 +12,7 @@ collator:
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32 batch_size: 30
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug minibatches: 0 # for debug
...@@ -22,7 +22,7 @@ collator: ...@@ -22,7 +22,7 @@ collator:
batch_frames_out: 0 batch_frames_out: 0
batch_frames_inout: 0 batch_frames_inout: 0
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
num_workers: 2 num_workers: 0
subsampling_factor: 1 subsampling_factor: 1
num_encs: 1 num_encs: 1
...@@ -58,6 +58,8 @@ model: ...@@ -58,6 +58,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: batch
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
...@@ -81,7 +83,7 @@ scheduler_conf: ...@@ -81,7 +83,7 @@ scheduler_conf:
lr_decay: 1.0 lr_decay: 1.0
decoding: decoding:
batch_size: 64 batch_size: 1
error_rate_type: wer error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
dict_path=$2 dict_path=$2
ckpt_prefix=$3 ckpt_prefix=$3
...@@ -26,8 +22,7 @@ python3 -u ${BIN_DIR}/test.py \ ...@@ -26,8 +22,7 @@ python3 -u ${BIN_DIR}/test.py \
--model-name 'u2_kaldi' \ --model-name 'u2_kaldi' \
--run-mode 'align' \ --run-mode 'align' \
--dict-path ${dict_path} \ --dict-path ${dict_path} \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result-file ${output_dir}/${type}.align \ --result-file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -12,15 +12,9 @@ config_path=$1 ...@@ -12,15 +12,9 @@ config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--model-name 'u2_kaldi' \ --model-name 'u2_kaldi' \
--run-mode 'export' \ --run-mode 'export' \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
#!/bin/bash #!/bin/bash
set -e
expdir=exp
datadir=data
nj=32
lmtag=
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpemodel=${bpeprefix}.model
if [ $# != 3 ];then if [ $# != 3 ];then
echo "usage: ${0} config_path dict_path ckpt_path_prefix" echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1 exit -1
...@@ -8,13 +25,8 @@ fi ...@@ -8,13 +25,8 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
dict_path=$2 dict=$2
ckpt_prefix=$3 ckpt_prefix=$3
chunk_mode=false chunk_mode=false
...@@ -30,50 +42,49 @@ echo "chunk mode ${chunk_mode}" ...@@ -30,50 +42,49 @@ echo "chunk mode ${chunk_mode}"
# exit 1 # exit 1
#fi #fi
for type in attention ctc_greedy_search; do pids=() # initialize pids
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--model-name u2_kaldi \
--run-mode test \
--dict-path ${dict_path} \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result-file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
echo "Failed in evaluation!" (
exit 1 for rtask in ${recog_set}; do
fi (
done decode_dir=decode_${rtask}_${dmethd}_$(basename ${config_path%.*})_${lmtag}
feat_recog_dir=${datadir}
mkdir -p ${expdir}/${decode_dir}
mkdir -p ${feat_recog_dir}
# split data
split_json.sh ${feat_recog_dir}/manifest.${rtask} ${nj}
for type in ctc_prefix_beam_search attention_rescoring; do #### use CPU for decoding
echo "decoding ${type}" ngpu=0
# set batchsize 0 to disable batch decoding
batch_size=1 batch_size=1
${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--model-name u2_kaldi \ --model-name u2_kaldi \
--run-mode test \ --run-mode test \
--dict-path ${dict_path} \ --nproc ${ngpu} \
--device ${device} \ --dict-path ${dict} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result-file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} --result-file ${expdir}/${decode_dir}/data.JOB.json \
--opts decoding.decoding_method ${dmethd} \
--opts decoding.batch_size ${batch_size} \
--opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true ${expdir}/${decode_dir} ${dict}
if [ $? -ne 0 ]; then ) &
echo "Failed in evaluation!" pids+=($!) # store background pids
exit 1 done
fi ) &
pids+=($!) # store background pids
done done
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
echo "Finished"
exit 0 exit 0
...@@ -11,28 +11,22 @@ echo "using $ngpu gpus..." ...@@ -11,28 +11,22 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp mkdir -p exp
seed=1024 # seed may break model convergence
if [ ${seed} ]; then seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--model-name u2_kaldi \ --model-name u2_kaldi \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
export MAIN_ROOT=`realpath ${PWD}/../../../` export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${PWD}/utils:${PATH} export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
export LC_ALL=C export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
......
#!/bin/bash #!/bin/bash
set -e set -e
source path.sh
. ./path.sh || exit 1;
. ./cmd.sh || exit 1;
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
dict_path=data/train_960_unigram5000_units.txt dict_path=data/train_960_unigram5000_units.txt
avg_num=5 avg_num=10
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
...@@ -20,12 +22,12 @@ fi ...@@ -20,12 +22,12 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh latest exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
# Ngram LM
* s0 - kenlm ngram lm
...@@ -2,6 +2,95 @@ ...@@ -2,6 +2,95 @@
Train chinese chararctor ngram lm by [kenlm](https://github.com/kpu/kenlm). Train chinese chararctor ngram lm by [kenlm](https://github.com/kpu/kenlm).
## Run
``` ```
. path.sh
bash run.sh bash run.sh
``` ```
## Results
```
exp/
|-- text
|-- text.char.tn
|-- text.word.tn
|-- text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa
|-- text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa.klm.bin
|-- text_zh_word_o3_p0_0_0_a22_q8_b8.arpa
`-- text_zh_word_o3_p0_0_0_a22_q8_b8.arpa.klm.bin
0 directories, 7 files
```
```
3ae083627b9b6cef1a82d574d8483f97 exp/text
d97da252d2a63a662af22f98af30cb8c exp/text.char.tn
c18b03005bd094dbfd9b46442be361fd exp/text.word.tn
73dbf50097896eda33985e11e1ba9a3a exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa
01334e2044c474b99c4f2ffbed790626 exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa.klm.bin
36a42de548045b54662411ae7982c77f exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa
332422803ffd73dd7ffd16cd2b0abcd5 exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa.klm.bin
```
```
==> exp/text <==
少先队员因该为老人让坐
祛痘印可以吗?有效果吗?
不知这款牛奶口感怎样? 小孩子喝行吗!
是转基因油?
我家宝宝13斤用多大码的
会起坨吗?
请问给送上楼吗?
亲是送赁上门吗
送货时候有外包装没有还是直接发货过来
会不会有坏的?
==> exp/text.char.tn <==
少 先 队 员 因 该 为 老 人 让 坐
祛 痘 印 可 以 吗 有 效 果 吗
不 知 这 款 牛 奶 口 感 怎 样 小 孩 子 喝 行 吗
是 转 基 因 油
我 家 宝 宝 十 三 斤 用 多 大 码 的
会 起 坨 吗
请 问 给 送 上 楼 吗
亲 是 送 赁 上 门 吗
送 货 时 候 有 外 包 装 没 有 还 是 直 接 发 货 过 来
会 不 会 有 坏 的
==> exp/text.word.tn <==
少先队员 因该 为 老人 让 坐
祛痘 印 可以 吗 有 效果 吗
不知 这 款 牛奶 口感 怎样 小孩子 喝行 吗
是 转基因 油
我家 宝宝 十三斤 用多大码 的
会起 坨 吗
请问 给 送 上楼 吗
亲是 送赁 上门 吗
送货 时候 有 外包装 没有 还是 直接 发货 过来
会 不会 有坏 的
==> exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa <==
\data\
ngram 1=587
ngram 2=395
ngram 3=100
ngram 4=2
ngram 5=0
\1-grams:
-3.272324 <unk> 0
0 <s> -0.36706257
==> exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa <==
\data\
ngram 1=689
ngram 2=1398
ngram 3=1506
\1-grams:
-3.1755018 <unk> 0
0 <s> -0.23069073
-1.2318869 </s> 0
-3.067262 少先队员 -0.051341705
```
# Punctation Restoration # Punctation Restoration
Please using `https://github.com/745165806/PaddleSpeechTask` to do this task. Please using [PaddleSpeechTask](https://github.com/745165806/PaddleSpeechTask] to do this task.
# [SentencePiece Model](https://github.com/google/sentencepiece) # [SentencePiece Model](https://github.com/google/sentencepiece)
## Run
Train a `spm` model for English tokenizer. Train a `spm` model for English tokenizer.
``` ```
. path.sh
bash run.sh bash run.sh
``` ```
## Results
```
data/
└── lang_char
├── input.bpe
├── input.decode
├── input.txt
├── train_unigram100.model
├── train_unigram100_units.txt
└── train_unigram100.vocab
1 directory, 6 files
```
```
b5a230c26c61db5c36f34e503102f936 data/lang_char/input.bpe
ec5a9b24acc35469229e41256ceaf77d data/lang_char/input.decode
ec5a9b24acc35469229e41256ceaf77d data/lang_char/input.txt
124bf3fe7ce3b73b1994234c15268577 data/lang_char/train_unigram100.model
0df2488cc8eaace95eb12713facb5cf0 data/lang_char/train_unigram100_units.txt
46360cac35c751310e8e8ffd3a034cb5 data/lang_char/train_unigram100.vocab
```
```
==> data/lang_char/input.bpe <==
▁mi ster ▁quilter ▁ is ▁the ▁a p ost le ▁o f ▁the ▁mi d d le ▁c las s es ▁ and ▁we ▁ar e ▁g l a d ▁ to ▁we l c om e ▁h is ▁g o s pe l
▁ n or ▁ is ▁mi ster ▁quilter ' s ▁ma nne r ▁ l ess ▁in ter es t ing ▁tha n ▁h is ▁ma t ter
▁h e ▁ t e ll s ▁us ▁tha t ▁ at ▁ t h is ▁f es t ive ▁ s e ason ▁o f ▁the ▁ y e ar ▁w ith ▁ ch r is t m a s ▁ and ▁ro a s t ▁be e f ▁ l o om ing ▁be fore ▁us ▁ s i mile s ▁d r a w n ▁f r om ▁ e at ing ▁ and ▁it s ▁re s u l t s ▁o c c ur ▁m ost ▁re a di l y ▁ to ▁the ▁ mind
▁h e ▁ ha s ▁g r a v e ▁d o u b t s ▁w h e t h er ▁ s i r ▁f r e d er ic k ▁ l eig h to n ' s ▁w or k ▁ is ▁re all y ▁gre e k ▁a f ter ▁ all ▁ and ▁c a n ▁di s c o v er ▁in ▁it ▁b u t ▁li t t le ▁o f ▁ro ck y ▁it ha c a
▁li nne ll ' s ▁ p ic tur es ▁ar e ▁a ▁ s or t ▁o f ▁ u p ▁g u ar d s ▁ and ▁ at ▁ em ▁painting s ▁ and ▁m ason ' s ▁ e x q u is i t e ▁ i d y ll s ▁ar e ▁a s ▁ n at ion a l ▁a s ▁a ▁ j ing o ▁ p o em ▁mi ster ▁b i r k e t ▁f o ster ' s ▁ l and s c a pe s ▁ s mile ▁ at ▁on e ▁m u ch ▁in ▁the ▁ s a m e ▁w a y ▁tha t ▁mi ster ▁c ar k er ▁us e d ▁ to ▁f las h ▁h is ▁ t e e t h ▁ and ▁mi ster ▁ j o h n ▁c o ll i er ▁g ive s ▁h is ▁ s i t ter ▁a ▁ ch e er f u l ▁ s l a p ▁on ▁the ▁b a ck ▁be fore ▁h
e ▁ s a y s ▁li k e ▁a ▁ s ha m p o o er ▁in ▁a ▁ tur k is h ▁b at h ▁ n e x t ▁ma n
▁it ▁ is ▁o b v i o u s l y ▁ u nne c ess ar y ▁for ▁us ▁ to ▁ p o i n t ▁o u t ▁h o w ▁ l u m i n o u s ▁the s e ▁c rit ic is m s ▁ar e ▁h o w ▁d e l ic at e ▁in ▁ e x p r ess ion
▁on ▁the ▁g e n er a l ▁ p r i n c i p l es ▁o f ▁ar t ▁mi ster ▁quilter ▁w rit es ▁w ith ▁ e qual ▁ l u c i di t y
▁painting ▁h e ▁ t e ll s ▁us ▁ is ▁o f ▁a ▁di f f er e n t ▁ qual i t y ▁ to ▁ma t h em at ic s ▁ and ▁f i nish ▁in ▁ar t ▁ is ▁a d d ing ▁m or e ▁f a c t
▁a s ▁for ▁ e t ch ing s ▁the y ▁ar e ▁o f ▁ t w o ▁ k i n d s ▁b rit is h ▁ and ▁for eig n
▁h e ▁ l a ment s ▁m ost ▁b i t ter l y ▁the ▁di v or c e ▁tha t ▁ ha s ▁be e n ▁ma d e ▁be t w e e n ▁d e c or at ive ▁ar t ▁ and ▁w ha t ▁we ▁us u all y ▁c all ▁ p ic tur es ▁ma k es ▁the ▁c u s t om ar y ▁a p pe a l ▁ to ▁the ▁ las t ▁ j u d g ment ▁ and ▁re mind s ▁us ▁tha t ▁in ▁the ▁gre at ▁d a y s ▁o f ▁ar t ▁mi c ha e l ▁a n g e l o ▁w a s ▁the ▁f ur nish ing ▁ u p h o l ster er
==> data/lang_char/input.decode <==
mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
nor is mister quilter's manner less interesting than his matter
he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind
he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca
linnell's pictures are a sort of up guards and at em paintings and mason's exquisite idylls are as national as a jingo poem mister birket foster's landscapes smile at one much in the same way that mister carker used to flash his teeth and mister john collier gives his sitter a cheerful slap on the back before he says like a shampooer in a turkish bath next man
it is obviously unnecessary for us to point out how luminous these criticisms are how delicate in expression
on the general principles of art mister quilter writes with equal lucidity
painting he tells us is of a different quality to mathematics and finish in art is adding more fact
as for etchings they are of two kinds british and foreign
he laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes the customary appeal to the last judgment and reminds us that in the great days of art michael angelo was the furnishing upholsterer
==> data/lang_char/input.txt <==
mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
nor is mister quilter's manner less interesting than his matter
he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind
he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca
linnell's pictures are a sort of up guards and at em paintings and mason's exquisite idylls are as national as a jingo poem mister birket foster's landscapes smile at one much in the same way that mister carker used to flash his teeth and mister john collier gives his sitter a cheerful slap on the back before he says like a shampooer in a turkish bath next man
it is obviously unnecessary for us to point out how luminous these criticisms are how delicate in expression
on the general principles of art mister quilter writes with equal lucidity
painting he tells us is of a different quality to mathematics and finish in art is adding more fact
as for etchings they are of two kinds british and foreign
he laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes the customary appeal to the last judgment and reminds us that in the great days of art michael angelo was the furnishing upholsterer
==> data/lang_char/train_unigram100_units.txt <==
<blank> 0
<unk> 1
' 2
a 3
all 4
and 5
ar 6
ason 7
at 8
b 9
==> data/lang_char/train_unigram100.vocab <==
<unk> 0
<s> 0
</s> 0
▁ -2.01742
e -2.7203
s -2.82989
t -2.99689
l -3.53267
n -3.84935
o -3.88229
```
...@@ -68,6 +68,8 @@ model: ...@@ -68,6 +68,8 @@ model:
model_conf: model_conf:
asr_weight: 0.0 asr_weight: 0.0
ctc_weight: 0.0 ctc_weight: 0.0
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -68,6 +68,8 @@ model: ...@@ -68,6 +68,8 @@ model:
model_conf: model_conf:
asr_weight: 0.5 asr_weight: 0.5
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -19,8 +15,7 @@ for type in fullsentence; do ...@@ -19,8 +15,7 @@ for type in fullsentence; do
echo "decoding ${type}" echo "decoding ${type}"
batch_size=32 batch_size=32
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -11,27 +11,21 @@ echo "using $ngpu gpus..." ...@@ -11,27 +11,21 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp mkdir -p exp
seed=1024 # seed may break model convergence
if [ ${seed} ]; then seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
...@@ -26,7 +26,7 @@ fi ...@@ -26,7 +26,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
../../utils/avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
# Regular expression based text normalization for Chinese
For simplicity and ease of implementation, text normalization is basically done by rules and dictionaries. Here's an example.
...@@ -66,6 +66,8 @@ model: ...@@ -66,6 +66,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -22,8 +18,7 @@ mkdir -p ${output_dir} ...@@ -22,8 +18,7 @@ mkdir -p ${output_dir}
# align dump in `result_file` # align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file` # .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \ python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.align \ --result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -12,13 +12,7 @@ config_path=$1 ...@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
...@@ -8,11 +8,6 @@ fi ...@@ -8,11 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -37,8 +32,7 @@ for type in attention ctc_greedy_search; do ...@@ -37,8 +32,7 @@ for type in attention ctc_greedy_search; do
batch_size=64 batch_size=64
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
...@@ -54,8 +48,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do ...@@ -54,8 +48,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}" echo "decoding ${type}"
batch_size=1 batch_size=1
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -11,27 +11,21 @@ echo "using $ngpu gpus..." ...@@ -11,27 +11,21 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp mkdir -p exp
seed=1024 # seed may break model convergence
if [ ${seed} ]; then seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
...@@ -26,7 +26,7 @@ fi ...@@ -26,7 +26,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
...@@ -4,7 +4,7 @@ data: ...@@ -4,7 +4,7 @@ data:
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 max_input_len: 30.0
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
...@@ -41,11 +41,14 @@ model: ...@@ -41,11 +41,14 @@ model:
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
blank_id: 0
ctc_grad_norm_type: instance
training: training:
n_epoch: 10 n_epoch: 10
accum_grad: 1
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 0.8
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
log_interval: 1 log_interval: 1
......
...@@ -4,7 +4,7 @@ data: ...@@ -4,7 +4,7 @@ data:
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 max_input_len: 30.0
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
...@@ -43,9 +43,12 @@ model: ...@@ -43,9 +43,12 @@ model:
num_fc_layers: 2 num_fc_layers: 2
fc_layers_size_list: 512, 256 fc_layers_size_list: 512, 256
use_gru: True use_gru: True
blank_id: 0
ctc_grad_norm_type: instance
training: training:
n_epoch: 10 n_epoch: 10
accum_grad: 1
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -13,13 +13,7 @@ ckpt_path_prefix=$2 ...@@ -13,13 +13,7 @@ ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
model_type=$4 model_type=$4
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
model_type=$3 model_type=$3
...@@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then ...@@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.rsl \ --result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
#!/bin/bash #!/bin/bash
profiler_options=
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# != 3 ];then if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1 exit -1
fi fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
model_type=$3 model_type=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
mkdir -p exp mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--model_type ${model_type} \ --model_type ${model_type} \
--profiler-options "${profiler_options}" \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
#!/bin/bash
if [ $# != 1 ];then
echo "usage: tune ckpt_path"
exit 1
fi
# grid-search for hyper-parameters in language model
python3 -u ${BIN_DIR}/tune.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--num_batches=-1 \
--batch_size=128 \
--beam_size=500 \
--num_proc_bsearch=12 \
--num_alphas=45 \
--num_betas=8 \
--alpha_from=1.0 \
--alpha_to=3.2 \
--beta_from=0.1 \
--beta_to=0.45 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--checkpoint_path ${1}
if [ $? -ne 0 ]; then
echo "Failed in tuning!"
exit 1
fi
exit 0
...@@ -27,7 +27,7 @@ fi ...@@ -27,7 +27,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
...@@ -19,17 +19,17 @@ ...@@ -19,17 +19,17 @@
{ {
"type": "specaug", "type": "specaug",
"params": { "params": {
"W": 0,
"warp_mode": "PIL",
"F": 10, "F": 10,
"T": 50,
"n_freq_masks": 2, "n_freq_masks": 2,
"T": 50,
"n_time_masks": 2, "n_time_masks": 2,
"p": 1.0, "p": 1.0,
"W": 80,
"adaptive_number_ratio": 0, "adaptive_number_ratio": 0,
"adaptive_size_ratio": 0, "adaptive_size_ratio": 0,
"max_n_time_masks": 20, "max_n_time_masks": 20,
"replace_with_zero": true, "replace_with_zero": true
"warp_mode": "PIL"
}, },
"prob": 1.0 "prob": 1.0
} }
......
...@@ -4,7 +4,7 @@ data: ...@@ -4,7 +4,7 @@ data:
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
min_input_len: 0.5 # second min_input_len: 0.5 # second
max_input_len: 20.0 # second max_input_len: 30.0 # second
min_output_len: 0.0 # tokens min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
...@@ -76,6 +76,8 @@ model: ...@@ -76,6 +76,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -69,6 +69,8 @@ model: ...@@ -69,6 +69,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -72,6 +72,8 @@ model: ...@@ -72,6 +72,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
...@@ -66,6 +66,8 @@ model: ...@@ -66,6 +66,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
...@@ -84,7 +86,7 @@ training: ...@@ -84,7 +86,7 @@ training:
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint: checkpoint:
kbest_n: 10 kbest_n: 2
latest_n: 1 latest_n: 1
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -22,8 +18,7 @@ mkdir -p ${output_dir} ...@@ -22,8 +18,7 @@ mkdir -p ${output_dir}
# align dump in `result_file` # align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file` # .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \ python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${output_dir}/${type}.align \ --result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
...@@ -12,13 +12,7 @@ config_path=$1 ...@@ -12,13 +12,7 @@ config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
......
...@@ -8,10 +8,6 @@ fi ...@@ -8,10 +8,6 @@ fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
...@@ -35,8 +31,7 @@ for type in attention ctc_greedy_search; do ...@@ -35,8 +31,7 @@ for type in attention ctc_greedy_search; do
batch_size=64 batch_size=64
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
...@@ -52,8 +47,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do ...@@ -52,8 +47,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}" echo "decoding ${type}"
batch_size=1 batch_size=1
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device ${device} \ --nproc ${ngpu} \
--nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
......
#!/bin/bash #!/bin/bash
profiler_options=
benchmark_batch_size=0
benchmark_max_step=0
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# != 2 ];then if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1 exit -1
fi fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
mkdir -p exp mkdir -p exp
seed=1024
if [ ${seed} ]; then
export FLAGS_cudnn_deterministic=True
fi
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device ${device} \ --seed ${seed} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
...@@ -25,7 +25,7 @@ fi ...@@ -25,7 +25,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
# Regular expression based text normalization for Chinese
For simplicity and ease of implementation, text normalization is basically done by rules and dictionaries. Here's an example.
## Run
```
. path.sh
bash run.sh
```
## Results
```
exp/
`-- normalized.txt
0 directories, 1 file
```
```
aff31f8aa08e2a7360228c9ce5886b98 exp/normalized.txt
```
```
今天的最低气温达到零下十度.
只要有四分之三十三的人同意,就可以通过决议。
一九四五年五月二日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。
四月十六日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。
如果剩下的百分之三十点六是过去,那么还有百分之六十九点四.
事情发生在二零二零年三月三十一日的上午八点.
警方正在找一支点二二口径的手枪。
欢迎致电中国联通,北京二零二二年冬奥会官方合作伙伴为您服务
充值缴费请按一,查询话费及余量请按二,跳过本次提醒请按井号键。
快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按九,查询话费、套餐余量、积分及活动返款请按一,手机上网流量开通及取消请按二,查���本机号码及本号所使用套餐请按四,密码修改及重置请按五,紧急开机请按六,挂失请按七,查询充值记录请按八,其它自助服务及工服务请按零
```
...@@ -2,6 +2,7 @@ coverage ...@@ -2,6 +2,7 @@ coverage
gpustat gpustat
jsonlines jsonlines
kaldiio kaldiio
loguru
Pillow Pillow
pre-commit pre-commit
pybind11 pybind11
...@@ -14,5 +15,7 @@ SoundFile==0.9.0.post1 ...@@ -14,5 +15,7 @@ SoundFile==0.9.0.post1
sox sox
tensorboardX tensorboardX
textgrid textgrid
tqdm
typeguard typeguard
visualdl==2.2.0
yacs yacs
old-pd_env.txt
pd_env.txt
# Benchmark Test
## Data
* Aishell
## Docker
```
registry.baidubce.com/paddlepaddle/paddle 2.1.1-gpu-cuda10.2-cudnn7 59d5ec1de486
```
#!/bin/bash
CUR_DIR=${PWD}
ROOT_DIR=../../
# 提供可稳定复现性能的脚本,默认在标准docker环境内py37执行:
# collect env info
bash ${ROOT_DIR}/utils/pd_env_collect.sh
#cat pd_env.txt
# 1 安装该模型需要的依赖 (如需开启优化策略请注明)
#pushd ${ROOT_DIR}/tools; make; popd
#source ${ROOT_DIR}/tools/venv/bin/activate
#pushd ${ROOT_DIR}; bash setup.sh; popd
# 2 拷贝该模型需要数据、预训练模型
# 执行目录:需说明
#pushd ${ROOT_DIR}/examples/aishell/s1
pushd ${ROOT_DIR}/examples/tiny/s1
mkdir -p exp/log
. path.sh
#bash local/data.sh &> exp/log/data.log
# 3 批量运行(如不方便批量,1,2需放到单个模型中)
model_mode_list=(conformer transformer)
fp_item_list=(fp32)
bs_item_list=(32 64 96)
for model_mode in ${model_mode_list[@]}; do
for fp_item in ${fp_item_list[@]}; do
for bs_item in ${bs_item_list[@]}
do
echo "index is speed, 1gpus, begin, ${model_name}"
run_mode=sp
CUDA_VISIBLE_DEVICES=0 bash ${CUR_DIR}/run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode} # (5min)
sleep 60
echo "index is speed, 8gpus, run_mode is multi_process, begin, ${model_name}"
run_mode=mp
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash ${CUR_DIR}/run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode}
sleep 60
done
done
done
popd # aishell/s1
#!/bin/bash
set -xe
# 运行示例:CUDA_VISIBLE_DEVICES=0 bash run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode}
# 参数说明
function _set_params(){
run_mode=${1:-"sp"} # 单卡sp|多卡mp
batch_size=${2:-"64"}
fp_item=${3:-"fp32"} # fp32|fp16
max_iter=${4:-"500"} # 可选,如果需要修改代码提前中断
model_name=${5:-"model_name"}
run_log_path=${TRAIN_LOG_DIR:-$(pwd)} # TRAIN_LOG_DIR 后续QA设置该参数
# 以下不用修改
device=${CUDA_VISIBLE_DEVICES//,/ }
arr=(${device})
num_gpu_devices=${#arr[*]}
log_file=${run_log_path}/${model_name}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices}
}
function _train(){
echo "Train on ${num_gpu_devices} GPUs"
echo "current CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES, gpus=$num_gpu_devices, batch_size=$batch_size"
train_cmd="--benchmark-batch-size ${batch_size}
--benchmark-max-step ${max_iter}
conf/${model_name}.yaml ${model_name}"
case ${run_mode} in
sp) train_cmd="bash local/train.sh "${train_cmd}"" ;;
mp)
train_cmd="bash local/train.sh "${train_cmd}"" ;;
*) echo "choose run_mode(sp or mp)"; exit 1;
esac
# 以下不用修改
timeout 15m ${train_cmd} > ${log_file} 2>&1
if [ $? -ne 0 ];then
echo -e "${model_name}, FAIL"
export job_fail_flag=1
else
echo -e "${model_name}, SUCCESS"
export job_fail_flag=0
fi
trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM
if [ $run_mode = "mp" -a -d mylog ]; then
rm ${log_file}
cp mylog/workerlog.0 ${log_file}
fi
}
_set_params $@
_train
...@@ -13,7 +13,7 @@ null:null ...@@ -13,7 +13,7 @@ null:null
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train: ../../../deepspeech/exps/deepspeech2/bin/train.py --nproc 1 --config conf/deepspeech2.yaml --model_type offline --device gpu norm_train: ../../../deepspeech/exps/deepspeech2/bin/train.py --nproc 1 --config conf/deepspeech2.yaml --model_type offline
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -21,7 +21,7 @@ null:null ...@@ -21,7 +21,7 @@ null:null
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
eval: ../../../deepspeech/exps/deepspeech2/bin/test.py --nproc 1 --config conf/deepspeech2.yaml --result_file tests/9.rsl --model_type offline --device gpu eval: ../../../deepspeech/exps/deepspeech2/bin/test.py --nproc 1 --config conf/deepspeech2.yaml --result_file tests/9.rsl --model_type offline
null:null null:null
## ##
===========================infer_params=========================== ===========================infer_params===========================
......
...@@ -37,13 +37,13 @@ class TestU2Model(unittest.TestCase): ...@@ -37,13 +37,13 @@ class TestU2Model(unittest.TestCase):
def test_make_non_pad_mask(self): def test_make_non_pad_mask(self):
res = make_non_pad_mask(self.lengths) res = make_non_pad_mask(self.lengths)
res2 = make_pad_mask(self.lengths).logical_not() res2 = ~make_pad_mask(self.lengths)
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist()) self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist())
def test_make_pad_mask(self): def test_make_pad_mask(self):
res = make_pad_mask(self.lengths) res = make_pad_mask(self.lengths)
res1 = make_non_pad_mask(self.lengths).logical_not() res1 = ~make_non_pad_mask(self.lengths)
self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res1.tolist()) self.assertSequenceEqual(res.numpy().tolist(), res1.tolist())
......
from typing import Tuple
import numpy as np
import paddle
from paddle import Tensor
from paddle import nn
from paddle.nn import functional as F
def frame(x: Tensor,
num_samples: Tensor,
win_length: int,
hop_length: int,
clip: bool = True) -> Tuple[Tensor, Tensor]:
"""Extract frames from audio.
Parameters
----------
x : Tensor
Shape (N, T), batched waveform.
num_samples : Tensor
Shape (N, ), number of samples of each waveform.
win_length : int
Window length.
hop_length : int
Number of samples shifted between ajancent frames.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
default True
Returns
-------
frames : Tensor
Shape (N, T', win_length).
num_frames : Tensor
Shape (N, ) number of valid frames
"""
assert hop_length <= win_length
num_frames = (num_samples - win_length) // hop_length
padding = (0, 0)
if not clip:
num_frames += 1
# NOTE: pad hop_length - 1 to the right to ensure that there is at most
# one frame dangling to the righe edge
padding = (0, hop_length - 1)
weight = paddle.eye(win_length).unsqueeze(1)
frames = F.conv1d(x.unsqueeze(1),
weight,
padding=padding,
stride=(hop_length, ))
return frames, num_frames
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
Parameters
------------
n_fft : int
Number of samples in a frame.
hop_length : int
Number of samples shifted between adjacent frames.
win_length : int
Length of the window.
clip: bool
Whether to clip audio is necesaary.
"""
def __init__(self,
n_fft: int,
hop_length: int,
win_length: int,
window_type: str = None,
clip: bool = True):
super().__init__()
self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2
self.n_fft = n_fft
self.clip = clip
# calculate window
if window_type is None:
window = np.ones(win_length)
elif window_type == "hann":
window = np.hanning(win_length)
elif window_type == "hamming":
window = np.hamming(win_length)
else:
raise ValueError("Not supported yet!")
if win_length < n_fft:
window = F.pad(window, (0, n_fft - win_length))
elif win_length > n_fft:
window = window[:n_fft]
# (n_bins, n_fft) complex
kernel_size = min(n_fft, win_length)
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size]
w_real = weight.real
w_imag = weight.imag
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
num_samples : Tensor
Number of samples of each waveform.
Returns
------------
D : Tensor
Shape(N, T', n_bins, 2) Spectrogram.
num_frames: Tensor
Shape (N,) number of samples of each spectrogram
"""
num_frames = (num_samples - self.win_length) // self.hop_length
padding = (0, 0)
if not self.clip:
num_frames += 1
padding = (0, self.hop_length - 1)
batch_size, _, _ = paddle.shape(x)
x = x.unsqueeze(-1)
D = F.conv1d(self.weight,
x,
stride=(self.hop_length, ),
padding=padding,
data_format="NLC")
D = paddle.reshape(D, [batch_size, -1, self.n_bin, 2])
return D, num_frames
import paddle
import numpy as np
from typing import Tuple, Optional, Union
# https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/src/feat/feature-window.cc#L109
def povey_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = (0.5 - 0.5 * np.cos(a * i) )**0.85
return win
def hann_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = 0.5 - 0.5 * np.cos(a * i)
return win
def sine_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = np.sin(0.5 * a * i)
return win
def hamm_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = 0.54 - 0.46 * np.cos(a * i)
return win
def get_window(wintype:Optional[str], winlen:int) -> np.ndarray:
"""get window function
Args:
wintype (Optional[str]): window type.
winlen (int): window length in samples.
Raises:
ValueError: not support window.
Returns:
np.ndarray: window coeffs.
"""
# calculate window
if not wintype or wintype == 'rectangular':
window = np.ones(winlen)
elif wintype == "hann":
window = hann_window(winlen)
elif wintype == "hamm":
window = hamm_window(winlen)
elif wintype == "povey":
window = povey_window(winlen)
else:
msg = f"{wintype} Not supported yet!"
raise ValueError(msg)
return window
def dft_matrix(n_fft:int, winlen:int=None, n_bin:int=None) -> Tuple[np.ndarray, np.ndarray, int]:
# https://en.wikipedia.org/wiki/Discrete_Fourier_transform
# (n_bins, n_fft) complex
if n_bin is None:
n_bin = 1 + n_fft // 2
if winlen is None:
winlen = n_bin
# https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49
kernel_size = min(n_fft, winlen)
n = np.arange(0, n_fft, 1.)
wsin = np.empty((n_bin, kernel_size)) #[Cout, kernel_size]
wcos = np.empty((n_bin, kernel_size)) #[Cout, kernel_size]
for k in range(n_bin): # Only half of the bins contain useful info
wsin[k,:] = -np.sin(2*np.pi*k*n/n_fft)[:kernel_size]
wcos[k,:] = np.cos(2*np.pi*k*n/n_fft)[:kernel_size]
w_real = wcos
w_imag = wsin
return w_real, w_imag, kernel_size
def dft_matrix_fast(n_fft:int, winlen:int=None, n_bin:int=None) -> Tuple[np.ndarray, np.ndarray, int]:
# (n_bins, n_fft) complex
if n_bin is None:
n_bin = 1 + n_fft // 2
if winlen is None:
winlen = n_bin
# https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49
kernel_size = min(n_fft, winlen)
# https://en.wikipedia.org/wiki/DFT_matrix
# https://ccrma.stanford.edu/~jos/st/Matrix_Formulation_DFT.html
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size]
w_real = weight.real
w_imag = weight.imag
return w_real, w_imag, kernel_size
def bin2hz(bin:Union[List[int], np.ndarray], N:int, sr:int)->List[float]:
"""FFT bins to Hz.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Args:
bins (List[int] or np.ndarray): bin index.
N (int): the number of samples, or FFT points.
sr (int): sampling rate.
Returns:
List[float]: Hz's.
"""
hz = bin * float(sr) / N
def hz2mel(hz):
"""Convert a value in Hertz to Mels
:param hz: a value in Hz. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Mels. If an array was passed in, an identical sized array is returned.
"""
return 1127 * np.log(1+hz/700.0)
def mel2hz(mel):
"""Convert a value in Mels to Hertz
:param mel: a value in Mels. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Hertz. If an array was passed in, an identical sized array is returned.
"""
return 700 * (np.exp(mel/1127.0)-1)
def rms_to_db(rms: float):
"""Root Mean Square to dB.
Args:
rms ([float]): root mean square
Returns:
float: dB
"""
return 20.0 * math.log10(max(1e-16, rms))
def rms_to_dbfs(rms: float):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
Args:
rms ([float]): root mean square
Returns:
float: dBFS
"""
return rms_to_db(rms) - 3.0103
def max_dbfs(sample_data: np.ndarray):
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
def mean_dbfs(sample_data):
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
return rms_to_dbfs(
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
def gain_db_to_ratio(gain_db: float):
"""dB to ratio
Args:
gain_db (float): gain in dB
Returns:
float: scale in amp
"""
return math.pow(10.0, gain_db / 20.0)
\ No newline at end of file
from typing import Tuple
import numpy as np
import paddle
from paddle import Tensor
from paddle import nn
from paddle.nn import functional as F
import soundfile as sf
from .common import get_window
from .common import dft_matrix
def read(wavpath:str, sr:int = None, start=0, stop=None, dtype='int16', always_2d=True)->Tuple[int, np.ndarray]:
"""load wav file.
Args:
wavpath (str): wav path.
sr (int, optional): expect sample rate. Defaults to None.
dtype (str, optional): wav data bits. Defaults to 'int16'.
Returns:
Tuple[int, np.ndarray]: sr (int), wav (int16) [T, C].
"""
wav, r_sr = sf.read(wavpath, start=start, stop=stop, dtype=dtype, always_2d=always_2d)
if sr:
assert sr == r_sr
return r_sr, wav
def write(wavpath:str, wav:np.ndarray, sr:int, dtype='PCM_16'):
"""write wav file.
Args:
wavpath (str): file path to save.
wav (np.ndarray): wav data.
sr (int): data samplerate.
dtype (str, optional): wav bit format. Defaults to 'PCM_16'.
"""
sf.write(wavpath, wav, sr, subtype=dtype)
def frames(x: Tensor,
num_samples: Tensor,
sr: int,
win_length: float,
stride_length: float,
clip: bool = False) -> Tuple[Tensor, Tensor]:
"""Extract frames from audio.
Parameters
----------
x : Tensor
Shape (B, T), batched waveform.
num_samples : Tensor
Shape (B, ), number of samples of each waveform.
sr: int
Sampling Rate.
win_length : float
Window length in ms.
stride_length : float
Stride length in ms.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
default True
Returns
-------
frames : Tensor
Shape (B, T', win_length).
num_frames : Tensor
Shape (B, ) number of valid frames
"""
assert stride_length <= win_length
stride_length = int(stride_length * sr)
win_length = int(win_length * sr)
num_frames = (num_samples - win_length) // stride_length
padding = (0, 0)
if not clip:
num_frames += 1
need_samples = num_frames * stride_length + win_length
padding = (0, need_samples - num_samples - 1)
weight = paddle.eye(win_length).unsqueeze(1) #[win_length, 1, win_length]
frames = F.conv1d(x.unsqueeze(-1),
weight,
padding=padding,
stride=(stride_length, ),
data_format='NLC')
return frames, num_frames
def dither(signal:Tensor, dither_value=1.0)->Tensor:
"""dither frames for log compute.
Args:
signal (Tensor): [B, T, D]
dither_value (float, optional): [scalar]. Defaults to 1.0.
Returns:
Tensor: [B, T, D]
"""
D = paddle.shape(signal)[-1]
signal += paddle.normal(shape=[1, 1, D]) * dither_value
return signal
def remove_dc_offset(signal:Tensor)->Tensor:
"""remove dc.
Args:
signal (Tensor): [B, T, D]
Returns:
Tensor: [B, T, D]
"""
signal -= paddle.mean(signal, axis=-1, keepdim=True)
return signal
def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
"""perform preemphasis on the input signal.
Args:
signal (Tensor): [B, T, D], The signal to filter.
coeff (float, optional): [scalar].The preemphasis coefficient. 0 is no filter, Defaults to 0.97.
Returns:
Tensor: [B, T, D]
"""
return paddle.concat([
(1-coeff)*signal[:, :, 0:1],
signal[:, :, 1:] - coeff * signal[:, :, :-1]
], axis=-1)
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Parameters
------------
n_fft : int
Number of samples in a frame.
sr: int
Number of Samplilng rate.
stride_length : float
Number of samples shifted between adjacent frames.
win_length : float
Length of the window.
clip: bool
Whether to clip audio is necesaary.
"""
def __init__(self,
n_fft: int,
sr: int,
win_length: float,
stride_length: float,
dither:float=0.0,
preemph_coeff:float=0.97,
remove_dc_offset:bool=True,
window_type: str = 'povey',
clip: bool = False):
super().__init__()
self.sr = sr
self.win_length = win_length
self.stride_length = stride_length
self.dither = dither
self.preemph_coeff = preemph_coeff
self.remove_dc_offset = remove_dc_offset
self.window_type = window_type
self.clip = clip
self.n_fft = n_fft
self.n_bin = 1 + n_fft // 2
w_real, w_imag, kernel_size = dft_matrix(
self.n_fft, int(self.win_length * self.sr), self.n_bin
)
# calculate window
window = get_window(window_type, kernel_size)
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
# (kernel_size, 2 * n_bins)
w = np.transpose(w)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
num_samples : Tensor [shape=(B,)]
Number of samples of each waveform.
Returns
------------
C : Tensor
Shape(B, T', n_bins, 2) Spectrogram.
num_frames: Tensor
Shape (B,) number of samples of each spectrogram
"""
batch_size = paddle.shape(num_samples)
F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip)
if self.dither:
F = dither(F, self.dither)
if self.remove_dc_offset:
F = remove_dc_offset(F)
if self.preemph_coeff:
F = preemphasis(F)
C = paddle.matmul(F, self.weight) # [B, T, K] [K, 2 * n_bins]
C = paddle.reshape(C, [batch_size, -1, 2, self.n_bin])
C = C.transpose([0, 1, 3, 2])
return C, nframe
def powspec(C:Tensor) -> Tensor:
"""Compute the power spectrum |X_k|^2.
Args:
C (Tensor): [B, T, C, 2]
Returns:
Tensor: [B, T, C]
"""
real, imag = paddle.chunk(C, 2, axis=-1)
return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1))
def magspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute the magnitude spectrum |X_k|.
Args:
C (Tensor): [B, T, C, 2]
eps (float): epsilon.
Returns:
Tensor: [B, T, C]
"""
pspec = powspec(C)
return paddle.sqrt(pspec + eps)
def logspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute log-spectrum 20log10∣X_k∣.
Args:
C (Tensor): [description]
eps ([type], optional): [description]. Defaults to 1e-10.
Returns:
Tensor: [description]
"""
spec = magspec(C)
return 20 * paddle.log10(spec + eps)
from typing import Tuple
import numpy as np
import paddle
import unittest
import decimal
import numpy
import math
import logging
from pathlib import Path
from scipy.fftpack import dct
from third_party.paddle_audio.frontend import kaldi
def round_half_up(number):
return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))
def rolling_window(a, window, step=1):
# http://ellisvalentiner.com/post/2017-03-21-np-strides-trick
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]
def do_dither(signal, dither_value=1.0):
signal += numpy.random.normal(size=signal.shape) * dither_value
return signal
def do_remove_dc_offset(signal):
signal -= numpy.mean(signal)
return signal
def do_preemphasis(signal, coeff=0.97):
"""perform preemphasis on the input signal.
:param signal: The signal to filter.
:param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.
:returns: the filtered signal.
"""
return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])
def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):
"""Frame a signal into overlapping frames.
:param sig: the audio signal to frame.
:param frame_len: length of each frame measured in samples.
:param frame_step: number of samples after the start of the previous frame that the next frame should begin.
:param winfunc: the analysis window to apply to each frame. By default no window is applied.
:param stride_trick: use stride trick to compute the rolling window and window multiplication faster
:returns: an array of frames. Size is NUMFRAMES by frame_len.
"""
slen = len(sig)
frame_len = int(round_half_up(frame_len))
frame_step = int(round_half_up(frame_step))
if slen <= frame_len:
numframes = 1
else:
numframes = 1 + (( slen - frame_len) // frame_step)
# check kaldi/src/feat/feature-window.h
padsignal = sig[:(numframes-1)*frame_step+frame_len]
if wintype is 'povey':
win = numpy.empty(frame_len)
for i in range(frame_len):
win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85
else: # the hamming window
win = numpy.hamming(frame_len)
if stride_trick:
frames = rolling_window(padsignal, window=frame_len, step=frame_step)
else:
indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
indices = numpy.array(indices, dtype=numpy.int32)
frames = padsignal[indices]
win = numpy.tile(win, (numframes, 1))
frames = frames.astype(numpy.float32)
raw_frames = numpy.zeros(frames.shape)
for frm in range(frames.shape[0]):
frames[frm,:] = do_dither(frames[frm,:], dither) # dither
frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset
raw_frames[frm,:] = frames[frm,:]
frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize
return frames * win, raw_frames
def magspec(frames, NFFT):
"""Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.
"""
if numpy.shape(frames)[1] > NFFT:
logging.warn(
'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',
numpy.shape(frames)[1], NFFT)
complex_spec = numpy.fft.rfft(frames, NFFT)
return numpy.absolute(complex_spec)
def powspec(frames, NFFT):
"""Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.
"""
return numpy.square(magspec(frames, NFFT))
def mfcc(signal,samplerate=16000,winlen=0.025,winstep=0.01,numcep=13,
nfilt=23,nfft=512,lowfreq=20,highfreq=None,dither=1.0,remove_dc_offset=True,preemph=0.97,
ceplifter=22,useEnergy=True,wintype='povey'):
"""Compute MFCC features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param numcep: the number of cepstrum to return, default 13
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:param ceplifter: apply a lifter to final cepstral coefficients. 0 is no lifter. Default is 22.
:param appendEnergy: if this is true, the zeroth cepstral coefficient is replaced with the log of the total frame energy.
:param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming
:returns: A numpy array of size (NUMFRAMES by numcep) containing features. Each row holds 1 feature vector.
"""
feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,dither,remove_dc_offset,preemph,wintype)
feat = numpy.log(feat)
feat = dct(feat, type=2, axis=1, norm='ortho')[:,:numcep]
feat = lifter(feat,ceplifter)
if useEnergy: feat[:,0] = numpy.log(energy) # replace first cepstral coefficient with log of frame energy
return feat
def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97,
wintype='hamming'):
"""Compute Mel-filterbank energy features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming
winfunc=lambda x:numpy.ones((x,))
:returns: 2 values. The first is a numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector. The
second return value is the energy in each frame (total energy, unwindowed)
"""
highfreq= highfreq or samplerate/2
frames,raw_frames = sigproc.framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)
pspec = sigproc.powspec(frames,nfft) # nearly the same until this part
energy = numpy.sum(raw_frames**2,1) # this stores the raw energy in each frame
energy = numpy.where(energy == 0,numpy.finfo(float).eps,energy) # if energy is zero, we get problems with log
fb = get_filterbanks(nfilt,nfft,samplerate,lowfreq,highfreq)
feat = numpy.dot(pspec,fb.T) # compute the filterbank energies
feat = numpy.where(feat == 0,numpy.finfo(float).eps,feat) # if feat is zero, we get problems with log
return feat,energy
def logfbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=64,highfreq=None,dither=1.0,remove_dc_offset=True,preemph=0.97,wintype='hamming'):
"""Compute log Mel-filterbank energy features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:returns: A numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector.
"""
feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,dither, remove_dc_offset,preemph,wintype)
return numpy.log(feat)
def hz2mel(hz):
"""Convert a value in Hertz to Mels
:param hz: a value in Hz. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Mels. If an array was passed in, an identical sized array is returned.
"""
return 1127 * numpy.log(1+hz/700.0)
def mel2hz(mel):
"""Convert a value in Mels to Hertz
:param mel: a value in Mels. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Hertz. If an array was passed in, an identical sized array is returned.
"""
return 700 * (numpy.exp(mel/1127.0)-1)
def get_filterbanks(nfilt=26,nfft=512,samplerate=16000,lowfreq=0,highfreq=None):
"""Compute a Mel-filterbank. The filters are stored in the rows, the columns correspond
to fft bins. The filters are returned as an array of size nfilt * (nfft/2 + 1)
:param nfilt: the number of filters in the filterbank, default 20.
:param nfft: the FFT size. Default is 512.
:param samplerate: the samplerate of the signal we are working with. Affects mel spacing.
:param lowfreq: lowest band edge of mel filters, default 0 Hz
:param highfreq: highest band edge of mel filters, default samplerate/2
:returns: A numpy array of size nfilt * (nfft/2 + 1) containing filterbank. Each row holds 1 filter.
"""
highfreq= highfreq or samplerate/2
assert highfreq <= samplerate/2, "highfreq is greater than samplerate/2"
# compute points evenly spaced in mels
lowmel = hz2mel(lowfreq)
highmel = hz2mel(highfreq)
# check kaldi/src/feat/Mel-computations.h
fbank = numpy.zeros([nfilt,nfft//2+1])
mel_freq_delta = (highmel-lowmel)/(nfilt+1)
for j in range(0,nfilt):
leftmel = lowmel+j*mel_freq_delta
centermel = lowmel+(j+1)*mel_freq_delta
rightmel = lowmel+(j+2)*mel_freq_delta
for i in range(0,nfft//2):
mel=hz2mel(i*samplerate/nfft)
if mel>leftmel and mel<rightmel:
if mel<centermel:
fbank[j,i]=(mel-leftmel)/(centermel-leftmel)
else:
fbank[j,i]=(rightmel-mel)/(rightmel-centermel)
return fbank
def lifter(cepstra, L=22):
"""Apply a cepstral lifter the the matrix of cepstra. This has the effect of increasing the
magnitude of the high frequency DCT coeffs.
:param cepstra: the matrix of mel-cepstra, will be numframes * numcep in size.
:param L: the liftering coefficient to use. Default is 22. L <= 0 disables lifter.
"""
if L > 0:
nframes,ncoeff = numpy.shape(cepstra)
n = numpy.arange(ncoeff)
lift = 1 + (L/2.)*numpy.sin(numpy.pi*n/L)
return lift*cepstra
else:
# values of L <= 0, do nothing
return cepstra
def delta(feat, N):
"""Compute delta features from a feature vector sequence.
:param feat: A numpy array of size (NUMFRAMES by number of features) containing features. Each row holds 1 feature vector.
:param N: For each frame, calculate delta features based on preceding and following N frames
:returns: A numpy array of size (NUMFRAMES by number of features) containing delta features. Each row holds 1 delta feature vector.
"""
if N < 1:
raise ValueError('N must be an integer >= 1')
NUMFRAMES = len(feat)
denominator = 2 * sum([i**2 for i in range(1, N+1)])
delta_feat = numpy.empty_like(feat)
padded = numpy.pad(feat, ((N, N), (0, 0)), mode='edge') # padded version of feat
for t in range(NUMFRAMES):
delta_feat[t] = numpy.dot(numpy.arange(-N, N+1), padded[t : t+2*N+1]) / denominator # [t : t+2*N+1] == [(N+t)-N : (N+t)+N+1]
return delta_feat
##### modify for test ######
def framesig_without_dither_dc_preemphasize(sig, frame_len, frame_step, wintype='hamming', stride_trick=True):
"""Frame a signal into overlapping frames.
:param sig: the audio signal to frame.
:param frame_len: length of each frame measured in samples.
:param frame_step: number of samples after the start of the previous frame that the next frame should begin.
:param winfunc: the analysis window to apply to each frame. By default no window is applied.
:param stride_trick: use stride trick to compute the rolling window and window multiplication faster
:returns: an array of frames. Size is NUMFRAMES by frame_len.
"""
slen = len(sig)
frame_len = int(round_half_up(frame_len))
frame_step = int(round_half_up(frame_step))
if slen <= frame_len:
numframes = 1
else:
numframes = 1 + (( slen - frame_len) // frame_step)
# check kaldi/src/feat/feature-window.h
padsignal = sig[:(numframes-1)*frame_step+frame_len]
if wintype is 'povey':
win = numpy.empty(frame_len)
for i in range(frame_len):
win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85
elif wintype == '':
win = numpy.ones(frame_len)
elif wintype == 'hann':
win = numpy.hanning(frame_len)
else: # the hamming window
win = numpy.hamming(frame_len)
if stride_trick:
frames = rolling_window(padsignal, window=frame_len, step=frame_step)
else:
indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
indices = numpy.array(indices, dtype=numpy.int32)
frames = padsignal[indices]
win = numpy.tile(win, (numframes, 1))
frames = frames.astype(numpy.float32)
raw_frames = frames
return frames * win, raw_frames
def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None, wintype='hamming'):
frames_with_win, raw_frames = framesig_without_dither_dc_preemphasize(signal, winlen*samplerate, winstep*samplerate, wintype)
return frames_with_win, raw_frames
def complexspec(frames, NFFT):
"""Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.
"""
if numpy.shape(frames)[1] > NFFT:
logging.warn(
'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',
numpy.shape(frames)[1], NFFT)
complex_spec = numpy.fft.rfft(frames, NFFT)
return complex_spec
def stft_with_window(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97,
wintype='hamming'):
frames_with_win, raw_frames = framesig_without_dither_dc_preemphasize(signal, winlen*samplerate, winstep*samplerate, wintype)
spec = magspec(frames_with_win, nfft) # nearly the same until this part
scomplex = complexspec(frames_with_win, nfft)
rspec = magspec(raw_frames, nfft)
rcomplex = complexspec(raw_frames, nfft)
return spec, scomplex, rspec, rcomplex
class TestKaldiFE(unittest.TestCase):
def setUp(self):
self. this_dir = Path(__file__).parent
self.wavpath = str(self.this_dir / 'english.wav')
self.winlen=0.025 # ms
self.winstep=0.01 # ms
self.nfft=512
self.lowfreq = 0
self.highfreq = None
self.wintype='hamm'
self.nfilt=40
paddle.set_device('cpu')
def test_read(self):
import scipy.io.wavfile as wav
rate, sig = wav.read(self.wavpath)
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
self.assertTrue(np.all(sig == wav))
self.assertEqual(rate, sr)
def test_frames(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
_, fs = frames(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
t_fs, t_nframe = kaldi.frames(t_wav, t_wavlen, sr, self.winlen, self.winstep, clip=False)
t_fs = t_fs.astype(fs.dtype)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertTrue(np.allclose(t_fs.numpy(), fs))
def test_stft(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
_, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_c_win.real.dtype)[0]
t_real = t_stft[:, :, 0]
t_imag = t_stft[:, :, 1]
self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0])
self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1)
self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1))
self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1)
self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1))
def test_magspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1)
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1))
def test_magsepc_winprocess(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
fs, _= framesig(wav, self.winlen*sr, self.winstep*sr,
dither=0.0, preemph=0.97, remove_dc_offset=True, wintype='povey', stride_trick=True)
spec = magspec(fs, self.nfft) # nearly the same until this part
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(
self.nfft, sr, self.winlen, self.winstep,
window_type='povey', dither=0.0, preemph_coeff=0.97, remove_dc_offset=True, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(spec.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(spec), 1)
self.assertTrue(np.allclose(t_spec.numpy(), spec, atol=1e-1))
def test_powspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
stft_win = np.square(stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.powspec(t_stft)[0]
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4)
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))
# from python_speech_features import mfcc
# from python_speech_features import delta
# from python_speech_features import logfbank
# import scipy.io.wavfile as wav
# (rate,sig) = wav.read("english.wav")
# # note that generally nfilt=40 is used for speech recognition
# fbank_feat = logfbank(sig,nfilt=23,lowfreq=20,dither=0,wintype='povey')
# # the computed fbank coefficents of english.wav with dimension [110,23]
# # [ 12.2865 12.6906 13.1765 15.714 16.064 15.7553 16.5746 16.9205 16.6472 16.1302 16.4576 16.7326 16.8864 17.7215 18.88 19.1377 19.1495 18.6683 18.3886 20.3506 20.2772 18.8248 18.1899
# # 11.9198 13.146 14.7215 15.8642 17.4288 16.394 16.8238 16.1095 16.4297 16.6331 16.3163 16.5093 17.4981 18.3429 19.6555 19.6263 19.8435 19.0534 19.001 20.0287 19.7707 19.5852 19.1112
# # ...
# # ...
# # the same with that using kaldi commands: compute-fbank-feats --dither=0.0
# mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')
# # the computed mfcc coefficents of english.wav with dimension [110,13]
# # [ 17.1337 -23.3651 -7.41751 -7.73686 -21.3682 -8.93884 -3.70843 4.68346 -16.0676 12.782 -7.24054 8.25089 10.7292
# # 17.1692 -23.3028 -5.61872 -4.0075 -23.287 -20.6101 -5.51584 -6.15273 -14.4333 8.13052 -0.0345329 2.06274 -0.564298
# # ...
# # ...
# # the same with that using kaldi commands: compute-mfcc-feats --dither=0.0
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
SHELL:= /bin/bash SHELL:= /bin/bash
PYTHON:= python3.7 PYTHON:= python3.7
CXX ?= g++
CC ?= gcc # used for sph2pipe
# CXX = clang++ # Uncomment these lines...
# CC = clang # ...to build with Clang.
WGET ?= wget
.PHONY: all clean .PHONY: all clean
all: virtualenv kenlm.done sox.done soxbindings.done mfa.done all: virtualenv kenlm.done sox.done soxbindings.done mfa.done sclite.done
virtualenv: virtualenv:
test -d venv || virtualenv -p $(PYTHON) venv test -d venv || virtualenv -p $(PYTHON) venv
...@@ -39,3 +47,50 @@ mfa.done: ...@@ -39,3 +47,50 @@ mfa.done:
test -d montreal-forced-aligner || wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz test -d montreal-forced-aligner || wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz
tar xvf montreal-forced-aligner_linux.tar.gz tar xvf montreal-forced-aligner_linux.tar.gz
touch mfa.done touch mfa.done
#== SCTK ===============================================================================
# SCTK official repo does not have version tags. Here's the mapping:
# # 2.4.9 = 659bc36; 2.4.10 = d914e1b; 2.4.11 = 20159b5.
SCTK_GITHASH = 20159b5
SCTK_CXFLAGS = -w -march=native
SCTK_MKENV = CFLAGS="$(CFLAGS) $(SCTK_CXFLAGS)" \
CXXFLAGS="$(CXXFLAGS) -std=c++11 $(SCTK_CXFLAGS)" \
# Keep the existing target 'sclite' to avoid breaking the users who might have
# scripted it in.
.PHONY: sclite.done sctk_cleaned sctk_made
sclite.done sctk_made: sctk/.compiled
touch sclite.done
sctk/.compiled: sctk
rm -f sctk/.compiled
$(SCTK_MKENV) $(MAKE) -C sctk config
$(SCTK_MKENV) $(MAKE) -C sctk all doc
$(MAKE) -C sctk install
touch sctk/.compiled
# The GitHub archive unpacks into SCTK-{40-character-long-hash}/
sctk: sctk-$(SCTK_GITHASH).tar.gz
tar zxvf sctk-$(SCTK_GITHASH).tar.gz
rm -rf sctk-$(SCTK_GITHASH) sctk
mv SCTK-$(SCTK_GITHASH)* sctk-$(SCTK_GITHASH)
ln -s sctk-$(SCTK_GITHASH) sctk
touch sctk-$(SCTK_GITHASH).tar.gz
sctk-$(SCTK_GITHASH).tar.gz:
if [ -d '$(DOWNLOAD_DIR)' ]; then \
cp -p '$(DOWNLOAD_DIR)/sctk-$(SCTK_GITHASH).tar.gz' .; \
else \
$(WGET) -nv -T 10 -t 3 -O sctk-$(SCTK_GITHASH).tar.gz \
https://github.com/usnistgov/SCTK/archive/$(SCTK_GITHASH).tar.gz; \
fi
sctk_cleaned:
-for d in sctk/ sctk-*/; do \
[ ! -f $$d/.compiled ] || $(MAKE) -C $$d clean; \
rm -f $$d/.compiled; \
done
# Utils
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils)
...@@ -5,8 +5,8 @@ if [ $# != 3 ]; then ...@@ -5,8 +5,8 @@ if [ $# != 3 ]; then
exit -1 exit -1
fi fi
ckpt_dir=${1} avg_mode=${1} # best,latest
avg_mode=${2} # best,latest ckpt_dir=${2}
average_num=${3} average_num=${3}
decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
......
...@@ -27,8 +27,9 @@ def main(args): ...@@ -27,8 +27,9 @@ def main(args):
val_scores = [] val_scores = []
beat_val_scores = [] beat_val_scores = []
selected_epochs = [] selected_epochs = []
if args.val_best:
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
for y in jsons: for y in jsons:
with open(y, 'r') as f: with open(y, 'r') as f:
dic_json = json.load(f) dic_json = json.load(f)
...@@ -36,24 +37,23 @@ def main(args): ...@@ -36,24 +37,23 @@ def main(args):
epoch = dic_json['epoch'] epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch: if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss)) val_scores.append((epoch, loss))
val_scores = np.array(val_scores) val_scores = np.array(val_scores)
if args.val_best:
sort_idx = np.argsort(val_scores[:, 1]) sort_idx = np.argsort(val_scores[:, 1])
sorted_val_scores = val_scores[sort_idx] sorted_val_scores = val_scores[sort_idx]
path_list = [ else:
args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) sorted_val_scores = val_scores
for epoch in sorted_val_scores[:args.num, 0]
]
beat_val_scores = sorted_val_scores[:args.num, 1] beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
print("best val scores = " + str(beat_val_scores)) print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs)) print("selected epochs = " + str(selected_epochs))
else:
path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams')
path_list = sorted(path_list, key=os.path.getmtime)
path_list = path_list[-args.num:]
path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
print(path_list) print(path_list)
avg = None avg = None
...@@ -78,6 +78,7 @@ def main(args): ...@@ -78,6 +78,7 @@ def main(args):
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f: with open(meta_path, 'w') as f:
data = json.dumps({ data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model, "avg_ckpt": args.dst_model,
"ckpt": path_list, "ckpt": path_list,
"epoch": selected_epochs.tolist(), "epoch": selected_epochs.tolist(),
......
#!/usr/bin/env bash
# 2020 author Jiayu DU
# Apache 2.0
# This script reads in an Arpa format language model, and converts it into the
# KenLM format language model.
[ -f path.sh ] && . ./path.sh;
# begin configuration section
kenlm_opts="" # e.g. "-q 8 -b 8" for 8bits quantization
model_type="trie" # "trie" or "probing". trie is smaller, probing is faster.
# end configuration section
. utils/parse_options.sh
if [ $# != 2 ]; then
echo "Usage: "
echo " $0 [options] <arpa-lm-path> <kenlm-path>"
echo "e.g.:"
echo " $0 data/local/lm/4gram.arpa data/lang_test/G.trie"
echo "Options:"
echo " --model-type can be either \"trie\" or \"probing\""
echo " --kenlm-opts directly pass through to kenlm"
echo " e.g. for 8bits quantization, feed \"-q 8 -b 8\""
exit 1;
fi
export LC_ALL=C
arpa_lm=$1
kenlm=$2
if ! which build_binary >& /dev/null ; then
echo "$0: cannot find KenLM's build_binary tool,"
echo "check kenlm installation (tools/extras/install_kenlm_query_only.sh)."
exit 1
fi
mkdir -p $(dirname $kenlm)
build_binary $kenlm_opts $model_type $arpa_lm $kenlm
echo "$0: Successfully built arpa into kenlm format: $kenlm"
exit 0
\ No newline at end of file
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
#!/usr/bin/env python3
# Apache 2.0
import argparse
import codecs
import sys
is_python2 = sys.version_info[0] == 2
def get_parser():
parser = argparse.ArgumentParser(
description="filter words in a text file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--exclude",
"-v",
dest="exclude",
action="store_true",
help="exclude filter words", )
parser.add_argument("filt", type=str, help="filter list")
parser.add_argument("infile", type=str, help="input file")
return parser
def main(args):
args = get_parser().parse_args(args)
filter_file(args.infile, args.filt, args.exclude)
def filter_file(infile, filt, exclude):
vocab = set()
with codecs.open(filt, "r", encoding="utf-8") as vocabfile:
for line in vocabfile:
vocab.add(line.strip())
sys.stdout = codecs.getwriter("utf-8")(sys.stdout
if is_python2 else sys.stdout.buffer)
with codecs.open(infile, "r", encoding="utf-8") as textfile:
for line in textfile:
if exclude:
print(" ".join(
map(
lambda word: word if word not in vocab else "",
line.strip().split(), )))
else:
print(" ".join(
map(
lambda word: word if word in vocab else "<UNK>",
line.strip().split(), )))
if __name__ == "__main__":
main(sys.argv[1:])
文件模式从 100644 更改为 100755
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# In general, doing
# run.pl some.log a b c is like running the command a b c in
# the bash shell, and putting the standard error and output into some.log.
# To run parallel jobs (backgrounded on the host machine), you can do (e.g.)
# run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB
# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier].
# If any of the jobs fails, this script will fail.
# A typical example is:
# run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz
# and run.pl will run something like:
# ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log
#
# Basically it takes the command-line arguments, quotes them
# as necessary to preserve spaces, and evaluates them with bash.
# In addition it puts the command line at the top of the log, and
# the start and end times of the command at the beginning and end.
# The reason why this is useful is so that we can create a different
# version of this program that uses a queueing system instead.
#use Data::Dumper;
@ARGV < 2 && die "usage: run.pl log-file command-line arguments...";
#print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n";
$job_pick = 'all';
$max_jobs_run = -1;
$jobstart = 1;
$jobend = 1;
$ignored_opts = ""; # These will be ignored.
# First parse an option like JOB=1:4, and any
# options that would normally be given to
# queue.pl, which we will just discard.
for (my $x = 1; $x <= 2; $x++) { # This for-loop is to
# allow the JOB=1:n option to be interleaved with the
# options to qsub.
while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) {
# parse any options that would normally go to qsub, but which will be ignored here.
my $switch = shift @ARGV;
if ($switch eq "-V") {
$ignored_opts .= "-V ";
} elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") {
# we do support the option --max-jobs-run n, and its GridEngine form -tc n.
# if the command appears multiple times uses the smallest option.
if ( $max_jobs_run <= 0 ) {
$max_jobs_run = shift @ARGV;
} else {
my $new_constraint = shift @ARGV;
if ( ($new_constraint < $max_jobs_run) ) {
$max_jobs_run = $new_constraint;
}
}
if (! ($max_jobs_run > 0)) {
die "run.pl: invalid option --max-jobs-run $max_jobs_run";
}
} else {
my $argument = shift @ARGV;
if ($argument =~ m/^--/) {
print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n";
}
if ($switch eq "-sync" && $argument =~ m/^[yY]/) {
$ignored_opts .= "-sync "; # Note: in the
# corresponding code in queue.pl it says instead, just "$sync = 1;".
} elsif ($switch eq "-pe") { # e.g. -pe smp 5
my $argument2 = shift @ARGV;
$ignored_opts .= "$switch $argument $argument2 ";
} elsif ($switch eq "--gpu") {
$using_gpu = $argument;
} elsif ($switch eq "--pick") {
if($argument =~ m/^(all|failed|incomplete)$/) {
$job_pick = $argument;
} else {
print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'"
}
} else {
# Ignore option.
$ignored_opts .= "$switch $argument ";
}
}
}
if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20
$jobname = $1;
$jobstart = $2;
$jobend = $3;
if ($jobstart > $jobend) {
die "run.pl: invalid job range $ARGV[0]";
}
if ($jobstart <= 0) {
die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility).";
}
shift;
} elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1.
$jobname = $1;
$jobstart = $2;
$jobend = $2;
shift;
} elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) {
print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n";
}
}
# Users found this message confusing so we are removing it.
# if ($ignored_opts ne "") {
# print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n";
# }
if ($max_jobs_run == -1) { # If --max-jobs-run option not set,
# then work out the number of processors if possible,
# and set it based on that.
$max_jobs_run = 0;
if ($using_gpu) {
if (open(P, "nvidia-smi -L |")) {
$max_jobs_run++ while (<P>);
close(P);
}
if ($max_jobs_run == 0) {
$max_jobs_run = 1;
print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n";
}
} elsif (open(P, "</proc/cpuinfo")) { # Linux
while (<P>) { if (m/^processor/) { $max_jobs_run++; } }
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n";
$max_jobs_run = 10; # reasonable default.
}
close(P);
} elsif (open(P, "sysctl -a |")) { # BSD/Darwin
while (<P>) {
if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4
$max_jobs_run = $1;
last;
}
}
close(P);
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n";
$max_jobs_run = 10; # reasonable default.
}
} else {
# allow at most 32 jobs at once, on non-UNIX systems; change this code
# if you need to change this default.
$max_jobs_run = 32;
}
# The just-computed value of $max_jobs_run is just the number of processors
# (or our best guess); and if it happens that the number of jobs we need to
# run is just slightly above $max_jobs_run, it will make sense to increase
# $max_jobs_run to equal the number of jobs, so we don't have a small number
# of leftover jobs.
$num_jobs = $jobend - $jobstart + 1;
if (!$using_gpu &&
$num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) {
$max_jobs_run = $num_jobs;
}
}
sub pick_or_exit {
# pick_or_exit ( $logfile )
# Invoked before each job is started helps to run jobs selectively.
#
# Given the name of the output logfile decides whether the job must be
# executed (by returning from the subroutine) or not (by terminating the
# process calling exit)
#
# PRE: $job_pick is a global variable set by command line switch --pick
# and indicates which class of jobs must be executed.
#
# 1) If a failed job is not executed the process exit code will indicate
# failure, just as if the task was just executed and failed.
#
# 2) If a task is incomplete it will be executed. Incomplete may be either
# a job whose log file does not contain the accounting notes in the end,
# or a job whose log file does not exist.
#
# 3) If the $job_pick is set to 'all' (default behavior) a task will be
# executed regardless of the result of previous attempts.
#
# This logic could have been implemented in the main execution loop
# but a subroutine to preserve the current level of readability of
# that part of the code.
#
# Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020
#
if($job_pick eq 'all'){
return; # no need to bother with the previous log
}
open my $fh, "<", $_[0] or return; # job not executed yet
my $log_line;
my $cur_line;
while ($cur_line = <$fh>) {
if( $cur_line =~ m/# Ended \(code .*/ ) {
$log_line = $cur_line;
}
}
close $fh;
if (! defined($log_line)){
return; # incomplete
}
if ( $log_line =~ m/# Ended \(code 0\).*/ ) {
exit(0); # complete
} elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){
if ($job_pick !~ m/^(failed|all)$/) {
exit(1); # failed but not going to run
} else {
return; # failed
}
} elsif ( $log_line =~ m/.*\S.*/ ) {
return; # incomplete jobs are always run
}
}
$logfile = shift @ARGV;
if (defined $jobname && $logfile !~ m/$jobname/ &&
$jobend > $jobstart) {
print STDERR "run.pl: you are trying to run a parallel job but "
. "you are putting the output into just one log file ($logfile)\n";
exit(1);
}
$cmd = "";
foreach $x (@ARGV) {
if ($x =~ m/^\S+$/) { $cmd .= $x . " "; }
elsif ($x =~ m:\":) { $cmd .= "'$x' "; }
else { $cmd .= "\"$x\" "; }
}
#$Data::Dumper::Indent=0;
$ret = 0;
$numfail = 0;
%active_pids=();
use POSIX ":sys_wait_h";
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
if (scalar(keys %active_pids) >= $max_jobs_run) {
# Lets wait for a change in any child's status
# Then we have to work out which child finished
$r = waitpid(-1, 0);
$code = $?;
if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen.
if ( defined $active_pids{$r} ) {
$jid=$active_pids{$r};
$fail[$jid]=$code;
if ($code !=0) { $numfail++;}
delete $active_pids{$r};
# print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n";
} else {
die "run.pl: Cannot find the PID of the child process that just finished.";
}
# In theory we could do a non-blocking waitpid over all jobs running just
# to find out if only one or more jobs finished during the previous waitpid()
# However, we just omit this and will reap the next one in the next pass
# through the for(;;) cycle
}
$childpid = fork();
if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; }
if ($childpid == 0) { # We're in the child... this branch
# executes the job and returns (possibly with an error status).
if (defined $jobname) {
$cmd =~ s/$jobname/$jobid/g;
$logfile =~ s/$jobname/$jobid/g;
}
# exit if the job does not need to be executed
pick_or_exit( $logfile );
system("mkdir -p `dirname $logfile` 2>/dev/null");
open(F, ">$logfile") || die "run.pl: Error opening log file $logfile";
print F "# " . $cmd . "\n";
print F "# Started at " . `date`;
$starttime = `date +'%s'`;
print F "#\n";
close(F);
# Pipe into bash.. make sure we're not using any other shell.
open(B, "|bash") || die "run.pl: Error opening shell command";
print B "( " . $cmd . ") 2>>$logfile >> $logfile";
close(B); # If there was an error, exit status is in $?
$ret = $?;
$lowbits = $ret & 127;
$highbits = $ret >> 8;
if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" }
else { $return_str = "code $highbits"; }
$endtime = `date +'%s'`;
open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)";
$enddate = `date`;
chop $enddate;
print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n";
print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n";
close(F);
exit($ret == 0 ? 0 : 1);
} else {
$pid[$jobid] = $childpid;
$active_pids{$childpid} = $jobid;
# print STDERR "Queued: " . Dumper(\%active_pids) . "\n";
}
}
# Now we have submitted all the jobs, lets wait until all the jobs finish
foreach $child (keys %active_pids) {
$jobid=$active_pids{$child};
$r = waitpid($pid[$jobid], 0);
$code = $?;
if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen.
if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully
}
# Some sanity checks:
# The $fail array should not contain undefined codes
# The number of non-zeros in that array should be equal to $numfail
# We cannot do foreach() here, as the JOB ids do not start at zero
$failed_jids=0;
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
$job_return = $fail[$jobid];
if (not defined $job_return ) {
# print Dumper(\@fail);
die "run.pl: Sanity check failed: we have indication that some jobs are running " .
"even after we waited for all jobs to finish" ;
}
if ($job_return != 0 ){ $failed_jids++;}
}
if ($failed_jids != $numfail) {
die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)."
}
if ($numfail > 0) { $ret = 1; }
if ($ret != 0) {
$njobs = $jobend - $jobstart + 1;
if ($njobs == 1) {
if (defined $jobname) {
$logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with
# that job.
}
print STDERR "run.pl: job failed, log is in $logfile\n";
if ($logfile =~ m/JOB/) {
print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script.";
}
}
else {
$logfile =~ s/$jobname/*/g;
print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n";
}
}
exit ($ret);
\ No newline at end of file
文件模式从 100644 更改为 100755
#!/usr/bin/env bash
unset GREP_OPTIONS
set -u # Check for undefined variables
die() {
# Print a message and exit with code 1.
#
# Usage: die <error_message>
# e.g., die "Something bad happened."
echo $@
exit 1
}
echo "Collecting system information..."
OUTPUT_FILE=pd_env.txt
python_bin_path=$(which python || which python3 || die "Cannot find Python binary")
{
echo
echo '== check python ==================================================='
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_python.py
import platform
print("""python version: %s
python branch: %s
python build version: %s
python compiler version: %s
python implementation: %s
""" % (
platform.python_version(),
platform.python_branch(),
platform.python_build(),
platform.python_compiler(),
platform.python_implementation(),
))
EOF
${python_bin_path} /tmp/check_python.py 2>&1 >> ${OUTPUT_FILE}
{
echo
echo '== check os platform ==============================================='
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_os.py
import platform
print("""os: %s
os kernel version: %s
os release version: %s
os platform: %s
linux distribution: %s
linux os distribution: %s
mac version: %s
uname: %s
architecture: %s
machine: %s
""" % (
platform.system(),
platform.version(),
platform.release(),
platform.platform(),
platform.linux_distribution(),
platform.dist(),
platform.mac_ver(),
platform.uname(),
platform.architecture(),
platform.machine(),
))
EOF
${python_bin_path} /tmp/check_os.py 2>&1 >> ${OUTPUT_FILE}
{
echo
echo '== are we in docker ============================================='
num=`cat /proc/1/cgroup | grep docker | wc -l`;
if [ $num -ge 1 ]; then
echo "Yes"
else
echo "No"
fi
echo
echo '== compiler ====================================================='
c++ --version 2>&1
echo
echo '== check pips ==================================================='
pip list 2>&1 | grep "proto\|numpy\|paddlepaddle"
echo
echo '== check for virtualenv ========================================='
${python_bin_path} -c "import sys;print(hasattr(sys, \"real_prefix\"))"
echo
echo '== paddlepaddle import ============================================'
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_pd.py
import paddle as pd;
pd.set_device('cpu')
print("pd.version.full_version = %s" % pd.version.full_version)
print("pd.version.commit = %s" % pd.version.commit)
print("pd.__version__ = %s" % pd.__version__)
print("Sanity check: %r" % pd.zeros([1,2,3])[:1])
EOF
${python_bin_path} /tmp/check_pd.py 2>&1 >> ${OUTPUT_FILE}
LD_DEBUG=libs ${python_bin_path} -c "import paddle" 2>>${OUTPUT_FILE} > /tmp/loadedlibs
{
grep libcudnn.so /tmp/loadedlibs
echo
echo '== env =========================================================='
if [ -z ${LD_LIBRARY_PATH+x} ]; then
echo "LD_LIBRARY_PATH is unset";
else
echo LD_LIBRARY_PATH ${LD_LIBRARY_PATH} ;
fi
if [ -z ${DYLD_LIBRARY_PATH+x} ]; then
echo "DYLD_LIBRARY_PATH is unset";
else
echo DYLD_LIBRARY_PATH ${DYLD_LIBRARY_PATH} ;
fi
echo
echo '== nvidia-smi ==================================================='
nvidia-smi 2>&1
echo
echo '== cuda libs ==================================================='
} >> ${OUTPUT_FILE}
find /usr/local -type f -name 'libcudart*' 2>/dev/null | grep cuda | grep -v "\\.cache" >> ${OUTPUT_FILE}
find /usr/local -type f -name 'libudnn*' 2>/dev/null | grep cuda | grep -v "\\.cache" >> ${OUTPUT_FILE}
{
echo
echo '== paddlepaddle installed from info =================='
pip show paddlepaddle-gpu
echo
echo '== python version =============================================='
echo '(major, minor, micro, releaselevel, serial)'
python -c 'import sys; print(sys.version_info[:])'
echo
echo '== bazel version ==============================================='
bazel version
echo '== cmake version ==============================================='
cmake --version
} >> ${OUTPUT_FILE}
# Remove any words with google.
mv $OUTPUT_FILE old-$OUTPUT_FILE
grep -v -i google old-${OUTPUT_FILE} > $OUTPUT_FILE
echo "Wrote environment to ${OUTPUT_FILE}. You can review the contents of that file."
echo "and use it to populate the fields in the github issue template."
echo
echo "cat ${OUTPUT_FILE}"
echo
文件模式从 100644 更改为 100755
parallel/run.pl
\ No newline at end of file
#!/usr/bin/env bash
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
[ -f ./path.sh ] && . ./path.sh
# non language symbol
nlsyms=""
wer=false
bpe=""
bpemodel=""
remove_blank=true
filter=""
num_spkrs=1
help_message="Usage: $0 <data-dir> <dict>"
. utils/parse_options.sh
if [ $# != 2 ]; then
echo "${help_message}"
exit 1;
fi
dir=$1
dic=$2
cat ${dir}/data.*.json > ${dir}/data.json
if [ $num_spkrs -eq 1 ]; then
json2trn.py ${dir}/data.json ${dic} --num-spkrs ${num_spkrs} --refs ${dir}/ref.trn --hyps ${dir}/hyp.trn
if ${remove_blank}; then
sed -i.bak2 -r 's/<blank> //g' ${dir}/hyp.trn
fi
if [ -n "${nlsyms}" ]; then
cp ${dir}/ref.trn ${dir}/ref.trn.org
cp ${dir}/hyp.trn ${dir}/hyp.trn.org
filt.py -v ${nlsyms} ${dir}/ref.trn.org > ${dir}/ref.trn
filt.py -v ${nlsyms} ${dir}/hyp.trn.org > ${dir}/hyp.trn
fi
if [ -n "${filter}" ]; then
sed -i.bak3 -f ${filter} ${dir}/hyp.trn
sed -i.bak3 -f ${filter} ${dir}/ref.trn
fi
sclite -r ${dir}/ref.trn trn -h ${dir}/hyp.trn trn -i rm -o all stdout > ${dir}/result.txt
echo "write a CER (or TER) result in ${dir}/result.txt"
grep -e Avg -e SPKR -m 2 ${dir}/result.txt
if ${wer}; then
if [ -n "$bpe" ]; then
spm_decode --model=${bpemodel} --input_format=piece < ${dir}/ref.trn | sed -e "s/▁/ /g" > ${dir}/ref.wrd.trn
spm_decode --model=${bpemodel} --input_format=piece < ${dir}/hyp.trn | sed -e "s/▁/ /g" > ${dir}/hyp.wrd.trn
else
sed -e "s/ //g" -e "s/(/ (/" -e "s/<space>/ /g" ${dir}/ref.trn > ${dir}/ref.wrd.trn
sed -e "s/ //g" -e "s/(/ (/" -e "s/<space>/ /g" ${dir}/hyp.trn > ${dir}/hyp.wrd.trn
fi
sclite -r ${dir}/ref.wrd.trn trn -h ${dir}/hyp.wrd.trn trn -i rm -o all stdout > ${dir}/result.wrd.txt
echo "write a WER result in ${dir}/result.wrd.txt"
grep -e Avg -e SPKR -m 2 ${dir}/result.wrd.txt
fi
elif [ ${num_spkrs} -lt 4 ]; then
ref_trns=""
hyp_trns=""
for i in $(seq ${num_spkrs}); do
ref_trns=${ref_trns}"${dir}/ref${i}.trn "
hyp_trns=${hyp_trns}"${dir}/hyp${i}.trn "
done
json2trn.py ${dir}/data.json ${dic} --num-spkrs ${num_spkrs} --refs ${ref_trns} --hyps ${hyp_trns}
for n in $(seq ${num_spkrs}); do
if ${remove_blank}; then
sed -i.bak2 -r 's/<blank> //g' ${dir}/hyp${n}.trn
fi
if [ -n "${nlsyms}" ]; then
cp ${dir}/ref${n}.trn ${dir}/ref${n}.trn.org
cp ${dir}/hyp${n}.trn ${dir}/hyp${n}.trn.org
filt.py -v ${nlsyms} ${dir}/ref${n}.trn.org > ${dir}/ref${n}.trn
filt.py -v ${nlsyms} ${dir}/hyp${n}.trn.org > ${dir}/hyp${n}.trn
fi
if [ -n "${filter}" ]; then
sed -i.bak3 -f ${filter} ${dir}/hyp${n}.trn
sed -i.bak3 -f ${filter} ${dir}/ref${n}.trn
fi
done
results_str=""
for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do
ind_r=$((i / num_spkrs + 1))
ind_h=$((i % num_spkrs + 1))
results_str=${results_str}"${dir}/result_r${ind_r}h${ind_h}.txt "
sclite -r ${dir}/ref${ind_r}.trn trn -h ${dir}/hyp${ind_h}.trn trn -i rm -o all stdout > ${dir}/result_r${ind_r}h${ind_h}.txt
done
echo "write CER (or TER) results in ${dir}/result_r*h*.txt"
eval_perm_free_error.py --num-spkrs ${num_spkrs} \
${results_str} > ${dir}/min_perm_result.json
sed -n '2,4p' ${dir}/min_perm_result.json
if ${wer}; then
for n in $(seq ${num_spkrs}); do
if [ -n "$bpe" ]; then
spm_decode --model=${bpemodel} --input_format=piece < ${dir}/ref${n}.trn | sed -e "s/▁/ /g" > ${dir}/ref${n}.wrd.trn
spm_decode --model=${bpemodel} --input_format=piece < ${dir}/hyp${n}.trn | sed -e "s/▁/ /g" > ${dir}/hyp${n}.wrd.trn
else
sed -e "s/ //g" -e "s/(/ (/" -e "s/<space>/ /g" ${dir}/ref${n}.trn > ${dir}/ref${n}.wrd.trn
sed -e "s/ //g" -e "s/(/ (/" -e "s/<space>/ /g" ${dir}/hyp${n}.trn > ${dir}/hyp${n}.wrd.trn
fi
done
results_str=""
for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do
ind_r=$((i / num_spkrs + 1))
ind_h=$((i % num_spkrs + 1))
results_str=${results_str}"${dir}/result_r${ind_r}h${ind_h}.wrd.txt "
sclite -r ${dir}/ref${ind_r}.wrd.trn trn -h ${dir}/hyp${ind_h}.wrd.trn trn -i rm -o all stdout > ${dir}/result_r${ind_r}h${ind_h}.wrd.txt
done
echo "write WER results in ${dir}/result_r*h*.wrd.txt"
eval_perm_free_error.py --num-spkrs ${num_spkrs} \
${results_str} > ${dir}/min_perm_result.wrd.json
sed -n '2,4p' ${dir}/min_perm_result.wrd.json
fi
fi
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
while(<>){
@A = split(" ", $_);
@A > 1 || die "Invalid line in spk2utt file: $_";
$s = shift @A;
foreach $u ( @A ) {
print "$u $s\n";
}
}
#!/usr/bin/env bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
set -o errexit
if [ $# != 2 ]; then
echo "Usage: split_data.sh data-dir num-to-split"
exit 1
fi
data=$1
numsplit=$2
if [ $numsplit -le 0 ]; then
echo "Invalid num-split argument $numsplit";
exit 1;
fi
n=0;
feats=""
wavs=""
utt2spks=""
texts=""
nu=`cat $data/utt2spk | wc -l`
nf=`cat $data/feats.scp | wc -l`
nt=`cat $data/text | wc -l`
if [ $nu -ne $nf ]; then
echo "split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf);"
echo "this script may produce incorrectly split data."
echo "use utils/fix_data_dir.sh to fix this."
fi
if [ $nt -ne 0 -a $nu -ne $nt ]; then
echo "split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt);"
echo "this script may produce incorrectly split data."
echo "use utils/fix_data_dir.sh to fix this."
fi
# utilsscripts/get_split.pl returns "0 1 2 3" or "00 01 .. 18 19" or whatever.
# for n in `get_splits.pl $numsplit`; do
for n in `seq 1 $numsplit`; do # Changed this to usual number sequence -Arnab
mkdir -p $data/split$numsplit/$n
feats="$feats $data/split$numsplit/$n/feats.scp"
wavs="$wavs $data/split$numsplit/$n/wav.scp"
texts="$texts $data/split$numsplit/$n/text"
utt2spks="$utt2spks $data/split$numsplit/$n/utt2spk"
done
split_scp.pl --utt2spk=$data/utt2spk $data/utt2spk $utt2spks
split_scp.pl --utt2spk=$data/utt2spk $data/feats.scp $feats
[ -f $data/wav.scp ] && \
split_scp.pl --utt2spk=$data/utt2spk $data/wav.scp $wavs
[ -f $data/text ] && \
split_scp.pl --utt2spk=$data/utt2spk $data/text $texts
# for n in `get_splits.pl $numsplit`; do
for n in `seq 1 $numsplit`; do # Changed this to usual number sequence -Arnab
utt2spk_to_spk2utt.pl $data/split$numsplit/$n/utt2spk \
> $data/split$numsplit/$n/spk2utt
# for completeness, also split the spk2gender file
[ -f $data/spk2gender ] && \
filter_scp.pl $data/split$numsplit/$n/spk2utt $data/spk2gender \
> $data/split$numsplit/$n/spk2gender
done
exit 0
\ No newline at end of file
#!/usr/bin/env bash
set -o errexit
if [ $# != 2 ]; then
echo "Usage: split_json.sh manifest num-to-split"
exit 1
fi
data=data
jsonfile=$1
numsplit=$2
if [ $numsplit -le 0 ]; then
echo "Invalid num-split argument $numsplit";
exit 1;
fi
n=0;
jsons=""
# utilsscripts/get_split.pl returns "0 1 2 3" or "00 01 .. 18 19" or whatever.
# for n in `get_splits.pl $numsplit`; do
for n in `seq 1 $numsplit`; do # Changed this to usual number sequence -Arnab
mkdir -p $data/split$numsplit/$n
jsons="$jsons $data/split$numsplit/$n/${jsonfile}"
done
split_scp.pl $data/${jsonfile} $jsons
exit 0
#!/usr/bin/env bash
# 2020 Author Jiayu DU
# Apache 2.0
# This script uses kenlm to estimate an arpa model from plain text,
# it is a resort when you hit memory limit dealing with large corpus
# kenlm estimates arpa using on-disk structure,
# as long as you have big enough hard disk, memory shouldn't be a problem.
# by default, kenlm use up to 50% of your local memory,
# you can control this through -S option
[ -f path.sh ] && . ./path.sh;
kenlm_opts="" # e.g. "-o 4 -S 50% --prune 0 5 7 7"
if [ $# != 4 ]; then
echo "$0 <text> <kaldi_symbol_table> <working_dir> <arpa_name>"
echo "e.g. $0 train.txt words.txt wdir 4gram"
exit 1
fi
text=$1
symbol_table=$2
dir=$3
arpa_name=$4
if ! which lmplz >& /dev/null ; then
echo "$0: cannot find training tool *lmplz*."
echo "tools/extras/install_kenlm_query_only.sh installs kenlm at tools/kenlm"
echo "it only supports runtime mode, to actually train an arpa using KenLM,"
echo "you need a complete KenLM installation(depends on EIGEN and BOOST),"
echo "follow KenLM's building instructions at (https://github.com/kpu/kenlm)"
exit 1
fi
# the text should be properly pre-processed, e.g:
# cleand, normalized and possibly word-segmented
# get rid off irrelavent symbols
grep -v '<eps>' $symbol_table \
| grep -v '#0' \
| grep -v '<unk>' | grep -v '<UNK>' \
| grep -v '<s>' | grep -v '</s>' \
| awk '{print $1}' \
> $dir/ngram.vocab
# To make sure that kenlm & kaldi have strictly the same vocabulary:
# 1. feed vocabulary into kenlm via --limit_vocab_file
# 2. cat vocabulary to training text, so each word at least appear once
#
# TL;DR reason:
# Unlike SRILM's -limit-vocab, kenlm's --limit_vocab_file option
# spcifies a *valid* set of vocabulary, whereas *valid but unseen*
# words are discarded in final arpa.
# So the trick is,
# we explicitly add kaldi's vocab(one word per line) to training text,
# making each word appear at least once.
# kenlm never prunes unigram,
# so this always generates consistent kenlm vocabuary as kaldi has.
# The effect of this is like add-one smoothing to unigram counts,
# shouldn't have significant impacts in practice.
cat $dir/ngram.vocab $text \
| lmplz $kenlm_opts --limit_vocab_file $dir/ngram.vocab \
> $dir/${arpa_name}.arpa
echo "$0: Done training arpa to: $dir/${arpa_name}.arpa"
\ No newline at end of file
文件模式从 100644 更改为 100755
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# converts an utt2spk file to a spk2utt file.
# Takes input from the stdin or from a file argument;
# output goes to the standard out.
if ( @ARGV > 1 ) {
die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
}
while(<>){
@A = split(" ", $_);
@A == 2 || die "Invalid line in utt2spk file: $_";
($u,$s) = @A;
if(!$seen_spk{$s}) {
$seen_spk{$s} = 1;
push @spklist, $s;
}
push (@{$spk_hash{$s}}, "$u");
}
foreach $s (@spklist) {
$l = join(' ',@{$spk_hash{$s}});
print "$s $l\n";
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册