提交 10cd6560 编写于 作者: H Hui Zhang

remove notebook

上级 4e9bc9ed
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "academic-surname",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from paddle import nn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fundamental-treasure",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"L = nn.Linear(256, 2048)\n",
"L2 = nn.Linear(2048, 256)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "consolidated-elephant",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "moderate-noise",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float64\n",
"Tensor(shape=[2, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[[-1.54171216, -2.61531472, -1.79881978, ..., -0.31395876, 0.56513089, -0.44516513],\n",
" [-0.79492962, 1.91157901, 0.66567147, ..., 0.54825783, -1.01471853, -0.84924090],\n",
" [-1.22556651, -0.36225814, 0.65063190, ..., 0.65726501, 0.05563191, 0.09009409],\n",
" ...,\n",
" [ 0.38615900, -0.77905393, 0.99732304, ..., -1.38463700, -3.32365036, -1.31089687],\n",
" [ 0.05579993, 0.06885809, -1.66662002, ..., -0.23346378, -3.29372883, 1.30561364],\n",
" [ 1.90676069, 1.95093191, -0.28849599, ..., -0.06860496, 0.95347673, 1.00475824]],\n",
"\n",
" [[-0.91453546, 0.55298805, -1.06146812, ..., -0.86378336, 1.00454640, 1.26062179],\n",
" [ 0.10223761, 0.81301165, 2.36865163, ..., 0.16821407, 0.29240361, 1.05408621],\n",
" [-1.33196676, 1.94433689, 0.01934209, ..., 0.48036841, 0.51585966, 1.22893548],\n",
" ...,\n",
" [-0.19558455, -0.47075930, 0.90796155, ..., -1.28598249, -0.24321797, 0.17734711],\n",
" [ 0.89819717, -1.39516675, 0.17138045, ..., 2.39761519, 1.76364994, -0.52177650],\n",
" [ 0.94122332, -0.18581429, 1.36099780, ..., 0.67647684, -0.04699665, 1.51205540]]])\n",
"tensor([[[-1.5417, -2.6153, -1.7988, ..., -0.3140, 0.5651, -0.4452],\n",
" [-0.7949, 1.9116, 0.6657, ..., 0.5483, -1.0147, -0.8492],\n",
" [-1.2256, -0.3623, 0.6506, ..., 0.6573, 0.0556, 0.0901],\n",
" ...,\n",
" [ 0.3862, -0.7791, 0.9973, ..., -1.3846, -3.3237, -1.3109],\n",
" [ 0.0558, 0.0689, -1.6666, ..., -0.2335, -3.2937, 1.3056],\n",
" [ 1.9068, 1.9509, -0.2885, ..., -0.0686, 0.9535, 1.0048]],\n",
"\n",
" [[-0.9145, 0.5530, -1.0615, ..., -0.8638, 1.0045, 1.2606],\n",
" [ 0.1022, 0.8130, 2.3687, ..., 0.1682, 0.2924, 1.0541],\n",
" [-1.3320, 1.9443, 0.0193, ..., 0.4804, 0.5159, 1.2289],\n",
" ...,\n",
" [-0.1956, -0.4708, 0.9080, ..., -1.2860, -0.2432, 0.1773],\n",
" [ 0.8982, -1.3952, 0.1714, ..., 2.3976, 1.7636, -0.5218],\n",
" [ 0.9412, -0.1858, 1.3610, ..., 0.6765, -0.0470, 1.5121]]])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"x = np.random.randn(2, 51, 256)\n",
"print(x.dtype)\n",
"px = paddle.to_tensor(x, dtype='float32')\n",
"tx = torch.tensor(x, dtype=torch.float32)\n",
"print(px)\n",
"print(tx)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cooked-progressive",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "mechanical-prisoner",
"metadata": {},
"outputs": [],
"source": [
"data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n",
"t_norm_ff = data['norm_ff']\n",
"t_ff_out = data['ff_out']\n",
"t_ff_l_x = data['ff_l_x']\n",
"t_ff_l_a_x = data['ff_l_a_x']\n",
"t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
"t_ps = data['ps']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "indie-marriage",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "assured-zambia",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"L.set_state_dict({'weight': t_ps[0].T, 'bias': t_ps[1]})\n",
"L2.set_state_dict({'weight': t_ps[2].T, 'bias': t_ps[3]})\n",
"\n",
"ps = []\n",
"for n, p in L.named_parameters():\n",
" ps.append(p)\n",
"\n",
"for n, p in L2.state_dict().items():\n",
" ps.append(p)\n",
" \n",
"for p, tp in zip(ps, t_ps):\n",
" print(np.allclose(p.numpy(), tp.T))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "committed-jacob",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "extreme-traffic",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "optimum-milwaukee",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "viral-indian",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"# data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n",
"# t_norm_ff = data['norm_ff']\n",
"# t_ff_out = data['ff_out']\n",
"# t_ff_l_x = data['ff_l_x']\n",
"# t_ff_l_a_x = data['ff_l_a_x']\n",
"# t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
"# t_ps = data['ps']\n",
"TL = torch.nn.Linear(256, 2048)\n",
"TL2 = torch.nn.Linear(2048, 256)\n",
"TL.load_state_dict({'weight': torch.tensor(t_ps[0]), 'bias': torch.tensor(t_ps[1])})\n",
"TL2.load_state_dict({'weight': torch.tensor(t_ps[2]), 'bias': torch.tensor(t_ps[3])})\n",
"\n",
"# for n, p in TL.named_parameters():\n",
"# print(n, p)\n",
"# for n, p in TL2.named_parameters():\n",
"# print(n, p)\n",
"\n",
"ps = []\n",
"for n, p in TL.state_dict().items():\n",
" ps.append(p.data.numpy())\n",
" \n",
"for n, p in TL2.state_dict().items():\n",
" ps.append(p.data.numpy())\n",
" \n",
"for p, tp in zip(ps, t_ps):\n",
" print(np.allclose(p, tp))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "skilled-vietnamese",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[ 0.67277956 0.08313607 -0.62761104 ... -0.17480263 0.42718208\n",
" -0.5787626 ]\n",
" [ 0.91516656 0.5393416 1.7159258 ... 0.06144593 0.06486575\n",
" -0.03350811]\n",
" [ 0.438351 0.6227843 0.24096036 ... 1.0912522 -0.90929437\n",
" -1.012989 ]\n",
" ...\n",
" [ 0.68631977 0.14240924 0.10763275 ... -0.11513516 0.48065388\n",
" 0.04070369]\n",
" [-0.9525228 0.23197874 0.31264272 ... 0.5312439 0.18773697\n",
" -0.8450228 ]\n",
" [ 0.42024016 -0.04561988 0.54541194 ... -0.41933843 -0.00436018\n",
" -0.06663495]]\n",
"\n",
" [[-0.11638781 -0.33566502 -0.20887226 ... 0.17423287 -0.9195841\n",
" -0.8161046 ]\n",
" [-0.3469874 0.88269687 -0.11887559 ... -0.15566081 0.16357468\n",
" -0.20766167]\n",
" [-0.3847657 0.3984318 -0.06963477 ... -0.00360622 1.2360432\n",
" -0.26811332]\n",
" ...\n",
" [ 0.08230796 -0.46158582 0.54582864 ... 0.15747628 -0.44790155\n",
" 0.06020184]\n",
" [-0.8095085 0.43163058 -0.42837143 ... 0.8627463 0.90656304\n",
" 0.15847842]\n",
" [-1.485811 -0.18216592 -0.8882585 ... 0.32596245 0.7822631\n",
" -0.6460344 ]]]\n",
"[[[ 0.67278004 0.08313602 -0.6276114 ... -0.17480245 0.42718196\n",
" -0.5787625 ]\n",
" [ 0.91516703 0.5393413 1.7159253 ... 0.06144581 0.06486579\n",
" -0.03350812]\n",
" [ 0.43835106 0.62278455 0.24096027 ... 1.0912521 -0.9092943\n",
" -1.0129892 ]\n",
" ...\n",
" [ 0.6863195 0.14240888 0.10763284 ... -0.11513527 0.48065376\n",
" 0.04070365]\n",
" [-0.9525231 0.23197863 0.31264275 ... 0.53124386 0.18773702\n",
" -0.84502304]\n",
" [ 0.42024007 -0.04561983 0.545412 ... -0.41933888 -0.00436005\n",
" -0.066635 ]]\n",
"\n",
" [[-0.11638767 -0.33566508 -0.20887226 ... 0.17423296 -0.9195838\n",
" -0.8161046 ]\n",
" [-0.34698725 0.88269705 -0.11887549 ... -0.15566081 0.16357464\n",
" -0.20766166]\n",
" [-0.3847657 0.3984319 -0.06963488 ... -0.00360619 1.2360426\n",
" -0.26811326]\n",
" ...\n",
" [ 0.08230786 -0.4615857 0.5458287 ... 0.15747619 -0.44790167\n",
" 0.06020182]\n",
" [-0.8095083 0.4316307 -0.42837155 ... 0.862746 0.9065631\n",
" 0.15847899]\n",
" [-1.485811 -0.18216613 -0.8882584 ... 0.32596254 0.7822631\n",
" -0.6460344 ]]]\n",
"True\n",
"False\n"
]
}
],
"source": [
"y = L(px)\n",
"print(y.numpy())\n",
"\n",
"ty = TL(tx)\n",
"print(ty.data.numpy())\n",
"print(np.allclose(px.numpy(), tx.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "incorrect-allah",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "prostate-cameroon",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"id": "governmental-surge",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.04476918 0.554463 -0.3027508 ... -0.49600336 0.3751858\n",
" 0.8254095 ]\n",
" [ 0.95594174 -0.29528382 -1.2899452 ... 0.43718258 0.05584608\n",
" -0.06974669]]\n",
"[[ 0.04476918 0.5544631 -0.3027507 ... -0.49600336 0.37518573\n",
" 0.8254096 ]\n",
" [ 0.95594174 -0.29528376 -1.2899454 ... 0.4371827 0.05584623\n",
" -0.0697467 ]]\n",
"True\n",
"False\n",
"True\n"
]
}
],
"source": [
"x = np.random.randn(2, 256)\n",
"px = paddle.to_tensor(x, dtype='float32')\n",
"tx = torch.tensor(x, dtype=torch.float32)\n",
"y = L(px)\n",
"print(y.numpy())\n",
"ty = TL(tx)\n",
"print(ty.data.numpy())\n",
"print(np.allclose(px.numpy(), tx.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy(), atol=1e-5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "confidential-jacket",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
"id": "improved-civilization",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5e7e7c9fde8350084abf1898cf52651cfc84b17a\n"
]
}
],
"source": [
"print(paddle.version.commit)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d1e2d3b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" 'commit',\n",
" 'full_version',\n",
" 'istaged',\n",
" 'major',\n",
" 'minor',\n",
" 'mkl',\n",
" 'patch',\n",
" 'rc',\n",
" 'show',\n",
" 'with_mkl']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(paddle.version)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c880c719",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.1.0\n"
]
}
],
"source": [
"print(paddle.version.full_version)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f26977bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"commit: 5e7e7c9fde8350084abf1898cf52651cfc84b17a\n",
"None\n"
]
}
],
"source": [
"print(paddle.version.show())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "04ad47f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.6.0\n"
]
}
],
"source": [
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "e1e03830",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" '__version__',\n",
" 'cuda',\n",
" 'debug',\n",
" 'git_version',\n",
" 'hip']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(torch.version)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "4ad0389b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'b31f58de6fa8bbda5353b3c77d9be4914399724d'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.version.git_version"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7870ea10",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'10.2'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.version.cuda"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db8ee5a7",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6321ec2a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d6a0e098",
"metadata": {},
"outputs": [],
"source": [
"from typing import Union\n",
"\n",
"import torch\n",
"from torch.optim.lr_scheduler import _LRScheduler\n",
"\n",
"from typeguard import check_argument_types\n",
"\n",
"\n",
"class WarmupLR(_LRScheduler):\n",
" \"\"\"The WarmupLR scheduler\n",
" This scheduler is almost same as NoamLR Scheduler except for following\n",
" difference:\n",
" NoamLR:\n",
" lr = optimizer.lr * model_size ** -0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" WarmupLR:\n",
" lr = optimizer.lr * warmup_step ** 0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" Note that the maximum lr equals to optimizer.lr in this scheduler.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" optimizer: torch.optim.Optimizer,\n",
" warmup_steps: Union[int, float] = 25000,\n",
" last_epoch: int = -1,\n",
" ):\n",
" assert check_argument_types()\n",
" self.warmup_steps = warmup_steps\n",
"\n",
" # __init__() must be invoked before setting field\n",
" # because step() is also invoked in __init__()\n",
" super().__init__(optimizer, last_epoch)\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n",
"\n",
" def get_lr(self):\n",
" step_num = self.last_epoch + 1\n",
" return [\n",
" lr\n",
" * self.warmup_steps ** 0.5\n",
" * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n",
" for lr in self.base_lrs\n",
" ]\n",
"\n",
" def set_step(self, step: int):\n",
" self.last_epoch = step"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0d496677",
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"model = torch.nn.Linear(10, 200)\n",
"optimizer = optim.Adam(model.parameters())\n",
"scheduler = WarmupLR(optimizer, warmup_steps=25000)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e3e3f3dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.0 -1\n"
]
}
],
"source": [
"infos = {}\n",
"start_epoch = infos.get('epoch', -1) + 1\n",
"cv_loss = infos.get('cv_loss', 0.0)\n",
"step = infos.get('step', -1)\n",
"print(start_epoch, cv_loss, step)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "dc3d550c",
"metadata": {},
"outputs": [],
"source": [
"scheduler.set_step(step)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e527634e",
"metadata": {},
"outputs": [],
"source": [
"lrs=[]\n",
"for i in range(100000):\n",
" scheduler.step()\n",
" lrs.append(scheduler.get_lr())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f1452db9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting matplotlib\n",
" Downloading matplotlib-3.4.1-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)\n",
"\u001b[K |████████████████████████████████| 10.3 MB 575 kB/s eta 0:00:01\n",
"\u001b[?25hCollecting kiwisolver>=1.0.1\n",
" Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)\n",
"\u001b[K |████████████████████████████████| 1.2 MB 465 kB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (8.1.2)\n",
"Requirement already satisfied: numpy>=1.16 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (1.20.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.8.1)\n",
"Collecting cycler>=0.10\n",
" Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n",
"Requirement already satisfied: pyparsing>=2.2.1 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n",
"Requirement already satisfied: six in /workspace/wenet/venv/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n",
"Installing collected packages: kiwisolver, cycler, matplotlib\n",
"Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n"
]
}
],
"source": [
"!pip install matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0f36d04f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f0c39aa82e0>]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqc0lEQVR4nO3deXxV1b338c8vCUkYkkAghJAEAhLQIJMEHHFCBa2KVkG0T7Wt1qet9ra1w9Xn3ufe1ld7b21tvVq1alut+mhJQK3Yqjig1SpCDgIyBiLTSZhCAglTyLSeP86GxjTDQZKc6ft+vXh5zjrrrLM2O+bL3mvv3zHnHCIiIu2JC/UEREQkvCkoRESkQwoKERHpkIJCREQ6pKAQEZEOJYR6Al1h0KBBLi8vL9TTEBGJKMuXL9/rnMvorF9UBEVeXh4+ny/U0xARiShmti2Yfjr1JCIiHVJQiIhIhxQUIiLSIQWFiIh0SEEhIiIdCioozGymmZWaWZmZ3d3G60lmVuS9vtTM8lq8do/XXmpmM1q0P2lme8xsTaux0s3sTTPb5P13wElsn4iInKROg8LM4oFHgMuBAuBGMyto1e1WYJ9zbhTwAHCf994CYC4wFpgJPOqNB/BHr621u4G3nXP5wNvecxERCZFgjiimAmXOuc3OuXpgHjCrVZ9ZwNPe4wXAdDMzr32ec+6oc24LUOaNh3PuPaC6jc9rOdbTwDXBb450p82VB3m3dE+opyEiPSyYoMgG/C2el3ttbfZxzjUCNcDAIN/bWqZzbqf3eBeQ2VYnM7vdzHxm5qusrAxiM+Rk3fS7pXzlqRLeWrc71FMRkR4U1ovZLvCtSm1+s5Jz7gnnXKFzrjAjo9M70OUkle05wK7aOgC+V7SSTysPhnhGItJTggmKCiC3xfMcr63NPmaWAKQBVUG+t7XdZpbljZUF6FxHGCj2lZMQZ7xy53kkJsRx+zM+DtQ1hHpaItIDggmKEiDfzEaYWSKBxemFrfosBG7xHl8PLPaOBhYCc72rokYA+cCyTj6v5Vi3AC8HMUfpRg1Nzbz4cTnTTxvMuJw0Hr7pDLZWHeZ7RatobtZX6YpEu06DwltzuBNYBKwHip1za83sXjO72uv2B2CgmZUBd+FdqeScWwsUA+uA14E7nHNNAGb2J2AJMMbMys3sVm+snwOXmtkm4BLvuYTQ4g172HuwnjmFgYPDs08ZyL9/4TTeWr+bB9/eFOLZiUh3s8A//CNbYWGhU/XY7nPb0yV8Ul7Dh3dfTEJ84N8Wzjl+uOATFiwv58G5E5k1sbNrFEQk3JjZcudcYWf9wnoxW0JvT20d75RWct3knOMhAWBm/Oza0zlzRDo/nP8JJVvbutJZRKKBgkI6tODjcpqa3fHTTi0lJcTz+Jcnk5Pem68/42PL3kMhmKGIdDcFhbTLOcd8XzlT89IZMahvm33690nkqa9MIc6Mrz61jOpD9T08SxHpbgoKaVfJ1n1s2XuIOVP++WiipeED+/K7myezo6aOrz/j40h9Uw/NUER6goJC2lXs89MvKYErxg3ptO/k4ek8eMNEVmzfxzefW059Y3MPzFBEeoKCQtp0oK6Bv36yk6smZNEnMbivVr98XBY/u3Yc75ZW8oP5usdCJFoE9xtAYs5fP9nJkYamNhexO3Lj1GHsP9zAfa9vIK13L+6dNZZAfUgRiVQKCmlTkc9P/uB+TMztf8Lv/eaFp7D/cD2Pv7eZ/n168f3LxnT9BEWkxygo5J9s2n2AFdv38+9fOO1zHw3cffmp1Bxp4DeLy0iMj+Pb0/O7eJYi0lMUFPJPin1+EuKMayZ9/rutAzfkjaO+sZlfvbkRM7jzYoWFSCRSUMhnBAoAVnDJaZkM6pd0UmPFxxm/nD0BgPvf2AgoLEQikYJCPuPt9XuoOlTPnCk5XTJe67AwM+64aFSXjC0iPUNBIZ8x3+cnMzWJ8/O77sugjoWFA365qJT6xma+e0m+roYSiRAKCjlud20d75Tu4RsXnPKZAoBdIT7OuH/2BBLijAff3kTNkQb+48oC4uIUFiLhTkEhxy1YXk6z44TvnQhWfJxx33XjSUnuxZMfbOFAXSP3XTeuy0NJRLqWgkKAYwUA/UwdkU5eOwUAu0JcnPF/rzyNtN69eOCtjRyoa+A3N00iKSG+2z5TRE6O/iknACzbUs3WqsPc0E1HEy2ZGd+5JJ//vKqAN9bt5qtPlVCr798WCVsKCgGg2FfuFQDM6rHP/Oq5I/j1nAks21LN9b/9kIr9R3rss0UkeAoK4UBdA6+u3slVE4bSO7FnTwF98Ywcnv7aVHbur+PaRz5gTUVNj36+iHROQSH8xSsAeEMn3zvRXc4dNYgF3zyHXvFxzHl8Ce9s2BOSeYhI2xQUQlGJn9GZ/ZiQkxayOYwZksJL3zqHkRl9ufXpEp5ZshXnVKZcJBwoKGLcxt0HWOnfz5zC3JDfADc4NZmi28/mojGD+Y+X13LPi6s52qhvyxMJNQVFjCsu8dMr3rj2JAoAdqW+SQk8cXMhd1x0CvNK/Nz0u6Xsqa0L9bREYpqCIobVNzbz0opAAcCBJ1kAsCvFxxk/nHEqj9x0But21HLVw39npX9/qKclErMUFDFs8YbdgQKAPXDvxOfxhfFZvPDNc0iICyxyF5f4Qz0lkZikoIhhxb5yhqQmc/7orisA2NUKhqbyyrfPo3D4AH70wid8v3gVh+sbQz0tkZiioIhRu2rqeLd0D9dNziY+zAvzpfdN5Nlbz+Rfpufz4opyZj38AWV7DoR6WiIxQ0ERo174OFAAcPbk8Dzt1Fp8nHHXpaN55mtTqT5Uz1W/+YCXVpSHeloiMUFBEYOccxT7/JzZzQUAu8O0/Axe/c40xuWk8b2iVfxowSoOHdWpKJHupKCIQUu3VLOt6nDI7sQ+WZmpyTx/25ncedEo5i8v54qH3mf5tn2hnpZI1FJQxKBin5+UpAQuP73nCgB2tYT4OH4wYwxFt59NY5Nj9mMf8us3N9LQ1BzqqYlEnaCCwsxmmlmpmZWZ2d1tvJ5kZkXe60vNLK/Fa/d47aVmNqOzMc1supl9bGYrzezvZqYvWO5CtccKAE7s+QKA3WHqiHRe++40rpmUzUNvb2L2Y0vYsvdQqKclElU6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5m+BLznnJgLPA/9+Ulson/GXVTupa2juke+d6Cmpyb349ZyJPHzTJLbsPcQVD77PUx9soblZtaJEukIwRxRTgTLn3GbnXD0wD5jVqs8s4Gnv8QJgugUKB80C5jnnjjrntgBl3ngdjemAVO9xGrDj822atKXI52dMZgrjQ1gAsLtcOX4or393GmeOTOcnr6xjzuNL+LTyYKinJRLxggmKbKDlLbHlXlubfZxzjUANMLCD93Y05m3Aq2ZWDnwZ+HlbkzKz283MZ2a+ysrKIDZDSncdYJV/P3OmhL4AYHfJSuvNU1+Zwq9mT2DTnoNc/uD7/PbdT2nU2oXI5xaOi9nfA65wzuUATwG/bquTc+4J51yhc64wIyN87ywOJ8W+8CoA2F3MjOsm5/DmXedz0ZgM7nt9A9c++iHrd9aGemoiESmYoKgAWp7QzvHa2uxjZgkEThlVdfDeNtvNLAOY4Jxb6rUXAecEtSXSoWMFAC8tyCS9b2Kop9MjBqck89j/mswjN53Bjv1HuPI3f+e/Xl2v+y5ETlAwQVEC5JvZCDNLJLA4vbBVn4XALd7j64HFLvCtMwuBud5VUSOAfGBZB2PuA9LMbLQ31qXA+s+/eXLM2+t3U32ontlRtIgdDDPjC+OzeOuuC5hTmMMT721m+q/+xmurd+qLkUSClNBZB+dco5ndCSwC4oEnnXNrzexewOecWwj8AXjWzMqAagK/+PH6FQPrgEbgDudcE0BbY3rtXwdeMLNmAsHxtS7d4hhV5PMHCgDmx+ZpugF9E/nvL45ndmEu//bSGr753MdcMDqDn1w9NuLuThfpaRYN/6oqLCx0Pp8v1NMIWztrjnDuzxfzrQtH8YMZY0I9nZBrbGrm2Y+28as3NlLf1Mw3LjiFb1wwkj6Jnf67SSSqmNly51xhZ/3CcTFbutgLy70CgIU5oZ5KWEiIj+Or547g7e9fwMyxQ3jo7U1cfP/fePHjct17IdIGBUWUa252FPvKOWtkOsMH6hRLS5mpyTx04yQWfONsMlOTuKt4Fdc8+gG+rdWhnppIWFFQRLmlW6rZXh25BQB7QmFeOi9961weuGECe2qPcv1jS7jj+Y/xVx8O9dREwoJOyka5+T4/KcmRXQCwJ8TFGddOymHG2CE8/rfNPP7ep7y5djdfOmsYd1w0ikFh9J3iIj1NRxRRrLaugVfX7OTqCUNJ7hX5BQB7Qp/EBL536Wje+cGFXDspm6c/3MoFv3iHX7+5kQN1DaGenkhIKCii2CurdgQKAOq00wnLSuvNfdeP543vXcAFYzJ46O1NnP+Ld/j9+5upa2gK9fREepSCIooVl/g5dUgK47KjrwBgTxk1uB+PfmkyC+88l9Oz0/jpX9dz0f3v8vzS7dQ3qn6UxAYFRZTasKuWVeU1zCmM3gKAPWl8Tn+evfVMnr/tTDJTk/k/L63mwl++w7NLtnK0UUcYEt0UFFGquKScXvHGNVFeALCnnTNqEC996xye+dpUsvr35v++vJYLfvEuf/xgi05JSdRSUEShQAHAci4rGBIzBQB7kplx/ugMFnzjbJ677UyGpffhx6+sY5q3hnG4XkUHJbro8tgo9Nb63ew73KA7sbuZmXHuqEGcO2oQH22u4qG3N/HTv67n4XfKuPms4dx8Tp4uq5WooKCIQkUlfrLSkpkWowUAQ+GskQM5a+RAlm+r5rG/beahxWU8/t5mrp+cw9enjVThQYloCooos2P/Ed7bVMmdF40iPk6L2D1t8vB0fndzOmV7DvL79zcz31fO88u2M3PsEG4/fySThg0I9RRFTpiCIsq8sLwc52D2ZN07EUqjBvfj59eN565LR/PHD7fy/z7axmtrdjE1L51bzsnjsrGZ9IrXEqFEBpUZjyLNzY4L73+XnAG9ef7rZ4V6OtLCwaONzFu2naeXbMVffYQhqcl8+ezhzJ2Sy0CtY0iIqMx4DPpoSxXbqw8zJ8a+xS4S9EtK4LZpI3n3Bxfxu5sLGTW4H79cVMrZP1/M94tXsbq8JtRTFGmXTj1Fkfm+clKSE5h5+pBQT0XaER9nXFqQyaUFmZTtOcDTH27jhY/LeeHjcs4Y1p8vnz2cy0/PUm0uCSs69RQlao40MPVnbzG7MIefXjMu1NORE1Bb18ACXznPLNnK1qrDpPXuxbWTsrlx6jDGDEkJ9fQkigV76klHFFHilVU7ONrYzA2Fw0I9FTlBqcm9+Np5I/jKOXl8tLmK55dt57ml2/jjh1s5Y1h/bpw6jCvHD6V3oo4yJDR0RBElrn7479Q3NvPad6aptlMUqDp4lBc/ruBPJdvZXHmIlOQErpmYzQ1Tchk7NFX7WLqEjihiyPqdtXxSXsN/XlWgXyBRYmC/JL5+/khumzaCZVuq+dOy7RT5/Dz70TbGZKZw3eRsZk3MJjM1OdRTlRigoIgCxT4/ifFxXDNRBQCjjZlx5siBnDlyID8+XM8rn+zkxY/L+a9XN/Dz1zZwXn4G152RzWUFQ3RqSrqNgiLCHW1s4s8rKrh0bCYDVAAwqvXvk8iXzxrOl88azqeVB3np4wpeWlHBd+atpF9SAl8Yl8UXz8hmSl46cborX7qQgiLCvbVuD/sON+jeiRhzSkY/fjBjDHddOpqPtlTx4scV/OWTHRT5AnW+rhyfxZXjhzI+J02nI+WkaTE7wt385DLKdh/g/X+9WLWdYtzh+kbeWLubv3yyg79trKShyTEsvQ9XTcjiqglDGZOZotCQz9BidgzYsf8I72+q5NsqAChAn8QErpmUzTWTsqk53MCitbt45ZMdPPa3zTzyzqeMGtyPq8YP5coJWZyS0S/U05UIoqCIYAuOFQDUaSdpJa1PL+ZMyWXOlFz2HjzKa2t28ZdVO/iftzfywFsbOXVICpeNHcKMsZkUZOlyW+mYTj1FqOZmxwX3v8Ow9D48d5sKAEpwdtXU8erqnby+dhe+rdU0O8hN782MgiHMOH0IZwwboKPTGKJTT1Huo81V+KuP8IPLxoR6KhJBhqQl87XzRvC180aw9+BR3lq3m0Vrd/HMkm38/u9bGNQviUsLMpl5+hDOHjmQxATVDRUFRcQq9vlJTU5gxlgVAJTPZ1C/JOZOHcbcqcM4UNfAO6WVLFq7i5dXVvCnZdtJSUrg/NEZXHTqYC4ck6GvdY1hQQWFmc0EHgTigd87537e6vUk4BlgMlAF3OCc2+q9dg9wK9AE/ItzblFHY1rgZOlPgdnee37rnHvo5DYzutQcaeC1NbuYU5irKqPSJVKSe3H1hKFcPWEodQ1NfFC2lzfW7uad0j38dfVOzGBCTn+mnzqYi04drDIiMabToDCzeOAR4FKgHCgxs4XOuXUtut0K7HPOjTKzucB9wA1mVgDMBcYCQ4G3zGy09572xvwKkAuc6pxrNrPBXbGh0WThsQKAU7SILV0vuVc800/LZPppmTQ3O9btrOXt9XtYXLqHX725kV+9uZEhqclcdGoGF5+aybmjBtInUScnolkwe3cqUOac2wxgZvOAWUDLoJgF/Nh7vAB42DsymAXMc84dBbaYWZk3Hh2M+U3gJudcM4Bzbs/n37zoVFzi57SsVMYOTQ31VCTKxcUZp2encXp2Gt+5JJ89B+p4t7SSdzbsYeHKHfxpmZ/EhDim5qUzLX8Q54/O4NQhul8j2gQTFNmAv8XzcuDM9vo45xrNrAYY6LV/1Oq9xwoStTfmKQSORq4FKgmcrtrUelJmdjtwO8CwYbFTWnvdjlpWV9TwYxUAlBAYnJLMnMJc5hTmUt/YTMnWahZv2MP7myr579c28N+vbSAjJYlpowYxbfQgzhuVQUaK1jYiXTgeLyYBdc65QjP7IvAkMK11J+fcE8ATELg8tmenGDrHCgDOUgFACbHEhDjOHTWIc0cNAgKX3r6/qZL3N+3l3Y2VvLiiAoDTslI5P38Q0/IzKMwboHW1CBRMUFQQWDM4Jsdra6tPuZklAGkEFrU7em977eXAi97jl4CngphjTDja2MSfV1ZwmQoAShgakpbM7MJcZhfmHl/beG9TJe9v3MuTH2zh8fc2k5QQR2HeAM4eOZCzRg5kfE5/XYIbAYIJihIg38xGEPhlPhe4qVWfhcAtwBLgemCxc86Z2ULgeTP7NYHF7HxgGWAdjPln4CJgC3ABsPFzb12UeXPdbvarAKBEgJZrG9+6cBSHjjaybEs172/ay5LNVdz/RuB/69694gPBccpAzh45kHHZaSTEKzjCTadB4a053AksInAp65POubVmdi/gc84tBP4APOstVlcT+MWP16+YwCJ1I3CHc64JoK0xvY/8OfCcmX0POAjc1nWbG9mKSvxk9+99/FBfJFL0TUrgIu/SWoB9h+pZuqWKJZ9WsWRzFb94vRSAfkkJTDkeHIMoGJqqO8XDgEp4RIiK/Uc4777FfPvifO66dHTnbxCJIHsPHuWjzf8Ijs2VhwBISUrgjOEDmJI3gMK8dCbm9tcaRxdSCY8os8BXDsDsyTkhnolI1xvUL4krxw/lyvFDAdhdW8dHm6tYtqUa39Z9x09V9YoPnNKakpdO4fBAeKRrva7bKSgiQHOzY/5yP+eeMojc9D6hno5It8tMTWbWxOzjV/ftP1zP8m37KNm6D9/Wav74wVaeeG8zAKdk9A0Ehxcewwf20aXjXUxBEQGWbK6ifN8RfjhDBQAlNvXvk3j8bnGAuoYmVlfUULI1cMTx6uqdzCsJ3JqV3jeRibn9mZTbn0nDBjA+N43U5F6hnH7EU1BEABUAFPms5F7xTMlLZ0peOhA46t605yC+bdWs3L6fFf79LN4QKOpgBqMy+gXCY9gAJub2Z3RmP11ddQIUFGGu5nCgAODcKSoAKNKeuDhjzJAUxgxJ4UtnDgcCxTM/Kd/Piu37Wenfz1vrdzN/eWCtr09iPONz0piYGwiOCblpDElN1imrdigowtzCVRXUNzbr3gmRE5TWuxfT8jOYlp8BgHOObVWHWenfz4rt+1jp38/v399MY3Pgys9B/RI5PTuN8d79H+NyFB7HKCjCXJHPT0FWKqdnp4V6KiIRzczIG9SXvEF9uWZSYJG8rqGJtTtqWVNRw+qKGlaX1/Dexkq87GBQvyTGZacyLqc/47LTGJ+TRmZqcgi3IjQUFGFs7Y4a1lTU8pOrx4Z6KiJRKblXPJOHD2Dy8AHH247UN7FuZyA0VlfUsrpiP39rER4ZKUmM8446CrwqzjkDekf1kYeCIozN95WTmBDHrIlDQz0VkZjROzGeycPTmTw8/Xjb4fpG1u+s5ZPywJHHmooa3i3dczw8UpISOC0rldOyUjgtK5WCoamMzkyJmnVFBUWYqmto4qUVFcwYO4T+fXRDkUgo9UlM+KfwOFLfROnuA6zbUcv6nbWs21nLguXlHKpvAiDO4JSMfseD41iQDE6JvFNXCoow9ea63dQcaWBOoe7EFglHvRPjmZjbn4m5/Y+3NTc7tlcfZv3Of4TH8m37WLhqx/E+g/olcVpWCmMyUxg9JIXRmSnkD+5H36Tw/XUcvjOLccU+rwDgKSoAKBIp4uL+sWB++bis4+37D9ezfueB4+Gxfmctz360jaONzcf75Kb3ZkxmCvmZXohkpnDK4L4kJYT+9JWCIgyV7zvM38v28p3p+cSpcqZIxOvfJzFQEfeUgcfbmpod/urDlO4+wMZdByjdfYBNuw/ybmnl8Ut24+OMvIF9GO0Fx5ghKYzO7EfewL49esOggiIMLfBuCrpeBQBFolZ8i6OPllUX6hub2Vp1iI0tAmTDrgMsWrvr+OJ5YnwcIwb1ZdTgftx9+andXgNOQRFmmpsd833lnDdqEDkDVABQJNYkJsQdP4Jg/D/a6xqaKNtzMBAguw9Stucga3fU9Mg3BCoowsyHn1ZRsf8I/3r5qaGeioiEkeRe8ce/NbCnqSpWmCn2+Unr3YvLCjJDPRUREUBBEVZqDjfw+tpdXDNxaNTcqCMikU9BEUZePlYAcIoKAIpI+FBQhJGiEj9jh6YydqgKAIpI+FBQhIk1FTWs3VHLDTqaEJEwo6AIE/N9/kABwAnZoZ6KiMhnKCjCQF1DE39euYOZY4eQ1kff7Ssi4UVBEQbeOF4AUKedRCT8KCjCQHGJn5wBvTmnRR0YEZFwoaAIMX/1YT74dC+zJ+eqAKCIhCUFRYgdLwCo750QkTCloAih5mbHguWBAoDZ/XuHejoiIm1SUITQB5/upWL/ES1ii0hYU1CEULGvnP59enHZWBUAFJHwpaAIkf2H61m0dhfXTMwOi686FBFpT1BBYWYzzazUzMrM7O42Xk8ysyLv9aVmltfitXu89lIzm3ECYz5kZgc/53aFvZdX7ggUANRpJxEJc50GhZnFA48AlwMFwI1mVtCq263APufcKOAB4D7vvQXAXGAsMBN41MziOxvTzAqBASe5bWGtqMTP6dmpFAxNDfVUREQ6FMwRxVSgzDm32TlXD8wDZrXqMwt42nu8AJhuZua1z3POHXXObQHKvPHaHdMLkV8CPzq5TQtfaypqWLezlht0NCEiESCYoMgG/C2el3ttbfZxzjUCNcDADt7b0Zh3Agudczs7mpSZ3W5mPjPzVVZWBrEZ4aPYKwB4tQoAikgECKvFbDMbCswGftNZX+fcE865QudcYUZGRvdProvUNTTx5xUVXH66CgCKSGQIJigqgJbnSHK8tjb7mFkCkAZUdfDe9tonAaOAMjPbCvQxs7IgtyUiLFq7i9q6Ri1ii0jECCYoSoB8MxthZokEFqcXtuqzELjFe3w9sNg557z2ud5VUSOAfGBZe2M65/7qnBvinMtzzuUBh70F8qhR7POTm96bs0eqAKCIRIaEzjo45xrN7E5gERAPPOmcW2tm9wI+59xC4A/As96//qsJ/OLH61cMrAMagTucc00AbY3Z9ZsXXvzVh/mgrIq7Lh2tAoAiEjE6DQoA59yrwKut2v6jxeM6AmsLbb33Z8DPghmzjT79gplfpJi/vBwzuG6yCgCKSOQIq8XsaNbU7Fjg8zMtP0MFAEUkoigoesgHZXvZUVPHHJUTF5EIo6DoIcU+P/379OLSAhUAFJHIoqDoAfsO1fPG2t0qACgiEUlB0QNeXllBfZMKAIpIZFJQdDPnHEW+csZlp6kAoIhEJAVFN1tTUcv6nbXMmaKjCRGJTAqKblbs85OUEMfVE4aGeioiIp+LgqIb1TU08eeVXgHA3ioAKCKRSUHRjRat3cWBukaddhKRiKag6EZFJYECgGeNUAFAEYlcCopu4q8+zIefVjFncq4KAIpIRFNQdJP5Pr8KAIpIVFBQdIOmZseC5eWcn5/BUBUAFJEIp6DoBn8/XgBQi9giEvkUFN2g2OdnQJ9eXFIwONRTERE5aQqKLrbvUD1vrt3NNZNUAFBEooOCoou9tCJQAPAG3TshIlFCQdGFnHMU+/yMz0nj1CEqACgi0UFB0YVWV9SwYdcBLWKLSFRRUHShYwUAr1IBQBGJIgqKLlLX0MTLK3dwxbgsFQAUkaiioOgir6/xCgDqtJOIRBkFRRcpKvEzLL0PZ45ID/VURES6lIKiC2yvOsySzVXMKcxRAUARiToKii4wf7mfOBUAFJEopaA4SccLAI7OICtNBQBFJPooKE7S+5sq2akCgCISxRQUJ2m+r5z0volcclpmqKciItItFBQnofpQPW+s28U1E7NJTNBfpYhEp6B+u5nZTDMrNbMyM7u7jdeTzKzIe32pmeW1eO0er73UzGZ0NqaZPee1rzGzJ80sbO9ee2lFBQ1NTgUARSSqdRoUZhYPPAJcDhQAN5pZQatutwL7nHOjgAeA+7z3FgBzgbHATOBRM4vvZMzngFOBcUBv4LaT2sJu4pxjvs/PhJw0xgxJCfV0RES6TTBHFFOBMufcZudcPTAPmNWqzyzgae/xAmC6mZnXPs85d9Q5twUo88Zrd0zn3KvOAywDwvKa00/KvQKAOpoQkSgXTFBkA/4Wz8u9tjb7OOcagRpgYAfv7XRM75TTl4HX25qUmd1uZj4z81VWVgaxGV2r2OcnuZcKAIpI9AvnFdhHgfecc++39aJz7gnnXKFzrjAjI6NHJ3akvomFK3dwxelZpCaH7RKKiEiXSAiiTwXQ8vxKjtfWVp9yM0sA0oCqTt7b7phm9p9ABvC/g5hfj3t97U4OHG3UaScRiQnBHFGUAPlmNsLMEgksTi9s1WchcIv3+HpgsbfGsBCY610VNQLIJ7Du0O6YZnYbMAO40TnXfHKb1z2KSvwMH6gCgCISGzo9onDONZrZncAiIB540jm31szuBXzOuYXAH4BnzawMqCbwix+vXzGwDmgE7nDONQG0Nab3kY8B24AlgfVwXnTO3dtlW3yStlUd4qPN1fxwxhi8+YmIRLVgTj3hnHsVeLVV23+0eFwHzG7nvT8DfhbMmF57UHMKlfm+8kABwDPC8mIsEZEuF86L2WHnWAHAC0ZnMCQtOdTTERHpEQqKE/Depkp21aoAoIjEFgXFCZjv85PeN5HpKgAoIjFEQRGkqoNHeXPdbq6dpAKAIhJb9BsvSMcKAOq0k4jEGgVFEJxzFPv8TMjtrwKAIhJzFBRBWFVew8bdB7lBRxMiEoMUFEH4RwHArFBPRUSkxykoOnGkvolXVu7ginFZpKgAoIjEIAVFJ15bEygAqNNOIhKrFBSdKCrxkzewD1NVAFBEYpSCogNb9x5i6ZZqZhfmqgCgiMQsBUUH5i/3qwCgiMQ8BUU7jhUAvHDMYBUAFJGYpqBox3sbK9lde5Q5hTqaEJHYpqBoR1GJn4F9E7n4VBUAFJHYpqBoQ9XBo7y1XgUARURAQdGml1ZU0NjsmDNF906IiCgoWnHOUVTiZ2Juf0ZnqgCgiIiCopWV/v1s2nOQG3Q0ISICKCj+SbGvnN694rlyvAoAioiAguIzDtc38soqFQAUEWlJQdHCa6t3cfBoo047iYi0oKBoocjnZ8SgvkzJGxDqqYiIhA0FhWfL3kMs21LN7MIcFQAUEWlBQeGZ71MBQBGRtigogMamZl74uJyLxgwmM1UFAEVEWlJQAO9tChQAnK1vsRMR+ScKCgIFAAf1S2T6aYNDPRURkbAT80Gx9+BR3l6/h2snZdMrPub/OkRE/knM/2Z86eNAAUDdOyEi0raggsLMZppZqZmVmdndbbyeZGZF3utLzSyvxWv3eO2lZjajszHNbIQ3Rpk3ZuJJbmO7nHMU+/ycMaw/owarAKCISFs6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5n3AA95Y+7yxu8UKrwDgHC1ii4i0K5gjiqlAmXNus3OuHpgHzGrVZxbwtPd4ATDdAnetzQLmOeeOOue2AGXeeG2O6b3nYm8MvDGv+dxb14n5Pn+gAOCEod31ESIiES+YoMgG/C2el3ttbfZxzjUCNcDADt7bXvtAYL83RnufBYCZ3W5mPjPzVVZWBrEZ/2xYel++cm4e/ZISPtf7RURiQcT+hnTOPQE8AVBYWOg+zxjfvPCULp2TiEg0CuaIogJoeRI/x2trs4+ZJQBpQFUH722vvQro743R3meJiEgPCiYoSoB872qkRAKL0wtb9VkI3OI9vh5Y7JxzXvtc76qoEUA+sKy9Mb33vOONgTfmy59/80RE5GR1eurJOddoZncCi4B44Enn3FozuxfwOecWAn8AnjWzMqCawC9+vH7FwDqgEbjDOdcE0NaY3kf+KzDPzH4KrPDGFhGRELHAP+IjW2FhofP5fKGehohIRDGz5c65ws76xfyd2SIi0jEFhYiIdEhBISIiHVJQiIhIh6JiMdvMKoFtn/Ptg4C9XTidSKBtjg3a5uh3sts73DmX0VmnqAiKk2FmvmBW/aOJtjk2aJujX09tr049iYhIhxQUIiLSIQWFV1gwxmibY4O2Ofr1yPbG/BqFiIh0TEcUIiLSIQWFiIh0KKaDwsxmmlmpmZWZ2d2hns+JMLNcM3vHzNaZ2Voz+47Xnm5mb5rZJu+/A7x2M7OHvG39xMzOaDHWLV7/TWZ2S4v2yWa22nvPQ95X1Yac973rK8zsL97zEWa21JtnkVe6Hq+8fZHXvtTM8lqMcY/XXmpmM1q0h93PhJn1N7MFZrbBzNab2dnRvp/N7Hvez/UaM/uTmSVH2342syfNbI+ZrWnR1u37tb3P6JBzLib/EChv/ikwEkgEVgEFoZ7XCcw/CzjDe5wCbAQKgF8Ad3vtdwP3eY+vAF4DDDgLWOq1pwObvf8O8B4P8F5b5vU1772Xh3q7vXndBTwP/MV7XgzM9R4/BnzTe/wt4DHv8VygyHtc4O3vJGCE93MQH64/EwS+O/4273Ei0D+a9zOBrz/eAvRusX+/Em37GTgfOANY06Kt2/dre5/R4VxD/T9BCH8YzwYWtXh+D3BPqOd1EtvzMnApUApkeW1ZQKn3+HHgxhb9S73XbwQeb9H+uNeWBWxo0f6ZfiHczhzgbeBi4C/e/wR7gYTW+5XA952c7T1O8PpZ6319rF84/kwQ+LbILXgXnrTef9G4nwkEhd/75Zfg7ecZ0bifgTw+GxTdvl/b+4yO/sTyqadjP4zHlHttEcc71J4ELAUynXM7vZd2AZne4/a2t6P28jbaQ+1/gB8Bzd7zgcB+51yj97zlPI9vm/d6jdf/RP8uQmkEUAk85Z1u+72Z9SWK97NzrgK4H9gO7CSw35YT3fv5mJ7Yr+19RrtiOSiigpn1A14Avuucq235mgv8kyFqrn82syuBPc655aGeSw9KIHB64rfOuUnAIQKnC46Lwv08AJhFICSHAn2BmSGdVAj0xH4N9jNiOSgqgNwWz3O8tohhZr0IhMRzzrkXvebdZpblvZ4F7PHa29vejtpz2mgPpXOBq81sKzCPwOmnB4H+Znbsa31bzvP4tnmvpwFVnPjfRSiVA+XOuaXe8wUEgiOa9/MlwBbnXKVzrgF4kcC+j+b9fExP7Nf2PqNdsRwUJUC+dyVFIoFFsIUhnlPQvCsY/gCsd879usVLC4FjVz7cQmDt4lj7zd7VE2cBNd7h5yLgMjMb4P1L7jIC5293ArVmdpb3WTe3GCsknHP3OOdynHN5BPbXYufcl4B3gOu9bq23+djfxfVef+e1z/WulhkB5BNY+Au7nwnn3C7Ab2ZjvKbpBL6DPmr3M4FTTmeZWR9vTse2OWr3cws9sV/b+4z2hXLRKtR/CFxJsJHAFRD/Fur5nODczyNwyPgJsNL7cwWBc7NvA5uAt4B0r78Bj3jbuhoobDHW14Ay789XW7QXAmu89zxMqwXVEG//hfzjqqeRBH4BlAHzgSSvPdl7Xua9PrLF+//N265SWlzlE44/E8BEwOft6z8TuLolqvcz8BNggzevZwlcuRRV+xn4E4E1mAYCR4639sR+be8zOvqjEh4iItKhWD71JCIiQVBQiIhIhxQUIiLSIQWFiIh0SEEhIiIdUlCIiEiHFBQiItKh/w/uhegfvR+Q7QAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xs = list(range(100000))\n",
"plt.plot(xs, lrs)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4f4e282c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/wenet/venv/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"from typing import Union\n",
"\n",
"from paddle.optimizer.lr import LRScheduler\n",
"from typeguard import check_argument_types\n",
"\n",
"class WarmupLR(LRScheduler):\n",
" \"\"\"The WarmupLR scheduler\n",
" This scheduler is almost same as NoamLR Scheduler except for following\n",
" difference:\n",
" NoamLR:\n",
" lr = optimizer.lr * model_size ** -0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" WarmupLR:\n",
" lr = optimizer.lr * warmup_step ** 0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" Note that the maximum lr equals to optimizer.lr in this scheduler.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" warmup_steps: Union[int, float]=25000,\n",
" learning_rate=1.0,\n",
" last_epoch=-1,\n",
" verbose=False):\n",
" assert check_argument_types()\n",
" self.warmup_steps = warmup_steps\n",
" super().__init__(learning_rate, last_epoch, verbose)\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n",
"\n",
" def get_lr(self):\n",
" step_num = self.last_epoch + 1\n",
" return self.base_lr * self.warmup_steps**0.5 * min(\n",
" step_num**-0.5, step_num * self.warmup_steps**-1.5)\n",
"\n",
" def set_step(self, step: int):\n",
" self.step(step)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8c40b202",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1\n"
]
}
],
"source": [
"sc = WarmupLR(warmup_steps=25000, learning_rate=0.001)\n",
"print(step)\n",
"#sc.set_step(step)\n",
"sc.set_step(0)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ecbc7e37",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f0ba6dd9c40>]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqaUlEQVR4nO3de3xU9Z3/8dcnCUm4JIGEEAIBEiCAQW4SEG94F7QqagGhu9Varb9a3W51267+tr/dtrvdVevW1VardrVaa4WAN7QqKqJ4QchwvwYiAZMQICQQ7uT2/f0xB4xpLoMkmcnM+/l48GDmO99z5ns4Yd4553vOZ8w5h4iISHOigj0AEREJbQoKERFpkYJCRERapKAQEZEWKShERKRFMcEeQFvo3bu3y8zMDPYwREQ6lRUrVux1zqW21i8sgiIzMxOfzxfsYYiIdCpmtiOQfjr1JCIiLVJQiIhIixQUIiLSIgWFiIi0SEEhIiItCigozGyqmRWYWaGZ3dvE63FmNtd7fZmZZTZ47T6vvcDMpjRof8bM9pjZ+kbrSjazd81sq/d3r9PYPhEROU2tBoWZRQOPAVcCOcBsM8tp1O1WYJ9zbijwMPCAt2wOMAsYCUwFHvfWB/Cs19bYvcAi51w2sMh7LiIiQRLIEcVEoNA5t805Vw3MAaY16jMNeM57PB+41MzMa5/jnDvunCsCCr314ZxbAlQ28X4N1/UccF3gmyPtaVv5IT4o2BPsYYhIBwskKPoDxQ2el3htTfZxztUCVUBKgMs2luacK/Me7wLSmupkZrebmc/MfOXl5QFshpyuWU99xnf+mM+iTbuDPRQR6UAhPZnt/N+q1OQ3KznnnnLO5TrnclNTW70DXU7T1t0H2XPwOAA/mrOaz8sPBXlEItJRAgmKUmBAg+cZXluTfcwsBkgCKgJctrHdZpburSsd0LmOEJDnKyYmynj9rvPpEhPF7X/ycfBYTbCHJSIdIJCgyAeyzSzLzGLxT04vaNRnAXCz93g68L53NLAAmOVdFZUFZAPLW3m/huu6GXgtgDFKO6qpq+fllaVcdkYaozKSeOxbZ7G94gj35K2hvl5fpSsS7loNCm/O4S5gIbAJyHPObTCzX5rZtV63p4EUMysE7sG7Usk5twHIAzYCbwN3OufqAMzsRWApMNzMSszsVm9d9wOXm9lW4DLvuQTRok17qDhczcwJGQCcMySFf7nqDN7duJtH398a5NGJSHsz/y/+nVtubq5T9dj2c+uz+azfWcUn/3wJMdH+3y2cc/x43lpeWlnCI7PGMm1sa9coiEioMbMVzrnc1vqF9GS2BN/uA8dYXLCHb56VcTIkAMyM/7zhTCZmJfOTeWvJ397Ulc4iEg4UFNKil1aWUO9gZu6Av3ktLiaap749noxeXbn9Tz627z0chBGKSHtTUEiznHPM85UwMSuZzN7dm+zTs1ssf7xlAmbGLc/ms+9wdQePUkTam4JCmrW8qJKivYe5sYmjiYYGpXTnDzeNp3T/Ub73Jx9Hq+s6aIQi0hEUFNKsPF8JPeJiuHJU31b7jh+UzP/cOJYVX+zjBy+soKauvgNGKCIdQUEhTTp4rIY315VxzZh+dIsN7KvVrxqVzn9eP4rFBeX8eJ7usRAJF4F9AkjEeWNtGUdr6rhxQsunnRqbPXEg+45U8+DbBSR17cIvrh2Jvz6kiHRWCgpp0tz8Yoal9WBMRtIpL3vHhUPYf6SGp5Zso2fXLtxzxfB2GKGIdBQFhfyNLbsPsrp4Pz/7xhlf62jAzLjvyhFUHanh0fcLiY2J4q5LstthpCLSERQU8jfy8ovpEm1cP+7r323tvyFvFDV19Tz0zhbMjDsvHtqGoxSRjqKgkK+orq3nlVX+AoApPeJOa13RUcavZ4zBAb9eWACgsBDphBQU8hXvb97tLwDYyr0TgYqOMh6aMQbwh4UZ/OAihYVIZ6KgkK/I85XQNzGeycPa7sugToSFc44H3y6gptbxw0uH6mookU5CQSEn7ao6xgcFe7jjoiFER7Xth3h0lPHfM8cSEx3Fw+9toepoDT/7xhlEtfH7iEjbU1DISScKAM4Y3zannRqLjjIe/OZoEuJjeOaTIg4cq+H+G0Z9pSqtiIQeBYUAJwoAFnN2CwUA20JUlPGvV+eQ1LUL//PeVg4dq+WR2WOJi4lut/cUkdOjX+UEgGVFlWyvOHLKd2J/HWbGjy4bxr9encPbG3bx3WfzOaDv3xYJWQoKASDPV0xCXAxXnpneYe/53fOz+O8ZY1i2rZIZv1/Kzv1HO+y9RSRwCgrhwIkCgGP70TW2Y08BfXN8Bs/eMpGd+49y3WOfsL60qkPfX0Rap6AQ3lhTxrGa+ja7d+JUnZ/dm3l3nENMlHHjk0tZXLAnKOMQkaYpKIS5vmKGpyV8rQKAbWVE30ReufM8BqV057bnfDy/dDvOqUy5SChQUES4gl0HWVO8n5kTBgT9Bri0xHjyvn8OFw5L5f+9toH/+8o6qmv1BUgiwaagiHB5vtMvANiWesTF8IebcvnBRUN4cXkxs//wGXsOHgv2sEQimoIigp0oAHh5ThrJ3WODPZyToqOMn04dwe++NY6NOw9w7W8/YW3J/mAPSyRiKSgi2KJNu6k8XM2MIE1it+bq0f2Yf8c5REcZ059YyjxfcbCHJBKRFBQRLM9X7C8AmN12BQDb2sh+SSy46zxyB/XiJ/PX8uN5azhaXRfsYYlEFAVFhNpVdYwPt5QzfXxGmxcAbGspPeJ4/taz+eElQ3lpZQnTHvuYwj0Hgz0skYihoIhQJwsA5mYEeygBiY4y7rliOM/dMpGKQ9Vc+7tPeGVVSbCHJRIRFBQRqL7ekecrZtLgZAaltF8BwPYweVgqf/3hBZzZL4m7567hp/PXcPh4bbCHJRLWFBQRaPn2SnZ0UAHA9tA3KZ6/fO9sfnDREOatKOGqRz9i5Rf7gj0skbCloIhAefn+AoBTR3ZcAcC2FhMdxU+njmDO9yZRW+eY8cRSHn53C7V1ukFPpK0FFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTWlunmV1qZivNbLWZfWxm+oLlNnTgWA1vri/j2iAUAGwPZw9O4a0fXcC0Mf14ZNFWpj+xlKK9h4M9LJGw0mpQmFk08BhwJZADzDaznEbdbgX2OeeGAg8DD3jL5gCzgJHAVOBxM4tuZZ2/B/7OOTcW+Avws9PaQvmK19fsDGoBwPaQGN+F39w4lt/OHse28kNc9chHPPtJEfX1qhUl0hYCOaKYCBQ657Y556qBOcC0Rn2mAc95j+cDl5q/cNA0YI5z7rhzrggo9NbX0jodkOg9TgJ2fr1Nk6bk5Rczom8Co4NYALC9XDOmHwvvnszErGR+/vpGZj65lG3lh4I9LJFOL5Cg6A80vCW2xGtrso9zrhaoAlJaWLaldd4GvGlmJcC3gfubGpSZ3W5mPjPzlZeXB7AZsnnXAdaUVDEzN/gFANtLelJXnr1lAg/NGMOW3QeZ+shHPPHh55q7EDkNoTiZfTdwlXMuA/gj8JumOjnnnnLO5TrnclNTQ/fO4lCSl19Cl2jjuhApANhezIzp4zN4754LuXh4Kve/tZkbfv8pm3cdCPbQRDqlQIKiFGh4QjvDa2uyj5nF4D9lVNHCsk22m1kqMMY5t8xrnwucG9CWSIv8BQBLuCKnb0gVAGxPfRLjeeLvx/O7b42jdN9RvvHox/znm5t034XIKQokKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzv/N86swCY5V0VlQVkA8tbWOc+IMnMhnnruhzY9PU3T054b9Nu9h2p6TR3YrcVM+Pq0f14754LmTE+g6eWbOOy33zIW+vK9MVIIgGKaa2Dc67WzO4CFgLRwDPOuQ1m9kvA55xbADwNPG9mhUAl/g9+vH55wEagFrjTOVcH0NQ6vfbvAS+ZWT3+4Phum25xhMrzFZOeFM8FIVwAsD316h7L/d8czYzcAfzs1fXc8cJKLhyWyi+uHUlm7851d7pIR7Nw+K0qNzfX+Xy+YA8jZJVVHeW8+9/nzouH8k9XDA/2cIKutq6ePy3dwW/e3UJ1XT3fv3AI379wMN1iW/29SSSsmNkK51xua/1CcTJb2thLK7wCgOPD596J0xETHcV3z89i0T9dyJSRfXl00VYueehDXl5ZonsvRJqgoAhz/gKAJZwzOIWBKd2CPZyQkpYYz29nj2Pe98+hT2Ic9+St4brHP8G3vTLYQxMJKQqKMLesqJIvKjtvAcCOMCEzmVd/cB6/mTmG3QeOMf2Jpdz5l5UUVx4J9tBEQoJOyoa5PF8xCfExTD2zb7CHEtKioowbzspg6pl9efLDbTy55HPe3bCbv580iDsvHkJKj7hgD1EkaHREEcaqjtbw5roypo3tR3yXzl8AsCN0i43h7suHsfjHF3H9uP48+2kRkx9czMPvbuHgsZpgD08kKBQUYez1NTs5XhteBQA7SnpSVx6YPpp37p7M5GGpPLJoKxf++gOe/riIYzX6zm6JLAqKMJbn8xcAHNU//AoAdpShfRL4/d+P57U7zyMnPZF/f2Mjlzz0AS8u/4LqWtWPksigoAhTm8oOsDbMCwB2pDEDevLn287mhdvOJjUxnvteXsfFD33Anz/bwfFaHWFIeFNQhKk8XzGx0VFcH+YFADvaeUN78+oPzuXZWybQJzGOn726ngsf/IDnPt2uU1ISthQUYeh4bR2vrirl8pFp9IqQAoAdycy4aHgfXr7jXP5869kMSO7Kvy3YwOQHF/P0x0UcrVZgSHjR5bFh6L2Ne9h3pEaT2O3MzDg/uzfnDU1h6bYKHl20lX9/YyO/e38rN52TyU3nDNJltRIWFBRhKM9XTL+keM4f2jvYQ4kIZsa5Q3pz7pDe5G+v5MkPP+eRRVt5csnnzBg/gNsuyGJQigoPSueloAgzO/cfZcnWcv7h4qFER2kSu6NNyExmQmYyhXsO8tSSbczNL+aFZTu48sx0bp88mDEDegZ7iCKnTEERZl5aUYJzMEOnnYJqaJ8EHpw+hh9fMZw/frqdP3+2g7+uK2NiVjLfOTeTK3LSiInWFKF0DiozHkbq6x0XPrSYAb268ZfvTQr2cKSBQ8drmbP8C579dDsl+47SLymev5s0iNkTB0bMNw5K6FGZ8Qj0WVEFxZVHVQAwBPWIi+G2Cwbz4U8u5g835ZKV2p1fLyxg0n8t4sfz1rC+tCrYQxRplk49hZG8fH8BwCkjVQAwVEVHGZfnpHF5Thpbdx/kuaXbeXllKfNXlDB+UC9uOmcQU0b2VW0uCSkKijBRdbSGt9bvYmbuAH3IdBLZaQn8x3Wj+MmUEcxfUcLzS7fzj3NW07NbF24Yl8HsiQPITksI9jBFFBThYoEKAHZaSV27cOv5WdxybiZLt1Xw4vIveP6z7TzzSRG5g3oxe+JArhqVTtdY/QIgwaHJ7DBxzW8/prbe8eYPz1dtpzBQceg4L68s5cXlX7Bt72ES4mO4YVx/Zk4YwMh+KvIobSPQyWwdUYSBjTsPsK60in+7JkchESZSesTxvcmDue2CLJYXVfLi8i94Mb+Y55buYETfBL55VgbTxvajT2J8sIcqEUBBEQZOFAC8bqwKAIYbM+PswSmcPTiFnx+p5vW1Zby0ooRfvbmJ/3prE5OHpXLDWRlckZOmuSlpNwqKTu54bR2vrlYBwEjQs1ss3540iG9PGsTn5Yd4eWUJr6ws5YcvriIhLoZvjE7nhrMyyB3UiyjdlS9tSEHRyb27cTf7j9RwoyaxI8qQ1B78ZMoI/uny4XxWVMFLK0pZsGYnc/L9db6uHtOPq0enM6p/kk5HymnTZHYnd9Mzy/l8zyGW/PRi1XaKcEeqa3lnw25eX7OTJVvLqalzDErpxjWj+3H1mHSGpyUoNOQrNJkdAUr3H+WjreX8wyXZCgmhW2wM143rz3Xj+lN1pIaFG3bx+tqdPP5BIb9bXEh2nx5cPbof14xJZ3Bqj2APVzoRBUUndrIA4PiMYA9FQkxSty7MnDCAmRMGsPfQcd5aV8bra8t4+L0tPPzeFkb0TWDKyL5MGdmXM9J1pCEt06mnTqq+3jH514sZlNKNF25TAUAJTFnVUd5ct4uF63eRv6MS52BgcjemjExjysi+nDVQE+GRRKeewtxn2yoo2XeUn0wZHuyhSCeSntSVW8/P4tbzsyg/eJz3Nu1m4YZdPPvpdv7wURGpCXFcnpPG1JF9mTQ4hdgY1Q0VBUWnNddXTKIKAMppSE2IY/bEgcyeOJADx2pYvHkPCzfs4tVVpfxl2RckxMcweVgql47ow0XD+6gcegQLKCjMbCrwCBAN/K9z7v5Gr8cBfwLGAxXAjc657d5r9wG3AnXAD51zC1tap/lPlv4HMMNb5vfOuUdPbzPDS9URfwHAWRNUAFDaRmJ8F6aN7c+0sf05VlPHx1v38s7GXSwuKOeva8swg3EDenLJiD5cMiJN8xoRptWgMLNo4DHgcqAEyDezBc65jQ263Qrsc84NNbNZwAPAjWaWA8wCRgL9gPfMbJi3THPr/A4wABjhnKs3sz5tsaHhZMGaUqpVAFDaSXyXaC7LSeOynDTq6x3rd1bx/uY9vL95Dw+9s4WH3tlCelI8F4/owyXD+3De0N4qWBjmAjmimAgUOue2AZjZHGAa0DAopgE/9x7PB37nHRlMA+Y4544DRWZW6K2PFtZ5B/At51w9gHNuz9ffvPA011dMTnoiZ/ZXcThpX1FRxuiMnozO6MmPLhvGngPH+KCgnEWbd/Oad4oqLiaKiVnJTM5O5YJhvXW/RhgKJCj6A8UNnpcAZzfXxzlXa2ZVQIrX/lmjZU8UJGpunUPwH41cD5TjP121tfGgzOx24HaAgQMHBrAZ4WHDzirWlx7g59fkBHsoEoH6JMafvOz2eG0dy4sqWby5nI+2lvOrNzfBm/65jwuyezM5O5XzhvYmNSEu2MOW0xSKk9lxwDHnXK6Z3QA8A1zQuJNz7ingKfBfHtuxQwyeeb4SfwHAcSoAKMEVFxPNBdmpXJCdCvgvvf1o614+2rqXxZv38PLKUgBy0hO5YJg/OHIzexEXo9NUnU0gQVGKf87ghAyvrak+JWYWAyThn9Ruadnm2kuAl73HrwB/DGCMEeFYTR2vrCrlipFp9OymK1AktKQndWVm7gBm5g6gvt6xYecBlmwtZ8mWcp75uIgnP9xGfJcocgclc86QFCYNTmF0RhJdonUJbqgLJCjygWwzy8L/YT4L+FajPguAm4GlwHTgfeecM7MFwF/M7Df4J7OzgeWAtbDOV4GLgSLgQmDL1966MPPuxt1UHa3hxgmaxJbQFhVljMpIYlRGEndePJTDx2tZVlTBki17+WxbBb9eWABAt9hoJmQmM2lwCucMSeHMfonEKDhCTqtB4c053AUsxH8p6zPOuQ1m9kvA55xbADwNPO9NVlfi/+DH65eHf5K6FrjTOVcH0NQ6vbe8H3jBzO4GDgG3td3mdm55vmL69+zKeUN6B3soIqeke1wMl4xI45IRaYD/G/yWFVXy2bYKln5ewQNvbwYgIS6GCVnJnOMFxxnpiapjFgJUwqOTKNl3hAseXMwPL8nm7suHtb6ASCdSfvC4PzS2VfDZ5xVs23sYgIT4GMYP6sWEzGRyB/VizICeuneoDamER5h5aYV/CmdGrgoASvhJTYjjmjH9uGZMPwB2VR3js20VLN9eSX5RJR8U+E9VdYk2RvVP8geHFx76wq72pyOKTuBEAcDMlO78+bbGVyaLhL99h6tZsWMf+Tsq8W3fx9qS/dTU+T+7hvbpwYTMXuQOSmZCZjIDkrvqPo4A6YgijCz1CgD+dOqIYA9FJCh6dY89ebc4+K8AXFtSRf72SnzbK3ljbRkvLvffmpXSPZaxA3oybmBPxg3sxeiMJBLiuwRz+J2egqITmJtfTFLXLlzh/ScRiXTxXaKZmJXMxKxkwH/UvWXPQfK372P1F/tZXbyPRZv9RR3MILtPDy88ejF2QE+GpSVokvwUKChCXNWRGt7esIvZKgAo0qyoKGNE30RG9E3k25MGAf7/O2tK9rPKC453Nu4mz1cCQPfYaEZlJJ0MjtEZSfRNjNcpq2YoKELca14BwBkqAChySpK6dWHysFQmD/PfOe6cY0fFEVYV+486VhXv5w9LtlFb75/r6N0jjlH9ExnVP4lRGT0Z1T+JtMQ4hQcKipCX5ytmZD8VABQ5XWZGZu/uZPbuzvXj/FcPHqupY8POA6wvrWJtSRXrS6v4cEs5XnbQu0ccozOSOLN/EqP7+28gTEuMD+JWBIeCIoSdKAD4i2tHBnsoImEpvks04wf1YvygXifbjlTXsqnsAOtKqlhb6g+PDwr2nAyP1IQ4Rvf3h0dOv0Ry0hPJ6BXeV1opKEJYXn4xsTFRTBvbL9hDEYkY3WJjGD8omfGDkk+2HamuZePOA6wrrWJdSRXrSqtY3CA8EuJjOKNvIjn9EjkjPYEz0hMZlpYQNvOKCooQdaymjldX72TKyL4qACgSZN1iY/w3+GV+NTwKdh1kU9lBNpZVsansIHm+Yo5U1wEQHWUM7t3dCw//n5z0xE5Zdl1BEaLeOVEAUJPYIiGpW2wM4wb2YtzAL09b1dc7vqg8wsayA2wqO8DGnQfIL6rktdU7T/bp3SOOM9ITGJ6WwLC+/r+z03rQLTZ0P45Dd2QRbp5XAPDcISnBHoqIBCgq6ssJ86tGpZ9s33e4mk27/MGxqewgm8oO8KeiHVTX1p/sMyC5qz880hIY3tf/9+DU7iHx/R0KihBUsu8IHxfu5R8vzSZKNwWJdHq9usdy7pDenNug8nNdvWNHxWG27D7Ilt2HKNh9kC27DvJBQfnJS3ajo4zMlG4ng+PEn8yUbh1ajl1BEYLmr/DfFDR9vAoAioSr6ChjcGoPBqf2YOqZX7ZX19ZTtPfwyeDYsvsgG3ce4K31uzhRmi82Ooqs3t0ZmtaDe6eOYEByt3Ydq4IixNTXO+b5Sjh/aG8yerXvzheR0BMbE8Xwvv7TT4z5sv1odR2Few55RyAHKdxziHUlVcTGtP+RhYIixHz6eQWl+49y75UqACgiX+rqlR0ZldHxN9/qOwdDzFyfvwDg5SoAKCIhQkERQvYfqWbhhl1cP65/2NyoIyKdn4IihLy2eqdXAFCT2CISOhQUISTPV8yZ/RMZ2U8FAEUkdCgoQsT60io27DzATN2JLSIhRkERIvJ8XgHAMf2DPRQRka9QUISAYzV1vLqqlKkj+5LUTd/tKyKhRUERAhZu2MWBY7XcOEGnnUQk9CgoQsA8XwkZvbpyzmAVABSR0KOgCLLiSn8BwBnjB6gAoIiEJAVFkM1fUYIZTNe9EyISohQUQVRX75i/wl8AsH/PrsEejohIkxQUQfTp53sp3X9Uk9giEtIUFEE0N7+Ynt1UAFBEQpuCIkj2H6nmnQ27uW5s/5D4qkMRkeYEFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTTmGdj5rZoa+5XSHv1VWlVNfVq2SHiIS8VoPCzKKBx4ArgRxgtpnlNOp2K7DPOTcUeBh4wFs2B5gFjASmAo+bWXRr6zSzXKDXaW5bSMvzlTCqfxI5/RKDPRQRkRYFckQxESh0zm1zzlUDc4BpjfpMA57zHs8HLjUz89rnOOeOO+eKgEJvfc2u0wuRXwM/Pb1NC13rS6vYWHaAmbokVkQ6gUCCoj9Q3OB5idfWZB/nXC1QBaS0sGxL67wLWOCcK2tpUGZ2u5n5zMxXXl4ewGaEjjxfMXExUVw7VgUARST0hdRktpn1A2YAv22tr3PuKedcrnMuNzU1tf0H10ZOFgA8sy9JXVUAUERCXyBBUQo0nHHN8Nqa7GNmMUASUNHCss21jwOGAoVmth3oZmaFAW5Lp3CyAKAmsUWkkwgkKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzvnHNe+yzvqqgsIBtY3tw6nXN/dc71dc5lOucygSPeBHnYyPMVMyC5K5NUAFBEOomY1jo452rN7C5gIRANPOOc22BmvwR8zrkFwNPA895v/5X4P/jx+uUBG4Fa4E7nXB1AU+ts+80LLcWVR/iksIJ7Lh+mAoAi0mm0GhQAzrk3gTcbtf1rg8fH8M8tNLXsr4BfBbLOJvr0CGR8ncU8rwDgN8fraicR6TxCajI7nNXVO+b7irkgO1UFAEWkU1FQdJBPCveys+qYJrFFpNNRUHSQub5ienXrwmU5fYI9FBGRU6Kg6AD7Dlfz7obdXDdOBQBFpPNRUHSAV1erAKCIdF4KinbmnGNufjGjM5I4I10FAEWk81FQtLP1pQfYvOsgM3Q0ISKdlIKinZ0sADimX7CHIiLytSgo2tGxmjpeXV3KlSoAKCKdmIKiHb29fhcHj9Uyc4JOO4lI56WgaEcnCwBmqQCgiHReCop28kXFET79vIKZ4weoAKCIdGoKinYyf0WxCgCKSFhQULSDunrHvBUlTM5OpZ8KAIpIJ6egaAcfF+6lrOoYN2oSW0TCgIKiHeTl+wsAXnqGCgCKSOenoGhjlYereWfjLq4fl6ECgCISFhQUbezVVaXU1DlmTtAktoiEBwVFG3LOkecrZkxGEiP6qgCgiIQHBUUbWldapQKAIhJ2FBRt6GQBwLEqACgi4UNB0UaO1dTx2uqdXDUqncR4FQAUkfChoGgjJwsA6rSTiIQZBUUbmZtfzMDkbpydlRzsoYiItCkFRRvYUXGYpdsqmJmboQKAIhJ2FBRtYP6KEqJUAFBEwpSC4jTV1Tvmryhh8rBU0pNUAFBEwo+C4jR9tLWcsqpjmsQWkbCloDhNeb5ikrvHctkZacEeiohIu1BQnIbKw9W8u3E314/rT2yM/ilFJDwF9OlmZlPNrMDMCs3s3iZejzOzud7ry8wss8Fr93ntBWY2pbV1mtkLXvt6M3vGzEL27rVXThQA1GknEQljrQaFmUUDjwFXAjnAbDPLadTtVmCfc24o8DDwgLdsDjALGAlMBR43s+hW1vkCMAIYBXQFbjutLWwnzjnm+YoZM6Anw/smBHs4IiLtJpAjiolAoXNum3OuGpgDTGvUZxrwnPd4PnCpmZnXPsc5d9w5VwQUeutrdp3OuTedB1gOhOQ1p2tL/AUAZ+aG5PBERNpMIEHRHyhu8LzEa2uyj3OuFqgCUlpYttV1eqecvg283dSgzOx2M/OZma+8vDyAzWhbeb5i4rtEcc0YFQAUkfAWyjOwjwNLnHMfNfWic+4p51yucy43NTW1Qwd2tLqOBat3ctWZKgAoIuEvJoA+pUDD2doMr62pPiVmFgMkARWtLNvsOs3s34BU4P8EML4O9/aGMg4er2XmBE1ii0j4C+SIIh/INrMsM4vFPzm9oFGfBcDN3uPpwPveHMMCYJZ3VVQWkI1/3qHZdZrZbcAUYLZzrv70Nq99zM0vZlCKCgCKSGRo9YjCOVdrZncBC4Fo4Bnn3AYz+yXgc84tAJ4GnjezQqAS/wc/Xr88YCNQC9zpnKsDaGqd3ls+AewAlvrnw3nZOffLNtvi07Sj4jCfbavkJ1OG441PRCSsBXLqCefcm8Cbjdr+tcHjY8CMZpb9FfCrQNbptQc0pmCZ5/MKAJ6lq51EJDKE8mR2yDlRAPDCYan0TYoP9nBERDqEguIULNlazq4DKgAoIpFFQXEK8vKLSekey6UqACgiEURBEaCKQ8d5b5MKAIpI5NEnXoBOFgDUvRMiEmEUFAFwzpHnK2bsgJ4MS1MBQBGJLAqKAKwpqWLL7kOaxBaRiKSgCMCXBQDTgz0UEZEOp6BoxdHqOl5fvZOrRqWToAKAIhKBFBSteGu9vwDgjTrtJCIRSkHRirn5xWSmdGOiCgCKSIRSULRg+97DLCuqZEbuABUAFJGIpaBowbwVxSoAKCIRT0HRjNq6euavKOGi4X1UAFBEIpqCohkfbd3L7gPHmZmrowkRiWwKimbM9QoAXjJCBQBFJLIpKJqgAoAiIl/Sp2ATXllVSm2940YVABQRUVA05pxjbn4x4wb2JFsFAEVEFBSNrS7ez9Y9KgAoInKCgqKRPF8JXbtEc/VoFQAUEQEFxVccqa7l9TUqACgi0pCCooG31u3i0PFaTWKLiDSgoGhgrq+YrN7dmZDZK9hDEREJGQoKT9HewywvqmRGboYKAIqINKCg8MzzqQCgiEhTFBT4CwC+tLKEi4f3IS1RBQBFRBpSUABLtpaz+8BxZujeCRGRv6GgwF8AsHePWC49o0+whyIiEnIiPij2HjrOok17uH5cf7pER/w/h4jI34j4T8ZXVqoAoIhISwIKCjObamYFZlZoZvc28Xqcmc31Xl9mZpkNXrvPay8wsymtrdPMsrx1FHrrjD3NbWyWc448XzFnDezJ0D4qACgi0pRWg8LMooHHgCuBHGC2meU06nYrsM85NxR4GHjAWzYHmAWMBKYCj5tZdCvrfAB42FvXPm/d7WKVCgCKiLQqkCOKiUChc26bc64amANMa9RnGvCc93g+cKn571qbBsxxzh13zhUBhd76mlynt8wl3jrw1nnd1966VszzFfsLAI7p115vISLS6QUSFP2B4gbPS7y2Jvs452qBKiClhWWba08B9nvraO69ADCz283MZ2a+8vLyADbjbw1M7s53zsukR1zM11peRCQSdNpPSOfcU8BTALm5ue7rrOOOi4a06ZhERMJRIEcUpUDDk/gZXluTfcwsBkgCKlpYtrn2CqCnt47m3ktERDpQIEGRD2R7VyPF4p+cXtCozwLgZu/xdOB955zz2md5V0VlAdnA8ubW6S2z2FsH3jpf+/qbJyIip6vVU0/OuVozuwtYCEQDzzjnNpjZLwGfc24B8DTwvJkVApX4P/jx+uUBG4Fa4E7nXB1AU+v03vKfgTlm9h/AKm/dIiISJOb/Jb5zy83NdT6fL9jDEBHpVMxshXMut7V+EX9ntoiItExBISIiLVJQiIhIixQUIiLSorCYzDazcmDH11y8N7C3DYfTGWibI4O2Ofyd7vYOcs6lttYpLILidJiZL5BZ/3CibY4M2ubw11Hbq1NPIiLSIgWFiIi0SEHhFRaMMNrmyKBtDn8dsr0RP0chIiIt0xGFiIi0SEEhIiItiuigMLOpZlZgZoVmdm+wx3MqzGyAmS02s41mtsHM/tFrTzazd81sq/d3L6/dzOxRb1vXmtlZDdZ1s9d/q5nd3KB9vJmt85Z51Puq2qDzvnd9lZm94T3PMrNl3jjneqXr8crbz/Xal5lZZoN13Oe1F5jZlAbtIfczYWY9zWy+mW02s01mdk6472czu9v7uV5vZi+aWXy47Wcze8bM9pjZ+gZt7b5fm3uPFjnnIvIP/vLmnwODgVhgDZAT7HGdwvjTgbO8xwnAFiAHeBC412u/F3jAe3wV8BZgwCRgmdeeDGzz/u7lPe7lvbbc62veslcGe7u9cd0D/AV4w3ueB8zyHj8B3OE9/gHwhPd4FjDXe5zj7e84IMv7OYgO1Z8J/N8df5v3OBboGc77Gf/XHxcBXRvs3++E234GJgNnAesbtLX7fm3uPVoca7D/EwTxh/EcYGGD5/cB9wV7XKexPa8BlwMFQLrXlg4UeI+fBGY36F/gvT4beLJB+5NeWzqwuUH7V/oFcTszgEXAJcAb3n+CvUBM4/2K//tOzvEex3j9rPG+PtEvFH8m8H9bZBHehSeN91847mf8QVHsffjFePt5SjjuZyCTrwZFu+/X5t6jpT+RfOrpxA/jCSVeW6fjHWqPA5YBac65Mu+lXUCa97i57W2pvaSJ9mD7H+CnQL33PAXY75yr9Z43HOfJbfNer/L6n+q/RTBlAeXAH73Tbf9rZt0J4/3snCsFHgK+AMrw77cVhPd+PqEj9mtz79GsSA6KsGBmPYCXgB855w40fM35f2UIm+ufzexqYI9zbkWwx9KBYvCfnvi9c24ccBj/6YKTwnA/9wKm4Q/JfkB3YGpQBxUEHbFfA32PSA6KUmBAg+cZXlunYWZd8IfEC865l73m3WaW7r2eDuzx2pvb3pbaM5poD6bzgGvNbDswB//pp0eAnmZ24mt9G47z5LZ5rycBFZz6v0UwlQAlzrll3vP5+IMjnPfzZUCRc67cOVcDvIx/34fzfj6hI/Zrc+/RrEgOinwg27uSIhb/JNiCII8pYN4VDE8Dm5xzv2nw0gLgxJUPN+OfuzjRfpN39cQkoMo7/FwIXGFmvbzf5K7Af/62DDhgZpO897qpwbqCwjl3n3MuwzmXiX9/ve+c+ztgMTDd69Z4m0/8W0z3+juvfZZ3tUwWkI1/4i/kfiacc7uAYjMb7jVdiv876MN2P+M/5TTJzLp5YzqxzWG7nxvoiP3a3Hs0L5iTVsH+g/9Kgi34r4D4l2CP5xTHfj7+Q8a1wGrvz1X4z80uArYC7wHJXn8DHvO2dR2Q22Bd3wUKvT+3NGjPBdZ7y/yORhOqQd7+i/jyqqfB+D8ACoF5QJzXHu89L/ReH9xg+X/xtquABlf5hOLPBDAW8Hn7+lX8V7eE9X4GfgFs9sb1PP4rl8JqPwMv4p+DqcF/5HhrR+zX5t6jpT8q4SEiIi2K5FNPIiISAAWFiIi0SEEhIiItUlCIiEiLFBQiItIiBYWIiLRIQSEiIi36/zob5nVzA95IAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"lrs=[]\n",
"for i in range(100000):\n",
" sc.step()\n",
" lrs.append(sc.get_lr())\n",
"xs = list(range(100000))\n",
"plt.plot(xs, lrs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e613fe16",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0fd9f40",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 94,
"id": "matched-camera",
"metadata": {},
"outputs": [],
"source": [
"from nnAudio import Spectrogram\n",
"from scipy.io import wavfile\n",
"import torch\n",
"import soundfile as sf\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "quarterly-solution",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[43 75 69 ... 7 6 3]\n",
"[43 75 69 ... 7 6 3]\n",
"[43 75 69 ... 7 6 3]\n"
]
}
],
"source": [
"import scipy.io.wavfile as wav\n",
"\n",
"rate,sig = wav.read('./BAC009S0764W0124.wav')\n",
"sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n",
"sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sig)\n",
"print(song)\n",
"print(sample)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "middle-salem",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"[43 75 69 ... 7 6 3]\n",
"(83792,)\n",
"int16\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.2733 seconds\n",
"tensor([[[[-4.0940e+03, 1.2600e+04],\n",
" [ 8.5108e+03, -5.4930e+03],\n",
" [-3.3631e+03, -1.7904e+03],\n",
" ...,\n",
" [ 8.2279e+03, -9.3340e+03],\n",
" [-3.1990e+03, 2.0969e+03],\n",
" [-1.2669e+03, 4.4488e+03]],\n",
"\n",
" [[ 3.4886e+03, -9.9620e+03],\n",
" [-4.5364e+03, 4.1907e+02],\n",
" [ 2.5074e+03, 7.1339e+03],\n",
" ...,\n",
" [-5.4819e+03, 3.9258e+01],\n",
" [ 4.7221e+03, 6.5887e+01],\n",
" [ 9.6492e+02, -3.4386e+03]],\n",
"\n",
" [[-3.4947e+03, 9.2981e+03],\n",
" [-7.5164e+03, 8.1856e+02],\n",
" [-5.3766e+03, -9.0889e+03],\n",
" ...,\n",
" [ 1.4317e+03, 5.7447e+03],\n",
" [-3.1178e+03, 3.0740e+03],\n",
" [-3.4351e+03, 5.6900e+02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.7112e+01, -4.5737e+00],\n",
" [-9.6295e+00, 3.5554e+01],\n",
" [ 1.8527e+00, -1.0491e+01],\n",
" ...,\n",
" [-1.1157e+01, 3.4423e+00],\n",
" [ 3.1193e+00, -4.4388e+00],\n",
" [-8.8242e+00, 8.0324e+00]],\n",
"\n",
" [[-6.5080e+01, 2.9543e+00],\n",
" [ 3.9992e+01, -1.3836e+01],\n",
" [-9.2803e+00, 1.0318e+01],\n",
" ...,\n",
" [ 4.2928e+00, 9.2397e+00],\n",
" [ 3.6642e+00, 9.4680e+00],\n",
" [ 4.8932e+00, -2.5199e+01]],\n",
"\n",
" [[ 4.7264e+01, -1.0721e+00],\n",
" [-6.0516e+00, -1.4589e+01],\n",
" [ 1.3127e+01, 1.4995e+00],\n",
" ...,\n",
" [ 1.7333e+01, -1.4380e+01],\n",
" [-3.6046e+00, -6.1019e+00],\n",
" [ 1.3321e+01, 2.3184e+01]]]])\n"
]
}
],
"source": [
"sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n",
"print(sr)\n",
"print(song)\n",
"print(song.shape)\n",
"print(song.dtype)\n",
"x = song\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n",
" window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n",
" fmin=50,fmax=8000, sr=sr) # Initializing the model\n",
"\n",
"spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"print(spec)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "finished-sterling",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"[43 75 69 ... 7 6 3]\n",
"(83792,)\n",
"int16\n",
"True\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.2001 seconds\n",
"torch.Size([1, 1025, 164, 2])\n",
"tensor([[[[-4.0940e+03, 1.2600e+04],\n",
" [ 8.5108e+03, -5.4930e+03],\n",
" [-3.3631e+03, -1.7904e+03],\n",
" ...,\n",
" [ 8.2279e+03, -9.3340e+03],\n",
" [-3.1990e+03, 2.0969e+03],\n",
" [-1.2669e+03, 4.4488e+03]],\n",
"\n",
" [[ 3.4886e+03, -9.9620e+03],\n",
" [-4.5364e+03, 4.1907e+02],\n",
" [ 2.5074e+03, 7.1339e+03],\n",
" ...,\n",
" [-5.4819e+03, 3.9258e+01],\n",
" [ 4.7221e+03, 6.5887e+01],\n",
" [ 9.6492e+02, -3.4386e+03]],\n",
"\n",
" [[-3.4947e+03, 9.2981e+03],\n",
" [-7.5164e+03, 8.1856e+02],\n",
" [-5.3766e+03, -9.0889e+03],\n",
" ...,\n",
" [ 1.4317e+03, 5.7447e+03],\n",
" [-3.1178e+03, 3.0740e+03],\n",
" [-3.4351e+03, 5.6900e+02]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.7112e+01, -4.5737e+00],\n",
" [-9.6295e+00, 3.5554e+01],\n",
" [ 1.8527e+00, -1.0491e+01],\n",
" ...,\n",
" [-1.1157e+01, 3.4423e+00],\n",
" [ 3.1193e+00, -4.4388e+00],\n",
" [-8.8242e+00, 8.0324e+00]],\n",
"\n",
" [[-6.5080e+01, 2.9543e+00],\n",
" [ 3.9992e+01, -1.3836e+01],\n",
" [-9.2803e+00, 1.0318e+01],\n",
" ...,\n",
" [ 4.2928e+00, 9.2397e+00],\n",
" [ 3.6642e+00, 9.4680e+00],\n",
" [ 4.8932e+00, -2.5199e+01]],\n",
"\n",
" [[ 4.7264e+01, -1.0721e+00],\n",
" [-6.0516e+00, -1.4589e+01],\n",
" [ 1.3127e+01, 1.4995e+00],\n",
" ...,\n",
" [ 1.7333e+01, -1.4380e+01],\n",
" [-3.6046e+00, -6.1019e+00],\n",
" [ 1.3321e+01, 2.3184e+01]]]])\n",
"True\n"
]
}
],
"source": [
"wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sr)\n",
"print(wav)\n",
"print(wav.shape)\n",
"print(wav.dtype)\n",
"print(np.allclose(wav, song))\n",
"\n",
"x = wav\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n",
" window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n",
" fmin=50,fmax=8000, sr=sr) # Initializing the model\n",
"\n",
"wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"print(wav_spec.shape)\n",
"print(wav_spec)\n",
"print(np.allclose(wav_spec, spec))"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "running-technology",
"metadata": {},
"outputs": [],
"source": [
"import decimal\n",
"\n",
"import numpy\n",
"import math\n",
"import logging\n",
"def round_half_up(number):\n",
" return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n",
"\n",
"\n",
"def rolling_window(a, window, step=1):\n",
" # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n",
" shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n",
" strides = a.strides + (a.strides[-1],)\n",
" return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n",
"\n",
"\n",
"def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n",
" \"\"\"Frame a signal into overlapping frames.\n",
"\n",
" :param sig: the audio signal to frame.\n",
" :param frame_len: length of each frame measured in samples.\n",
" :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n",
" :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n",
" :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n",
" :returns: an array of frames. Size is NUMFRAMES by frame_len.\n",
" \"\"\"\n",
" slen = len(sig)\n",
" frame_len = int(round_half_up(frame_len))\n",
" frame_step = int(round_half_up(frame_step))\n",
" if slen <= frame_len:\n",
" numframes = 1\n",
" else:\n",
" numframes = 1 + (( slen - frame_len) // frame_step)\n",
"\n",
" # check kaldi/src/feat/feature-window.h\n",
" padsignal = sig[:(numframes-1)*frame_step+frame_len]\n",
" if wintype is 'povey':\n",
" win = numpy.empty(frame_len)\n",
" for i in range(frame_len):\n",
" win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n",
" else: # the hamming window\n",
" win = numpy.hamming(frame_len)\n",
" \n",
" if stride_trick:\n",
" frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n",
" else:\n",
" indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n",
" numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n",
" indices = numpy.array(indices, dtype=numpy.int32)\n",
" frames = padsignal[indices]\n",
" win = numpy.tile(win, (numframes, 1))\n",
" \n",
" frames = frames.astype(numpy.float32)\n",
" raw_frames = numpy.zeros(frames.shape)\n",
" for frm in range(frames.shape[0]):\n",
" raw_frames[frm,:] = frames[frm,:]\n",
" frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n",
" frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n",
" # raw_frames[frm,:] = frames[frm,:]\n",
" frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n",
"\n",
" return frames * win, raw_frames\n",
"\n",
"\n",
"def magspec(frames, NFFT):\n",
" \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n",
"\n",
" :param frames: the array of frames. Each row is a frame.\n",
" :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n",
" :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n",
" \"\"\"\n",
" if numpy.shape(frames)[1] > NFFT:\n",
" logging.warn(\n",
" 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n",
" numpy.shape(frames)[1], NFFT)\n",
" complex_spec = numpy.fft.rfft(frames, NFFT)\n",
" return numpy.absolute(complex_spec)\n",
"\n",
"\n",
"def powspec(frames, NFFT):\n",
" \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n",
"\n",
" :param frames: the array of frames. Each row is a frame.\n",
" :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n",
" :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n",
" \"\"\"\n",
" return numpy.square(magspec(frames, NFFT))\n",
"\n",
"\n",
"def do_dither(signal, dither_value=1.0):\n",
" signal += numpy.random.normal(size=signal.shape) * dither_value\n",
" return signal\n",
" \n",
"def do_remove_dc_offset(signal):\n",
" signal -= numpy.mean(signal)\n",
" return signal\n",
"\n",
"def do_preemphasis(signal, coeff=0.97):\n",
" \"\"\"perform preemphasis on the input signal.\n",
"\n",
" :param signal: The signal to filter.\n",
" :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n",
" :returns: the filtered signal.\n",
" \"\"\"\n",
" return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "ignored-retreat",
"metadata": {},
"outputs": [],
"source": [
"def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n",
" wintype='hamming'):\n",
" highfreq= highfreq or samplerate/2\n",
" frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n",
" spec = magspec(frames, nfft) # nearly the same until this part\n",
" rspec = magspec(raw_frames, nfft)\n",
" return spec, rspec\n",
"\n",
"\n",
"\n",
"def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n",
" wintype='hamming'):\n",
" highfreq= highfreq or samplerate/2\n",
" frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n",
" return raw_frames"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "federal-teacher",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"import scipy # used only in CFP\n",
"\n",
"import numpy as np\n",
"from time import time\n",
"\n",
"def pad_center(data, size, axis=-1, **kwargs):\n",
"\n",
" kwargs.setdefault('mode', 'constant')\n",
"\n",
" n = data.shape[axis]\n",
"\n",
" lpad = int((size - n) // 2)\n",
"\n",
" lengths = [(0, 0)] * data.ndim\n",
" lengths[axis] = (lpad, int(size - n - lpad))\n",
"\n",
" if lpad < 0:\n",
" raise ParameterError(('Target size ({:d}) must be '\n",
" 'at least input size ({:d})').format(size, n))\n",
"\n",
" return np.pad(data, lengths, **kwargs)\n",
"\n",
"\n",
"\n",
"sz_float = 4 # size of a float\n",
"epsilon = 10e-8 # fudge factor for normalization\n",
"\n",
"def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n",
" freq_scale='linear', window='hann', verbose=True):\n",
"\n",
" if freq_bins==None: freq_bins = n_fft//2+1\n",
" if win_length==None: win_length = n_fft\n",
"\n",
" s = np.arange(0, n_fft, 1.)\n",
" wsin = np.empty((freq_bins,1,n_fft))\n",
" wcos = np.empty((freq_bins,1,n_fft))\n",
" start_freq = fmin\n",
" end_freq = fmax\n",
" bins2freq = []\n",
" binslist = []\n",
"\n",
" # num_cycles = start_freq*d/44000.\n",
" # scaling_ind = np.log(end_freq/start_freq)/k\n",
"\n",
" # Choosing window shape\n",
"\n",
" #window_mask = get_window(window, int(win_length), fftbins=True)\n",
" window_mask = np.hamming(int(win_length))\n",
" window_mask = pad_center(window_mask, n_fft)\n",
"\n",
" if freq_scale == 'linear':\n",
" if verbose==True:\n",
" print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n",
" f\"get a valid freq range\")\n",
" \n",
" start_bin = start_freq*n_fft/sr\n",
" scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n",
"\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n",
" bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n",
" binslist.append((k*scaling_ind+start_bin))\n",
" wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n",
"\n",
" elif freq_scale == 'log':\n",
" if verbose==True:\n",
" print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n",
" f\"get a valid freq range\")\n",
" start_bin = start_freq*n_fft/sr\n",
" scaling_ind = np.log(end_freq/start_freq)/freq_bins\n",
"\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n",
" bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n",
" binslist.append((np.exp(k*scaling_ind)*start_bin))\n",
" wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n",
"\n",
" elif freq_scale == 'no':\n",
" for k in range(freq_bins): # Only half of the bins contain useful info\n",
" bins2freq.append(k*sr/n_fft)\n",
" binslist.append(k)\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n",
" else:\n",
" print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n",
" return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n",
"\n",
"\n",
"\n",
"def broadcast_dim(x):\n",
" \"\"\"\n",
" Auto broadcast input so that it can fits into a Conv1d\n",
" \"\"\"\n",
"\n",
" if x.dim() == 2:\n",
" x = x[:, None, :]\n",
" elif x.dim() == 1:\n",
" # If nn.DataParallel is used, this broadcast doesn't work\n",
" x = x[None, None, :]\n",
" elif x.dim() == 3:\n",
" pass\n",
" else:\n",
" raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n",
" return x\n",
"\n",
"\n",
"\n",
"### --------------------------- Spectrogram Classes ---------------------------###\n",
"class STFT(torch.nn.Module):\n",
"\n",
" def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n",
" freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n",
" fmin=50, fmax=6000, sr=22050, trainable=False,\n",
" output_format=\"Complex\", verbose=True):\n",
"\n",
" super().__init__()\n",
"\n",
" # Trying to make the default setting same as librosa\n",
" if win_length==None: win_length = n_fft\n",
" if hop_length==None: hop_length = int(win_length // 4)\n",
"\n",
" self.output_format = output_format\n",
" self.trainable = trainable\n",
" self.stride = hop_length\n",
" self.center = center\n",
" self.pad_mode = pad_mode\n",
" self.n_fft = n_fft\n",
" self.freq_bins = freq_bins\n",
" self.trainable = trainable\n",
" self.pad_amount = self.n_fft // 2\n",
" self.window = window\n",
" self.win_length = win_length\n",
" self.iSTFT = iSTFT\n",
" self.trainable = trainable\n",
" start = time()\n",
"\n",
"\n",
"\n",
" # Create filter windows for stft\n",
" kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n",
" win_length=win_length,\n",
" freq_bins=freq_bins,\n",
" window=window,\n",
" freq_scale=freq_scale,\n",
" fmin=fmin,\n",
" fmax=fmax,\n",
" sr=sr,\n",
" verbose=verbose)\n",
"\n",
"\n",
" kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n",
" kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n",
" \n",
" # In this way, the inverse kernel and the forward kernel do not share the same memory...\n",
" kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n",
" kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n",
" \n",
" if iSTFT:\n",
" self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n",
" self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n",
"\n",
" # Applying window functions to the Fourier kernels\n",
" if window:\n",
" window_mask = torch.tensor(window_mask)\n",
" wsin = kernel_sin * window_mask\n",
" wcos = kernel_cos * window_mask\n",
" else:\n",
" wsin = kernel_sin\n",
" wcos = kernel_cos\n",
" \n",
" if self.trainable==False:\n",
" self.register_buffer('wsin', wsin)\n",
" self.register_buffer('wcos', wcos) \n",
" \n",
" if self.trainable==True:\n",
" wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n",
" wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n",
" self.register_parameter('wsin', wsin)\n",
" self.register_parameter('wcos', wcos) \n",
" \n",
" # Prepare the shape of window mask so that it can be used later in inverse\n",
" # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n",
" \n",
" if verbose==True:\n",
" print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n",
" else:\n",
" pass\n",
"\n",
" def forward(self, x, output_format=None):\n",
" \"\"\"\n",
" Convert a batch of waveforms to spectrograms.\n",
" \n",
" Parameters\n",
" ----------\n",
" x : torch tensor\n",
" Input signal should be in either of the following shapes.\\n\n",
" 1. ``(len_audio)``\\n\n",
" 2. ``(num_audio, len_audio)``\\n\n",
" 3. ``(num_audio, 1, len_audio)``\n",
" It will be automatically broadcast to the right shape\n",
" \n",
" output_format : str\n",
" Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n",
" Default value is ``Complex``. \n",
" \n",
" \"\"\"\n",
" output_format = output_format or self.output_format\n",
" self.num_samples = x.shape[-1]\n",
" \n",
" x = broadcast_dim(x)\n",
" if self.center:\n",
" if self.pad_mode == 'constant':\n",
" padding = nn.ConstantPad1d(self.pad_amount, 0)\n",
"\n",
" elif self.pad_mode == 'reflect':\n",
" if self.num_samples < self.pad_amount:\n",
" raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n",
" padding = nn.ReflectionPad1d(self.pad_amount)\n",
"\n",
" x = padding(x)\n",
" spec_imag = conv1d(x, self.wsin, stride=self.stride)\n",
" spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n",
"\n",
" # remove redundant parts\n",
" spec_real = spec_real[:, :self.freq_bins, :]\n",
" spec_imag = spec_imag[:, :self.freq_bins, :]\n",
"\n",
" if output_format=='Magnitude':\n",
" spec = spec_real.pow(2) + spec_imag.pow(2)\n",
" if self.trainable==True:\n",
" return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n",
" else:\n",
" return torch.sqrt(spec)\n",
"\n",
" elif output_format=='Complex':\n",
" return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n",
"\n",
" elif output_format=='Phase':\n",
" return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n",
"\n",
" def inverse(self, X, onesided=True, length=None, refresh_win=True):\n",
" \"\"\"\n",
" This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n",
" which is to convert spectrograms back to waveforms. \n",
" It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n",
" please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n",
" \n",
" Parameters\n",
" ----------\n",
" onesided : bool\n",
" If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n",
" else use ``onesided=False``\n",
"\n",
" length : int\n",
" To make sure the inverse STFT has the same output length of the original waveform, please\n",
" set `length` as your intended waveform length. By default, ``length=None``,\n",
" which will remove ``n_fft//2`` samples from the start and the end of the output.\n",
" \n",
" refresh_win : bool\n",
" Recalculating the window sum square. If you have an input with fixed number of timesteps,\n",
" you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n",
" \n",
" \n",
" \"\"\"\n",
" if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n",
" raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n",
" \n",
" assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n",
" \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n",
" \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n",
" if onesided:\n",
" X = extend_fbins(X) # extend freq\n",
"\n",
" \n",
" X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n",
"\n",
" # broadcast dimensions to support 2D convolution\n",
" X_real_bc = X_real.unsqueeze(1)\n",
" X_imag_bc = X_imag.unsqueeze(1)\n",
" a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n",
" b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n",
" \n",
" # compute real and imag part. signal lies in the real part\n",
" real = a1 - b2\n",
" real = real.squeeze(-2)*self.window_mask\n",
"\n",
" # Normalize the amplitude with n_fft\n",
" real /= (self.n_fft)\n",
"\n",
" # Overlap and Add algorithm to connect all the frames\n",
" real = overlap_add(real, self.stride)\n",
" \n",
" # Prepare the window sumsqure for division\n",
" # Only need to create this window once to save time\n",
" # Unless the input spectrograms have different time steps\n",
" if hasattr(self, 'w_sum')==False or refresh_win==True:\n",
" self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n",
" self.nonzero_indices = (self.w_sum>1e-10) \n",
" else:\n",
" pass\n",
" real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n",
" # Remove padding\n",
" if length is None: \n",
" if self.center:\n",
" real = real[:, self.pad_amount:-self.pad_amount]\n",
"\n",
" else:\n",
" if self.center:\n",
" real = real[:, self.pad_amount:self.pad_amount + length] \n",
" else:\n",
" real = real[:, :length] \n",
" \n",
" return real\n",
" \n",
" def extra_repr(self) -> str:\n",
" return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n",
" self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n",
" ) "
]
},
{
"cell_type": "code",
"execution_count": 128,
"id": "unusual-baker",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000\n",
"(83792,)\n",
"sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n",
"STFT kernels created, time used = 0.0153 seconds\n",
"torch.Size([521, 257])\n",
"(522, 257)\n",
"[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n",
" 2.32206573e+01 1.90274487e+01]\n",
" [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n",
" 1.80534729e+02 1.84179596e+02]\n",
" [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n",
" 1.21321175e+02 1.08345497e+02]\n",
" ...\n",
" [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n",
" 1.25670158e+02 1.20691467e+02]\n",
" [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n",
" 9.01831894e+01 9.84055099e+01]\n",
" [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n",
" 8.94473553e+00 9.61348629e+00]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]\n",
" [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n",
" 6.36197377e+01 4.40000000e+01]]\n"
]
}
],
"source": [
"wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n",
"print(sr)\n",
"print(wav.shape)\n",
"\n",
"x = wav\n",
"x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n",
"\n",
"spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n",
" window='', freq_scale='linear', center=False, pad_mode='constant',\n",
" fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n",
"wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n",
"wav_spec = wav_spec[0].T\n",
"print(wav_spec.shape)\n",
"\n",
"\n",
"spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"print(spec.shape)\n",
"\n",
"print(wav_spec.numpy())\n",
"print(rspec)\n",
"# print(spec)\n",
"\n",
"# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n",
"# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
"# dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
"# wintype='hamming')\n",
"# print(rspec)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "white-istanbul",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 129,
"id": "modern-rescue",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n",
" 0.75 0.41317591 0.11697778 0. ]\n"
]
},
{
"data": {
"text/plain": [
"array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n",
" 0.9045085, 0.6545085, 0.3454915, 0.0954915])"
]
},
"execution_count": 129,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(np.hanning(10))\n",
"from scipy.signal import get_window\n",
"get_window('hann', 10, fftbins=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "professional-journalism",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 153,
"id": "involved-motion",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(522, 400)\n",
"[[ 43. 75. 69. ... 46. 46. 45.]\n",
" [ 210. 215. 216. ... -86. -89. -91.]\n",
" [ 128. 128. 128. ... -154. -151. -151.]\n",
" ...\n",
" [ -60. -61. -61. ... 112. 109. 110.]\n",
" [ 20. 22. 24. ... 91. 87. 87.]\n",
" [ 111. 107. 108. ... -6. -4. -8.]]\n",
"torch.Size([1, 1, 83792])\n",
"torch.Size([400, 1, 512])\n",
"torch.Size([1, 400, 521])\n",
"conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n",
" [ 210., 215., 216., ..., -86., -89., -91.],\n",
" [ 128., 128., 128., ..., -154., -151., -151.],\n",
" ...,\n",
" [-143., -141., -142., ..., 96., 101., 101.],\n",
" [ -60., -61., -61., ..., 112., 109., 110.],\n",
" [ 20., 22., 24., ..., 91., 87., 87.]])\n",
"xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n",
" 2.4064583e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n",
" 9.5781029e+01 1.4200000e+02]\n",
" ...\n",
" [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n",
" 9.1511757e+01 1.1500000e+02]\n",
" [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n",
" 7.8405365e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n",
" 5.1310158e+01 3.5000000e+01]]\n",
"torch.Size([521, 257])\n",
"yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n",
" 9.15117270e+01 1.15000000e+02]\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]]\n",
"yy (522, 257)\n",
"[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n",
" 2.4064583e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n",
" 9.5781029e+01 1.4200000e+02]\n",
" ...\n",
" [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n",
" 9.1511757e+01 1.1500000e+02]\n",
" [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n",
" 7.8405365e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n",
" 5.1310158e+01 3.5000000e+01]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n",
" 9.15117270e+01 1.15000000e+02]\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]]\n",
"False\n"
]
}
],
"source": [
"f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"print(f.shape)\n",
"print(f)\n",
"\n",
"n_fft=512\n",
"freq_bins = n_fft//2+1\n",
"s = np.arange(0, n_fft, 1.)\n",
"wsin = np.empty((freq_bins,1,n_fft))\n",
"wcos = np.empty((freq_bins,1,n_fft))\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n",
"\n",
"\n",
"wsin = np.empty((n_fft,1,n_fft))\n",
"wcos = np.empty((n_fft,1,n_fft))\n",
"for k in range(n_fft): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n",
" wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n",
" \n",
" \n",
"wsin = np.empty((400,1,n_fft))\n",
"wcos = np.empty((400,1,n_fft))\n",
"for k in range(400): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.eye(400, n_fft)[k]\n",
" wcos[k,0,:] = np.eye(400, n_fft)[k]\n",
" \n",
"\n",
" \n",
"x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n",
"x = x[None, None, :]\n",
"print(x.size())\n",
"kernel_sin = torch.tensor(wsin, dtype=torch.float)\n",
"kernel_cos = torch.tensor(wcos, dtype=torch.float)\n",
"print(kernel_sin.size())\n",
"\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"spec_imag = conv1d(x, kernel_sin, stride=160)\n",
"spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n",
"\n",
"print(spec_imag.size())\n",
"print(\"conv frame\", spec_imag[0].T)\n",
"# print(spec_imag[0].T[:, :400])\n",
"\n",
"# remove redundant parts\n",
"# spec_real = spec_real[:, :freq_bins, :]\n",
"# spec_imag = spec_imag[:, :freq_bins, :]\n",
"# spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"# spec = torch.sqrt(spec)\n",
"# print(spec)\n",
"\n",
"\n",
"\n",
"s = np.arange(0, 512, 1.)\n",
"# s = s[::-1]\n",
"wsin = np.empty((freq_bins, 400))\n",
"wcos = np.empty((freq_bins, 400))\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n",
" wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n",
"\n",
"spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n",
"spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n",
"\n",
"\n",
"# remove redundant parts\n",
"spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"spec = torch.sqrt(spec)\n",
"\n",
"print('xx', spec.numpy())\n",
"print(spec.size())\n",
"print('yy', rspec[:521, :])\n",
"print('yy', rspec.shape)\n",
"\n",
"\n",
"x = spec.numpy()\n",
"y = rspec[:-1, :]\n",
"print(x)\n",
"print(y)\n",
"print(np.allclose(x, y))"
]
},
{
"cell_type": "code",
"execution_count": 160,
"id": "mathematical-traffic",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([257, 1, 400])\n",
"tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n",
" 3.8693e+04, 3.1020e+04],\n",
" [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n",
" 1.3598e+04, 1.5920e+04],\n",
" [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n",
" 6.7707e+03, 4.3020e+03],\n",
" ...,\n",
" [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n",
" 6.1071e+01, 5.3685e+01],\n",
" [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n",
" 5.1310e+01, 6.3620e+01],\n",
" [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n",
" 3.5000e+01, 4.4000e+01]]])\n",
"[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n",
" 2.4064537e+01 2.2000000e+01]\n",
" [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n",
" 1.1877571e+02 1.6200000e+02]\n",
" [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n",
" 9.5781006e+01 1.4200000e+02]\n",
" ...\n",
" [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n",
" 7.8405350e+01 9.0000000e+01]\n",
" [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n",
" 5.1310150e+01 3.5000000e+01]\n",
" [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n",
" 6.3619797e+01 4.4000000e+01]]\n",
"[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n",
" 2.40645984e+01 2.20000000e+01]\n",
" [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n",
" 1.18775735e+02 1.62000000e+02]\n",
" [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n",
" 9.57810428e+01 1.42000000e+02]\n",
" ...\n",
" [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n",
" 7.84053656e+01 9.00000000e+01]\n",
" [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n",
" 5.13101944e+01 3.50000000e+01]\n",
" [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n",
" 6.36197377e+01 4.40000000e+01]]\n",
"False\n"
]
}
],
"source": [
"f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n",
" nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n",
" dither=0.0,remove_dc_offset=False, preemph=1.0, \n",
" wintype='hamming')\n",
"\n",
"n_fft=512\n",
"freq_bins = n_fft//2+1\n",
"s = np.arange(0, n_fft, 1.)\n",
"wsin = np.empty((freq_bins,1,400))\n",
"wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n",
"for k in range(freq_bins): # Only half of the bins contain useful info\n",
" wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n",
" wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n",
"\n",
" \n",
"x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n",
"x = x[None, None, :] #[B, C, T]\n",
"\n",
"kernel_sin = torch.tensor(wsin, dtype=torch.float)\n",
"kernel_cos = torch.tensor(wcos, dtype=torch.float)\n",
"print(kernel_sin.size())\n",
"\n",
"from torch.nn.functional import conv1d, conv2d, fold\n",
"spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n",
"spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n",
"\n",
"# remove redundant parts\n",
"spec = spec_real.pow(2) + spec_imag.pow(2)\n",
"spec = torch.sqrt(spec)\n",
"print(spec)\n",
"\n",
"x = spec[0].T.numpy()\n",
"y = rspec[:, :]\n",
"print(x)\n",
"print(y)\n",
"print(np.allclose(x, y))"
]
},
{
"cell_type": "code",
"execution_count": 162,
"id": "olive-nicaragua",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
},
{
"data": {
"text/plain": [
"27241"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.argmax(np.abs(x -y) / np.abs(y))"
]
},
{
"cell_type": "code",
"execution_count": 165,
"id": "ultimate-assault",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[np.unravel_index(27241, y.shape)]"
]
},
{
"cell_type": "code",
"execution_count": 166,
"id": "institutional-stock",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4.2412265e-10"
]
},
"execution_count": 166,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x[np.unravel_index(27241, y.shape)]"
]
},
{
"cell_type": "code",
"execution_count": 167,
"id": "integrated-courage",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 167,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.allclose(y, x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "different-operation",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "purple-consequence",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "defensive-mason",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "patient-convention",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Namespace(delta_delta=False, feat_dim=80, manifest_path='examples/aishell/s1/data/manifest.train.raw', num_samples=-1, num_workers=16, output_path='data/librispeech/mean_std.npz', sample_rate=16000, specgram_type='fbank', stride_ms=10.0, window_ms=25.0)\n"
]
}
],
"source": [
"import argparse\n",
"import functools\n",
"\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"from deepspeech.utils.utility import add_arguments\n",
"from deepspeech.utils.utility import print_arguments\n",
"\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, -1, \"# of samples to for statistics.\")\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc, fbank.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n",
"add_arg('delta_delta', bool,\n",
" False,\n",
" \"Audio feature with delta delta.\")\n",
"add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n",
"add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n",
"add_arg('sample_rate', int, 16000, \"target sample rate.\")\n",
"add_arg('manifest_path', str,\n",
" 'examples/aishell/s1/data/manifest.train.raw',\n",
" \"Filepath of manifest to compute normalizer's mean and stddev.\")\n",
"add_arg('num_workers',\n",
" default=16,\n",
" type=int,\n",
" help='num of subprocess workers for processing')\n",
"add_arg('output_path', str,\n",
" 'data/librispeech/mean_std.npz',\n",
" \"Filepath of write mean and stddev to (.npz).\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(args)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "enormous-currency",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import paddle\n",
"from paddle.io import DataLoader\n",
"from paddle.io import Dataset\n",
"\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"class CollateFunc(object):\n",
" ''' Collate function for AudioDataset\n",
" '''\n",
" def __init__(self):\n",
" pass\n",
" \n",
" def __call__(self, batch):\n",
" mean_stat = None\n",
" var_stat = None\n",
" number = 0\n",
" for feat in batch:\n",
" sums = np.sum(feat, axis=1)\n",
" if mean_stat is None:\n",
" mean_stat = sums\n",
" else:\n",
" mean_stat += sums\n",
"\n",
" square_sums = np.sum(np.square(feat), axis=1)\n",
" if var_stat is None:\n",
" var_stat = square_sums\n",
" else:\n",
" var_stat += square_sums\n",
"\n",
" number += feat.shape[1]\n",
" #return paddle.to_tensor(number), paddle.to_tensor(mean_stat), paddle.to_tensor(var_stat)\n",
" return number, mean_stat, var_stat\n",
"\n",
"\n",
"class AudioDataset(Dataset):\n",
" def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None):\n",
" self.feature_func = feature_func\n",
" self._rng = rng\n",
" manifest = read_manifest(manifest_path)\n",
" if num_samples == -1:\n",
" sampled_manifest = manifest\n",
" else:\n",
" sampled_manifest = self._rng.sample(manifest, num_samples)\n",
" self.items = sampled_manifest\n",
"\n",
" def __len__(self):\n",
" return len(self.items)\n",
"\n",
" def __getitem__(self, idx):\n",
" key = self.items[idx]['feat']\n",
" audioseg = AudioSegment.from_file(key)\n",
" feat = self.feature_func(audioseg) #(D, T)\n",
" return feat"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "armed-semester",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"process 1000 wavs,450739 frames\n",
"process 2000 wavs,887447 frames\n",
"process 3000 wavs,1354148 frames\n",
"process 4000 wavs,1816494 frames\n",
"process 5000 wavs,2359211 frames\n",
"process 6000 wavs,2828455 frames\n",
"process 7000 wavs,3276186 frames\n",
"process 8000 wavs,3692234 frames\n",
"process 9000 wavs,4139360 frames\n",
"process 10000 wavs,4591528 frames\n",
"process 11000 wavs,5020114 frames\n",
"process 12000 wavs,5459523 frames\n",
"process 13000 wavs,5899534 frames\n",
"process 14000 wavs,6323242 frames\n",
"process 15000 wavs,6736597 frames\n",
"process 16000 wavs,7207686 frames\n",
"process 17000 wavs,7637800 frames\n",
"process 18000 wavs,8093004 frames\n",
"process 19000 wavs,8529518 frames\n",
"process 20000 wavs,8906022 frames\n",
"process 21000 wavs,9352652 frames\n",
"process 22000 wavs,9807495 frames\n",
"process 23000 wavs,10247938 frames\n",
"process 24000 wavs,10700011 frames\n",
"process 25000 wavs,11126134 frames\n",
"process 26000 wavs,11558061 frames\n",
"process 27000 wavs,12010359 frames\n",
"process 28000 wavs,12470938 frames\n",
"process 29000 wavs,12916013 frames\n",
"process 30000 wavs,13345816 frames\n",
"process 31000 wavs,13752365 frames\n",
"process 32000 wavs,14174801 frames\n",
"process 33000 wavs,14642170 frames\n",
"process 34000 wavs,15053557 frames\n",
"process 35000 wavs,15531890 frames\n",
"process 36000 wavs,16022711 frames\n",
"process 37000 wavs,16437688 frames\n",
"process 38000 wavs,16859517 frames\n",
"process 39000 wavs,17307676 frames\n",
"process 40000 wavs,17796629 frames\n",
"process 41000 wavs,18264151 frames\n",
"process 42000 wavs,18711898 frames\n",
"process 43000 wavs,19159890 frames\n",
"process 44000 wavs,19576435 frames\n",
"process 45000 wavs,19992793 frames\n",
"process 46000 wavs,20464449 frames\n",
"process 47000 wavs,20886021 frames\n",
"process 48000 wavs,21317318 frames\n",
"process 49000 wavs,21738034 frames\n",
"process 50000 wavs,22171890 frames\n",
"process 51000 wavs,22622238 frames\n",
"process 52000 wavs,23100734 frames\n",
"process 53000 wavs,23526901 frames\n",
"process 54000 wavs,23969746 frames\n",
"process 55000 wavs,24418691 frames\n",
"process 56000 wavs,24862546 frames\n",
"process 57000 wavs,25336448 frames\n",
"process 58000 wavs,25778435 frames\n",
"process 59000 wavs,26216199 frames\n",
"process 60000 wavs,26694692 frames\n",
"process 61000 wavs,27148978 frames\n",
"process 62000 wavs,27617088 frames\n",
"process 63000 wavs,28064946 frames\n",
"process 64000 wavs,28519843 frames\n",
"process 65000 wavs,28989722 frames\n",
"process 66000 wavs,29470156 frames\n",
"process 67000 wavs,29952931 frames\n",
"process 68000 wavs,30360555 frames\n",
"process 69000 wavs,30797929 frames\n",
"process 70000 wavs,31218227 frames\n",
"process 71000 wavs,31663934 frames\n",
"process 72000 wavs,32107468 frames\n",
"process 73000 wavs,32541943 frames\n",
"process 74000 wavs,33010702 frames\n",
"process 75000 wavs,33448082 frames\n",
"process 76000 wavs,33886812 frames\n",
"process 77000 wavs,34338108 frames\n",
"process 78000 wavs,34761495 frames\n",
"process 79000 wavs,35199730 frames\n",
"process 80000 wavs,35669630 frames\n",
"process 81000 wavs,36122402 frames\n",
"process 82000 wavs,36604561 frames\n",
"process 83000 wavs,37085552 frames\n",
"process 84000 wavs,37517500 frames\n",
"process 85000 wavs,37987196 frames\n",
"process 86000 wavs,38415721 frames\n",
"process 87000 wavs,38889467 frames\n",
"process 88000 wavs,39337809 frames\n",
"process 89000 wavs,39792342 frames\n",
"process 90000 wavs,40287946 frames\n",
"process 91000 wavs,40719461 frames\n",
"process 92000 wavs,41178919 frames\n",
"process 93000 wavs,41659635 frames\n",
"process 94000 wavs,42132985 frames\n",
"process 95000 wavs,42584564 frames\n",
"process 96000 wavs,43018598 frames\n",
"process 97000 wavs,43480662 frames\n",
"process 98000 wavs,43973670 frames\n",
"process 99000 wavs,44448190 frames\n",
"process 100000 wavs,44935034 frames\n",
"process 101000 wavs,45379812 frames\n",
"process 102000 wavs,45821207 frames\n",
"process 103000 wavs,46258420 frames\n",
"process 104000 wavs,46743733 frames\n",
"process 105000 wavs,47206922 frames\n",
"process 106000 wavs,47683041 frames\n",
"process 107000 wavs,48122809 frames\n",
"process 108000 wavs,48594623 frames\n",
"process 109000 wavs,49086358 frames\n",
"process 110000 wavs,49525568 frames\n",
"process 111000 wavs,49985820 frames\n",
"process 112000 wavs,50428262 frames\n",
"process 113000 wavs,50897957 frames\n",
"process 114000 wavs,51344589 frames\n",
"process 115000 wavs,51774621 frames\n",
"process 116000 wavs,52243372 frames\n",
"process 117000 wavs,52726025 frames\n",
"process 118000 wavs,53170026 frames\n",
"process 119000 wavs,53614141 frames\n",
"process 120000 wavs,54071271 frames\n"
]
}
],
"source": [
"\n",
"augmentation_pipeline = AugmentationPipeline('{}')\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20)\n",
"\n",
"def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
"\n",
"collate_func = CollateFunc()\n",
"\n",
"dataset = AudioDataset(\n",
" args.manifest_path,\n",
" augment_and_featurize, \n",
" args.num_samples)\n",
"\n",
"batch_size = 20\n",
"data_loader = DataLoader(\n",
" dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=args.num_workers,\n",
" collate_fn=collate_func)\n",
"\n",
"with paddle.no_grad():\n",
" all_mean_stat = None\n",
" all_var_stat = None\n",
" all_number = 0\n",
" wav_number = 0\n",
" for i, batch in enumerate(data_loader()):\n",
" #for batch in data_loader():\n",
" number, mean_stat, var_stat = batch\n",
" if i == 0:\n",
" all_mean_stat = mean_stat\n",
" all_var_stat = var_stat\n",
" else:\n",
" all_mean_stat += mean_stat\n",
" all_var_stat += var_stat\n",
" all_number += number\n",
" wav_number += batch_size\n",
"\n",
" if wav_number % 1000 == 0:\n",
" print('process {} wavs,{} frames'.format(wav_number,\n",
" all_number))\n",
"\n",
"cmvn_info = {\n",
" 'mean_stat': list(all_mean_stat.tolist()),\n",
" 'var_stat': list(all_var_stat.tolist()),\n",
" 'frame_num': all_number\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "danish-executive",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'mean_stat': [-813852467.7953382, -769025957.9140725, -809499593.411409, -774700574.014532, -750961217.5896736, -760564397.2864963, -805662399.3771614, -843490965.4231446, -850242081.9416809, -857678651.504435, -879067453.9826999, -908602072.3856701, -936850957.7187386, -957242686.489041, -968425442.0916103, -972687545.5953809, -980383731.7683417, -991533337.6343704, -1001966818.1164789, -1010334169.7486078, -1016855066.9099333, -1022176245.7021623, -1025700476.4788507, -1030678878.3195274, -1037075963.124199, -1042705719.0195516, -1047422212.6492896, -1049003537.271861, -1050314833.7453628, -1050772191.0204058, -1050010034.9948177, -1050436065.1336465, -1053327181.7978873, -1058710548.2036785, -1065950852.4966162, -1071709705.0060445, -1077682778.259181, -1083371045.272074, -1089708906.2657735, -1096312217.7865202, -1101089858.8364556, -1104965332.4332569, -1107791702.5223634, -1109431075.2374773, -1110066333.0280604, -1110382732.0722318, -1110480306.3793216, -1110203297.7110727, -1109972534.3583376, -1109378081.8792782, -1108212059.413654, -1107235713.2041805, -1106973581.9280007, -1107352339.7860134, -1108730029.862537, -1110425202.83704, -1113220669.4552443, -1115887535.4870913, -1118105356.3628063, -1120001376.8503075, -1121135822.320366, -1122265971.8751016, -1123990217.401155, -1125786729.6230593, -1127784957.2745507, -1129180108.9033566, -1132000461.6688302, -1134675829.8190608, -1137652487.5164194, -1141755948.0463965, -1145340901.5468378, -1148637682.593287, -1151755522.470022, -1154981643.2268832, -1157417488.840151, -1161240429.0989249, -1165411128.671642, -1170521097.1034513, -1176307165.5109766, -1183456865.0039694, -1190535938.6591117, -1197946309.0472982, -1203596565.037139, -1207563038.1241052, -1209707561.5829782, -1211407066.2452552, -1211884576.9201162, -1212778872.005509, -1214041413.8080075, -1215367953.1745043, -1216850831.482193, -1217678325.5351057, -1218854289.54188, -1219325064.8610544, -1219080344.7580786, -1218541313.657531, -1217889833.2067819, -1216552930.1654336, -1216423777.4113154, -1216575252.225508, -1217075384.9826024, -1217391577.901724, -1217838974.57273, -1218131805.6054134, -1218294889.7465532, -1218566666.1755593, -1218790537.5519717, -1218748668.9956846, -1218603191.4941735, -1218004566.4348054, -1217312410.127734, -1217207493.9522285, -1217284002.3834674, -1217644312.51745, -1218039821.6444128, -1218721811.6269798, -1219121088.9265897, -1219014460.8090584, -1218530127.6776083, -1217952335.451711, -1217316073.8666434, -1217035380.1151958, -1216636431.2964456, -1216257015.2945514, -1215658496.1208403, -1215097272.0976632, -1214669859.2064147, -1214593853.4809475, -1214599475.7838447, -1214575440.823035, -1214158828.8008435, -1213482920.2673717, -1212476577.5897374, -1211251374.2198513, -1210284855.590475, -1209302456.065669, -1209106252.6625297, -1209373211.5146718, -1209689421.7984035, -1210021342.495856, -1210650609.3592312, -1211428521.3900626, -1212616111.4257205, -1213820075.2948189, -1215320588.7144456, -1217175082.2739282, -1219703351.4585004, -1222007827.120464, -1224637375.5900724, -1228367798.912171, -1234853879.862459, -1247222219.867692, -1268562808.1616178, -1302034822.9569275, -1347823631.0776038, -1402753916.9445229, -1458826717.3262982, -1505843092.0970414, -1534278782.249077, -1543955545.8994718, -1600409154.893352], 'var_stat': [12665413908.91729, 11145088801.244318, 12567119446.035736, 11758392758.06822, 11200687982.736668, 11551903443.711124, 12880777868.435602, 14084854368.236998, 14394011058.866192, 14678818621.277662, 15346278722.626339, 16268053979.757076, 17191705347.854794, 17877540386.548733, 18251857849.077663, 18392628178.710472, 18645534548.4045, 19018598212.22902, 19366711357.782673, 19655730286.72857, 19890681996.786858, 20094163350.461906, 20227774955.225887, 20423525628.66887, 20669928826.76939, 20882313568.247944, 21062392676.270527, 21126648821.879055, 21185210734.751118, 21209014745.520447, 21182293842.91236, 21197433134.875977, 21302147790.662144, 21504666657.651955, 21781818550.89697, 21996170165.145462, 22217169779.096275, 22431161762.176693, 22672708668.38104, 22922683961.072956, 23101137011.201683, 23249680793.556847, 23358894817.24979, 23422895267.919228, 23449479198.303394, 23464433357.671055, 23469197140.124596, 23459013479.866177, 23447935341.542686, 23422585038.052387, 23375601301.949135, 23338397991.497776, 23329682884.21905, 23348002892.39853, 23406274659.89975, 23478242518.92228, 23592891371.876236, 23703885161.772205, 23797158601.65954, 23875230355.66992, 23918333664.3946, 23968582109.371258, 24040547318.081936, 24112364295.110058, 24189973697.612144, 24242165205.640236, 24364255205.82311, 24472408850.760197, 24590211203.05312, 24763026764.005527, 24909192634.69144, 25043438176.23281, 25167141466.500504, 25297108031.48665, 25395377064.0999, 25550930772.86505, 25721404827.10336, 25931101211.156487, 26168988710.098465, 26465528802.762875, 26760033029.443783, 27075408488.605213, 27316626931.655052, 27487275073.52796, 27579518448.2332, 27652308513.875782, 27673412508.45838, 27711509210.702576, 27767312240.641487, 27827464683.295334, 27894794590.957966, 27935988489.16511, 27992337099.891083, 28019655483.58796, 28014286886.252903, 27996189233.857716, 27973078840.875465, 27920045013.68706, 27917103211.22359, 27927566165.64652, 27953525818.61368, 27973386070.140022, 27999317832.502476, 28019494120.641834, 28033010746.452637, 28051086123.896503, 28066195174.191753, 28068570977.318798, 28064890246.85437, 28042424375.860577, 28015849655.869568, 28014812222.566605, 28021039053.959835, 28039270607.169422, 28058271295.10199, 28088976520.10178, 28107824988.74732, 28105633030.784756, 28087681357.818607, 28065484299.963837, 28039555887.004284, 28028214431.52875, 28011714871.929447, 27995603790.480755, 27970125897.561134, 27946436130.511288, 27929044772.5522, 27926612443.390316, 27926256324.387302, 27924771848.71099, 27905526922.390133, 27876268519.168198, 27832532606.552593, 27779497699.976765, 27737034351.907337, 27692129825.179924, 27684252911.371475, 27698882622.878677, 27712387157.27985, 27726474638.933037, 27752647691.051613, 27786197932.382797, 27836378752.662235, 27887415700.334576, 27949784230.702114, 28028117657.84245, 28136313097.200474, 28234098926.207996, 28345845477.25874, 28507222800.146496, 28793832339.90449, 29350765483.070816, 30328262350.231213, 31894930713.76519, 34093669067.422382, 36801959396.22739, 39638995447.49344, 42088579425.44825, 43616108982.85117, 44152063315.31461, 47464832889.5967], 'frame_num': 54129649}\n"
]
}
],
"source": [
"print(cmvn_info)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "accurate-terminal",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dominant-abuse",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 1000 wavs,450240 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 2000 wavs,886411 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 3000 wavs,1352580 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 4000 wavs,1814397 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 5000 wavs,2356587 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 6000 wavs,2825310 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 7000 wavs,3272506 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 8000 wavs,3688045 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 9000 wavs,4134669 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 10000 wavs,4586357 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 11000 wavs,5014429 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 12000 wavs,5453334 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 13000 wavs,5892888 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 14000 wavs,6316059 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 15000 wavs,6728870 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 16000 wavs,7199442 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 17000 wavs,7629055 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 18000 wavs,8083729 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 19000 wavs,8519732 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 20000 wavs,8895694 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 21000 wavs,9341778 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 22000 wavs,9796126 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 23000 wavs,10236057 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 24000 wavs,10687461 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 25000 wavs,11113082 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 26000 wavs,11544482 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 27000 wavs,11996273 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 28000 wavs,12456350 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 29000 wavs,12900895 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 30000 wavs,13330353 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 31000 wavs,13736568 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 32000 wavs,14158472 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 33000 wavs,14625316 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 34000 wavs,15036206 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 35000 wavs,15514001 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 36000 wavs,16004323 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 37000 wavs,16418799 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 38000 wavs,16840100 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 39000 wavs,17287752 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 40000 wavs,17776206 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 41000 wavs,18243209 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 42000 wavs,18690449 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 43000 wavs,19137940 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 44000 wavs,19553966 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 45000 wavs,19969813 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 46000 wavs,20440963 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 47000 wavs,20862022 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 48000 wavs,21292801 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 49000 wavs,21713004 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 50000 wavs,22146346 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 51000 wavs,22596172 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 52000 wavs,23074160 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 53000 wavs,23499823 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 54000 wavs,23942151 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 55000 wavs,24390566 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 56000 wavs,24833905 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 57000 wavs,25307270 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 58000 wavs,25748720 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 59000 wavs,26185964 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 60000 wavs,26663953 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 61000 wavs,27117720 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 62000 wavs,27585349 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 63000 wavs,28032693 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 64000 wavs,28487074 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 65000 wavs,28956462 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 66000 wavs,29436358 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 67000 wavs,29918569 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 68000 wavs,30325682 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 69000 wavs,30762528 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 70000 wavs,31182319 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 71000 wavs,31627526 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 72000 wavs,32070556 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 73000 wavs,32504534 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 74000 wavs,32972775 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 75000 wavs,33409637 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 76000 wavs,33847861 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 77000 wavs,34298647 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 78000 wavs,34721536 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 79000 wavs,35159236 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 80000 wavs,35628628 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 81000 wavs,36080909 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 82000 wavs,36562496 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 83000 wavs,37042976 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 84000 wavs,37474403 frames\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 85000 wavs,37943596 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 86000 wavs,38371620 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 87000 wavs,38844874 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 88000 wavs,39292686 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 89000 wavs,39746715 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 90000 wavs,40241800 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 91000 wavs,40672817 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 92000 wavs,41131773 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 93000 wavs,41612001 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 94000 wavs,42084822 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 95000 wavs,42535878 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 96000 wavs,42969365 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 97000 wavs,43430890 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 98000 wavs,43923378 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 99000 wavs,44397370 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 100000 wavs,44883695 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 101000 wavs,45327968 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 102000 wavs,45768860 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 103000 wavs,46205602 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 104000 wavs,46690407 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 105000 wavs,47153089 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 106000 wavs,47628699 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 107000 wavs,48067945 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 108000 wavs,48539256 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 109000 wavs,49030485 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 110000 wavs,49469189 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 111000 wavs,49928968 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 112000 wavs,50370921 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 113000 wavs,50840090 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 114000 wavs,51286249 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 115000 wavs,51715786 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 116000 wavs,52184017 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 117000 wavs,52666156 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 118000 wavs,53109645 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 119000 wavs,53553253 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 120000 wavs,54009877 frames\n",
"{'mean_stat': [700612678.1184504, 704246512.9321843, 720430663.1822729, 754033269.0474415, 798737761.616614, 829467218.4204571, 851246702.9426627, 862261185.2661449, 859339943.6923889, 846303730.8696194, 832995109.605447, 823196536.6029147, 832626008.2569772, 845571326.1936859, 848801373.0562981, 846503549.328017, 836774344.5500796, 823481091.0445303, 820728368.2518216, 804571348.4957463, 795306095.0083207, 811729024.2415155, 805734803.5703195, 813076782.1959459, 806620199.406499, 809655573.8886961, 804371708.9347517, 809272248.6085774, 810322689.7490631, 814294131.1973915, 816262716.0476038, 816213124.2411841, 817158473.4380915, 821414211.5629157, 827408091.5728914, 834353896.0519086, 840094990.3467333, 842613218.6554606, 842070761.1727513, 834970952.5260613, 837020570.8200948, 829592602.7833654, 830116543.8893851, 829482316.3881509, 833397219.4597517, 839251633.3120549, 845475010.4718693, 852378426.7183967, 859563981.8633184, 866063840.5523493, 867790921.9978689, 868215100.5962687, 869683066.032885, 872467375.6674014, 873097681.1780069, 873025823.0543871, 869897292.7201596, 866386426.3869117, 863166726.7256871, 854653071.2244718, 842402803.9000899, 830838253.4144138, 830143002.3536818, 831492285.0310817, 833304371.8781006, 838896092.8621838, 843866088.9578133, 847316792.1429776, 851038022.3643295, 855931698.0149751, 859320543.9795249, 863031001.3470656, 868325062.1832993, 873626971.0115026, 878726636.924209, 884861725.972504, 886920281.5192285, 883056006.5094173, 863719240.7255149, 773378975.9476194], 'var_stat': [9237018652.657722, 9417257721.82426, 10105084297.159702, 11071318522.587782, 12422783727.426847, 13400306419.784964, 14148498843.406874, 14576436982.89939, 14529009036.494726, 14105645932.596651, 13682988821.478252, 13413013425.088106, 13764134927.293928, 14233704806.737064, 14361631309.367067, 14281358385.45644, 13939662689.213865, 13496884231.929493, 13382566162.783987, 12871350930.6626, 12576198160.876635, 13051463889.56708, 12859205935.513906, 13053861416.098743, 12830323588.550724, 12886405923.897238, 12708529922.84171, 12847306110.231739, 12880398489.53404, 13002566299.565536, 13066708060.463543, 13064231286.858614, 13088983337.353497, 13221393824.891022, 13412425607.755072, 13631485149.777075, 13807797519.156103, 13877277485.033077, 13848613909.96762, 13609176326.2529, 13649815250.130072, 13397698404.696907, 13388964704.359968, 13354326914.968012, 13469861474.898457, 13652539440.283333, 13846837321.329163, 14062143714.601675, 14292571198.61228, 14504626563.299246, 14563864749.132776, 14579720287.991764, 14626700787.353922, 14716185568.128899, 14728532777.28015, 14719101187.113443, 14607945896.239174, 14478517828.531614, 14355110561.681187, 14057430280.249746, 13634284490.879377, 13248236002.494394, 13217602306.335958, 13257856701.946049, 13323688441.072674, 13515395318.023148, 13685827169.67645, 13811622609.426846, 13947347160.615082, 14115883822.884943, 14231204526.433033, 14356066668.651815, 14533604268.238445, 14708971788.69237, 14875667326.732443, 15079098318.79331, 15144888989.667963, 15002658970.504765, 14349232841.34513, 11544480117.013124], 'frame_num': 54068199}\n"
]
}
],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import paddle\n",
"from paddle.io import DataLoader\n",
"from paddle.io import Dataset\n",
"\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"# https://github.com/PaddlePaddle/Paddle/pull/31481\n",
"class CollateFunc(object):\n",
" ''' Collate function for AudioDataset\n",
" '''\n",
" def __init__(self, feature_func):\n",
" self.feature_func = feature_func\n",
" \n",
" def __call__(self, batch):\n",
" mean_stat = None\n",
" var_stat = None\n",
" number = 0\n",
" for item in batch:\n",
" audioseg = AudioSegment.from_file(item['feat'])\n",
" feat = self.feature_func(audioseg) #(D, T)\n",
"\n",
" sums = np.sum(feat, axis=1)\n",
" if mean_stat is None:\n",
" mean_stat = sums\n",
" else:\n",
" mean_stat += sums\n",
"\n",
" square_sums = np.sum(np.square(feat), axis=1)\n",
" if var_stat is None:\n",
" var_stat = square_sums\n",
" else:\n",
" var_stat += square_sums\n",
"\n",
" number += feat.shape[1]\n",
" return number, mean_stat, var_stat\n",
"\n",
"\n",
"class AudioDataset(Dataset):\n",
" def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):\n",
" self._rng = rng if rng else np.random.RandomState(random_seed)\n",
" manifest = read_manifest(manifest_path)\n",
" if num_samples == -1:\n",
" sampled_manifest = manifest\n",
" else:\n",
" sampled_manifest = self._rng.choice(manifest, num_samples, replace=False)\n",
" self.items = sampled_manifest\n",
"\n",
" def __len__(self):\n",
" return len(self.items)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.items[idx]\n",
" \n",
" \n",
"augmentation_pipeline = AugmentationPipeline('{}')\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20)\n",
"\n",
"def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
"\n",
"collate_func = CollateFunc(augment_and_featurize)\n",
"\n",
"dataset = AudioDataset(\n",
" args.manifest_path,\n",
" args.num_samples)\n",
"\n",
"batch_size = 20\n",
"data_loader = DataLoader(\n",
" dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=args.num_workers,\n",
" collate_fn=collate_func)\n",
"\n",
"with paddle.no_grad():\n",
" all_mean_stat = None\n",
" all_var_stat = None\n",
" all_number = 0\n",
" wav_number = 0\n",
" for i, batch in enumerate(data_loader):\n",
" number, mean_stat, var_stat = batch\n",
" if i == 0:\n",
" all_mean_stat = mean_stat\n",
" all_var_stat = var_stat\n",
" else:\n",
" all_mean_stat += mean_stat\n",
" all_var_stat += var_stat\n",
" all_number += number\n",
" wav_number += batch_size\n",
"\n",
" if wav_number % 1000 == 0:\n",
" print('process {} wavs,{} frames'.format(wav_number,\n",
" all_number))\n",
"\n",
"cmvn_info = {\n",
" 'mean_stat': list(all_mean_stat.tolist()),\n",
" 'var_stat': list(all_var_stat.tolist()),\n",
" 'frame_num': all_number\n",
"}\n",
"print(cmvn_info)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unlike-search",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "emerging-meter",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"import math\n",
"import random\n",
"import tarfile\n",
"import logging\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"from functools import partial\n",
"\n",
"import paddle\n",
"from paddle.io import Dataset\n",
"from paddle.io import DataLoader\n",
"from paddle.io import BatchSampler\n",
"from paddle.io import DistributedBatchSampler\n",
"from paddle import distributed as dist\n",
"\n",
"from data_utils.utility import read_manifest\n",
"from data_utils.augmentor.augmentation import AugmentationPipeline\n",
"from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n",
"from data_utils.speech import SpeechSegment\n",
"from data_utils.normalizer import FeatureNormalizer\n",
"\n",
"\n",
"from data_utils.dataset import (\n",
" DeepSpeech2Dataset,\n",
" DeepSpeech2DistributedBatchSampler,\n",
" DeepSpeech2BatchSampler,\n",
" SpeechCollator,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "excessive-american",
"metadata": {},
"outputs": [],
"source": [
"def create_dataloader(manifest_path,\t\n",
" vocab_filepath,\t\n",
" mean_std_filepath,\t\n",
" augmentation_config='{}',\t\n",
" max_duration=float('inf'),\t\n",
" min_duration=0.0,\t\n",
" stride_ms=10.0,\t\n",
" window_ms=20.0,\t\n",
" max_freq=None,\t\n",
" specgram_type='linear',\t\n",
" use_dB_normalization=True,\t\n",
" random_seed=0,\t\n",
" keep_transcription_text=False,\t\n",
" is_training=False,\t\n",
" batch_size=1,\t\n",
" num_workers=0,\t\n",
" sortagrad=False,\t\n",
" shuffle_method=None,\t\n",
" dist=False):\t\n",
"\n",
" dataset = DeepSpeech2Dataset(\t\n",
" manifest_path,\t\n",
" vocab_filepath,\t\n",
" mean_std_filepath,\t\n",
" augmentation_config=augmentation_config,\t\n",
" max_duration=max_duration,\t\n",
" min_duration=min_duration,\t\n",
" stride_ms=stride_ms,\t\n",
" window_ms=window_ms,\t\n",
" max_freq=max_freq,\t\n",
" specgram_type=specgram_type,\t\n",
" use_dB_normalization=use_dB_normalization,\t\n",
" random_seed=random_seed,\t\n",
" keep_transcription_text=keep_transcription_text)\t\n",
"\n",
" if dist:\t\n",
" batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n",
" dataset,\t\n",
" batch_size,\t\n",
" num_replicas=None,\t\n",
" rank=None,\t\n",
" shuffle=is_training,\t\n",
" drop_last=is_training,\t\n",
" sortagrad=is_training,\t\n",
" shuffle_method=shuffle_method)\t\n",
" else:\t\n",
" batch_sampler = DeepSpeech2BatchSampler(\t\n",
" dataset,\t\n",
" shuffle=is_training,\t\n",
" batch_size=batch_size,\t\n",
" drop_last=is_training,\t\n",
" sortagrad=is_training,\t\n",
" shuffle_method=shuffle_method)\t\n",
"\n",
" def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n",
" \"\"\"\t\n",
" Padding audio features with zeros to make them have the same shape (or\t\n",
" a user-defined shape) within one bach.\t\n",
"\n",
" If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n",
" as the target shape for padding. Otherwise, `padding_to` will be the\t\n",
" target shape (only refers to the second axis).\t\n",
"\n",
" If `flatten` is True, features will be flatten to 1darray.\t\n",
" \"\"\"\t\n",
" new_batch = []\t\n",
" # get target shape\t\n",
" max_length = max([audio.shape[1] for audio, text in batch])\t\n",
" if padding_to != -1:\t\n",
" if padding_to < max_length:\t\n",
" raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n",
" \"than any instance's shape in the batch\")\t\n",
" max_length = padding_to\t\n",
" max_text_length = max([len(text) for audio, text in batch])\t\n",
" # padding\t\n",
" padded_audios = []\t\n",
" audio_lens = []\t\n",
" texts, text_lens = [], []\t\n",
" for audio, text in batch:\t\n",
" padded_audio = np.zeros([audio.shape[0], max_length])\t\n",
" padded_audio[:, :audio.shape[1]] = audio\t\n",
" if flatten:\t\n",
" padded_audio = padded_audio.flatten()\t\n",
" padded_audios.append(padded_audio)\t\n",
" audio_lens.append(audio.shape[1])\t\n",
"\n",
" padded_text = np.zeros([max_text_length])\n",
" if is_training:\n",
" padded_text[:len(text)] = text\t# ids\n",
" else:\n",
" padded_text[:len(text)] = [ord(t) for t in text] # string\n",
" \n",
" texts.append(padded_text)\t\n",
" text_lens.append(len(text))\t\n",
"\n",
" padded_audios = np.array(padded_audios).astype('float32')\t\n",
" audio_lens = np.array(audio_lens).astype('int64')\t\n",
" texts = np.array(texts).astype('int32')\t\n",
" text_lens = np.array(text_lens).astype('int64')\t\n",
" return padded_audios, texts, audio_lens, text_lens\t\n",
"\n",
" loader = DataLoader(\t\n",
" dataset,\t\n",
" batch_sampler=batch_sampler,\t\n",
" collate_fn=partial(padding_batch, is_training=is_training),\t\n",
" num_workers=num_workers)\t\n",
" return loader"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "naval-brave",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/data/manifest.dev',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/aishell/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'linear',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "bearing-physics",
"metadata": {},
"outputs": [],
"source": [
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=True,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "classified-melissa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[22823, 26102, 20195, 37324, 0 , 0 ],\n",
" [22238, 26469, 23601, 22909, 0 , 0 ],\n",
" [20108, 26376, 22235, 26085, 0 , 0 ],\n",
" [36824, 35201, 20445, 25345, 32654, 24863],\n",
" [29042, 27748, 21463, 23456, 0 , 0 ]])\n",
"test raw 大时代里\n",
"test raw 煲汤受宠\n",
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [4, 4, 4, 6, 4])\n",
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
" ...,\n",
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
"\n",
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
" ...,\n",
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
]
}
],
"source": [
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len', audio_len)\n",
" print('test len', text_len)\n",
" print('audio', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unexpected-skating",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "minus-modern",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "medieval-monday",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/DeepSpeech-2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "emerging-meter",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n"
]
}
],
"source": [
"import math\n",
"import random\n",
"import tarfile\n",
"import logging\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"from functools import partial\n",
"\n",
"import paddle\n",
"from paddle.io import Dataset\n",
"from paddle.io import DataLoader\n",
"from paddle.io import BatchSampler\n",
"from paddle.io import DistributedBatchSampler\n",
"from paddle import distributed as dist\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "excessive-american",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "naval-brave",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:93] register user softmax to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:97] register user log_softmax to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:101] register user sigmoid to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:105] register user log_sigmoid to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:109] register user relu to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:119] override cat of paddle if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:133] override item of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:144] override long of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:164] override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:179] override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:185] override eq of paddle if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:195] override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:212] override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:223] register user view to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:233] register user view_as to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:259] register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:277] register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:288] register user fill_ to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:298] register user repeat to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:303] register user softmax to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:308] register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:312] register user relu to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:322] register user type_as to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:337] register user to to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:346] register user float to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:356] register user tolist to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:371] register user glu to paddle.nn.functional, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:422] override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:428] register user Module to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:434] register user ModuleList to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:450] register user GLU to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:483] register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:489] register user export to paddle.jit, remove this when fixed!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/tiny/s1/data/spm_bpe', 'infer_manifest': 'examples/tiny/s1/data/manifest.tiny', 'mean_std_path': 'examples/tiny/s1/data/mean_std.npz', 'vocab_path': 'examples/tiny/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/tiny/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('unit_type', str,\n",
" 'char',\n",
" \"Options: char, word, spm.\",\n",
" choices=['char', 'word', 'spm'])\n",
"add_arg('spm_model_prefix', str,\n",
" 'examples/tiny/s1/data/spm_bpe',\n",
" \"spm model prefix.\",)\n",
"add_arg('infer_manifest', str,\n",
" 'examples/tiny/s1/data/manifest.tiny',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/tiny/s1/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/tiny/s1/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/tiny/s1/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n",
"add_arg('delta_delta', bool, False, \"delta delta\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "wired-principal",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/aishell/s1/data/spm_bpe', 'infer_manifest': 'examples/aishell/s1/data/manifest.test', 'mean_std_path': '', 'vocab_path': 'examples/aishell/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('unit_type', str,\n",
" 'char',\n",
" \"Options: char, word, spm.\",\n",
" choices=['char', 'word', 'spm'])\n",
"add_arg('spm_model_prefix', str,\n",
" 'examples/aishell/s1/data/spm_bpe',\n",
" \"spm model prefix.\",)\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/s1/data/manifest.test',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" '',\n",
" \"examples/aishell/s1/data/mean_std.npz, Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/s1/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/s1/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n",
"add_arg('delta_delta', bool, False, \"delta delta\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bearing-physics",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n",
"from deepspeech.frontend.speech import SpeechSegment\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"\n",
"\n",
"from deepspeech.io.collator import SpeechCollator\n",
"from deepspeech.io.dataset import ManifestDataset\n",
"from deepspeech.io.sampler import (\n",
" SortagradDistributedBatchSampler,\n",
" SortagradBatchSampler,\n",
")\n",
"from deepspeech.io import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" unit_type=args.unit_type,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" spm_model_prefix=args.spm_model_prefix,\n",
" augmentation_config='{}',\n",
" max_input_len=27.0,\n",
" min_input_len=0.0,\n",
" max_output_len=float('inf'),\n",
" min_output_len=0.0,\n",
" max_output_input_ratio=float('inf'),\n",
" min_output_input_ratio=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=True,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" num_workers=0,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "classified-melissa",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[232 387 331 ... 249 249 262] int16\n",
"fbank\n",
"[-138 -219 -192 ... 338 324 351] int16\n",
"fbank\n",
"[ 694 1175 1022 ... 553 514 627] int16\n",
"fbank\n",
"[-39 -79 -53 ... 139 172 99] int16\n",
"fbank\n",
"[-277 -480 -425 ... 758 767 739] int16\n",
"fbank\n",
"[ 399 693 609 ... 1291 1270 1291] int16\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if arr.dtype == np.object:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[ -750 -1254 -1107 ... 2276 1889 2067] int16\n",
"fbank\n",
"[ -127 -199 -149 ... -5243 -5065 -5398] int16\n",
"fbank\n",
"[ 465 783 677 ... 980 903 1008] int16\n",
"fbank\n",
"[ 90 160 157 ... -2 -16 -21] int16\n",
"fbank\n",
"[ 213 345 295 ... 2483 2246 2501] int16\n",
"fbank\n",
"[ -86 -159 -131 ... 270 258 290] int16\n",
"fbank\n",
"[-1023 -1714 -1505 ... 1532 1596 1575] int16\n",
"fbank\n",
"[-366 -602 -527 ... 374 370 379] int16\n",
"fbank\n",
"[ 761 1275 1127 ... 369 413 295] int16\n",
"fbank\n",
"[382 621 550 ... 161 161 174] int16\n",
"fbank\n",
"[ -28 -91 -120 ... 28 34 11] int16\n",
"fbank\n",
"[ -5 -5 -5 ... 268 294 341] int16\n",
"fbank\n",
"[240 417 684 ... 267 262 219] int16\n",
"fbank\n",
"[131 206 194 ... 383 320 343] int16\n",
"test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[31069, 21487, 29233, 30340, 20320, -1 , -1 ],\n",
" [20540, 24471, 19968, 25552, 30340, 26159, -1 ],\n",
" [36825, 20010, 31243, 24230, 26159, 32654, 30340],\n",
" [20108, 21040, 20108, -1 , -1 , -1 , -1 ],\n",
" [21435, 34892, 25919, 21270, -1 , -1 , -1 ]])\n",
"fbank\n",
"[1155 1890 1577 ... 1092 989 1130] int16\n",
"fbank\n",
"[296 358 296 ... 140 140 168] int16\n",
"fbank\n",
"[-50 -91 -63 ... 104 104 86] int16\n",
"fbank\n",
"[-37 -66 -50 ... -31 -45 -52] int16\n",
"fbank\n",
"[-401 -652 -547 ... -339 -307 -344] int16\n",
"fbank\n",
"[-21 -47 -51 ... 94 81 107] int16\n",
"fbank\n",
"[ 533 887 755 ... 3074 2853 3254] int16\n",
"fbank\n",
"[ 44 71 66 ... -628 -733 -601] int16\n",
"fbank\n",
"[ 50 86 79 ... 129 116 138] int16\n",
"fbank\n",
"[ 92 146 126 ... -208 -193 -179] int16\n",
"test raw: 祝可爱的你\n",
"test raw: 去行政化\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [184, 194, 196, 204, 207])\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [5, 6, 7, 3, 4])\n",
"audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[12.25633812, 12.61639309, 10.36936474, ..., 13.02949619, 11.51365757, 10.59789085],\n",
" [13.32148266, 13.41071606, 11.43800735, ..., 13.69783783, 12.83939362, 11.51259613],\n",
" [12.62640572, 12.53621101, 10.97212505, ..., 13.33757591, 12.32293034, 10.75493717],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[10.99619484, 11.35202599, 9.56922054 , ..., 9.94971657 , 9.88354111 , 9.55315971 ],\n",
" [10.44461155, 9.81688595 , 5.62538481 , ..., 10.60468388, 10.94417381, 9.42646980 ],\n",
" [10.23835754, 10.23407459, 7.99464273 , ..., 10.68097591, 9.91640091 , 10.04131031],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[14.10299397, 14.50298119, 12.87738323, ..., 12.62796497, 12.69949627, 11.43171215],\n",
" [13.85035992, 13.15289116, 10.66541386, ..., 13.34364223, 13.46972179, 11.02160740],\n",
" [13.19866467, 13.23537827, 11.65760899, ..., 12.72559357, 12.42716217, 11.74562359],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.85668373, 12.82431412, 11.68144703, ..., 14.10119247, 15.12791920, 13.68221378],\n",
" [13.19507027, 13.40244961, 11.43618393, ..., 13.32919979, 13.68267441, 12.73429012],\n",
" [13.02173328, 12.92082500, 11.44303989, ..., 12.77793121, 13.10915661, 11.77327728],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.90771198, 13.40234852, 13.01435471, ..., 13.80359459, 14.08088684, 13.17883396],\n",
" [14.06678009, 14.06943512, 12.52837276, ..., 13.66423225, 13.66300583, 13.60142994],\n",
" [12.58743191, 12.94520760, 11.75190544, ..., 14.28828907, 14.08229160, 13.02433395],\n",
" ...,\n",
" [16.20896912, 16.42283821, 14.94358730, ..., 12.91146755, 12.66766262, 11.76361752],\n",
" [13.49324894, 14.14653301, 13.16490936, ..., 13.23435783, 13.45378494, 12.60386276],\n",
" [15.56288910, 15.92445087, 14.90794277, ..., 13.43840790, 13.41075516, 12.55605984]]])\n"
]
}
],
"source": [
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len:', audio_len)\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unexpected-skating",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "minus-modern",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[232 387 331 ... 249 249 262] int16\n",
"fbank\n",
"[-138 -219 -192 ... 338 324 351] int16\n",
"fbank\n",
"[ 694 1175 1022 ... 553 514 627] int16\n",
"fbank\n",
"[-39 -79 -53 ... 139 172 99] int16\n",
"fbank\n",
"[-277 -480 -425 ... 758 767 739] int16\n",
"fbank\n",
"test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[2695, 505, 2332, 2553, 169, -1 , -1 ],\n",
" [ 230, 1237, 2 , 1556, 2553, 1694, -1 ],\n",
" [3703, 28 , 2739, 1172, 1694, 2966, 2553],\n",
" [ 70 , 355, 70 , -1 , -1 , -1 , -1 ],\n",
" [ 477, 3363, 1621, 412, -1 , -1 , -1 ]])\n",
"[ 399 693 609 ... 1291 1270 1291] int16\n",
"test raw: ઇǹज৹©\n",
"test raw: ǝണٕƜ\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [5, 6, 7, 3, 4])\n",
"audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[12.25794601, 12.61855793, 10.37306023, ..., 13.12571049, 11.53678799, 10.32210350],\n",
" [13.32333183, 13.41336918, 11.44248962, ..., 13.65861225, 12.79308128, 11.31168747],\n",
" [12.62584686, 12.53506088, 10.96861362, ..., 13.32526493, 12.41560936, 10.71458912],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[11.00003052, 11.35529137, 9.56384087 , ..., 10.06063652, 10.16322994, 9.43149185 ],\n",
" [10.44556236, 9.81155300 , 5.49400425 , ..., 10.84116268, 11.02734756, 9.42253590 ],\n",
" [10.23620510, 10.23321152, 7.99466419 , ..., 10.93381882, 10.28395081, 10.00841141],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[14.10379314, 14.50375748, 12.87825108, ..., 12.68065739, 12.62359715, 11.53773308],\n",
" [13.84964657, 13.15079498, 10.67198086, ..., 13.24875164, 13.45796680, 10.97363472],\n",
" [13.19808197, 13.23482990, 11.65900230, ..., 12.70375061, 12.41395664, 11.88668156],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.85676289, 12.82410812, 11.67961884, ..., 14.12018299, 15.14850044, 13.80065727],\n",
" [13.19532776, 13.40243340, 11.43492508, ..., 13.29144669, 13.70278549, 12.67841339],\n",
" [13.02196407, 12.92111111, 11.43998623, ..., 12.71165752, 13.16518497, 11.92028046],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.90661621, 13.40162563, 13.01394463, ..., 13.84056377, 14.11240959, 13.21227264],\n",
" [14.06642914, 14.06922340, 12.52955723, ..., 13.55829811, 13.60157204, 13.50268650],\n",
" [12.58881378, 12.94780254, 11.75758171, ..., 14.29055786, 14.12165928, 13.02695847],\n",
" ...,\n",
" [16.20891571, 16.42290306, 14.94398117, ..., 12.86083794, 12.63515949, 11.67581463],\n",
" [13.49345875, 14.14656067, 13.16498375, ..., 13.28024578, 13.40956783, 12.70357513],\n",
" [15.56265163, 15.92387581, 14.90643024, ..., 13.45694065, 13.44703197, 12.81099033]]])\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [184, 194, 196, 204, 207])\n"
]
}
],
"source": [
"keep_transcription_text=False\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" unit_type=args.unit_type,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" spm_model_prefix=args.spm_model_prefix,\n",
" augmentation_config='{}',\n",
" max_input_len=27.0,\n",
" min_input_len=0.0,\n",
" max_output_len=float('inf'),\n",
" min_output_len=0.0,\n",
" max_output_input_ratio=float('inf'),\n",
" min_output_input_ratio=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=keep_transcription_text,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" num_workers=0,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)\n",
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" print('audio len:', audio_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "competitive-mounting",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"id": "knowing-military",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 1, 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False, 'stride_ms': 10.0, 'window_ms': 25.0, 'sample_rate': 16000, 'manifest_path': 'examples/aishell/s1/data/manifest.train', 'output_path': 'examples/aishell/s1/data/mean_std.npz'}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"\n",
"add_arg('num_samples', int, 1, \"# of samples to for statistics.\")\n",
"add_arg('specgram_type', str, 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc, fbank.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n",
"add_arg('delta_delta', bool, False,\"Audio feature with delta delta.\")\n",
"add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n",
"add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n",
"add_arg('sample_rate', int, 16000, \"target sample rate.\")\n",
"add_arg('manifest_path', str,\n",
" 'examples/aishell/s1/data/manifest.train',\n",
" \"Filepath of manifest to compute normalizer's mean and stddev.\")\n",
"add_arg('output_path', str,\n",
" 'examples/aishell/s1/data/mean_std.npz',\n",
" \"Filepath of write mean and stddev to (.npz).\")\n",
"args = parser.parse_args([])\n",
"print(vars(args))\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "unnecessary-province",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"\n",
"\n",
"def mean(args):\n",
" augmentation_pipeline = AugmentationPipeline('{}')\n",
" audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20,\n",
" dither=0.0)\n",
"\n",
" def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
" normalizer = FeatureNormalizer(\n",
" mean_std_filepath=None,\n",
" manifest_path=args.manifest_path,\n",
" featurize_func=augment_and_featurize,\n",
" num_samples=args.num_samples)\n",
" normalizer.write_to_file(args.output_path)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "interested-camping",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"[54. 90. 77. ... 58. 58. 61.]\n",
"29746\n",
"fbank\n",
"[54 90 77 ... 58 58 61] int16\n",
"(184, 80) float64\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n",
"(184, 80) float64\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n"
]
}
],
"source": [
"wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/test/S0916/BAC009S0916W0426.wav'\n",
"test='祝可爱的你'\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=False,\n",
" target_dB=-20,\n",
" dither=0.0)\n",
"samples = AudioSegment.from_file(wav)\n",
"print(samples._samples)\n",
"print(samples._samples * 2**15)\n",
"print(len(samples._samples))\n",
"feat = audio_featurizer.featurize(samples, False, False)\n",
"feat = feat.T\n",
"print(feat.shape, feat.dtype)\n",
"print(feat)\n",
"\n",
"from python_speech_features import logfbank\n",
"max_freq = args.sample_rate / 2\n",
"fbank_feat = logfbank(\n",
" signal=samples.to('int16'),\n",
" samplerate=args.sample_rate,\n",
" winlen=0.001 * args.window_ms,\n",
" winstep=0.001 * args.stride_ms,\n",
" nfilt=args.feat_dim,\n",
" nfft=512,\n",
" lowfreq=20,\n",
" highfreq=max_freq,\n",
" preemph=0.97,\n",
" dither=0.0,\n",
" wintype='povey')\n",
"print(fbank_feat.shape, fbank_feat.dtype)\n",
"print(fbank_feat)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "numeric-analyst",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(184, 160)\n",
"[ 8.59522397 8.43148278 8.36414052 8.45487173 8.31761643 8.04843683\n",
" 8.01683696 7.6574614 7.95521932 8.22945157 10.20138275 9.0447775\n",
" 9.14763398 9.18184349 9.03801065 9.04852307 8.67706728 8.71894271\n",
" 9.54553655 9.19535135 8.76413076 8.47828946 8.52586143 8.49469288\n",
" 8.72461247 8.28562879 8.11581393 7.99922156 7.91023364 8.04142296\n",
" 7.89762773 7.76257636 8.32043745 8.01592886 8.34109665 8.90115454\n",
" 8.48246945 7.98658664 8.05745122 8.11384088 8.18864479 8.8091827\n",
" 11.8067711 13.25258218 14.44311795 13.90515283 14.00120623 13.99801252\n",
" 13.81595394 13.6379904 13.3574897 13.14933334 12.96518543 13.02601156\n",
" 12.70246737 12.54410834 12.15615068 11.86574681 11.67497882 10.79645481\n",
" 10.48150035 10.03758575 10.05637027 9.92891308 10.06923218 12.43382431\n",
" 12.71428321 14.33135052 13.94470959 14.29188291 14.11483993 14.03496606\n",
" 13.78167331 13.66701466 14.40308625 14.73934137 15.09569382 14.89565815\n",
" 15.10519995 14.94383582 15.03275563 15.42194679 15.29219967 15.41602274\n",
" 15.39242545 15.76836177 16.259222 16.47777231 17.03366795 17.46165793\n",
" 17.52596217 17.78844031 17.99878075 18.11446843 17.95761578 17.99900337\n",
" 17.86282737 17.7290163 17.47686504 17.43425516 17.07750485 16.64395242\n",
" 15.68217043 14.90058399 14.45645737 14.0405463 14.89549542 16.00405781\n",
" 16.27301689 16.37572895 16.31219037 16.31765447 16.44819716 16.36281089\n",
" 16.24932823 15.79302555 14.76361963 13.95761882 13.48917053 13.45543501\n",
" 13.00091327 13.13854248 13.74596395 13.86340629 14.00656109 13.77432101\n",
" 13.64267001 13.35742634 13.23042234 12.97916104 12.80694468 12.70005006\n",
" 13.2802483 13.22644525 13.14579624 13.02536594 13.36511022 11.37167205\n",
" 12.11598045 12.47619798 12.83885973 11.63880287 11.42083924 11.08747705\n",
" 11.04093403 11.11263149 10.74353319 10.58734669 10.46180738 10.34157335\n",
" 9.63131146 9.70582692 9.29059204 8.94583657 8.66065094 8.46799095\n",
" 8.25064103 8.30239167 8.19463371 8.12104567 8.02731234 8.06412715\n",
" 7.84889951 7.73090283 7.74119562 7.85444657 7.80717312 7.7129933\n",
" 7.84087442 7.77907788 7.60660865 7.55051479 7.458385 7.496416\n",
" 7.69519793 7.49086759 7.32199493 8.01617458 7.58525375 7.06661122\n",
" 6.94653756 7.19874283 7.28515661 7.17574078]\n",
"(184,)\n",
"(184,)\n",
"[1.48370471 1.52174523 1.46984238 1.67010478 1.88757689 1.68825992\n",
" 1.74270259 1.55497318 1.29200818 1.68446481 1.88133219 1.97138928\n",
" 2.15910096 2.3149476 1.9820247 2.07694378 1.93498835 2.01493974\n",
" 2.39156824 2.02396518 1.69586449 1.63808752 1.64020228 1.43573473\n",
" 1.93092656 1.37466294 1.34704929 1.59600739 1.03960441 1.45276496\n",
" 1.59360131 1.57466343 1.89491479 1.79333746 1.32701974 1.49441767\n",
" 1.51466756 1.63497989 1.42858074 1.51135396 1.61077201 1.81066387\n",
" 1.83367783 2.3507094 2.87885378 3.26231227 2.1313117 1.98557548\n",
" 1.99105426 2.26150533 2.34298751 2.44621608 2.39201042 2.41226503\n",
" 2.5142992 3.03777565 2.81592295 2.75117863 2.78324175 2.68819666\n",
" 2.8945782 2.84464168 2.680973 2.78397395 2.47996808 1.71829563\n",
" 1.60636949 1.65992483 1.38122631 1.74831825 2.16006884 1.68076185\n",
" 1.69329487 1.44929837 1.63763312 1.80101076 2.01166253 2.03254244\n",
" 1.9583913 2.04542255 2.00859694 2.16600883 2.16095629 1.97541122\n",
" 2.13807632 2.06386436 2.2154187 2.84205688 2.54862449 2.64321545\n",
" 2.6805773 2.52300146 2.53209001 2.54682059 2.4521937 2.43155532\n",
" 2.42571275 2.23421289 2.23164529 2.23597192 2.14215121 2.10406703\n",
" 2.07962874 1.88506161 1.80092372 1.61156092 1.77426835 1.98765563\n",
" 2.0356793 1.87964187 1.779513 1.87187681 1.76463632 1.70978684\n",
" 1.76471778 1.75604749 1.62792552 1.73929352 1.6887024 1.8677704\n",
" 2.17342368 2.08166072 2.14567453 2.15936953 2.18351006 2.41010388\n",
" 2.26101752 2.25468001 2.23739715 2.15395133 2.04547813 1.92038843\n",
" 1.85491264 1.91905927 2.16709365 1.99924152 2.1850471 2.55461622\n",
" 2.72476673 1.69682926 1.73249614 2.06992695 2.1210591 1.66854454\n",
" 1.63907505 1.32203822 1.38992558 1.2436937 1.17932877 1.02963653\n",
" 1.26085036 1.16997132 1.09339504 1.14188689 1.18675772 1.31859788\n",
" 1.21746591 1.3872131 1.26095274 1.34885761 1.46633543 1.64506975\n",
" 1.36013821 1.45574721 1.43766588 1.65119054 1.57163772 1.55082968\n",
" 1.29413316 1.38351736 1.64234673 1.57186432 1.45381083 1.71204761\n",
" 1.51828607 1.30639985 1.32928395 1.49004237 1.6057589 1.81815735\n",
" 1.67784678 1.72180861 1.60703743 1.64850255]\n"
]
}
],
"source": [
"a = np.hstack([feat, feat])\n",
"print(a.shape)\n",
"m = np.mean(a, axis=1)\n",
"print(m)\n",
"print(m.shape)\n",
"std = np.std(a, axis=1)\n",
"print(std.shape)\n",
"print(std)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "nonprofit-potato",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 18,
"id": "hispanic-ethics",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchaudio\n",
"import torchaudio.compliance.kaldi as kaldi\n",
"import torchaudio.sox_effects as sox_effects\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"torchaudio.set_audio_backend(\"sox\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "changing-calvin",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 29746])\n",
"tensor([[54., 90., 77., ..., 58., 58., 61.]])\n",
"(184, 80)\n",
"[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n",
" [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n",
" [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n",
" ...\n",
" [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n",
" [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n",
" [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n",
"-----------\n",
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"(184, 80)\n",
"[[-10.177039 -10.717326 -15.46954 ... -10.546229 -11.897424 -12.987689]\n",
" [ -9.750411 -10.476343 -14.485752 ... -9.557108 -10.436023 -11.955799]\n",
" [-10.525113 -10.798049 -13.46475 ... -10.343097 -11.101464 -12.832712]\n",
" ...\n",
" [-10.649446 -10.907673 -14.056403 ... -10.578607 -11.790988 -12.038239]\n",
" [-10.816959 -11.114918 -12.88781 ... -10.570049 -11.199847 -13.101528]\n",
" [-14.320845 -13.03106 -13.036756 ... -10.829194 -11.171779 -12.634331]]\n",
"**************\n",
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"[54. 90. 77. ... 58. 58. 61.] float32\n",
"(184, 80)\n",
"[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n",
" [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n",
" [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n",
" ...\n",
" [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n",
" [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n",
" [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: torchaudio.backend.sox_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use \"torchaudio.load\".\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
}
],
"source": [
"waveform, sample_rate = torchaudio.load_wav(wav)\n",
"print(waveform.shape)\n",
"print(waveform)\n",
"mat = kaldi.fbank(\n",
" waveform,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)\n",
"\n",
"print('-----------')\n",
"print(samples._samples)\n",
"aud = torch.tensor(samples._samples).view(1, -1)\n",
"mat = kaldi.fbank(\n",
" aud,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)\n",
"\n",
"print('**************')\n",
"print(samples._samples)\n",
"tmp = samples.to('int16').astype('float32')\n",
"print(tmp, tmp.dtype)\n",
"aud = torch.tensor(tmp).view(1, -1)\n",
"mat = kaldi.fbank(\n",
" aud,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "buried-dependence",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "silver-printing",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 20,
"id": "outer-space",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(29746,)\n",
"[54 90 77 ... 58 58 61]\n",
"(184, 80)\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n",
"(184, 13)\n",
"[[ 14.73775998 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n",
" 8.86862748]\n",
" [ 15.31274834 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n",
" 6.78353115]\n",
" [ 13.82218765 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n",
" -0.05919222]\n",
" ...\n",
" [ 13.5837844 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n",
" 5.59045842]\n",
" [ 13.75757034 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n",
" 12.0785146 ]\n",
" [ 11.92813809 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n",
" 0.88337965]]\n"
]
}
],
"source": [
"from python_speech_features import mfcc\n",
"from python_speech_features import delta\n",
"from python_speech_features import logfbank\n",
"import scipy.io.wavfile as iowav\n",
"\n",
"(rate,sig) = iowav.read(wav)\n",
"print(sig.shape)\n",
"print(sig)\n",
"\n",
"# note that generally nfilt=40 is used for speech recognition\n",
"fbank_feat = logfbank(sig,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n",
"print(fbank_feat.shape)\n",
"print(fbank_feat)\n",
"\n",
"# the computed fbank coefficents of english.wav with dimension [110,23]\n",
"# [ 12.2865\t12.6906\t13.1765\t15.714\t16.064\t15.7553\t16.5746\t16.9205\t16.6472\t16.1302\t16.4576\t16.7326\t16.8864\t17.7215\t18.88\t19.1377\t19.1495\t18.6683\t18.3886\t20.3506\t20.2772\t18.8248\t18.1899\n",
"# 11.9198\t13.146\t14.7215\t15.8642\t17.4288\t16.394\t16.8238\t16.1095\t16.4297\t16.6331\t16.3163\t16.5093\t17.4981\t18.3429\t19.6555\t19.6263\t19.8435\t19.0534\t19.001\t20.0287\t19.7707\t19.5852\t19.1112\n",
"# ...\n",
"# ...\n",
"# the same with that using kaldi commands: compute-fbank-feats --dither=0.0\n",
"\n",
"mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')\n",
"print(mfcc_feat.shape)\n",
"print(mfcc_feat)\n",
"\n",
"# the computed mfcc coefficents of english.wav with dimension [110,13]\n",
"# [ 17.1337\t-23.3651\t-7.41751\t-7.73686\t-21.3682\t-8.93884\t-3.70843\t4.68346\t-16.0676\t12.782\t-7.24054\t8.25089\t10.7292\n",
"# 17.1692\t-23.3028\t-5.61872\t-4.0075\t-23.287\t-20.6101\t-5.51584\t-6.15273\t-14.4333\t8.13052\t-0.0345329\t2.06274\t-0.564298\n",
"# ...\n",
"# ...\n",
"# the same with that using kaldi commands: compute-mfcc-feats --dither=0.0"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "sporting-school",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(184, 80)\n",
"[[-10.17703627 -10.71732606 -15.46954014 ... -10.54623152 -11.89742148\n",
" -12.98770428]\n",
" [ -9.75040771 -10.47634331 -14.48575413 ... -9.55710616 -10.43602673\n",
" -11.95581463]\n",
" [-10.52510987 -10.79804975 -13.46475161 ... -10.34309947 -11.10146239\n",
" -12.83273051]\n",
" ...\n",
" [-10.64944197 -10.90767335 -14.05640404 ... -10.57860915 -11.7909807\n",
" -12.03825021]\n",
" [-10.8169558 -11.11491806 -12.88781116 ... -10.57004889 -11.19985048\n",
" -13.10154358]\n",
" [-14.32084168 -13.03106051 -13.03675699 ... -10.82919465 -11.17177892\n",
" -12.63434434]]\n",
"(184, 13)\n",
"[[ -6.05665544 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n",
" 8.86862748]\n",
" [ -5.48166707 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n",
" 6.78353115]\n",
" [ -6.97222776 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n",
" -0.05919222]\n",
" ...\n",
" [ -7.21063102 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n",
" 5.59045842]\n",
" [ -7.03684508 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n",
" 12.0785146 ]\n",
" [ -8.86627732 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n",
" 0.88337965]]\n"
]
}
],
"source": [
"fbank_feat = logfbank(samples._samples,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n",
"print(fbank_feat.shape)\n",
"print(fbank_feat)\n",
"\n",
"mfcc_feat = mfcc(samples._samples,dither=0,useEnergy=True,wintype='povey')\n",
"print(mfcc_feat.shape)\n",
"print(mfcc_feat)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "restricted-license",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "specialized-threat",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 147,
"id": "extensive-venice",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/\n"
]
},
{
"data": {
"text/plain": [
"'/'"
]
},
"execution_count": 147,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 148,
"id": "correct-window",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"manifest.dev\t manifest.test-clean\t manifest.train\r\n",
"manifest.dev.raw manifest.test-clean.raw manifest.train.raw\r\n"
]
}
],
"source": [
"!ls /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/"
]
},
{
"cell_type": "code",
"execution_count": 149,
"id": "exceptional-cheese",
"metadata": {},
"outputs": [],
"source": [
"dev_data='/workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev'"
]
},
{
"cell_type": "code",
"execution_count": 150,
"id": "extraordinary-orleans",
"metadata": {},
"outputs": [],
"source": [
"from deepspeech.frontend.utility import read_manifest"
]
},
{
"cell_type": "code",
"execution_count": 151,
"id": "returning-lighter",
"metadata": {},
"outputs": [],
"source": [
"dev_json = read_manifest(dev_data)"
]
},
{
"cell_type": "code",
"execution_count": 152,
"id": "western-founder",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev/deltafalse/feats.1.ark:16',\n",
" 'name': 'input1',\n",
" 'shape': [1063, 83]}],\n",
" 'output': [{'name': 'target1',\n",
" 'shape': [41, 5002],\n",
" 'text': 'AS I APPROACHED THE CITY I HEARD BELLS RINGING AND A '\n",
" 'LITTLE LATER I FOUND THE STREETS ASTIR WITH THRONGS OF '\n",
" 'WELL DRESSED PEOPLE IN FAMILY GROUPS WENDING THEIR WAY '\n",
" 'HITHER AND THITHER',\n",
" 'token': '▁AS ▁I ▁APPROACHED ▁THE ▁CITY ▁I ▁HEARD ▁BELL S ▁RING '\n",
" 'ING ▁AND ▁A ▁LITTLE ▁LATER ▁I ▁FOUND ▁THE ▁STREETS ▁AS '\n",
" 'T IR ▁WITH ▁THRONG S ▁OF ▁WELL ▁DRESSED ▁PEOPLE ▁IN '\n",
" '▁FAMILY ▁GROUP S ▁WE ND ING ▁THEIR ▁WAY ▁HITHER ▁AND '\n",
" '▁THITHER',\n",
" 'tokenid': '713 2458 676 4502 1155 2458 2351 849 389 3831 206 627 '\n",
" '482 2812 2728 2458 2104 4502 4316 713 404 212 4925 '\n",
" '4549 389 3204 4861 1677 3339 2495 1950 2279 389 4845 '\n",
" '302 206 4504 4843 2394 627 4526'}],\n",
" 'utt': '116-288045-0000',\n",
" 'utt2spk': '116-288045'}\n",
"5542\n",
"<class 'list'>\n"
]
}
],
"source": [
"from pprint import pprint\n",
"pprint(dev_json[0])\n",
"print(len(dev_json))\n",
"print(type(dev_json))"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "motivated-receptor",
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"import itertools\n",
"\n",
"import numpy as np\n",
"\n",
"from deepspeech.utils.log import Log\n",
"\n",
"__all__ = [\"make_batchset\"]\n",
"\n",
"logger = Log(__name__).getlog()\n",
"\n",
"\n",
"def batchfy_by_seq(\n",
" sorted_data,\n",
" batch_size,\n",
" max_length_in,\n",
" max_length_out,\n",
" min_batch_size=1,\n",
" shortest_first=False,\n",
" ikey=\"input\",\n",
" iaxis=0,\n",
" okey=\"output\",\n",
" oaxis=0, ):\n",
" \"\"\"Make batch set from json dictionary\n",
"\n",
" :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
" :param int batch_size: batch size\n",
" :param int max_length_in: maximum length of input to decide adaptive batch size\n",
" :param int max_length_out: maximum length of output to decide adaptive batch size\n",
" :param int min_batch_size: mininum batch size (for multi-gpu)\n",
" :param bool shortest_first: Sort from batch with shortest samples\n",
" to longest if true, otherwise reverse\n",
" :param str ikey: key to access input\n",
" (for ASR ikey=\"input\", for TTS, MT ikey=\"output\".)\n",
" :param int iaxis: dimension to access input\n",
" (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
" :param str okey: key to access output\n",
" (for ASR, MT okey=\"output\". for TTS okey=\"input\".)\n",
" :param int oaxis: dimension to access output\n",
" (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)\n",
" :return: List[List[Tuple[str, dict]]] list of batches\n",
" \"\"\"\n",
" if batch_size <= 0:\n",
" raise ValueError(f\"Invalid batch_size={batch_size}\")\n",
"\n",
" # check #utts is more than min_batch_size\n",
" if len(sorted_data) < min_batch_size:\n",
" raise ValueError(\n",
" f\"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).\"\n",
" )\n",
"\n",
" # make list of minibatches\n",
" minibatches = []\n",
" start = 0\n",
" while True:\n",
" _, info = sorted_data[start]\n",
" ilen = int(info[ikey][iaxis][\"shape\"][0])\n",
" olen = (int(info[okey][oaxis][\"shape\"][0]) if oaxis >= 0 else\n",
" max(map(lambda x: int(x[\"shape\"][0]), info[okey])))\n",
" factor = max(int(ilen / max_length_in), int(olen / max_length_out))\n",
" # change batchsize depending on the input and output length\n",
" # if ilen = 1000 and max_length_in = 800\n",
" # then b = batchsize / 2\n",
" # and max(min_batches, .) avoids batchsize = 0\n",
" bs = max(min_batch_size, int(batch_size / (1 + factor)))\n",
" end = min(len(sorted_data), start + bs)\n",
" minibatch = sorted_data[start:end]\n",
" if shortest_first:\n",
" minibatch.reverse()\n",
"\n",
" # check each batch is more than minimum batchsize\n",
" if len(minibatch) < min_batch_size:\n",
" mod = min_batch_size - len(minibatch) % min_batch_size\n",
" additional_minibatch = [\n",
" sorted_data[i] for i in np.random.randint(0, start, mod)\n",
" ]\n",
" if shortest_first:\n",
" additional_minibatch.reverse()\n",
" minibatch.extend(additional_minibatch)\n",
" minibatches.append(minibatch)\n",
"\n",
" if end == len(sorted_data):\n",
" break\n",
" start = end\n",
"\n",
" # batch: List[List[Tuple[str, dict]]]\n",
" return minibatches\n",
"\n",
"\n",
"def batchfy_by_bin(\n",
" sorted_data,\n",
" batch_bins,\n",
" num_batches=0,\n",
" min_batch_size=1,\n",
" shortest_first=False,\n",
" ikey=\"input\",\n",
" okey=\"output\", ):\n",
" \"\"\"Make variably sized batch set, which maximizes\n",
"\n",
" the number of bins up to `batch_bins`.\n",
"\n",
" :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n",
" :param int batch_bins: Maximum frames of a batch\n",
" :param int num_batches: # number of batches to use (for debug)\n",
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
" :param int test: Return only every `test` batches\n",
" :param bool shortest_first: Sort from batch with shortest samples\n",
" to longest if true, otherwise reverse\n",
"\n",
" :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
" :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
"\n",
" :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
" \"\"\"\n",
" if batch_bins <= 0:\n",
" raise ValueError(f\"invalid batch_bins={batch_bins}\")\n",
" length = len(sorted_data)\n",
" idim = int(sorted_data[0][1][ikey][0][\"shape\"][1])\n",
" odim = int(sorted_data[0][1][okey][0][\"shape\"][1])\n",
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
" minibatches = []\n",
" start = 0\n",
" n = 0\n",
" while True:\n",
" # Dynamic batch size depending on size of samples\n",
" b = 0\n",
" next_size = 0\n",
" max_olen = 0\n",
" while next_size < batch_bins and (start + b) < length:\n",
" ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0]) * idim\n",
" olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0]) * odim\n",
" if olen > max_olen:\n",
" max_olen = olen\n",
" next_size = (max_olen + ilen) * (b + 1)\n",
" if next_size <= batch_bins:\n",
" b += 1\n",
" elif next_size == 0:\n",
" raise ValueError(\n",
" f\"Can't fit one sample in batch_bins ({batch_bins}): \"\n",
" f\"Please increase the value\")\n",
" end = min(length, start + max(min_batch_size, b))\n",
" batch = sorted_data[start:end]\n",
" if shortest_first:\n",
" batch.reverse()\n",
" minibatches.append(batch)\n",
" # Check for min_batch_size and fixes the batches if needed\n",
" i = -1\n",
" while len(minibatches[i]) < min_batch_size:\n",
" missing = min_batch_size - len(minibatches[i])\n",
" if -i == len(minibatches):\n",
" minibatches[i + 1].extend(minibatches[i])\n",
" minibatches = minibatches[1:]\n",
" break\n",
" else:\n",
" minibatches[i].extend(minibatches[i - 1][:missing])\n",
" minibatches[i - 1] = minibatches[i - 1][missing:]\n",
" i -= 1\n",
" if end == length:\n",
" break\n",
" start = end\n",
" n += 1\n",
" if num_batches > 0:\n",
" minibatches = minibatches[:num_batches]\n",
" lengths = [len(x) for x in minibatches]\n",
" logger.info(\n",
" str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
" + \" to \" + str(max(lengths)) + \" samples \" + \"(avg \" + str(\n",
" int(np.mean(lengths))) + \" samples).\")\n",
" return minibatches\n",
"\n",
"\n",
"def batchfy_by_frame(\n",
" sorted_data,\n",
" max_frames_in,\n",
" max_frames_out,\n",
" max_frames_inout,\n",
" num_batches=0,\n",
" min_batch_size=1,\n",
" shortest_first=False,\n",
" ikey=\"input\",\n",
" okey=\"output\", ):\n",
" \"\"\"Make variable batch set, which maximizes the number of frames to max_batch_frame.\n",
"\n",
" :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json\n",
" :param int max_frames_in: Maximum input frames of a batch\n",
" :param int max_frames_out: Maximum output frames of a batch\n",
" :param int max_frames_inout: Maximum input+output frames of a batch\n",
" :param int num_batches: # number of batches to use (for debug)\n",
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
" :param int test: Return only every `test` batches\n",
" :param bool shortest_first: Sort from batch with shortest samples\n",
" to longest if true, otherwise reverse\n",
"\n",
" :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n",
" :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n",
"\n",
" :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n",
" \"\"\"\n",
" if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:\n",
" raise ValueError(\n",
" \"At least, one of `--batch-frames-in`, `--batch-frames-out` or \"\n",
" \"`--batch-frames-inout` should be > 0\")\n",
" length = len(sorted_data)\n",
" minibatches = []\n",
" start = 0\n",
" end = 0\n",
" while end != length:\n",
" # Dynamic batch size depending on size of samples\n",
" b = 0\n",
" max_olen = 0\n",
" max_ilen = 0\n",
" while (start + b) < length:\n",
" ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0])\n",
" if ilen > max_frames_in and max_frames_in != 0:\n",
" raise ValueError(\n",
" f\"Can't fit one sample in --batch-frames-in ({max_frames_in}): \"\n",
" f\"Please increase the value\")\n",
" olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0])\n",
" if olen > max_frames_out and max_frames_out != 0:\n",
" raise ValueError(\n",
" f\"Can't fit one sample in --batch-frames-out ({max_frames_out}): \"\n",
" f\"Please increase the value\")\n",
" if ilen + olen > max_frames_inout and max_frames_inout != 0:\n",
" raise ValueError(\n",
" f\"Can't fit one sample in --batch-frames-out ({max_frames_inout}): \"\n",
" f\"Please increase the value\")\n",
" max_olen = max(max_olen, olen)\n",
" max_ilen = max(max_ilen, ilen)\n",
" in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0\n",
" out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0\n",
" inout_ok = (max_ilen + max_olen) * (\n",
" b + 1) <= max_frames_inout or max_frames_inout == 0\n",
" if in_ok and out_ok and inout_ok:\n",
" # add more seq in the minibatch\n",
" b += 1\n",
" else:\n",
" # no more seq in the minibatch\n",
" break\n",
" end = min(length, start + b)\n",
" batch = sorted_data[start:end]\n",
" if shortest_first:\n",
" batch.reverse()\n",
" minibatches.append(batch)\n",
" # Check for min_batch_size and fixes the batches if needed\n",
" i = -1\n",
" while len(minibatches[i]) < min_batch_size:\n",
" missing = min_batch_size - len(minibatches[i])\n",
" if -i == len(minibatches):\n",
" minibatches[i + 1].extend(minibatches[i])\n",
" minibatches = minibatches[1:]\n",
" break\n",
" else:\n",
" minibatches[i].extend(minibatches[i - 1][:missing])\n",
" minibatches[i - 1] = minibatches[i - 1][missing:]\n",
" i -= 1\n",
" start = end\n",
" if num_batches > 0:\n",
" minibatches = minibatches[:num_batches]\n",
" lengths = [len(x) for x in minibatches]\n",
" logger.info(\n",
" str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n",
" + \" to \" + str(max(lengths)) + \" samples\" + \"(avg \" + str(\n",
" int(np.mean(lengths))) + \" samples).\")\n",
"\n",
" return minibatches\n",
"\n",
"\n",
"def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,\n",
" shortest_first):\n",
" import random\n",
"\n",
" logger.info(\"use shuffled batch.\")\n",
" sorted_data = random.sample(data.items(), len(data.items()))\n",
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
" # make list of minibatches\n",
" minibatches = []\n",
" start = 0\n",
" while True:\n",
" end = min(len(sorted_data), start + batch_size)\n",
" # check each batch is more than minimum batchsize\n",
" minibatch = sorted_data[start:end]\n",
" if shortest_first:\n",
" minibatch.reverse()\n",
" if len(minibatch) < min_batch_size:\n",
" mod = min_batch_size - len(minibatch) % min_batch_size\n",
" additional_minibatch = [\n",
" sorted_data[i] for i in np.random.randint(0, start, mod)\n",
" ]\n",
" if shortest_first:\n",
" additional_minibatch.reverse()\n",
" minibatch.extend(additional_minibatch)\n",
" minibatches.append(minibatch)\n",
" if end == len(sorted_data):\n",
" break\n",
" start = end\n",
"\n",
" # for debugging\n",
" if num_batches > 0:\n",
" minibatches = minibatches[:num_batches]\n",
" logger.info(\"# minibatches: \" + str(len(minibatches)))\n",
" return minibatches\n",
"\n",
"\n",
"BATCH_COUNT_CHOICES = [\"auto\", \"seq\", \"bin\", \"frame\"]\n",
"BATCH_SORT_KEY_CHOICES = [\"input\", \"output\", \"shuffle\"]\n",
"\n",
"\n",
"def make_batchset(\n",
" data,\n",
" batch_size=0,\n",
" max_length_in=float(\"inf\"),\n",
" max_length_out=float(\"inf\"),\n",
" num_batches=0,\n",
" min_batch_size=1,\n",
" shortest_first=False,\n",
" batch_sort_key=\"input\",\n",
" count=\"auto\",\n",
" batch_bins=0,\n",
" batch_frames_in=0,\n",
" batch_frames_out=0,\n",
" batch_frames_inout=0,\n",
" iaxis=0,\n",
" oaxis=0, ):\n",
" \"\"\"Make batch set from json dictionary\n",
"\n",
" if utts have \"category\" value,\n",
"\n",
" >>> data = {'utt1': {'category': 'A', 'input': ...},\n",
" ... 'utt2': {'category': 'B', 'input': ...},\n",
" ... 'utt3': {'category': 'B', 'input': ...},\n",
" ... 'utt4': {'category': 'A', 'input': ...}}\n",
" >>> make_batchset(data, batchsize=2, ...)\n",
" [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]\n",
"\n",
" Note that if any utts doesn't have \"category\",\n",
" perform as same as batchfy_by_{count}\n",
"\n",
" :param List[Dict[str, Any]] data: dictionary loaded from data.json\n",
" :param int batch_size: maximum number of sequences in a minibatch.\n",
" :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.\n",
" :param int batch_frames_in: maximum number of input frames in a minibatch.\n",
" :param int batch_frames_out: maximum number of output frames in a minibatch.\n",
" :param int batch_frames_out: maximum number of input+output frames in a minibatch.\n",
" :param str count: strategy to count maximum size of batch.\n",
" For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES\n",
"\n",
" :param int max_length_in: maximum length of input to decide adaptive batch size\n",
" :param int max_length_out: maximum length of output to decide adaptive batch size\n",
" :param int num_batches: # number of batches to use (for debug)\n",
" :param int min_batch_size: minimum batch size (for multi-gpu)\n",
" :param bool shortest_first: Sort from batch with shortest samples\n",
" to longest if true, otherwise reverse\n",
" :param str batch_sort_key: how to sort data before creating minibatches\n",
" [\"input\", \"output\", \"shuffle\"]\n",
" :param bool swap_io: if True, use \"input\" as output and \"output\"\n",
" as input in `data` dict\n",
" :param bool mt: if True, use 0-axis of \"output\" as output and 1-axis of \"output\"\n",
" as input in `data` dict\n",
" :param int iaxis: dimension to access input\n",
" (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n",
" :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,\n",
" reserved for future research, -1 means all axis.)\n",
" :return: List[List[Tuple[str, dict]]] list of batches\n",
" \"\"\"\n",
"\n",
" # check args\n",
" if count not in BATCH_COUNT_CHOICES:\n",
" raise ValueError(\n",
" f\"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}\")\n",
" if batch_sort_key not in BATCH_SORT_KEY_CHOICES:\n",
" raise ValueError(f\"arg 'batch_sort_key' ({batch_sort_key}) should be \"\n",
" f\"one of {BATCH_SORT_KEY_CHOICES}\")\n",
"\n",
" ikey = \"input\"\n",
" okey = \"output\"\n",
" batch_sort_axis = 0 # index of list \n",
"\n",
" if count == \"auto\":\n",
" if batch_size != 0:\n",
" count = \"seq\"\n",
" elif batch_bins != 0:\n",
" count = \"bin\"\n",
" elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:\n",
" count = \"frame\"\n",
" else:\n",
" raise ValueError(\n",
" f\"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}\"\n",
" )\n",
" logger.info(f\"count is auto detected as {count}\")\n",
"\n",
" if count != \"seq\" and batch_sort_key == \"shuffle\":\n",
" raise ValueError(\n",
" \"batch_sort_key=shuffle is only available if batch_count=seq\")\n",
"\n",
" category2data = {} # Dict[str, dict]\n",
" for v in data:\n",
" k = v['utt']\n",
" category2data.setdefault(v.get(\"category\"), {})[k] = v\n",
"\n",
" batches_list = [] # List[List[List[Tuple[str, dict]]]]\n",
" for d in category2data.values():\n",
" if batch_sort_key == \"shuffle\":\n",
" batches = batchfy_shuffle(d, batch_size, min_batch_size,\n",
" num_batches, shortest_first)\n",
" batches_list.append(batches)\n",
" continue\n",
"\n",
" # sort it by input lengths (long to short)\n",
" sorted_data = sorted(\n",
" d.items(),\n",
" key=lambda data: int(data[1][batch_sort_key][batch_sort_axis][\"shape\"][0]),\n",
" reverse=not shortest_first, )\n",
" logger.info(\"# utts: \" + str(len(sorted_data)))\n",
" \n",
" if count == \"seq\":\n",
" batches = batchfy_by_seq(\n",
" sorted_data,\n",
" batch_size=batch_size,\n",
" max_length_in=max_length_in,\n",
" max_length_out=max_length_out,\n",
" min_batch_size=min_batch_size,\n",
" shortest_first=shortest_first,\n",
" ikey=ikey,\n",
" iaxis=iaxis,\n",
" okey=okey,\n",
" oaxis=oaxis, )\n",
" if count == \"bin\":\n",
" batches = batchfy_by_bin(\n",
" sorted_data,\n",
" batch_bins=batch_bins,\n",
" min_batch_size=min_batch_size,\n",
" shortest_first=shortest_first,\n",
" ikey=ikey,\n",
" okey=okey, )\n",
" if count == \"frame\":\n",
" batches = batchfy_by_frame(\n",
" sorted_data,\n",
" max_frames_in=batch_frames_in,\n",
" max_frames_out=batch_frames_out,\n",
" max_frames_inout=batch_frames_inout,\n",
" min_batch_size=min_batch_size,\n",
" shortest_first=shortest_first,\n",
" ikey=ikey,\n",
" okey=okey, )\n",
" batches_list.append(batches)\n",
"\n",
" if len(batches_list) == 1:\n",
" batches = batches_list[0]\n",
" else:\n",
" # Concat list. This way is faster than \"sum(batch_list, [])\"\n",
" batches = list(itertools.chain(*batches_list))\n",
"\n",
" # for debugging\n",
" if num_batches > 0:\n",
" batches = batches[:num_batches]\n",
" logger.info(\"# minibatches: \" + str(len(batches)))\n",
"\n",
" # batch: List[List[Tuple[str, dict]]]\n",
" return batches\n"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "acquired-hurricane",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO 2021/08/18 06:57:10 1445365138.py:284] use shuffled batch.\n",
"[INFO 2021/08/18 06:57:10 1445365138.py:286] # utts: 5542\n",
"[INFO 2021/08/18 06:57:10 1445365138.py:468] # minibatches: 555\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"555\n"
]
}
],
"source": [
"batch_size=10\n",
"maxlen_in=300\n",
"maxlen_out=400\n",
"minibatches=0 # for debug\n",
"min_batch_size=2\n",
"use_sortagrad=True\n",
"batch_count='seq'\n",
"batch_bins=0\n",
"batch_frames_in=3000\n",
"batch_frames_out=0\n",
"batch_frames_inout=0\n",
" \n",
"dev_data = make_batchset(\n",
" dev_json,\n",
" batch_size,\n",
" maxlen_in,\n",
" maxlen_out,\n",
" minibatches, # for debug\n",
" min_batch_size=min_batch_size,\n",
" shortest_first=use_sortagrad,\n",
" batch_sort_key=\"shuffle\",\n",
" count=batch_count,\n",
" batch_bins=batch_bins,\n",
" batch_frames_in=batch_frames_in,\n",
" batch_frames_out=batch_frames_out,\n",
" batch_frames_inout=batch_frames_inout,\n",
" iaxis=0,\n",
" oaxis=0, )\n",
"print(len(dev_data))\n",
"# for i in range(len(dev_data)):\n",
"# print(len(dev_data[i]))\n"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "warming-malpractice",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: kaldiio in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (2.17.2)\n",
"Requirement already satisfied: numpy in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numpy-1.21.2-py3.7-linux-x86_64.egg (from kaldiio) (1.21.2)\n",
"\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.2.4 is available.\n",
"You should consider upgrading via the '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"source": [
"!pip install kaldiio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "equipped-subject",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 100,
"id": "superb-methodology",
"metadata": {},
"outputs": [],
"source": [
"from collections import OrderedDict\n",
"import kaldiio\n",
"\n",
"class LoadInputsAndTargets():\n",
" \"\"\"Create a mini-batch from a list of dicts\n",
"\n",
" >>> batch = [('utt1',\n",
" ... dict(input=[dict(feat='some.ark:123',\n",
" ... filetype='mat',\n",
" ... name='input1',\n",
" ... shape=[100, 80])],\n",
" ... output=[dict(tokenid='1 2 3 4',\n",
" ... name='target1',\n",
" ... shape=[4, 31])]]))\n",
" >>> l = LoadInputsAndTargets()\n",
" >>> feat, target = l(batch)\n",
"\n",
" :param: str mode: Specify the task mode, \"asr\" or \"tts\"\n",
" :param: str preprocess_conf: The path of a json file for pre-processing\n",
" :param: bool load_input: If False, not to load the input data\n",
" :param: bool load_output: If False, not to load the output data\n",
" :param: bool sort_in_input_length: Sort the mini-batch in descending order\n",
" of the input length\n",
" :param: bool use_speaker_embedding: Used for tts mode only\n",
" :param: bool use_second_target: Used for tts mode only\n",
" :param: dict preprocess_args: Set some optional arguments for preprocessing\n",
" :param: Optional[dict] preprocess_args: Used for tts mode only\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" mode=\"asr\",\n",
" preprocess_conf=None,\n",
" load_input=True,\n",
" load_output=True,\n",
" sort_in_input_length=True,\n",
" preprocess_args=None,\n",
" keep_all_data_on_mem=False, ):\n",
" self._loaders = {}\n",
"\n",
" if mode not in [\"asr\"]:\n",
" raise ValueError(\"Only asr are allowed: mode={}\".format(mode))\n",
"\n",
" if preprocess_conf is not None:\n",
" self.preprocessing = AugmentationPipeline(preprocess_conf)\n",
" logging.warning(\n",
" \"[Experimental feature] Some preprocessing will be done \"\n",
" \"for the mini-batch creation using {}\".format(\n",
" self.preprocessing))\n",
" else:\n",
" # If conf doesn't exist, this function don't touch anything.\n",
" self.preprocessing = None\n",
"\n",
" self.mode = mode\n",
" self.load_output = load_output\n",
" self.load_input = load_input\n",
" self.sort_in_input_length = sort_in_input_length\n",
" if preprocess_args is None:\n",
" self.preprocess_args = {}\n",
" else:\n",
" assert isinstance(preprocess_args, dict), type(preprocess_args)\n",
" self.preprocess_args = dict(preprocess_args)\n",
"\n",
" self.keep_all_data_on_mem = keep_all_data_on_mem\n",
"\n",
" def __call__(self, batch, return_uttid=False):\n",
" \"\"\"Function to load inputs and targets from list of dicts\n",
"\n",
" :param List[Tuple[str, dict]] batch: list of dict which is subset of\n",
" loaded data.json\n",
" :param bool return_uttid: return utterance ID information for visualization\n",
" :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]\n",
" :return: list of input feature sequences\n",
" [(T_1, D), (T_2, D), ..., (T_B, D)]\n",
" :rtype: list of float ndarray\n",
" :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]\n",
" :rtype: list of int ndarray\n",
"\n",
" \"\"\"\n",
" x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n",
" y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n",
" uttid_list = [] # List[str]\n",
"\n",
" for uttid, info in batch:\n",
" uttid_list.append(uttid)\n",
"\n",
" if self.load_input:\n",
" # Note(kamo): This for-loop is for multiple inputs\n",
" for idx, inp in enumerate(info[\"input\"]):\n",
" # {\"input\":\n",
" # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"hdf5\",\n",
" # \"name\": \"input1\", ...}], ...}\n",
" x = self._get_from_loader(\n",
" filepath=inp[\"feat\"],\n",
" filetype=inp.get(\"filetype\", \"mat\"))\n",
" x_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
"\n",
" if self.load_output:\n",
" for idx, inp in enumerate(info[\"output\"]):\n",
" if \"tokenid\" in inp:\n",
" # ======= Legacy format for output =======\n",
" # {\"output\": [{\"tokenid\": \"1 2 3 4\"}])\n",
" x = np.fromiter(\n",
" map(int, inp[\"tokenid\"].split()), dtype=np.int64)\n",
" else:\n",
" # ======= New format =======\n",
" # {\"input\":\n",
" # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"hdf5\",\n",
" # \"name\": \"target1\", ...}], ...}\n",
" x = self._get_from_loader(\n",
" filepath=inp[\"feat\"],\n",
" filetype=inp.get(\"filetype\", \"mat\"))\n",
"\n",
" y_feats_dict.setdefault(inp[\"name\"], []).append(x)\n",
"\n",
" if self.mode == \"asr\":\n",
" return_batch, uttid_list = self._create_batch_asr(\n",
" x_feats_dict, y_feats_dict, uttid_list)\n",
" else:\n",
" raise NotImplementedError(self.mode)\n",
"\n",
" if self.preprocessing is not None:\n",
" # Apply pre-processing all input features\n",
" for x_name in return_batch.keys():\n",
" if x_name.startswith(\"input\"):\n",
" return_batch[x_name] = self.preprocessing(\n",
" return_batch[x_name], uttid_list,\n",
" **self.preprocess_args)\n",
"\n",
" if return_uttid:\n",
" return tuple(return_batch.values()), uttid_list\n",
"\n",
" # Doesn't return the names now.\n",
" return tuple(return_batch.values())\n",
"\n",
" def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):\n",
" \"\"\"Create a OrderedDict for the mini-batch\n",
"\n",
" :param OrderedDict x_feats_dict:\n",
" e.g. {\"input1\": [ndarray, ndarray, ...],\n",
" \"input2\": [ndarray, ndarray, ...]}\n",
" :param OrderedDict y_feats_dict:\n",
" e.g. {\"target1\": [ndarray, ndarray, ...],\n",
" \"target2\": [ndarray, ndarray, ...]}\n",
" :param: List[str] uttid_list:\n",
" Give uttid_list to sort in the same order as the mini-batch\n",
" :return: batch, uttid_list\n",
" :rtype: Tuple[OrderedDict, List[str]]\n",
" \"\"\"\n",
" # handle single-input and multi-input (paralell) asr mode\n",
" xs = list(x_feats_dict.values())\n",
"\n",
" if self.load_output:\n",
" ys = list(y_feats_dict.values())\n",
" assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))\n",
"\n",
" # get index of non-zero length samples\n",
" nonzero_idx = list(\n",
" filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))\n",
" for n in range(1, len(y_feats_dict)):\n",
" nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)\n",
" else:\n",
" # Note(kamo): Be careful not to make nonzero_idx to a generator\n",
" nonzero_idx = list(range(len(xs[0])))\n",
"\n",
" if self.sort_in_input_length:\n",
" # sort in input lengths based on the first input\n",
" nonzero_sorted_idx = sorted(\n",
" nonzero_idx, key=lambda i: -len(xs[0][i]))\n",
" else:\n",
" nonzero_sorted_idx = nonzero_idx\n",
"\n",
" if len(nonzero_sorted_idx) != len(xs[0]):\n",
" logging.warning(\n",
" \"Target sequences include empty tokenid (batch {} -> {}).\".\n",
" format(len(xs[0]), len(nonzero_sorted_idx)))\n",
"\n",
" # remove zero-length samples\n",
" xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]\n",
" uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]\n",
"\n",
" x_names = list(x_feats_dict.keys())\n",
" if self.load_output:\n",
" ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]\n",
" y_names = list(y_feats_dict.keys())\n",
"\n",
" # Keeping x_name and y_name, e.g. input1, for future extension\n",
" return_batch = OrderedDict([\n",
" * [(x_name, x) for x_name, x in zip(x_names, xs)],\n",
" * [(y_name, y) for y_name, y in zip(y_names, ys)],\n",
" ])\n",
" else:\n",
" return_batch = OrderedDict(\n",
" [(x_name, x) for x_name, x in zip(x_names, xs)])\n",
" return return_batch, uttid_list\n",
"\n",
" def _get_from_loader(self, filepath, filetype):\n",
" \"\"\"Return ndarray\n",
"\n",
" In order to make the fds to be opened only at the first referring,\n",
" the loader are stored in self._loaders\n",
"\n",
" >>> ndarray = loader.get_from_loader(\n",
" ... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')\n",
"\n",
" :param: str filepath:\n",
" :param: str filetype:\n",
" :return:\n",
" :rtype: np.ndarray\n",
" \"\"\"\n",
" if filetype == \"hdf5\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"hdf5\",\n",
" # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
" filepath, key = filepath.split(\":\", 1)\n",
"\n",
" loader = self._loaders.get(filepath)\n",
" if loader is None:\n",
" # To avoid disk access, create loader only for the first time\n",
" loader = h5py.File(filepath, \"r\")\n",
" self._loaders[filepath] = loader\n",
" return loader[key][()]\n",
" elif filetype == \"sound.hdf5\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"sound.hdf5\",\n",
" # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n",
" filepath, key = filepath.split(\":\", 1)\n",
"\n",
" loader = self._loaders.get(filepath)\n",
" if loader is None:\n",
" # To avoid disk access, create loader only for the first time\n",
" loader = SoundHDF5File(filepath, \"r\", dtype=\"int16\")\n",
" self._loaders[filepath] = loader\n",
" array, rate = loader[key]\n",
" return array\n",
" elif filetype == \"sound\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.wav\",\n",
" # \"filetype\": \"sound\"},\n",
" # Assume PCM16\n",
" if not self.keep_all_data_on_mem:\n",
" array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
" return array\n",
" if filepath not in self._loaders:\n",
" array, _ = soundfile.read(filepath, dtype=\"int16\")\n",
" self._loaders[filepath] = array\n",
" return self._loaders[filepath]\n",
" elif filetype == \"npz\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.npz:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"npz\",\n",
" filepath, key = filepath.split(\":\", 1)\n",
"\n",
" loader = self._loaders.get(filepath)\n",
" if loader is None:\n",
" # To avoid disk access, create loader only for the first time\n",
" loader = np.load(filepath)\n",
" self._loaders[filepath] = loader\n",
" return loader[key]\n",
" elif filetype == \"npy\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.npy\",\n",
" # \"filetype\": \"npy\"},\n",
" if not self.keep_all_data_on_mem:\n",
" return np.load(filepath)\n",
" if filepath not in self._loaders:\n",
" self._loaders[filepath] = np.load(filepath)\n",
" return self._loaders[filepath]\n",
" elif filetype in [\"mat\", \"vec\"]:\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.ark:123\",\n",
" # \"filetype\": \"mat\"}]},\n",
" # In this case, \"123\" indicates the starting points of the matrix\n",
" # load_mat can load both matrix and vector\n",
" if not self.keep_all_data_on_mem:\n",
" return kaldiio.load_mat(filepath)\n",
" if filepath not in self._loaders:\n",
" self._loaders[filepath] = kaldiio.load_mat(filepath)\n",
" return self._loaders[filepath]\n",
" elif filetype == \"scp\":\n",
" # e.g.\n",
" # {\"input\": [{\"feat\": \"some/path.scp:F01_050C0101_PED_REAL\",\n",
" # \"filetype\": \"scp\",\n",
" filepath, key = filepath.split(\":\", 1)\n",
" loader = self._loaders.get(filepath)\n",
" if loader is None:\n",
" # To avoid disk access, create loader only for the first time\n",
" loader = kaldiio.load_scp(filepath)\n",
" self._loaders[filepath] = loader\n",
" return loader[key]\n",
" else:\n",
" raise NotImplementedError(\n",
" \"Not supported: loader_type={}\".format(filetype))\n"
]
},
{
"cell_type": "code",
"execution_count": 101,
"id": "monthly-muscle",
"metadata": {},
"outputs": [],
"source": [
"preprocess_conf=None\n",
"train_mode=True\n",
"load = LoadInputsAndTargets(\n",
" mode=\"asr\",\n",
" load_output=True,\n",
" preprocess_conf=preprocess_conf,\n",
" preprocess_args={\"train\":\n",
" train_mode}, # Switch the mode of preprocessing\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "periodic-senegal",
"metadata": {},
"outputs": [],
"source": [
"res = load(dev_data[0])"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "502d3f4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'tuple'>\n",
"2\n",
"10\n",
"10\n",
"(1174, 83) float32\n",
"(29,) int64\n"
]
}
],
"source": [
"print(type(res))\n",
"print(len(res))\n",
"print(len(res[0]))\n",
"print(len(res[1]))\n",
"print(res[0][0].shape, res[0][0].dtype)\n",
"print(res[1][0].shape, res[1][0].dtype)\n",
"# Tuple[Tuple[np.ndarry], Tuple[np.ndarry]]\n",
"# 2[10, 10]\n",
"# feats, labels"
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "humanitarian-container",
"metadata": {},
"outputs": [],
"source": [
"(inputs, outputs), utts = load(dev_data[0], return_uttid=True)"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "heard-prize",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027'] 10\n",
"10\n"
]
}
],
"source": [
"print(utts, len(utts))\n",
"print(len(inputs))"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "convinced-animation",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from deepspeech.io.utility import pad_list\n",
"class CustomConverter():\n",
" \"\"\"Custom batch converter.\n",
"\n",
" Args:\n",
" subsampling_factor (int): The subsampling factor.\n",
" dtype (paddle.dtype): Data type to convert.\n",
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, subsampling_factor=1, dtype=np.float32):\n",
" \"\"\"Construct a CustomConverter object.\"\"\"\n",
" self.subsampling_factor = subsampling_factor\n",
" self.ignore_id = -1\n",
" self.dtype = dtype\n",
"\n",
" def __call__(self, batch):\n",
" \"\"\"Transform a batch and send it to a device.\n",
"\n",
" Args:\n",
" batch (list): The batch to transform.\n",
"\n",
" Returns:\n",
" tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)\n",
"\n",
" \"\"\"\n",
" # batch should be located in list\n",
" assert len(batch) == 1\n",
" (xs, ys), utts = batch[0]\n",
"\n",
" # perform subsampling\n",
" if self.subsampling_factor > 1:\n",
" xs = [x[::self.subsampling_factor, :] for x in xs]\n",
"\n",
" # get batch of lengths of input sequences\n",
" ilens = np.array([x.shape[0] for x in xs])\n",
"\n",
" # perform padding and convert to tensor\n",
" # currently only support real number\n",
" if xs[0].dtype.kind == \"c\":\n",
" xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)\n",
" xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)\n",
" # Note(kamo):\n",
" # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.\n",
" # Don't create ComplexTensor and give it E2E here\n",
" # because torch.nn.DataParellel can't handle it.\n",
" xs_pad = {\"real\": xs_pad_real, \"imag\": xs_pad_imag}\n",
" else:\n",
" xs_pad = pad_list(xs, 0).astype(self.dtype)\n",
"\n",
" # NOTE: this is for multi-output (e.g., speech translation)\n",
" ys_pad = pad_list(\n",
" [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],\n",
" self.ignore_id)\n",
"\n",
" olens = np.array([y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])\n",
" return utts, xs_pad, ilens, ys_pad, olens"
]
},
{
"cell_type": "code",
"execution_count": 107,
"id": "0b92ade5",
"metadata": {},
"outputs": [],
"source": [
"convert = CustomConverter()"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "8dbd847c",
"metadata": {},
"outputs": [],
"source": [
"utts, xs, ilen, ys, olen = convert([load(dev_data[0], return_uttid=True)])"
]
},
{
"cell_type": "code",
"execution_count": 109,
"id": "31c085f4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027']\n",
"(10, 1174, 83)\n",
"(10,)\n",
"[1174 821 716 628 597 473 463 441 419 358]\n",
"(10, 32)\n",
"[[4502 2404 4223 3204 4502 587 1018 3861 2932 713 2458 2916 253 4508\n",
" 627 1395 713 4504 957 2761 209 2967 3173 3918 2598 4100 3 2816\n",
" 4990 -1 -1 -1]\n",
" [1005 451 210 278 3411 206 482 2307 573 4502 3848 4577 4273 2388\n",
" 4444 89 4919 278 1264 4501 2371 3 139 113 2603 4962 3158 3325\n",
" 4577 814 4587 1422]\n",
" [2345 4144 2291 200 713 2345 532 999 2458 3076 545 2458 4832 3038\n",
" 4499 482 2812 1260 3080 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]\n",
" [2345 832 4577 4920 4501 2345 2298 1236 381 288 389 101 2495 4172\n",
" 4843 3233 3245 4501 2345 2298 3987 4502 3023 3353 2345 1361 1635 2603\n",
" 4723 2371 -1 -1]\n",
" [4502 4207 432 3204 4502 2396 125 935 433 2598 483 18 327 2\n",
" 389 627 4512 2340 713 482 1981 4525 4031 269 2030 1340 101 2495\n",
" 4013 4844 -1 -1]\n",
" [4502 4892 3204 1892 3780 389 482 2774 3013 89 192 2495 4502 3475\n",
" 389 66 370 343 404 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]\n",
" [2458 2314 4577 2340 2863 1254 303 269 2 389 932 2079 4577 299\n",
" 195 3233 4508 2 89 814 3144 1091 3204 3250 2193 3414 -1 -1\n",
" -1 -1 -1 -1]\n",
" [2391 1785 443 78 39 4962 2340 829 599 4593 278 4681 202 407\n",
" 269 194 182 4577 482 4308 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]\n",
" [ 627 4873 2175 363 202 404 1018 4577 4502 3412 4875 2286 107 122\n",
" 4832 2345 3896 89 2368 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]\n",
" [ 481 174 474 599 1881 3252 2842 742 4502 2545 107 88 3204 4525\n",
" 4517 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n",
" -1 -1 -1 -1]]\n",
"[29 32 19 30 30 19 26 20 19 15]\n",
"float32\n",
"int64\n",
"int64\n",
"int64\n"
]
}
],
"source": [
"print(utts)\n",
"print(xs.shape)\n",
"print(ilen.shape)\n",
"print(ilen)\n",
"print(ys.shape)\n",
"print(ys)\n",
"print(olen)\n",
"print(xs.dtype)\n",
"print(ilen.dtype)\n",
"print(ys.dtype)\n",
"print(olen.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 110,
"id": "72e9ba60",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 230,
"id": "64593e5f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from paddle.io import DataLoader\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.io.batchfy import make_batchset\n",
"from deepspeech.io.converter import CustomConverter\n",
"from deepspeech.io.dataset import TransformDataset\n",
"from deepspeech.io.reader import LoadInputsAndTargets\n",
"from deepspeech.utils.log import Log\n",
"\n",
"\n",
"logger = Log(__name__).getlog()\n",
"\n",
"\n",
"class BatchDataLoader():\n",
" def __init__(self,\n",
" json_file: str,\n",
" train_mode: bool,\n",
" sortagrad: bool=False,\n",
" batch_size: int=0,\n",
" maxlen_in: float=float('inf'),\n",
" maxlen_out: float=float('inf'),\n",
" minibatches: int=0,\n",
" mini_batch_size: int=1,\n",
" batch_count: str='auto',\n",
" batch_bins: int=0,\n",
" batch_frames_in: int=0,\n",
" batch_frames_out: int=0,\n",
" batch_frames_inout: int=0,\n",
" preprocess_conf=None,\n",
" n_iter_processes: int=1,\n",
" subsampling_factor: int=1,\n",
" num_encs: int=1):\n",
" self.json_file = json_file\n",
" self.train_mode = train_mode\n",
" self.use_sortagrad = sortagrad == -1 or sortagrad > 0\n",
" self.batch_size = batch_size\n",
" self.maxlen_in = maxlen_in\n",
" self.maxlen_out = maxlen_out\n",
" self.batch_count = batch_count\n",
" self.batch_bins = batch_bins\n",
" self.batch_frames_in = batch_frames_in\n",
" self.batch_frames_out = batch_frames_out\n",
" self.batch_frames_inout = batch_frames_inout\n",
" self.subsampling_factor = subsampling_factor\n",
" self.num_encs = num_encs\n",
" self.preprocess_conf = preprocess_conf\n",
" self.n_iter_processes = n_iter_processes\n",
"\n",
" \n",
" # read json data\n",
" self.data_json = read_manifest(json_file)\n",
"\n",
" # make minibatch list (variable length)\n",
" self.minibaches = make_batchset(\n",
" self.data_json,\n",
" batch_size,\n",
" maxlen_in,\n",
" maxlen_out,\n",
" minibatches, # for debug\n",
" min_batch_size=mini_batch_size,\n",
" shortest_first=self.use_sortagrad,\n",
" count=batch_count,\n",
" batch_bins=batch_bins,\n",
" batch_frames_in=batch_frames_in,\n",
" batch_frames_out=batch_frames_out,\n",
" batch_frames_inout=batch_frames_inout,\n",
" iaxis=0,\n",
" oaxis=0, )\n",
"\n",
" # data reader\n",
" self.reader = LoadInputsAndTargets(\n",
" mode=\"asr\",\n",
" load_output=True,\n",
" preprocess_conf=preprocess_conf,\n",
" preprocess_args={\"train\":\n",
" train_mode}, # Switch the mode of preprocessing\n",
" )\n",
"\n",
" # Setup a converter\n",
" if num_encs == 1:\n",
" self.converter = CustomConverter(\n",
" subsampling_factor=subsampling_factor, dtype=np.float32)\n",
" else:\n",
" assert NotImplementedError(\"not impl CustomConverterMulEnc.\")\n",
"\n",
" # hack to make batchsize argument as 1\n",
" # actual bathsize is included in a list\n",
" # default collate function converts numpy array to pytorch tensor\n",
" # we used an empty collate function instead which returns list\n",
" self.dataset = TransformDataset(self.minibaches, \n",
" lambda data: self.converter([self.reader(data, return_uttid=True)]))\n",
" self.dataloader = DataLoader(\n",
" dataset=self.dataset,\n",
" batch_size=1,\n",
" shuffle=not use_sortagrad if train_mode else False,\n",
" collate_fn=lambda x: x[0],\n",
" num_workers=n_iter_processes, )\n",
"\n",
" def __repr__(self):\n",
" echo = f\"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> \"\n",
" echo += f\"train_mode: {self.train_mode}, \"\n",
" echo += f\"sortagrad: {self.use_sortagrad}, \"\n",
" echo += f\"batch_size: {self.batch_size}, \"\n",
" echo += f\"maxlen_in: {self.maxlen_in}, \"\n",
" echo += f\"maxlen_out: {self.maxlen_out}, \"\n",
" echo += f\"batch_count: {self.batch_count}, \"\n",
" echo += f\"batch_bins: {self.batch_bins}, \"\n",
" echo += f\"batch_frames_in: {self.batch_frames_in}, \"\n",
" echo += f\"batch_frames_out: {self.batch_frames_out}, \"\n",
" echo += f\"batch_frames_inout: {self.batch_frames_inout}, \"\n",
" echo += f\"subsampling_factor: {self.subsampling_factor}, \"\n",
" echo += f\"num_encs: {self.num_encs}, \"\n",
" echo += f\"num_workers: {self.n_iter_processes}, \"\n",
" echo += f\"file: {self.json_file}\"\n",
" return echo\n",
" \n",
" def __len__(self):\n",
" return len(self.dataloader)\n",
" \n",
" def __iter__(self):\n",
" return self.dataloader.__iter__()\n",
" \n",
" def __call__(self):\n",
" return self.__iter__()\n"
]
},
{
"cell_type": "code",
"execution_count": 231,
"id": "fcea3fd0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO 2021/08/18 07:42:23 batchfy.py:399] count is auto detected as seq\n",
"[INFO 2021/08/18 07:42:23 batchfy.py:423] # utts: 5542\n",
"[INFO 2021/08/18 07:42:23 batchfy.py:466] # minibatches: 278\n"
]
}
],
"source": [
"train = BatchDataLoader(dev_data, True, batch_size=20)"
]
},
{
"cell_type": "code",
"execution_count": 232,
"id": "e2a2c9a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"278\n",
"['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'auto_collate_batch', 'batch_sampler', 'batch_size', 'collate_fn', 'dataset', 'dataset_kind', 'feed_list', 'from_dataset', 'from_generator', 'num_workers', 'pin_memory', 'places', 'return_list', 'timeout', 'use_buffer_reader', 'use_shared_memory', 'worker_init_fn']\n",
"<__main__.BatchDataLoader object at 0x7fdddba35470> train_mode: True, sortagrad: False, batch_size: 20, maxlen_in: inf, maxlen_out: inf, batch_count: auto, batch_bins: 0, batch_frames_in: 0, batch_frames_out: 0, batch_frames_inout: 0, subsampling_factor: 1, num_encs: 1, num_workers: 1, file: /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev\n",
"278\n"
]
}
],
"source": [
"print(len(train.dataloader))\n",
"print(dir(train.dataloader))\n",
"print(train)\n",
"print(len(train))"
]
},
{
"cell_type": "code",
"execution_count": 220,
"id": "a5ba7d6e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['7601-101619-0003', '1255-138279-0000', '1272-128104-0004', '6123-59150-0027', '2078-142845-0025', '7850-73752-0018', '4570-24733-0004', '2506-169427-0002', '7601-101619-0004', '3170-137482-0000', '6267-53049-0019', '4570-14911-0009', '174-168635-0018', '7601-291468-0004', '3576-138058-0022', '1919-142785-0007', '6467-62797-0007', '4153-61735-0005', '1686-142278-0003', '2506-169427-0000']\n",
"Tensor(shape=[20, 2961, 83], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[-1.99415934, -1.80315673, -1.88801885, ..., 0.86933994, -0.59853148, 0.02596200],\n",
" [-1.95346808, -1.84891188, -2.17492867, ..., 0.83640492, -0.59853148, -0.11333394],\n",
" [-2.27899861, -2.21495342, -2.58480024, ..., 0.91874266, -0.59853148, -0.31453922],\n",
" ...,\n",
" [-2.64522028, -2.35221887, -2.91269732, ..., 1.48994756, -0.16100442, 0.36646330],\n",
" [-2.40107250, -2.21495342, -2.37986445, ..., 1.44072104, -0.13220564, 0.12656468],\n",
" [-2.15692472, -1.89466715, -2.25690317, ..., 1.31273174, -0.09620714, -0.15202725]],\n",
"\n",
" [[-0.28859532, -0.29033494, -0.86576819, ..., 1.37753224, -0.30570769, 0.25806731],\n",
" [-0.20149794, -0.17814466, -0.59891301, ..., 1.35188794, -0.30570769, -0.02964944],\n",
" [-0.34947991, -0.33597648, -0.96877253, ..., 1.38394332, -0.30570769, -0.38376236],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.44914246, -0.33902276, -0.78237975, ..., 1.38218808, 0.29214793, -0.16815147],\n",
" [-0.55490732, -0.41596055, -0.84425378, ..., 1.34530187, 0.25002354, -0.04004869],\n",
" [-0.83694696, -0.62112784, -1.07112527, ..., 1.19160914, 0.20789915, 0.37984371],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" ...,\n",
"\n",
" [[-1.24343657, -0.94188881, -1.41092563, ..., 0.96716309, 0.60345763, 0.15360183],\n",
" [-1.19466043, -0.80585432, -0.49723154, ..., 1.06735480, 0.60345763, 0.14511746],\n",
" [-0.94079566, -0.59330046, -0.40948665, ..., 0.82244170, 0.55614340, 0.28086722],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.21757117, 0.11361472, -0.33262897, ..., 0.76338506, -0.10711290, -0.57754958],\n",
" [-1.00205481, -0.61152041, -0.47124696, ..., 1.11897349, -0.10711290, 0.24931324],\n",
" [-1.03929281, -1.20336759, -1.16433656, ..., 0.88888687, -0.10711290, -0.04115745],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-1.25289667, -1.05046368, -0.82881606, ..., 1.23991334, 0.61702502, 0.05275881],\n",
" [-1.19659519, -0.78677225, -0.80407262, ..., 1.27644968, 0.61702502, -0.35079369],\n",
" [-1.49687004, -1.01750231, -0.82881606, ..., 1.29106426, 0.65006059, 0.17958963],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]]])\n",
"Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [2961, 2948, 2938, 2907, 2904, 2838, 2832, 2819, 2815, 2797, 2775, 2710, 2709, 2696, 2688, 2661, 2616, 2595, 2589, 2576])\n",
"Tensor(shape=[20, 133], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[3098, 1595, 389, ..., -1 , -1 , -1 ],\n",
" [2603, 4832, 482, ..., -1 , -1 , -1 ],\n",
" [2796, 303, 269, ..., -1 , -1 , -1 ],\n",
" ...,\n",
" [3218, 3673, 206, ..., -1 , -1 , -1 ],\n",
" [2371, 4832, 4031, ..., -1 , -1 , -1 ],\n",
" [2570, 2433, 4285, ..., -1 , -1 , -1 ]])\n",
"Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [80 , 83 , 102, 133, 82 , 102, 71 , 91 , 68 , 81 , 86 , 67 , 71 , 95 , 65 , 88 , 97 , 98 , 89 , 72 ])\n"
]
}
],
"source": [
"for batch in train:\n",
" utts, xs, ilens, ys, olens = batch\n",
" print(utts)\n",
" print(xs)\n",
" print(ilens)\n",
" print(ys)\n",
" print(olens)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c974a1e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "breeding-haven",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "appropriate-theta",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LICENSE deepspeech examples\t\t requirements.txt tools\r\n",
"README.md docs\t libsndfile-1.0.28\t setup.sh\t utils\r\n",
"README_cn.md env.sh\t libsndfile-1.0.28.tar.gz tests\r\n"
]
}
],
"source": [
"!ls"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "entire-bloom",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"WARNING:root:override cat of paddle.Tensor if exists or register, remove this when fixed!\n",
"WARNING:root:register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user repeat to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user glu to paddle.nn.functional, remove this when fixed!\n",
"WARNING:root:register user GLU to paddle.nn, remove this when fixed!\n",
"WARNING:root:register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"WARNING:root:override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n"
]
}
],
"source": [
"from deepspeech.modules import loss"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "governmental-aircraft",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"import paddle"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "proprietary-disaster",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function deepspeech.modules.repeat(xs: paddle.VarBase, *size: Any) -> paddle.VarBase>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.repeat"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "first-diagram",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<property at 0x7fb515eeeb88>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.size"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "intelligent-david",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function paddle.tensor.manipulation.concat(x, axis=0, name=None)>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.cat"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bronze-tenant",
"metadata": {},
"outputs": [],
"source": [
"a = paddle.to_tensor([12,32, 10, 12, 123,32 ,4])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "balanced-bearing",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "extreme-republic",
"metadata": {},
"outputs": [],
"source": [
"def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:\n",
" nargs = len(args)\n",
" assert (nargs <= 1)\n",
" s = paddle.shape(xs)\n",
" if nargs == 1:\n",
" return s[args[0]]\n",
" else:\n",
" return s\n",
"\n",
"# logger.warn(\n",
"# \"override size of paddle.Tensor if exists or register, remove this when fixed!\"\n",
"# )\n",
"paddle.Tensor.size = size"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "gross-addiction",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [7])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size(0)\n",
"a.size()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "adverse-dining",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [7])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "popular-potato",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-26 02:55:23,873 - WARNING - register user softmax to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,875 - WARNING - register user sigmoid to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,875 - WARNING - register user relu to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,876 - WARNING - override cat of paddle if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,876 - WARNING - override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,877 - WARNING - override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,877 - WARNING - override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"2021-03-26 02:55:23,878 - WARNING - register user view to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,878 - WARNING - register user view_as to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,879 - WARNING - register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,880 - WARNING - register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,880 - WARNING - register user fill_ to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,881 - WARNING - register user repeat to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,881 - WARNING - register user softmax to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,882 - WARNING - register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,882 - WARNING - register user relu to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,883 - WARNING - register user glu to paddle.nn.functional, remove this when fixed!\n",
"2021-03-26 02:55:23,883 - WARNING - override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"2021-03-26 02:55:23,884 - WARNING - register user GLU to paddle.nn, remove this when fixed!\n",
"2021-03-26 02:55:23,884 - WARNING - register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n"
]
}
],
"source": [
"import os\n",
"import time\n",
"import argparse\n",
"import functools\n",
"import paddle\n",
"import numpy as np\n",
"\n",
"from deepspeech.utils.socket_server import warm_up_test\n",
"from deepspeech.utils.socket_server import AsrTCPServer\n",
"from deepspeech.utils.socket_server import AsrRequestHandler\n",
"\n",
"from deepspeech.training.cli import default_argument_parser\n",
"from deepspeech.exps.deepspeech2.config import get_cfg_defaults\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"\n",
"from deepspeech.models.ds2 import DeepSpeech2Model\n",
"from deepspeech.models.ds2 import DeepSpeech2InferModel\n",
"from deepspeech.io.dataset import ManifestDataset\n",
"\n",
"\n",
"\n",
"from deepspeech.frontend.utility import read_manifest"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0.0\n",
"e7f28d6c0db54eb9c9a810612300b526687e56a6\n",
"OFF\n",
"OFF\n",
"commit: e7f28d6c0db54eb9c9a810612300b526687e56a6\n",
"None\n",
"0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" 'commit',\n",
" 'full_version',\n",
" 'istaged',\n",
" 'major',\n",
" 'minor',\n",
" 'mkl',\n",
" 'patch',\n",
" 'rc',\n",
" 'show',\n",
" 'with_mkl']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(paddle.__version__)\n",
"print(paddle.version.commit)\n",
"print(paddle.version.with_mkl)\n",
"print(paddle.version.mkl())\n",
"print(paddle.version.show())\n",
"print(paddle.version.patch)\n",
"dir(paddle.version)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data:\n",
" augmentation_config: conf/augmentation.config\n",
" batch_size: 64\n",
" dev_manifest: data/manifest.dev\n",
" keep_transcription_text: False\n",
" max_duration: 27.0\n",
" max_freq: None\n",
" mean_std_filepath: examples/aishell/data/mean_std.npz\n",
" min_duration: 0.0\n",
" n_fft: None\n",
" num_workers: 0\n",
" random_seed: 0\n",
" shuffle_method: batch_shuffle\n",
" sortagrad: True\n",
" specgram_type: linear\n",
" stride_ms: 10.0\n",
" target_dB: -20\n",
" target_sample_rate: 16000\n",
" test_manifest: examples/aishell/data/manifest.test\n",
" train_manifest: data/manifest.train\n",
" use_dB_normalization: True\n",
" vocab_filepath: examples/aishell/data/vocab.txt\n",
" window_ms: 20.0\n",
"decoding:\n",
" alpha: 2.6\n",
" batch_size: 128\n",
" beam_size: 300\n",
" beta: 5.0\n",
" cutoff_prob: 0.99\n",
" cutoff_top_n: 40\n",
" decoding_method: ctc_beam_search\n",
" error_rate_type: cer\n",
" lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm\n",
" num_proc_bsearch: 10\n",
"model:\n",
" num_conv_layers: 2\n",
" num_rnn_layers: 3\n",
" rnn_layer_size: 1024\n",
" share_rnn_weights: False\n",
" use_gru: True\n",
"training:\n",
" global_grad_clip: 5.0\n",
" lr: 0.0005\n",
" lr_decay: 0.83\n",
" n_epoch: 30\n",
" weight_decay: 1e-06\n",
"----------- Configuration Arguments -----------\n",
"checkpoint_path: examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725\n",
"config: examples/aishell/conf/deepspeech2.yaml\n",
"device: gpu\n",
"dump_config: None\n",
"export_path: None\n",
"host_ip: localhost\n",
"host_port: 8086\n",
"model_dir: None\n",
"model_file: examples/aishell/jit.model.pdmodel\n",
"nprocs: 1\n",
"opts: ['data.test_manifest', 'examples/aishell/data/manifest.test', 'data.mean_std_filepath', 'examples/aishell/data/mean_std.npz', 'data.vocab_filepath', 'examples/aishell/data/vocab.txt']\n",
"output: None\n",
"params_file: examples/aishell/jit.model.pdiparams\n",
"speech_save_dir: demo_cache\n",
"use_gpu: False\n",
"warmup_manifest: examples/aishell/data/manifest.test\n",
"------------------------------------------------\n"
]
}
],
"source": [
"parser = default_argument_parser()\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"add_arg('host_ip', str,\n",
" 'localhost',\n",
" \"Server's IP address.\")\n",
"add_arg('host_port', int, 8086, \"Server's IP port.\")\n",
"add_arg('speech_save_dir', str,\n",
" 'demo_cache',\n",
" \"Directory to save demo audios.\")\n",
"add_arg('warmup_manifest', \n",
" str, \n",
" \"examples/aishell/data/manifest.test\", \n",
" \"Filepath of manifest to warm up.\")\n",
"add_arg(\n",
" \"--model_file\",\n",
" type=str,\n",
" default=\"examples/aishell/jit.model.pdmodel\",\n",
" help=\"Model filename, Specify this when your model is a combined model.\"\n",
")\n",
"add_arg(\n",
" \"--params_file\",\n",
" type=str,\n",
" default=\"examples/aishell/jit.model.pdiparams\",\n",
" help=\n",
" \"Parameter filename, Specify this when your model is a combined model.\"\n",
")\n",
"add_arg(\n",
" \"--model_dir\",\n",
" type=str,\n",
" default=None,\n",
" help=\n",
" \"Model dir, If you load a non-combined model, specify the directory of the model.\"\n",
")\n",
"add_arg(\"--use_gpu\",type=bool,default=False, help=\"Whether use gpu.\")\n",
"\n",
"\n",
"args = parser.parse_args(\n",
" \"--checkpoint_path examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725 --config examples/aishell/conf/deepspeech2.yaml --opts data.test_manifest examples/aishell/data/manifest.test data.mean_std_filepath examples/aishell/data/mean_std.npz data.vocab_filepath examples/aishell/data/vocab.txt\".split()\n",
")\n",
"\n",
"\n",
"config = get_cfg_defaults()\n",
"if args.config:\n",
" config.merge_from_file(args.config)\n",
"if args.opts:\n",
" config.merge_from_list(args.opts)\n",
"config.freeze()\n",
"print(config)\n",
"\n",
"args.warmup_manifest = config.data.test_manifest\n",
"\n",
"print_arguments(args)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dataset = ManifestDataset(\n",
" config.data.test_manifest,\n",
" config.data.unit_type,\n",
" config.data.vocab_filepath,\n",
" config.data.mean_std_filepath,\n",
" augmentation_config=\"{}\",\n",
" max_duration=config.data.max_duration,\n",
" min_duration=config.data.min_duration,\n",
" stride_ms=config.data.stride_ms,\n",
" window_ms=config.data.window_ms,\n",
" n_fft=config.data.n_fft,\n",
" max_freq=config.data.max_freq,\n",
" target_sample_rate=config.data.target_sample_rate,\n",
" specgram_type=config.data.specgram_type,\n",
" feat_dim=config.data.feat_dim,\n",
" delta_delta=config.data.delat_delta,\n",
" use_dB_normalization=config.data.use_dB_normalization,\n",
" target_dB=config.data.target_dB,\n",
" random_seed=config.data.random_seed,\n",
" keep_transcription_text=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-26 02:55:57,930 - INFO - [checkpoint] Rank 0: loaded model from examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725.pdparams\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"layer summary:\n",
"encoder.conv.conv_in.conv.weight|[32, 1, 41, 11]|14432\n",
"encoder.conv.conv_in.bn.weight|[32]|32\n",
"encoder.conv.conv_in.bn.bias|[32]|32\n",
"encoder.conv.conv_in.bn._mean|[32]|32\n",
"encoder.conv.conv_in.bn._variance|[32]|32\n",
"encoder.conv.conv_stack.0.conv.weight|[32, 32, 21, 11]|236544\n",
"encoder.conv.conv_stack.0.bn.weight|[32]|32\n",
"encoder.conv.conv_stack.0.bn.bias|[32]|32\n",
"encoder.conv.conv_stack.0.bn._mean|[32]|32\n",
"encoder.conv.conv_stack.0.bn._variance|[32]|32\n",
"encoder.rnn.rnn_stacks.0.fw_fc.weight|[1312, 3072]|4030464\n",
"encoder.rnn.rnn_stacks.0.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_fc.weight|[1312, 3072]|4030464\n",
"encoder.rnn.rnn_stacks.0.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.bw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.1.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.1.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.bw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.2.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.2.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.bw_rnn.cell.bias_hh|[3072]|3072\n",
"decoder.ctc_lo.weight|[2048, 4300]|8806400\n",
"decoder.ctc_lo.bias|[4300]|4300\n",
"layer has 66 parameters, 80148012 elements.\n"
]
}
],
"source": [
"model = DeepSpeech2InferModel.from_pretrained(dataset, config,\n",
" args.checkpoint_path)\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"examples/aishell/jit.model.pdmodel\n",
"examples/aishell/jit.model.pdiparams\n",
"0\n",
"False\n"
]
}
],
"source": [
"\n",
"from paddle.inference import Config\n",
"from paddle.inference import PrecisionType\n",
"from paddle.inference import create_predictor\n",
"\n",
"args.use_gpu=False\n",
"paddle.set_device('cpu')\n",
"\n",
"def init_predictor(args):\n",
" if args.model_dir is not None:\n",
" config = Config(args.model_dir)\n",
" else:\n",
" config = Config(args.model_file, args.params_file)\n",
"\n",
" if args.use_gpu:\n",
" config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)\n",
"# config.enable_tensorrt_engine(precision_mode=PrecisionType.Float32,\n",
"# use_calib_mode=True) # 开启TensorRT预测,精度为fp32,开启int8离线量化\n",
" else:\n",
" # If not specific mkldnn, you can set the blas thread.\n",
" # The thread num should not be greater than the number of cores in the CPU.\n",
" config.set_cpu_math_library_num_threads(1)\n",
" config.enable_mkldnn()\n",
" \n",
" config.enable_memory_optim()\n",
" config.switch_ir_optim(True)\n",
" \n",
" print(config.model_dir())\n",
" print(config.prog_file())\n",
" print(config.params_file())\n",
" print(config.gpu_device_id())\n",
" print(args.use_gpu)\n",
" predictor = create_predictor(config)\n",
" return predictor\n",
"\n",
"def run(predictor, audio, audio_len):\n",
" # copy img data to input tensor\n",
" input_names = predictor.get_input_names()\n",
" for i, name in enumerate(input_names):\n",
" print(\"input:\", i, name)\n",
" \n",
" audio_tensor = predictor.get_input_handle('audio')\n",
" audio_tensor.reshape(audio.shape)\n",
" audio_tensor.copy_from_cpu(audio.copy())\n",
" \n",
" audiolen_tensor = predictor.get_input_handle('audio_len')\n",
" audiolen_tensor.reshape(audio_len.shape)\n",
" audiolen_tensor.copy_from_cpu(audio_len.copy())\n",
"\n",
" output_names = predictor.get_output_names()\n",
" for i, name in enumerate(output_names):\n",
" print(\"output:\", i, name)\n",
"\n",
" # do the inference\n",
" predictor.run()\n",
"\n",
" results = []\n",
" # get out data from output tensor\n",
" output_names = predictor.get_output_names()\n",
" for i, name in enumerate(output_names):\n",
" output_tensor = predictor.get_output_handle(name)\n",
" output_data = output_tensor.copy_to_cpu()\n",
" results.append(output_data)\n",
"\n",
" return results\n",
"\n",
"\n",
"predictor = init_predictor(args)\n",
"\n",
"def file_to_transcript(filename):\n",
" print(filename)\n",
" feature = dataset.process_utterance(filename, \"\")\n",
" audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n",
" audio_len = feature[0].shape[1]\n",
" audio_len = np.array([audio_len]).astype('int64') # [1]\n",
" \n",
" \n",
" i_probs = run(predictor, audio, audio_len)\n",
" print('jit:', i_probs[0], type(i_probs[0]))\n",
" \n",
" audio = paddle.to_tensor(audio)\n",
" audio_len = paddle.to_tensor(audio_len)\n",
" print(audio.shape)\n",
" print(audio_len.shape)\n",
" \n",
" #eouts, eouts_len = model.encoder(audio, audio_len)\n",
" #probs = model.decoder.softmax(eouts)\n",
" probs = model.forward(audio, audio_len)\n",
" print('paddle:', probs.numpy())\n",
" \n",
" flag = np.allclose(i_probs[0], probs.numpy())\n",
" print(flag)\n",
" \n",
" return probs\n",
"\n",
"# result_transcript = model.decode(\n",
"# audio,\n",
"# audio_len,\n",
"# vocab_list=dataset.vocab_list,\n",
"# decoding_method=config.decoding.decoding_method,\n",
"# lang_model_path=config.decoding.lang_model_path,\n",
"# beam_alpha=config.decoding.alpha,\n",
"# beam_beta=config.decoding.beta,\n",
"# beam_size=config.decoding.beam_size,\n",
"# cutoff_prob=config.decoding.cutoff_prob,\n",
"# cutoff_top_n=config.decoding.cutoff_top_n,\n",
"# num_processes=config.decoding.num_proc_bsearch)\n",
"# return result_transcript[0]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warm-up Test Case %d: %s 0 /home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n",
"input: 0 audio\n",
"input: 1 audio_len\n",
"output: 0 tmp_75\n",
"jit: [[[8.91786298e-12 4.45648032e-12 3.67572750e-09 ... 8.91767563e-12\n",
" 8.91573707e-12 4.64317296e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638127e-17 7.61802427e-16 2.93265812e-14 ... 1.24633371e-17\n",
" 1.24587264e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676260e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89334696e-13 1.66754856e-11 1.42900388e-11 ... 3.89329492e-13\n",
" 3.89252270e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]] <class 'numpy.ndarray'>\n",
"[1, 161, 522]\n",
"[1]\n",
"paddle: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n",
" 8.91577090e-12 4.64319072e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n",
" 1.24587735e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n",
" 3.89253761e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]]\n",
"False\n"
]
}
],
"source": [
"manifest = read_manifest(args.warmup_manifest)\n",
"\n",
"for idx, sample in enumerate(manifest[:1]):\n",
" print(\"Warm-up Test Case %d: %s\", idx, sample['audio_filepath'])\n",
" start_time = time.time()\n",
" transcript = file_to_transcript(sample['audio_filepath'])\n",
" finish_time = time.time()\n",
"# print(\"Response Time: %f, Transcript: %s\" %\n",
"# (finish_time - start_time, transcript))\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 161, 522) (1,)\n",
"input: 0 audio\n",
"input: 1 audio_len\n",
"output: 0 tmp_75\n",
"jit: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n",
" 8.91577090e-12 4.64319072e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n",
" 1.24587735e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n",
" 3.89253761e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]]\n"
]
}
],
"source": [
"def test(filename):\n",
" feature = dataset.process_utterance(filename, \"\")\n",
" audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n",
" audio_len = feature[0].shape[1]\n",
" audio_len = np.array([audio_len]).astype('int64') # [1]\n",
" \n",
" print(audio.shape, audio_len.shape)\n",
"\n",
" i_probs = run(predictor, audio, audio_len)\n",
" print('jit:', i_probs[0])\n",
" return i_probs\n",
" \n",
"probs = test(sample['audio_filepath'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 32,
"id": "academic-surname",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from paddle import nn"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "fundamental-treasure",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])\n",
"Parameter containing:\n",
"Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n"
]
}
],
"source": [
"L = nn.LayerNorm(256, epsilon=1e-12)\n",
"for p in L.parameters():\n",
" print(p)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "consolidated-elephant",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "moderate-noise",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float64\n"
]
}
],
"source": [
"x = np.random.randn(2, 51, 256)\n",
"print(x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "cooked-progressive",
"metadata": {},
"outputs": [],
"source": [
"y = L(paddle.to_tensor(x, dtype='float32'))"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "optimum-milwaukee",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "viral-indian",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1.], requires_grad=True)\n",
"Parameter containing:\n",
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" requires_grad=True)\n"
]
}
],
"source": [
"TL = torch.nn.LayerNorm(256, eps=1e-12)\n",
"for p in TL.parameters():\n",
" print(p)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "skilled-vietnamese",
"metadata": {},
"outputs": [],
"source": [
"ty = TL(torch.tensor(x, dtype=torch.float32))"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "incorrect-allah",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.allclose(y.numpy(), ty.detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "prostate-cameroon",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 52,
"id": "governmental-surge",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.random.randn(2, 256)\n",
"y = L(paddle.to_tensor(x, dtype='float32'))\n",
"ty = TL(torch.tensor(x, dtype=torch.float32))\n",
"np.allclose(y.numpy(), ty.detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "confidential-jacket",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "primary-organic",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "stopped-semester",
"metadata": {},
"outputs": [],
"source": [
"def mask_finished_scores(score: torch.Tensor,\n",
" flag: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" If a sequence is finished, we only allow one alive branch. This function\n",
" aims to give one branch a zero score and the rest -inf score.\n",
" Args:\n",
" score (torch.Tensor): A real value array with shape\n",
" (batch_size * beam_size, beam_size).\n",
" flag (torch.Tensor): A bool array with shape\n",
" (batch_size * beam_size, 1).\n",
" Returns:\n",
" torch.Tensor: (batch_size * beam_size, beam_size).\n",
" \"\"\"\n",
" beam_size = score.size(-1)\n",
" zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n",
" if beam_size > 1:\n",
" unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n",
" dim=1)\n",
" finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n",
" dim=1)\n",
" else:\n",
" unfinished = zero_mask\n",
" finished = flag\n",
" print(unfinished)\n",
" print(finished)\n",
" score.masked_fill_(unfinished, -float('inf'))\n",
" score.masked_fill_(finished, 0)\n",
" return score"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "agreed-portuguese",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ True],\n",
" [False]])\n",
"tensor([[-0.8841, 0.7381, -0.9986],\n",
" [ 0.2675, -0.7971, 0.3798]])\n",
"tensor([[ True, True],\n",
" [False, False]])\n"
]
}
],
"source": [
"score = torch.randn((2, 3))\n",
"flag = torch.ones((2, 1), dtype=torch.bool)\n",
"flag[1] = False\n",
"print(flag)\n",
"print(score)\n",
"print(flag.repeat([1, 2]))"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "clean-aspect",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[False, True, True],\n",
" [False, False, False]])\n",
"tensor([[ True, False, False],\n",
" [False, False, False]])\n",
"tensor([[ 0.0000, -inf, -inf],\n",
" [ 0.2675, -0.7971, 0.3798]])\n",
"tensor([[ 0.0000, -inf, -inf],\n",
" [ 0.2675, -0.7971, 0.3798]])\n"
]
}
],
"source": [
"r = mask_finished_scores(score, flag)\n",
"print(r)\n",
"print(score)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "thrown-airline",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True ],\n",
" [False]])\n",
"Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"import paddle\n",
"\n",
"score = paddle.randn((2, 3))\n",
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag)\n",
"print(score)\n",
"print(flag.tile([1, 2]))"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "internal-patent",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[False, True , True ],\n",
" [False, False, False]])\n",
"Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , False, False],\n",
" [False, False, False]])\n",
"x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n"
]
}
],
"source": [
"paddle.bool = 'bool'\n",
"\n",
"def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
" print(xs)\n",
" trues = paddle.ones_like(xs) * value\n",
" assert xs.shape == mask.shape\n",
" xs = paddle.where(mask, trues, xs)\n",
" return xs\n",
"\n",
"def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
" print('x', xs)\n",
" trues = paddle.ones_like(xs) * value\n",
" assert xs.shape == mask.shape\n",
" ret = paddle.where(mask, trues, xs)\n",
" print('2', xs)\n",
" paddle.assign(ret, output=xs)\n",
" print('3', xs)\n",
"\n",
"paddle.Tensor.masked_fill = masked_fill\n",
"paddle.Tensor.masked_fill_ = masked_fill_\n",
"\n",
"def mask_finished_scores_pd(score: paddle.Tensor,\n",
" flag: paddle.Tensor) -> paddle.Tensor:\n",
" \"\"\"\n",
" If a sequence is finished, we only allow one alive branch. This function\n",
" aims to give one branch a zero score and the rest -inf score.\n",
" Args:\n",
" score (torch.Tensor): A real value array with shape\n",
" (batch_size * beam_size, beam_size).\n",
" flag (torch.Tensor): A bool array with shape\n",
" (batch_size * beam_size, 1).\n",
" Returns:\n",
" torch.Tensor: (batch_size * beam_size, beam_size).\n",
" \"\"\"\n",
" beam_size = score.shape[-1]\n",
" zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n",
" if beam_size > 1:\n",
" unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n",
" axis=1)\n",
" finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n",
" axis=1)\n",
" else:\n",
" unfinished = zero_mask\n",
" finished = flag\n",
" print(unfinished)\n",
" print(finished)\n",
" \n",
" #score.masked_fill_(unfinished, -float('inf'))\n",
" #score.masked_fill_(finished, 0)\n",
"# infs = paddle.ones_like(score) * -float('inf')\n",
"# score = paddle.where(unfinished, infs, score)\n",
"# score = paddle.where(finished, paddle.zeros_like(score), score)\n",
"\n",
"# score = score.masked_fill(unfinished, -float('inf'))\n",
"# score = score.masked_fill(finished, 0)\n",
" score.masked_fill_(unfinished, -float('inf'))\n",
" score.masked_fill_(finished, 0)\n",
" return score\n",
"\n",
"r = mask_finished_scores_pd(score, flag)\n",
"print(r)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "vocal-prime",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<bound method PyCapsule.value of Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])>"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"score.value"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "bacterial-adolescent",
"metadata": {},
"outputs": [],
"source": [
"from typing import Union, Any"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "absent-fiber",
"metadata": {},
"outputs": [],
"source": [
"def repeat(xs : paddle.Tensor, *size: Any):\n",
" print(size)\n",
" return paddle.tile(xs, size)\n",
"paddle.Tensor.repeat = repeat"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "material-harbor",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 2)\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag.repeat(1, 2))"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "acute-brighton",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [1]), 2)\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag.repeat(paddle.to_tensor(1), 2))"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "european-rugby",
"metadata": {},
"outputs": [],
"source": [
"def size(xs, *args: int):\n",
" nargs = len(args)\n",
" s = paddle.shape(xs)\n",
" assert(nargs <= 1)\n",
" if nargs == 1:\n",
" return s[args[0]]\n",
" else:\n",
" return s\n",
"paddle.Tensor.size = size"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "moral-special",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [2, 1])"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size()"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "ahead-coach",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [1])"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size(1)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "incomplete-fitness",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [2])"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "upset-connectivity",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "designing-borough",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n",
"True\n",
"True\n"
]
}
],
"source": [
"import torch\n",
"import math\n",
"import numpy as np\n",
"\n",
"max_len=100\n",
"d_model=256\n",
"\n",
"pe = torch.zeros(max_len, d_model)\n",
"position = torch.arange(0, max_len,\n",
" dtype=torch.float32).unsqueeze(1)\n",
"toruch_position = position\n",
"div_term = torch.exp(\n",
" torch.arange(0, d_model, 2, dtype=torch.float32) *\n",
" -(math.log(10000.0) / d_model))\n",
"tourch_div_term = div_term.cpu().detach().numpy()\n",
"\n",
"\n",
"\n",
"torhc_sin = torch.sin(position * div_term)\n",
"torhc_cos = torch.cos(position * div_term)\n",
"print(torhc_sin.cpu().detach().numpy())\n",
"np_sin = np.sin((position * div_term).cpu().detach().numpy())\n",
"np_cos = np.cos((position * div_term).cpu().detach().numpy())\n",
"print(np.allclose(np_sin, torhc_sin.cpu().detach().numpy()))\n",
"print(np.allclose(np_cos, torhc_cos.cpu().detach().numpy()))\n",
"pe[:, 0::2] = torhc_sin\n",
"pe[:, 1::2] = torhc_cos\n",
"tourch_pe = pe.cpu().detach().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "swiss-referral",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"False\n",
"False\n",
"False\n",
"False\n",
"[[ 1. 1. 1. ... 1. 1.\n",
" 1. ]\n",
" [ 0.5403023 0.59737533 0.6479059 ... 1. 1.\n",
" 1. ]\n",
" [-0.41614684 -0.28628543 -0.1604359 ... 0.99999994 1.\n",
" 1. ]\n",
" ...\n",
" [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.99993724\n",
" 0.9999457 ]\n",
" [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n",
" 0.99994457]\n",
" [ 0.03982088 -0.52298605 -0.6157435 ... 0.99992454 0.9999347\n",
" 0.99994344]]\n",
"----\n",
"[[ 1. 1. 1. ... 1. 1.\n",
" 1. ]\n",
" [ 0.54030234 0.59737533 0.6479059 ... 1. 1.\n",
" 1. ]\n",
" [-0.41614684 -0.28628543 -0.1604359 ... 1. 1.\n",
" 1. ]\n",
" ...\n",
" [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.9999373\n",
" 0.9999457 ]\n",
" [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n",
" 0.99994457]\n",
" [ 0.03982088 -0.5229861 -0.6157435 ... 0.99992454 0.9999347\n",
" 0.99994344]]\n",
")))))))\n",
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n",
"----\n",
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n"
]
}
],
"source": [
"import paddle\n",
"paddle.set_device('cpu')\n",
"ppe = paddle.zeros((max_len, d_model), dtype='float32')\n",
"position = paddle.arange(0, max_len,\n",
" dtype='float32').unsqueeze(1)\n",
"print(np.allclose(position.numpy(), toruch_position))\n",
"div_term = paddle.exp(\n",
" paddle.arange(0, d_model, 2, dtype='float32') *\n",
" -(math.log(10000.0) / d_model))\n",
"print(np.allclose(div_term.numpy(), tourch_div_term))\n",
"\n",
"\n",
"\n",
"p_sin = paddle.sin(position * div_term)\n",
"p_cos = paddle.cos(position * div_term)\n",
"print(np.allclose(np_sin, p_sin.numpy(), rtol=1.e-6, atol=0))\n",
"print(np.allclose(np_cos, p_cos.numpy(), rtol=1.e-6, atol=0))\n",
"ppe[:, 0::2] = p_sin\n",
"ppe[:, 1::2] = p_cos\n",
"print(np.allclose(p_sin.numpy(), torhc_sin.cpu().detach().numpy()))\n",
"print(np.allclose(p_cos.numpy(), torhc_cos.cpu().detach().numpy()))\n",
"print(p_cos.numpy())\n",
"print(\"----\")\n",
"print(torhc_cos.cpu().detach().numpy())\n",
"print(\")))))))\")\n",
"print(p_sin.numpy())\n",
"print(\"----\")\n",
"print(torhc_sin.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "integrated-boards",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
"source": [
"print(np.allclose(ppe.numpy(), pe.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "flying-reserve",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "revised-divide",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
因为 它太大了无法显示 source diff 。你可以改为 查看blob
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cloudy-glass",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISISBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "grand-stephen",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0.0\n"
]
}
],
"source": [
"import paddle\n",
"print(paddle.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "isolated-prize",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "romance-samuel",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/data/manifest.dev',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/aishell/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'linear',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "timely-bikini",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"from data_utils.dataset import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=False,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "organized-warrior",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if arr.dtype == np.object:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[14 , 34 , 322 , 233 , 0 , 0 ],\n",
" [238 , 38 , 122 , 164 , 0 , 0 ],\n",
" [8 , 52 , 49 , 42 , 0 , 0 ],\n",
" [109 , 47 , 146 , 193 , 210 , 479 ],\n",
" [3330, 1751, 208 , 1923, 0 , 0 ]])\n",
"test raw 大时代里的的\n",
"test raw 煲汤受宠的的\n",
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"test len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [4, 4, 4, 6, 4])\n",
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
" ...,\n",
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
"\n",
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
" ...,\n",
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
]
}
],
"source": [
" for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n",
" print('audio len', audio_len)\n",
" print('test len', text_len)\n",
" print('audio', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "confidential-radius",
"metadata": {},
"outputs": [],
"source": [
"# reader = batch_reader()\n",
"# audio, test , audio_len, text_len = reader.next()\n",
"# print('test', text)\n",
"# print('t len', text_len) #[B, T]\n",
"# print('audio len', audio_len)\n",
"# print(audio)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "future-vermont",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"煲汤受宠\n"
]
}
],
"source": [
"print(u'\\u7172\\u6c64\\u53d7\\u5ba0')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dental-sweden",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "sunrise-contact",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "hispanic-asthma",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "hearing-leadership",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "skilled-friday",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "copyrighted-measure",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"id": "employed-lightweight",
"metadata": {},
"outputs": [],
"source": [
"from model_utils.network import DeepSpeech2, DeepSpeech2Loss\n",
"\n",
"from data_utils.dataset import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=False,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)\n",
"\n",
"\n",
"import paddle\n",
"from paddle import nn\n",
"from paddle.nn import functional as F\n",
"from paddle.nn import initializer as I\n",
"\n",
"import math\n",
"\n",
"def brelu(x, t_min=0.0, t_max=24.0, name=None):\n",
" t_min = paddle.to_tensor(t_min)\n",
" t_max = paddle.to_tensor(t_max)\n",
" return x.maximum(t_min).minimum(t_max)\n",
"\n",
"def sequence_mask(x_len, max_len=None, dtype='float32'):\n",
" max_len = max_len or x_len.max()\n",
" x_len = paddle.unsqueeze(x_len, -1)\n",
" row_vector = paddle.arange(max_len)\n",
" mask = row_vector > x_len # maybe a bug\n",
" mask = paddle.cast(mask, dtype)\n",
" print(f'seq mask: {mask}')\n",
" return mask\n",
"\n",
"\n",
"class ConvBn(nn.Layer):\n",
" def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,\n",
" padding, act):\n",
"\n",
" super().__init__()\n",
" self.kernel_size = kernel_size\n",
" self.stride = stride\n",
" self.padding = padding\n",
"\n",
" self.conv = nn.Conv2D(\n",
" num_channels_in,\n",
" num_channels_out,\n",
" kernel_size=kernel_size,\n",
" stride=stride,\n",
" padding=padding,\n",
" weight_attr=None,\n",
" bias_attr=None,\n",
" data_format='NCHW')\n",
"\n",
" self.bn = nn.BatchNorm2D(\n",
" num_channels_out,\n",
" weight_attr=None,\n",
" bias_attr=None,\n",
" data_format='NCHW')\n",
" self.act = F.relu if act == 'relu' else brelu\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x(Tensor): audio, shape [B, C, D, T]\n",
" \"\"\"\n",
" x = self.conv(x)\n",
" x = self.bn(x)\n",
" x = self.act(x)\n",
"\n",
" x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]\n",
" ) // self.stride[1] + 1\n",
"\n",
" # reset padding part to 0\n",
" masks = sequence_mask(x_len) #[B, T]\n",
" masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]\n",
" x = x.multiply(masks)\n",
"\n",
" return x, x_len\n",
"\n",
"\n",
"class ConvStack(nn.Layer):\n",
" def __init__(self, feat_size, num_stacks):\n",
" super().__init__()\n",
" self.feat_size = feat_size # D\n",
" self.num_stacks = num_stacks\n",
"\n",
" self.conv_in = ConvBn(\n",
" num_channels_in=1,\n",
" num_channels_out=32,\n",
" kernel_size=(41, 11), #[D, T]\n",
" stride=(2, 3),\n",
" padding=(20, 5),\n",
" act='brelu')\n",
"\n",
" out_channel = 32\n",
" self.conv_stack = nn.Sequential([\n",
" ConvBn(\n",
" num_channels_in=32,\n",
" num_channels_out=out_channel,\n",
" kernel_size=(21, 11),\n",
" stride=(2, 1),\n",
" padding=(10, 5),\n",
" act='brelu') for i in range(num_stacks - 1)\n",
" ])\n",
"\n",
" # conv output feat_dim\n",
" output_height = (feat_size - 1) // 2 + 1\n",
" for i in range(self.num_stacks - 1):\n",
" output_height = (output_height - 1) // 2 + 1\n",
" self.output_height = out_channel * output_height\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x: shape [B, C, D, T]\n",
" x_len : shape [B]\n",
" \"\"\"\n",
" print(f\"conv in: {x_len}\")\n",
" x, x_len = self.conv_in(x, x_len)\n",
" for i, conv in enumerate(self.conv_stack):\n",
" print(f\"conv in: {x_len}\")\n",
" x, x_len = conv(x, x_len)\n",
" print(f\"conv out: {x_len}\")\n",
" return x, x_len\n",
" \n",
" \n",
"\n",
"class RNNCell(nn.RNNCellBase):\n",
" r\"\"\"\n",
" Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it \n",
" computes the outputs and updates states.\n",
" The formula used is as follows:\n",
" .. math::\n",
" h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})\n",
" y_{t} & = h_{t}\n",
" \n",
" where :math:`act` is for :attr:`activation`.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" hidden_size,\n",
" activation=\"tanh\",\n",
" weight_ih_attr=None,\n",
" weight_hh_attr=None,\n",
" bias_ih_attr=None,\n",
" bias_hh_attr=None,\n",
" name=None):\n",
" super().__init__()\n",
" std = 1.0 / math.sqrt(hidden_size)\n",
" self.weight_hh = self.create_parameter(\n",
" (hidden_size, hidden_size),\n",
" weight_hh_attr,\n",
" default_initializer=I.Uniform(-std, std))\n",
" # self.bias_ih = self.create_parameter(\n",
" # (hidden_size, ),\n",
" # bias_ih_attr,\n",
" # is_bias=True,\n",
" # default_initializer=I.Uniform(-std, std))\n",
" self.bias_ih = None\n",
" self.bias_hh = self.create_parameter(\n",
" (hidden_size, ),\n",
" bias_hh_attr,\n",
" is_bias=True,\n",
" default_initializer=I.Uniform(-std, std))\n",
"\n",
" self.hidden_size = hidden_size\n",
" if activation not in [\"tanh\", \"relu\", \"brelu\"]:\n",
" raise ValueError(\n",
" \"activation for SimpleRNNCell should be tanh or relu, \"\n",
" \"but get {}\".format(activation))\n",
" self.activation = activation\n",
" self._activation_fn = paddle.tanh \\\n",
" if activation == \"tanh\" \\\n",
" else F.relu\n",
" if activation == 'brelu':\n",
" self._activation_fn = brelu\n",
"\n",
" def forward(self, inputs, states=None):\n",
" if states is None:\n",
" states = self.get_initial_states(inputs, self.state_shape)\n",
" pre_h = states\n",
" i2h = inputs\n",
" if self.bias_ih is not None:\n",
" i2h += self.bias_ih\n",
" h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)\n",
" if self.bias_hh is not None:\n",
" h2h += self.bias_hh\n",
" h = self._activation_fn(i2h + h2h)\n",
" return h, h\n",
"\n",
" @property\n",
" def state_shape(self):\n",
" return (self.hidden_size, )\n",
"\n",
"\n",
"class GRUCellShare(nn.RNNCellBase):\n",
" r\"\"\"\n",
" Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, \n",
" it computes the outputs and updates states.\n",
" The formula for GRU used is as follows:\n",
" .. math::\n",
" r_{t} & = \\sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})\n",
" z_{t} & = \\sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})\n",
" \\widetilde{h}_{t} & = \\tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))\n",
" h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \\widetilde{h}_{t}\n",
" y_{t} & = h_{t}\n",
" \n",
" where :math:`\\sigma` is the sigmoid fucntion, and * is the elemetwise \n",
" multiplication operator.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" input_size,\n",
" hidden_size,\n",
" weight_ih_attr=None,\n",
" weight_hh_attr=None,\n",
" bias_ih_attr=None,\n",
" bias_hh_attr=None,\n",
" name=None):\n",
" super().__init__()\n",
" std = 1.0 / math.sqrt(hidden_size)\n",
" self.weight_hh = self.create_parameter(\n",
" (3 * hidden_size, hidden_size),\n",
" weight_hh_attr,\n",
" default_initializer=I.Uniform(-std, std))\n",
" # self.bias_ih = self.create_parameter(\n",
" # (3 * hidden_size, ),\n",
" # bias_ih_attr,\n",
" # is_bias=True,\n",
" # default_initializer=I.Uniform(-std, std))\n",
" self.bias_ih = None\n",
" self.bias_hh = self.create_parameter(\n",
" (3 * hidden_size, ),\n",
" bias_hh_attr,\n",
" is_bias=True,\n",
" default_initializer=I.Uniform(-std, std))\n",
"\n",
" self.hidden_size = hidden_size\n",
" self.input_size = input_size\n",
" self._gate_activation = F.sigmoid\n",
" #self._activation = paddle.tanh\n",
" self._activation = F.relu\n",
"\n",
" def forward(self, inputs, states=None):\n",
" if states is None:\n",
" states = self.get_initial_states(inputs, self.state_shape)\n",
"\n",
" pre_hidden = states\n",
" x_gates = inputs\n",
" if self.bias_ih is not None:\n",
" x_gates = x_gates + self.bias_ih\n",
" h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)\n",
" if self.bias_hh is not None:\n",
" h_gates = h_gates + self.bias_hh\n",
"\n",
" x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)\n",
" h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)\n",
"\n",
" r = self._gate_activation(x_r + h_r)\n",
" z = self._gate_activation(x_z + h_z)\n",
" c = self._activation(x_c + r * h_c) # apply reset gate after mm\n",
" h = (pre_hidden - c) * z + c\n",
" # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru\n",
" #h = (1-z) * pre_hidden + z * c\n",
"\n",
" return h, h\n",
"\n",
" @property\n",
" def state_shape(self):\n",
" r\"\"\"\n",
" The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch\n",
" size would be automatically inserted into shape). The shape corresponds\n",
" to the shape of :math:`h_{t-1}`.\n",
" \"\"\"\n",
" return (self.hidden_size, )\n",
"\n",
"\n",
"class BiRNNWithBN(nn.Layer):\n",
" \"\"\"Bidirectonal simple rnn layer with sequence-wise batch normalization.\n",
" The batch normalization is only performed on input-state weights.\n",
"\n",
" :param name: Name of the layer parameters.\n",
" :type name: string\n",
" :param size: Dimension of RNN cells.\n",
" :type size: int\n",
" :param share_weights: Whether to share input-hidden weights between\n",
" forward and backward directional RNNs.\n",
" :type share_weights: bool\n",
" :return: Bidirectional simple rnn layer.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, share_weights):\n",
" super().__init__()\n",
" self.share_weights = share_weights\n",
" if self.share_weights:\n",
" #input-hidden weights shared between bi-directional rnn.\n",
" self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" # batch norm is only performed on input-state projection\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = self.fw_fc\n",
" self.bw_bn = self.fw_bn\n",
" else:\n",
" self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" self.bw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
"\n",
" self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n",
" self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n",
" self.fw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n",
" self.bw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n",
"\n",
" def forward(self, x, x_len):\n",
" # x, shape [B, T, D]\n",
" fw_x = self.fw_bn(self.fw_fc(x))\n",
" bw_x = self.bw_bn(self.bw_fc(x))\n",
" fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n",
" bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n",
" x = paddle.concat([fw_x, bw_x], axis=-1)\n",
" return x, x_len\n",
"\n",
"\n",
"class BiGRUWithBN(nn.Layer):\n",
" \"\"\"Bidirectonal gru layer with sequence-wise batch normalization.\n",
" The batch normalization is only performed on input-state weights.\n",
"\n",
" :param name: Name of the layer.\n",
" :type name: string\n",
" :param input: Input layer.\n",
" :type input: Variable\n",
" :param size: Dimension of GRU cells.\n",
" :type size: int\n",
" :param act: Activation type.\n",
" :type act: string\n",
" :return: Bidirectional GRU layer.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, act):\n",
" super().__init__()\n",
" hidden_size = h_size * 3\n",
" self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" hidden_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n",
" self.bw_bn = nn.BatchNorm1D(\n",
" hidden_size, bias_attr=None, data_format='NLC')\n",
"\n",
" self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n",
" self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n",
" self.fw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n",
" self.bw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n",
"\n",
" def forward(self, x, x_len):\n",
" # x, shape [B, T, D]\n",
" fw_x = self.fw_bn(self.fw_fc(x))\n",
" bw_x = self.bw_bn(self.bw_fc(x))\n",
" fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n",
" bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n",
" x = paddle.concat([fw_x, bw_x], axis=-1)\n",
" return x, x_len\n",
"\n",
"\n",
"class RNNStack(nn.Layer):\n",
" \"\"\"RNN group with stacked bidirectional simple RNN or GRU layers.\n",
"\n",
" :param input: Input layer.\n",
" :type input: Variable\n",
" :param size: Dimension of RNN cells in each layer.\n",
" :type size: int\n",
" :param num_stacks: Number of stacked rnn layers.\n",
" :type num_stacks: int\n",
" :param use_gru: Use gru if set True. Use simple rnn if set False.\n",
" :type use_gru: bool\n",
" :param share_rnn_weights: Whether to share input-hidden weights between\n",
" forward and backward directional RNNs.\n",
" It is only available when use_gru=False.\n",
" :type share_weights: bool\n",
" :return: Output layer of the RNN group.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):\n",
" super().__init__()\n",
" self.rnn_stacks = nn.LayerList()\n",
" for i in range(num_stacks):\n",
" if use_gru:\n",
" #default:GRU using tanh\n",
" self.rnn_stacks.append(\n",
" BiGRUWithBN(i_size=i_size, h_size=h_size, act=\"relu\"))\n",
" else:\n",
" self.rnn_stacks.append(\n",
" BiRNNWithBN(\n",
" i_size=i_size,\n",
" h_size=h_size,\n",
" share_weights=share_rnn_weights))\n",
" i_size = h_size * 2\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x: shape [B, T, D]\n",
" x_len: shpae [B]\n",
" \"\"\"\n",
" for i, rnn in enumerate(self.rnn_stacks):\n",
" x, x_len = rnn(x, x_len)\n",
" masks = sequence_mask(x_len) #[B, T]\n",
" masks = masks.unsqueeze(-1) # [B, T, 1]\n",
" x = x.multiply(masks)\n",
" return x, x_len\n",
"\n",
" \n",
"class DeepSpeech2Test(DeepSpeech2):\n",
" def __init__(self,\n",
" feat_size,\n",
" dict_size,\n",
" num_conv_layers=2,\n",
" num_rnn_layers=3,\n",
" rnn_size=256,\n",
" use_gru=False,\n",
" share_rnn_weights=True):\n",
" super().__init__(feat_size,\n",
" dict_size,\n",
" num_conv_layers=2,\n",
" num_rnn_layers=3,\n",
" rnn_size=256,\n",
" use_gru=False,\n",
" share_rnn_weights=True)\n",
" self.feat_size = feat_size # 161 for linear\n",
" self.dict_size = dict_size\n",
"\n",
" self.conv = ConvStack(feat_size, num_conv_layers)\n",
" \n",
"# self.fc = nn.Linear(1312, dict_size + 1)\n",
"\n",
" i_size = self.conv.output_height # H after conv stack\n",
" self.rnn = RNNStack(\n",
" i_size=i_size,\n",
" h_size=rnn_size,\n",
" num_stacks=num_rnn_layers,\n",
" use_gru=use_gru,\n",
" share_rnn_weights=share_rnn_weights)\n",
" \n",
" self.fc = nn.Linear(rnn_size * 2, dict_size + 1)\n",
" \n",
" def infer(self, audio, audio_len):\n",
" # [B, D, T] -> [B, C=1, D, T]\n",
" audio = audio.unsqueeze(1)\n",
"\n",
" # convolution group\n",
" x, audio_len = self.conv(audio, audio_len)\n",
" print('conv out', x.shape)\n",
"\n",
" # convert data from convolution feature map to sequence of vectors\n",
" B, C, D, T = paddle.shape(x)\n",
" x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]\n",
" x = x.reshape([B, T, C * D]) #[B, T, C*D]\n",
" print('rnn input', x.shape)\n",
"\n",
" # remove padding part\n",
" x, audio_len = self.rnn(x, audio_len) #[B, T, D]\n",
" print('rnn output', x.shape)\n",
"\n",
" logits = self.fc(x) #[B, T, V + 1]\n",
"\n",
" #ctcdecoder need probs, not log_probs\n",
" probs = F.softmax(logits)\n",
"\n",
" return logits, probs, audio_len\n",
"\n",
" def forward(self, audio, audio_len, text, text_len):\n",
" \"\"\"\n",
" audio: shape [B, D, T]\n",
" text: shape [B, T]\n",
" audio_len: shape [B]\n",
" text_len: shape [B]\n",
" \"\"\"\n",
" return self.infer(audio, audio_len)\n",
" \n",
"\n",
"feat_dim=161\n",
"\n",
"model = DeepSpeech2Test(\n",
" feat_size=feat_dim,\n",
" dict_size=batch_reader.dataset.vocab_size,\n",
" num_conv_layers=args.num_conv_layers,\n",
" num_rnn_layers=args.num_rnn_layers,\n",
" rnn_size=1024,\n",
" use_gru=args.use_gru,\n",
" share_rnn_weights=args.share_rnn_weights,\n",
" )\n",
"dp_model = model\n",
"#dp_model = paddle.DataParallel(model)\n",
"\n",
"loss_fn = DeepSpeech2Loss(batch_reader.dataset.vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "divided-incentive",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 22,
"id": "discrete-conjunction",
"metadata": {},
"outputs": [],
"source": [
"audio, audio_len, text, text_len = None, None, None, None\n",
"\n",
"for idx, inputs in enumerate(batch_reader):\n",
" audio, audio_len, text, text_len = inputs\n",
"# print(idx)\n",
"# print('a', audio.shape, audio.place)\n",
"# print('t', text)\n",
"# print('al', audio_len)\n",
"# print('tl', text_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "protected-announcement",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"conv in: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"conv in: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"conv out: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"conv out [5, 32, 41, 62]\n",
"rnn input [5, 62, 1312]\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
" return (isinstance(seq, collections.Sequence) and\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"rnn output [5, 62, 2048]\n",
"logits len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"loss Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [2316.82153320])\n"
]
}
],
"source": [
"outputs = dp_model(audio, audio_len, text, text_len)\n",
"logits, _, logits_len = outputs\n",
"print('logits len', logits_len)\n",
"loss = loss_fn.forward(logits, text, logits_len, text_len)\n",
"print('loss', loss)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "universal-myrtle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: None\n",
"param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: None\n",
"param grad: fc.bias: shape: [4299] stop_grad: False grad: None\n"
]
}
],
"source": [
"for n, p in dp_model.named_parameters():\n",
" print(\n",
" f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "referenced-double",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: [[[[ 2.1243238 1.696022 3.770659 ... 5.234652 5.4865217\n",
" 4.757795 ]\n",
" [ 2.651376 2.3109848 4.428488 ... 5.353201 8.703288\n",
" 5.1787405 ]\n",
" [ 2.7511077 1.8823049 2.1875212 ... 3.4821286 6.386543\n",
" 3.5026932 ]\n",
" ...\n",
" [ 1.9173846 1.8623551 0.5601456 ... 2.8375719 3.8496673\n",
" 2.359191 ]\n",
" [ 2.3827765 2.497965 1.5914664 ... 2.220721 3.4617734\n",
" 4.829253 ]\n",
" [ 1.6855702 1.5040786 1.8793598 ... 4.0773935 3.176893\n",
" 3.7477999 ]]]\n",
"\n",
"\n",
" [[[ 1.8451455 2.0091445 1.5225713 ... 1.524528 0.17764974\n",
" 1.0245132 ]\n",
" [ 1.9388857 1.3873467 2.044691 ... 0.92544 -0.9746763\n",
" -0.41603735]\n",
" [ 2.6814485 2.6096234 1.6802506 ... 1.902397 1.6837387\n",
" -0.96788657]\n",
" ...\n",
" [ 4.3675485 1.9822174 1.1695029 ... 1.4672399 3.2029557\n",
" 2.6364415 ]\n",
" [ 3.2536 1.1792442 -0.5618002 ... 2.101127 1.904225\n",
" 3.3839993 ]\n",
" [ 1.9118482 1.0651072 0.5409893 ... 2.6783593 1.6871439\n",
" 4.1078367 ]]]\n",
"\n",
"\n",
" [[[-4.412424 -1.7111907 -1.7722387 ... -4.3383503 -6.2393785\n",
" -6.139402 ]\n",
" [-2.260428 -1.0250616 -2.0550888 ... -5.353946 -4.29947\n",
" -6.158736 ]\n",
" [-1.4927872 0.7552787 -0.0702923 ... -4.485656 -4.0794134\n",
" -5.416684 ]\n",
" ...\n",
" [ 2.9100134 4.156195 4.357041 ... -3.569804 -1.8634341\n",
" -0.8772557 ]\n",
" [ 1.6895763 3.4314504 4.1192107 ... -1.380024 -2.3234155\n",
" -3.6650617 ]\n",
" [ 2.4190075 1.007498 3.1173465 ... -0.96318084 -3.6175003\n",
" -2.5240796 ]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[-0.6865506 -0.60106415 -1.5555015 ... 2.0853553 1.900961\n",
" 2.101063 ]\n",
" [-0.31686288 -1.4362946 -1.4929098 ... 0.15085456 1.4540495\n",
" 1.4128599 ]\n",
" [-0.57852304 -0.8204216 -2.3264258 ... 1.4970423 0.54599845\n",
" 1.6222539 ]\n",
" ...\n",
" [ 0.32624918 0.96004546 -0.7476514 ... 2.2786083 2.1000178\n",
" 2.7494807 ]\n",
" [-1.6967826 -0.78979015 -1.8424999 ... 1.0620685 2.0544293\n",
" 2.2483966 ]\n",
" [ 0.8192332 2.601636 -2.6636481 ... 0.26625186 1.7610842\n",
" 1.7467536 ]]]\n",
"\n",
"\n",
" [[[ 0.9140297 0.42424175 1.4352363 ... -2.3022954 -3.001058\n",
" -2.6987422 ]\n",
" [ 0.4491998 -0.10698095 1.5089144 ... -3.2831016 -3.6055021\n",
" -3.6595795 ]\n",
" [ 2.6818252 -1.5750014 -0.34812498 ... -4.4137015 -4.250422\n",
" -3.481941 ]\n",
" ...\n",
" [ 1.4232106 2.9689102 3.9547806 ... -0.481165 0.28190404\n",
" -1.2167063 ]\n",
" [ 2.2297084 4.8198485 4.2857304 ... 0.57483846 1.4093391\n",
" 0.0715822 ]\n",
" [ 1.679745 4.768068 5.416195 ... 0.17254728 0.4623217\n",
" 1.4772662 ]]]\n",
"\n",
"\n",
" [[[-2.0860114 -2.9508173 -1.4945896 ... -4.067145 -2.5652342\n",
" -3.5771027 ]\n",
" [-2.697845 -1.9273603 -2.3885014 ... -2.196533 -2.8573706\n",
" -2.0113711 ]\n",
" [-2.413383 -2.7204053 -1.0502659 ... -3.001385 -3.36447\n",
" -4.3225455 ]\n",
" ...\n",
" [ 1.2754489 0.9560999 1.5239805 ... -0.0105865 -1.00876\n",
" 2.6247358 ]\n",
" [ 1.1965859 1.0378222 1.1025598 ... -0.5394704 0.49838027\n",
" -0.9618193 ]\n",
" [ 1.1361816 1.3232857 0.687318 ... -0.23925456 -0.43679112\n",
" -0.79297894]]]]\n",
"param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: [ 5.9604645e-07 -3.9339066e-06 -1.0728836e-06 -1.6689301e-06\n",
" 1.1920929e-06 -2.5033951e-06 -2.3841858e-07 4.7683716e-07\n",
" 4.2915344e-06 -1.9073486e-06 -1.9073486e-06 3.0994415e-06\n",
" -2.6822090e-06 3.3378601e-06 -4.2915344e-06 5.2452087e-06\n",
" 3.8146973e-06 2.3841858e-07 7.1525574e-07 -3.6954880e-06\n",
" 2.0563602e-06 -2.6226044e-06 3.0994415e-06 -3.5762787e-07\n",
" -4.7683716e-06 1.2218952e-06 3.3378601e-06 -2.5629997e-06\n",
" 2.3841858e-07 -1.7881393e-06 4.7683716e-07 -2.7418137e-06]\n",
"param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: [ 2.363316 3.286464 1.9607866 -1.6367784 -1.6325372 -1.7729434\n",
" -0.9261875 2.0950415 0.1155543 -0.8857083 0.70079553 0.33920464\n",
" 2.6953902 -0.64524114 0.8845749 -1.2271115 0.6578167 -2.939814\n",
" 5.5728893 -1.0917969 0.01470797 1.395206 4.8009634 -0.744532\n",
" 0.944651 -1.092311 1.4877632 -3.042566 0.51686054 -5.4768667\n",
" -5.628145 -1.0894046 ]\n",
"param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: [ 1.5193373 1.8838218 3.7722278 0.28052303 0.5386534 -0.44620085\n",
" -1.6977876 3.115642 0.03312349 -2.9121587 3.8925257 0.2288351\n",
" -2.273387 -1.3597974 4.3708124 -0.23374033 0.116272 -0.7064927\n",
" 6.5267463 -1.5318865 1.0288429 0.7928574 -0.24655592 -2.1116853\n",
" 2.922772 -3.3462617 1.7016437 -3.5471547 0.29777628 -3.2820854\n",
" -4.116946 -0.9909375 ]\n",
"param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: [[[[ 6.20494843e-01 5.95983505e-01 -1.48909020e+00 ... -6.86620831e-01\n",
" 6.71104014e-01 -1.95339048e+00]\n",
" [-3.91837955e-03 1.27062631e+00 -1.63248098e+00 ... 1.07290137e+00\n",
" -9.42245364e-01 -3.34277248e+00]\n",
" [ 2.41821265e+00 2.36212373e-01 -1.84433365e+00 ... 1.23182368e+00\n",
" 1.36039746e+00 -2.94621849e+00]\n",
" ...\n",
" [ 1.55153418e+00 7.25861669e-01 2.08785534e+00 ... -6.40172660e-01\n",
" -3.23889256e-02 -2.30832791e+00]\n",
" [ 3.69824195e+00 1.27163112e-01 4.09263194e-01 ... -8.60729575e-01\n",
" -3.51897454e+00 -2.10093403e+00]\n",
" [-4.94779050e-01 -3.74262631e-01 -1.19801068e+00 ... -2.05930543e+00\n",
" -7.38576293e-01 -9.44581270e-01]]\n",
"\n",
" [[-2.04341412e+00 -3.70606273e-01 -1.40429378e+00 ... -1.71711946e+00\n",
" -4.09437418e-01 -1.74107194e+00]\n",
" [-8.72247815e-01 -1.06301677e+00 -9.19306517e-01 ... -2.98976970e+00\n",
" -3.03250861e+00 -2.37099743e+00]\n",
" [-5.00457406e-01 -1.11882675e+00 -5.91526508e-01 ... 4.23921436e-01\n",
" -2.08650708e+00 -1.82109618e+00]\n",
" ...\n",
" [ 2.07773042e+00 1.40735030e-01 -2.60543615e-01 ... -1.55956164e-01\n",
" -1.31862307e+00 -2.07174897e+00]\n",
" [ 7.95007765e-01 1.14988625e-01 -1.43308258e+00 ... 8.29253554e-01\n",
" -9.57888126e-01 -3.82121086e-01]\n",
" [ 8.34397674e-02 1.38636863e+00 -1.21593380e+00 ... -2.65783578e-01\n",
" 1.78124309e-02 -3.40287232e+00]]\n",
"\n",
" [[ 6.27344131e-01 5.71699142e-02 -3.58010936e+00 ... -4.53077674e-01\n",
" 1.65331578e+00 2.58466601e-02]\n",
" [ 2.66681361e+00 2.02069378e+00 -1.52052927e+00 ... 2.94914508e+00\n",
" 1.94632411e+00 -1.06698799e+00]\n",
" [ 1.57839453e+00 -1.03649735e-01 -4.22528505e+00 ... 2.28863955e+00\n",
" 4.27859402e+00 3.66381669e+00]\n",
" ...\n",
" [-2.44603205e+00 -2.09621000e+00 -2.57623529e+00 ... 9.00211930e-01\n",
" 4.30536079e+00 -2.49779320e+00]\n",
" [-2.52187514e+00 -3.36546659e+00 -1.26748765e+00 ... 8.11533451e-01\n",
" 2.55930424e-01 4.50821817e-02]\n",
" [-3.40082574e+00 -3.26924801e+00 -5.86932135e+00 ... -1.18203712e+00\n",
" 1.09565187e+00 -4.96661961e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 8.20469666e+00 6.96195841e+00 2.73753977e+00 ... 8.34498823e-01\n",
" 2.56748104e+00 1.67592216e+00]\n",
" [ 9.85801792e+00 8.81465149e+00 6.09280396e+00 ... 1.42389655e+00\n",
" 2.92086434e+00 2.08308399e-01]\n",
" [ 8.00702763e+00 7.97301006e+00 4.64527416e+00 ... 8.61916900e-01\n",
" 3.55370259e+00 4.75085378e-01]\n",
" ...\n",
" [ 5.61662769e+00 -4.72857296e-01 -1.04519971e-01 ... -4.03000236e-01\n",
" -1.66419971e+00 -1.70375630e-01]\n",
" [ 4.52409792e+00 -3.70670676e-01 4.54190969e-02 ... -8.20453286e-01\n",
" 9.49141383e-02 8.88008535e-01]\n",
" [ 3.27219462e+00 8.93201411e-01 1.94810414e+00 ... -2.86915004e-02\n",
" 1.93200278e+00 8.19505215e-01]]\n",
"\n",
" [[ 5.84066296e+00 6.72855520e+00 5.21399307e+00 ... 4.55058670e+00\n",
" 3.19132543e+00 3.17435169e+00]\n",
" [ 6.04594421e+00 6.88997173e+00 5.00542831e+00 ... 2.23561144e+00\n",
" 2.76059532e+00 4.83479440e-01]\n",
" [ 5.36118126e+00 4.13896275e+00 3.68701124e+00 ... 3.64462805e+00\n",
" 2.80596399e+00 1.52781498e+00]\n",
" ...\n",
" [ 2.87856674e+00 5.84320784e-01 1.74297714e+00 ... 2.83938944e-01\n",
" -2.26546407e-01 -1.18434143e+00]\n",
" [ 2.08510804e+00 1.74915957e+00 1.58637917e+00 ... 6.41967297e-01\n",
" -1.31319761e-01 -3.85830402e-01]\n",
" [ 4.41666174e+00 2.58244562e+00 2.97712159e+00 ... 1.42317235e-01\n",
" 1.68037796e+00 -6.50003672e-01]]\n",
"\n",
" [[ 1.05511594e+00 6.74880028e-01 -7.64639139e-01 ... -2.15282440e-01\n",
" 2.07197094e+00 4.48752761e-01]\n",
" [ 2.12095881e+00 3.44118834e+00 1.61375272e+00 ... -1.18487728e+00\n",
" 1.88659012e+00 1.48252523e+00]\n",
" [ 8.33427787e-01 4.35035896e+00 -3.59877385e-02 ... 8.70242774e-01\n",
" 3.75945044e+00 -3.09408635e-01]\n",
" ...\n",
" [ 5.08510351e+00 4.73114061e+00 1.97346115e+00 ... -2.25924397e+00\n",
" -1.26373076e+00 -1.37826729e+00]\n",
" [ 6.17275095e+00 4.16016817e+00 3.15675950e+00 ... -2.02416754e+00\n",
" 1.50002241e-02 1.84633851e+00]\n",
" [ 7.32995272e+00 5.34601831e+00 4.58857203e+00 ... -1.88874304e+00\n",
" 1.53240371e+00 7.47349262e-02]]]\n",
"\n",
"\n",
" [[[-1.80918843e-01 -2.52616453e+00 -2.78145695e+00 ... 1.44283652e+00\n",
" -1.08945215e+00 4.19084758e-01]\n",
" [-9.66833949e-01 -2.41106153e+00 -3.48886085e+00 ... -1.87193304e-01\n",
" 8.21905077e-01 1.89097953e+00]\n",
" [-1.59118319e+00 -2.56997013e+00 -3.10426521e+00 ... 2.05900550e+00\n",
" -2.78253704e-01 6.96343541e-01]\n",
" ...\n",
" [ 6.66302443e-02 -2.00887346e+00 -3.17550874e+00 ... 7.97579706e-01\n",
" -9.71581042e-02 1.71877682e+00]\n",
" [-8.01679730e-01 -2.02678037e+00 -3.21915555e+00 ... 8.35528374e-01\n",
" -1.15296638e+00 4.35728967e-01]\n",
" [ 1.45292446e-01 -2.15479851e+00 -1.51839817e+00 ... -3.07936192e-01\n",
" -5.39051890e-01 1.13107657e+00]]\n",
"\n",
" [[-2.43341160e+00 -3.35346818e+00 -9.87014294e-01 ... 1.34049034e+00\n",
" 2.95773447e-02 1.27177119e+00]\n",
" [-2.61602497e+00 -9.76761580e-01 -2.52060473e-01 ... -1.38134825e+00\n",
" 3.85564029e-01 4.57195908e-01]\n",
" [-2.23676014e+00 -4.00404739e+00 -2.23409963e+00 ... -1.41846514e+00\n",
" -6.58698231e-02 -3.61778140e-01]\n",
" ...\n",
" [-1.13604403e+00 -6.03917837e-02 -4.95491922e-01 ... 2.14673686e+00\n",
" 1.21484184e+00 2.22764325e+00]\n",
" [-1.05162430e+00 -1.59828448e+00 3.15489501e-01 ... 2.28046751e+00\n",
" 2.39702511e+00 2.43942714e+00]\n",
" [-1.27370405e+00 -2.05736399e-01 -1.12124372e+00 ... 2.21597219e+00\n",
" 2.50086927e+00 1.91134131e+00]]\n",
"\n",
" [[-4.53170598e-01 -1.59644139e+00 -3.63470483e+00 ... -4.35066032e+00\n",
" -3.79540777e+00 -1.09796596e+00]\n",
" [-2.21036464e-01 -2.53353834e+00 -1.28269875e+00 ... -3.38615727e+00\n",
" -2.59143281e+00 7.74220943e-01]\n",
" [-6.89323783e-01 -1.44375205e+00 6.66438341e-02 ... -1.30736077e+00\n",
" -1.23293114e+00 1.58148706e+00]\n",
" ...\n",
" [ 1.63751483e+00 -4.08427984e-01 -8.15176964e-01 ... 3.70807743e+00\n",
" 2.04232907e+00 1.97716308e+00]\n",
" [ 2.13261342e+00 1.85947633e+00 -8.06532025e-01 ... 1.98311245e+00\n",
" 2.27003932e+00 -1.11734614e-01]\n",
" [ 1.28702402e+00 3.98628891e-01 -1.63712263e+00 ... 8.00528765e-01\n",
" 5.78273535e-01 -2.59924948e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 3.96233416e+00 4.66794682e+00 1.39437711e+00 ... 7.52061129e-01\n",
" -1.53534544e+00 -6.67162359e-01]\n",
" [ 2.33841681e+00 3.35811281e+00 9.80114818e-01 ... 1.48806703e+00\n",
" 2.68609226e-01 -1.35124445e+00]\n",
" [ 2.08177710e+00 4.28519583e+00 1.52450514e+00 ... 7.45321214e-01\n",
" -5.04359961e-01 -1.81241560e+00]\n",
" ...\n",
" [ 2.95398951e-01 4.30877179e-01 -2.03731894e+00 ... -4.20221925e-01\n",
" 3.29260826e-01 5.83679557e-01]\n",
" [ 1.30742240e+00 -6.32183790e-01 -3.13741422e+00 ... 9.63868052e-02\n",
" 2.91730791e-01 1.33400351e-01]\n",
" [ 5.43292165e-01 -2.83665359e-01 -1.88138187e+00 ... 2.15468198e-01\n",
" 4.90157723e-01 2.40562439e+00]]\n",
"\n",
" [[ 1.57632053e+00 6.27885723e+00 2.87853765e+00 ... 3.07016110e+00\n",
" 1.91490650e+00 1.76274943e+00]\n",
" [ 2.57776356e+00 4.07256317e+00 2.52231169e+00 ... 4.09494352e+00\n",
" 2.53548074e+00 2.44395185e+00]\n",
" [ 2.43037057e+00 4.35728836e+00 1.96233964e+00 ... 2.26702976e+00\n",
" 2.94634581e+00 2.21452284e+00]\n",
" ...\n",
" [-2.72509992e-01 -8.41220498e-01 -1.89133918e+00 ... -1.80079627e+00\n",
" -2.00367713e+00 -7.09145784e-01]\n",
" [ 8.21575999e-01 -1.13323164e+00 -2.62418866e+00 ... -2.38889670e+00\n",
" -7.83945560e-01 -1.01922750e-01]\n",
" [-1.14730227e+00 -1.42182577e+00 -2.00993991e+00 ... -2.11025667e+00\n",
" 1.60286129e-02 -7.26446986e-01]]\n",
"\n",
" [[ 4.20389509e+00 3.75917768e+00 4.97653627e+00 ... 1.23642838e+00\n",
" 8.52760911e-01 1.27920091e-01]\n",
" [ 5.29409122e+00 5.29002380e+00 3.96404648e+00 ... 1.91227329e+00\n",
" 3.97556186e-01 1.69182217e+00]\n",
" [ 4.60112572e+00 4.12772799e+00 2.10280085e+00 ... 3.24303842e+00\n",
" -1.07720590e+00 -3.81854475e-01]\n",
" ...\n",
" [ 1.81884170e-02 -3.11472058e+00 -8.23525012e-01 ... -2.40161085e+00\n",
" -4.48192549e+00 -6.14600539e-01]\n",
" [ 1.16305006e+00 -1.15409636e+00 -3.48765063e+00 ... -1.97504926e+00\n",
" -4.44984436e+00 -2.28429958e-01]\n",
" [ 1.29197860e+00 6.17720246e-01 -5.87171853e-01 ... -1.35258228e-01\n",
" -1.29259872e+00 1.30360842e-01]]]\n",
"\n",
"\n",
" [[[-1.26687372e+00 -2.33633637e+00 -1.49625254e+00 ... 2.52396107e+00\n",
" -6.68072224e-01 -1.13282454e+00]\n",
" [-1.34229445e+00 -2.87080932e+00 -2.57388353e+00 ... -8.75385761e-01\n",
" -1.00205469e+00 -3.58956242e+00]\n",
" [-9.49853599e-01 -5.78684711e+00 -3.52962446e+00 ... 8.88233304e-01\n",
" 2.25133196e-01 -1.02802217e+00]\n",
" ...\n",
" [-7.38113701e-01 -3.47510982e+00 -3.23011065e+00 ... -1.25624001e+00\n",
" -1.63268471e+00 6.00247443e-01]\n",
" [-2.29733467e+00 -5.72547615e-01 -1.98301303e+00 ... -1.90137398e+00\n",
" -1.47013855e+00 -1.45779204e+00]\n",
" [-2.24628520e+00 -3.36337948e+00 -3.91878939e+00 ... -1.53652275e+00\n",
" -1.36285520e+00 -1.68160331e+00]]\n",
"\n",
" [[-8.11348319e-01 -7.17824280e-01 -1.02243233e+00 ... -2.69050407e+00\n",
" -2.32403350e+00 -4.25943947e+00]\n",
" [-2.35056520e+00 -2.35941172e+00 -1.24398732e+00 ... -2.08313870e+00\n",
" -1.16508257e+00 -1.30353463e+00]\n",
" [-2.25146723e+00 -1.94972813e+00 -1.13295293e+00 ... -2.61496377e+00\n",
" -1.91106403e+00 -1.07801402e+00]\n",
" ...\n",
" [-2.67012739e+00 -3.20916414e+00 -2.41768575e+00 ... 2.65138328e-01\n",
" -5.27612507e-01 1.44604075e+00]\n",
" [-3.54237866e+00 -3.62832785e+00 -2.40270257e+00 ... -9.76106226e-02\n",
" 4.67946082e-01 -7.24248111e-01]\n",
" [-2.49844384e+00 -3.42463255e+00 -2.99040008e+00 ... 4.28889185e-01\n",
" -7.51657963e-01 -1.00530767e+00]]\n",
"\n",
" [[-8.42589438e-02 1.42022014e-01 -8.51281703e-01 ... 4.21745628e-01\n",
" -2.35717297e-02 -1.71374834e+00]\n",
" [-1.05496287e+00 3.82416457e-01 -4.40595537e-01 ... 1.03381336e-01\n",
" -1.41204190e+00 -7.58325040e-01]\n",
" [-2.28930283e+00 -2.03857040e+00 -9.16261196e-01 ... -3.94939929e-01\n",
" -1.07798588e+00 -1.48433352e+00]\n",
" ...\n",
" [-3.11473966e-01 -1.40877593e+00 -2.42908645e+00 ... 7.88682699e-01\n",
" 1.24199319e+00 1.89949930e-01]\n",
" [ 5.44084549e-01 -1.02425671e+00 -1.53991556e+00 ... -4.36764538e-01\n",
" -5.78772545e-01 2.62665659e-01]\n",
" [ 1.26812792e+00 -9.89493608e-01 -1.47972977e+00 ... 2.21440494e-02\n",
" 2.79776216e-01 7.63269484e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 6.02095068e-01 5.93243122e-01 -1.06838238e+00 ... 3.56546330e+00\n",
" 1.16390383e+00 -1.47593319e-01]\n",
" [ 1.80458140e+00 1.68401957e+00 4.17516947e-01 ... 3.33444500e+00\n",
" 1.89411759e+00 1.03220642e-01]\n",
" [ 2.74264169e+00 2.92038846e+00 1.00775683e+00 ... 3.53285050e+00\n",
" 2.07282662e+00 -2.56800652e-01]\n",
" ...\n",
" [ 4.88933468e+00 3.72433925e+00 3.58677816e+00 ... 1.98363388e+00\n",
" 1.80851030e+00 8.32634747e-01]\n",
" [ 4.01546288e+00 4.78934765e+00 2.94778132e+00 ... 2.99637699e+00\n",
" 1.30439472e+00 3.61029744e-01]\n",
" [ 3.13628030e+00 2.01894832e+00 2.82585931e+00 ... 2.54264188e+00\n",
" -9.16651785e-02 9.93353873e-02]]\n",
"\n",
" [[ 2.35585642e+00 8.42678428e-01 1.57331872e+00 ... 3.65935063e+00\n",
" 3.94066262e+00 4.89832020e+00]\n",
" [ 1.85791731e+00 1.34373701e+00 1.30812299e+00 ... 2.71434736e+00\n",
" 3.22004294e+00 2.99872303e+00]\n",
" [ 1.67675853e+00 -4.05569375e-02 1.85539150e+00 ... 3.73934364e+00\n",
" 2.98195982e+00 3.37315011e+00]\n",
" ...\n",
" [ 2.14539170e+00 2.86586595e+00 2.20222116e+00 ... 1.20492995e+00\n",
" 2.13971066e+00 1.94932449e+00]\n",
" [ 4.68422651e+00 3.80044746e+00 4.23209000e+00 ... 2.40658951e+00\n",
" 2.29117441e+00 2.52368808e+00]\n",
" [ 3.10694575e+00 2.49402595e+00 4.53786707e+00 ... 9.08902645e-01\n",
" 1.86903965e+00 2.27776885e+00]]\n",
"\n",
" [[ 1.45200038e+00 5.17961740e-01 -1.58403587e+00 ... 5.07019472e+00\n",
" 7.87163258e-01 1.20610237e+00]\n",
" [ 3.39321136e+00 2.21043849e+00 -6.31202877e-01 ... 4.97822762e+00\n",
" 9.66498017e-01 1.18883348e+00]\n",
" [ 1.20627856e+00 1.82759428e+00 5.91053367e-01 ... 4.14318657e+00\n",
" 5.25399208e-01 -1.16850233e+00]\n",
" ...\n",
" [ 1.05183899e+00 5.80030501e-01 1.89724147e+00 ... 2.54626465e+00\n",
" -1.49128008e+00 -1.85064209e+00]\n",
" [ 1.50983357e+00 2.85973406e+00 2.61224055e+00 ... 4.83481932e+00\n",
" 9.67048705e-02 -4.37043965e-01]\n",
" [ 2.57720876e+00 2.09961963e+00 4.11754288e-02 ... 3.80421424e+00\n",
" -7.83308804e-01 -1.64871216e+00]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[-1.16345096e+00 -2.53971386e+00 -8.99101734e-01 ... -4.35583591e-01\n",
" -1.29671764e+00 -1.61429560e+00]\n",
" [ 3.72841507e-01 3.45808208e-01 -1.82167351e+00 ... -2.14515448e+00\n",
" -1.26383066e+00 -2.27464601e-01]\n",
" [ 1.58568513e+00 2.58181524e+00 1.86554670e+00 ... -1.10401320e+00\n",
" -3.68550658e-01 -2.58849680e-01]\n",
" ...\n",
" [-9.15827155e-01 -1.25424683e+00 -4.04716206e+00 ... 2.13138080e+00\n",
" 2.67662477e+00 2.31014514e+00]\n",
" [-3.19453120e-01 -6.71132684e-01 -1.51378751e+00 ... 1.86080432e+00\n",
" 2.77418542e+00 1.22875953e+00]\n",
" [-1.20453942e+00 -3.93669218e-01 -1.51751983e+00 ... 1.17620552e+00\n",
" 1.95602298e+00 7.64306366e-01]]\n",
"\n",
" [[-8.73186827e-01 -2.12537169e+00 -1.91664994e+00 ... -2.90821463e-01\n",
" 1.90896463e+00 8.02283168e-01]\n",
" [-1.06389821e+00 -2.15300727e+00 -1.82113051e+00 ... -4.34280694e-01\n",
" 1.53455496e+00 1.94702053e+00]\n",
" [-2.08403468e+00 -4.72900331e-01 -1.10610819e+00 ... -8.79420400e-01\n",
" 7.79394627e-01 2.02670670e+00]\n",
" ...\n",
" [-4.28208113e-01 -7.90894389e-01 -1.06713009e+00 ... 1.12579381e+00\n",
" 9.61961091e-01 1.40342009e+00]\n",
" [ 4.40416574e-01 -1.65901780e-02 -1.05338669e+00 ... 1.40698349e+00\n",
" 9.43485856e-01 2.34856772e+00]\n",
" [-1.20572495e+00 -2.03134632e+00 4.88817632e-01 ... 2.20770907e+00\n",
" 1.38143206e+00 2.00714707e+00]]\n",
"\n",
" [[ 9.00486887e-01 -9.50459957e-01 -1.42935121e+00 ... -1.30648065e+00\n",
" -2.52133775e+00 -8.87715697e-01]\n",
" [ 3.73431134e+00 1.69571114e+00 5.99429727e-01 ... 6.64332986e-01\n",
" -6.10453069e-01 2.06534386e+00]\n",
" [ 1.59800696e+00 -4.59622175e-01 -6.73136234e-01 ... 2.18770742e-01\n",
" -1.12928271e+00 4.87097502e-02]\n",
" ...\n",
" [ 1.92336845e+00 1.37130380e-01 -3.51048648e-01 ... 5.41638851e-01\n",
" 1.06069386e+00 1.36404145e+00]\n",
" [ 1.29641414e+00 -2.79530913e-01 -2.63607264e-01 ... -8.62445176e-01\n",
" 1.48393130e+00 2.69196725e+00]\n",
" [ 1.14442182e+00 -1.24098969e+00 3.70959163e-01 ... -1.12241995e+00\n",
" 3.67927134e-01 2.55976987e+00]]\n",
"\n",
" ...\n",
"\n",
" [[ 5.32017851e+00 3.64207411e+00 3.84571218e+00 ... 3.60754800e+00\n",
" 2.57500267e+00 -1.38083458e-01]\n",
" [ 5.69058084e+00 3.93056583e+00 2.93337941e+00 ... 3.17091584e+00\n",
" 2.34770632e+00 6.48133337e-01]\n",
" [ 5.98239613e+00 6.16548634e+00 3.04750896e+00 ... 5.51510525e+00\n",
" 4.34810448e+00 1.31588542e+00]\n",
" ...\n",
" [ 5.09930992e+00 3.32360983e+00 2.29228449e+00 ... 3.45123887e-01\n",
" 1.06280947e+00 -5.93325794e-02]\n",
" [ 4.19760656e+00 3.97779059e+00 1.66905916e+00 ... 3.68937254e-01\n",
" 8.06131065e-02 8.08142900e-01]\n",
" [ 4.52498960e+00 3.45109749e+00 1.01074433e+00 ... -2.54036248e-01\n",
" 3.13675582e-01 2.13851762e+00]]\n",
"\n",
" [[ 6.93927193e+00 6.05758238e+00 4.60648441e+00 ... 4.32221603e+00\n",
" 3.17874146e+00 1.47012353e+00]\n",
" [ 7.88523865e+00 6.62228966e+00 4.77496338e+00 ... 4.45868683e+00\n",
" 2.73698759e+00 2.17057824e+00]\n",
" [ 7.12061214e+00 6.01714134e+00 4.52996492e+00 ... 3.97184372e+00\n",
" 3.43153954e+00 1.21802723e+00]\n",
" ...\n",
" [ 2.85720730e+00 1.89639473e+00 1.96340394e+00 ... 1.89643729e+00\n",
" 1.64856291e+00 1.15853786e+00]\n",
" [ 3.88248491e+00 2.16386199e+00 1.53069091e+00 ... 2.71704245e+00\n",
" 2.24890351e+00 2.22156644e+00]\n",
" [ 5.27136230e+00 1.68400204e+00 2.09500480e+00 ... 2.75956345e+00\n",
" 3.71970820e+00 1.69852686e+00]]\n",
"\n",
" [[ 2.55598164e+00 1.64588141e+00 6.70431674e-01 ... 3.24091220e+00\n",
" 1.48759770e+00 -1.72001183e+00]\n",
" [ 4.33942318e+00 8.40826690e-01 -7.40000725e-01 ... 7.24577069e-01\n",
" 1.74327165e-01 -1.83029580e+00]\n",
" [ 4.39864540e+00 2.28395438e+00 -1.90353513e-01 ... 5.58019161e+00\n",
" 1.05627227e+00 -8.02519619e-01]\n",
" ...\n",
" [ 1.97654784e+00 3.26888156e+00 1.52879453e+00 ... 3.15013933e+00\n",
" 4.66731453e+00 4.98701715e+00]\n",
" [ 1.40016854e+00 3.45761251e+00 3.68359756e+00 ... 1.14207900e+00\n",
" 3.32219076e+00 3.83035636e+00]\n",
" [ 1.99269783e+00 2.15428829e+00 3.35396528e-01 ... 2.45916694e-01\n",
" 2.13785577e+00 4.33214951e+00]]]\n",
"\n",
"\n",
" [[[ 1.35320330e+00 5.05850911e-02 1.04915988e+00 ... 1.82023585e-01\n",
" 2.72914767e-01 3.92112255e-01]\n",
" [ 1.04646444e+00 7.60913491e-01 1.93323612e+00 ... 1.19493449e+00\n",
" -1.44200325e-01 4.07531261e-02]\n",
" [-9.88207340e-01 -1.46165287e+00 1.05884135e-01 ... -3.23057353e-01\n",
" -2.28934169e+00 -7.38609374e-01]\n",
" ...\n",
" [ 1.01198792e+00 2.34331083e+00 1.04566610e+00 ... 1.29697472e-01\n",
" -1.23878837e+00 2.21006930e-01]\n",
" [-3.75360101e-01 1.53673506e+00 -1.32206869e+00 ... -2.55255580e-01\n",
" -6.22699618e-01 -1.73162484e+00]\n",
" [ 4.34735864e-01 5.08327007e-01 -3.49233925e-01 ... -1.04749084e+00\n",
" -1.15777385e+00 -1.13671994e+00]]\n",
"\n",
" [[ 1.67839336e+00 -1.80224836e-01 1.02194118e+00 ... 8.44027162e-01\n",
" 8.81283879e-02 -1.37762165e+00]\n",
" [ 8.39694083e-01 1.32322550e+00 4.02442753e-01 ... -4.21785116e-01\n",
" -9.98012185e-01 -1.11348581e+00]\n",
" [ 7.64424682e-01 8.58965695e-01 2.94626594e-01 ... -6.65519595e-01\n",
" -3.65677416e-01 -2.25250268e+00]\n",
" ...\n",
" [-1.10193872e+00 1.18070498e-01 1.04604781e-01 ... -1.44486964e+00\n",
" -2.52748466e+00 -2.16131711e+00]\n",
" [-1.06079710e+00 -1.48379254e+00 3.80138367e-01 ... -1.62288392e+00\n",
" -2.44736362e+00 -8.78590107e-01]\n",
" [ 3.44401300e-02 -2.60935068e+00 -2.35597759e-01 ... -2.41114974e+00\n",
" -2.45255780e+00 -1.82384634e+00]]\n",
"\n",
" [[ 1.37670958e+00 1.58661580e+00 -2.85664916e-01 ... 1.49081087e+00\n",
" 4.13422853e-01 1.12761199e+00]\n",
" [ 1.54148173e+00 6.22704089e-01 1.41886568e+00 ... 1.59678531e+00\n",
" -8.72656107e-01 1.52415514e-01]\n",
" [ 3.30207205e+00 2.89925170e+00 1.91855145e+00 ... 3.18863559e+00\n",
" 1.87347198e+00 9.48901057e-01]\n",
" ...\n",
" [-1.53920484e+00 1.77375078e-02 -1.02018684e-01 ... 1.94011092e+00\n",
" -6.83587790e-01 1.49154460e+00]\n",
" [-2.27719522e+00 1.02481163e+00 -2.11300224e-01 ... -8.18020821e-01\n",
" 1.54248989e+00 -1.46732473e+00]\n",
" [-4.50206220e-01 3.62383485e+00 1.07175660e+00 ... 4.25961137e-01\n",
" 1.12405360e-01 -6.87821358e-02]]\n",
"\n",
" ...\n",
"\n",
" [[-3.40477467e-01 -2.99311423e+00 -2.12096786e+00 ... 2.27393007e+00\n",
" 4.03424358e+00 3.73335361e+00]\n",
" [-6.99971199e-01 -2.97719741e+00 -2.72910309e+00 ... 1.50101089e+00\n",
" 2.29408574e+00 3.14105940e+00]\n",
" [-1.41648722e+00 -1.86292887e+00 -1.84006739e+00 ... 2.78402638e+00\n",
" 3.91481900e+00 5.32456112e+00]\n",
" ...\n",
" [ 5.97958088e-01 1.50512588e+00 6.23718500e-01 ... 2.83813477e+00\n",
" 3.87909842e+00 3.33359623e+00]\n",
" [ 1.65542316e+00 3.56163192e+00 4.01527691e+00 ... 3.38367462e+00\n",
" 1.55827272e+00 2.50741863e+00]\n",
" [ 2.82036042e+00 2.53322673e+00 4.38798475e+00 ... 4.64642382e+00\n",
" 3.28739667e+00 3.02895570e+00]]\n",
"\n",
" [[-3.47941303e+00 -3.49006844e+00 -2.25583363e+00 ... 1.45181656e-01\n",
" 1.52944064e+00 2.08810711e+00]\n",
" [-2.27786446e+00 -4.59218550e+00 -2.74722624e+00 ... -1.73136210e+00\n",
" 7.46028006e-01 1.74789345e+00]\n",
" [-3.35524082e+00 -4.58244705e+00 -2.40820456e+00 ... -5.04051924e-01\n",
" 1.49640536e+00 2.16613841e+00]\n",
" ...\n",
" [ 5.26107132e-01 2.05329061e+00 2.84252572e+00 ... 1.33222675e+00\n",
" 3.87935114e+00 3.69385266e+00]\n",
" [ 4.38092083e-01 2.15028906e+00 3.13363624e+00 ... 3.36048746e+00\n",
" 5.36551809e+00 2.94915986e+00]\n",
" [ 2.75497317e+00 3.25929213e+00 2.33522987e+00 ... 1.69926262e+00\n",
" 3.93462896e+00 3.68200874e+00]]\n",
"\n",
" [[ 1.10951948e+00 5.31419516e-02 -1.58864903e+00 ... 5.24887085e+00\n",
" 1.60273385e+00 4.90113163e+00]\n",
" [-2.94517064e+00 -2.81092644e+00 -4.89631557e+00 ... 3.99868512e+00\n",
" 1.40544355e+00 2.84833241e+00]\n",
" [-3.51893663e-01 -3.53325534e+00 -2.21239805e+00 ... 4.26225853e+00\n",
" 6.87886119e-01 2.58609629e+00]\n",
" ...\n",
" [ 2.92248201e+00 5.40264511e+00 4.65721560e+00 ... 5.24537373e+00\n",
" 2.30406880e+00 1.29892707e+00]\n",
" [ 1.43473256e+00 4.61167526e+00 3.57578802e+00 ... 5.12181854e+00\n",
" 8.59923482e-01 1.38731599e+00]\n",
" [-6.50881350e-01 2.18233657e+00 2.74669623e+00 ... 4.86368895e+00\n",
" 1.44120216e+00 1.79993320e+00]]]\n",
"\n",
"\n",
" [[[ 1.64106202e+00 3.54410499e-01 -3.54172409e-01 ... 2.32646990e+00\n",
" 1.65043330e+00 3.45897645e-01]\n",
" [ 2.16236949e+00 1.28213906e+00 2.26082468e+00 ... 6.10507369e-01\n",
" 9.12241280e-01 1.27429694e-01]\n",
" [ 2.07962990e+00 7.03816175e-01 2.01272345e+00 ... -2.26959705e-01\n",
" 1.00041127e+00 5.87104559e-02]\n",
" ...\n",
" [-1.62972426e+00 -3.04028845e+00 -1.39124167e+00 ... 2.47561097e+00\n",
" 2.35047388e+00 1.61532843e+00]\n",
" [-1.97368932e+00 -5.44541061e-01 -5.92882216e-01 ... 1.39800012e+00\n",
" 2.32770801e+00 9.96662021e-01]\n",
" [-1.15636075e+00 -1.34654212e+00 -8.50648999e-01 ... 1.85655832e+00\n",
" 2.05776072e+00 5.34575820e-01]]\n",
"\n",
" [[-1.02104437e+00 3.08469892e-01 2.81789303e-01 ... -8.24654043e-01\n",
" -9.85817850e-01 -2.05517030e+00]\n",
" [ 9.50192690e-01 3.35105330e-01 5.31637192e-01 ... -1.42974198e-01\n",
" -1.79659498e+00 -1.58266973e+00]\n",
" [-2.51316994e-01 -1.28709340e+00 3.01498562e-01 ... -1.32253516e+00\n",
" -1.55507576e+00 -9.37123299e-01]\n",
" ...\n",
" [ 2.33016998e-01 2.92454743e+00 3.15420461e+00 ... 1.15574491e+00\n",
" 1.27850962e+00 1.35487700e+00]\n",
" [ 3.81013602e-01 1.44239831e+00 6.64825320e-01 ... -3.89374971e-01\n",
" 1.50716826e-01 1.33641326e+00]\n",
" [ 1.71373415e+00 1.67357373e+00 1.76596940e+00 ... 1.57941079e+00\n",
" 1.60940981e+00 1.78091609e+00]]\n",
"\n",
" [[-5.16522598e+00 -1.68099070e+00 -3.24440050e+00 ... -3.46229005e+00\n",
" -2.18273020e+00 -1.98621082e+00]\n",
" [-3.05743694e+00 9.15392339e-01 -1.93508530e+00 ... -1.82306373e+00\n",
" -2.12960863e+00 -3.45255351e+00]\n",
" [-4.32777822e-01 -1.00303245e+00 -1.61397791e+00 ... -2.08376765e+00\n",
" -3.72989595e-01 -1.36516929e+00]\n",
" ...\n",
" [-5.83641946e-01 4.14125490e+00 1.58227599e+00 ... 2.03144050e+00\n",
" 2.13982654e+00 -1.81909311e+00]\n",
" [-1.74230576e+00 2.39347410e+00 2.44080925e+00 ... 5.43732524e-01\n",
" 2.07899213e+00 -3.71748984e-01]\n",
" [ 3.80016506e-01 7.84988403e-01 1.20596504e+00 ... -2.32057095e+00\n",
" -2.81265080e-01 -3.69353056e+00]]\n",
"\n",
" ...\n",
"\n",
" [[-3.48024845e+00 -2.60937548e+00 -3.84952760e+00 ... 6.68736577e-01\n",
" -1.75104141e-02 -3.54720926e+00]\n",
" [-2.59637117e+00 -5.18190145e+00 -2.33887696e+00 ... 9.13373232e-02\n",
" -3.58282638e+00 -2.40778995e+00]\n",
" [-2.50912881e+00 -1.22113395e+00 -2.34372020e+00 ... 1.40071487e+00\n",
" -1.67449510e+00 -1.14655948e+00]\n",
" ...\n",
" [-5.75253534e+00 -6.67348385e+00 -5.05184650e+00 ... -2.73145151e+00\n",
" -1.48933101e+00 -1.36807609e+00]\n",
" [-3.29049587e+00 -3.73956156e+00 -2.85064268e+00 ... -3.92481357e-01\n",
" -8.00529659e-01 -8.39800835e-01]\n",
" [-4.30351114e+00 -4.21471930e+00 -2.41703367e+00 ... -1.27081513e+00\n",
" 1.67839837e+00 8.47821474e-01]]\n",
"\n",
" [[-5.27856112e-01 -1.09752083e+00 3.39107156e-01 ... 2.00062895e+00\n",
" 8.83528054e-01 2.57416844e-01]\n",
" [-1.58655810e+00 -3.36268663e-01 1.16161990e+00 ... 1.54868484e+00\n",
" 2.38878536e+00 1.84097290e+00]\n",
" [ 5.96052647e-01 2.15484858e-01 1.85280466e+00 ... 2.74587560e+00\n",
" 1.61432290e+00 1.13214278e+00]\n",
" ...\n",
" [-4.57659864e+00 -5.42679739e+00 -4.35204458e+00 ... -1.82452416e+00\n",
" -2.18670201e+00 -3.91811800e+00]\n",
" [-1.32477629e+00 -4.19110394e+00 -3.41308069e+00 ... 1.39622003e-01\n",
" -1.59393203e+00 -9.08105671e-01]\n",
" [-3.60161018e+00 -4.05932713e+00 -2.23674798e+00 ... 9.09647286e-01\n",
" 9.73127842e-01 1.19991803e+00]]\n",
"\n",
" [[ 2.04062796e+00 7.95603275e-01 -1.28833270e+00 ... 4.64749050e+00\n",
" 2.25974560e+00 1.02396965e+00]\n",
" [ 1.68882537e+00 2.63353348e+00 2.53597498e-02 ... 4.69063854e+00\n",
" -4.19382691e-01 2.91669458e-01]\n",
" [ 7.71395087e-01 1.20833695e+00 -2.58601785e-01 ... 1.21794045e+00\n",
" -1.51922226e-01 7.44265199e-01]\n",
" ...\n",
" [-6.66095781e+00 -4.81577682e+00 -5.39921665e+00 ... -2.20548606e+00\n",
" 5.72486281e-01 -4.35207397e-01]\n",
" [-7.51608658e+00 -6.67776871e+00 -3.73199415e+00 ... -1.70327055e+00\n",
" 1.01334639e-02 -3.20627165e+00]\n",
" [-5.73050356e+00 -2.74379373e+00 -3.70248461e+00 ... -1.09794116e+00\n",
" -1.73590891e-02 -1.80156028e+00]]]]\n",
"param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: [-1.4305115e-06 0.0000000e+00 -4.0531158e-06 -1.6689301e-06\n",
" 2.3841858e-07 -7.1525574e-07 1.1920929e-06 1.5497208e-06\n",
" -2.3841858e-07 1.6689301e-06 9.5367432e-07 9.5367432e-07\n",
" -2.6226044e-06 1.1920929e-06 1.3113022e-06 1.9669533e-06\n",
" -4.7683716e-07 1.1920929e-06 -1.6689301e-06 -1.5497208e-06\n",
" -2.2649765e-06 4.7683716e-07 2.3841858e-06 -3.5762787e-06\n",
" 2.3841858e-07 2.1457672e-06 -3.5762787e-07 8.3446503e-07\n",
" -3.5762787e-07 -7.1525574e-07 2.6524067e-06 -1.1920929e-06]\n",
"param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: [-3.7669735 1.5226867 1.759756 4.501629 -2.2077336 0.18411277\n",
" 1.3558264 -1.0269645 3.9628277 3.9300344 -2.80754 1.8462183\n",
" -0.03385968 2.1284049 0.46124816 -4.364863 0.78491163 0.25565645\n",
" -5.3538237 3.2606194 0.79100513 -1.4652673 2.769378 1.2283417\n",
" -4.7466464 -1.3404545 -6.9374166 0.710248 2.0944448 0.4334769\n",
" -0.24313992 0.31392363]\n",
"param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: [-0.6251638 2.833331 0.6993131 3.7106915 -2.262496 0.7390424\n",
" 0.5360477 -2.803875 2.1646228 2.117193 -1.9988279 1.5135905\n",
" -2.0181084 2.6450465 0.06302822 -3.0530102 1.4788482 0.5941844\n",
" -3.1690063 1.8753575 -0.0737313 -2.7806277 -0.04483938 0.16129279\n",
" -1.2960215 -0.38020235 -0.55218065 0.10754502 2.065371 -1.4703183\n",
" -0.40964937 -1.4454535 ]\n",
"param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: [[-0.46178514 0.1095643 0.06441769 ... 0.42020613 -0.34181893\n",
" -0.0658682 ]\n",
" [-0.03619978 0.21653323 0.01727325 ... 0.05731536 -0.37822944\n",
" -0.05464617]\n",
" [-0.32397318 0.04158126 -0.08091418 ... 0.0928297 -0.06518176\n",
" -0.40110156]\n",
" ...\n",
" [-0.2702023 0.05126935 0.11825457 ... 0.0069707 -0.36951366\n",
" 0.37071258]\n",
" [-0.11326203 0.19305304 -0.133317 ... -0.13030824 -0.09068564\n",
" 0.32735693]\n",
" [-0.04543798 0.09902512 -0.10745425 ... -0.06685166 -0.3055201\n",
" 0.0752247 ]]\n",
"param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.07338604 0.64991236 0.5465856 ... 0.507725 0.14061031\n",
" 0.3020359 ]\n",
"param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.41395143 -0.28493872 0.36796764 ... 0.2387953 0.06732331\n",
" 0.16263628]\n",
"param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.09370177 -0.12264141 -0.08237482 ... -0.50241685 -0.149155\n",
" -0.25661892]\n",
" [-0.37426725 0.44987115 0.10685667 ... -0.65946174 -0.4499248\n",
" -0.17545304]\n",
" [-0.03753807 0.33422717 0.12750985 ... 0.05405155 -0.17648363\n",
" 0.05315325]\n",
" ...\n",
" [ 0.15721183 0.03064088 -0.00751081 ... 0.27183983 0.3881693\n",
" -0.01544908]\n",
" [ 0.26047793 0.16917065 0.00915196 ... 0.18076143 -0.05080506\n",
" 0.14791614]\n",
" [ 0.19052255 0.03642382 -0.14313167 ... 0.2611448 0.20763844\n",
" 0.26846847]]\n",
"param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.4139514 -0.28493875 0.36796758 ... 0.23879525 0.06732336\n",
" 0.16263627]\n",
"param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[ 0.04214853 -0.1710323 0.17557406 ... 0.11926915 0.21577051\n",
" -0.30598596]\n",
" [-0.02370887 -0.03498494 -0.05991999 ... -0.06049232 -0.14527473\n",
" -0.5335691 ]\n",
" [-0.21417995 -0.10263194 -0.05903128 ... -0.26958284 0.05936668\n",
" 0.25522667]\n",
" ...\n",
" [ 0.31594425 -0.29487017 0.15871571 ... 0.3504135 -0.1418606\n",
" -0.07482046]\n",
" [ 0.22316164 0.7682122 -0.22191924 ... -0.00535548 -0.6497105\n",
" -0.2011079 ]\n",
" [-0.05800886 0.13750821 0.02450509 ... 0.245736 0.07425706\n",
" -0.17761081]]\n",
"param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.45080703 0.19005743 0.077441 ... -0.24504453 0.19666554\n",
" -0.10503208]\n",
"param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237206 0.03389215 ... -0.35602498 0.25528812\n",
" 0.11344345]\n",
"param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.48457903 0.04466334 -0.19785863 ... -0.0254025 -0.10338341\n",
" -0.29202533]\n",
" [-0.15261276 0.00412052 0.22198747 ... 0.22460426 -0.03752084\n",
" 0.05170784]\n",
" [-0.09337254 0.02530848 0.1263681 ... -0.02056236 0.33342454\n",
" -0.08760723]\n",
" ...\n",
" [-0.28645608 -0.19169135 -0.1361257 ... -0.00444204 -0.06552711\n",
" -0.14726155]\n",
" [ 0.21883707 0.2049045 0.23723911 ... 0.4626113 -0.14110637\n",
" 0.02569831]\n",
" [ 0.37554163 -0.19249167 0.14591683 ... 0.25602737 0.40088275\n",
" 0.41056633]]\n",
"param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237211 0.0338921 ... -0.35602498 0.2552881\n",
" 0.11344352]\n",
"param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[-0.28007814 -0.09206 -0.01297755 ... -0.2557205 -0.2693453\n",
" 0.05862035]\n",
" [-0.34194735 -0.01383794 -0.06490533 ... -0.11063005 0.16226721\n",
" -0.3197178 ]\n",
" [-0.3646778 0.15443833 0.02241019 ... -0.15093157 -0.09886418\n",
" -0.44295847]\n",
" ...\n",
" [-0.01041886 -0.57636976 -0.03988511 ... -0.2260822 0.49646813\n",
" -0.15528557]\n",
" [-0.19385241 -0.56451964 -0.05551083 ... -0.5638106 0.43611372\n",
" -0.61484563]\n",
" [ 0.1051331 -0.4762463 0.11194798 ... -0.26766616 -0.30734932\n",
" 0.17856634]]\n",
"param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.02791309 -0.992517 0.63012564 ... -1.1830902 1.4646478\n",
" 1.6333911 ]\n",
"param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.10834587 -1.7079136 0.81259465 ... -1.4478713 1.455745\n",
" 2.069446 ]\n",
"param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.14363798 -0.06933184 0.02901152 ... -0.19233373 -0.03206367\n",
" -0.00845779]\n",
" [-0.44314507 -0.8921327 -1.031872 ... -0.558997 -0.53070104\n",
" -0.855925 ]\n",
" [ 0.15673254 0.28793585 0.13351494 ... 0.38433537 0.5040767\n",
" 0.11303265]\n",
" ...\n",
" [-0.22923109 -0.62508404 -0.6195032 ... -0.6876448 -0.41718128\n",
" -0.74844164]\n",
" [ 0.18024652 0.45618314 0.81391454 ... 0.5780604 0.87566674\n",
" 0.71526295]\n",
" [ 0.3763076 0.54033077 0.9940485 ... 1.087821 0.72288674\n",
" 1.2852117 ]]\n",
"param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.10834593 -1.7079139 0.8125948 ... -1.4478711 1.4557447\n",
" 2.0694466 ]\n",
"param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: [[ 1.4382483e-02 2.0160766e-02 1.2322801e-02 ... 1.0075266e-02\n",
" 7.4421698e-03 -2.3925617e+01]\n",
" [ 3.7887424e-02 5.7105277e-02 2.8803380e-02 ... 2.4820438e-02\n",
" 1.8560058e-02 -5.0687141e+01]\n",
" [ 4.5566272e-02 5.4415584e-02 3.2858539e-02 ... 3.2725763e-02\n",
" 2.1536341e-02 -6.1036335e+01]\n",
" ...\n",
" [ 2.8015019e-02 3.5967816e-02 2.3228688e-02 ... 2.1284629e-02\n",
" 1.3860047e-02 -5.2543671e+01]\n",
" [ 2.8445240e-02 4.2448867e-02 2.7125146e-02 ... 2.2253662e-02\n",
" 1.7470375e-02 -4.3619675e+01]\n",
" [ 4.7438074e-02 5.8287360e-02 3.4546286e-02 ... 3.0827176e-02\n",
" 2.2168703e-02 -6.7901680e+01]]\n",
"param grad: fc.bias: shape: [4299] stop_grad: False grad: [ 8.8967547e-02 1.0697905e-01 6.5251388e-02 ... 6.1503030e-02\n",
" 4.3404289e-02 -1.3512518e+02]\n"
]
}
],
"source": [
"loss.backward(retain_graph=False)\n",
"for n, p in dp_model.named_parameters():\n",
" print(\n",
" f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "selected-crazy",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1.]\n"
]
}
],
"source": [
"print(loss.grad)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bottom-engineer",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "stuffed-yeast",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
\ No newline at end of file
因为 它太大了无法显示 source diff 。你可以改为 查看blob
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "choice-grade",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/DeepSpeech-2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "broke-broad",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"register user softmax to paddle, remove this when fixed!\n",
"register user log_softmax to paddle, remove this when fixed!\n",
"register user sigmoid to paddle, remove this when fixed!\n",
"register user log_sigmoid to paddle, remove this when fixed!\n",
"register user relu to paddle, remove this when fixed!\n",
"override cat of paddle if exists or register, remove this when fixed!\n",
"override item of paddle.Tensor if exists or register, remove this when fixed!\n",
"override long of paddle.Tensor if exists or register, remove this when fixed!\n",
"override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
"override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"override eq of paddle if exists or register, remove this when fixed!\n",
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"register user view to paddle.Tensor, remove this when fixed!\n",
"register user view_as to paddle.Tensor, remove this when fixed!\n",
"register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"register user fill_ to paddle.Tensor, remove this when fixed!\n",
"register user repeat to paddle.Tensor, remove this when fixed!\n",
"register user softmax to paddle.Tensor, remove this when fixed!\n",
"register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"register user relu to paddle.Tensor, remove this when fixed!\n",
"register user type_as to paddle.Tensor, remove this when fixed!\n",
"register user to to paddle.Tensor, remove this when fixed!\n",
"register user float to paddle.Tensor, remove this when fixed!\n",
"register user tolist to paddle.Tensor, remove this when fixed!\n",
"register user glu to paddle.nn.functional, remove this when fixed!\n",
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"register user Module to paddle.nn, remove this when fixed!\n",
"register user ModuleList to paddle.nn, remove this when fixed!\n",
"register user GLU to paddle.nn, remove this when fixed!\n",
"register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"register user export to paddle.jit, remove this when fixed!\n"
]
}
],
"source": [
"import numpy as np\n",
"import paddle\n",
"from yacs.config import CfgNode as CN\n",
"\n",
"from deepspeech.models.u2 import U2Model\n",
"from deepspeech.utils.layer_tools import print_params\n",
"from deepspeech.utils.layer_tools import summary"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "permanent-summary",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"[INFO 2021/05/31 03:23:22 u2.py:839] U2 Encoder type: transformer\n",
"[INFO 2021/05/31 03:23:22 u2.py:840] attention_dropout_rate: 0.0\n",
"attention_heads: 4\n",
"dropout_rate: 0.1\n",
"input_layer: conv2d\n",
"linear_units: 2048\n",
"normalize_before: True\n",
"num_blocks: 12\n",
"output_size: 256\n",
"positional_dropout_rate: 0.1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n",
"encoder.embed.conv.0.bias | [256] | 256 | True\n",
"encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n",
"encoder.embed.conv.2.bias | [256] | 256 | True\n",
"encoder.embed.out.0.weight | [5120, 256] | 1310720 | True\n",
"encoder.embed.out.0.bias | [256] | 256 | True\n",
"encoder.after_norm.weight | [256] | 256 | True\n",
"encoder.after_norm.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n",
"decoder.embed.0.weight | [4233, 256] | 1083648 | True\n",
"decoder.after_norm.weight | [256] | 256 | True\n",
"decoder.after_norm.bias | [256] | 256 | True\n",
"decoder.output_layer.weight | [256, 4233] | 1083648 | True\n",
"decoder.output_layer.bias | [4233] | 4233 | True\n",
"decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n",
"ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n",
"ctc.ctc_lo.bias | [4233] | 4233 | True\n",
"Total parameters: 411.0, 32.01M elements.\n"
]
}
],
"source": [
"conf_str='examples/tiny/s1/conf/transformer.yaml'\n",
"cfg = CN().load_cfg(open(conf_str))\n",
"cfg.model.input_dim = 83\n",
"cfg.model.output_dim = 4233\n",
"cfg.model.cmvn_file = None\n",
"cfg.model.cmvn_file_type = 'json'\n",
"#cfg.model.encoder_conf.concat_after=True\n",
"cfg.freeze()\n",
"model = U2Model(cfg.model)\n",
"\n",
"print_params(model)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "sapphire-agent",
"metadata": {},
"outputs": [],
"source": [
"#summary(model)\n",
"#print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ruled-invitation",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fossil-means",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n",
"encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n"
]
}
],
"source": [
"%ls .notebook/espnet"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "45c2b75f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"state\n",
"odict_keys(['mask_feature', 'encoder.embed.conv.0.weight', 'encoder.embed.conv.0.bias', 'encoder.embed.conv.2.weight', 'encoder.embed.conv.2.bias', 'encoder.embed.out.0.weight', 'encoder.embed.out.0.bias', 'encoder.encoders.0.self_attn.linear_q.weight', 'encoder.encoders.0.self_attn.linear_q.bias', 'encoder.encoders.0.self_attn.linear_k.weight', 'encoder.encoders.0.self_attn.linear_k.bias', 'encoder.encoders.0.self_attn.linear_v.weight', 'encoder.encoders.0.self_attn.linear_v.bias', 'encoder.encoders.0.self_attn.linear_out.weight', 'encoder.encoders.0.self_attn.linear_out.bias', 'encoder.encoders.0.feed_forward.w_1.weight', 'encoder.encoders.0.feed_forward.w_1.bias', 'encoder.encoders.0.feed_forward.w_2.weight', 'encoder.encoders.0.feed_forward.w_2.bias', 'encoder.encoders.0.norm1.weight', 'encoder.encoders.0.norm1.bias', 'encoder.encoders.0.norm2.weight', 'encoder.encoders.0.norm2.bias', 'encoder.encoders.1.self_attn.linear_q.weight', 'encoder.encoders.1.self_attn.linear_q.bias', 'encoder.encoders.1.self_attn.linear_k.weight', 'encoder.encoders.1.self_attn.linear_k.bias', 'encoder.encoders.1.self_attn.linear_v.weight', 'encoder.encoders.1.self_attn.linear_v.bias', 'encoder.encoders.1.self_attn.linear_out.weight', 'encoder.encoders.1.self_attn.linear_out.bias', 'encoder.encoders.1.feed_forward.w_1.weight', 'encoder.encoders.1.feed_forward.w_1.bias', 'encoder.encoders.1.feed_forward.w_2.weight', 'encoder.encoders.1.feed_forward.w_2.bias', 'encoder.encoders.1.norm1.weight', 'encoder.encoders.1.norm1.bias', 'encoder.encoders.1.norm2.weight', 'encoder.encoders.1.norm2.bias', 'encoder.encoders.2.self_attn.linear_q.weight', 'encoder.encoders.2.self_attn.linear_q.bias', 'encoder.encoders.2.self_attn.linear_k.weight', 'encoder.encoders.2.self_attn.linear_k.bias', 'encoder.encoders.2.self_attn.linear_v.weight', 'encoder.encoders.2.self_attn.linear_v.bias', 'encoder.encoders.2.self_attn.linear_out.weight', 'encoder.encoders.2.self_attn.linear_out.bias', 'encoder.encoders.2.feed_forward.w_1.weight', 'encoder.encoders.2.feed_forward.w_1.bias', 'encoder.encoders.2.feed_forward.w_2.weight', 'encoder.encoders.2.feed_forward.w_2.bias', 'encoder.encoders.2.norm1.weight', 'encoder.encoders.2.norm1.bias', 'encoder.encoders.2.norm2.weight', 'encoder.encoders.2.norm2.bias', 'encoder.encoders.3.self_attn.linear_q.weight', 'encoder.encoders.3.self_attn.linear_q.bias', 'encoder.encoders.3.self_attn.linear_k.weight', 'encoder.encoders.3.self_attn.linear_k.bias', 'encoder.encoders.3.self_attn.linear_v.weight', 'encoder.encoders.3.self_attn.linear_v.bias', 'encoder.encoders.3.self_attn.linear_out.weight', 'encoder.encoders.3.self_attn.linear_out.bias', 'encoder.encoders.3.feed_forward.w_1.weight', 'encoder.encoders.3.feed_forward.w_1.bias', 'encoder.encoders.3.feed_forward.w_2.weight', 'encoder.encoders.3.feed_forward.w_2.bias', 'encoder.encoders.3.norm1.weight', 'encoder.encoders.3.norm1.bias', 'encoder.encoders.3.norm2.weight', 'encoder.encoders.3.norm2.bias', 'encoder.encoders.4.self_attn.linear_q.weight', 'encoder.encoders.4.self_attn.linear_q.bias', 'encoder.encoders.4.self_attn.linear_k.weight', 'encoder.encoders.4.self_attn.linear_k.bias', 'encoder.encoders.4.self_attn.linear_v.weight', 'encoder.encoders.4.self_attn.linear_v.bias', 'encoder.encoders.4.self_attn.linear_out.weight', 'encoder.encoders.4.self_attn.linear_out.bias', 'encoder.encoders.4.feed_forward.w_1.weight', 'encoder.encoders.4.feed_forward.w_1.bias', 'encoder.encoders.4.feed_forward.w_2.weight', 'encoder.encoders.4.feed_forward.w_2.bias', 'encoder.encoders.4.norm1.weight', 'encoder.encoders.4.norm1.bias', 'encoder.encoders.4.norm2.weight', 'encoder.encoders.4.norm2.bias', 'encoder.encoders.5.self_attn.linear_q.weight', 'encoder.encoders.5.self_attn.linear_q.bias', 'encoder.encoders.5.self_attn.linear_k.weight', 'encoder.encoders.5.self_attn.linear_k.bias', 'encoder.encoders.5.self_attn.linear_v.weight', 'encoder.encoders.5.self_attn.linear_v.bias', 'encoder.encoders.5.self_attn.linear_out.weight', 'encoder.encoders.5.self_attn.linear_out.bias', 'encoder.encoders.5.feed_forward.w_1.weight', 'encoder.encoders.5.feed_forward.w_1.bias', 'encoder.encoders.5.feed_forward.w_2.weight', 'encoder.encoders.5.feed_forward.w_2.bias', 'encoder.encoders.5.norm1.weight', 'encoder.encoders.5.norm1.bias', 'encoder.encoders.5.norm2.weight', 'encoder.encoders.5.norm2.bias', 'encoder.encoders.6.self_attn.linear_q.weight', 'encoder.encoders.6.self_attn.linear_q.bias', 'encoder.encoders.6.self_attn.linear_k.weight', 'encoder.encoders.6.self_attn.linear_k.bias', 'encoder.encoders.6.self_attn.linear_v.weight', 'encoder.encoders.6.self_attn.linear_v.bias', 'encoder.encoders.6.self_attn.linear_out.weight', 'encoder.encoders.6.self_attn.linear_out.bias', 'encoder.encoders.6.feed_forward.w_1.weight', 'encoder.encoders.6.feed_forward.w_1.bias', 'encoder.encoders.6.feed_forward.w_2.weight', 'encoder.encoders.6.feed_forward.w_2.bias', 'encoder.encoders.6.norm1.weight', 'encoder.encoders.6.norm1.bias', 'encoder.encoders.6.norm2.weight', 'encoder.encoders.6.norm2.bias', 'encoder.encoders.7.self_attn.linear_q.weight', 'encoder.encoders.7.self_attn.linear_q.bias', 'encoder.encoders.7.self_attn.linear_k.weight', 'encoder.encoders.7.self_attn.linear_k.bias', 'encoder.encoders.7.self_attn.linear_v.weight', 'encoder.encoders.7.self_attn.linear_v.bias', 'encoder.encoders.7.self_attn.linear_out.weight', 'encoder.encoders.7.self_attn.linear_out.bias', 'encoder.encoders.7.feed_forward.w_1.weight', 'encoder.encoders.7.feed_forward.w_1.bias', 'encoder.encoders.7.feed_forward.w_2.weight', 'encoder.encoders.7.feed_forward.w_2.bias', 'encoder.encoders.7.norm1.weight', 'encoder.encoders.7.norm1.bias', 'encoder.encoders.7.norm2.weight', 'encoder.encoders.7.norm2.bias', 'encoder.encoders.8.self_attn.linear_q.weight', 'encoder.encoders.8.self_attn.linear_q.bias', 'encoder.encoders.8.self_attn.linear_k.weight', 'encoder.encoders.8.self_attn.linear_k.bias', 'encoder.encoders.8.self_attn.linear_v.weight', 'encoder.encoders.8.self_attn.linear_v.bias', 'encoder.encoders.8.self_attn.linear_out.weight', 'encoder.encoders.8.self_attn.linear_out.bias', 'encoder.encoders.8.feed_forward.w_1.weight', 'encoder.encoders.8.feed_forward.w_1.bias', 'encoder.encoders.8.feed_forward.w_2.weight', 'encoder.encoders.8.feed_forward.w_2.bias', 'encoder.encoders.8.norm1.weight', 'encoder.encoders.8.norm1.bias', 'encoder.encoders.8.norm2.weight', 'encoder.encoders.8.norm2.bias', 'encoder.encoders.9.self_attn.linear_q.weight', 'encoder.encoders.9.self_attn.linear_q.bias', 'encoder.encoders.9.self_attn.linear_k.weight', 'encoder.encoders.9.self_attn.linear_k.bias', 'encoder.encoders.9.self_attn.linear_v.weight', 'encoder.encoders.9.self_attn.linear_v.bias', 'encoder.encoders.9.self_attn.linear_out.weight', 'encoder.encoders.9.self_attn.linear_out.bias', 'encoder.encoders.9.feed_forward.w_1.weight', 'encoder.encoders.9.feed_forward.w_1.bias', 'encoder.encoders.9.feed_forward.w_2.weight', 'encoder.encoders.9.feed_forward.w_2.bias', 'encoder.encoders.9.norm1.weight', 'encoder.encoders.9.norm1.bias', 'encoder.encoders.9.norm2.weight', 'encoder.encoders.9.norm2.bias', 'encoder.encoders.10.self_attn.linear_q.weight', 'encoder.encoders.10.self_attn.linear_q.bias', 'encoder.encoders.10.self_attn.linear_k.weight', 'encoder.encoders.10.self_attn.linear_k.bias', 'encoder.encoders.10.self_attn.linear_v.weight', 'encoder.encoders.10.self_attn.linear_v.bias', 'encoder.encoders.10.self_attn.linear_out.weight', 'encoder.encoders.10.self_attn.linear_out.bias', 'encoder.encoders.10.feed_forward.w_1.weight', 'encoder.encoders.10.feed_forward.w_1.bias', 'encoder.encoders.10.feed_forward.w_2.weight', 'encoder.encoders.10.feed_forward.w_2.bias', 'encoder.encoders.10.norm1.weight', 'encoder.encoders.10.norm1.bias', 'encoder.encoders.10.norm2.weight', 'encoder.encoders.10.norm2.bias', 'encoder.encoders.11.self_attn.linear_q.weight', 'encoder.encoders.11.self_attn.linear_q.bias', 'encoder.encoders.11.self_attn.linear_k.weight', 'encoder.encoders.11.self_attn.linear_k.bias', 'encoder.encoders.11.self_attn.linear_v.weight', 'encoder.encoders.11.self_attn.linear_v.bias', 'encoder.encoders.11.self_attn.linear_out.weight', 'encoder.encoders.11.self_attn.linear_out.bias', 'encoder.encoders.11.feed_forward.w_1.weight', 'encoder.encoders.11.feed_forward.w_1.bias', 'encoder.encoders.11.feed_forward.w_2.weight', 'encoder.encoders.11.feed_forward.w_2.bias', 'encoder.encoders.11.norm1.weight', 'encoder.encoders.11.norm1.bias', 'encoder.encoders.11.norm2.weight', 'encoder.encoders.11.norm2.bias', 'encoder.after_norm.weight', 'encoder.after_norm.bias', 'decoder.embed.0.weight', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'sfc.weight', 'sfc.bias', 'deconv.0.weight', 'deconv.0.bias', 'deconv.1.weight', 'deconv.1.bias', 'xlm_embed.0.weight', 'xlm_pred.weight', 'xlm_pred.bias'])\n"
]
}
],
"source": [
"#!pip install torch\n",
"import torch\n",
"\n",
"e_model = np.load('.notebook/espnet/model.npz',allow_pickle=True)\n",
"for k in e_model.files:\n",
" print(k)\n",
"state_dict = e_model['state']\n",
"state_dict = state_dict.tolist()\n",
"print(state_dict.keys())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f187bb55",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"# embed.conv.0.weight None torch.Size([256, 1, 3, 3]) \tencoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n",
"# embed.conv.0.bias None torch.Size([256]) \tencoder.embed.conv.0.bias | [256] | 256 | True\n",
"# embed.conv.2.weight None torch.Size([256, 256, 3, 3]) \tencoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n",
"# embed.conv.2.bias None torch.Size([256]) \tencoder.embed.conv.2.bias | [256] | 256 | True\n",
"# embed.out.0.weight None torch.Size([256, 5120]) 83 feature\tencoder.embed.out.0.weight | [4864, 256] | 1245184 | True 80 feature\n",
"# embed.out.0.bias None torch.Size([256]) \tencoder.embed.out.0.bias | [256] | 256 | True\n",
"# after_norm.weight None torch.Size([256]) \tencoder.after_norm.weight | [256] | 256 | True\n",
"# after_norm.bias None torch.Size([256]) \tencoder.after_norm.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_q.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_q.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_k.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_k.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_v.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_v.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_out.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_out.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"# encoders.9.feed_forward.w_1.weight None torch.Size([2048, 256]) \tencoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"# encoders.9.feed_forward.w_1.bias None torch.Size([2048]) \tencoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"# encoders.9.feed_forward.w_2.weight None torch.Size([256, 2048]) \tencoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"# encoders.9.feed_forward.w_2.bias None torch.Size([256]) \tencoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"# encoders.9.norm1.weight None torch.Size([256]) \tencoder.encoders.0.norm1.weight | [256] | 256 | True\n",
"# encoders.9.norm1.bias None torch.Size([256]) \tencoder.encoders.0.norm1.bias | [256] | 256 | True\n",
"# encoders.9.norm2.weight None torch.Size([256]) \tencoder.encoders.0.norm2.weight | [256] | 256 | True\n",
"# encoders.9.norm2.bias None torch.Size([256]) \tencoder.encoders.0.norm2.bias | [256] | 256 | True\n",
"# \tencoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n",
"# \tencoder.encoders.0.concat_linear.bias | [256] | 256 | True\n",
"# espnet transformer\tconcat_linear只是保存了,但是未使用\n",
"\t\n",
"# \tpaddle transformer"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2a0428ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-> encoder.embed.conv.0.weight\n",
"-> encoder.embed.conv.0.bias\n",
"-> encoder.embed.conv.2.weight\n",
"-> encoder.embed.conv.2.bias\n",
"-> encoder.embed.out.0.weight\n",
"encoder.embed.out.0.weight: (256, 5120) -> (5120, 256)\n",
"-> encoder.embed.out.0.bias\n",
"-> encoder.encoders.0.self_attn.linear_q.weight\n",
"encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_q.bias\n",
"-> encoder.encoders.0.self_attn.linear_k.weight\n",
"encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_k.bias\n",
"-> encoder.encoders.0.self_attn.linear_v.weight\n",
"encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_v.bias\n",
"-> encoder.encoders.0.self_attn.linear_out.weight\n",
"encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_out.bias\n",
"-> encoder.encoders.0.feed_forward.w_1.weight\n",
"encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.0.feed_forward.w_1.bias\n",
"-> encoder.encoders.0.feed_forward.w_2.weight\n",
"encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.0.feed_forward.w_2.bias\n",
"-> encoder.encoders.0.norm1.weight\n",
"-> encoder.encoders.0.norm1.bias\n",
"-> encoder.encoders.0.norm2.weight\n",
"-> encoder.encoders.0.norm2.bias\n",
"-> encoder.encoders.1.self_attn.linear_q.weight\n",
"encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_q.bias\n",
"-> encoder.encoders.1.self_attn.linear_k.weight\n",
"encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_k.bias\n",
"-> encoder.encoders.1.self_attn.linear_v.weight\n",
"encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_v.bias\n",
"-> encoder.encoders.1.self_attn.linear_out.weight\n",
"encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_out.bias\n",
"-> encoder.encoders.1.feed_forward.w_1.weight\n",
"encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.1.feed_forward.w_1.bias\n",
"-> encoder.encoders.1.feed_forward.w_2.weight\n",
"encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.1.feed_forward.w_2.bias\n",
"-> encoder.encoders.1.norm1.weight\n",
"-> encoder.encoders.1.norm1.bias\n",
"-> encoder.encoders.1.norm2.weight\n",
"-> encoder.encoders.1.norm2.bias\n",
"-> encoder.encoders.2.self_attn.linear_q.weight\n",
"encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_q.bias\n",
"-> encoder.encoders.2.self_attn.linear_k.weight\n",
"encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_k.bias\n",
"-> encoder.encoders.2.self_attn.linear_v.weight\n",
"encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_v.bias\n",
"-> encoder.encoders.2.self_attn.linear_out.weight\n",
"encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_out.bias\n",
"-> encoder.encoders.2.feed_forward.w_1.weight\n",
"encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.2.feed_forward.w_1.bias\n",
"-> encoder.encoders.2.feed_forward.w_2.weight\n",
"encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.2.feed_forward.w_2.bias\n",
"-> encoder.encoders.2.norm1.weight\n",
"-> encoder.encoders.2.norm1.bias\n",
"-> encoder.encoders.2.norm2.weight\n",
"-> encoder.encoders.2.norm2.bias\n",
"-> encoder.encoders.3.self_attn.linear_q.weight\n",
"encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_q.bias\n",
"-> encoder.encoders.3.self_attn.linear_k.weight\n",
"encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_k.bias\n",
"-> encoder.encoders.3.self_attn.linear_v.weight\n",
"encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_v.bias\n",
"-> encoder.encoders.3.self_attn.linear_out.weight\n",
"encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_out.bias\n",
"-> encoder.encoders.3.feed_forward.w_1.weight\n",
"encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.3.feed_forward.w_1.bias\n",
"-> encoder.encoders.3.feed_forward.w_2.weight\n",
"encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.3.feed_forward.w_2.bias\n",
"-> encoder.encoders.3.norm1.weight\n",
"-> encoder.encoders.3.norm1.bias\n",
"-> encoder.encoders.3.norm2.weight\n",
"-> encoder.encoders.3.norm2.bias\n",
"-> encoder.encoders.4.self_attn.linear_q.weight\n",
"encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_q.bias\n",
"-> encoder.encoders.4.self_attn.linear_k.weight\n",
"encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_k.bias\n",
"-> encoder.encoders.4.self_attn.linear_v.weight\n",
"encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_v.bias\n",
"-> encoder.encoders.4.self_attn.linear_out.weight\n",
"encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_out.bias\n",
"-> encoder.encoders.4.feed_forward.w_1.weight\n",
"encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.4.feed_forward.w_1.bias\n",
"-> encoder.encoders.4.feed_forward.w_2.weight\n",
"encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.4.feed_forward.w_2.bias\n",
"-> encoder.encoders.4.norm1.weight\n",
"-> encoder.encoders.4.norm1.bias\n",
"-> encoder.encoders.4.norm2.weight\n",
"-> encoder.encoders.4.norm2.bias\n",
"-> encoder.encoders.5.self_attn.linear_q.weight\n",
"encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_q.bias\n",
"-> encoder.encoders.5.self_attn.linear_k.weight\n",
"encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_k.bias\n",
"-> encoder.encoders.5.self_attn.linear_v.weight\n",
"encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_v.bias\n",
"-> encoder.encoders.5.self_attn.linear_out.weight\n",
"encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_out.bias\n",
"-> encoder.encoders.5.feed_forward.w_1.weight\n",
"encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.5.feed_forward.w_1.bias\n",
"-> encoder.encoders.5.feed_forward.w_2.weight\n",
"encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.5.feed_forward.w_2.bias\n",
"-> encoder.encoders.5.norm1.weight\n",
"-> encoder.encoders.5.norm1.bias\n",
"-> encoder.encoders.5.norm2.weight\n",
"-> encoder.encoders.5.norm2.bias\n",
"-> encoder.encoders.6.self_attn.linear_q.weight\n",
"encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_q.bias\n",
"-> encoder.encoders.6.self_attn.linear_k.weight\n",
"encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_k.bias\n",
"-> encoder.encoders.6.self_attn.linear_v.weight\n",
"encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_v.bias\n",
"-> encoder.encoders.6.self_attn.linear_out.weight\n",
"encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_out.bias\n",
"-> encoder.encoders.6.feed_forward.w_1.weight\n",
"encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.6.feed_forward.w_1.bias\n",
"-> encoder.encoders.6.feed_forward.w_2.weight\n",
"encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.6.feed_forward.w_2.bias\n",
"-> encoder.encoders.6.norm1.weight\n",
"-> encoder.encoders.6.norm1.bias\n",
"-> encoder.encoders.6.norm2.weight\n",
"-> encoder.encoders.6.norm2.bias\n",
"-> encoder.encoders.7.self_attn.linear_q.weight\n",
"encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_q.bias\n",
"-> encoder.encoders.7.self_attn.linear_k.weight\n",
"encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_k.bias\n",
"-> encoder.encoders.7.self_attn.linear_v.weight\n",
"encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_v.bias\n",
"-> encoder.encoders.7.self_attn.linear_out.weight\n",
"encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_out.bias\n",
"-> encoder.encoders.7.feed_forward.w_1.weight\n",
"encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.7.feed_forward.w_1.bias\n",
"-> encoder.encoders.7.feed_forward.w_2.weight\n",
"encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.7.feed_forward.w_2.bias\n",
"-> encoder.encoders.7.norm1.weight\n",
"-> encoder.encoders.7.norm1.bias\n",
"-> encoder.encoders.7.norm2.weight\n",
"-> encoder.encoders.7.norm2.bias\n",
"-> encoder.encoders.8.self_attn.linear_q.weight\n",
"encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_q.bias\n",
"-> encoder.encoders.8.self_attn.linear_k.weight\n",
"encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_k.bias\n",
"-> encoder.encoders.8.self_attn.linear_v.weight\n",
"encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_v.bias\n",
"-> encoder.encoders.8.self_attn.linear_out.weight\n",
"encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_out.bias\n",
"-> encoder.encoders.8.feed_forward.w_1.weight\n",
"encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.8.feed_forward.w_1.bias\n",
"-> encoder.encoders.8.feed_forward.w_2.weight\n",
"encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.8.feed_forward.w_2.bias\n",
"-> encoder.encoders.8.norm1.weight\n",
"-> encoder.encoders.8.norm1.bias\n",
"-> encoder.encoders.8.norm2.weight\n",
"-> encoder.encoders.8.norm2.bias\n",
"-> encoder.encoders.9.self_attn.linear_q.weight\n",
"encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_q.bias\n",
"-> encoder.encoders.9.self_attn.linear_k.weight\n",
"encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_k.bias\n",
"-> encoder.encoders.9.self_attn.linear_v.weight\n",
"encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_v.bias\n",
"-> encoder.encoders.9.self_attn.linear_out.weight\n",
"encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_out.bias\n",
"-> encoder.encoders.9.feed_forward.w_1.weight\n",
"encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.9.feed_forward.w_1.bias\n",
"-> encoder.encoders.9.feed_forward.w_2.weight\n",
"encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.9.feed_forward.w_2.bias\n",
"-> encoder.encoders.9.norm1.weight\n",
"-> encoder.encoders.9.norm1.bias\n",
"-> encoder.encoders.9.norm2.weight\n",
"-> encoder.encoders.9.norm2.bias\n",
"-> encoder.encoders.10.self_attn.linear_q.weight\n",
"encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_q.bias\n",
"-> encoder.encoders.10.self_attn.linear_k.weight\n",
"encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_k.bias\n",
"-> encoder.encoders.10.self_attn.linear_v.weight\n",
"encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_v.bias\n",
"-> encoder.encoders.10.self_attn.linear_out.weight\n",
"encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_out.bias\n",
"-> encoder.encoders.10.feed_forward.w_1.weight\n",
"encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.10.feed_forward.w_1.bias\n",
"-> encoder.encoders.10.feed_forward.w_2.weight\n",
"encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.10.feed_forward.w_2.bias\n",
"-> encoder.encoders.10.norm1.weight\n",
"-> encoder.encoders.10.norm1.bias\n",
"-> encoder.encoders.10.norm2.weight\n",
"-> encoder.encoders.10.norm2.bias\n",
"-> encoder.encoders.11.self_attn.linear_q.weight\n",
"encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_q.bias\n",
"-> encoder.encoders.11.self_attn.linear_k.weight\n",
"encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_k.bias\n",
"-> encoder.encoders.11.self_attn.linear_v.weight\n",
"encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_v.bias\n",
"-> encoder.encoders.11.self_attn.linear_out.weight\n",
"encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_out.bias\n",
"-> encoder.encoders.11.feed_forward.w_1.weight\n",
"encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.11.feed_forward.w_1.bias\n",
"-> encoder.encoders.11.feed_forward.w_2.weight\n",
"encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.11.feed_forward.w_2.bias\n",
"-> encoder.encoders.11.norm1.weight\n",
"-> encoder.encoders.11.norm1.bias\n",
"-> encoder.encoders.11.norm2.weight\n",
"-> encoder.encoders.11.norm2.bias\n",
"-> encoder.after_norm.weight\n",
"-> encoder.after_norm.bias\n"
]
}
],
"source": [
"# dump torch model to paddle\n",
"#state_dict = model.state_dict()\n",
"paddle_state_dict = {}\n",
"\n",
"for n, p in state_dict.items():\n",
" if 'encoder' not in n:\n",
" continue \n",
" print(f'-> {n}')\n",
" \n",
" \n",
" name_change=True\n",
" if 'norm.running_mean' in n:\n",
" new_n = n.replace('norm.running_', 'norm._')\n",
" elif 'norm.running_var' in n:\n",
" new_n = n.replace('norm.running_var', 'norm._variance')\n",
" else:\n",
" name_change=False\n",
" new_n = n\n",
" if name_change:\n",
" print(f\"{n} -> {new_n}\")\n",
" \n",
" \n",
" p = p.cpu().detach().numpy()\n",
" if n.endswith('weight') and p.ndim == 2:\n",
" new_p = p.T\n",
" print(f\"{n}: {p.shape} -> {new_p.shape}\")\n",
" else:\n",
" new_p = p\n",
" \n",
" if 'global_cmvn.mean' in n:\n",
" print(p, p.dtype)\n",
" \n",
" paddle_state_dict[new_n] = new_p\n",
" \n",
"# np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n",
"# state=paddle_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a1d97e9f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.weight. encoder.encoders.0.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.bias. encoder.encoders.0.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.weight. encoder.encoders.1.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.bias. encoder.encoders.1.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.weight. encoder.encoders.2.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.bias. encoder.encoders.2.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.weight. encoder.encoders.3.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.bias. encoder.encoders.3.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.weight. encoder.encoders.4.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.bias. encoder.encoders.4.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.weight. encoder.encoders.5.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.bias. encoder.encoders.5.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.weight. encoder.encoders.6.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.bias. encoder.encoders.6.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.weight. encoder.encoders.7.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.bias. encoder.encoders.7.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.weight. encoder.encoders.8.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.bias. encoder.encoders.8.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.weight. encoder.encoders.9.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.bias. encoder.encoders.9.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.weight. encoder.encoders.10.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.bias. encoder.encoders.10.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.weight. encoder.encoders.11.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.bias. encoder.encoders.11.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.embed.0.weight. decoder.embed.0.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.weight. decoder.after_norm.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.bias. decoder.after_norm.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.weight. decoder.output_layer.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.bias. decoder.output_layer.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.weight. decoder.decoders.0.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.bias. decoder.decoders.0.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.weight. decoder.decoders.0.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.bias. decoder.decoders.0.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.weight. decoder.decoders.0.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.bias. decoder.decoders.0.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.weight. decoder.decoders.0.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.bias. decoder.decoders.0.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.weight. decoder.decoders.0.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.bias. decoder.decoders.0.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.weight. decoder.decoders.0.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.bias. decoder.decoders.0.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.weight. decoder.decoders.0.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.bias. decoder.decoders.0.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.weight. decoder.decoders.0.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.bias. decoder.decoders.0.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.weight. decoder.decoders.0.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.bias. decoder.decoders.0.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.weight. decoder.decoders.0.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.bias. decoder.decoders.0.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.weight. decoder.decoders.0.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.bias. decoder.decoders.0.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.weight. decoder.decoders.0.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.bias. decoder.decoders.0.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.weight. decoder.decoders.0.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.bias. decoder.decoders.0.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.weight. decoder.decoders.0.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.bias. decoder.decoders.0.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.weight. decoder.decoders.0.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.bias. decoder.decoders.0.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.weight. decoder.decoders.1.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.bias. decoder.decoders.1.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.weight. decoder.decoders.1.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.bias. decoder.decoders.1.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.weight. decoder.decoders.1.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.bias. decoder.decoders.1.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.weight. decoder.decoders.1.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.bias. decoder.decoders.1.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.weight. decoder.decoders.1.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.bias. decoder.decoders.1.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.weight. decoder.decoders.1.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.bias. decoder.decoders.1.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.weight. decoder.decoders.1.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.bias. decoder.decoders.1.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.weight. decoder.decoders.1.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.bias. decoder.decoders.1.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.weight. decoder.decoders.1.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.bias. decoder.decoders.1.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.weight. decoder.decoders.1.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.bias. decoder.decoders.1.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.weight. decoder.decoders.1.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.bias. decoder.decoders.1.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.weight. decoder.decoders.1.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.bias. decoder.decoders.1.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.weight. decoder.decoders.1.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.bias. decoder.decoders.1.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.weight. decoder.decoders.1.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.bias. decoder.decoders.1.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.weight. decoder.decoders.1.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.bias. decoder.decoders.1.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.weight. decoder.decoders.2.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.bias. decoder.decoders.2.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.weight. decoder.decoders.2.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.bias. decoder.decoders.2.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.weight. decoder.decoders.2.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.bias. decoder.decoders.2.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.weight. decoder.decoders.2.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.bias. decoder.decoders.2.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.weight. decoder.decoders.2.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.bias. decoder.decoders.2.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.weight. decoder.decoders.2.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.bias. decoder.decoders.2.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.weight. decoder.decoders.2.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.bias. decoder.decoders.2.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.weight. decoder.decoders.2.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.bias. decoder.decoders.2.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.weight. decoder.decoders.2.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.bias. decoder.decoders.2.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.weight. decoder.decoders.2.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.bias. decoder.decoders.2.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.weight. decoder.decoders.2.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.bias. decoder.decoders.2.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.weight. decoder.decoders.2.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.bias. decoder.decoders.2.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.weight. decoder.decoders.2.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.bias. decoder.decoders.2.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.weight. decoder.decoders.2.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.bias. decoder.decoders.2.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.weight. decoder.decoders.2.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.bias. decoder.decoders.2.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.weight. decoder.decoders.3.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.bias. decoder.decoders.3.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.weight. decoder.decoders.3.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.bias. decoder.decoders.3.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.weight. decoder.decoders.3.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.bias. decoder.decoders.3.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.weight. decoder.decoders.3.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.bias. decoder.decoders.3.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.weight. decoder.decoders.3.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.bias. decoder.decoders.3.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.weight. decoder.decoders.3.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.bias. decoder.decoders.3.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.weight. decoder.decoders.3.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.bias. decoder.decoders.3.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.weight. decoder.decoders.3.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.bias. decoder.decoders.3.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.weight. decoder.decoders.3.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.bias. decoder.decoders.3.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.weight. decoder.decoders.3.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.bias. decoder.decoders.3.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.weight. decoder.decoders.3.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.bias. decoder.decoders.3.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.weight. decoder.decoders.3.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.bias. decoder.decoders.3.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.weight. decoder.decoders.3.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.bias. decoder.decoders.3.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.weight. decoder.decoders.3.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.bias. decoder.decoders.3.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.weight. decoder.decoders.3.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.bias. decoder.decoders.3.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.weight. decoder.decoders.4.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.bias. decoder.decoders.4.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.weight. decoder.decoders.4.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.bias. decoder.decoders.4.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.weight. decoder.decoders.4.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.bias. decoder.decoders.4.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.weight. decoder.decoders.4.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.bias. decoder.decoders.4.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.weight. decoder.decoders.4.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.bias. decoder.decoders.4.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.weight. decoder.decoders.4.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.bias. decoder.decoders.4.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.weight. decoder.decoders.4.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.bias. decoder.decoders.4.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.weight. decoder.decoders.4.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.bias. decoder.decoders.4.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.weight. decoder.decoders.4.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.bias. decoder.decoders.4.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.weight. decoder.decoders.4.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.bias. decoder.decoders.4.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.weight. decoder.decoders.4.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.bias. decoder.decoders.4.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.weight. decoder.decoders.4.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.bias. decoder.decoders.4.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.weight. decoder.decoders.4.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.bias. decoder.decoders.4.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.weight. decoder.decoders.4.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.bias. decoder.decoders.4.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.weight. decoder.decoders.4.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.bias. decoder.decoders.4.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.weight. decoder.decoders.5.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.bias. decoder.decoders.5.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.weight. decoder.decoders.5.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.bias. decoder.decoders.5.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.weight. decoder.decoders.5.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.bias. decoder.decoders.5.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.weight. decoder.decoders.5.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.bias. decoder.decoders.5.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.weight. decoder.decoders.5.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.bias. decoder.decoders.5.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.weight. decoder.decoders.5.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.bias. decoder.decoders.5.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.weight. decoder.decoders.5.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.bias. decoder.decoders.5.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.weight. decoder.decoders.5.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.bias. decoder.decoders.5.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.weight. decoder.decoders.5.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.bias. decoder.decoders.5.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.weight. decoder.decoders.5.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.bias. decoder.decoders.5.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.weight. decoder.decoders.5.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.bias. decoder.decoders.5.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.weight. decoder.decoders.5.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.bias. decoder.decoders.5.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.weight. decoder.decoders.5.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.bias. decoder.decoders.5.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.weight. decoder.decoders.5.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.bias. decoder.decoders.5.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.weight. decoder.decoders.5.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.bias. decoder.decoders.5.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.weight. ctc.ctc_lo.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.bias. ctc.ctc_lo.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n"
]
}
],
"source": [
"model.set_state_dict(paddle_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fc7edf1e",
"metadata": {},
"outputs": [],
"source": [
"e_state = model.encoder.state_dict()\n",
"for key, value in e_state.items():\n",
" if 'concat_linear' in key:\n",
" continue\n",
" if not np.allclose(value.numpy(), paddle_state_dict['encoder.' + key]):\n",
" print(key)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "572097d0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "748250b7",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "91e5deee",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 14,
"id": "fleet-despite",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n",
"encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n"
]
}
],
"source": [
"%ls .notebook/espnet"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "abroad-oracle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(8, 57, 83)\n",
"(8, 1, 57)\n",
"[57 50 48 38 32 31 28 25]\n"
]
}
],
"source": [
"data = np.load('.notebook/espnet/feat.npz', allow_pickle=True)\n",
"xs=data['xs']\n",
"masks=data['masks']\n",
"print(xs.shape)\n",
"print(masks.shape)\n",
"xs_lens = masks.sum(axis=-1).squeeze()\n",
"print(xs_lens)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "false-instrument",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[8, 13, 256]\n",
"[8, 1, 13]\n"
]
}
],
"source": [
"# ecnoder\n",
"xs = paddle.to_tensor(xs, dtype='float32')\n",
"x_lens = paddle.to_tensor(xs_lens, dtype='int32')\n",
"model.eval()\n",
"encoder_out, encoder_mask = model.encoder(xs, x_lens)\n",
"print(encoder_out.shape)\n",
"print(encoder_mask.shape)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "arctic-proxy",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(8, 13, 256)\n",
"(8, 1, 13)\n",
"False\n",
"False\n",
"True\n",
"True\n"
]
}
],
"source": [
"data = np.load('.notebook/espnet/encoder.npz', allow_pickle=True)\n",
"xs = data['xs']\n",
"masks = data['masks']\n",
"print(xs.shape)\n",
"print(masks.shape)\n",
"print(np.allclose(xs, encoder_out.numpy()))\n",
"print(np.allclose(xs, encoder_out.numpy(), atol=1e-6))\n",
"print(np.allclose(xs, encoder_out.numpy(), atol=1e-5))\n",
"print(np.allclose(masks, encoder_mask.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "seasonal-switch",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 2.1380312 1.8675405 -1.1873871 ... -0.30456656 0.56382173\n",
" -0.6526459 ]\n",
" [ 2.1926146 2.1373641 -0.6548196 ... -0.897318 0.6044322\n",
" -0.63332295]\n",
" [ 1.6367635 2.3320658 -0.8848577 ... -0.9640939 1.2420733\n",
" -0.05243584]\n",
" ...\n",
" [ 1.8533031 1.8421621 -0.6728406 ... 0.04810616 0.6459763\n",
" -0.18188554]\n",
" [ 2.0894065 1.7813934 -1.1591585 ... -0.09513803 0.8321831\n",
" -0.72916794]\n",
" [ 1.6488649 2.0984242 -1.3490562 ... 0.42678255 0.5903866\n",
" -0.32597935]]\n",
"Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[ 2.13803196, 1.86753929, -1.18738675, ..., -0.30456796, 0.56382364, -0.65264463],\n",
" [ 2.19261336, 2.13736486, -0.65482187, ..., -0.89731705, 0.60443199, -0.63332343],\n",
" [ 1.63676369, 2.33206534, -0.88485885, ..., -0.96409231, 1.24207270, -0.05243752],\n",
" ...,\n",
" [ 1.85330284, 1.84216177, -0.67284071, ..., 0.04810715, 0.64597648, -0.18188696],\n",
" [ 2.08940673, 1.78139246, -1.15916038, ..., -0.09513779, 0.83218288, -0.72916913],\n",
" [ 1.64886570, 2.09842515, -1.34905660, ..., 0.42678308, 0.59038705, -0.32598034]])\n"
]
}
],
"source": [
"print(xs[0])\n",
"print(encoder_out[0])"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "defined-brooks",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 2.209824 1.5208759 0.1417884 ... -0.73617566 1.6538682\n",
" -0.16355833]\n",
" [ 2.1441019 1.4377339 0.3629197 ... -0.91226125 1.3739952\n",
" 0.11874156]\n",
" [ 1.8725398 1.5417286 0.38919652 ... -0.89621615 1.1841662\n",
" 0.27621832]\n",
" ...\n",
" [ 2.4591084 0.7238764 -1.1456345 ... -0.24188249 0.8232168\n",
" -0.9794884 ]\n",
" [ 2.5156236 1.1919155 -0.97032744 ... -0.7360675 1.0647209\n",
" -1.3076135 ]\n",
" [ 2.160009 0.98425585 -1.2231126 ... -0.03393313 1.9141548\n",
" -1.0099151 ]]\n",
"Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[ 2.20982409, 1.52087593, 0.14178854, ..., -0.73617446, 1.65386844, -0.16355731],\n",
" [ 2.14410043, 1.43773460, 0.36291891, ..., -0.91226172, 1.37399518, 0.11874183],\n",
" [ 1.87254059, 1.54172909, 0.38919681, ..., -0.89621687, 1.18416822, 0.27621880],\n",
" ...,\n",
" [ 2.45910931, 0.72387671, -1.14563596, ..., -0.24188218, 0.82321703, -0.97948682],\n",
" [ 2.51562238, 1.19191694, -0.97032893, ..., -0.73606837, 1.06472087, -1.30761123],\n",
" [ 2.16000915, 0.98425680, -1.22311163, ..., -0.03393326, 1.91415381, -1.00991392]])\n"
]
}
],
"source": [
"print(xs[1])\n",
"print(encoder_out[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0504e3f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
因为 它太大了无法显示 source diff 。你可以改为 查看blob
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册