diff --git a/.gitignore b/.gitignore index e4134a08240e4ba03f6a3a3979732d7d9960849b..cd2360e156fbe40f76b83cd107c71fedef8ed765 100644 --- a/.gitignore +++ b/.gitignore @@ -18,5 +18,7 @@ tools/sox-14.4.2 tools/soxbindings tools/montreal-forced-aligner/ tools/Montreal-Forced-Aligner/ +tools/sctk +tools/sctk-20159b5/ *output/ diff --git a/.notebook/Linear_test.ipynb b/.notebook/Linear_test.ipynb deleted file mode 100644 index 5c7370cf328557eda1029317453c72dd999d0e2a..0000000000000000000000000000000000000000 --- a/.notebook/Linear_test.ipynb +++ /dev/null @@ -1,605 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "academic-surname", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from paddle import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "fundamental-treasure", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "L = nn.Linear(256, 2048)\n", - "L2 = nn.Linear(2048, 256)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "consolidated-elephant", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "moderate-noise", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n", - "Tensor(shape=[2, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[-1.54171216, -2.61531472, -1.79881978, ..., -0.31395876, 0.56513089, -0.44516513],\n", - " [-0.79492962, 1.91157901, 0.66567147, ..., 0.54825783, -1.01471853, -0.84924090],\n", - " [-1.22556651, -0.36225814, 0.65063190, ..., 0.65726501, 0.05563191, 0.09009409],\n", - " ...,\n", - " [ 0.38615900, -0.77905393, 0.99732304, ..., -1.38463700, -3.32365036, -1.31089687],\n", - " [ 0.05579993, 0.06885809, -1.66662002, ..., -0.23346378, -3.29372883, 1.30561364],\n", - " [ 1.90676069, 1.95093191, -0.28849599, ..., -0.06860496, 0.95347673, 1.00475824]],\n", - "\n", - " [[-0.91453546, 0.55298805, -1.06146812, ..., -0.86378336, 1.00454640, 1.26062179],\n", - " [ 0.10223761, 0.81301165, 2.36865163, ..., 0.16821407, 0.29240361, 1.05408621],\n", - " [-1.33196676, 1.94433689, 0.01934209, ..., 0.48036841, 0.51585966, 1.22893548],\n", - " ...,\n", - " [-0.19558455, -0.47075930, 0.90796155, ..., -1.28598249, -0.24321797, 0.17734711],\n", - " [ 0.89819717, -1.39516675, 0.17138045, ..., 2.39761519, 1.76364994, -0.52177650],\n", - " [ 0.94122332, -0.18581429, 1.36099780, ..., 0.67647684, -0.04699665, 1.51205540]]])\n", - "tensor([[[-1.5417, -2.6153, -1.7988, ..., -0.3140, 0.5651, -0.4452],\n", - " [-0.7949, 1.9116, 0.6657, ..., 0.5483, -1.0147, -0.8492],\n", - " [-1.2256, -0.3623, 0.6506, ..., 0.6573, 0.0556, 0.0901],\n", - " ...,\n", - " [ 0.3862, -0.7791, 0.9973, ..., -1.3846, -3.3237, -1.3109],\n", - " [ 0.0558, 0.0689, -1.6666, ..., -0.2335, -3.2937, 1.3056],\n", - " [ 1.9068, 1.9509, -0.2885, ..., -0.0686, 0.9535, 1.0048]],\n", - "\n", - " [[-0.9145, 0.5530, -1.0615, ..., -0.8638, 1.0045, 1.2606],\n", - " [ 0.1022, 0.8130, 2.3687, ..., 0.1682, 0.2924, 1.0541],\n", - " [-1.3320, 1.9443, 0.0193, ..., 0.4804, 0.5159, 1.2289],\n", - " ...,\n", - " [-0.1956, -0.4708, 0.9080, ..., -1.2860, -0.2432, 0.1773],\n", - " [ 0.8982, -1.3952, 0.1714, ..., 2.3976, 1.7636, -0.5218],\n", - " [ 0.9412, -0.1858, 1.3610, ..., 0.6765, -0.0470, 1.5121]]])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 51, 256)\n", - "print(x.dtype)\n", - "px = paddle.to_tensor(x, dtype='float32')\n", - "tx = torch.tensor(x, dtype=torch.float32)\n", - "print(px)\n", - "print(tx)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cooked-progressive", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "mechanical-prisoner", - "metadata": {}, - "outputs": [], - "source": [ - "data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n", - "t_norm_ff = data['norm_ff']\n", - "t_ff_out = data['ff_out']\n", - "t_ff_l_x = data['ff_l_x']\n", - "t_ff_l_a_x = data['ff_l_a_x']\n", - "t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - "t_ps = data['ps']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "indie-marriage", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "assured-zambia", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "L.set_state_dict({'weight': t_ps[0].T, 'bias': t_ps[1]})\n", - "L2.set_state_dict({'weight': t_ps[2].T, 'bias': t_ps[3]})\n", - "\n", - "ps = []\n", - "for n, p in L.named_parameters():\n", - " ps.append(p)\n", - "\n", - "for n, p in L2.state_dict().items():\n", - " ps.append(p)\n", - " \n", - "for p, tp in zip(ps, t_ps):\n", - " print(np.allclose(p.numpy(), tp.T))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "committed-jacob", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "extreme-traffic", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "optimum-milwaukee", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "viral-indian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "# data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n", - "# t_norm_ff = data['norm_ff']\n", - "# t_ff_out = data['ff_out']\n", - "# t_ff_l_x = data['ff_l_x']\n", - "# t_ff_l_a_x = data['ff_l_a_x']\n", - "# t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - "# t_ps = data['ps']\n", - "TL = torch.nn.Linear(256, 2048)\n", - "TL2 = torch.nn.Linear(2048, 256)\n", - "TL.load_state_dict({'weight': torch.tensor(t_ps[0]), 'bias': torch.tensor(t_ps[1])})\n", - "TL2.load_state_dict({'weight': torch.tensor(t_ps[2]), 'bias': torch.tensor(t_ps[3])})\n", - "\n", - "# for n, p in TL.named_parameters():\n", - "# print(n, p)\n", - "# for n, p in TL2.named_parameters():\n", - "# print(n, p)\n", - "\n", - "ps = []\n", - "for n, p in TL.state_dict().items():\n", - " ps.append(p.data.numpy())\n", - " \n", - "for n, p in TL2.state_dict().items():\n", - " ps.append(p.data.numpy())\n", - " \n", - "for p, tp in zip(ps, t_ps):\n", - " print(np.allclose(p, tp))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "skilled-vietnamese", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[ 0.67277956 0.08313607 -0.62761104 ... -0.17480263 0.42718208\n", - " -0.5787626 ]\n", - " [ 0.91516656 0.5393416 1.7159258 ... 0.06144593 0.06486575\n", - " -0.03350811]\n", - " [ 0.438351 0.6227843 0.24096036 ... 1.0912522 -0.90929437\n", - " -1.012989 ]\n", - " ...\n", - " [ 0.68631977 0.14240924 0.10763275 ... -0.11513516 0.48065388\n", - " 0.04070369]\n", - " [-0.9525228 0.23197874 0.31264272 ... 0.5312439 0.18773697\n", - " -0.8450228 ]\n", - " [ 0.42024016 -0.04561988 0.54541194 ... -0.41933843 -0.00436018\n", - " -0.06663495]]\n", - "\n", - " [[-0.11638781 -0.33566502 -0.20887226 ... 0.17423287 -0.9195841\n", - " -0.8161046 ]\n", - " [-0.3469874 0.88269687 -0.11887559 ... -0.15566081 0.16357468\n", - " -0.20766167]\n", - " [-0.3847657 0.3984318 -0.06963477 ... -0.00360622 1.2360432\n", - " -0.26811332]\n", - " ...\n", - " [ 0.08230796 -0.46158582 0.54582864 ... 0.15747628 -0.44790155\n", - " 0.06020184]\n", - " [-0.8095085 0.43163058 -0.42837143 ... 0.8627463 0.90656304\n", - " 0.15847842]\n", - " [-1.485811 -0.18216592 -0.8882585 ... 0.32596245 0.7822631\n", - " -0.6460344 ]]]\n", - "[[[ 0.67278004 0.08313602 -0.6276114 ... -0.17480245 0.42718196\n", - " -0.5787625 ]\n", - " [ 0.91516703 0.5393413 1.7159253 ... 0.06144581 0.06486579\n", - " -0.03350812]\n", - " [ 0.43835106 0.62278455 0.24096027 ... 1.0912521 -0.9092943\n", - " -1.0129892 ]\n", - " ...\n", - " [ 0.6863195 0.14240888 0.10763284 ... -0.11513527 0.48065376\n", - " 0.04070365]\n", - " [-0.9525231 0.23197863 0.31264275 ... 0.53124386 0.18773702\n", - " -0.84502304]\n", - " [ 0.42024007 -0.04561983 0.545412 ... -0.41933888 -0.00436005\n", - " -0.066635 ]]\n", - "\n", - " [[-0.11638767 -0.33566508 -0.20887226 ... 0.17423296 -0.9195838\n", - " -0.8161046 ]\n", - " [-0.34698725 0.88269705 -0.11887549 ... -0.15566081 0.16357464\n", - " -0.20766166]\n", - " [-0.3847657 0.3984319 -0.06963488 ... -0.00360619 1.2360426\n", - " -0.26811326]\n", - " ...\n", - " [ 0.08230786 -0.4615857 0.5458287 ... 0.15747619 -0.44790167\n", - " 0.06020182]\n", - " [-0.8095083 0.4316307 -0.42837155 ... 0.862746 0.9065631\n", - " 0.15847899]\n", - " [-1.485811 -0.18216613 -0.8882584 ... 0.32596254 0.7822631\n", - " -0.6460344 ]]]\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "y = L(px)\n", - "print(y.numpy())\n", - "\n", - "ty = TL(tx)\n", - "print(ty.data.numpy())\n", - "print(np.allclose(px.numpy(), tx.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "incorrect-allah", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "prostate-cameroon", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "governmental-surge", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0.04476918 0.554463 -0.3027508 ... -0.49600336 0.3751858\n", - " 0.8254095 ]\n", - " [ 0.95594174 -0.29528382 -1.2899452 ... 0.43718258 0.05584608\n", - " -0.06974669]]\n", - "[[ 0.04476918 0.5544631 -0.3027507 ... -0.49600336 0.37518573\n", - " 0.8254096 ]\n", - " [ 0.95594174 -0.29528376 -1.2899454 ... 0.4371827 0.05584623\n", - " -0.0697467 ]]\n", - "True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 256)\n", - "px = paddle.to_tensor(x, dtype='float32')\n", - "tx = torch.tensor(x, dtype=torch.float32)\n", - "y = L(px)\n", - "print(y.numpy())\n", - "ty = TL(tx)\n", - "print(ty.data.numpy())\n", - "print(np.allclose(px.numpy(), tx.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy(), atol=1e-5))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "confidential-jacket", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "improved-civilization", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "5e7e7c9fde8350084abf1898cf52651cfc84b17a\n" - ] - } - ], - "source": [ - "print(paddle.version.commit)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "d1e2d3b4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " 'commit',\n", - " 'full_version',\n", - " 'istaged',\n", - " 'major',\n", - " 'minor',\n", - " 'mkl',\n", - " 'patch',\n", - " 'rc',\n", - " 'show',\n", - " 'with_mkl']" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(paddle.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "c880c719", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.1.0\n" - ] - } - ], - "source": [ - "print(paddle.version.full_version)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "f26977bf", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "commit: 5e7e7c9fde8350084abf1898cf52651cfc84b17a\n", - "None\n" - ] - } - ], - "source": [ - "print(paddle.version.show())" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "04ad47f6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.6.0\n" - ] - } - ], - "source": [ - "print(torch.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "e1e03830", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " '__version__',\n", - " 'cuda',\n", - " 'debug',\n", - " 'git_version',\n", - " 'hip']" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(torch.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "4ad0389b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'b31f58de6fa8bbda5353b3c77d9be4914399724d'" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.version.git_version" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "7870ea10", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'10.2'" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.version.cuda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "db8ee5a7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6321ec2a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/WarmupLR.ipynb b/.notebook/WarmupLR.ipynb deleted file mode 100644 index 21abf9cbefe4caedd10e5854d27facbc94d15a29..0000000000000000000000000000000000000000 --- a/.notebook/WarmupLR.ipynb +++ /dev/null @@ -1,339 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "d6a0e098", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Union\n", - "\n", - "import torch\n", - "from torch.optim.lr_scheduler import _LRScheduler\n", - "\n", - "from typeguard import check_argument_types\n", - "\n", - "\n", - "class WarmupLR(_LRScheduler):\n", - " \"\"\"The WarmupLR scheduler\n", - " This scheduler is almost same as NoamLR Scheduler except for following\n", - " difference:\n", - " NoamLR:\n", - " lr = optimizer.lr * model_size ** -0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " WarmupLR:\n", - " lr = optimizer.lr * warmup_step ** 0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " Note that the maximum lr equals to optimizer.lr in this scheduler.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " optimizer: torch.optim.Optimizer,\n", - " warmup_steps: Union[int, float] = 25000,\n", - " last_epoch: int = -1,\n", - " ):\n", - " assert check_argument_types()\n", - " self.warmup_steps = warmup_steps\n", - "\n", - " # __init__() must be invoked before setting field\n", - " # because step() is also invoked in __init__()\n", - " super().__init__(optimizer, last_epoch)\n", - "\n", - " def __repr__(self):\n", - " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", - "\n", - " def get_lr(self):\n", - " step_num = self.last_epoch + 1\n", - " return [\n", - " lr\n", - " * self.warmup_steps ** 0.5\n", - " * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n", - " for lr in self.base_lrs\n", - " ]\n", - "\n", - " def set_step(self, step: int):\n", - " self.last_epoch = step" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "0d496677", - "metadata": {}, - "outputs": [], - "source": [ - "import torch.optim as optim\n", - "model = torch.nn.Linear(10, 200)\n", - "optimizer = optim.Adam(model.parameters())\n", - "scheduler = WarmupLR(optimizer, warmup_steps=25000)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "e3e3f3dc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0 0.0 -1\n" - ] - } - ], - "source": [ - "infos = {}\n", - "start_epoch = infos.get('epoch', -1) + 1\n", - "cv_loss = infos.get('cv_loss', 0.0)\n", - "step = infos.get('step', -1)\n", - "print(start_epoch, cv_loss, step)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "dc3d550c", - "metadata": {}, - "outputs": [], - "source": [ - "scheduler.set_step(step)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "e527634e", - "metadata": {}, - "outputs": [], - "source": [ - "lrs=[]\n", - "for i in range(100000):\n", - " scheduler.step()\n", - " lrs.append(scheduler.get_lr())" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f1452db9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting matplotlib\n", - " Downloading matplotlib-3.4.1-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)\n", - "\u001b[K |████████████████████████████████| 10.3 MB 575 kB/s eta 0:00:01\n", - "\u001b[?25hCollecting kiwisolver>=1.0.1\n", - " Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)\n", - "\u001b[K |████████████████████████████████| 1.2 MB 465 kB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (8.1.2)\n", - "Requirement already satisfied: numpy>=1.16 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (1.20.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.8.1)\n", - "Collecting cycler>=0.10\n", - " Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n", - "Requirement already satisfied: pyparsing>=2.2.1 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n", - "Requirement already satisfied: six in /workspace/wenet/venv/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n", - "Installing collected packages: kiwisolver, cycler, matplotlib\n", - "Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n" - ] - } - ], - "source": [ - "!pip install matplotlib\n", - "import matplotlib.pyplot as plt\n", - "\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "0f36d04f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqc0lEQVR4nO3deXxV1b338c8vCUkYkkAghJAEAhLQIJMEHHFCBa2KVkG0T7Wt1qet9ra1w9Xn3ufe1ld7b21tvVq1alut+mhJQK3Yqjig1SpCDgIyBiLTSZhCAglTyLSeP86GxjTDQZKc6ft+vXh5zjrrrLM2O+bL3mvv3zHnHCIiIu2JC/UEREQkvCkoRESkQwoKERHpkIJCREQ6pKAQEZEOJYR6Al1h0KBBLi8vL9TTEBGJKMuXL9/rnMvorF9UBEVeXh4+ny/U0xARiShmti2Yfjr1JCIiHVJQiIhIhxQUIiLSIQWFiIh0SEEhIiIdCioozGymmZWaWZmZ3d3G60lmVuS9vtTM8lq8do/XXmpmM1q0P2lme8xsTaux0s3sTTPb5P13wElsn4iInKROg8LM4oFHgMuBAuBGMyto1e1WYJ9zbhTwAHCf994CYC4wFpgJPOqNB/BHr621u4G3nXP5wNvecxERCZFgjiimAmXOuc3OuXpgHjCrVZ9ZwNPe4wXAdDMzr32ec+6oc24LUOaNh3PuPaC6jc9rOdbTwDXBb450p82VB3m3dE+opyEiPSyYoMgG/C2el3ttbfZxzjUCNcDAIN/bWqZzbqf3eBeQ2VYnM7vdzHxm5qusrAxiM+Rk3fS7pXzlqRLeWrc71FMRkR4U1ovZLvCtSm1+s5Jz7gnnXKFzrjAjo9M70OUkle05wK7aOgC+V7SSTysPhnhGItJTggmKCiC3xfMcr63NPmaWAKQBVUG+t7XdZpbljZUF6FxHGCj2lZMQZ7xy53kkJsRx+zM+DtQ1hHpaItIDggmKEiDfzEaYWSKBxemFrfosBG7xHl8PLPaOBhYCc72rokYA+cCyTj6v5Vi3AC8HMUfpRg1Nzbz4cTnTTxvMuJw0Hr7pDLZWHeZ7RatobtZX6YpEu06DwltzuBNYBKwHip1za83sXjO72uv2B2CgmZUBd+FdqeScWwsUA+uA14E7nHNNAGb2J2AJMMbMys3sVm+snwOXmtkm4BLvuYTQ4g172HuwnjmFgYPDs08ZyL9/4TTeWr+bB9/eFOLZiUh3s8A//CNbYWGhU/XY7nPb0yV8Ul7Dh3dfTEJ84N8Wzjl+uOATFiwv58G5E5k1sbNrFEQk3JjZcudcYWf9wnoxW0JvT20d75RWct3knOMhAWBm/Oza0zlzRDo/nP8JJVvbutJZRKKBgkI6tODjcpqa3fHTTi0lJcTz+Jcnk5Pem68/42PL3kMhmKGIdDcFhbTLOcd8XzlT89IZMahvm33690nkqa9MIc6Mrz61jOpD9T08SxHpbgoKaVfJ1n1s2XuIOVP++WiipeED+/K7myezo6aOrz/j40h9Uw/NUER6goJC2lXs89MvKYErxg3ptO/k4ek8eMNEVmzfxzefW059Y3MPzFBEeoKCQtp0oK6Bv36yk6smZNEnMbivVr98XBY/u3Yc75ZW8oP5usdCJFoE9xtAYs5fP9nJkYamNhexO3Lj1GHsP9zAfa9vIK13L+6dNZZAfUgRiVQKCmlTkc9P/uB+TMztf8Lv/eaFp7D/cD2Pv7eZ/n168f3LxnT9BEWkxygo5J9s2n2AFdv38+9fOO1zHw3cffmp1Bxp4DeLy0iMj+Pb0/O7eJYi0lMUFPJPin1+EuKMayZ9/rutAzfkjaO+sZlfvbkRM7jzYoWFSCRSUMhnBAoAVnDJaZkM6pd0UmPFxxm/nD0BgPvf2AgoLEQikYJCPuPt9XuoOlTPnCk5XTJe67AwM+64aFSXjC0iPUNBIZ8x3+cnMzWJ8/O77sugjoWFA365qJT6xma+e0m+roYSiRAKCjlud20d75Tu4RsXnPKZAoBdIT7OuH/2BBLijAff3kTNkQb+48oC4uIUFiLhTkEhxy1YXk6z44TvnQhWfJxx33XjSUnuxZMfbOFAXSP3XTeuy0NJRLqWgkKAYwUA/UwdkU5eOwUAu0JcnPF/rzyNtN69eOCtjRyoa+A3N00iKSG+2z5TRE6O/iknACzbUs3WqsPc0E1HEy2ZGd+5JJ//vKqAN9bt5qtPlVCr798WCVsKCgGg2FfuFQDM6rHP/Oq5I/j1nAks21LN9b/9kIr9R3rss0UkeAoK4UBdA6+u3slVE4bSO7FnTwF98Ywcnv7aVHbur+PaRz5gTUVNj36+iHROQSH8xSsAeEMn3zvRXc4dNYgF3zyHXvFxzHl8Ce9s2BOSeYhI2xQUQlGJn9GZ/ZiQkxayOYwZksJL3zqHkRl9ufXpEp5ZshXnVKZcJBwoKGLcxt0HWOnfz5zC3JDfADc4NZmi28/mojGD+Y+X13LPi6s52qhvyxMJNQVFjCsu8dMr3rj2JAoAdqW+SQk8cXMhd1x0CvNK/Nz0u6Xsqa0L9bREYpqCIobVNzbz0opAAcCBJ1kAsCvFxxk/nHEqj9x0But21HLVw39npX9/qKclErMUFDFs8YbdgQKAPXDvxOfxhfFZvPDNc0iICyxyF5f4Qz0lkZikoIhhxb5yhqQmc/7orisA2NUKhqbyyrfPo3D4AH70wid8v3gVh+sbQz0tkZiioIhRu2rqeLd0D9dNziY+zAvzpfdN5Nlbz+Rfpufz4opyZj38AWV7DoR6WiIxQ0ERo174OFAAcPbk8Dzt1Fp8nHHXpaN55mtTqT5Uz1W/+YCXVpSHeloiMUFBEYOccxT7/JzZzQUAu8O0/Axe/c40xuWk8b2iVfxowSoOHdWpKJHupKCIQUu3VLOt6nDI7sQ+WZmpyTx/25ncedEo5i8v54qH3mf5tn2hnpZI1FJQxKBin5+UpAQuP73nCgB2tYT4OH4wYwxFt59NY5Nj9mMf8us3N9LQ1BzqqYlEnaCCwsxmmlmpmZWZ2d1tvJ5kZkXe60vNLK/Fa/d47aVmNqOzMc1supl9bGYrzezvZqYvWO5CtccKAE7s+QKA3WHqiHRe++40rpmUzUNvb2L2Y0vYsvdQqKclElU6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5m+BLznnJgLPA/9+Ulson/GXVTupa2juke+d6Cmpyb349ZyJPHzTJLbsPcQVD77PUx9soblZtaJEukIwRxRTgTLn3GbnXD0wD5jVqs8s4Gnv8QJgugUKB80C5jnnjjrntgBl3ngdjemAVO9xGrDj822atKXI52dMZgrjQ1gAsLtcOX4or393GmeOTOcnr6xjzuNL+LTyYKinJRLxggmKbKDlLbHlXlubfZxzjUANMLCD93Y05m3Aq2ZWDnwZ+HlbkzKz283MZ2a+ysrKIDZDSncdYJV/P3OmhL4AYHfJSuvNU1+Zwq9mT2DTnoNc/uD7/PbdT2nU2oXI5xaOi9nfA65wzuUATwG/bquTc+4J51yhc64wIyN87ywOJ8W+8CoA2F3MjOsm5/DmXedz0ZgM7nt9A9c++iHrd9aGemoiESmYoKgAWp7QzvHa2uxjZgkEThlVdfDeNtvNLAOY4Jxb6rUXAecEtSXSoWMFAC8tyCS9b2Kop9MjBqck89j/mswjN53Bjv1HuPI3f+e/Xl2v+y5ETlAwQVEC5JvZCDNLJLA4vbBVn4XALd7j64HFLvCtMwuBud5VUSOAfGBZB2PuA9LMbLQ31qXA+s+/eXLM2+t3U32ontlRtIgdDDPjC+OzeOuuC5hTmMMT721m+q/+xmurd+qLkUSClNBZB+dco5ndCSwC4oEnnXNrzexewOecWwj8AXjWzMqAagK/+PH6FQPrgEbgDudcE0BbY3rtXwdeMLNmAsHxtS7d4hhV5PMHCgDmx+ZpugF9E/nvL45ndmEu//bSGr753MdcMDqDn1w9NuLuThfpaRYN/6oqLCx0Pp8v1NMIWztrjnDuzxfzrQtH8YMZY0I9nZBrbGrm2Y+28as3NlLf1Mw3LjiFb1wwkj6Jnf67SSSqmNly51xhZ/3CcTFbutgLy70CgIU5oZ5KWEiIj+Or547g7e9fwMyxQ3jo7U1cfP/fePHjct17IdIGBUWUa252FPvKOWtkOsMH6hRLS5mpyTx04yQWfONsMlOTuKt4Fdc8+gG+rdWhnppIWFFQRLmlW6rZXh25BQB7QmFeOi9961weuGECe2qPcv1jS7jj+Y/xVx8O9dREwoJOyka5+T4/KcmRXQCwJ8TFGddOymHG2CE8/rfNPP7ep7y5djdfOmsYd1w0ikFh9J3iIj1NRxRRrLaugVfX7OTqCUNJ7hX5BQB7Qp/EBL536Wje+cGFXDspm6c/3MoFv3iHX7+5kQN1DaGenkhIKCii2CurdgQKAOq00wnLSuvNfdeP543vXcAFYzJ46O1NnP+Ld/j9+5upa2gK9fREepSCIooVl/g5dUgK47KjrwBgTxk1uB+PfmkyC+88l9Oz0/jpX9dz0f3v8vzS7dQ3qn6UxAYFRZTasKuWVeU1zCmM3gKAPWl8Tn+evfVMnr/tTDJTk/k/L63mwl++w7NLtnK0UUcYEt0UFFGquKScXvHGNVFeALCnnTNqEC996xye+dpUsvr35v++vJYLfvEuf/xgi05JSdRSUEShQAHAci4rGBIzBQB7kplx/ugMFnzjbJ677UyGpffhx6+sY5q3hnG4XkUHJbro8tgo9Nb63ew73KA7sbuZmXHuqEGcO2oQH22u4qG3N/HTv67n4XfKuPms4dx8Tp4uq5WooKCIQkUlfrLSkpkWowUAQ+GskQM5a+RAlm+r5rG/beahxWU8/t5mrp+cw9enjVThQYloCooos2P/Ed7bVMmdF40iPk6L2D1t8vB0fndzOmV7DvL79zcz31fO88u2M3PsEG4/fySThg0I9RRFTpiCIsq8sLwc52D2ZN07EUqjBvfj59eN565LR/PHD7fy/z7axmtrdjE1L51bzsnjsrGZ9IrXEqFEBpUZjyLNzY4L73+XnAG9ef7rZ4V6OtLCwaONzFu2naeXbMVffYQhqcl8+ezhzJ2Sy0CtY0iIqMx4DPpoSxXbqw8zJ8a+xS4S9EtK4LZpI3n3Bxfxu5sLGTW4H79cVMrZP1/M94tXsbq8JtRTFGmXTj1Fkfm+clKSE5h5+pBQT0XaER9nXFqQyaUFmZTtOcDTH27jhY/LeeHjcs4Y1p8vnz2cy0/PUm0uCSs69RQlao40MPVnbzG7MIefXjMu1NORE1Bb18ACXznPLNnK1qrDpPXuxbWTsrlx6jDGDEkJ9fQkigV76klHFFHilVU7ONrYzA2Fw0I9FTlBqcm9+Np5I/jKOXl8tLmK55dt57ml2/jjh1s5Y1h/bpw6jCvHD6V3oo4yJDR0RBElrn7479Q3NvPad6aptlMUqDp4lBc/ruBPJdvZXHmIlOQErpmYzQ1Tchk7NFX7WLqEjihiyPqdtXxSXsN/XlWgXyBRYmC/JL5+/khumzaCZVuq+dOy7RT5/Dz70TbGZKZw3eRsZk3MJjM1OdRTlRigoIgCxT4/ifFxXDNRBQCjjZlx5siBnDlyID8+XM8rn+zkxY/L+a9XN/Dz1zZwXn4G152RzWUFQ3RqSrqNgiLCHW1s4s8rKrh0bCYDVAAwqvXvk8iXzxrOl88azqeVB3np4wpeWlHBd+atpF9SAl8Yl8UXz8hmSl46cborX7qQgiLCvbVuD/sON+jeiRhzSkY/fjBjDHddOpqPtlTx4scV/OWTHRT5AnW+rhyfxZXjhzI+J02nI+WkaTE7wt385DLKdh/g/X+9WLWdYtzh+kbeWLubv3yyg79trKShyTEsvQ9XTcjiqglDGZOZotCQz9BidgzYsf8I72+q5NsqAChAn8QErpmUzTWTsqk53MCitbt45ZMdPPa3zTzyzqeMGtyPq8YP5coJWZyS0S/U05UIoqCIYAuOFQDUaSdpJa1PL+ZMyWXOlFz2HjzKa2t28ZdVO/iftzfywFsbOXVICpeNHcKMsZkUZOlyW+mYTj1FqOZmxwX3v8Ow9D48d5sKAEpwdtXU8erqnby+dhe+rdU0O8hN782MgiHMOH0IZwwboKPTGKJTT1Huo81V+KuP8IPLxoR6KhJBhqQl87XzRvC180aw9+BR3lq3m0Vrd/HMkm38/u9bGNQviUsLMpl5+hDOHjmQxATVDRUFRcQq9vlJTU5gxlgVAJTPZ1C/JOZOHcbcqcM4UNfAO6WVLFq7i5dXVvCnZdtJSUrg/NEZXHTqYC4ck6GvdY1hQQWFmc0EHgTigd87537e6vUk4BlgMlAF3OCc2+q9dg9wK9AE/ItzblFHY1rgZOlPgdnee37rnHvo5DYzutQcaeC1NbuYU5irKqPSJVKSe3H1hKFcPWEodQ1NfFC2lzfW7uad0j38dfVOzGBCTn+mnzqYi04drDIiMabToDCzeOAR4FKgHCgxs4XOuXUtut0K7HPOjTKzucB9wA1mVgDMBcYCQ4G3zGy09572xvwKkAuc6pxrNrPBXbGh0WThsQKAU7SILV0vuVc800/LZPppmTQ3O9btrOXt9XtYXLqHX725kV+9uZEhqclcdGoGF5+aybmjBtInUScnolkwe3cqUOac2wxgZvOAWUDLoJgF/Nh7vAB42DsymAXMc84dBbaYWZk3Hh2M+U3gJudcM4Bzbs/n37zoVFzi57SsVMYOTQ31VCTKxcUZp2encXp2Gt+5JJ89B+p4t7SSdzbsYeHKHfxpmZ/EhDim5qUzLX8Q54/O4NQhul8j2gQTFNmAv8XzcuDM9vo45xrNrAYY6LV/1Oq9xwoStTfmKQSORq4FKgmcrtrUelJmdjtwO8CwYbFTWnvdjlpWV9TwYxUAlBAYnJLMnMJc5hTmUt/YTMnWahZv2MP7myr579c28N+vbSAjJYlpowYxbfQgzhuVQUaK1jYiXTgeLyYBdc65QjP7IvAkMK11J+fcE8ATELg8tmenGDrHCgDOUgFACbHEhDjOHTWIc0cNAgKX3r6/qZL3N+3l3Y2VvLiiAoDTslI5P38Q0/IzKMwboHW1CBRMUFQQWDM4Jsdra6tPuZklAGkEFrU7em977eXAi97jl4CngphjTDja2MSfV1ZwmQoAShgakpbM7MJcZhfmHl/beG9TJe9v3MuTH2zh8fc2k5QQR2HeAM4eOZCzRg5kfE5/XYIbAYIJihIg38xGEPhlPhe4qVWfhcAtwBLgemCxc86Z2ULgeTP7NYHF7HxgGWAdjPln4CJgC3ABsPFzb12UeXPdbvarAKBEgJZrG9+6cBSHjjaybEs172/ay5LNVdz/RuB/69694gPBccpAzh45kHHZaSTEKzjCTadB4a053AksInAp65POubVmdi/gc84tBP4APOstVlcT+MWP16+YwCJ1I3CHc64JoK0xvY/8OfCcmX0POAjc1nWbG9mKSvxk9+99/FBfJFL0TUrgIu/SWoB9h+pZuqWKJZ9WsWRzFb94vRSAfkkJTDkeHIMoGJqqO8XDgEp4RIiK/Uc4777FfPvifO66dHTnbxCJIHsPHuWjzf8Ijs2VhwBISUrgjOEDmJI3gMK8dCbm9tcaRxdSCY8os8BXDsDsyTkhnolI1xvUL4krxw/lyvFDAdhdW8dHm6tYtqUa39Z9x09V9YoPnNKakpdO4fBAeKRrva7bKSgiQHOzY/5yP+eeMojc9D6hno5It8tMTWbWxOzjV/ftP1zP8m37KNm6D9/Wav74wVaeeG8zAKdk9A0Ehxcewwf20aXjXUxBEQGWbK6ifN8RfjhDBQAlNvXvk3j8bnGAuoYmVlfUULI1cMTx6uqdzCsJ3JqV3jeRibn9mZTbn0nDBjA+N43U5F6hnH7EU1BEABUAFPms5F7xTMlLZ0peOhA46t605yC+bdWs3L6fFf79LN4QKOpgBqMy+gXCY9gAJub2Z3RmP11ddQIUFGGu5nCgAODcKSoAKNKeuDhjzJAUxgxJ4UtnDgcCxTM/Kd/Piu37Wenfz1vrdzN/eWCtr09iPONz0piYGwiOCblpDElN1imrdigowtzCVRXUNzbr3gmRE5TWuxfT8jOYlp8BgHOObVWHWenfz4rt+1jp38/v399MY3Pgys9B/RI5PTuN8d79H+NyFB7HKCjCXJHPT0FWKqdnp4V6KiIRzczIG9SXvEF9uWZSYJG8rqGJtTtqWVNRw+qKGlaX1/Dexkq87GBQvyTGZacyLqc/47LTGJ+TRmZqcgi3IjQUFGFs7Y4a1lTU8pOrx4Z6KiJRKblXPJOHD2Dy8AHH247UN7FuZyA0VlfUsrpiP39rER4ZKUmM8446CrwqzjkDekf1kYeCIozN95WTmBDHrIlDQz0VkZjROzGeycPTmTw8/Xjb4fpG1u+s5ZPywJHHmooa3i3dczw8UpISOC0rldOyUjgtK5WCoamMzkyJmnVFBUWYqmto4qUVFcwYO4T+fXRDkUgo9UlM+KfwOFLfROnuA6zbUcv6nbWs21nLguXlHKpvAiDO4JSMfseD41iQDE6JvFNXCoow9ea63dQcaWBOoe7EFglHvRPjmZjbn4m5/Y+3NTc7tlcfZv3Of4TH8m37WLhqx/E+g/olcVpWCmMyUxg9JIXRmSnkD+5H36Tw/XUcvjOLccU+rwDgKSoAKBIp4uL+sWB++bis4+37D9ezfueB4+Gxfmctz360jaONzcf75Kb3ZkxmCvmZXohkpnDK4L4kJYT+9JWCIgyV7zvM38v28p3p+cSpcqZIxOvfJzFQEfeUgcfbmpod/urDlO4+wMZdByjdfYBNuw/ybmnl8Ut24+OMvIF9GO0Fx5ghKYzO7EfewL49esOggiIMLfBuCrpeBQBFolZ8i6OPllUX6hub2Vp1iI0tAmTDrgMsWrvr+OJ5YnwcIwb1ZdTgftx9+andXgNOQRFmmpsd833lnDdqEDkDVABQJNYkJsQdP4Jg/D/a6xqaKNtzMBAguw9Stucga3fU9Mg3BCoowsyHn1ZRsf8I/3r5qaGeioiEkeRe8ce/NbCnqSpWmCn2+Unr3YvLCjJDPRUREUBBEVZqDjfw+tpdXDNxaNTcqCMikU9BEUZePlYAcIoKAIpI+FBQhJGiEj9jh6YydqgKAIpI+FBQhIk1FTWs3VHLDTqaEJEwo6AIE/N9/kABwAnZoZ6KiMhnKCjCQF1DE39euYOZY4eQ1kff7Ssi4UVBEQbeOF4AUKedRCT8KCjCQHGJn5wBvTmnRR0YEZFwoaAIMX/1YT74dC+zJ+eqAKCIhCUFRYgdLwCo750QkTCloAih5mbHguWBAoDZ/XuHejoiIm1SUITQB5/upWL/ES1ii0hYU1CEULGvnP59enHZWBUAFJHwpaAIkf2H61m0dhfXTMwOi686FBFpT1BBYWYzzazUzMrM7O42Xk8ysyLv9aVmltfitXu89lIzm3ECYz5kZgc/53aFvZdX7ggUANRpJxEJc50GhZnFA48AlwMFwI1mVtCq263APufcKOAB4D7vvQXAXGAsMBN41MziOxvTzAqBASe5bWGtqMTP6dmpFAxNDfVUREQ6FMwRxVSgzDm32TlXD8wDZrXqMwt42nu8AJhuZua1z3POHXXObQHKvPHaHdMLkV8CPzq5TQtfaypqWLezlht0NCEiESCYoMgG/C2el3ttbfZxzjUCNcDADt7b0Zh3Agudczs7mpSZ3W5mPjPzVVZWBrEZ4aPYKwB4tQoAikgECKvFbDMbCswGftNZX+fcE865QudcYUZGRvdProvUNTTx5xUVXH66CgCKSGQIJigqgJbnSHK8tjb7mFkCkAZUdfDe9tonAaOAMjPbCvQxs7IgtyUiLFq7i9q6Ri1ii0jECCYoSoB8MxthZokEFqcXtuqzELjFe3w9sNg557z2ud5VUSOAfGBZe2M65/7qnBvinMtzzuUBh70F8qhR7POTm96bs0eqAKCIRIaEzjo45xrN7E5gERAPPOmcW2tm9wI+59xC4A/As96//qsJ/OLH61cMrAMagTucc00AbY3Z9ZsXXvzVh/mgrIq7Lh2tAoAiEjE6DQoA59yrwKut2v6jxeM6AmsLbb33Z8DPghmzjT79gplfpJi/vBwzuG6yCgCKSOQIq8XsaNbU7Fjg8zMtP0MFAEUkoigoesgHZXvZUVPHHJUTF5EIo6DoIcU+P/379OLSAhUAFJHIoqDoAfsO1fPG2t0qACgiEUlB0QNeXllBfZMKAIpIZFJQdDPnHEW+csZlp6kAoIhEJAVFN1tTUcv6nbXMmaKjCRGJTAqKblbs85OUEMfVE4aGeioiIp+LgqIb1TU08eeVXgHA3ioAKCKRSUHRjRat3cWBukaddhKRiKag6EZFJYECgGeNUAFAEYlcCopu4q8+zIefVjFncq4KAIpIRFNQdJP5Pr8KAIpIVFBQdIOmZseC5eWcn5/BUBUAFJEIp6DoBn8/XgBQi9giEvkUFN2g2OdnQJ9eXFIwONRTERE5aQqKLrbvUD1vrt3NNZNUAFBEooOCoou9tCJQAPAG3TshIlFCQdGFnHMU+/yMz0nj1CEqACgi0UFB0YVWV9SwYdcBLWKLSFRRUHShYwUAr1IBQBGJIgqKLlLX0MTLK3dwxbgsFQAUkaiioOgir6/xCgDqtJOIRBkFRRcpKvEzLL0PZ45ID/VURES6lIKiC2yvOsySzVXMKcxRAUARiToKii4wf7mfOBUAFJEopaA4SccLAI7OICtNBQBFJPooKE7S+5sq2akCgCISxRQUJ2m+r5z0volcclpmqKciItItFBQnofpQPW+s28U1E7NJTNBfpYhEp6B+u5nZTDMrNbMyM7u7jdeTzKzIe32pmeW1eO0er73UzGZ0NqaZPee1rzGzJ80sbO9ee2lFBQ1NTgUARSSqdRoUZhYPPAJcDhQAN5pZQatutwL7nHOjgAeA+7z3FgBzgbHATOBRM4vvZMzngFOBcUBv4LaT2sJu4pxjvs/PhJw0xgxJCfV0RES6TTBHFFOBMufcZudcPTAPmNWqzyzgae/xAmC6mZnXPs85d9Q5twUo88Zrd0zn3KvOAywDwvKa00/KvQKAOpoQkSgXTFBkA/4Wz8u9tjb7OOcagRpgYAfv7XRM75TTl4HX25qUmd1uZj4z81VWVgaxGV2r2OcnuZcKAIpI9AvnFdhHgfecc++39aJz7gnnXKFzrjAjI6NHJ3akvomFK3dwxelZpCaH7RKKiEiXSAiiTwXQ8vxKjtfWVp9yM0sA0oCqTt7b7phm9p9ABvC/g5hfj3t97U4OHG3UaScRiQnBHFGUAPlmNsLMEgksTi9s1WchcIv3+HpgsbfGsBCY610VNQLIJ7Du0O6YZnYbMAO40TnXfHKb1z2KSvwMH6gCgCISGzo9onDONZrZncAiIB540jm31szuBXzOuYXAH4BnzawMqCbwix+vXzGwDmgE7nDONQG0Nab3kY8B24AlgfVwXnTO3dtlW3yStlUd4qPN1fxwxhi8+YmIRLVgTj3hnHsVeLVV23+0eFwHzG7nvT8DfhbMmF57UHMKlfm+8kABwDPC8mIsEZEuF86L2WHnWAHAC0ZnMCQtOdTTERHpEQqKE/Depkp21aoAoIjEFgXFCZjv85PeN5HpKgAoIjFEQRGkqoNHeXPdbq6dpAKAIhJb9BsvSMcKAOq0k4jEGgVFEJxzFPv8TMjtrwKAIhJzFBRBWFVew8bdB7lBRxMiEoMUFEH4RwHArFBPRUSkxykoOnGkvolXVu7ginFZpKgAoIjEIAVFJ15bEygAqNNOIhKrFBSdKCrxkzewD1NVAFBEYpSCogNb9x5i6ZZqZhfmqgCgiMQsBUUH5i/3qwCgiMQ8BUU7jhUAvHDMYBUAFJGYpqBox3sbK9lde5Q5hTqaEJHYpqBoR1GJn4F9E7n4VBUAFJHYpqBoQ9XBo7y1XgUARURAQdGml1ZU0NjsmDNF906IiCgoWnHOUVTiZ2Juf0ZnqgCgiIiCopWV/v1s2nOQG3Q0ISICKCj+SbGvnN694rlyvAoAioiAguIzDtc38soqFQAUEWlJQdHCa6t3cfBoo047iYi0oKBoocjnZ8SgvkzJGxDqqYiIhA0FhWfL3kMs21LN7MIcFQAUEWlBQeGZ71MBQBGRtigogMamZl74uJyLxgwmM1UFAEVEWlJQAO9tChQAnK1vsRMR+ScKCgIFAAf1S2T6aYNDPRURkbAT80Gx9+BR3l6/h2snZdMrPub/OkRE/knM/2Z86eNAAUDdOyEi0raggsLMZppZqZmVmdndbbyeZGZF3utLzSyvxWv3eO2lZjajszHNbIQ3Rpk3ZuJJbmO7nHMU+/ycMaw/owarAKCISFs6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5n3AA95Y+7yxu8UKrwDgHC1ii4i0K5gjiqlAmXNus3OuHpgHzGrVZxbwtPd4ATDdAnetzQLmOeeOOue2AGXeeG2O6b3nYm8MvDGv+dxb14n5Pn+gAOCEod31ESIiES+YoMgG/C2el3ttbfZxzjUCNcDADt7bXvtAYL83RnufBYCZ3W5mPjPzVVZWBrEZ/2xYel++cm4e/ZISPtf7RURiQcT+hnTOPQE8AVBYWOg+zxjfvPCULp2TiEg0CuaIogJoeRI/x2trs4+ZJQBpQFUH722vvQro743R3meJiEgPCiYoSoB872qkRAKL0wtb9VkI3OI9vh5Y7JxzXvtc76qoEUA+sKy9Mb33vOONgTfmy59/80RE5GR1eurJOddoZncCi4B44Enn3FozuxfwOecWAn8AnjWzMqCawC9+vH7FwDqgEbjDOdcE0NaY3kf+KzDPzH4KrPDGFhGRELHAP+IjW2FhofP5fKGehohIRDGz5c65ws76xfyd2SIi0jEFhYiIdEhBISIiHVJQiIhIh6JiMdvMKoFtn/Ptg4C9XTidSKBtjg3a5uh3sts73DmX0VmnqAiKk2FmvmBW/aOJtjk2aJujX09tr049iYhIhxQUIiLSIQWFV1gwxmibY4O2Ofr1yPbG/BqFiIh0TEcUIiLSIQWFiIh0KKaDwsxmmlmpmZWZ2d2hns+JMLNcM3vHzNaZ2Voz+47Xnm5mb5rZJu+/A7x2M7OHvG39xMzOaDHWLV7/TWZ2S4v2yWa22nvPQ95X1Yac973rK8zsL97zEWa21JtnkVe6Hq+8fZHXvtTM8lqMcY/XXmpmM1q0h93PhJn1N7MFZrbBzNab2dnRvp/N7Hvez/UaM/uTmSVH2342syfNbI+ZrWnR1u37tb3P6JBzLib/EChv/ikwEkgEVgEFoZ7XCcw/CzjDe5wCbAQKgF8Ad3vtdwP3eY+vAF4DDDgLWOq1pwObvf8O8B4P8F5b5vU1772Xh3q7vXndBTwP/MV7XgzM9R4/BnzTe/wt4DHv8VygyHtc4O3vJGCE93MQH64/EwS+O/4273Ei0D+a9zOBrz/eAvRusX+/Em37GTgfOANY06Kt2/dre5/R4VxD/T9BCH8YzwYWtXh+D3BPqOd1EtvzMnApUApkeW1ZQKn3+HHgxhb9S73XbwQeb9H+uNeWBWxo0f6ZfiHczhzgbeBi4C/e/wR7gYTW+5XA952c7T1O8PpZ6319rF84/kwQ+LbILXgXnrTef9G4nwkEhd/75Zfg7ecZ0bifgTw+GxTdvl/b+4yO/sTyqadjP4zHlHttEcc71J4ELAUynXM7vZd2AZne4/a2t6P28jbaQ+1/gB8Bzd7zgcB+51yj97zlPI9vm/d6jdf/RP8uQmkEUAk85Z1u+72Z9SWK97NzrgK4H9gO7CSw35YT3fv5mJ7Yr+19RrtiOSiigpn1A14Avuucq235mgv8kyFqrn82syuBPc655aGeSw9KIHB64rfOuUnAIQKnC46Lwv08AJhFICSHAn2BmSGdVAj0xH4N9jNiOSgqgNwWz3O8tohhZr0IhMRzzrkXvebdZpblvZ4F7PHa29vejtpz2mgPpXOBq81sKzCPwOmnB4H+Znbsa31bzvP4tnmvpwFVnPjfRSiVA+XOuaXe8wUEgiOa9/MlwBbnXKVzrgF4kcC+j+b9fExP7Nf2PqNdsRwUJUC+dyVFIoFFsIUhnlPQvCsY/gCsd879usVLC4FjVz7cQmDt4lj7zd7VE2cBNd7h5yLgMjMb4P1L7jIC5293ArVmdpb3WTe3GCsknHP3OOdynHN5BPbXYufcl4B3gOu9bq23+djfxfVef+e1z/WulhkB5BNY+Au7nwnn3C7Ab2ZjvKbpBL6DPmr3M4FTTmeZWR9vTse2OWr3cws9sV/b+4z2hXLRKtR/CFxJsJHAFRD/Fur5nODczyNwyPgJsNL7cwWBc7NvA5uAt4B0r78Bj3jbuhoobDHW14Ay789XW7QXAmu89zxMqwXVEG//hfzjqqeRBH4BlAHzgSSvPdl7Xua9PrLF+//N265SWlzlE44/E8BEwOft6z8TuLolqvcz8BNggzevZwlcuRRV+xn4E4E1mAYCR4639sR+be8zOvqjEh4iItKhWD71JCIiQVBQiIhIhxQUIiLSIQWFiIh0SEEhIiIdUlCIiEiHFBQiItKh/w/uhegfvR+Q7QAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "xs = list(range(100000))\n", - "plt.plot(xs, lrs)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "4f4e282c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/wenet/venv/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "from typing import Union\n", - "\n", - "from paddle.optimizer.lr import LRScheduler\n", - "from typeguard import check_argument_types\n", - "\n", - "class WarmupLR(LRScheduler):\n", - " \"\"\"The WarmupLR scheduler\n", - " This scheduler is almost same as NoamLR Scheduler except for following\n", - " difference:\n", - " NoamLR:\n", - " lr = optimizer.lr * model_size ** -0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " WarmupLR:\n", - " lr = optimizer.lr * warmup_step ** 0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " Note that the maximum lr equals to optimizer.lr in this scheduler.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " warmup_steps: Union[int, float]=25000,\n", - " learning_rate=1.0,\n", - " last_epoch=-1,\n", - " verbose=False):\n", - " assert check_argument_types()\n", - " self.warmup_steps = warmup_steps\n", - " super().__init__(learning_rate, last_epoch, verbose)\n", - "\n", - " def __repr__(self):\n", - " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", - "\n", - " def get_lr(self):\n", - " step_num = self.last_epoch + 1\n", - " return self.base_lr * self.warmup_steps**0.5 * min(\n", - " step_num**-0.5, step_num * self.warmup_steps**-1.5)\n", - "\n", - " def set_step(self, step: int):\n", - " self.step(step)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "8c40b202", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-1\n" - ] - } - ], - "source": [ - "sc = WarmupLR(warmup_steps=25000, learning_rate=0.001)\n", - "print(step)\n", - "#sc.set_step(step)\n", - "sc.set_step(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "ecbc7e37", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqaUlEQVR4nO3de3xU9Z3/8dcnCUm4JIGEEAIBEiCAQW4SEG94F7QqagGhu9Varb9a3W51267+tr/dtrvdVevW1VardrVaa4WAN7QqKqJ4QchwvwYiAZMQICQQ7uT2/f0xB4xpLoMkmcnM+/l48GDmO99z5ns4Yd4553vOZ8w5h4iISHOigj0AEREJbQoKERFpkYJCRERapKAQEZEWKShERKRFMcEeQFvo3bu3y8zMDPYwREQ6lRUrVux1zqW21i8sgiIzMxOfzxfsYYiIdCpmtiOQfjr1JCIiLVJQiIhIixQUIiLSIgWFiIi0SEEhIiItCigozGyqmRWYWaGZ3dvE63FmNtd7fZmZZTZ47T6vvcDMpjRof8bM9pjZ+kbrSjazd81sq/d3r9PYPhEROU2tBoWZRQOPAVcCOcBsM8tp1O1WYJ9zbijwMPCAt2wOMAsYCUwFHvfWB/Cs19bYvcAi51w2sMh7LiIiQRLIEcVEoNA5t805Vw3MAaY16jMNeM57PB+41MzMa5/jnDvunCsCCr314ZxbAlQ28X4N1/UccF3gmyPtaVv5IT4o2BPsYYhIBwskKPoDxQ2el3htTfZxztUCVUBKgMs2luacK/Me7wLSmupkZrebmc/MfOXl5QFshpyuWU99xnf+mM+iTbuDPRQR6UAhPZnt/N+q1OQ3KznnnnLO5TrnclNTW70DXU7T1t0H2XPwOAA/mrOaz8sPBXlEItJRAgmKUmBAg+cZXluTfcwsBkgCKgJctrHdZpburSsd0LmOEJDnKyYmynj9rvPpEhPF7X/ycfBYTbCHJSIdIJCgyAeyzSzLzGLxT04vaNRnAXCz93g68L53NLAAmOVdFZUFZAPLW3m/huu6GXgtgDFKO6qpq+fllaVcdkYaozKSeOxbZ7G94gj35K2hvl5fpSsS7loNCm/O4S5gIbAJyHPObTCzX5rZtV63p4EUMysE7sG7Usk5twHIAzYCbwN3OufqAMzsRWApMNzMSszsVm9d9wOXm9lW4DLvuQTRok17qDhczcwJGQCcMySFf7nqDN7duJtH398a5NGJSHsz/y/+nVtubq5T9dj2c+uz+azfWcUn/3wJMdH+3y2cc/x43lpeWlnCI7PGMm1sa9coiEioMbMVzrnc1vqF9GS2BN/uA8dYXLCHb56VcTIkAMyM/7zhTCZmJfOTeWvJ397Ulc4iEg4UFNKil1aWUO9gZu6Av3ktLiaap749noxeXbn9Tz627z0chBGKSHtTUEiznHPM85UwMSuZzN7dm+zTs1ssf7xlAmbGLc/ms+9wdQePUkTam4JCmrW8qJKivYe5sYmjiYYGpXTnDzeNp3T/Ub73Jx9Hq+s6aIQi0hEUFNKsPF8JPeJiuHJU31b7jh+UzP/cOJYVX+zjBy+soKauvgNGKCIdQUEhTTp4rIY315VxzZh+dIsN7KvVrxqVzn9eP4rFBeX8eJ7usRAJF4F9AkjEeWNtGUdr6rhxQsunnRqbPXEg+45U8+DbBSR17cIvrh2Jvz6kiHRWCgpp0tz8Yoal9WBMRtIpL3vHhUPYf6SGp5Zso2fXLtxzxfB2GKGIdBQFhfyNLbsPsrp4Pz/7xhlf62jAzLjvyhFUHanh0fcLiY2J4q5LstthpCLSERQU8jfy8ovpEm1cP+7r323tvyFvFDV19Tz0zhbMjDsvHtqGoxSRjqKgkK+orq3nlVX+AoApPeJOa13RUcavZ4zBAb9eWACgsBDphBQU8hXvb97tLwDYyr0TgYqOMh6aMQbwh4UZ/OAihYVIZ6KgkK/I85XQNzGeycPa7sugToSFc44H3y6gptbxw0uH6mookU5CQSEn7ao6xgcFe7jjoiFER7Xth3h0lPHfM8cSEx3Fw+9toepoDT/7xhlEtfH7iEjbU1DISScKAM4Y3zannRqLjjIe/OZoEuJjeOaTIg4cq+H+G0Z9pSqtiIQeBYUAJwoAFnN2CwUA20JUlPGvV+eQ1LUL//PeVg4dq+WR2WOJi4lut/cUkdOjX+UEgGVFlWyvOHLKd2J/HWbGjy4bxr9encPbG3bx3WfzOaDv3xYJWQoKASDPV0xCXAxXnpneYe/53fOz+O8ZY1i2rZIZv1/Kzv1HO+y9RSRwCgrhwIkCgGP70TW2Y08BfXN8Bs/eMpGd+49y3WOfsL60qkPfX0Rap6AQ3lhTxrGa+ja7d+JUnZ/dm3l3nENMlHHjk0tZXLAnKOMQkaYpKIS5vmKGpyV8rQKAbWVE30ReufM8BqV057bnfDy/dDvOqUy5SChQUES4gl0HWVO8n5kTBgT9Bri0xHjyvn8OFw5L5f+9toH/+8o6qmv1BUgiwaagiHB5vtMvANiWesTF8IebcvnBRUN4cXkxs//wGXsOHgv2sEQimoIigp0oAHh5ThrJ3WODPZyToqOMn04dwe++NY6NOw9w7W8/YW3J/mAPSyRiKSgi2KJNu6k8XM2MIE1it+bq0f2Yf8c5REcZ059YyjxfcbCHJBKRFBQRLM9X7C8AmN12BQDb2sh+SSy46zxyB/XiJ/PX8uN5azhaXRfsYYlEFAVFhNpVdYwPt5QzfXxGmxcAbGspPeJ4/taz+eElQ3lpZQnTHvuYwj0Hgz0skYihoIhQJwsA5mYEeygBiY4y7rliOM/dMpGKQ9Vc+7tPeGVVSbCHJRIRFBQRqL7ekecrZtLgZAaltF8BwPYweVgqf/3hBZzZL4m7567hp/PXcPh4bbCHJRLWFBQRaPn2SnZ0UAHA9tA3KZ6/fO9sfnDREOatKOGqRz9i5Rf7gj0skbCloIhAefn+AoBTR3ZcAcC2FhMdxU+njmDO9yZRW+eY8cRSHn53C7V1ukFPpK0FFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTWlunmV1qZivNbLWZfWxm+oLlNnTgWA1vri/j2iAUAGwPZw9O4a0fXcC0Mf14ZNFWpj+xlKK9h4M9LJGw0mpQmFk08BhwJZADzDaznEbdbgX2OeeGAg8DD3jL5gCzgJHAVOBxM4tuZZ2/B/7OOTcW+Avws9PaQvmK19fsDGoBwPaQGN+F39w4lt/OHse28kNc9chHPPtJEfX1qhUl0hYCOaKYCBQ657Y556qBOcC0Rn2mAc95j+cDl5q/cNA0YI5z7rhzrggo9NbX0jodkOg9TgJ2fr1Nk6bk5Rczom8Co4NYALC9XDOmHwvvnszErGR+/vpGZj65lG3lh4I9LJFOL5Cg6A80vCW2xGtrso9zrhaoAlJaWLaldd4GvGlmJcC3gfubGpSZ3W5mPjPzlZeXB7AZsnnXAdaUVDEzN/gFANtLelJXnr1lAg/NGMOW3QeZ+shHPPHh55q7EDkNoTiZfTdwlXMuA/gj8JumOjnnnnLO5TrnclNTQ/fO4lCSl19Cl2jjuhApANhezIzp4zN4754LuXh4Kve/tZkbfv8pm3cdCPbQRDqlQIKiFGh4QjvDa2uyj5nF4D9lVNHCsk22m1kqMMY5t8xrnwucG9CWSIv8BQBLuCKnb0gVAGxPfRLjeeLvx/O7b42jdN9RvvHox/znm5t034XIKQokKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzv/N86swCY5V0VlQVkA8tbWOc+IMnMhnnruhzY9PU3T054b9Nu9h2p6TR3YrcVM+Pq0f14754LmTE+g6eWbOOy33zIW+vK9MVIIgGKaa2Dc67WzO4CFgLRwDPOuQ1m9kvA55xbADwNPG9mhUAl/g9+vH55wEagFrjTOVcH0NQ6vfbvAS+ZWT3+4Phum25xhMrzFZOeFM8FIVwAsD316h7L/d8czYzcAfzs1fXc8cJKLhyWyi+uHUlm7851d7pIR7Nw+K0qNzfX+Xy+YA8jZJVVHeW8+9/nzouH8k9XDA/2cIKutq6ePy3dwW/e3UJ1XT3fv3AI379wMN1iW/29SSSsmNkK51xua/1CcTJb2thLK7wCgOPD596J0xETHcV3z89i0T9dyJSRfXl00VYueehDXl5ZonsvRJqgoAhz/gKAJZwzOIWBKd2CPZyQkpYYz29nj2Pe98+hT2Ic9+St4brHP8G3vTLYQxMJKQqKMLesqJIvKjtvAcCOMCEzmVd/cB6/mTmG3QeOMf2Jpdz5l5UUVx4J9tBEQoJOyoa5PF8xCfExTD2zb7CHEtKioowbzspg6pl9efLDbTy55HPe3bCbv580iDsvHkJKj7hgD1EkaHREEcaqjtbw5roypo3tR3yXzl8AsCN0i43h7suHsfjHF3H9uP48+2kRkx9czMPvbuHgsZpgD08kKBQUYez1NTs5XhteBQA7SnpSVx6YPpp37p7M5GGpPLJoKxf++gOe/riIYzX6zm6JLAqKMJbn8xcAHNU//AoAdpShfRL4/d+P57U7zyMnPZF/f2Mjlzz0AS8u/4LqWtWPksigoAhTm8oOsDbMCwB2pDEDevLn287mhdvOJjUxnvteXsfFD33Anz/bwfFaHWFIeFNQhKk8XzGx0VFcH+YFADvaeUN78+oPzuXZWybQJzGOn726ngsf/IDnPt2uU1ISthQUYeh4bR2vrirl8pFp9IqQAoAdycy4aHgfXr7jXP5869kMSO7Kvy3YwOQHF/P0x0UcrVZgSHjR5bFh6L2Ne9h3pEaT2O3MzDg/uzfnDU1h6bYKHl20lX9/YyO/e38rN52TyU3nDNJltRIWFBRhKM9XTL+keM4f2jvYQ4kIZsa5Q3pz7pDe5G+v5MkPP+eRRVt5csnnzBg/gNsuyGJQigoPSueloAgzO/cfZcnWcv7h4qFER2kSu6NNyExmQmYyhXsO8tSSbczNL+aFZTu48sx0bp88mDEDegZ7iCKnTEERZl5aUYJzMEOnnYJqaJ8EHpw+hh9fMZw/frqdP3+2g7+uK2NiVjLfOTeTK3LSiInWFKF0DiozHkbq6x0XPrSYAb268ZfvTQr2cKSBQ8drmbP8C579dDsl+47SLymev5s0iNkTB0bMNw5K6FGZ8Qj0WVEFxZVHVQAwBPWIi+G2Cwbz4U8u5g835ZKV2p1fLyxg0n8t4sfz1rC+tCrYQxRplk49hZG8fH8BwCkjVQAwVEVHGZfnpHF5Thpbdx/kuaXbeXllKfNXlDB+UC9uOmcQU0b2VW0uCSkKijBRdbSGt9bvYmbuAH3IdBLZaQn8x3Wj+MmUEcxfUcLzS7fzj3NW07NbF24Yl8HsiQPITksI9jBFFBThYoEKAHZaSV27cOv5WdxybiZLt1Xw4vIveP6z7TzzSRG5g3oxe+JArhqVTtdY/QIgwaHJ7DBxzW8/prbe8eYPz1dtpzBQceg4L68s5cXlX7Bt72ES4mO4YVx/Zk4YwMh+KvIobSPQyWwdUYSBjTsPsK60in+7JkchESZSesTxvcmDue2CLJYXVfLi8i94Mb+Y55buYETfBL55VgbTxvajT2J8sIcqEUBBEQZOFAC8bqwKAIYbM+PswSmcPTiFnx+p5vW1Zby0ooRfvbmJ/3prE5OHpXLDWRlckZOmuSlpNwqKTu54bR2vrlYBwEjQs1ss3540iG9PGsTn5Yd4eWUJr6ws5YcvriIhLoZvjE7nhrMyyB3UiyjdlS9tSEHRyb27cTf7j9RwoyaxI8qQ1B78ZMoI/uny4XxWVMFLK0pZsGYnc/L9db6uHtOPq0enM6p/kk5HymnTZHYnd9Mzy/l8zyGW/PRi1XaKcEeqa3lnw25eX7OTJVvLqalzDErpxjWj+3H1mHSGpyUoNOQrNJkdAUr3H+WjreX8wyXZCgmhW2wM143rz3Xj+lN1pIaFG3bx+tqdPP5BIb9bXEh2nx5cPbof14xJZ3Bqj2APVzoRBUUndrIA4PiMYA9FQkxSty7MnDCAmRMGsPfQcd5aV8bra8t4+L0tPPzeFkb0TWDKyL5MGdmXM9J1pCEt06mnTqq+3jH514sZlNKNF25TAUAJTFnVUd5ct4uF63eRv6MS52BgcjemjExjysi+nDVQE+GRRKeewtxn2yoo2XeUn0wZHuyhSCeSntSVW8/P4tbzsyg/eJz3Nu1m4YZdPPvpdv7wURGpCXFcnpPG1JF9mTQ4hdgY1Q0VBUWnNddXTKIKAMppSE2IY/bEgcyeOJADx2pYvHkPCzfs4tVVpfxl2RckxMcweVgql47ow0XD+6gcegQLKCjMbCrwCBAN/K9z7v5Gr8cBfwLGAxXAjc657d5r9wG3AnXAD51zC1tap/lPlv4HMMNb5vfOuUdPbzPDS9URfwHAWRNUAFDaRmJ8F6aN7c+0sf05VlPHx1v38s7GXSwuKOeva8swg3EDenLJiD5cMiJN8xoRptWgMLNo4DHgcqAEyDezBc65jQ263Qrsc84NNbNZwAPAjWaWA8wCRgL9gPfMbJi3THPr/A4wABjhnKs3sz5tsaHhZMGaUqpVAFDaSXyXaC7LSeOynDTq6x3rd1bx/uY9vL95Dw+9s4WH3tlCelI8F4/owyXD+3De0N4qWBjmAjmimAgUOue2AZjZHGAa0DAopgE/9x7PB37nHRlMA+Y4544DRWZW6K2PFtZ5B/At51w9gHNuz9ffvPA011dMTnoiZ/ZXcThpX1FRxuiMnozO6MmPLhvGngPH+KCgnEWbd/Oad4oqLiaKiVnJTM5O5YJhvXW/RhgKJCj6A8UNnpcAZzfXxzlXa2ZVQIrX/lmjZU8UJGpunUPwH41cD5TjP121tfGgzOx24HaAgQMHBrAZ4WHDzirWlx7g59fkBHsoEoH6JMafvOz2eG0dy4sqWby5nI+2lvOrNzfBm/65jwuyezM5O5XzhvYmNSEu2MOW0xSKk9lxwDHnXK6Z3QA8A1zQuJNz7ingKfBfHtuxQwyeeb4SfwHAcSoAKMEVFxPNBdmpXJCdCvgvvf1o614+2rqXxZv38PLKUgBy0hO5YJg/OHIzexEXo9NUnU0gQVGKf87ghAyvrak+JWYWAyThn9Ruadnm2kuAl73HrwB/DGCMEeFYTR2vrCrlipFp9OymK1AktKQndWVm7gBm5g6gvt6xYecBlmwtZ8mWcp75uIgnP9xGfJcocgclc86QFCYNTmF0RhJdonUJbqgLJCjygWwzy8L/YT4L+FajPguAm4GlwHTgfeecM7MFwF/M7Df4J7OzgeWAtbDOV4GLgSLgQmDL1966MPPuxt1UHa3hxgmaxJbQFhVljMpIYlRGEndePJTDx2tZVlTBki17+WxbBb9eWABAt9hoJmQmM2lwCucMSeHMfonEKDhCTqtB4c053AUsxH8p6zPOuQ1m9kvA55xbADwNPO9NVlfi/+DH65eHf5K6FrjTOVcH0NQ6vbe8H3jBzO4GDgG3td3mdm55vmL69+zKeUN6B3soIqeke1wMl4xI45IRaYD/G/yWFVXy2bYKln5ewQNvbwYgIS6GCVnJnOMFxxnpiapjFgJUwqOTKNl3hAseXMwPL8nm7suHtb6ASCdSfvC4PzS2VfDZ5xVs23sYgIT4GMYP6sWEzGRyB/VizICeuneoDamER5h5aYV/CmdGrgoASvhJTYjjmjH9uGZMPwB2VR3js20VLN9eSX5RJR8U+E9VdYk2RvVP8geHFx76wq72pyOKTuBEAcDMlO78+bbGVyaLhL99h6tZsWMf+Tsq8W3fx9qS/dTU+T+7hvbpwYTMXuQOSmZCZjIDkrvqPo4A6YgijCz1CgD+dOqIYA9FJCh6dY89ebc4+K8AXFtSRf72SnzbK3ljbRkvLvffmpXSPZaxA3oybmBPxg3sxeiMJBLiuwRz+J2egqITmJtfTFLXLlzh/ScRiXTxXaKZmJXMxKxkwH/UvWXPQfK372P1F/tZXbyPRZv9RR3MILtPDy88ejF2QE+GpSVokvwUKChCXNWRGt7esIvZKgAo0qyoKGNE30RG9E3k25MGAf7/O2tK9rPKC453Nu4mz1cCQPfYaEZlJJ0MjtEZSfRNjNcpq2YoKELca14BwBkqAChySpK6dWHysFQmD/PfOe6cY0fFEVYV+486VhXv5w9LtlFb75/r6N0jjlH9ExnVP4lRGT0Z1T+JtMQ4hQcKipCX5ytmZD8VABQ5XWZGZu/uZPbuzvXj/FcPHqupY8POA6wvrWJtSRXrS6v4cEs5XnbQu0ccozOSOLN/EqP7+28gTEuMD+JWBIeCIoSdKAD4i2tHBnsoImEpvks04wf1YvygXifbjlTXsqnsAOtKqlhb6g+PDwr2nAyP1IQ4Rvf3h0dOv0Ry0hPJ6BXeV1opKEJYXn4xsTFRTBvbL9hDEYkY3WJjGD8omfGDkk+2HamuZePOA6wrrWJdSRXrSqtY3CA8EuJjOKNvIjn9EjkjPYEz0hMZlpYQNvOKCooQdaymjldX72TKyL4qACgSZN1iY/w3+GV+NTwKdh1kU9lBNpZVsansIHm+Yo5U1wEQHWUM7t3dCw//n5z0xE5Zdl1BEaLeOVEAUJPYIiGpW2wM4wb2YtzAL09b1dc7vqg8wsayA2wqO8DGnQfIL6rktdU7T/bp3SOOM9ITGJ6WwLC+/r+z03rQLTZ0P45Dd2QRbp5XAPDcISnBHoqIBCgq6ssJ86tGpZ9s33e4mk27/MGxqewgm8oO8KeiHVTX1p/sMyC5qz880hIY3tf/9+DU7iHx/R0KihBUsu8IHxfu5R8vzSZKNwWJdHq9usdy7pDenNug8nNdvWNHxWG27D7Ilt2HKNh9kC27DvJBQfnJS3ajo4zMlG4ng+PEn8yUbh1ajl1BEYLmr/DfFDR9vAoAioSr6ChjcGoPBqf2YOqZX7ZX19ZTtPfwyeDYsvsgG3ce4K31uzhRmi82Ooqs3t0ZmtaDe6eOYEByt3Ydq4IixNTXO+b5Sjh/aG8yerXvzheR0BMbE8Xwvv7TT4z5sv1odR2Few55RyAHKdxziHUlVcTGtP+RhYIixHz6eQWl+49y75UqACgiX+rqlR0ZldHxN9/qOwdDzFyfvwDg5SoAKCIhQkERQvYfqWbhhl1cP65/2NyoIyKdn4IihLy2eqdXAFCT2CISOhQUISTPV8yZ/RMZ2U8FAEUkdCgoQsT60io27DzATN2JLSIhRkERIvJ8XgHAMf2DPRQRka9QUISAYzV1vLqqlKkj+5LUTd/tKyKhRUERAhZu2MWBY7XcOEGnnUQk9CgoQsA8XwkZvbpyzmAVABSR0KOgCLLiSn8BwBnjB6gAoIiEJAVFkM1fUYIZTNe9EyISohQUQVRX75i/wl8AsH/PrsEejohIkxQUQfTp53sp3X9Uk9giEtIUFEE0N7+Ynt1UAFBEQpuCIkj2H6nmnQ27uW5s/5D4qkMRkeYEFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTTmGdj5rZoa+5XSHv1VWlVNfVq2SHiIS8VoPCzKKBx4ArgRxgtpnlNOp2K7DPOTcUeBh4wFs2B5gFjASmAo+bWXRr6zSzXKDXaW5bSMvzlTCqfxI5/RKDPRQRkRYFckQxESh0zm1zzlUDc4BpjfpMA57zHs8HLjUz89rnOOeOO+eKgEJvfc2u0wuRXwM/Pb1NC13rS6vYWHaAmbokVkQ6gUCCoj9Q3OB5idfWZB/nXC1QBaS0sGxL67wLWOCcK2tpUGZ2u5n5zMxXXl4ewGaEjjxfMXExUVw7VgUARST0hdRktpn1A2YAv22tr3PuKedcrnMuNzU1tf0H10ZOFgA8sy9JXVUAUERCXyBBUQo0nHHN8Nqa7GNmMUASUNHCss21jwOGAoVmth3oZmaFAW5Lp3CyAKAmsUWkkwgkKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzvnHNe+yzvqqgsIBtY3tw6nXN/dc71dc5lOucygSPeBHnYyPMVMyC5K5NUAFBEOomY1jo452rN7C5gIRANPOOc22BmvwR8zrkFwNPA895v/5X4P/jx+uUBG4Fa4E7nXB1AU+ts+80LLcWVR/iksIJ7Lh+mAoAi0mm0GhQAzrk3gTcbtf1rg8fH8M8tNLXsr4BfBbLOJvr0CGR8ncU8rwDgN8fraicR6TxCajI7nNXVO+b7irkgO1UFAEWkU1FQdJBPCveys+qYJrFFpNNRUHSQub5ienXrwmU5fYI9FBGRU6Kg6AD7Dlfz7obdXDdOBQBFpPNRUHSAV1erAKCIdF4KinbmnGNufjGjM5I4I10FAEWk81FQtLP1pQfYvOsgM3Q0ISKdlIKinZ0sADimX7CHIiLytSgo2tGxmjpeXV3KlSoAKCKdmIKiHb29fhcHj9Uyc4JOO4lI56WgaEcnCwBmqQCgiHReCop28kXFET79vIKZ4weoAKCIdGoKinYyf0WxCgCKSFhQULSDunrHvBUlTM5OpZ8KAIpIJ6egaAcfF+6lrOoYN2oSW0TCgIKiHeTl+wsAXnqGCgCKSOenoGhjlYereWfjLq4fl6ECgCISFhQUbezVVaXU1DlmTtAktoiEBwVFG3LOkecrZkxGEiP6qgCgiIQHBUUbWldapQKAIhJ2FBRt6GQBwLEqACgi4UNB0UaO1dTx2uqdXDUqncR4FQAUkfChoGgjJwsA6rSTiIQZBUUbmZtfzMDkbpydlRzsoYiItCkFRRvYUXGYpdsqmJmboQKAIhJ2FBRtYP6KEqJUAFBEwpSC4jTV1Tvmryhh8rBU0pNUAFBEwo+C4jR9tLWcsqpjmsQWkbCloDhNeb5ikrvHctkZacEeiohIu1BQnIbKw9W8u3E314/rT2yM/ilFJDwF9OlmZlPNrMDMCs3s3iZejzOzud7ry8wss8Fr93ntBWY2pbV1mtkLXvt6M3vGzEL27rVXThQA1GknEQljrQaFmUUDjwFXAjnAbDPLadTtVmCfc24o8DDwgLdsDjALGAlMBR43s+hW1vkCMAIYBXQFbjutLWwnzjnm+YoZM6Anw/smBHs4IiLtJpAjiolAoXNum3OuGpgDTGvUZxrwnPd4PnCpmZnXPsc5d9w5VwQUeutrdp3OuTedB1gOhOQ1p2tL/AUAZ+aG5PBERNpMIEHRHyhu8LzEa2uyj3OuFqgCUlpYttV1eqecvg283dSgzOx2M/OZma+8vDyAzWhbeb5i4rtEcc0YFQAUkfAWyjOwjwNLnHMfNfWic+4p51yucy43NTW1Qwd2tLqOBat3ctWZKgAoIuEvJoA+pUDD2doMr62pPiVmFgMkARWtLNvsOs3s34BU4P8EML4O9/aGMg4er2XmBE1ii0j4C+SIIh/INrMsM4vFPzm9oFGfBcDN3uPpwPveHMMCYJZ3VVQWkI1/3qHZdZrZbcAUYLZzrv70Nq99zM0vZlCKCgCKSGRo9YjCOVdrZncBC4Fo4Bnn3AYz+yXgc84tAJ4GnjezQqAS/wc/Xr88YCNQC9zpnKsDaGqd3ls+AewAlvrnw3nZOffLNtvi07Sj4jCfbavkJ1OG441PRCSsBXLqCefcm8Cbjdr+tcHjY8CMZpb9FfCrQNbptQc0pmCZ5/MKAJ6lq51EJDKE8mR2yDlRAPDCYan0TYoP9nBERDqEguIULNlazq4DKgAoIpFFQXEK8vKLSekey6UqACgiEURBEaCKQ8d5b5MKAIpI5NEnXoBOFgDUvRMiEmEUFAFwzpHnK2bsgJ4MS1MBQBGJLAqKAKwpqWLL7kOaxBaRiKSgCMCXBQDTgz0UEZEOp6BoxdHqOl5fvZOrRqWToAKAIhKBFBSteGu9vwDgjTrtJCIRSkHRirn5xWSmdGOiCgCKSIRSULRg+97DLCuqZEbuABUAFJGIpaBowbwVxSoAKCIRT0HRjNq6euavKOGi4X1UAFBEIpqCohkfbd3L7gPHmZmrowkRiWwKimbM9QoAXjJCBQBFJLIpKJqgAoAiIl/Sp2ATXllVSm2940YVABQRUVA05pxjbn4x4wb2JFsFAEVEFBSNrS7ez9Y9KgAoInKCgqKRPF8JXbtEc/VoFQAUEQEFxVccqa7l9TUqACgi0pCCooG31u3i0PFaTWKLiDSgoGhgrq+YrN7dmZDZK9hDEREJGQoKT9HewywvqmRGboYKAIqINKCg8MzzqQCgiEhTFBT4CwC+tLKEi4f3IS1RBQBFRBpSUABLtpaz+8BxZujeCRGRv6GgwF8AsHePWC49o0+whyIiEnIiPij2HjrOok17uH5cf7pER/w/h4jI34j4T8ZXVqoAoIhISwIKCjObamYFZlZoZvc28Xqcmc31Xl9mZpkNXrvPay8wsymtrdPMsrx1FHrrjD3NbWyWc448XzFnDezJ0D4qACgi0pRWg8LMooHHgCuBHGC2meU06nYrsM85NxR4GHjAWzYHmAWMBKYCj5tZdCvrfAB42FvXPm/d7WKVCgCKiLQqkCOKiUChc26bc64amANMa9RnGvCc93g+cKn571qbBsxxzh13zhUBhd76mlynt8wl3jrw1nnd1966VszzFfsLAI7p115vISLS6QUSFP2B4gbPS7y2Jvs452qBKiClhWWba08B9nvraO69ADCz283MZ2a+8vLyADbjbw1M7s53zsukR1zM11peRCQSdNpPSOfcU8BTALm5ue7rrOOOi4a06ZhERMJRIEcUpUDDk/gZXluTfcwsBkgCKlpYtrn2CqCnt47m3ktERDpQIEGRD2R7VyPF4p+cXtCozwLgZu/xdOB955zz2md5V0VlAdnA8ubW6S2z2FsH3jpf+/qbJyIip6vVU0/OuVozuwtYCEQDzzjnNpjZLwGfc24B8DTwvJkVApX4P/jx+uUBG4Fa4E7nXB1AU+v03vKfgTlm9h/AKm/dIiISJOb/Jb5zy83NdT6fL9jDEBHpVMxshXMut7V+EX9ntoiItExBISIiLVJQiIhIixQUIiLSorCYzDazcmDH11y8N7C3DYfTGWibI4O2Ofyd7vYOcs6lttYpLILidJiZL5BZ/3CibY4M2ubw11Hbq1NPIiLSIgWFiIi0SEHhFRaMMNrmyKBtDn8dsr0RP0chIiIt0xGFiIi0SEEhIiItiuigMLOpZlZgZoVmdm+wx3MqzGyAmS02s41mtsHM/tFrTzazd81sq/d3L6/dzOxRb1vXmtlZDdZ1s9d/q5nd3KB9vJmt85Z51Puq2qDzvnd9lZm94T3PMrNl3jjneqXr8crbz/Xal5lZZoN13Oe1F5jZlAbtIfczYWY9zWy+mW02s01mdk6472czu9v7uV5vZi+aWXy47Wcze8bM9pjZ+gZt7b5fm3uPFjnnIvIP/vLmnwODgVhgDZAT7HGdwvjTgbO8xwnAFiAHeBC412u/F3jAe3wV8BZgwCRgmdeeDGzz/u7lPe7lvbbc62veslcGe7u9cd0D/AV4w3ueB8zyHj8B3OE9/gHwhPd4FjDXe5zj7e84IMv7OYgO1Z8J/N8df5v3OBboGc77Gf/XHxcBXRvs3++E234GJgNnAesbtLX7fm3uPVoca7D/EwTxh/EcYGGD5/cB9wV7XKexPa8BlwMFQLrXlg4UeI+fBGY36F/gvT4beLJB+5NeWzqwuUH7V/oFcTszgEXAJcAb3n+CvUBM4/2K//tOzvEex3j9rPG+PtEvFH8m8H9bZBHehSeN91847mf8QVHsffjFePt5SjjuZyCTrwZFu+/X5t6jpT+RfOrpxA/jCSVeW6fjHWqPA5YBac65Mu+lXUCa97i57W2pvaSJ9mD7H+CnQL33PAXY75yr9Z43HOfJbfNer/L6n+q/RTBlAeXAH73Tbf9rZt0J4/3snCsFHgK+AMrw77cVhPd+PqEj9mtz79GsSA6KsGBmPYCXgB855w40fM35f2UIm+ufzexqYI9zbkWwx9KBYvCfnvi9c24ccBj/6YKTwnA/9wKm4Q/JfkB3YGpQBxUEHbFfA32PSA6KUmBAg+cZXlunYWZd8IfEC865l73m3WaW7r2eDuzx2pvb3pbaM5poD6bzgGvNbDswB//pp0eAnmZ24mt9G47z5LZ5rycBFZz6v0UwlQAlzrll3vP5+IMjnPfzZUCRc67cOVcDvIx/34fzfj6hI/Zrc+/RrEgOinwg27uSIhb/JNiCII8pYN4VDE8Dm5xzv2nw0gLgxJUPN+OfuzjRfpN39cQkoMo7/FwIXGFmvbzf5K7Af/62DDhgZpO897qpwbqCwjl3n3MuwzmXiX9/ve+c+ztgMTDd69Z4m0/8W0z3+juvfZZ3tUwWkI1/4i/kfiacc7uAYjMb7jVdiv876MN2P+M/5TTJzLp5YzqxzWG7nxvoiP3a3Hs0L5iTVsH+g/9Kgi34r4D4l2CP5xTHfj7+Q8a1wGrvz1X4z80uArYC7wHJXn8DHvO2dR2Q22Bd3wUKvT+3NGjPBdZ7y/yORhOqQd7+i/jyqqfB+D8ACoF5QJzXHu89L/ReH9xg+X/xtquABlf5hOLPBDAW8Hn7+lX8V7eE9X4GfgFs9sb1PP4rl8JqPwMv4p+DqcF/5HhrR+zX5t6jpT8q4SEiIi2K5FNPIiISAAWFiIi0SEEhIiItUlCIiEiLFBQiItIiBYWIiLRIQSEiIi36/zob5nVzA95IAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "lrs=[]\n", - "for i in range(100000):\n", - " sc.step()\n", - " lrs.append(sc.get_lr())\n", - "xs = list(range(100000))\n", - "plt.plot(xs, lrs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e613fe16", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f0fd9f40", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/audio_feature.ipynb b/.notebook/audio_feature.ipynb deleted file mode 100644 index 04b4a3924a0131e6ede05ea24604149d8bff22df..0000000000000000000000000000000000000000 --- a/.notebook/audio_feature.ipynb +++ /dev/null @@ -1,1207 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 94, - "id": "matched-camera", - "metadata": {}, - "outputs": [], - "source": [ - "from nnAudio import Spectrogram\n", - "from scipy.io import wavfile\n", - "import torch\n", - "import soundfile as sf\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "quarterly-solution", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[43 75 69 ... 7 6 3]\n", - "[43 75 69 ... 7 6 3]\n", - "[43 75 69 ... 7 6 3]\n" - ] - } - ], - "source": [ - "import scipy.io.wavfile as wav\n", - "\n", - "rate,sig = wav.read('./BAC009S0764W0124.wav')\n", - "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", - "sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sig)\n", - "print(song)\n", - "print(sample)" - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "id": "middle-salem", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "[43 75 69 ... 7 6 3]\n", - "(83792,)\n", - "int16\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2733 seconds\n", - "tensor([[[[-4.0940e+03, 1.2600e+04],\n", - " [ 8.5108e+03, -5.4930e+03],\n", - " [-3.3631e+03, -1.7904e+03],\n", - " ...,\n", - " [ 8.2279e+03, -9.3340e+03],\n", - " [-3.1990e+03, 2.0969e+03],\n", - " [-1.2669e+03, 4.4488e+03]],\n", - "\n", - " [[ 3.4886e+03, -9.9620e+03],\n", - " [-4.5364e+03, 4.1907e+02],\n", - " [ 2.5074e+03, 7.1339e+03],\n", - " ...,\n", - " [-5.4819e+03, 3.9258e+01],\n", - " [ 4.7221e+03, 6.5887e+01],\n", - " [ 9.6492e+02, -3.4386e+03]],\n", - "\n", - " [[-3.4947e+03, 9.2981e+03],\n", - " [-7.5164e+03, 8.1856e+02],\n", - " [-5.3766e+03, -9.0889e+03],\n", - " ...,\n", - " [ 1.4317e+03, 5.7447e+03],\n", - " [-3.1178e+03, 3.0740e+03],\n", - " [-3.4351e+03, 5.6900e+02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.7112e+01, -4.5737e+00],\n", - " [-9.6295e+00, 3.5554e+01],\n", - " [ 1.8527e+00, -1.0491e+01],\n", - " ...,\n", - " [-1.1157e+01, 3.4423e+00],\n", - " [ 3.1193e+00, -4.4388e+00],\n", - " [-8.8242e+00, 8.0324e+00]],\n", - "\n", - " [[-6.5080e+01, 2.9543e+00],\n", - " [ 3.9992e+01, -1.3836e+01],\n", - " [-9.2803e+00, 1.0318e+01],\n", - " ...,\n", - " [ 4.2928e+00, 9.2397e+00],\n", - " [ 3.6642e+00, 9.4680e+00],\n", - " [ 4.8932e+00, -2.5199e+01]],\n", - "\n", - " [[ 4.7264e+01, -1.0721e+00],\n", - " [-6.0516e+00, -1.4589e+01],\n", - " [ 1.3127e+01, 1.4995e+00],\n", - " ...,\n", - " [ 1.7333e+01, -1.4380e+01],\n", - " [-3.6046e+00, -6.1019e+00],\n", - " [ 1.3321e+01, 2.3184e+01]]]])\n" - ] - } - ], - "source": [ - "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", - "print(sr)\n", - "print(song)\n", - "print(song.shape)\n", - "print(song.dtype)\n", - "x = song\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", - " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", - " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", - "\n", - "spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "print(spec)" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "id": "finished-sterling", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "[43 75 69 ... 7 6 3]\n", - "(83792,)\n", - "int16\n", - "True\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2001 seconds\n", - "torch.Size([1, 1025, 164, 2])\n", - "tensor([[[[-4.0940e+03, 1.2600e+04],\n", - " [ 8.5108e+03, -5.4930e+03],\n", - " [-3.3631e+03, -1.7904e+03],\n", - " ...,\n", - " [ 8.2279e+03, -9.3340e+03],\n", - " [-3.1990e+03, 2.0969e+03],\n", - " [-1.2669e+03, 4.4488e+03]],\n", - "\n", - " [[ 3.4886e+03, -9.9620e+03],\n", - " [-4.5364e+03, 4.1907e+02],\n", - " [ 2.5074e+03, 7.1339e+03],\n", - " ...,\n", - " [-5.4819e+03, 3.9258e+01],\n", - " [ 4.7221e+03, 6.5887e+01],\n", - " [ 9.6492e+02, -3.4386e+03]],\n", - "\n", - " [[-3.4947e+03, 9.2981e+03],\n", - " [-7.5164e+03, 8.1856e+02],\n", - " [-5.3766e+03, -9.0889e+03],\n", - " ...,\n", - " [ 1.4317e+03, 5.7447e+03],\n", - " [-3.1178e+03, 3.0740e+03],\n", - " [-3.4351e+03, 5.6900e+02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.7112e+01, -4.5737e+00],\n", - " [-9.6295e+00, 3.5554e+01],\n", - " [ 1.8527e+00, -1.0491e+01],\n", - " ...,\n", - " [-1.1157e+01, 3.4423e+00],\n", - " [ 3.1193e+00, -4.4388e+00],\n", - " [-8.8242e+00, 8.0324e+00]],\n", - "\n", - " [[-6.5080e+01, 2.9543e+00],\n", - " [ 3.9992e+01, -1.3836e+01],\n", - " [-9.2803e+00, 1.0318e+01],\n", - " ...,\n", - " [ 4.2928e+00, 9.2397e+00],\n", - " [ 3.6642e+00, 9.4680e+00],\n", - " [ 4.8932e+00, -2.5199e+01]],\n", - "\n", - " [[ 4.7264e+01, -1.0721e+00],\n", - " [-6.0516e+00, -1.4589e+01],\n", - " [ 1.3127e+01, 1.4995e+00],\n", - " ...,\n", - " [ 1.7333e+01, -1.4380e+01],\n", - " [-3.6046e+00, -6.1019e+00],\n", - " [ 1.3321e+01, 2.3184e+01]]]])\n", - "True\n" - ] - } - ], - "source": [ - "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sr)\n", - "print(wav)\n", - "print(wav.shape)\n", - "print(wav.dtype)\n", - "print(np.allclose(wav, song))\n", - "\n", - "x = wav\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", - " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", - " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", - "\n", - "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "print(wav_spec.shape)\n", - "print(wav_spec)\n", - "print(np.allclose(wav_spec, spec))" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "running-technology", - "metadata": {}, - "outputs": [], - "source": [ - "import decimal\n", - "\n", - "import numpy\n", - "import math\n", - "import logging\n", - "def round_half_up(number):\n", - " return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n", - "\n", - "\n", - "def rolling_window(a, window, step=1):\n", - " # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n", - " shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n", - " strides = a.strides + (a.strides[-1],)\n", - " return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n", - "\n", - "\n", - "def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n", - " \"\"\"Frame a signal into overlapping frames.\n", - "\n", - " :param sig: the audio signal to frame.\n", - " :param frame_len: length of each frame measured in samples.\n", - " :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n", - " :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n", - " :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n", - " :returns: an array of frames. Size is NUMFRAMES by frame_len.\n", - " \"\"\"\n", - " slen = len(sig)\n", - " frame_len = int(round_half_up(frame_len))\n", - " frame_step = int(round_half_up(frame_step))\n", - " if slen <= frame_len:\n", - " numframes = 1\n", - " else:\n", - " numframes = 1 + (( slen - frame_len) // frame_step)\n", - "\n", - " # check kaldi/src/feat/feature-window.h\n", - " padsignal = sig[:(numframes-1)*frame_step+frame_len]\n", - " if wintype is 'povey':\n", - " win = numpy.empty(frame_len)\n", - " for i in range(frame_len):\n", - " win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n", - " else: # the hamming window\n", - " win = numpy.hamming(frame_len)\n", - " \n", - " if stride_trick:\n", - " frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n", - " else:\n", - " indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n", - " numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n", - " indices = numpy.array(indices, dtype=numpy.int32)\n", - " frames = padsignal[indices]\n", - " win = numpy.tile(win, (numframes, 1))\n", - " \n", - " frames = frames.astype(numpy.float32)\n", - " raw_frames = numpy.zeros(frames.shape)\n", - " for frm in range(frames.shape[0]):\n", - " raw_frames[frm,:] = frames[frm,:]\n", - " frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n", - " frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n", - " # raw_frames[frm,:] = frames[frm,:]\n", - " frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n", - "\n", - " return frames * win, raw_frames\n", - "\n", - "\n", - "def magspec(frames, NFFT):\n", - " \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", - "\n", - " :param frames: the array of frames. Each row is a frame.\n", - " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", - " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n", - " \"\"\"\n", - " if numpy.shape(frames)[1] > NFFT:\n", - " logging.warn(\n", - " 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n", - " numpy.shape(frames)[1], NFFT)\n", - " complex_spec = numpy.fft.rfft(frames, NFFT)\n", - " return numpy.absolute(complex_spec)\n", - "\n", - "\n", - "def powspec(frames, NFFT):\n", - " \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", - "\n", - " :param frames: the array of frames. Each row is a frame.\n", - " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", - " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n", - " \"\"\"\n", - " return numpy.square(magspec(frames, NFFT))\n", - "\n", - "\n", - "def do_dither(signal, dither_value=1.0):\n", - " signal += numpy.random.normal(size=signal.shape) * dither_value\n", - " return signal\n", - " \n", - "def do_remove_dc_offset(signal):\n", - " signal -= numpy.mean(signal)\n", - " return signal\n", - "\n", - "def do_preemphasis(signal, coeff=0.97):\n", - " \"\"\"perform preemphasis on the input signal.\n", - "\n", - " :param signal: The signal to filter.\n", - " :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n", - " :returns: the filtered signal.\n", - " \"\"\"\n", - " return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "ignored-retreat", - "metadata": {}, - "outputs": [], - "source": [ - "def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", - " wintype='hamming'):\n", - " highfreq= highfreq or samplerate/2\n", - " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", - " spec = magspec(frames, nfft) # nearly the same until this part\n", - " rspec = magspec(raw_frames, nfft)\n", - " return spec, rspec\n", - "\n", - "\n", - "\n", - "def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", - " wintype='hamming'):\n", - " highfreq= highfreq or samplerate/2\n", - " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", - " return raw_frames" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "federal-teacher", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import torch\n", - "import torch.nn as nn\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "import scipy # used only in CFP\n", - "\n", - "import numpy as np\n", - "from time import time\n", - "\n", - "def pad_center(data, size, axis=-1, **kwargs):\n", - "\n", - " kwargs.setdefault('mode', 'constant')\n", - "\n", - " n = data.shape[axis]\n", - "\n", - " lpad = int((size - n) // 2)\n", - "\n", - " lengths = [(0, 0)] * data.ndim\n", - " lengths[axis] = (lpad, int(size - n - lpad))\n", - "\n", - " if lpad < 0:\n", - " raise ParameterError(('Target size ({:d}) must be '\n", - " 'at least input size ({:d})').format(size, n))\n", - "\n", - " return np.pad(data, lengths, **kwargs)\n", - "\n", - "\n", - "\n", - "sz_float = 4 # size of a float\n", - "epsilon = 10e-8 # fudge factor for normalization\n", - "\n", - "def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n", - " freq_scale='linear', window='hann', verbose=True):\n", - "\n", - " if freq_bins==None: freq_bins = n_fft//2+1\n", - " if win_length==None: win_length = n_fft\n", - "\n", - " s = np.arange(0, n_fft, 1.)\n", - " wsin = np.empty((freq_bins,1,n_fft))\n", - " wcos = np.empty((freq_bins,1,n_fft))\n", - " start_freq = fmin\n", - " end_freq = fmax\n", - " bins2freq = []\n", - " binslist = []\n", - "\n", - " # num_cycles = start_freq*d/44000.\n", - " # scaling_ind = np.log(end_freq/start_freq)/k\n", - "\n", - " # Choosing window shape\n", - "\n", - " #window_mask = get_window(window, int(win_length), fftbins=True)\n", - " window_mask = np.hamming(int(win_length))\n", - " window_mask = pad_center(window_mask, n_fft)\n", - "\n", - " if freq_scale == 'linear':\n", - " if verbose==True:\n", - " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", - " f\"get a valid freq range\")\n", - " \n", - " start_bin = start_freq*n_fft/sr\n", - " scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n", - "\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n", - " bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n", - " binslist.append((k*scaling_ind+start_bin))\n", - " wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", - "\n", - " elif freq_scale == 'log':\n", - " if verbose==True:\n", - " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", - " f\"get a valid freq range\")\n", - " start_bin = start_freq*n_fft/sr\n", - " scaling_ind = np.log(end_freq/start_freq)/freq_bins\n", - "\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n", - " bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n", - " binslist.append((np.exp(k*scaling_ind)*start_bin))\n", - " wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", - "\n", - " elif freq_scale == 'no':\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " bins2freq.append(k*sr/n_fft)\n", - " binslist.append(k)\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", - " else:\n", - " print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n", - " return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n", - "\n", - "\n", - "\n", - "def broadcast_dim(x):\n", - " \"\"\"\n", - " Auto broadcast input so that it can fits into a Conv1d\n", - " \"\"\"\n", - "\n", - " if x.dim() == 2:\n", - " x = x[:, None, :]\n", - " elif x.dim() == 1:\n", - " # If nn.DataParallel is used, this broadcast doesn't work\n", - " x = x[None, None, :]\n", - " elif x.dim() == 3:\n", - " pass\n", - " else:\n", - " raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n", - " return x\n", - "\n", - "\n", - "\n", - "### --------------------------- Spectrogram Classes ---------------------------###\n", - "class STFT(torch.nn.Module):\n", - "\n", - " def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n", - " freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n", - " fmin=50, fmax=6000, sr=22050, trainable=False,\n", - " output_format=\"Complex\", verbose=True):\n", - "\n", - " super().__init__()\n", - "\n", - " # Trying to make the default setting same as librosa\n", - " if win_length==None: win_length = n_fft\n", - " if hop_length==None: hop_length = int(win_length // 4)\n", - "\n", - " self.output_format = output_format\n", - " self.trainable = trainable\n", - " self.stride = hop_length\n", - " self.center = center\n", - " self.pad_mode = pad_mode\n", - " self.n_fft = n_fft\n", - " self.freq_bins = freq_bins\n", - " self.trainable = trainable\n", - " self.pad_amount = self.n_fft // 2\n", - " self.window = window\n", - " self.win_length = win_length\n", - " self.iSTFT = iSTFT\n", - " self.trainable = trainable\n", - " start = time()\n", - "\n", - "\n", - "\n", - " # Create filter windows for stft\n", - " kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n", - " win_length=win_length,\n", - " freq_bins=freq_bins,\n", - " window=window,\n", - " freq_scale=freq_scale,\n", - " fmin=fmin,\n", - " fmax=fmax,\n", - " sr=sr,\n", - " verbose=verbose)\n", - "\n", - "\n", - " kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n", - " kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n", - " \n", - " # In this way, the inverse kernel and the forward kernel do not share the same memory...\n", - " kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n", - " kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n", - " \n", - " if iSTFT:\n", - " self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n", - " self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n", - "\n", - " # Applying window functions to the Fourier kernels\n", - " if window:\n", - " window_mask = torch.tensor(window_mask)\n", - " wsin = kernel_sin * window_mask\n", - " wcos = kernel_cos * window_mask\n", - " else:\n", - " wsin = kernel_sin\n", - " wcos = kernel_cos\n", - " \n", - " if self.trainable==False:\n", - " self.register_buffer('wsin', wsin)\n", - " self.register_buffer('wcos', wcos) \n", - " \n", - " if self.trainable==True:\n", - " wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n", - " wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n", - " self.register_parameter('wsin', wsin)\n", - " self.register_parameter('wcos', wcos) \n", - " \n", - " # Prepare the shape of window mask so that it can be used later in inverse\n", - " # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n", - " \n", - " if verbose==True:\n", - " print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n", - " else:\n", - " pass\n", - "\n", - " def forward(self, x, output_format=None):\n", - " \"\"\"\n", - " Convert a batch of waveforms to spectrograms.\n", - " \n", - " Parameters\n", - " ----------\n", - " x : torch tensor\n", - " Input signal should be in either of the following shapes.\\n\n", - " 1. ``(len_audio)``\\n\n", - " 2. ``(num_audio, len_audio)``\\n\n", - " 3. ``(num_audio, 1, len_audio)``\n", - " It will be automatically broadcast to the right shape\n", - " \n", - " output_format : str\n", - " Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n", - " Default value is ``Complex``. \n", - " \n", - " \"\"\"\n", - " output_format = output_format or self.output_format\n", - " self.num_samples = x.shape[-1]\n", - " \n", - " x = broadcast_dim(x)\n", - " if self.center:\n", - " if self.pad_mode == 'constant':\n", - " padding = nn.ConstantPad1d(self.pad_amount, 0)\n", - "\n", - " elif self.pad_mode == 'reflect':\n", - " if self.num_samples < self.pad_amount:\n", - " raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n", - " padding = nn.ReflectionPad1d(self.pad_amount)\n", - "\n", - " x = padding(x)\n", - " spec_imag = conv1d(x, self.wsin, stride=self.stride)\n", - " spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n", - "\n", - " # remove redundant parts\n", - " spec_real = spec_real[:, :self.freq_bins, :]\n", - " spec_imag = spec_imag[:, :self.freq_bins, :]\n", - "\n", - " if output_format=='Magnitude':\n", - " spec = spec_real.pow(2) + spec_imag.pow(2)\n", - " if self.trainable==True:\n", - " return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n", - " else:\n", - " return torch.sqrt(spec)\n", - "\n", - " elif output_format=='Complex':\n", - " return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n", - "\n", - " elif output_format=='Phase':\n", - " return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n", - "\n", - " def inverse(self, X, onesided=True, length=None, refresh_win=True):\n", - " \"\"\"\n", - " This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n", - " which is to convert spectrograms back to waveforms. \n", - " It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n", - " please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n", - " \n", - " Parameters\n", - " ----------\n", - " onesided : bool\n", - " If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n", - " else use ``onesided=False``\n", - "\n", - " length : int\n", - " To make sure the inverse STFT has the same output length of the original waveform, please\n", - " set `length` as your intended waveform length. By default, ``length=None``,\n", - " which will remove ``n_fft//2`` samples from the start and the end of the output.\n", - " \n", - " refresh_win : bool\n", - " Recalculating the window sum square. If you have an input with fixed number of timesteps,\n", - " you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n", - " \n", - " \n", - " \"\"\"\n", - " if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n", - " raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n", - " \n", - " assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n", - " \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n", - " \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n", - " if onesided:\n", - " X = extend_fbins(X) # extend freq\n", - "\n", - " \n", - " X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n", - "\n", - " # broadcast dimensions to support 2D convolution\n", - " X_real_bc = X_real.unsqueeze(1)\n", - " X_imag_bc = X_imag.unsqueeze(1)\n", - " a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n", - " b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n", - " \n", - " # compute real and imag part. signal lies in the real part\n", - " real = a1 - b2\n", - " real = real.squeeze(-2)*self.window_mask\n", - "\n", - " # Normalize the amplitude with n_fft\n", - " real /= (self.n_fft)\n", - "\n", - " # Overlap and Add algorithm to connect all the frames\n", - " real = overlap_add(real, self.stride)\n", - " \n", - " # Prepare the window sumsqure for division\n", - " # Only need to create this window once to save time\n", - " # Unless the input spectrograms have different time steps\n", - " if hasattr(self, 'w_sum')==False or refresh_win==True:\n", - " self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n", - " self.nonzero_indices = (self.w_sum>1e-10) \n", - " else:\n", - " pass\n", - " real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n", - " # Remove padding\n", - " if length is None: \n", - " if self.center:\n", - " real = real[:, self.pad_amount:-self.pad_amount]\n", - "\n", - " else:\n", - " if self.center:\n", - " real = real[:, self.pad_amount:self.pad_amount + length] \n", - " else:\n", - " real = real[:, :length] \n", - " \n", - " return real\n", - " \n", - " def extra_repr(self) -> str:\n", - " return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n", - " self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n", - " ) " - ] - }, - { - "cell_type": "code", - "execution_count": 128, - "id": "unusual-baker", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "(83792,)\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.0153 seconds\n", - "torch.Size([521, 257])\n", - "(522, 257)\n", - "[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n", - " 2.32206573e+01 1.90274487e+01]\n", - " [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n", - " 1.80534729e+02 1.84179596e+02]\n", - " [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n", - " 1.21321175e+02 1.08345497e+02]\n", - " ...\n", - " [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n", - " 1.25670158e+02 1.20691467e+02]\n", - " [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n", - " 9.01831894e+01 9.84055099e+01]\n", - " [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n", - " 8.94473553e+00 9.61348629e+00]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]\n", - " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", - " 6.36197377e+01 4.40000000e+01]]\n" - ] - } - ], - "source": [ - "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sr)\n", - "print(wav.shape)\n", - "\n", - "x = wav\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n", - " window='', freq_scale='linear', center=False, pad_mode='constant',\n", - " fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n", - "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "wav_spec = wav_spec[0].T\n", - "print(wav_spec.shape)\n", - "\n", - "\n", - "spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "print(spec.shape)\n", - "\n", - "print(wav_spec.numpy())\n", - "print(rspec)\n", - "# print(spec)\n", - "\n", - "# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n", - "# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - "# dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - "# wintype='hamming')\n", - "# print(rspec)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "white-istanbul", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 129, - "id": "modern-rescue", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n", - " 0.75 0.41317591 0.11697778 0. ]\n" - ] - }, - { - "data": { - "text/plain": [ - "array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n", - " 0.9045085, 0.6545085, 0.3454915, 0.0954915])" - ] - }, - "execution_count": 129, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(np.hanning(10))\n", - "from scipy.signal import get_window\n", - "get_window('hann', 10, fftbins=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "professional-journalism", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 153, - "id": "involved-motion", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(522, 400)\n", - "[[ 43. 75. 69. ... 46. 46. 45.]\n", - " [ 210. 215. 216. ... -86. -89. -91.]\n", - " [ 128. 128. 128. ... -154. -151. -151.]\n", - " ...\n", - " [ -60. -61. -61. ... 112. 109. 110.]\n", - " [ 20. 22. 24. ... 91. 87. 87.]\n", - " [ 111. 107. 108. ... -6. -4. -8.]]\n", - "torch.Size([1, 1, 83792])\n", - "torch.Size([400, 1, 512])\n", - "torch.Size([1, 400, 521])\n", - "conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n", - " [ 210., 215., 216., ..., -86., -89., -91.],\n", - " [ 128., 128., 128., ..., -154., -151., -151.],\n", - " ...,\n", - " [-143., -141., -142., ..., 96., 101., 101.],\n", - " [ -60., -61., -61., ..., 112., 109., 110.],\n", - " [ 20., 22., 24., ..., 91., 87., 87.]])\n", - "xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", - " 2.4064583e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", - " 9.5781029e+01 1.4200000e+02]\n", - " ...\n", - " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", - " 9.1511757e+01 1.1500000e+02]\n", - " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", - " 7.8405365e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", - " 5.1310158e+01 3.5000000e+01]]\n", - "torch.Size([521, 257])\n", - "yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", - " 9.15117270e+01 1.15000000e+02]\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]]\n", - "yy (522, 257)\n", - "[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", - " 2.4064583e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", - " 9.5781029e+01 1.4200000e+02]\n", - " ...\n", - " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", - " 9.1511757e+01 1.1500000e+02]\n", - " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", - " 7.8405365e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", - " 5.1310158e+01 3.5000000e+01]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", - " 9.15117270e+01 1.15000000e+02]\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]]\n", - "False\n" - ] - } - ], - "source": [ - "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "print(f.shape)\n", - "print(f)\n", - "\n", - "n_fft=512\n", - "freq_bins = n_fft//2+1\n", - "s = np.arange(0, n_fft, 1.)\n", - "wsin = np.empty((freq_bins,1,n_fft))\n", - "wcos = np.empty((freq_bins,1,n_fft))\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", - "\n", - "\n", - "wsin = np.empty((n_fft,1,n_fft))\n", - "wcos = np.empty((n_fft,1,n_fft))\n", - "for k in range(n_fft): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n", - " wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n", - " \n", - " \n", - "wsin = np.empty((400,1,n_fft))\n", - "wcos = np.empty((400,1,n_fft))\n", - "for k in range(400): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.eye(400, n_fft)[k]\n", - " wcos[k,0,:] = np.eye(400, n_fft)[k]\n", - " \n", - "\n", - " \n", - "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", - "x = x[None, None, :]\n", - "print(x.size())\n", - "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", - "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", - "print(kernel_sin.size())\n", - "\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "spec_imag = conv1d(x, kernel_sin, stride=160)\n", - "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", - "\n", - "print(spec_imag.size())\n", - "print(\"conv frame\", spec_imag[0].T)\n", - "# print(spec_imag[0].T[:, :400])\n", - "\n", - "# remove redundant parts\n", - "# spec_real = spec_real[:, :freq_bins, :]\n", - "# spec_imag = spec_imag[:, :freq_bins, :]\n", - "# spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "# spec = torch.sqrt(spec)\n", - "# print(spec)\n", - "\n", - "\n", - "\n", - "s = np.arange(0, 512, 1.)\n", - "# s = s[::-1]\n", - "wsin = np.empty((freq_bins, 400))\n", - "wcos = np.empty((freq_bins, 400))\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", - " wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", - "\n", - "spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n", - "spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n", - "\n", - "\n", - "# remove redundant parts\n", - "spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "spec = torch.sqrt(spec)\n", - "\n", - "print('xx', spec.numpy())\n", - "print(spec.size())\n", - "print('yy', rspec[:521, :])\n", - "print('yy', rspec.shape)\n", - "\n", - "\n", - "x = spec.numpy()\n", - "y = rspec[:-1, :]\n", - "print(x)\n", - "print(y)\n", - "print(np.allclose(x, y))" - ] - }, - { - "cell_type": "code", - "execution_count": 160, - "id": "mathematical-traffic", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([257, 1, 400])\n", - "tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n", - " 3.8693e+04, 3.1020e+04],\n", - " [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n", - " 1.3598e+04, 1.5920e+04],\n", - " [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n", - " 6.7707e+03, 4.3020e+03],\n", - " ...,\n", - " [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n", - " 6.1071e+01, 5.3685e+01],\n", - " [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n", - " 5.1310e+01, 6.3620e+01],\n", - " [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n", - " 3.5000e+01, 4.4000e+01]]])\n", - "[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n", - " 2.4064537e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n", - " 9.5781006e+01 1.4200000e+02]\n", - " ...\n", - " [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n", - " 7.8405350e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n", - " 5.1310150e+01 3.5000000e+01]\n", - " [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n", - " 6.3619797e+01 4.4000000e+01]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]\n", - " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", - " 6.36197377e+01 4.40000000e+01]]\n", - "False\n" - ] - } - ], - "source": [ - "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "\n", - "n_fft=512\n", - "freq_bins = n_fft//2+1\n", - "s = np.arange(0, n_fft, 1.)\n", - "wsin = np.empty((freq_bins,1,400))\n", - "wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", - "\n", - " \n", - "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", - "x = x[None, None, :] #[B, C, T]\n", - "\n", - "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", - "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", - "print(kernel_sin.size())\n", - "\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n", - "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", - "\n", - "# remove redundant parts\n", - "spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "spec = torch.sqrt(spec)\n", - "print(spec)\n", - "\n", - "x = spec[0].T.numpy()\n", - "y = rspec[:, :]\n", - "print(x)\n", - "print(y)\n", - "print(np.allclose(x, y))" - ] - }, - { - "cell_type": "code", - "execution_count": 162, - "id": "olive-nicaragua", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n", - " \"\"\"Entry point for launching an IPython kernel.\n" - ] - }, - { - "data": { - "text/plain": [ - "27241" - ] - }, - "execution_count": 162, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.argmax(np.abs(x -y) / np.abs(y))" - ] - }, - { - "cell_type": "code", - "execution_count": 165, - "id": "ultimate-assault", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.0" - ] - }, - "execution_count": 165, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y[np.unravel_index(27241, y.shape)]" - ] - }, - { - "cell_type": "code", - "execution_count": 166, - "id": "institutional-stock", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4.2412265e-10" - ] - }, - "execution_count": 166, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x[np.unravel_index(27241, y.shape)]" - ] - }, - { - "cell_type": "code", - "execution_count": 167, - "id": "integrated-courage", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 167, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(y, x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "different-operation", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/compute_cmvn_loader_test.ipynb b/.notebook/compute_cmvn_loader_test.ipynb deleted file mode 100644 index 2b0a8b75ffc84c445152733fbd5a35789d393f08..0000000000000000000000000000000000000000 --- a/.notebook/compute_cmvn_loader_test.ipynb +++ /dev/null @@ -1,793 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "purple-consequence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "defensive-mason", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "patient-convention", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Namespace(delta_delta=False, feat_dim=80, manifest_path='examples/aishell/s1/data/manifest.train.raw', num_samples=-1, num_workers=16, output_path='data/librispeech/mean_std.npz', sample_rate=16000, specgram_type='fbank', stride_ms=10.0, window_ms=25.0)\n" - ] - } - ], - "source": [ - "import argparse\n", - "import functools\n", - "\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "from deepspeech.utils.utility import add_arguments\n", - "from deepspeech.utils.utility import print_arguments\n", - "\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, -1, \"# of samples to for statistics.\")\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc, fbank.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n", - "add_arg('delta_delta', bool,\n", - " False,\n", - " \"Audio feature with delta delta.\")\n", - "add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n", - "add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n", - "add_arg('sample_rate', int, 16000, \"target sample rate.\")\n", - "add_arg('manifest_path', str,\n", - " 'examples/aishell/s1/data/manifest.train.raw',\n", - " \"Filepath of manifest to compute normalizer's mean and stddev.\")\n", - "add_arg('num_workers',\n", - " default=16,\n", - " type=int,\n", - " help='num of subprocess workers for processing')\n", - "add_arg('output_path', str,\n", - " 'data/librispeech/mean_std.npz',\n", - " \"Filepath of write mean and stddev to (.npz).\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(args)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "enormous-currency", - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "\n", - "import numpy as np\n", - "import paddle\n", - "from paddle.io import DataLoader\n", - "from paddle.io import Dataset\n", - "\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "class CollateFunc(object):\n", - " ''' Collate function for AudioDataset\n", - " '''\n", - " def __init__(self):\n", - " pass\n", - " \n", - " def __call__(self, batch):\n", - " mean_stat = None\n", - " var_stat = None\n", - " number = 0\n", - " for feat in batch:\n", - " sums = np.sum(feat, axis=1)\n", - " if mean_stat is None:\n", - " mean_stat = sums\n", - " else:\n", - " mean_stat += sums\n", - "\n", - " square_sums = np.sum(np.square(feat), axis=1)\n", - " if var_stat is None:\n", - " var_stat = square_sums\n", - " else:\n", - " var_stat += square_sums\n", - "\n", - " number += feat.shape[1]\n", - " #return paddle.to_tensor(number), paddle.to_tensor(mean_stat), paddle.to_tensor(var_stat)\n", - " return number, mean_stat, var_stat\n", - "\n", - "\n", - "class AudioDataset(Dataset):\n", - " def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None):\n", - " self.feature_func = feature_func\n", - " self._rng = rng\n", - " manifest = read_manifest(manifest_path)\n", - " if num_samples == -1:\n", - " sampled_manifest = manifest\n", - " else:\n", - " sampled_manifest = self._rng.sample(manifest, num_samples)\n", - " self.items = sampled_manifest\n", - "\n", - " def __len__(self):\n", - " return len(self.items)\n", - "\n", - " def __getitem__(self, idx):\n", - " key = self.items[idx]['feat']\n", - " audioseg = AudioSegment.from_file(key)\n", - " feat = self.feature_func(audioseg) #(D, T)\n", - " return feat" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "armed-semester", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "process 1000 wavs,450739 frames\n", - "process 2000 wavs,887447 frames\n", - "process 3000 wavs,1354148 frames\n", - "process 4000 wavs,1816494 frames\n", - "process 5000 wavs,2359211 frames\n", - "process 6000 wavs,2828455 frames\n", - "process 7000 wavs,3276186 frames\n", - "process 8000 wavs,3692234 frames\n", - "process 9000 wavs,4139360 frames\n", - "process 10000 wavs,4591528 frames\n", - "process 11000 wavs,5020114 frames\n", - "process 12000 wavs,5459523 frames\n", - "process 13000 wavs,5899534 frames\n", - "process 14000 wavs,6323242 frames\n", - "process 15000 wavs,6736597 frames\n", - "process 16000 wavs,7207686 frames\n", - "process 17000 wavs,7637800 frames\n", - "process 18000 wavs,8093004 frames\n", - "process 19000 wavs,8529518 frames\n", - "process 20000 wavs,8906022 frames\n", - "process 21000 wavs,9352652 frames\n", - "process 22000 wavs,9807495 frames\n", - "process 23000 wavs,10247938 frames\n", - "process 24000 wavs,10700011 frames\n", - "process 25000 wavs,11126134 frames\n", - "process 26000 wavs,11558061 frames\n", - "process 27000 wavs,12010359 frames\n", - "process 28000 wavs,12470938 frames\n", - "process 29000 wavs,12916013 frames\n", - "process 30000 wavs,13345816 frames\n", - "process 31000 wavs,13752365 frames\n", - "process 32000 wavs,14174801 frames\n", - "process 33000 wavs,14642170 frames\n", - "process 34000 wavs,15053557 frames\n", - "process 35000 wavs,15531890 frames\n", - "process 36000 wavs,16022711 frames\n", - "process 37000 wavs,16437688 frames\n", - "process 38000 wavs,16859517 frames\n", - "process 39000 wavs,17307676 frames\n", - "process 40000 wavs,17796629 frames\n", - "process 41000 wavs,18264151 frames\n", - "process 42000 wavs,18711898 frames\n", - "process 43000 wavs,19159890 frames\n", - "process 44000 wavs,19576435 frames\n", - "process 45000 wavs,19992793 frames\n", - "process 46000 wavs,20464449 frames\n", - "process 47000 wavs,20886021 frames\n", - "process 48000 wavs,21317318 frames\n", - "process 49000 wavs,21738034 frames\n", - "process 50000 wavs,22171890 frames\n", - "process 51000 wavs,22622238 frames\n", - "process 52000 wavs,23100734 frames\n", - "process 53000 wavs,23526901 frames\n", - "process 54000 wavs,23969746 frames\n", - "process 55000 wavs,24418691 frames\n", - "process 56000 wavs,24862546 frames\n", - "process 57000 wavs,25336448 frames\n", - "process 58000 wavs,25778435 frames\n", - "process 59000 wavs,26216199 frames\n", - "process 60000 wavs,26694692 frames\n", - "process 61000 wavs,27148978 frames\n", - "process 62000 wavs,27617088 frames\n", - "process 63000 wavs,28064946 frames\n", - "process 64000 wavs,28519843 frames\n", - "process 65000 wavs,28989722 frames\n", - "process 66000 wavs,29470156 frames\n", - "process 67000 wavs,29952931 frames\n", - "process 68000 wavs,30360555 frames\n", - "process 69000 wavs,30797929 frames\n", - "process 70000 wavs,31218227 frames\n", - "process 71000 wavs,31663934 frames\n", - "process 72000 wavs,32107468 frames\n", - "process 73000 wavs,32541943 frames\n", - "process 74000 wavs,33010702 frames\n", - "process 75000 wavs,33448082 frames\n", - "process 76000 wavs,33886812 frames\n", - "process 77000 wavs,34338108 frames\n", - "process 78000 wavs,34761495 frames\n", - "process 79000 wavs,35199730 frames\n", - "process 80000 wavs,35669630 frames\n", - "process 81000 wavs,36122402 frames\n", - "process 82000 wavs,36604561 frames\n", - "process 83000 wavs,37085552 frames\n", - "process 84000 wavs,37517500 frames\n", - "process 85000 wavs,37987196 frames\n", - "process 86000 wavs,38415721 frames\n", - "process 87000 wavs,38889467 frames\n", - "process 88000 wavs,39337809 frames\n", - "process 89000 wavs,39792342 frames\n", - "process 90000 wavs,40287946 frames\n", - "process 91000 wavs,40719461 frames\n", - "process 92000 wavs,41178919 frames\n", - "process 93000 wavs,41659635 frames\n", - "process 94000 wavs,42132985 frames\n", - "process 95000 wavs,42584564 frames\n", - "process 96000 wavs,43018598 frames\n", - "process 97000 wavs,43480662 frames\n", - "process 98000 wavs,43973670 frames\n", - "process 99000 wavs,44448190 frames\n", - "process 100000 wavs,44935034 frames\n", - "process 101000 wavs,45379812 frames\n", - "process 102000 wavs,45821207 frames\n", - "process 103000 wavs,46258420 frames\n", - "process 104000 wavs,46743733 frames\n", - "process 105000 wavs,47206922 frames\n", - "process 106000 wavs,47683041 frames\n", - "process 107000 wavs,48122809 frames\n", - "process 108000 wavs,48594623 frames\n", - "process 109000 wavs,49086358 frames\n", - "process 110000 wavs,49525568 frames\n", - "process 111000 wavs,49985820 frames\n", - "process 112000 wavs,50428262 frames\n", - "process 113000 wavs,50897957 frames\n", - "process 114000 wavs,51344589 frames\n", - "process 115000 wavs,51774621 frames\n", - "process 116000 wavs,52243372 frames\n", - "process 117000 wavs,52726025 frames\n", - "process 118000 wavs,53170026 frames\n", - "process 119000 wavs,53614141 frames\n", - "process 120000 wavs,54071271 frames\n" - ] - } - ], - "source": [ - "\n", - "augmentation_pipeline = AugmentationPipeline('{}')\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20)\n", - "\n", - "def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - "\n", - "collate_func = CollateFunc()\n", - "\n", - "dataset = AudioDataset(\n", - " args.manifest_path,\n", - " augment_and_featurize, \n", - " args.num_samples)\n", - "\n", - "batch_size = 20\n", - "data_loader = DataLoader(\n", - " dataset,\n", - " batch_size=batch_size,\n", - " shuffle=False,\n", - " num_workers=args.num_workers,\n", - " collate_fn=collate_func)\n", - "\n", - "with paddle.no_grad():\n", - " all_mean_stat = None\n", - " all_var_stat = None\n", - " all_number = 0\n", - " wav_number = 0\n", - " for i, batch in enumerate(data_loader()):\n", - " #for batch in data_loader():\n", - " number, mean_stat, var_stat = batch\n", - " if i == 0:\n", - " all_mean_stat = mean_stat\n", - " all_var_stat = var_stat\n", - " else:\n", - " all_mean_stat += mean_stat\n", - " all_var_stat += var_stat\n", - " all_number += number\n", - " wav_number += batch_size\n", - "\n", - " if wav_number % 1000 == 0:\n", - " print('process {} wavs,{} frames'.format(wav_number,\n", - " all_number))\n", - "\n", - "cmvn_info = {\n", - " 'mean_stat': list(all_mean_stat.tolist()),\n", - " 'var_stat': list(all_var_stat.tolist()),\n", - " 'frame_num': all_number\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "danish-executive", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'mean_stat': [-813852467.7953382, -769025957.9140725, -809499593.411409, -774700574.014532, -750961217.5896736, -760564397.2864963, -805662399.3771614, -843490965.4231446, -850242081.9416809, -857678651.504435, -879067453.9826999, -908602072.3856701, -936850957.7187386, -957242686.489041, -968425442.0916103, -972687545.5953809, -980383731.7683417, -991533337.6343704, -1001966818.1164789, -1010334169.7486078, -1016855066.9099333, -1022176245.7021623, -1025700476.4788507, -1030678878.3195274, -1037075963.124199, -1042705719.0195516, -1047422212.6492896, -1049003537.271861, -1050314833.7453628, -1050772191.0204058, -1050010034.9948177, -1050436065.1336465, -1053327181.7978873, -1058710548.2036785, -1065950852.4966162, -1071709705.0060445, -1077682778.259181, -1083371045.272074, -1089708906.2657735, -1096312217.7865202, -1101089858.8364556, -1104965332.4332569, -1107791702.5223634, -1109431075.2374773, -1110066333.0280604, -1110382732.0722318, -1110480306.3793216, -1110203297.7110727, -1109972534.3583376, -1109378081.8792782, -1108212059.413654, -1107235713.2041805, -1106973581.9280007, -1107352339.7860134, -1108730029.862537, -1110425202.83704, -1113220669.4552443, -1115887535.4870913, -1118105356.3628063, -1120001376.8503075, -1121135822.320366, -1122265971.8751016, -1123990217.401155, -1125786729.6230593, -1127784957.2745507, -1129180108.9033566, -1132000461.6688302, -1134675829.8190608, -1137652487.5164194, -1141755948.0463965, -1145340901.5468378, -1148637682.593287, -1151755522.470022, -1154981643.2268832, -1157417488.840151, -1161240429.0989249, -1165411128.671642, -1170521097.1034513, -1176307165.5109766, -1183456865.0039694, -1190535938.6591117, -1197946309.0472982, -1203596565.037139, -1207563038.1241052, -1209707561.5829782, -1211407066.2452552, -1211884576.9201162, -1212778872.005509, -1214041413.8080075, -1215367953.1745043, -1216850831.482193, -1217678325.5351057, -1218854289.54188, -1219325064.8610544, -1219080344.7580786, -1218541313.657531, -1217889833.2067819, -1216552930.1654336, -1216423777.4113154, -1216575252.225508, -1217075384.9826024, -1217391577.901724, -1217838974.57273, -1218131805.6054134, -1218294889.7465532, -1218566666.1755593, -1218790537.5519717, -1218748668.9956846, -1218603191.4941735, -1218004566.4348054, -1217312410.127734, -1217207493.9522285, -1217284002.3834674, -1217644312.51745, -1218039821.6444128, -1218721811.6269798, -1219121088.9265897, -1219014460.8090584, -1218530127.6776083, -1217952335.451711, -1217316073.8666434, -1217035380.1151958, -1216636431.2964456, -1216257015.2945514, -1215658496.1208403, -1215097272.0976632, -1214669859.2064147, -1214593853.4809475, -1214599475.7838447, -1214575440.823035, -1214158828.8008435, -1213482920.2673717, -1212476577.5897374, -1211251374.2198513, -1210284855.590475, -1209302456.065669, -1209106252.6625297, -1209373211.5146718, -1209689421.7984035, -1210021342.495856, -1210650609.3592312, -1211428521.3900626, -1212616111.4257205, -1213820075.2948189, -1215320588.7144456, -1217175082.2739282, -1219703351.4585004, -1222007827.120464, -1224637375.5900724, -1228367798.912171, -1234853879.862459, -1247222219.867692, -1268562808.1616178, -1302034822.9569275, -1347823631.0776038, -1402753916.9445229, -1458826717.3262982, -1505843092.0970414, -1534278782.249077, -1543955545.8994718, -1600409154.893352], 'var_stat': [12665413908.91729, 11145088801.244318, 12567119446.035736, 11758392758.06822, 11200687982.736668, 11551903443.711124, 12880777868.435602, 14084854368.236998, 14394011058.866192, 14678818621.277662, 15346278722.626339, 16268053979.757076, 17191705347.854794, 17877540386.548733, 18251857849.077663, 18392628178.710472, 18645534548.4045, 19018598212.22902, 19366711357.782673, 19655730286.72857, 19890681996.786858, 20094163350.461906, 20227774955.225887, 20423525628.66887, 20669928826.76939, 20882313568.247944, 21062392676.270527, 21126648821.879055, 21185210734.751118, 21209014745.520447, 21182293842.91236, 21197433134.875977, 21302147790.662144, 21504666657.651955, 21781818550.89697, 21996170165.145462, 22217169779.096275, 22431161762.176693, 22672708668.38104, 22922683961.072956, 23101137011.201683, 23249680793.556847, 23358894817.24979, 23422895267.919228, 23449479198.303394, 23464433357.671055, 23469197140.124596, 23459013479.866177, 23447935341.542686, 23422585038.052387, 23375601301.949135, 23338397991.497776, 23329682884.21905, 23348002892.39853, 23406274659.89975, 23478242518.92228, 23592891371.876236, 23703885161.772205, 23797158601.65954, 23875230355.66992, 23918333664.3946, 23968582109.371258, 24040547318.081936, 24112364295.110058, 24189973697.612144, 24242165205.640236, 24364255205.82311, 24472408850.760197, 24590211203.05312, 24763026764.005527, 24909192634.69144, 25043438176.23281, 25167141466.500504, 25297108031.48665, 25395377064.0999, 25550930772.86505, 25721404827.10336, 25931101211.156487, 26168988710.098465, 26465528802.762875, 26760033029.443783, 27075408488.605213, 27316626931.655052, 27487275073.52796, 27579518448.2332, 27652308513.875782, 27673412508.45838, 27711509210.702576, 27767312240.641487, 27827464683.295334, 27894794590.957966, 27935988489.16511, 27992337099.891083, 28019655483.58796, 28014286886.252903, 27996189233.857716, 27973078840.875465, 27920045013.68706, 27917103211.22359, 27927566165.64652, 27953525818.61368, 27973386070.140022, 27999317832.502476, 28019494120.641834, 28033010746.452637, 28051086123.896503, 28066195174.191753, 28068570977.318798, 28064890246.85437, 28042424375.860577, 28015849655.869568, 28014812222.566605, 28021039053.959835, 28039270607.169422, 28058271295.10199, 28088976520.10178, 28107824988.74732, 28105633030.784756, 28087681357.818607, 28065484299.963837, 28039555887.004284, 28028214431.52875, 28011714871.929447, 27995603790.480755, 27970125897.561134, 27946436130.511288, 27929044772.5522, 27926612443.390316, 27926256324.387302, 27924771848.71099, 27905526922.390133, 27876268519.168198, 27832532606.552593, 27779497699.976765, 27737034351.907337, 27692129825.179924, 27684252911.371475, 27698882622.878677, 27712387157.27985, 27726474638.933037, 27752647691.051613, 27786197932.382797, 27836378752.662235, 27887415700.334576, 27949784230.702114, 28028117657.84245, 28136313097.200474, 28234098926.207996, 28345845477.25874, 28507222800.146496, 28793832339.90449, 29350765483.070816, 30328262350.231213, 31894930713.76519, 34093669067.422382, 36801959396.22739, 39638995447.49344, 42088579425.44825, 43616108982.85117, 44152063315.31461, 47464832889.5967], 'frame_num': 54129649}\n" - ] - } - ], - "source": [ - "print(cmvn_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "accurate-terminal", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "dominant-abuse", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " \n", - "process 1000 wavs,450240 frames\n", - " \n", - "process 2000 wavs,886411 frames\n", - " \n", - "process 3000 wavs,1352580 frames\n", - " \n", - "process 4000 wavs,1814397 frames\n", - " \n", - "process 5000 wavs,2356587 frames\n", - " \n", - "process 6000 wavs,2825310 frames\n", - " \n", - "process 7000 wavs,3272506 frames\n", - " \n", - "process 8000 wavs,3688045 frames\n", - " \n", - "process 9000 wavs,4134669 frames\n", - " \n", - "process 10000 wavs,4586357 frames\n", - " \n", - "process 11000 wavs,5014429 frames\n", - " \n", - "process 12000 wavs,5453334 frames\n", - " \n", - "process 13000 wavs,5892888 frames\n", - " \n", - "process 14000 wavs,6316059 frames\n", - " \n", - "process 15000 wavs,6728870 frames\n", - " \n", - "process 16000 wavs,7199442 frames\n", - " \n", - "process 17000 wavs,7629055 frames\n", - " \n", - "process 18000 wavs,8083729 frames\n", - " \n", - "process 19000 wavs,8519732 frames\n", - " \n", - "process 20000 wavs,8895694 frames\n", - " \n", - "process 21000 wavs,9341778 frames\n", - " \n", - "process 22000 wavs,9796126 frames\n", - " \n", - "process 23000 wavs,10236057 frames\n", - " \n", - "process 24000 wavs,10687461 frames\n", - " \n", - "process 25000 wavs,11113082 frames\n", - " \n", - "process 26000 wavs,11544482 frames\n", - " \n", - "process 27000 wavs,11996273 frames\n", - " \n", - "process 28000 wavs,12456350 frames\n", - " \n", - "process 29000 wavs,12900895 frames\n", - " \n", - "process 30000 wavs,13330353 frames\n", - " \n", - "process 31000 wavs,13736568 frames\n", - " \n", - "process 32000 wavs,14158472 frames\n", - " \n", - "process 33000 wavs,14625316 frames\n", - " \n", - "process 34000 wavs,15036206 frames\n", - " \n", - "process 35000 wavs,15514001 frames\n", - " \n", - "process 36000 wavs,16004323 frames\n", - " \n", - "process 37000 wavs,16418799 frames\n", - " \n", - "process 38000 wavs,16840100 frames\n", - " \n", - "process 39000 wavs,17287752 frames\n", - " \n", - "process 40000 wavs,17776206 frames\n", - " \n", - "process 41000 wavs,18243209 frames\n", - " \n", - "process 42000 wavs,18690449 frames\n", - " \n", - "process 43000 wavs,19137940 frames\n", - " \n", - "process 44000 wavs,19553966 frames\n", - " \n", - "process 45000 wavs,19969813 frames\n", - " \n", - "process 46000 wavs,20440963 frames\n", - " \n", - "process 47000 wavs,20862022 frames\n", - " \n", - "process 48000 wavs,21292801 frames\n", - " \n", - "process 49000 wavs,21713004 frames\n", - " \n", - "process 50000 wavs,22146346 frames\n", - " \n", - "process 51000 wavs,22596172 frames\n", - " \n", - "process 52000 wavs,23074160 frames\n", - " \n", - "process 53000 wavs,23499823 frames\n", - " \n", - "process 54000 wavs,23942151 frames\n", - " \n", - "process 55000 wavs,24390566 frames\n", - " \n", - "process 56000 wavs,24833905 frames\n", - " \n", - "process 57000 wavs,25307270 frames\n", - " \n", - "process 58000 wavs,25748720 frames\n", - " \n", - "process 59000 wavs,26185964 frames\n", - " \n", - "process 60000 wavs,26663953 frames\n", - " \n", - "process 61000 wavs,27117720 frames\n", - " \n", - "process 62000 wavs,27585349 frames\n", - " \n", - "process 63000 wavs,28032693 frames\n", - " \n", - "process 64000 wavs,28487074 frames\n", - " \n", - "process 65000 wavs,28956462 frames\n", - " \n", - "process 66000 wavs,29436358 frames\n", - " \n", - "process 67000 wavs,29918569 frames\n", - " \n", - "process 68000 wavs,30325682 frames\n", - " \n", - "process 69000 wavs,30762528 frames\n", - " \n", - "process 70000 wavs,31182319 frames\n", - " \n", - "process 71000 wavs,31627526 frames\n", - " \n", - "process 72000 wavs,32070556 frames\n", - " \n", - "process 73000 wavs,32504534 frames\n", - " \n", - "process 74000 wavs,32972775 frames\n", - " \n", - "process 75000 wavs,33409637 frames\n", - " \n", - "process 76000 wavs,33847861 frames\n", - " \n", - "process 77000 wavs,34298647 frames\n", - " \n", - "process 78000 wavs,34721536 frames\n", - " \n", - "process 79000 wavs,35159236 frames\n", - " \n", - "process 80000 wavs,35628628 frames\n", - " \n", - "process 81000 wavs,36080909 frames\n", - " \n", - "process 82000 wavs,36562496 frames\n", - " \n", - "process 83000 wavs,37042976 frames\n", - " \n", - "process 84000 wavs,37474403 frames\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " \n", - "process 85000 wavs,37943596 frames\n", - " \n", - "process 86000 wavs,38371620 frames\n", - " \n", - "process 87000 wavs,38844874 frames\n", - " \n", - "process 88000 wavs,39292686 frames\n", - " \n", - "process 89000 wavs,39746715 frames\n", - " \n", - "process 90000 wavs,40241800 frames\n", - " \n", - "process 91000 wavs,40672817 frames\n", - " \n", - "process 92000 wavs,41131773 frames\n", - " \n", - "process 93000 wavs,41612001 frames\n", - " \n", - "process 94000 wavs,42084822 frames\n", - " \n", - "process 95000 wavs,42535878 frames\n", - " \n", - "process 96000 wavs,42969365 frames\n", - " \n", - "process 97000 wavs,43430890 frames\n", - " \n", - "process 98000 wavs,43923378 frames\n", - " \n", - "process 99000 wavs,44397370 frames\n", - " \n", - "process 100000 wavs,44883695 frames\n", - " \n", - "process 101000 wavs,45327968 frames\n", - " \n", - "process 102000 wavs,45768860 frames\n", - " \n", - "process 103000 wavs,46205602 frames\n", - " \n", - "process 104000 wavs,46690407 frames\n", - " \n", - "process 105000 wavs,47153089 frames\n", - " \n", - "process 106000 wavs,47628699 frames\n", - " \n", - "process 107000 wavs,48067945 frames\n", - " \n", - "process 108000 wavs,48539256 frames\n", - " \n", - "process 109000 wavs,49030485 frames\n", - " \n", - "process 110000 wavs,49469189 frames\n", - " \n", - "process 111000 wavs,49928968 frames\n", - " \n", - "process 112000 wavs,50370921 frames\n", - " \n", - "process 113000 wavs,50840090 frames\n", - " \n", - "process 114000 wavs,51286249 frames\n", - " \n", - "process 115000 wavs,51715786 frames\n", - " \n", - "process 116000 wavs,52184017 frames\n", - " \n", - "process 117000 wavs,52666156 frames\n", - " \n", - "process 118000 wavs,53109645 frames\n", - " \n", - "process 119000 wavs,53553253 frames\n", - " \n", - "process 120000 wavs,54009877 frames\n", - "{'mean_stat': [700612678.1184504, 704246512.9321843, 720430663.1822729, 754033269.0474415, 798737761.616614, 829467218.4204571, 851246702.9426627, 862261185.2661449, 859339943.6923889, 846303730.8696194, 832995109.605447, 823196536.6029147, 832626008.2569772, 845571326.1936859, 848801373.0562981, 846503549.328017, 836774344.5500796, 823481091.0445303, 820728368.2518216, 804571348.4957463, 795306095.0083207, 811729024.2415155, 805734803.5703195, 813076782.1959459, 806620199.406499, 809655573.8886961, 804371708.9347517, 809272248.6085774, 810322689.7490631, 814294131.1973915, 816262716.0476038, 816213124.2411841, 817158473.4380915, 821414211.5629157, 827408091.5728914, 834353896.0519086, 840094990.3467333, 842613218.6554606, 842070761.1727513, 834970952.5260613, 837020570.8200948, 829592602.7833654, 830116543.8893851, 829482316.3881509, 833397219.4597517, 839251633.3120549, 845475010.4718693, 852378426.7183967, 859563981.8633184, 866063840.5523493, 867790921.9978689, 868215100.5962687, 869683066.032885, 872467375.6674014, 873097681.1780069, 873025823.0543871, 869897292.7201596, 866386426.3869117, 863166726.7256871, 854653071.2244718, 842402803.9000899, 830838253.4144138, 830143002.3536818, 831492285.0310817, 833304371.8781006, 838896092.8621838, 843866088.9578133, 847316792.1429776, 851038022.3643295, 855931698.0149751, 859320543.9795249, 863031001.3470656, 868325062.1832993, 873626971.0115026, 878726636.924209, 884861725.972504, 886920281.5192285, 883056006.5094173, 863719240.7255149, 773378975.9476194], 'var_stat': [9237018652.657722, 9417257721.82426, 10105084297.159702, 11071318522.587782, 12422783727.426847, 13400306419.784964, 14148498843.406874, 14576436982.89939, 14529009036.494726, 14105645932.596651, 13682988821.478252, 13413013425.088106, 13764134927.293928, 14233704806.737064, 14361631309.367067, 14281358385.45644, 13939662689.213865, 13496884231.929493, 13382566162.783987, 12871350930.6626, 12576198160.876635, 13051463889.56708, 12859205935.513906, 13053861416.098743, 12830323588.550724, 12886405923.897238, 12708529922.84171, 12847306110.231739, 12880398489.53404, 13002566299.565536, 13066708060.463543, 13064231286.858614, 13088983337.353497, 13221393824.891022, 13412425607.755072, 13631485149.777075, 13807797519.156103, 13877277485.033077, 13848613909.96762, 13609176326.2529, 13649815250.130072, 13397698404.696907, 13388964704.359968, 13354326914.968012, 13469861474.898457, 13652539440.283333, 13846837321.329163, 14062143714.601675, 14292571198.61228, 14504626563.299246, 14563864749.132776, 14579720287.991764, 14626700787.353922, 14716185568.128899, 14728532777.28015, 14719101187.113443, 14607945896.239174, 14478517828.531614, 14355110561.681187, 14057430280.249746, 13634284490.879377, 13248236002.494394, 13217602306.335958, 13257856701.946049, 13323688441.072674, 13515395318.023148, 13685827169.67645, 13811622609.426846, 13947347160.615082, 14115883822.884943, 14231204526.433033, 14356066668.651815, 14533604268.238445, 14708971788.69237, 14875667326.732443, 15079098318.79331, 15144888989.667963, 15002658970.504765, 14349232841.34513, 11544480117.013124], 'frame_num': 54068199}\n" - ] - } - ], - "source": [ - "import random\n", - "\n", - "import numpy as np\n", - "import paddle\n", - "from paddle.io import DataLoader\n", - "from paddle.io import Dataset\n", - "\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "# https://github.com/PaddlePaddle/Paddle/pull/31481\n", - "class CollateFunc(object):\n", - " ''' Collate function for AudioDataset\n", - " '''\n", - " def __init__(self, feature_func):\n", - " self.feature_func = feature_func\n", - " \n", - " def __call__(self, batch):\n", - " mean_stat = None\n", - " var_stat = None\n", - " number = 0\n", - " for item in batch:\n", - " audioseg = AudioSegment.from_file(item['feat'])\n", - " feat = self.feature_func(audioseg) #(D, T)\n", - "\n", - " sums = np.sum(feat, axis=1)\n", - " if mean_stat is None:\n", - " mean_stat = sums\n", - " else:\n", - " mean_stat += sums\n", - "\n", - " square_sums = np.sum(np.square(feat), axis=1)\n", - " if var_stat is None:\n", - " var_stat = square_sums\n", - " else:\n", - " var_stat += square_sums\n", - "\n", - " number += feat.shape[1]\n", - " return number, mean_stat, var_stat\n", - "\n", - "\n", - "class AudioDataset(Dataset):\n", - " def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):\n", - " self._rng = rng if rng else np.random.RandomState(random_seed)\n", - " manifest = read_manifest(manifest_path)\n", - " if num_samples == -1:\n", - " sampled_manifest = manifest\n", - " else:\n", - " sampled_manifest = self._rng.choice(manifest, num_samples, replace=False)\n", - " self.items = sampled_manifest\n", - "\n", - " def __len__(self):\n", - " return len(self.items)\n", - "\n", - " def __getitem__(self, idx):\n", - " return self.items[idx]\n", - " \n", - " \n", - "augmentation_pipeline = AugmentationPipeline('{}')\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20)\n", - "\n", - "def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - "\n", - "collate_func = CollateFunc(augment_and_featurize)\n", - "\n", - "dataset = AudioDataset(\n", - " args.manifest_path,\n", - " args.num_samples)\n", - "\n", - "batch_size = 20\n", - "data_loader = DataLoader(\n", - " dataset,\n", - " batch_size=batch_size,\n", - " shuffle=False,\n", - " num_workers=args.num_workers,\n", - " collate_fn=collate_func)\n", - "\n", - "with paddle.no_grad():\n", - " all_mean_stat = None\n", - " all_var_stat = None\n", - " all_number = 0\n", - " wav_number = 0\n", - " for i, batch in enumerate(data_loader):\n", - " number, mean_stat, var_stat = batch\n", - " if i == 0:\n", - " all_mean_stat = mean_stat\n", - " all_var_stat = var_stat\n", - " else:\n", - " all_mean_stat += mean_stat\n", - " all_var_stat += var_stat\n", - " all_number += number\n", - " wav_number += batch_size\n", - "\n", - " if wav_number % 1000 == 0:\n", - " print('process {} wavs,{} frames'.format(wav_number,\n", - " all_number))\n", - "\n", - "cmvn_info = {\n", - " 'mean_stat': list(all_mean_stat.tolist()),\n", - " 'var_stat': list(all_var_stat.tolist()),\n", - " 'frame_num': all_number\n", - "}\n", - "print(cmvn_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unlike-search", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/dataloader.ipynb b/.notebook/dataloader.ipynb deleted file mode 100644 index 3de8f64a9be81c036222c76c8581beb4b4b438df..0000000000000000000000000000000000000000 --- a/.notebook/dataloader.ipynb +++ /dev/null @@ -1,389 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "emerging-meter", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "import math\n", - "import random\n", - "import tarfile\n", - "import logging\n", - "import numpy as np\n", - "from collections import namedtuple\n", - "from functools import partial\n", - "\n", - "import paddle\n", - "from paddle.io import Dataset\n", - "from paddle.io import DataLoader\n", - "from paddle.io import BatchSampler\n", - "from paddle.io import DistributedBatchSampler\n", - "from paddle import distributed as dist\n", - "\n", - "from data_utils.utility import read_manifest\n", - "from data_utils.augmentor.augmentation import AugmentationPipeline\n", - "from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n", - "from data_utils.speech import SpeechSegment\n", - "from data_utils.normalizer import FeatureNormalizer\n", - "\n", - "\n", - "from data_utils.dataset import (\n", - " DeepSpeech2Dataset,\n", - " DeepSpeech2DistributedBatchSampler,\n", - " DeepSpeech2BatchSampler,\n", - " SpeechCollator,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "excessive-american", - "metadata": {}, - "outputs": [], - "source": [ - "def create_dataloader(manifest_path,\t\n", - " vocab_filepath,\t\n", - " mean_std_filepath,\t\n", - " augmentation_config='{}',\t\n", - " max_duration=float('inf'),\t\n", - " min_duration=0.0,\t\n", - " stride_ms=10.0,\t\n", - " window_ms=20.0,\t\n", - " max_freq=None,\t\n", - " specgram_type='linear',\t\n", - " use_dB_normalization=True,\t\n", - " random_seed=0,\t\n", - " keep_transcription_text=False,\t\n", - " is_training=False,\t\n", - " batch_size=1,\t\n", - " num_workers=0,\t\n", - " sortagrad=False,\t\n", - " shuffle_method=None,\t\n", - " dist=False):\t\n", - "\n", - " dataset = DeepSpeech2Dataset(\t\n", - " manifest_path,\t\n", - " vocab_filepath,\t\n", - " mean_std_filepath,\t\n", - " augmentation_config=augmentation_config,\t\n", - " max_duration=max_duration,\t\n", - " min_duration=min_duration,\t\n", - " stride_ms=stride_ms,\t\n", - " window_ms=window_ms,\t\n", - " max_freq=max_freq,\t\n", - " specgram_type=specgram_type,\t\n", - " use_dB_normalization=use_dB_normalization,\t\n", - " random_seed=random_seed,\t\n", - " keep_transcription_text=keep_transcription_text)\t\n", - "\n", - " if dist:\t\n", - " batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n", - " dataset,\t\n", - " batch_size,\t\n", - " num_replicas=None,\t\n", - " rank=None,\t\n", - " shuffle=is_training,\t\n", - " drop_last=is_training,\t\n", - " sortagrad=is_training,\t\n", - " shuffle_method=shuffle_method)\t\n", - " else:\t\n", - " batch_sampler = DeepSpeech2BatchSampler(\t\n", - " dataset,\t\n", - " shuffle=is_training,\t\n", - " batch_size=batch_size,\t\n", - " drop_last=is_training,\t\n", - " sortagrad=is_training,\t\n", - " shuffle_method=shuffle_method)\t\n", - "\n", - " def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n", - " \"\"\"\t\n", - " Padding audio features with zeros to make them have the same shape (or\t\n", - " a user-defined shape) within one bach.\t\n", - "\n", - " If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n", - " as the target shape for padding. Otherwise, `padding_to` will be the\t\n", - " target shape (only refers to the second axis).\t\n", - "\n", - " If `flatten` is True, features will be flatten to 1darray.\t\n", - " \"\"\"\t\n", - " new_batch = []\t\n", - " # get target shape\t\n", - " max_length = max([audio.shape[1] for audio, text in batch])\t\n", - " if padding_to != -1:\t\n", - " if padding_to < max_length:\t\n", - " raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n", - " \"than any instance's shape in the batch\")\t\n", - " max_length = padding_to\t\n", - " max_text_length = max([len(text) for audio, text in batch])\t\n", - " # padding\t\n", - " padded_audios = []\t\n", - " audio_lens = []\t\n", - " texts, text_lens = [], []\t\n", - " for audio, text in batch:\t\n", - " padded_audio = np.zeros([audio.shape[0], max_length])\t\n", - " padded_audio[:, :audio.shape[1]] = audio\t\n", - " if flatten:\t\n", - " padded_audio = padded_audio.flatten()\t\n", - " padded_audios.append(padded_audio)\t\n", - " audio_lens.append(audio.shape[1])\t\n", - "\n", - " padded_text = np.zeros([max_text_length])\n", - " if is_training:\n", - " padded_text[:len(text)] = text\t# ids\n", - " else:\n", - " padded_text[:len(text)] = [ord(t) for t in text] # string\n", - " \n", - " texts.append(padded_text)\t\n", - " text_lens.append(len(text))\t\n", - "\n", - " padded_audios = np.array(padded_audios).astype('float32')\t\n", - " audio_lens = np.array(audio_lens).astype('int64')\t\n", - " texts = np.array(texts).astype('int32')\t\n", - " text_lens = np.array(text_lens).astype('int64')\t\n", - " return padded_audios, texts, audio_lens, text_lens\t\n", - "\n", - " loader = DataLoader(\t\n", - " dataset,\t\n", - " batch_sampler=batch_sampler,\t\n", - " collate_fn=partial(padding_batch, is_training=is_training),\t\n", - " num_workers=num_workers)\t\n", - " return loader" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "naval-brave", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/data/manifest.dev',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/aishell/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'linear',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "bearing-physics", - "metadata": {}, - "outputs": [], - "source": [ - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=True,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "classified-melissa", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[22823, 26102, 20195, 37324, 0 , 0 ],\n", - " [22238, 26469, 23601, 22909, 0 , 0 ],\n", - " [20108, 26376, 22235, 26085, 0 , 0 ],\n", - " [36824, 35201, 20445, 25345, 32654, 24863],\n", - " [29042, 27748, 21463, 23456, 0 , 0 ]])\n", - "test raw 大时代里\n", - "test raw 煲汤受宠\n", - "audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [4, 4, 4, 6, 4])\n", - "audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n", - " [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n", - " [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n", - " [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n", - " [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n", - " [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n", - " [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n", - " [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n", - " [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n", - " [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n", - " [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n", - " [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n", - " [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n", - " [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n", - " [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n", - " ...,\n", - " [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n", - " [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n", - " [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n", - "\n", - " [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n", - " [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n", - " [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n", - " ...,\n", - " [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n", - " [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n", - " [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n" - ] - } - ], - "source": [ - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test', text)\n", - " print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('audio len', audio_len)\n", - " print('test len', text_len)\n", - " print('audio', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unexpected-skating", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "minus-modern", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/.notebook/dataloader_with_tokens_tokenids.ipynb b/.notebook/dataloader_with_tokens_tokenids.ipynb deleted file mode 100644 index 7d93dd00940071c47c9674d75a3e57074f6bfff1..0000000000000000000000000000000000000000 --- a/.notebook/dataloader_with_tokens_tokenids.ipynb +++ /dev/null @@ -1,1204 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "medieval-monday", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "emerging-meter", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n" - ] - } - ], - "source": [ - "import math\n", - "import random\n", - "import tarfile\n", - "import logging\n", - "import numpy as np\n", - "from collections import namedtuple\n", - "from functools import partial\n", - "\n", - "import paddle\n", - "from paddle.io import Dataset\n", - "from paddle.io import DataLoader\n", - "from paddle.io import BatchSampler\n", - "from paddle.io import DistributedBatchSampler\n", - "from paddle import distributed as dist\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "excessive-american", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "naval-brave", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:93] register user softmax to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:97] register user log_softmax to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:101] register user sigmoid to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:105] register user log_sigmoid to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:109] register user relu to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:119] override cat of paddle if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:133] override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:144] override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:164] override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:179] override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:185] override eq of paddle if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:195] override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:212] override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:223] register user view to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:233] register user view_as to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:259] register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:277] register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:288] register user fill_ to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:298] register user repeat to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:303] register user softmax to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:308] register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:312] register user relu to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:322] register user type_as to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:337] register user to to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:346] register user float to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:356] register user tolist to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:371] register user glu to paddle.nn.functional, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:422] override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:428] register user Module to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:434] register user ModuleList to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:450] register user GLU to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:483] register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:489] register user export to paddle.jit, remove this when fixed!\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/tiny/s1/data/spm_bpe', 'infer_manifest': 'examples/tiny/s1/data/manifest.tiny', 'mean_std_path': 'examples/tiny/s1/data/mean_std.npz', 'vocab_path': 'examples/tiny/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/tiny/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('unit_type', str,\n", - " 'char',\n", - " \"Options: char, word, spm.\",\n", - " choices=['char', 'word', 'spm'])\n", - "add_arg('spm_model_prefix', str,\n", - " 'examples/tiny/s1/data/spm_bpe',\n", - " \"spm model prefix.\",)\n", - "add_arg('infer_manifest', str,\n", - " 'examples/tiny/s1/data/manifest.tiny',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/tiny/s1/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/tiny/s1/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/tiny/s1/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n", - "add_arg('delta_delta', bool, False, \"delta delta\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "wired-principal", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/aishell/s1/data/spm_bpe', 'infer_manifest': 'examples/aishell/s1/data/manifest.test', 'mean_std_path': '', 'vocab_path': 'examples/aishell/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('unit_type', str,\n", - " 'char',\n", - " \"Options: char, word, spm.\",\n", - " choices=['char', 'word', 'spm'])\n", - "add_arg('spm_model_prefix', str,\n", - " 'examples/aishell/s1/data/spm_bpe',\n", - " \"spm model prefix.\",)\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/s1/data/manifest.test',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " '',\n", - " \"examples/aishell/s1/data/mean_std.npz, Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/s1/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/s1/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n", - "add_arg('delta_delta', bool, False, \"delta delta\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "bearing-physics", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n", - "from deepspeech.frontend.speech import SpeechSegment\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "\n", - "\n", - "from deepspeech.io.collator import SpeechCollator\n", - "from deepspeech.io.dataset import ManifestDataset\n", - "from deepspeech.io.sampler import (\n", - " SortagradDistributedBatchSampler,\n", - " SortagradBatchSampler,\n", - ")\n", - "from deepspeech.io import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " unit_type=args.unit_type,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " spm_model_prefix=args.spm_model_prefix,\n", - " augmentation_config='{}',\n", - " max_input_len=27.0,\n", - " min_input_len=0.0,\n", - " max_output_len=float('inf'),\n", - " min_output_len=0.0,\n", - " max_output_input_ratio=float('inf'),\n", - " min_output_input_ratio=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=True,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " num_workers=0,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "classified-melissa", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[232 387 331 ... 249 249 262] int16\n", - "fbank\n", - "[-138 -219 -192 ... 338 324 351] int16\n", - "fbank\n", - "[ 694 1175 1022 ... 553 514 627] int16\n", - "fbank\n", - "[-39 -79 -53 ... 139 172 99] int16\n", - "fbank\n", - "[-277 -480 -425 ... 758 767 739] int16\n", - "fbank\n", - "[ 399 693 609 ... 1291 1270 1291] int16\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " if arr.dtype == np.object:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[ -750 -1254 -1107 ... 2276 1889 2067] int16\n", - "fbank\n", - "[ -127 -199 -149 ... -5243 -5065 -5398] int16\n", - "fbank\n", - "[ 465 783 677 ... 980 903 1008] int16\n", - "fbank\n", - "[ 90 160 157 ... -2 -16 -21] int16\n", - "fbank\n", - "[ 213 345 295 ... 2483 2246 2501] int16\n", - "fbank\n", - "[ -86 -159 -131 ... 270 258 290] int16\n", - "fbank\n", - "[-1023 -1714 -1505 ... 1532 1596 1575] int16\n", - "fbank\n", - "[-366 -602 -527 ... 374 370 379] int16\n", - "fbank\n", - "[ 761 1275 1127 ... 369 413 295] int16\n", - "fbank\n", - "[382 621 550 ... 161 161 174] int16\n", - "fbank\n", - "[ -28 -91 -120 ... 28 34 11] int16\n", - "fbank\n", - "[ -5 -5 -5 ... 268 294 341] int16\n", - "fbank\n", - "[240 417 684 ... 267 262 219] int16\n", - "fbank\n", - "[131 206 194 ... 383 320 343] int16\n", - "test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[31069, 21487, 29233, 30340, 20320, -1 , -1 ],\n", - " [20540, 24471, 19968, 25552, 30340, 26159, -1 ],\n", - " [36825, 20010, 31243, 24230, 26159, 32654, 30340],\n", - " [20108, 21040, 20108, -1 , -1 , -1 , -1 ],\n", - " [21435, 34892, 25919, 21270, -1 , -1 , -1 ]])\n", - "fbank\n", - "[1155 1890 1577 ... 1092 989 1130] int16\n", - "fbank\n", - "[296 358 296 ... 140 140 168] int16\n", - "fbank\n", - "[-50 -91 -63 ... 104 104 86] int16\n", - "fbank\n", - "[-37 -66 -50 ... -31 -45 -52] int16\n", - "fbank\n", - "[-401 -652 -547 ... -339 -307 -344] int16\n", - "fbank\n", - "[-21 -47 -51 ... 94 81 107] int16\n", - "fbank\n", - "[ 533 887 755 ... 3074 2853 3254] int16\n", - "fbank\n", - "[ 44 71 66 ... -628 -733 -601] int16\n", - "fbank\n", - "[ 50 86 79 ... 129 116 138] int16\n", - "fbank\n", - "[ 92 146 126 ... -208 -193 -179] int16\n", - "test raw: 祝可爱的你\n", - "test raw: 去行政化\n", - "audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [184, 194, 196, 204, 207])\n", - "test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [5, 6, 7, 3, 4])\n", - "audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[12.25633812, 12.61639309, 10.36936474, ..., 13.02949619, 11.51365757, 10.59789085],\n", - " [13.32148266, 13.41071606, 11.43800735, ..., 13.69783783, 12.83939362, 11.51259613],\n", - " [12.62640572, 12.53621101, 10.97212505, ..., 13.33757591, 12.32293034, 10.75493717],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[10.99619484, 11.35202599, 9.56922054 , ..., 9.94971657 , 9.88354111 , 9.55315971 ],\n", - " [10.44461155, 9.81688595 , 5.62538481 , ..., 10.60468388, 10.94417381, 9.42646980 ],\n", - " [10.23835754, 10.23407459, 7.99464273 , ..., 10.68097591, 9.91640091 , 10.04131031],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[14.10299397, 14.50298119, 12.87738323, ..., 12.62796497, 12.69949627, 11.43171215],\n", - " [13.85035992, 13.15289116, 10.66541386, ..., 13.34364223, 13.46972179, 11.02160740],\n", - " [13.19866467, 13.23537827, 11.65760899, ..., 12.72559357, 12.42716217, 11.74562359],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.85668373, 12.82431412, 11.68144703, ..., 14.10119247, 15.12791920, 13.68221378],\n", - " [13.19507027, 13.40244961, 11.43618393, ..., 13.32919979, 13.68267441, 12.73429012],\n", - " [13.02173328, 12.92082500, 11.44303989, ..., 12.77793121, 13.10915661, 11.77327728],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.90771198, 13.40234852, 13.01435471, ..., 13.80359459, 14.08088684, 13.17883396],\n", - " [14.06678009, 14.06943512, 12.52837276, ..., 13.66423225, 13.66300583, 13.60142994],\n", - " [12.58743191, 12.94520760, 11.75190544, ..., 14.28828907, 14.08229160, 13.02433395],\n", - " ...,\n", - " [16.20896912, 16.42283821, 14.94358730, ..., 12.91146755, 12.66766262, 11.76361752],\n", - " [13.49324894, 14.14653301, 13.16490936, ..., 13.23435783, 13.45378494, 12.60386276],\n", - " [15.56288910, 15.92445087, 14.90794277, ..., 13.43840790, 13.41075516, 12.55605984]]])\n" - ] - } - ], - "source": [ - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test:', text)\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('audio len:', audio_len)\n", - " print('test len:', text_len)\n", - " print('audio:', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unexpected-skating", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "minus-modern", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[232 387 331 ... 249 249 262] int16\n", - "fbank\n", - "[-138 -219 -192 ... 338 324 351] int16\n", - "fbank\n", - "[ 694 1175 1022 ... 553 514 627] int16\n", - "fbank\n", - "[-39 -79 -53 ... 139 172 99] int16\n", - "fbank\n", - "[-277 -480 -425 ... 758 767 739] int16\n", - "fbank\n", - "test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[2695, 505, 2332, 2553, 169, -1 , -1 ],\n", - " [ 230, 1237, 2 , 1556, 2553, 1694, -1 ],\n", - " [3703, 28 , 2739, 1172, 1694, 2966, 2553],\n", - " [ 70 , 355, 70 , -1 , -1 , -1 , -1 ],\n", - " [ 477, 3363, 1621, 412, -1 , -1 , -1 ]])\n", - "[ 399 693 609 ... 1291 1270 1291] int16\n", - "test raw: ઇǹज৹©\n", - "test raw: ǝണٕƜ\n", - "test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [5, 6, 7, 3, 4])\n", - "audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[12.25794601, 12.61855793, 10.37306023, ..., 13.12571049, 11.53678799, 10.32210350],\n", - " [13.32333183, 13.41336918, 11.44248962, ..., 13.65861225, 12.79308128, 11.31168747],\n", - " [12.62584686, 12.53506088, 10.96861362, ..., 13.32526493, 12.41560936, 10.71458912],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[11.00003052, 11.35529137, 9.56384087 , ..., 10.06063652, 10.16322994, 9.43149185 ],\n", - " [10.44556236, 9.81155300 , 5.49400425 , ..., 10.84116268, 11.02734756, 9.42253590 ],\n", - " [10.23620510, 10.23321152, 7.99466419 , ..., 10.93381882, 10.28395081, 10.00841141],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[14.10379314, 14.50375748, 12.87825108, ..., 12.68065739, 12.62359715, 11.53773308],\n", - " [13.84964657, 13.15079498, 10.67198086, ..., 13.24875164, 13.45796680, 10.97363472],\n", - " [13.19808197, 13.23482990, 11.65900230, ..., 12.70375061, 12.41395664, 11.88668156],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.85676289, 12.82410812, 11.67961884, ..., 14.12018299, 15.14850044, 13.80065727],\n", - " [13.19532776, 13.40243340, 11.43492508, ..., 13.29144669, 13.70278549, 12.67841339],\n", - " [13.02196407, 12.92111111, 11.43998623, ..., 12.71165752, 13.16518497, 11.92028046],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.90661621, 13.40162563, 13.01394463, ..., 13.84056377, 14.11240959, 13.21227264],\n", - " [14.06642914, 14.06922340, 12.52955723, ..., 13.55829811, 13.60157204, 13.50268650],\n", - " [12.58881378, 12.94780254, 11.75758171, ..., 14.29055786, 14.12165928, 13.02695847],\n", - " ...,\n", - " [16.20891571, 16.42290306, 14.94398117, ..., 12.86083794, 12.63515949, 11.67581463],\n", - " [13.49345875, 14.14656067, 13.16498375, ..., 13.28024578, 13.40956783, 12.70357513],\n", - " [15.56265163, 15.92387581, 14.90643024, ..., 13.45694065, 13.44703197, 12.81099033]]])\n", - "audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [184, 194, 196, 204, 207])\n" - ] - } - ], - "source": [ - "keep_transcription_text=False\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " unit_type=args.unit_type,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " spm_model_prefix=args.spm_model_prefix,\n", - " augmentation_config='{}',\n", - " max_input_len=27.0,\n", - " min_input_len=0.0,\n", - " max_output_len=float('inf'),\n", - " min_output_len=0.0,\n", - " max_output_input_ratio=float('inf'),\n", - " min_output_input_ratio=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=keep_transcription_text,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " num_workers=0,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)\n", - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test:', text)\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('test len:', text_len)\n", - " print('audio:', audio)\n", - " print('audio len:', audio_len)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "competitive-mounting", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "knowing-military", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 1, 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False, 'stride_ms': 10.0, 'window_ms': 25.0, 'sample_rate': 16000, 'manifest_path': 'examples/aishell/s1/data/manifest.train', 'output_path': 'examples/aishell/s1/data/mean_std.npz'}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "\n", - "add_arg('num_samples', int, 1, \"# of samples to for statistics.\")\n", - "add_arg('specgram_type', str, 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc, fbank.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n", - "add_arg('delta_delta', bool, False,\"Audio feature with delta delta.\")\n", - "add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n", - "add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n", - "add_arg('sample_rate', int, 16000, \"target sample rate.\")\n", - "add_arg('manifest_path', str,\n", - " 'examples/aishell/s1/data/manifest.train',\n", - " \"Filepath of manifest to compute normalizer's mean and stddev.\")\n", - "add_arg('output_path', str,\n", - " 'examples/aishell/s1/data/mean_std.npz',\n", - " \"Filepath of write mean and stddev to (.npz).\")\n", - "args = parser.parse_args([])\n", - "print(vars(args))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "unnecessary-province", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "\n", - "\n", - "def mean(args):\n", - " augmentation_pipeline = AugmentationPipeline('{}')\n", - " audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20,\n", - " dither=0.0)\n", - "\n", - " def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - " normalizer = FeatureNormalizer(\n", - " mean_std_filepath=None,\n", - " manifest_path=args.manifest_path,\n", - " featurize_func=augment_and_featurize,\n", - " num_samples=args.num_samples)\n", - " normalizer.write_to_file(args.output_path)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "interested-camping", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "[54. 90. 77. ... 58. 58. 61.]\n", - "29746\n", - "fbank\n", - "[54 90 77 ... 58 58 61] int16\n", - "(184, 80) float64\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n", - "(184, 80) float64\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n" - ] - } - ], - "source": [ - "wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/test/S0916/BAC009S0916W0426.wav'\n", - "test='祝可爱的你'\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=False,\n", - " target_dB=-20,\n", - " dither=0.0)\n", - "samples = AudioSegment.from_file(wav)\n", - "print(samples._samples)\n", - "print(samples._samples * 2**15)\n", - "print(len(samples._samples))\n", - "feat = audio_featurizer.featurize(samples, False, False)\n", - "feat = feat.T\n", - "print(feat.shape, feat.dtype)\n", - "print(feat)\n", - "\n", - "from python_speech_features import logfbank\n", - "max_freq = args.sample_rate / 2\n", - "fbank_feat = logfbank(\n", - " signal=samples.to('int16'),\n", - " samplerate=args.sample_rate,\n", - " winlen=0.001 * args.window_ms,\n", - " winstep=0.001 * args.stride_ms,\n", - " nfilt=args.feat_dim,\n", - " nfft=512,\n", - " lowfreq=20,\n", - " highfreq=max_freq,\n", - " preemph=0.97,\n", - " dither=0.0,\n", - " wintype='povey')\n", - "print(fbank_feat.shape, fbank_feat.dtype)\n", - "print(fbank_feat)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "numeric-analyst", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(184, 160)\n", - "[ 8.59522397 8.43148278 8.36414052 8.45487173 8.31761643 8.04843683\n", - " 8.01683696 7.6574614 7.95521932 8.22945157 10.20138275 9.0447775\n", - " 9.14763398 9.18184349 9.03801065 9.04852307 8.67706728 8.71894271\n", - " 9.54553655 9.19535135 8.76413076 8.47828946 8.52586143 8.49469288\n", - " 8.72461247 8.28562879 8.11581393 7.99922156 7.91023364 8.04142296\n", - " 7.89762773 7.76257636 8.32043745 8.01592886 8.34109665 8.90115454\n", - " 8.48246945 7.98658664 8.05745122 8.11384088 8.18864479 8.8091827\n", - " 11.8067711 13.25258218 14.44311795 13.90515283 14.00120623 13.99801252\n", - " 13.81595394 13.6379904 13.3574897 13.14933334 12.96518543 13.02601156\n", - " 12.70246737 12.54410834 12.15615068 11.86574681 11.67497882 10.79645481\n", - " 10.48150035 10.03758575 10.05637027 9.92891308 10.06923218 12.43382431\n", - " 12.71428321 14.33135052 13.94470959 14.29188291 14.11483993 14.03496606\n", - " 13.78167331 13.66701466 14.40308625 14.73934137 15.09569382 14.89565815\n", - " 15.10519995 14.94383582 15.03275563 15.42194679 15.29219967 15.41602274\n", - " 15.39242545 15.76836177 16.259222 16.47777231 17.03366795 17.46165793\n", - " 17.52596217 17.78844031 17.99878075 18.11446843 17.95761578 17.99900337\n", - " 17.86282737 17.7290163 17.47686504 17.43425516 17.07750485 16.64395242\n", - " 15.68217043 14.90058399 14.45645737 14.0405463 14.89549542 16.00405781\n", - " 16.27301689 16.37572895 16.31219037 16.31765447 16.44819716 16.36281089\n", - " 16.24932823 15.79302555 14.76361963 13.95761882 13.48917053 13.45543501\n", - " 13.00091327 13.13854248 13.74596395 13.86340629 14.00656109 13.77432101\n", - " 13.64267001 13.35742634 13.23042234 12.97916104 12.80694468 12.70005006\n", - " 13.2802483 13.22644525 13.14579624 13.02536594 13.36511022 11.37167205\n", - " 12.11598045 12.47619798 12.83885973 11.63880287 11.42083924 11.08747705\n", - " 11.04093403 11.11263149 10.74353319 10.58734669 10.46180738 10.34157335\n", - " 9.63131146 9.70582692 9.29059204 8.94583657 8.66065094 8.46799095\n", - " 8.25064103 8.30239167 8.19463371 8.12104567 8.02731234 8.06412715\n", - " 7.84889951 7.73090283 7.74119562 7.85444657 7.80717312 7.7129933\n", - " 7.84087442 7.77907788 7.60660865 7.55051479 7.458385 7.496416\n", - " 7.69519793 7.49086759 7.32199493 8.01617458 7.58525375 7.06661122\n", - " 6.94653756 7.19874283 7.28515661 7.17574078]\n", - "(184,)\n", - "(184,)\n", - "[1.48370471 1.52174523 1.46984238 1.67010478 1.88757689 1.68825992\n", - " 1.74270259 1.55497318 1.29200818 1.68446481 1.88133219 1.97138928\n", - " 2.15910096 2.3149476 1.9820247 2.07694378 1.93498835 2.01493974\n", - " 2.39156824 2.02396518 1.69586449 1.63808752 1.64020228 1.43573473\n", - " 1.93092656 1.37466294 1.34704929 1.59600739 1.03960441 1.45276496\n", - " 1.59360131 1.57466343 1.89491479 1.79333746 1.32701974 1.49441767\n", - " 1.51466756 1.63497989 1.42858074 1.51135396 1.61077201 1.81066387\n", - " 1.83367783 2.3507094 2.87885378 3.26231227 2.1313117 1.98557548\n", - " 1.99105426 2.26150533 2.34298751 2.44621608 2.39201042 2.41226503\n", - " 2.5142992 3.03777565 2.81592295 2.75117863 2.78324175 2.68819666\n", - " 2.8945782 2.84464168 2.680973 2.78397395 2.47996808 1.71829563\n", - " 1.60636949 1.65992483 1.38122631 1.74831825 2.16006884 1.68076185\n", - " 1.69329487 1.44929837 1.63763312 1.80101076 2.01166253 2.03254244\n", - " 1.9583913 2.04542255 2.00859694 2.16600883 2.16095629 1.97541122\n", - " 2.13807632 2.06386436 2.2154187 2.84205688 2.54862449 2.64321545\n", - " 2.6805773 2.52300146 2.53209001 2.54682059 2.4521937 2.43155532\n", - " 2.42571275 2.23421289 2.23164529 2.23597192 2.14215121 2.10406703\n", - " 2.07962874 1.88506161 1.80092372 1.61156092 1.77426835 1.98765563\n", - " 2.0356793 1.87964187 1.779513 1.87187681 1.76463632 1.70978684\n", - " 1.76471778 1.75604749 1.62792552 1.73929352 1.6887024 1.8677704\n", - " 2.17342368 2.08166072 2.14567453 2.15936953 2.18351006 2.41010388\n", - " 2.26101752 2.25468001 2.23739715 2.15395133 2.04547813 1.92038843\n", - " 1.85491264 1.91905927 2.16709365 1.99924152 2.1850471 2.55461622\n", - " 2.72476673 1.69682926 1.73249614 2.06992695 2.1210591 1.66854454\n", - " 1.63907505 1.32203822 1.38992558 1.2436937 1.17932877 1.02963653\n", - " 1.26085036 1.16997132 1.09339504 1.14188689 1.18675772 1.31859788\n", - " 1.21746591 1.3872131 1.26095274 1.34885761 1.46633543 1.64506975\n", - " 1.36013821 1.45574721 1.43766588 1.65119054 1.57163772 1.55082968\n", - " 1.29413316 1.38351736 1.64234673 1.57186432 1.45381083 1.71204761\n", - " 1.51828607 1.30639985 1.32928395 1.49004237 1.6057589 1.81815735\n", - " 1.67784678 1.72180861 1.60703743 1.64850255]\n" - ] - } - ], - "source": [ - "a = np.hstack([feat, feat])\n", - "print(a.shape)\n", - "m = np.mean(a, axis=1)\n", - "print(m)\n", - "print(m.shape)\n", - "std = np.std(a, axis=1)\n", - "print(std.shape)\n", - "print(std)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "nonprofit-potato", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "hispanic-ethics", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torchaudio\n", - "import torchaudio.compliance.kaldi as kaldi\n", - "import torchaudio.sox_effects as sox_effects\n", - "from torch.nn.utils.rnn import pad_sequence\n", - "torchaudio.set_audio_backend(\"sox\")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "changing-calvin", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 29746])\n", - "tensor([[54., 90., 77., ..., 58., 58., 61.]])\n", - "(184, 80)\n", - "[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n", - " [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n", - " [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n", - " ...\n", - " [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n", - " [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n", - " [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n", - "-----------\n", - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "(184, 80)\n", - "[[-10.177039 -10.717326 -15.46954 ... -10.546229 -11.897424 -12.987689]\n", - " [ -9.750411 -10.476343 -14.485752 ... -9.557108 -10.436023 -11.955799]\n", - " [-10.525113 -10.798049 -13.46475 ... -10.343097 -11.101464 -12.832712]\n", - " ...\n", - " [-10.649446 -10.907673 -14.056403 ... -10.578607 -11.790988 -12.038239]\n", - " [-10.816959 -11.114918 -12.88781 ... -10.570049 -11.199847 -13.101528]\n", - " [-14.320845 -13.03106 -13.036756 ... -10.829194 -11.171779 -12.634331]]\n", - "**************\n", - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "[54. 90. 77. ... 58. 58. 61.] float32\n", - "(184, 80)\n", - "[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n", - " [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n", - " [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n", - " ...\n", - " [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n", - " [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n", - " [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: torchaudio.backend.sox_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use \"torchaudio.load\".\n", - " \"\"\"Entry point for launching an IPython kernel.\n" - ] - } - ], - "source": [ - "waveform, sample_rate = torchaudio.load_wav(wav)\n", - "print(waveform.shape)\n", - "print(waveform)\n", - "mat = kaldi.fbank(\n", - " waveform,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)\n", - "\n", - "print('-----------')\n", - "print(samples._samples)\n", - "aud = torch.tensor(samples._samples).view(1, -1)\n", - "mat = kaldi.fbank(\n", - " aud,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)\n", - "\n", - "print('**************')\n", - "print(samples._samples)\n", - "tmp = samples.to('int16').astype('float32')\n", - "print(tmp, tmp.dtype)\n", - "aud = torch.tensor(tmp).view(1, -1)\n", - "mat = kaldi.fbank(\n", - " aud,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "buried-dependence", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "silver-printing", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "outer-space", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(29746,)\n", - "[54 90 77 ... 58 58 61]\n", - "(184, 80)\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n", - "(184, 13)\n", - "[[ 14.73775998 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n", - " 8.86862748]\n", - " [ 15.31274834 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n", - " 6.78353115]\n", - " [ 13.82218765 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n", - " -0.05919222]\n", - " ...\n", - " [ 13.5837844 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n", - " 5.59045842]\n", - " [ 13.75757034 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n", - " 12.0785146 ]\n", - " [ 11.92813809 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n", - " 0.88337965]]\n" - ] - } - ], - "source": [ - "from python_speech_features import mfcc\n", - "from python_speech_features import delta\n", - "from python_speech_features import logfbank\n", - "import scipy.io.wavfile as iowav\n", - "\n", - "(rate,sig) = iowav.read(wav)\n", - "print(sig.shape)\n", - "print(sig)\n", - "\n", - "# note that generally nfilt=40 is used for speech recognition\n", - "fbank_feat = logfbank(sig,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n", - "print(fbank_feat.shape)\n", - "print(fbank_feat)\n", - "\n", - "# the computed fbank coefficents of english.wav with dimension [110,23]\n", - "# [ 12.2865\t12.6906\t13.1765\t15.714\t16.064\t15.7553\t16.5746\t16.9205\t16.6472\t16.1302\t16.4576\t16.7326\t16.8864\t17.7215\t18.88\t19.1377\t19.1495\t18.6683\t18.3886\t20.3506\t20.2772\t18.8248\t18.1899\n", - "# 11.9198\t13.146\t14.7215\t15.8642\t17.4288\t16.394\t16.8238\t16.1095\t16.4297\t16.6331\t16.3163\t16.5093\t17.4981\t18.3429\t19.6555\t19.6263\t19.8435\t19.0534\t19.001\t20.0287\t19.7707\t19.5852\t19.1112\n", - "# ...\n", - "# ...\n", - "# the same with that using kaldi commands: compute-fbank-feats --dither=0.0\n", - "\n", - "mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')\n", - "print(mfcc_feat.shape)\n", - "print(mfcc_feat)\n", - "\n", - "# the computed mfcc coefficents of english.wav with dimension [110,13]\n", - "# [ 17.1337\t-23.3651\t-7.41751\t-7.73686\t-21.3682\t-8.93884\t-3.70843\t4.68346\t-16.0676\t12.782\t-7.24054\t8.25089\t10.7292\n", - "# 17.1692\t-23.3028\t-5.61872\t-4.0075\t-23.287\t-20.6101\t-5.51584\t-6.15273\t-14.4333\t8.13052\t-0.0345329\t2.06274\t-0.564298\n", - "# ...\n", - "# ...\n", - "# the same with that using kaldi commands: compute-mfcc-feats --dither=0.0" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "sporting-school", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(184, 80)\n", - "[[-10.17703627 -10.71732606 -15.46954014 ... -10.54623152 -11.89742148\n", - " -12.98770428]\n", - " [ -9.75040771 -10.47634331 -14.48575413 ... -9.55710616 -10.43602673\n", - " -11.95581463]\n", - " [-10.52510987 -10.79804975 -13.46475161 ... -10.34309947 -11.10146239\n", - " -12.83273051]\n", - " ...\n", - " [-10.64944197 -10.90767335 -14.05640404 ... -10.57860915 -11.7909807\n", - " -12.03825021]\n", - " [-10.8169558 -11.11491806 -12.88781116 ... -10.57004889 -11.19985048\n", - " -13.10154358]\n", - " [-14.32084168 -13.03106051 -13.03675699 ... -10.82919465 -11.17177892\n", - " -12.63434434]]\n", - "(184, 13)\n", - "[[ -6.05665544 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n", - " 8.86862748]\n", - " [ -5.48166707 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n", - " 6.78353115]\n", - " [ -6.97222776 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n", - " -0.05919222]\n", - " ...\n", - " [ -7.21063102 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n", - " 5.59045842]\n", - " [ -7.03684508 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n", - " 12.0785146 ]\n", - " [ -8.86627732 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n", - " 0.88337965]]\n" - ] - } - ], - "source": [ - "fbank_feat = logfbank(samples._samples,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n", - "print(fbank_feat.shape)\n", - "print(fbank_feat)\n", - "\n", - "mfcc_feat = mfcc(samples._samples,dither=0,useEnergy=True,wintype='povey')\n", - "print(mfcc_feat.shape)\n", - "print(mfcc_feat)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "restricted-license", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "specialized-threat", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/espnet_dataloader.ipynb b/.notebook/espnet_dataloader.ipynb deleted file mode 100644 index 1bfc13e3c169260cb76df82b703ae7a48d202aa8..0000000000000000000000000000000000000000 --- a/.notebook/espnet_dataloader.ipynb +++ /dev/null @@ -1,1541 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 147, - "id": "extensive-venice", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/\n" - ] - }, - { - "data": { - "text/plain": [ - "'/'" - ] - }, - "execution_count": 147, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 148, - "id": "correct-window", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "manifest.dev\t manifest.test-clean\t manifest.train\r\n", - "manifest.dev.raw manifest.test-clean.raw manifest.train.raw\r\n" - ] - } - ], - "source": [ - "!ls /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/" - ] - }, - { - "cell_type": "code", - "execution_count": 149, - "id": "exceptional-cheese", - "metadata": {}, - "outputs": [], - "source": [ - "dev_data='/workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev'" - ] - }, - { - "cell_type": "code", - "execution_count": 150, - "id": "extraordinary-orleans", - "metadata": {}, - "outputs": [], - "source": [ - "from deepspeech.frontend.utility import read_manifest" - ] - }, - { - "cell_type": "code", - "execution_count": 151, - "id": "returning-lighter", - "metadata": {}, - "outputs": [], - "source": [ - "dev_json = read_manifest(dev_data)" - ] - }, - { - "cell_type": "code", - "execution_count": 152, - "id": "western-founder", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev/deltafalse/feats.1.ark:16',\n", - " 'name': 'input1',\n", - " 'shape': [1063, 83]}],\n", - " 'output': [{'name': 'target1',\n", - " 'shape': [41, 5002],\n", - " 'text': 'AS I APPROACHED THE CITY I HEARD BELLS RINGING AND A '\n", - " 'LITTLE LATER I FOUND THE STREETS ASTIR WITH THRONGS OF '\n", - " 'WELL DRESSED PEOPLE IN FAMILY GROUPS WENDING THEIR WAY '\n", - " 'HITHER AND THITHER',\n", - " 'token': '▁AS ▁I ▁APPROACHED ▁THE ▁CITY ▁I ▁HEARD ▁BELL S ▁RING '\n", - " 'ING ▁AND ▁A ▁LITTLE ▁LATER ▁I ▁FOUND ▁THE ▁STREETS ▁AS '\n", - " 'T IR ▁WITH ▁THRONG S ▁OF ▁WELL ▁DRESSED ▁PEOPLE ▁IN '\n", - " '▁FAMILY ▁GROUP S ▁WE ND ING ▁THEIR ▁WAY ▁HITHER ▁AND '\n", - " '▁THITHER',\n", - " 'tokenid': '713 2458 676 4502 1155 2458 2351 849 389 3831 206 627 '\n", - " '482 2812 2728 2458 2104 4502 4316 713 404 212 4925 '\n", - " '4549 389 3204 4861 1677 3339 2495 1950 2279 389 4845 '\n", - " '302 206 4504 4843 2394 627 4526'}],\n", - " 'utt': '116-288045-0000',\n", - " 'utt2spk': '116-288045'}\n", - "5542\n", - "\n" - ] - } - ], - "source": [ - "from pprint import pprint\n", - "pprint(dev_json[0])\n", - "print(len(dev_json))\n", - "print(type(dev_json))" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "id": "motivated-receptor", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "import itertools\n", - "\n", - "import numpy as np\n", - "\n", - "from deepspeech.utils.log import Log\n", - "\n", - "__all__ = [\"make_batchset\"]\n", - "\n", - "logger = Log(__name__).getlog()\n", - "\n", - "\n", - "def batchfy_by_seq(\n", - " sorted_data,\n", - " batch_size,\n", - " max_length_in,\n", - " max_length_out,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " iaxis=0,\n", - " okey=\"output\",\n", - " oaxis=0, ):\n", - " \"\"\"Make batch set from json dictionary\n", - "\n", - " :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n", - " :param int batch_size: batch size\n", - " :param int max_length_in: maximum length of input to decide adaptive batch size\n", - " :param int max_length_out: maximum length of output to decide adaptive batch size\n", - " :param int min_batch_size: mininum batch size (for multi-gpu)\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - " :param str ikey: key to access input\n", - " (for ASR ikey=\"input\", for TTS, MT ikey=\"output\".)\n", - " :param int iaxis: dimension to access input\n", - " (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n", - " :param str okey: key to access output\n", - " (for ASR, MT okey=\"output\". for TTS okey=\"input\".)\n", - " :param int oaxis: dimension to access output\n", - " (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)\n", - " :return: List[List[Tuple[str, dict]]] list of batches\n", - " \"\"\"\n", - " if batch_size <= 0:\n", - " raise ValueError(f\"Invalid batch_size={batch_size}\")\n", - "\n", - " # check #utts is more than min_batch_size\n", - " if len(sorted_data) < min_batch_size:\n", - " raise ValueError(\n", - " f\"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).\"\n", - " )\n", - "\n", - " # make list of minibatches\n", - " minibatches = []\n", - " start = 0\n", - " while True:\n", - " _, info = sorted_data[start]\n", - " ilen = int(info[ikey][iaxis][\"shape\"][0])\n", - " olen = (int(info[okey][oaxis][\"shape\"][0]) if oaxis >= 0 else\n", - " max(map(lambda x: int(x[\"shape\"][0]), info[okey])))\n", - " factor = max(int(ilen / max_length_in), int(olen / max_length_out))\n", - " # change batchsize depending on the input and output length\n", - " # if ilen = 1000 and max_length_in = 800\n", - " # then b = batchsize / 2\n", - " # and max(min_batches, .) avoids batchsize = 0\n", - " bs = max(min_batch_size, int(batch_size / (1 + factor)))\n", - " end = min(len(sorted_data), start + bs)\n", - " minibatch = sorted_data[start:end]\n", - " if shortest_first:\n", - " minibatch.reverse()\n", - "\n", - " # check each batch is more than minimum batchsize\n", - " if len(minibatch) < min_batch_size:\n", - " mod = min_batch_size - len(minibatch) % min_batch_size\n", - " additional_minibatch = [\n", - " sorted_data[i] for i in np.random.randint(0, start, mod)\n", - " ]\n", - " if shortest_first:\n", - " additional_minibatch.reverse()\n", - " minibatch.extend(additional_minibatch)\n", - " minibatches.append(minibatch)\n", - "\n", - " if end == len(sorted_data):\n", - " break\n", - " start = end\n", - "\n", - " # batch: List[List[Tuple[str, dict]]]\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_by_bin(\n", - " sorted_data,\n", - " batch_bins,\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " okey=\"output\", ):\n", - " \"\"\"Make variably sized batch set, which maximizes\n", - "\n", - " the number of bins up to `batch_bins`.\n", - "\n", - " :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n", - " :param int batch_bins: Maximum frames of a batch\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param int test: Return only every `test` batches\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - "\n", - " :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n", - " :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n", - "\n", - " :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n", - " \"\"\"\n", - " if batch_bins <= 0:\n", - " raise ValueError(f\"invalid batch_bins={batch_bins}\")\n", - " length = len(sorted_data)\n", - " idim = int(sorted_data[0][1][ikey][0][\"shape\"][1])\n", - " odim = int(sorted_data[0][1][okey][0][\"shape\"][1])\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " minibatches = []\n", - " start = 0\n", - " n = 0\n", - " while True:\n", - " # Dynamic batch size depending on size of samples\n", - " b = 0\n", - " next_size = 0\n", - " max_olen = 0\n", - " while next_size < batch_bins and (start + b) < length:\n", - " ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0]) * idim\n", - " olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0]) * odim\n", - " if olen > max_olen:\n", - " max_olen = olen\n", - " next_size = (max_olen + ilen) * (b + 1)\n", - " if next_size <= batch_bins:\n", - " b += 1\n", - " elif next_size == 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in batch_bins ({batch_bins}): \"\n", - " f\"Please increase the value\")\n", - " end = min(length, start + max(min_batch_size, b))\n", - " batch = sorted_data[start:end]\n", - " if shortest_first:\n", - " batch.reverse()\n", - " minibatches.append(batch)\n", - " # Check for min_batch_size and fixes the batches if needed\n", - " i = -1\n", - " while len(minibatches[i]) < min_batch_size:\n", - " missing = min_batch_size - len(minibatches[i])\n", - " if -i == len(minibatches):\n", - " minibatches[i + 1].extend(minibatches[i])\n", - " minibatches = minibatches[1:]\n", - " break\n", - " else:\n", - " minibatches[i].extend(minibatches[i - 1][:missing])\n", - " minibatches[i - 1] = minibatches[i - 1][missing:]\n", - " i -= 1\n", - " if end == length:\n", - " break\n", - " start = end\n", - " n += 1\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " lengths = [len(x) for x in minibatches]\n", - " logger.info(\n", - " str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n", - " + \" to \" + str(max(lengths)) + \" samples \" + \"(avg \" + str(\n", - " int(np.mean(lengths))) + \" samples).\")\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_by_frame(\n", - " sorted_data,\n", - " max_frames_in,\n", - " max_frames_out,\n", - " max_frames_inout,\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " okey=\"output\", ):\n", - " \"\"\"Make variable batch set, which maximizes the number of frames to max_batch_frame.\n", - "\n", - " :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json\n", - " :param int max_frames_in: Maximum input frames of a batch\n", - " :param int max_frames_out: Maximum output frames of a batch\n", - " :param int max_frames_inout: Maximum input+output frames of a batch\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param int test: Return only every `test` batches\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - "\n", - " :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n", - " :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n", - "\n", - " :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n", - " \"\"\"\n", - " if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:\n", - " raise ValueError(\n", - " \"At least, one of `--batch-frames-in`, `--batch-frames-out` or \"\n", - " \"`--batch-frames-inout` should be > 0\")\n", - " length = len(sorted_data)\n", - " minibatches = []\n", - " start = 0\n", - " end = 0\n", - " while end != length:\n", - " # Dynamic batch size depending on size of samples\n", - " b = 0\n", - " max_olen = 0\n", - " max_ilen = 0\n", - " while (start + b) < length:\n", - " ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0])\n", - " if ilen > max_frames_in and max_frames_in != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-in ({max_frames_in}): \"\n", - " f\"Please increase the value\")\n", - " olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0])\n", - " if olen > max_frames_out and max_frames_out != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-out ({max_frames_out}): \"\n", - " f\"Please increase the value\")\n", - " if ilen + olen > max_frames_inout and max_frames_inout != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-out ({max_frames_inout}): \"\n", - " f\"Please increase the value\")\n", - " max_olen = max(max_olen, olen)\n", - " max_ilen = max(max_ilen, ilen)\n", - " in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0\n", - " out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0\n", - " inout_ok = (max_ilen + max_olen) * (\n", - " b + 1) <= max_frames_inout or max_frames_inout == 0\n", - " if in_ok and out_ok and inout_ok:\n", - " # add more seq in the minibatch\n", - " b += 1\n", - " else:\n", - " # no more seq in the minibatch\n", - " break\n", - " end = min(length, start + b)\n", - " batch = sorted_data[start:end]\n", - " if shortest_first:\n", - " batch.reverse()\n", - " minibatches.append(batch)\n", - " # Check for min_batch_size and fixes the batches if needed\n", - " i = -1\n", - " while len(minibatches[i]) < min_batch_size:\n", - " missing = min_batch_size - len(minibatches[i])\n", - " if -i == len(minibatches):\n", - " minibatches[i + 1].extend(minibatches[i])\n", - " minibatches = minibatches[1:]\n", - " break\n", - " else:\n", - " minibatches[i].extend(minibatches[i - 1][:missing])\n", - " minibatches[i - 1] = minibatches[i - 1][missing:]\n", - " i -= 1\n", - " start = end\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " lengths = [len(x) for x in minibatches]\n", - " logger.info(\n", - " str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n", - " + \" to \" + str(max(lengths)) + \" samples\" + \"(avg \" + str(\n", - " int(np.mean(lengths))) + \" samples).\")\n", - "\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,\n", - " shortest_first):\n", - " import random\n", - "\n", - " logger.info(\"use shuffled batch.\")\n", - " sorted_data = random.sample(data.items(), len(data.items()))\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " # make list of minibatches\n", - " minibatches = []\n", - " start = 0\n", - " while True:\n", - " end = min(len(sorted_data), start + batch_size)\n", - " # check each batch is more than minimum batchsize\n", - " minibatch = sorted_data[start:end]\n", - " if shortest_first:\n", - " minibatch.reverse()\n", - " if len(minibatch) < min_batch_size:\n", - " mod = min_batch_size - len(minibatch) % min_batch_size\n", - " additional_minibatch = [\n", - " sorted_data[i] for i in np.random.randint(0, start, mod)\n", - " ]\n", - " if shortest_first:\n", - " additional_minibatch.reverse()\n", - " minibatch.extend(additional_minibatch)\n", - " minibatches.append(minibatch)\n", - " if end == len(sorted_data):\n", - " break\n", - " start = end\n", - "\n", - " # for debugging\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " logger.info(\"# minibatches: \" + str(len(minibatches)))\n", - " return minibatches\n", - "\n", - "\n", - "BATCH_COUNT_CHOICES = [\"auto\", \"seq\", \"bin\", \"frame\"]\n", - "BATCH_SORT_KEY_CHOICES = [\"input\", \"output\", \"shuffle\"]\n", - "\n", - "\n", - "def make_batchset(\n", - " data,\n", - " batch_size=0,\n", - " max_length_in=float(\"inf\"),\n", - " max_length_out=float(\"inf\"),\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " batch_sort_key=\"input\",\n", - " count=\"auto\",\n", - " batch_bins=0,\n", - " batch_frames_in=0,\n", - " batch_frames_out=0,\n", - " batch_frames_inout=0,\n", - " iaxis=0,\n", - " oaxis=0, ):\n", - " \"\"\"Make batch set from json dictionary\n", - "\n", - " if utts have \"category\" value,\n", - "\n", - " >>> data = {'utt1': {'category': 'A', 'input': ...},\n", - " ... 'utt2': {'category': 'B', 'input': ...},\n", - " ... 'utt3': {'category': 'B', 'input': ...},\n", - " ... 'utt4': {'category': 'A', 'input': ...}}\n", - " >>> make_batchset(data, batchsize=2, ...)\n", - " [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]\n", - "\n", - " Note that if any utts doesn't have \"category\",\n", - " perform as same as batchfy_by_{count}\n", - "\n", - " :param List[Dict[str, Any]] data: dictionary loaded from data.json\n", - " :param int batch_size: maximum number of sequences in a minibatch.\n", - " :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.\n", - " :param int batch_frames_in: maximum number of input frames in a minibatch.\n", - " :param int batch_frames_out: maximum number of output frames in a minibatch.\n", - " :param int batch_frames_out: maximum number of input+output frames in a minibatch.\n", - " :param str count: strategy to count maximum size of batch.\n", - " For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES\n", - "\n", - " :param int max_length_in: maximum length of input to decide adaptive batch size\n", - " :param int max_length_out: maximum length of output to decide adaptive batch size\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - " :param str batch_sort_key: how to sort data before creating minibatches\n", - " [\"input\", \"output\", \"shuffle\"]\n", - " :param bool swap_io: if True, use \"input\" as output and \"output\"\n", - " as input in `data` dict\n", - " :param bool mt: if True, use 0-axis of \"output\" as output and 1-axis of \"output\"\n", - " as input in `data` dict\n", - " :param int iaxis: dimension to access input\n", - " (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n", - " :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,\n", - " reserved for future research, -1 means all axis.)\n", - " :return: List[List[Tuple[str, dict]]] list of batches\n", - " \"\"\"\n", - "\n", - " # check args\n", - " if count not in BATCH_COUNT_CHOICES:\n", - " raise ValueError(\n", - " f\"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}\")\n", - " if batch_sort_key not in BATCH_SORT_KEY_CHOICES:\n", - " raise ValueError(f\"arg 'batch_sort_key' ({batch_sort_key}) should be \"\n", - " f\"one of {BATCH_SORT_KEY_CHOICES}\")\n", - "\n", - " ikey = \"input\"\n", - " okey = \"output\"\n", - " batch_sort_axis = 0 # index of list \n", - "\n", - " if count == \"auto\":\n", - " if batch_size != 0:\n", - " count = \"seq\"\n", - " elif batch_bins != 0:\n", - " count = \"bin\"\n", - " elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:\n", - " count = \"frame\"\n", - " else:\n", - " raise ValueError(\n", - " f\"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}\"\n", - " )\n", - " logger.info(f\"count is auto detected as {count}\")\n", - "\n", - " if count != \"seq\" and batch_sort_key == \"shuffle\":\n", - " raise ValueError(\n", - " \"batch_sort_key=shuffle is only available if batch_count=seq\")\n", - "\n", - " category2data = {} # Dict[str, dict]\n", - " for v in data:\n", - " k = v['utt']\n", - " category2data.setdefault(v.get(\"category\"), {})[k] = v\n", - "\n", - " batches_list = [] # List[List[List[Tuple[str, dict]]]]\n", - " for d in category2data.values():\n", - " if batch_sort_key == \"shuffle\":\n", - " batches = batchfy_shuffle(d, batch_size, min_batch_size,\n", - " num_batches, shortest_first)\n", - " batches_list.append(batches)\n", - " continue\n", - "\n", - " # sort it by input lengths (long to short)\n", - " sorted_data = sorted(\n", - " d.items(),\n", - " key=lambda data: int(data[1][batch_sort_key][batch_sort_axis][\"shape\"][0]),\n", - " reverse=not shortest_first, )\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " \n", - " if count == \"seq\":\n", - " batches = batchfy_by_seq(\n", - " sorted_data,\n", - " batch_size=batch_size,\n", - " max_length_in=max_length_in,\n", - " max_length_out=max_length_out,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " iaxis=iaxis,\n", - " okey=okey,\n", - " oaxis=oaxis, )\n", - " if count == \"bin\":\n", - " batches = batchfy_by_bin(\n", - " sorted_data,\n", - " batch_bins=batch_bins,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " okey=okey, )\n", - " if count == \"frame\":\n", - " batches = batchfy_by_frame(\n", - " sorted_data,\n", - " max_frames_in=batch_frames_in,\n", - " max_frames_out=batch_frames_out,\n", - " max_frames_inout=batch_frames_inout,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " okey=okey, )\n", - " batches_list.append(batches)\n", - "\n", - " if len(batches_list) == 1:\n", - " batches = batches_list[0]\n", - " else:\n", - " # Concat list. This way is faster than \"sum(batch_list, [])\"\n", - " batches = list(itertools.chain(*batches_list))\n", - "\n", - " # for debugging\n", - " if num_batches > 0:\n", - " batches = batches[:num_batches]\n", - " logger.info(\"# minibatches: \" + str(len(batches)))\n", - "\n", - " # batch: List[List[Tuple[str, dict]]]\n", - " return batches\n" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "acquired-hurricane", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[INFO 2021/08/18 06:57:10 1445365138.py:284] use shuffled batch.\n", - "[INFO 2021/08/18 06:57:10 1445365138.py:286] # utts: 5542\n", - "[INFO 2021/08/18 06:57:10 1445365138.py:468] # minibatches: 555\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "555\n" - ] - } - ], - "source": [ - "batch_size=10\n", - "maxlen_in=300\n", - "maxlen_out=400\n", - "minibatches=0 # for debug\n", - "min_batch_size=2\n", - "use_sortagrad=True\n", - "batch_count='seq'\n", - "batch_bins=0\n", - "batch_frames_in=3000\n", - "batch_frames_out=0\n", - "batch_frames_inout=0\n", - " \n", - "dev_data = make_batchset(\n", - " dev_json,\n", - " batch_size,\n", - " maxlen_in,\n", - " maxlen_out,\n", - " minibatches, # for debug\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=use_sortagrad,\n", - " batch_sort_key=\"shuffle\",\n", - " count=batch_count,\n", - " batch_bins=batch_bins,\n", - " batch_frames_in=batch_frames_in,\n", - " batch_frames_out=batch_frames_out,\n", - " batch_frames_inout=batch_frames_inout,\n", - " iaxis=0,\n", - " oaxis=0, )\n", - "print(len(dev_data))\n", - "# for i in range(len(dev_data)):\n", - "# print(len(dev_data[i]))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "warming-malpractice", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: kaldiio in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (2.17.2)\n", - "Requirement already satisfied: numpy in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numpy-1.21.2-py3.7-linux-x86_64.egg (from kaldiio) (1.21.2)\n", - "\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.2.4 is available.\n", - "You should consider upgrading via the '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n" - ] - } - ], - "source": [ - "!pip install kaldiio" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "equipped-subject", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "superb-methodology", - "metadata": {}, - "outputs": [], - "source": [ - "from collections import OrderedDict\n", - "import kaldiio\n", - "\n", - "class LoadInputsAndTargets():\n", - " \"\"\"Create a mini-batch from a list of dicts\n", - "\n", - " >>> batch = [('utt1',\n", - " ... dict(input=[dict(feat='some.ark:123',\n", - " ... filetype='mat',\n", - " ... name='input1',\n", - " ... shape=[100, 80])],\n", - " ... output=[dict(tokenid='1 2 3 4',\n", - " ... name='target1',\n", - " ... shape=[4, 31])]]))\n", - " >>> l = LoadInputsAndTargets()\n", - " >>> feat, target = l(batch)\n", - "\n", - " :param: str mode: Specify the task mode, \"asr\" or \"tts\"\n", - " :param: str preprocess_conf: The path of a json file for pre-processing\n", - " :param: bool load_input: If False, not to load the input data\n", - " :param: bool load_output: If False, not to load the output data\n", - " :param: bool sort_in_input_length: Sort the mini-batch in descending order\n", - " of the input length\n", - " :param: bool use_speaker_embedding: Used for tts mode only\n", - " :param: bool use_second_target: Used for tts mode only\n", - " :param: dict preprocess_args: Set some optional arguments for preprocessing\n", - " :param: Optional[dict] preprocess_args: Used for tts mode only\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " mode=\"asr\",\n", - " preprocess_conf=None,\n", - " load_input=True,\n", - " load_output=True,\n", - " sort_in_input_length=True,\n", - " preprocess_args=None,\n", - " keep_all_data_on_mem=False, ):\n", - " self._loaders = {}\n", - "\n", - " if mode not in [\"asr\"]:\n", - " raise ValueError(\"Only asr are allowed: mode={}\".format(mode))\n", - "\n", - " if preprocess_conf is not None:\n", - " self.preprocessing = AugmentationPipeline(preprocess_conf)\n", - " logging.warning(\n", - " \"[Experimental feature] Some preprocessing will be done \"\n", - " \"for the mini-batch creation using {}\".format(\n", - " self.preprocessing))\n", - " else:\n", - " # If conf doesn't exist, this function don't touch anything.\n", - " self.preprocessing = None\n", - "\n", - " self.mode = mode\n", - " self.load_output = load_output\n", - " self.load_input = load_input\n", - " self.sort_in_input_length = sort_in_input_length\n", - " if preprocess_args is None:\n", - " self.preprocess_args = {}\n", - " else:\n", - " assert isinstance(preprocess_args, dict), type(preprocess_args)\n", - " self.preprocess_args = dict(preprocess_args)\n", - "\n", - " self.keep_all_data_on_mem = keep_all_data_on_mem\n", - "\n", - " def __call__(self, batch, return_uttid=False):\n", - " \"\"\"Function to load inputs and targets from list of dicts\n", - "\n", - " :param List[Tuple[str, dict]] batch: list of dict which is subset of\n", - " loaded data.json\n", - " :param bool return_uttid: return utterance ID information for visualization\n", - " :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]\n", - " :return: list of input feature sequences\n", - " [(T_1, D), (T_2, D), ..., (T_B, D)]\n", - " :rtype: list of float ndarray\n", - " :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]\n", - " :rtype: list of int ndarray\n", - "\n", - " \"\"\"\n", - " x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n", - " y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n", - " uttid_list = [] # List[str]\n", - "\n", - " for uttid, info in batch:\n", - " uttid_list.append(uttid)\n", - "\n", - " if self.load_input:\n", - " # Note(kamo): This for-loop is for multiple inputs\n", - " for idx, inp in enumerate(info[\"input\"]):\n", - " # {\"input\":\n", - " # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # \"name\": \"input1\", ...}], ...}\n", - " x = self._get_from_loader(\n", - " filepath=inp[\"feat\"],\n", - " filetype=inp.get(\"filetype\", \"mat\"))\n", - " x_feats_dict.setdefault(inp[\"name\"], []).append(x)\n", - "\n", - " if self.load_output:\n", - " for idx, inp in enumerate(info[\"output\"]):\n", - " if \"tokenid\" in inp:\n", - " # ======= Legacy format for output =======\n", - " # {\"output\": [{\"tokenid\": \"1 2 3 4\"}])\n", - " x = np.fromiter(\n", - " map(int, inp[\"tokenid\"].split()), dtype=np.int64)\n", - " else:\n", - " # ======= New format =======\n", - " # {\"input\":\n", - " # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # \"name\": \"target1\", ...}], ...}\n", - " x = self._get_from_loader(\n", - " filepath=inp[\"feat\"],\n", - " filetype=inp.get(\"filetype\", \"mat\"))\n", - "\n", - " y_feats_dict.setdefault(inp[\"name\"], []).append(x)\n", - "\n", - " if self.mode == \"asr\":\n", - " return_batch, uttid_list = self._create_batch_asr(\n", - " x_feats_dict, y_feats_dict, uttid_list)\n", - " else:\n", - " raise NotImplementedError(self.mode)\n", - "\n", - " if self.preprocessing is not None:\n", - " # Apply pre-processing all input features\n", - " for x_name in return_batch.keys():\n", - " if x_name.startswith(\"input\"):\n", - " return_batch[x_name] = self.preprocessing(\n", - " return_batch[x_name], uttid_list,\n", - " **self.preprocess_args)\n", - "\n", - " if return_uttid:\n", - " return tuple(return_batch.values()), uttid_list\n", - "\n", - " # Doesn't return the names now.\n", - " return tuple(return_batch.values())\n", - "\n", - " def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):\n", - " \"\"\"Create a OrderedDict for the mini-batch\n", - "\n", - " :param OrderedDict x_feats_dict:\n", - " e.g. {\"input1\": [ndarray, ndarray, ...],\n", - " \"input2\": [ndarray, ndarray, ...]}\n", - " :param OrderedDict y_feats_dict:\n", - " e.g. {\"target1\": [ndarray, ndarray, ...],\n", - " \"target2\": [ndarray, ndarray, ...]}\n", - " :param: List[str] uttid_list:\n", - " Give uttid_list to sort in the same order as the mini-batch\n", - " :return: batch, uttid_list\n", - " :rtype: Tuple[OrderedDict, List[str]]\n", - " \"\"\"\n", - " # handle single-input and multi-input (paralell) asr mode\n", - " xs = list(x_feats_dict.values())\n", - "\n", - " if self.load_output:\n", - " ys = list(y_feats_dict.values())\n", - " assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))\n", - "\n", - " # get index of non-zero length samples\n", - " nonzero_idx = list(\n", - " filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))\n", - " for n in range(1, len(y_feats_dict)):\n", - " nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)\n", - " else:\n", - " # Note(kamo): Be careful not to make nonzero_idx to a generator\n", - " nonzero_idx = list(range(len(xs[0])))\n", - "\n", - " if self.sort_in_input_length:\n", - " # sort in input lengths based on the first input\n", - " nonzero_sorted_idx = sorted(\n", - " nonzero_idx, key=lambda i: -len(xs[0][i]))\n", - " else:\n", - " nonzero_sorted_idx = nonzero_idx\n", - "\n", - " if len(nonzero_sorted_idx) != len(xs[0]):\n", - " logging.warning(\n", - " \"Target sequences include empty tokenid (batch {} -> {}).\".\n", - " format(len(xs[0]), len(nonzero_sorted_idx)))\n", - "\n", - " # remove zero-length samples\n", - " xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]\n", - " uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]\n", - "\n", - " x_names = list(x_feats_dict.keys())\n", - " if self.load_output:\n", - " ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]\n", - " y_names = list(y_feats_dict.keys())\n", - "\n", - " # Keeping x_name and y_name, e.g. input1, for future extension\n", - " return_batch = OrderedDict([\n", - " * [(x_name, x) for x_name, x in zip(x_names, xs)],\n", - " * [(y_name, y) for y_name, y in zip(y_names, ys)],\n", - " ])\n", - " else:\n", - " return_batch = OrderedDict(\n", - " [(x_name, x) for x_name, x in zip(x_names, xs)])\n", - " return return_batch, uttid_list\n", - "\n", - " def _get_from_loader(self, filepath, filetype):\n", - " \"\"\"Return ndarray\n", - "\n", - " In order to make the fds to be opened only at the first referring,\n", - " the loader are stored in self._loaders\n", - "\n", - " >>> ndarray = loader.get_from_loader(\n", - " ... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')\n", - "\n", - " :param: str filepath:\n", - " :param: str filetype:\n", - " :return:\n", - " :rtype: np.ndarray\n", - " \"\"\"\n", - " if filetype == \"hdf5\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = h5py.File(filepath, \"r\")\n", - " self._loaders[filepath] = loader\n", - " return loader[key][()]\n", - " elif filetype == \"sound.hdf5\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"sound.hdf5\",\n", - " # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = SoundHDF5File(filepath, \"r\", dtype=\"int16\")\n", - " self._loaders[filepath] = loader\n", - " array, rate = loader[key]\n", - " return array\n", - " elif filetype == \"sound\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.wav\",\n", - " # \"filetype\": \"sound\"},\n", - " # Assume PCM16\n", - " if not self.keep_all_data_on_mem:\n", - " array, _ = soundfile.read(filepath, dtype=\"int16\")\n", - " return array\n", - " if filepath not in self._loaders:\n", - " array, _ = soundfile.read(filepath, dtype=\"int16\")\n", - " self._loaders[filepath] = array\n", - " return self._loaders[filepath]\n", - " elif filetype == \"npz\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.npz:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"npz\",\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = np.load(filepath)\n", - " self._loaders[filepath] = loader\n", - " return loader[key]\n", - " elif filetype == \"npy\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.npy\",\n", - " # \"filetype\": \"npy\"},\n", - " if not self.keep_all_data_on_mem:\n", - " return np.load(filepath)\n", - " if filepath not in self._loaders:\n", - " self._loaders[filepath] = np.load(filepath)\n", - " return self._loaders[filepath]\n", - " elif filetype in [\"mat\", \"vec\"]:\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.ark:123\",\n", - " # \"filetype\": \"mat\"}]},\n", - " # In this case, \"123\" indicates the starting points of the matrix\n", - " # load_mat can load both matrix and vector\n", - " if not self.keep_all_data_on_mem:\n", - " return kaldiio.load_mat(filepath)\n", - " if filepath not in self._loaders:\n", - " self._loaders[filepath] = kaldiio.load_mat(filepath)\n", - " return self._loaders[filepath]\n", - " elif filetype == \"scp\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.scp:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"scp\",\n", - " filepath, key = filepath.split(\":\", 1)\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = kaldiio.load_scp(filepath)\n", - " self._loaders[filepath] = loader\n", - " return loader[key]\n", - " else:\n", - " raise NotImplementedError(\n", - " \"Not supported: loader_type={}\".format(filetype))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "id": "monthly-muscle", - "metadata": {}, - "outputs": [], - "source": [ - "preprocess_conf=None\n", - "train_mode=True\n", - "load = LoadInputsAndTargets(\n", - " mode=\"asr\",\n", - " load_output=True,\n", - " preprocess_conf=preprocess_conf,\n", - " preprocess_args={\"train\":\n", - " train_mode}, # Switch the mode of preprocessing\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "id": "periodic-senegal", - "metadata": {}, - "outputs": [], - "source": [ - "res = load(dev_data[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "id": "502d3f4d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "2\n", - "10\n", - "10\n", - "(1174, 83) float32\n", - "(29,) int64\n" - ] - } - ], - "source": [ - "print(type(res))\n", - "print(len(res))\n", - "print(len(res[0]))\n", - "print(len(res[1]))\n", - "print(res[0][0].shape, res[0][0].dtype)\n", - "print(res[1][0].shape, res[1][0].dtype)\n", - "# Tuple[Tuple[np.ndarry], Tuple[np.ndarry]]\n", - "# 2[10, 10]\n", - "# feats, labels" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "id": "humanitarian-container", - "metadata": {}, - "outputs": [], - "source": [ - "(inputs, outputs), utts = load(dev_data[0], return_uttid=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "id": "heard-prize", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027'] 10\n", - "10\n" - ] - } - ], - "source": [ - "print(utts, len(utts))\n", - "print(len(inputs))" - ] - }, - { - "cell_type": "code", - "execution_count": 106, - "id": "convinced-animation", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from deepspeech.io.utility import pad_list\n", - "class CustomConverter():\n", - " \"\"\"Custom batch converter.\n", - "\n", - " Args:\n", - " subsampling_factor (int): The subsampling factor.\n", - " dtype (paddle.dtype): Data type to convert.\n", - "\n", - " \"\"\"\n", - "\n", - " def __init__(self, subsampling_factor=1, dtype=np.float32):\n", - " \"\"\"Construct a CustomConverter object.\"\"\"\n", - " self.subsampling_factor = subsampling_factor\n", - " self.ignore_id = -1\n", - " self.dtype = dtype\n", - "\n", - " def __call__(self, batch):\n", - " \"\"\"Transform a batch and send it to a device.\n", - "\n", - " Args:\n", - " batch (list): The batch to transform.\n", - "\n", - " Returns:\n", - " tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)\n", - "\n", - " \"\"\"\n", - " # batch should be located in list\n", - " assert len(batch) == 1\n", - " (xs, ys), utts = batch[0]\n", - "\n", - " # perform subsampling\n", - " if self.subsampling_factor > 1:\n", - " xs = [x[::self.subsampling_factor, :] for x in xs]\n", - "\n", - " # get batch of lengths of input sequences\n", - " ilens = np.array([x.shape[0] for x in xs])\n", - "\n", - " # perform padding and convert to tensor\n", - " # currently only support real number\n", - " if xs[0].dtype.kind == \"c\":\n", - " xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)\n", - " xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)\n", - " # Note(kamo):\n", - " # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.\n", - " # Don't create ComplexTensor and give it E2E here\n", - " # because torch.nn.DataParellel can't handle it.\n", - " xs_pad = {\"real\": xs_pad_real, \"imag\": xs_pad_imag}\n", - " else:\n", - " xs_pad = pad_list(xs, 0).astype(self.dtype)\n", - "\n", - " # NOTE: this is for multi-output (e.g., speech translation)\n", - " ys_pad = pad_list(\n", - " [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],\n", - " self.ignore_id)\n", - "\n", - " olens = np.array([y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])\n", - " return utts, xs_pad, ilens, ys_pad, olens" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "id": "0b92ade5", - "metadata": {}, - "outputs": [], - "source": [ - "convert = CustomConverter()" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "id": "8dbd847c", - "metadata": {}, - "outputs": [], - "source": [ - "utts, xs, ilen, ys, olen = convert([load(dev_data[0], return_uttid=True)])" - ] - }, - { - "cell_type": "code", - "execution_count": 109, - "id": "31c085f4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027']\n", - "(10, 1174, 83)\n", - "(10,)\n", - "[1174 821 716 628 597 473 463 441 419 358]\n", - "(10, 32)\n", - "[[4502 2404 4223 3204 4502 587 1018 3861 2932 713 2458 2916 253 4508\n", - " 627 1395 713 4504 957 2761 209 2967 3173 3918 2598 4100 3 2816\n", - " 4990 -1 -1 -1]\n", - " [1005 451 210 278 3411 206 482 2307 573 4502 3848 4577 4273 2388\n", - " 4444 89 4919 278 1264 4501 2371 3 139 113 2603 4962 3158 3325\n", - " 4577 814 4587 1422]\n", - " [2345 4144 2291 200 713 2345 532 999 2458 3076 545 2458 4832 3038\n", - " 4499 482 2812 1260 3080 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2345 832 4577 4920 4501 2345 2298 1236 381 288 389 101 2495 4172\n", - " 4843 3233 3245 4501 2345 2298 3987 4502 3023 3353 2345 1361 1635 2603\n", - " 4723 2371 -1 -1]\n", - " [4502 4207 432 3204 4502 2396 125 935 433 2598 483 18 327 2\n", - " 389 627 4512 2340 713 482 1981 4525 4031 269 2030 1340 101 2495\n", - " 4013 4844 -1 -1]\n", - " [4502 4892 3204 1892 3780 389 482 2774 3013 89 192 2495 4502 3475\n", - " 389 66 370 343 404 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2458 2314 4577 2340 2863 1254 303 269 2 389 932 2079 4577 299\n", - " 195 3233 4508 2 89 814 3144 1091 3204 3250 2193 3414 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2391 1785 443 78 39 4962 2340 829 599 4593 278 4681 202 407\n", - " 269 194 182 4577 482 4308 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [ 627 4873 2175 363 202 404 1018 4577 4502 3412 4875 2286 107 122\n", - " 4832 2345 3896 89 2368 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [ 481 174 474 599 1881 3252 2842 742 4502 2545 107 88 3204 4525\n", - " 4517 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]]\n", - "[29 32 19 30 30 19 26 20 19 15]\n", - "float32\n", - "int64\n", - "int64\n", - "int64\n" - ] - } - ], - "source": [ - "print(utts)\n", - "print(xs.shape)\n", - "print(ilen.shape)\n", - "print(ilen)\n", - "print(ys.shape)\n", - "print(ys)\n", - "print(olen)\n", - "print(xs.dtype)\n", - "print(ilen.dtype)\n", - "print(ys.dtype)\n", - "print(olen.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "id": "72e9ba60", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 230, - "id": "64593e5f", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "from paddle.io import DataLoader\n", - "\n", - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.io.batchfy import make_batchset\n", - "from deepspeech.io.converter import CustomConverter\n", - "from deepspeech.io.dataset import TransformDataset\n", - "from deepspeech.io.reader import LoadInputsAndTargets\n", - "from deepspeech.utils.log import Log\n", - "\n", - "\n", - "logger = Log(__name__).getlog()\n", - "\n", - "\n", - "class BatchDataLoader():\n", - " def __init__(self,\n", - " json_file: str,\n", - " train_mode: bool,\n", - " sortagrad: bool=False,\n", - " batch_size: int=0,\n", - " maxlen_in: float=float('inf'),\n", - " maxlen_out: float=float('inf'),\n", - " minibatches: int=0,\n", - " mini_batch_size: int=1,\n", - " batch_count: str='auto',\n", - " batch_bins: int=0,\n", - " batch_frames_in: int=0,\n", - " batch_frames_out: int=0,\n", - " batch_frames_inout: int=0,\n", - " preprocess_conf=None,\n", - " n_iter_processes: int=1,\n", - " subsampling_factor: int=1,\n", - " num_encs: int=1):\n", - " self.json_file = json_file\n", - " self.train_mode = train_mode\n", - " self.use_sortagrad = sortagrad == -1 or sortagrad > 0\n", - " self.batch_size = batch_size\n", - " self.maxlen_in = maxlen_in\n", - " self.maxlen_out = maxlen_out\n", - " self.batch_count = batch_count\n", - " self.batch_bins = batch_bins\n", - " self.batch_frames_in = batch_frames_in\n", - " self.batch_frames_out = batch_frames_out\n", - " self.batch_frames_inout = batch_frames_inout\n", - " self.subsampling_factor = subsampling_factor\n", - " self.num_encs = num_encs\n", - " self.preprocess_conf = preprocess_conf\n", - " self.n_iter_processes = n_iter_processes\n", - "\n", - " \n", - " # read json data\n", - " self.data_json = read_manifest(json_file)\n", - "\n", - " # make minibatch list (variable length)\n", - " self.minibaches = make_batchset(\n", - " self.data_json,\n", - " batch_size,\n", - " maxlen_in,\n", - " maxlen_out,\n", - " minibatches, # for debug\n", - " min_batch_size=mini_batch_size,\n", - " shortest_first=self.use_sortagrad,\n", - " count=batch_count,\n", - " batch_bins=batch_bins,\n", - " batch_frames_in=batch_frames_in,\n", - " batch_frames_out=batch_frames_out,\n", - " batch_frames_inout=batch_frames_inout,\n", - " iaxis=0,\n", - " oaxis=0, )\n", - "\n", - " # data reader\n", - " self.reader = LoadInputsAndTargets(\n", - " mode=\"asr\",\n", - " load_output=True,\n", - " preprocess_conf=preprocess_conf,\n", - " preprocess_args={\"train\":\n", - " train_mode}, # Switch the mode of preprocessing\n", - " )\n", - "\n", - " # Setup a converter\n", - " if num_encs == 1:\n", - " self.converter = CustomConverter(\n", - " subsampling_factor=subsampling_factor, dtype=np.float32)\n", - " else:\n", - " assert NotImplementedError(\"not impl CustomConverterMulEnc.\")\n", - "\n", - " # hack to make batchsize argument as 1\n", - " # actual bathsize is included in a list\n", - " # default collate function converts numpy array to pytorch tensor\n", - " # we used an empty collate function instead which returns list\n", - " self.dataset = TransformDataset(self.minibaches, \n", - " lambda data: self.converter([self.reader(data, return_uttid=True)]))\n", - " self.dataloader = DataLoader(\n", - " dataset=self.dataset,\n", - " batch_size=1,\n", - " shuffle=not use_sortagrad if train_mode else False,\n", - " collate_fn=lambda x: x[0],\n", - " num_workers=n_iter_processes, )\n", - "\n", - " def __repr__(self):\n", - " echo = f\"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> \"\n", - " echo += f\"train_mode: {self.train_mode}, \"\n", - " echo += f\"sortagrad: {self.use_sortagrad}, \"\n", - " echo += f\"batch_size: {self.batch_size}, \"\n", - " echo += f\"maxlen_in: {self.maxlen_in}, \"\n", - " echo += f\"maxlen_out: {self.maxlen_out}, \"\n", - " echo += f\"batch_count: {self.batch_count}, \"\n", - " echo += f\"batch_bins: {self.batch_bins}, \"\n", - " echo += f\"batch_frames_in: {self.batch_frames_in}, \"\n", - " echo += f\"batch_frames_out: {self.batch_frames_out}, \"\n", - " echo += f\"batch_frames_inout: {self.batch_frames_inout}, \"\n", - " echo += f\"subsampling_factor: {self.subsampling_factor}, \"\n", - " echo += f\"num_encs: {self.num_encs}, \"\n", - " echo += f\"num_workers: {self.n_iter_processes}, \"\n", - " echo += f\"file: {self.json_file}\"\n", - " return echo\n", - " \n", - " def __len__(self):\n", - " return len(self.dataloader)\n", - " \n", - " def __iter__(self):\n", - " return self.dataloader.__iter__()\n", - " \n", - " def __call__(self):\n", - " return self.__iter__()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 231, - "id": "fcea3fd0", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[INFO 2021/08/18 07:42:23 batchfy.py:399] count is auto detected as seq\n", - "[INFO 2021/08/18 07:42:23 batchfy.py:423] # utts: 5542\n", - "[INFO 2021/08/18 07:42:23 batchfy.py:466] # minibatches: 278\n" - ] - } - ], - "source": [ - "train = BatchDataLoader(dev_data, True, batch_size=20)" - ] - }, - { - "cell_type": "code", - "execution_count": 232, - "id": "e2a2c9a8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "278\n", - "['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'auto_collate_batch', 'batch_sampler', 'batch_size', 'collate_fn', 'dataset', 'dataset_kind', 'feed_list', 'from_dataset', 'from_generator', 'num_workers', 'pin_memory', 'places', 'return_list', 'timeout', 'use_buffer_reader', 'use_shared_memory', 'worker_init_fn']\n", - "<__main__.BatchDataLoader object at 0x7fdddba35470> train_mode: True, sortagrad: False, batch_size: 20, maxlen_in: inf, maxlen_out: inf, batch_count: auto, batch_bins: 0, batch_frames_in: 0, batch_frames_out: 0, batch_frames_inout: 0, subsampling_factor: 1, num_encs: 1, num_workers: 1, file: /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev\n", - "278\n" - ] - } - ], - "source": [ - "print(len(train.dataloader))\n", - "print(dir(train.dataloader))\n", - "print(train)\n", - "print(len(train))" - ] - }, - { - "cell_type": "code", - "execution_count": 220, - "id": "a5ba7d6e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['7601-101619-0003', '1255-138279-0000', '1272-128104-0004', '6123-59150-0027', '2078-142845-0025', '7850-73752-0018', '4570-24733-0004', '2506-169427-0002', '7601-101619-0004', '3170-137482-0000', '6267-53049-0019', '4570-14911-0009', '174-168635-0018', '7601-291468-0004', '3576-138058-0022', '1919-142785-0007', '6467-62797-0007', '4153-61735-0005', '1686-142278-0003', '2506-169427-0000']\n", - "Tensor(shape=[20, 2961, 83], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[-1.99415934, -1.80315673, -1.88801885, ..., 0.86933994, -0.59853148, 0.02596200],\n", - " [-1.95346808, -1.84891188, -2.17492867, ..., 0.83640492, -0.59853148, -0.11333394],\n", - " [-2.27899861, -2.21495342, -2.58480024, ..., 0.91874266, -0.59853148, -0.31453922],\n", - " ...,\n", - " [-2.64522028, -2.35221887, -2.91269732, ..., 1.48994756, -0.16100442, 0.36646330],\n", - " [-2.40107250, -2.21495342, -2.37986445, ..., 1.44072104, -0.13220564, 0.12656468],\n", - " [-2.15692472, -1.89466715, -2.25690317, ..., 1.31273174, -0.09620714, -0.15202725]],\n", - "\n", - " [[-0.28859532, -0.29033494, -0.86576819, ..., 1.37753224, -0.30570769, 0.25806731],\n", - " [-0.20149794, -0.17814466, -0.59891301, ..., 1.35188794, -0.30570769, -0.02964944],\n", - " [-0.34947991, -0.33597648, -0.96877253, ..., 1.38394332, -0.30570769, -0.38376236],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.44914246, -0.33902276, -0.78237975, ..., 1.38218808, 0.29214793, -0.16815147],\n", - " [-0.55490732, -0.41596055, -0.84425378, ..., 1.34530187, 0.25002354, -0.04004869],\n", - " [-0.83694696, -0.62112784, -1.07112527, ..., 1.19160914, 0.20789915, 0.37984371],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.24343657, -0.94188881, -1.41092563, ..., 0.96716309, 0.60345763, 0.15360183],\n", - " [-1.19466043, -0.80585432, -0.49723154, ..., 1.06735480, 0.60345763, 0.14511746],\n", - " [-0.94079566, -0.59330046, -0.40948665, ..., 0.82244170, 0.55614340, 0.28086722],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.21757117, 0.11361472, -0.33262897, ..., 0.76338506, -0.10711290, -0.57754958],\n", - " [-1.00205481, -0.61152041, -0.47124696, ..., 1.11897349, -0.10711290, 0.24931324],\n", - " [-1.03929281, -1.20336759, -1.16433656, ..., 0.88888687, -0.10711290, -0.04115745],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-1.25289667, -1.05046368, -0.82881606, ..., 1.23991334, 0.61702502, 0.05275881],\n", - " [-1.19659519, -0.78677225, -0.80407262, ..., 1.27644968, 0.61702502, -0.35079369],\n", - " [-1.49687004, -1.01750231, -0.82881606, ..., 1.29106426, 0.65006059, 0.17958963],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]]])\n", - "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [2961, 2948, 2938, 2907, 2904, 2838, 2832, 2819, 2815, 2797, 2775, 2710, 2709, 2696, 2688, 2661, 2616, 2595, 2589, 2576])\n", - "Tensor(shape=[20, 133], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[3098, 1595, 389, ..., -1 , -1 , -1 ],\n", - " [2603, 4832, 482, ..., -1 , -1 , -1 ],\n", - " [2796, 303, 269, ..., -1 , -1 , -1 ],\n", - " ...,\n", - " [3218, 3673, 206, ..., -1 , -1 , -1 ],\n", - " [2371, 4832, 4031, ..., -1 , -1 , -1 ],\n", - " [2570, 2433, 4285, ..., -1 , -1 , -1 ]])\n", - "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [80 , 83 , 102, 133, 82 , 102, 71 , 91 , 68 , 81 , 86 , 67 , 71 , 95 , 65 , 88 , 97 , 98 , 89 , 72 ])\n" - ] - } - ], - "source": [ - "for batch in train:\n", - " utts, xs, ilens, ys, olens = batch\n", - " print(utts)\n", - " print(xs)\n", - " print(ilens)\n", - " print(ys)\n", - " print(olens)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3c974a1e", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/hack_api_test.ipynb b/.notebook/hack_api_test.ipynb deleted file mode 100644 index f653084e6f9eb31cf8b1cdcf74854b488c3fd7bf..0000000000000000000000000000000000000000 --- a/.notebook/hack_api_test.ipynb +++ /dev/null @@ -1,290 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "breeding-haven", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "appropriate-theta", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "LICENSE deepspeech examples\t\t requirements.txt tools\r\n", - "README.md docs\t libsndfile-1.0.28\t setup.sh\t utils\r\n", - "README_cn.md env.sh\t libsndfile-1.0.28.tar.gz tests\r\n" - ] - } - ], - "source": [ - "!ls" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "entire-bloom", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "WARNING:root:override cat of paddle.Tensor if exists or register, remove this when fixed!\n", - "WARNING:root:register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user repeat to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user glu to paddle.nn.functional, remove this when fixed!\n", - "WARNING:root:register user GLU to paddle.nn, remove this when fixed!\n", - "WARNING:root:register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "WARNING:root:override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n" - ] - } - ], - "source": [ - "from deepspeech.modules import loss" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "governmental-aircraft", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "import paddle" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "proprietary-disaster", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - " paddle.VarBase>" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.repeat" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "first-diagram", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.size" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "intelligent-david", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.cat" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "bronze-tenant", - "metadata": {}, - "outputs": [], - "source": [ - "a = paddle.to_tensor([12,32, 10, 12, 123,32 ,4])" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "balanced-bearing", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "7" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "extreme-republic", - "metadata": {}, - "outputs": [], - "source": [ - "def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:\n", - " nargs = len(args)\n", - " assert (nargs <= 1)\n", - " s = paddle.shape(xs)\n", - " if nargs == 1:\n", - " return s[args[0]]\n", - " else:\n", - " return s\n", - "\n", - "# logger.warn(\n", - "# \"override size of paddle.Tensor if exists or register, remove this when fixed!\"\n", - "# )\n", - "paddle.Tensor.size = size" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "gross-addiction", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [7])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size(0)\n", - "a.size()" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "adverse-dining", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [7])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "popular-potato", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb deleted file mode 100644 index 20882c1ae75a5c934c66b4a6127b7d6f10d2b061..0000000000000000000000000000000000000000 --- a/.notebook/jit_infer.ipynb +++ /dev/null @@ -1,672 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-03-26 02:55:23,873 - WARNING - register user softmax to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,875 - WARNING - register user sigmoid to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,875 - WARNING - register user relu to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,876 - WARNING - override cat of paddle if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,876 - WARNING - override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,877 - WARNING - override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,877 - WARNING - override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "2021-03-26 02:55:23,878 - WARNING - register user view to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,878 - WARNING - register user view_as to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,879 - WARNING - register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,880 - WARNING - register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,880 - WARNING - register user fill_ to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,881 - WARNING - register user repeat to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,881 - WARNING - register user softmax to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,882 - WARNING - register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,882 - WARNING - register user relu to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,883 - WARNING - register user glu to paddle.nn.functional, remove this when fixed!\n", - "2021-03-26 02:55:23,883 - WARNING - override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "2021-03-26 02:55:23,884 - WARNING - register user GLU to paddle.nn, remove this when fixed!\n", - "2021-03-26 02:55:23,884 - WARNING - register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n" - ] - } - ], - "source": [ - "import os\n", - "import time\n", - "import argparse\n", - "import functools\n", - "import paddle\n", - "import numpy as np\n", - "\n", - "from deepspeech.utils.socket_server import warm_up_test\n", - "from deepspeech.utils.socket_server import AsrTCPServer\n", - "from deepspeech.utils.socket_server import AsrRequestHandler\n", - "\n", - "from deepspeech.training.cli import default_argument_parser\n", - "from deepspeech.exps.deepspeech2.config import get_cfg_defaults\n", - "\n", - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "\n", - "from deepspeech.models.ds2 import DeepSpeech2Model\n", - "from deepspeech.models.ds2 import DeepSpeech2InferModel\n", - "from deepspeech.io.dataset import ManifestDataset\n", - "\n", - "\n", - "\n", - "from deepspeech.frontend.utility import read_manifest" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.0.0\n", - "e7f28d6c0db54eb9c9a810612300b526687e56a6\n", - "OFF\n", - "OFF\n", - "commit: e7f28d6c0db54eb9c9a810612300b526687e56a6\n", - "None\n", - "0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " 'commit',\n", - " 'full_version',\n", - " 'istaged',\n", - " 'major',\n", - " 'minor',\n", - " 'mkl',\n", - " 'patch',\n", - " 'rc',\n", - " 'show',\n", - " 'with_mkl']" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(paddle.__version__)\n", - "print(paddle.version.commit)\n", - "print(paddle.version.with_mkl)\n", - "print(paddle.version.mkl())\n", - "print(paddle.version.show())\n", - "print(paddle.version.patch)\n", - "dir(paddle.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "data:\n", - " augmentation_config: conf/augmentation.config\n", - " batch_size: 64\n", - " dev_manifest: data/manifest.dev\n", - " keep_transcription_text: False\n", - " max_duration: 27.0\n", - " max_freq: None\n", - " mean_std_filepath: examples/aishell/data/mean_std.npz\n", - " min_duration: 0.0\n", - " n_fft: None\n", - " num_workers: 0\n", - " random_seed: 0\n", - " shuffle_method: batch_shuffle\n", - " sortagrad: True\n", - " specgram_type: linear\n", - " stride_ms: 10.0\n", - " target_dB: -20\n", - " target_sample_rate: 16000\n", - " test_manifest: examples/aishell/data/manifest.test\n", - " train_manifest: data/manifest.train\n", - " use_dB_normalization: True\n", - " vocab_filepath: examples/aishell/data/vocab.txt\n", - " window_ms: 20.0\n", - "decoding:\n", - " alpha: 2.6\n", - " batch_size: 128\n", - " beam_size: 300\n", - " beta: 5.0\n", - " cutoff_prob: 0.99\n", - " cutoff_top_n: 40\n", - " decoding_method: ctc_beam_search\n", - " error_rate_type: cer\n", - " lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm\n", - " num_proc_bsearch: 10\n", - "model:\n", - " num_conv_layers: 2\n", - " num_rnn_layers: 3\n", - " rnn_layer_size: 1024\n", - " share_rnn_weights: False\n", - " use_gru: True\n", - "training:\n", - " global_grad_clip: 5.0\n", - " lr: 0.0005\n", - " lr_decay: 0.83\n", - " n_epoch: 30\n", - " weight_decay: 1e-06\n", - "----------- Configuration Arguments -----------\n", - "checkpoint_path: examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725\n", - "config: examples/aishell/conf/deepspeech2.yaml\n", - "device: gpu\n", - "dump_config: None\n", - "export_path: None\n", - "host_ip: localhost\n", - "host_port: 8086\n", - "model_dir: None\n", - "model_file: examples/aishell/jit.model.pdmodel\n", - "nprocs: 1\n", - "opts: ['data.test_manifest', 'examples/aishell/data/manifest.test', 'data.mean_std_filepath', 'examples/aishell/data/mean_std.npz', 'data.vocab_filepath', 'examples/aishell/data/vocab.txt']\n", - "output: None\n", - "params_file: examples/aishell/jit.model.pdiparams\n", - "speech_save_dir: demo_cache\n", - "use_gpu: False\n", - "warmup_manifest: examples/aishell/data/manifest.test\n", - "------------------------------------------------\n" - ] - } - ], - "source": [ - "parser = default_argument_parser()\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "add_arg('host_ip', str,\n", - " 'localhost',\n", - " \"Server's IP address.\")\n", - "add_arg('host_port', int, 8086, \"Server's IP port.\")\n", - "add_arg('speech_save_dir', str,\n", - " 'demo_cache',\n", - " \"Directory to save demo audios.\")\n", - "add_arg('warmup_manifest', \n", - " str, \n", - " \"examples/aishell/data/manifest.test\", \n", - " \"Filepath of manifest to warm up.\")\n", - "add_arg(\n", - " \"--model_file\",\n", - " type=str,\n", - " default=\"examples/aishell/jit.model.pdmodel\",\n", - " help=\"Model filename, Specify this when your model is a combined model.\"\n", - ")\n", - "add_arg(\n", - " \"--params_file\",\n", - " type=str,\n", - " default=\"examples/aishell/jit.model.pdiparams\",\n", - " help=\n", - " \"Parameter filename, Specify this when your model is a combined model.\"\n", - ")\n", - "add_arg(\n", - " \"--model_dir\",\n", - " type=str,\n", - " default=None,\n", - " help=\n", - " \"Model dir, If you load a non-combined model, specify the directory of the model.\"\n", - ")\n", - "add_arg(\"--use_gpu\",type=bool,default=False, help=\"Whether use gpu.\")\n", - "\n", - "\n", - "args = parser.parse_args(\n", - " \"--checkpoint_path examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725 --config examples/aishell/conf/deepspeech2.yaml --opts data.test_manifest examples/aishell/data/manifest.test data.mean_std_filepath examples/aishell/data/mean_std.npz data.vocab_filepath examples/aishell/data/vocab.txt\".split()\n", - ")\n", - "\n", - "\n", - "config = get_cfg_defaults()\n", - "if args.config:\n", - " config.merge_from_file(args.config)\n", - "if args.opts:\n", - " config.merge_from_list(args.opts)\n", - "config.freeze()\n", - "print(config)\n", - "\n", - "args.warmup_manifest = config.data.test_manifest\n", - "\n", - "print_arguments(args)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = ManifestDataset(\n", - " config.data.test_manifest,\n", - " config.data.unit_type,\n", - " config.data.vocab_filepath,\n", - " config.data.mean_std_filepath,\n", - " augmentation_config=\"{}\",\n", - " max_duration=config.data.max_duration,\n", - " min_duration=config.data.min_duration,\n", - " stride_ms=config.data.stride_ms,\n", - " window_ms=config.data.window_ms,\n", - " n_fft=config.data.n_fft,\n", - " max_freq=config.data.max_freq,\n", - " target_sample_rate=config.data.target_sample_rate,\n", - " specgram_type=config.data.specgram_type,\n", - " feat_dim=config.data.feat_dim,\n", - " delta_delta=config.data.delat_delta,\n", - " use_dB_normalization=config.data.use_dB_normalization,\n", - " target_dB=config.data.target_dB,\n", - " random_seed=config.data.random_seed,\n", - " keep_transcription_text=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-03-26 02:55:57,930 - INFO - [checkpoint] Rank 0: loaded model from examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725.pdparams\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "layer summary:\n", - "encoder.conv.conv_in.conv.weight|[32, 1, 41, 11]|14432\n", - "encoder.conv.conv_in.bn.weight|[32]|32\n", - "encoder.conv.conv_in.bn.bias|[32]|32\n", - "encoder.conv.conv_in.bn._mean|[32]|32\n", - "encoder.conv.conv_in.bn._variance|[32]|32\n", - "encoder.conv.conv_stack.0.conv.weight|[32, 32, 21, 11]|236544\n", - "encoder.conv.conv_stack.0.bn.weight|[32]|32\n", - "encoder.conv.conv_stack.0.bn.bias|[32]|32\n", - "encoder.conv.conv_stack.0.bn._mean|[32]|32\n", - "encoder.conv.conv_stack.0.bn._variance|[32]|32\n", - "encoder.rnn.rnn_stacks.0.fw_fc.weight|[1312, 3072]|4030464\n", - "encoder.rnn.rnn_stacks.0.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_fc.weight|[1312, 3072]|4030464\n", - "encoder.rnn.rnn_stacks.0.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.bw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.1.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.1.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.bw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.2.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.2.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.bw_rnn.cell.bias_hh|[3072]|3072\n", - "decoder.ctc_lo.weight|[2048, 4300]|8806400\n", - "decoder.ctc_lo.bias|[4300]|4300\n", - "layer has 66 parameters, 80148012 elements.\n" - ] - } - ], - "source": [ - "model = DeepSpeech2InferModel.from_pretrained(dataset, config,\n", - " args.checkpoint_path)\n", - "model.eval()" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "examples/aishell/jit.model.pdmodel\n", - "examples/aishell/jit.model.pdiparams\n", - "0\n", - "False\n" - ] - } - ], - "source": [ - "\n", - "from paddle.inference import Config\n", - "from paddle.inference import PrecisionType\n", - "from paddle.inference import create_predictor\n", - "\n", - "args.use_gpu=False\n", - "paddle.set_device('cpu')\n", - "\n", - "def init_predictor(args):\n", - " if args.model_dir is not None:\n", - " config = Config(args.model_dir)\n", - " else:\n", - " config = Config(args.model_file, args.params_file)\n", - "\n", - " if args.use_gpu:\n", - " config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)\n", - "# config.enable_tensorrt_engine(precision_mode=PrecisionType.Float32,\n", - "# use_calib_mode=True) # 开启TensorRT预测,精度为fp32,开启int8离线量化\n", - " else:\n", - " # If not specific mkldnn, you can set the blas thread.\n", - " # The thread num should not be greater than the number of cores in the CPU.\n", - " config.set_cpu_math_library_num_threads(1)\n", - " config.enable_mkldnn()\n", - " \n", - " config.enable_memory_optim()\n", - " config.switch_ir_optim(True)\n", - " \n", - " print(config.model_dir())\n", - " print(config.prog_file())\n", - " print(config.params_file())\n", - " print(config.gpu_device_id())\n", - " print(args.use_gpu)\n", - " predictor = create_predictor(config)\n", - " return predictor\n", - "\n", - "def run(predictor, audio, audio_len):\n", - " # copy img data to input tensor\n", - " input_names = predictor.get_input_names()\n", - " for i, name in enumerate(input_names):\n", - " print(\"input:\", i, name)\n", - " \n", - " audio_tensor = predictor.get_input_handle('audio')\n", - " audio_tensor.reshape(audio.shape)\n", - " audio_tensor.copy_from_cpu(audio.copy())\n", - " \n", - " audiolen_tensor = predictor.get_input_handle('audio_len')\n", - " audiolen_tensor.reshape(audio_len.shape)\n", - " audiolen_tensor.copy_from_cpu(audio_len.copy())\n", - "\n", - " output_names = predictor.get_output_names()\n", - " for i, name in enumerate(output_names):\n", - " print(\"output:\", i, name)\n", - "\n", - " # do the inference\n", - " predictor.run()\n", - "\n", - " results = []\n", - " # get out data from output tensor\n", - " output_names = predictor.get_output_names()\n", - " for i, name in enumerate(output_names):\n", - " output_tensor = predictor.get_output_handle(name)\n", - " output_data = output_tensor.copy_to_cpu()\n", - " results.append(output_data)\n", - "\n", - " return results\n", - "\n", - "\n", - "predictor = init_predictor(args)\n", - "\n", - "def file_to_transcript(filename):\n", - " print(filename)\n", - " feature = dataset.process_utterance(filename, \"\")\n", - " audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n", - " audio_len = feature[0].shape[1]\n", - " audio_len = np.array([audio_len]).astype('int64') # [1]\n", - " \n", - " \n", - " i_probs = run(predictor, audio, audio_len)\n", - " print('jit:', i_probs[0], type(i_probs[0]))\n", - " \n", - " audio = paddle.to_tensor(audio)\n", - " audio_len = paddle.to_tensor(audio_len)\n", - " print(audio.shape)\n", - " print(audio_len.shape)\n", - " \n", - " #eouts, eouts_len = model.encoder(audio, audio_len)\n", - " #probs = model.decoder.softmax(eouts)\n", - " probs = model.forward(audio, audio_len)\n", - " print('paddle:', probs.numpy())\n", - " \n", - " flag = np.allclose(i_probs[0], probs.numpy())\n", - " print(flag)\n", - " \n", - " return probs\n", - "\n", - "# result_transcript = model.decode(\n", - "# audio,\n", - "# audio_len,\n", - "# vocab_list=dataset.vocab_list,\n", - "# decoding_method=config.decoding.decoding_method,\n", - "# lang_model_path=config.decoding.lang_model_path,\n", - "# beam_alpha=config.decoding.alpha,\n", - "# beam_beta=config.decoding.beta,\n", - "# beam_size=config.decoding.beam_size,\n", - "# cutoff_prob=config.decoding.cutoff_prob,\n", - "# cutoff_top_n=config.decoding.cutoff_top_n,\n", - "# num_processes=config.decoding.num_proc_bsearch)\n", - "# return result_transcript[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Warm-up Test Case %d: %s 0 /home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n", - "input: 0 audio\n", - "input: 1 audio_len\n", - "output: 0 tmp_75\n", - "jit: [[[8.91786298e-12 4.45648032e-12 3.67572750e-09 ... 8.91767563e-12\n", - " 8.91573707e-12 4.64317296e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638127e-17 7.61802427e-16 2.93265812e-14 ... 1.24633371e-17\n", - " 1.24587264e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676260e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89334696e-13 1.66754856e-11 1.42900388e-11 ... 3.89329492e-13\n", - " 3.89252270e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]] \n", - "[1, 161, 522]\n", - "[1]\n", - "paddle: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n", - " 8.91577090e-12 4.64319072e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n", - " 1.24587735e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n", - " 3.89253761e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]]\n", - "False\n" - ] - } - ], - "source": [ - "manifest = read_manifest(args.warmup_manifest)\n", - "\n", - "for idx, sample in enumerate(manifest[:1]):\n", - " print(\"Warm-up Test Case %d: %s\", idx, sample['audio_filepath'])\n", - " start_time = time.time()\n", - " transcript = file_to_transcript(sample['audio_filepath'])\n", - " finish_time = time.time()\n", - "# print(\"Response Time: %f, Transcript: %s\" %\n", - "# (finish_time - start_time, transcript))\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 161, 522) (1,)\n", - "input: 0 audio\n", - "input: 1 audio_len\n", - "output: 0 tmp_75\n", - "jit: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n", - " 8.91577090e-12 4.64319072e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n", - " 1.24587735e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n", - " 3.89253761e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]]\n" - ] - } - ], - "source": [ - "def test(filename):\n", - " feature = dataset.process_utterance(filename, \"\")\n", - " audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n", - " audio_len = feature[0].shape[1]\n", - " audio_len = np.array([audio_len]).astype('int64') # [1]\n", - " \n", - " print(audio.shape, audio_len.shape)\n", - "\n", - " i_probs = run(predictor, audio, audio_len)\n", - " print('jit:', i_probs[0])\n", - " return i_probs\n", - " \n", - "probs = test(sample['audio_filepath'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/.notebook/layer_norm_test.ipynb b/.notebook/layer_norm_test.ipynb deleted file mode 100644 index eac3566ff0590295a1f3b742cd8d038f420500ce..0000000000000000000000000000000000000000 --- a/.notebook/layer_norm_test.ipynb +++ /dev/null @@ -1,229 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 32, - "id": "academic-surname", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from paddle import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "fundamental-treasure", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameter containing:\n", - "Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])\n", - "Parameter containing:\n", - "Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n" - ] - } - ], - "source": [ - "L = nn.LayerNorm(256, epsilon=1e-12)\n", - "for p in L.parameters():\n", - " print(p)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "consolidated-elephant", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "moderate-noise", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 51, 256)\n", - "print(x.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "cooked-progressive", - "metadata": {}, - "outputs": [], - "source": [ - "y = L(paddle.to_tensor(x, dtype='float32'))" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "optimum-milwaukee", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "viral-indian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameter containing:\n", - "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1.], requires_grad=True)\n", - "Parameter containing:\n", - "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " requires_grad=True)\n" - ] - } - ], - "source": [ - "TL = torch.nn.LayerNorm(256, eps=1e-12)\n", - "for p in TL.parameters():\n", - " print(p)" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "skilled-vietnamese", - "metadata": {}, - "outputs": [], - "source": [ - "ty = TL(torch.tensor(x, dtype=torch.float32))" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "incorrect-allah", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(y.numpy(), ty.detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "prostate-cameroon", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "governmental-surge", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = np.random.randn(2, 256)\n", - "y = L(paddle.to_tensor(x, dtype='float32'))\n", - "ty = TL(torch.tensor(x, dtype=torch.float32))\n", - "np.allclose(y.numpy(), ty.detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "confidential-jacket", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/mask_and_masked_fill_test.ipynb b/.notebook/mask_and_masked_fill_test.ipynb deleted file mode 100644 index 265ec536b93260b3165423e57de548574ac7a5de..0000000000000000000000000000000000000000 --- a/.notebook/mask_and_masked_fill_test.ipynb +++ /dev/null @@ -1,449 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "primary-organic", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "stopped-semester", - "metadata": {}, - "outputs": [], - "source": [ - "def mask_finished_scores(score: torch.Tensor,\n", - " flag: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"\n", - " If a sequence is finished, we only allow one alive branch. This function\n", - " aims to give one branch a zero score and the rest -inf score.\n", - " Args:\n", - " score (torch.Tensor): A real value array with shape\n", - " (batch_size * beam_size, beam_size).\n", - " flag (torch.Tensor): A bool array with shape\n", - " (batch_size * beam_size, 1).\n", - " Returns:\n", - " torch.Tensor: (batch_size * beam_size, beam_size).\n", - " \"\"\"\n", - " beam_size = score.size(-1)\n", - " zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n", - " if beam_size > 1:\n", - " unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n", - " dim=1)\n", - " finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n", - " dim=1)\n", - " else:\n", - " unfinished = zero_mask\n", - " finished = flag\n", - " print(unfinished)\n", - " print(finished)\n", - " score.masked_fill_(unfinished, -float('inf'))\n", - " score.masked_fill_(finished, 0)\n", - " return score" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "agreed-portuguese", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ True],\n", - " [False]])\n", - "tensor([[-0.8841, 0.7381, -0.9986],\n", - " [ 0.2675, -0.7971, 0.3798]])\n", - "tensor([[ True, True],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "score = torch.randn((2, 3))\n", - "flag = torch.ones((2, 1), dtype=torch.bool)\n", - "flag[1] = False\n", - "print(flag)\n", - "print(score)\n", - "print(flag.repeat([1, 2]))" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "clean-aspect", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[False, True, True],\n", - " [False, False, False]])\n", - "tensor([[ True, False, False],\n", - " [False, False, False]])\n", - "tensor([[ 0.0000, -inf, -inf],\n", - " [ 0.2675, -0.7971, 0.3798]])\n", - "tensor([[ 0.0000, -inf, -inf],\n", - " [ 0.2675, -0.7971, 0.3798]])\n" - ] - } - ], - "source": [ - "r = mask_finished_scores(score, flag)\n", - "print(r)\n", - "print(score)" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "thrown-airline", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True ],\n", - " [False]])\n", - "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "import paddle\n", - "\n", - "score = paddle.randn((2, 3))\n", - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag)\n", - "print(score)\n", - "print(flag.tile([1, 2]))" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "internal-patent", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[False, True , True ],\n", - " [False, False, False]])\n", - "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , False, False],\n", - " [False, False, False]])\n", - "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 0. , -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 0. , -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n" - ] - } - ], - "source": [ - "paddle.bool = 'bool'\n", - "\n", - "def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", - " print(xs)\n", - " trues = paddle.ones_like(xs) * value\n", - " assert xs.shape == mask.shape\n", - " xs = paddle.where(mask, trues, xs)\n", - " return xs\n", - "\n", - "def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", - " print('x', xs)\n", - " trues = paddle.ones_like(xs) * value\n", - " assert xs.shape == mask.shape\n", - " ret = paddle.where(mask, trues, xs)\n", - " print('2', xs)\n", - " paddle.assign(ret, output=xs)\n", - " print('3', xs)\n", - "\n", - "paddle.Tensor.masked_fill = masked_fill\n", - "paddle.Tensor.masked_fill_ = masked_fill_\n", - "\n", - "def mask_finished_scores_pd(score: paddle.Tensor,\n", - " flag: paddle.Tensor) -> paddle.Tensor:\n", - " \"\"\"\n", - " If a sequence is finished, we only allow one alive branch. This function\n", - " aims to give one branch a zero score and the rest -inf score.\n", - " Args:\n", - " score (torch.Tensor): A real value array with shape\n", - " (batch_size * beam_size, beam_size).\n", - " flag (torch.Tensor): A bool array with shape\n", - " (batch_size * beam_size, 1).\n", - " Returns:\n", - " torch.Tensor: (batch_size * beam_size, beam_size).\n", - " \"\"\"\n", - " beam_size = score.shape[-1]\n", - " zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n", - " if beam_size > 1:\n", - " unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n", - " axis=1)\n", - " finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n", - " axis=1)\n", - " else:\n", - " unfinished = zero_mask\n", - " finished = flag\n", - " print(unfinished)\n", - " print(finished)\n", - " \n", - " #score.masked_fill_(unfinished, -float('inf'))\n", - " #score.masked_fill_(finished, 0)\n", - "# infs = paddle.ones_like(score) * -float('inf')\n", - "# score = paddle.where(unfinished, infs, score)\n", - "# score = paddle.where(finished, paddle.zeros_like(score), score)\n", - "\n", - "# score = score.masked_fill(unfinished, -float('inf'))\n", - "# score = score.masked_fill(finished, 0)\n", - " score.masked_fill_(unfinished, -float('inf'))\n", - " score.masked_fill_(finished, 0)\n", - " return score\n", - "\n", - "r = mask_finished_scores_pd(score, flag)\n", - "print(r)" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "vocal-prime", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "score.value" - ] - }, - { - "cell_type": "code", - "execution_count": 71, - "id": "bacterial-adolescent", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Union, Any" - ] - }, - { - "cell_type": "code", - "execution_count": 72, - "id": "absent-fiber", - "metadata": {}, - "outputs": [], - "source": [ - "def repeat(xs : paddle.Tensor, *size: Any):\n", - " print(size)\n", - " return paddle.tile(xs, size)\n", - "paddle.Tensor.repeat = repeat" - ] - }, - { - "cell_type": "code", - "execution_count": 73, - "id": "material-harbor", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 2)\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag.repeat(1, 2))" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "id": "acute-brighton", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [1]), 2)\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag.repeat(paddle.to_tensor(1), 2))" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "id": "european-rugby", - "metadata": {}, - "outputs": [], - "source": [ - "def size(xs, *args: int):\n", - " nargs = len(args)\n", - " s = paddle.shape(xs)\n", - " assert(nargs <= 1)\n", - " if nargs == 1:\n", - " return s[args[0]]\n", - " else:\n", - " return s\n", - "paddle.Tensor.size = size" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "id": "moral-special", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [2, 1])" - ] - }, - "execution_count": 86, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size()" - ] - }, - { - "cell_type": "code", - "execution_count": 87, - "id": "ahead-coach", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [1])" - ] - }, - "execution_count": 87, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "id": "incomplete-fitness", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [2])" - ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size(0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "upset-connectivity", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/position_embeding_check.ipynb b/.notebook/position_embeding_check.ipynb deleted file mode 100644 index d4b9098d989c40f5ceb3e36842354336c8d280dc..0000000000000000000000000000000000000000 --- a/.notebook/position_embeding_check.ipynb +++ /dev/null @@ -1,231 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "id": "designing-borough", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "import torch\n", - "import math\n", - "import numpy as np\n", - "\n", - "max_len=100\n", - "d_model=256\n", - "\n", - "pe = torch.zeros(max_len, d_model)\n", - "position = torch.arange(0, max_len,\n", - " dtype=torch.float32).unsqueeze(1)\n", - "toruch_position = position\n", - "div_term = torch.exp(\n", - " torch.arange(0, d_model, 2, dtype=torch.float32) *\n", - " -(math.log(10000.0) / d_model))\n", - "tourch_div_term = div_term.cpu().detach().numpy()\n", - "\n", - "\n", - "\n", - "torhc_sin = torch.sin(position * div_term)\n", - "torhc_cos = torch.cos(position * div_term)\n", - "print(torhc_sin.cpu().detach().numpy())\n", - "np_sin = np.sin((position * div_term).cpu().detach().numpy())\n", - "np_cos = np.cos((position * div_term).cpu().detach().numpy())\n", - "print(np.allclose(np_sin, torhc_sin.cpu().detach().numpy()))\n", - "print(np.allclose(np_cos, torhc_cos.cpu().detach().numpy()))\n", - "pe[:, 0::2] = torhc_sin\n", - "pe[:, 1::2] = torhc_cos\n", - "tourch_pe = pe.cpu().detach().numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "swiss-referral", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "False\n", - "False\n", - "False\n", - "False\n", - "[[ 1. 1. 1. ... 1. 1.\n", - " 1. ]\n", - " [ 0.5403023 0.59737533 0.6479059 ... 1. 1.\n", - " 1. ]\n", - " [-0.41614684 -0.28628543 -0.1604359 ... 0.99999994 1.\n", - " 1. ]\n", - " ...\n", - " [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.99993724\n", - " 0.9999457 ]\n", - " [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n", - " 0.99994457]\n", - " [ 0.03982088 -0.52298605 -0.6157435 ... 0.99992454 0.9999347\n", - " 0.99994344]]\n", - "----\n", - "[[ 1. 1. 1. ... 1. 1.\n", - " 1. ]\n", - " [ 0.54030234 0.59737533 0.6479059 ... 1. 1.\n", - " 1. ]\n", - " [-0.41614684 -0.28628543 -0.1604359 ... 1. 1.\n", - " 1. ]\n", - " ...\n", - " [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.9999373\n", - " 0.9999457 ]\n", - " [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n", - " 0.99994457]\n", - " [ 0.03982088 -0.5229861 -0.6157435 ... 0.99992454 0.9999347\n", - " 0.99994344]]\n", - ")))))))\n", - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n", - "----\n", - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n" - ] - } - ], - "source": [ - "import paddle\n", - "paddle.set_device('cpu')\n", - "ppe = paddle.zeros((max_len, d_model), dtype='float32')\n", - "position = paddle.arange(0, max_len,\n", - " dtype='float32').unsqueeze(1)\n", - "print(np.allclose(position.numpy(), toruch_position))\n", - "div_term = paddle.exp(\n", - " paddle.arange(0, d_model, 2, dtype='float32') *\n", - " -(math.log(10000.0) / d_model))\n", - "print(np.allclose(div_term.numpy(), tourch_div_term))\n", - "\n", - "\n", - "\n", - "p_sin = paddle.sin(position * div_term)\n", - "p_cos = paddle.cos(position * div_term)\n", - "print(np.allclose(np_sin, p_sin.numpy(), rtol=1.e-6, atol=0))\n", - "print(np.allclose(np_cos, p_cos.numpy(), rtol=1.e-6, atol=0))\n", - "ppe[:, 0::2] = p_sin\n", - "ppe[:, 1::2] = p_cos\n", - "print(np.allclose(p_sin.numpy(), torhc_sin.cpu().detach().numpy()))\n", - "print(np.allclose(p_cos.numpy(), torhc_cos.cpu().detach().numpy()))\n", - "print(p_cos.numpy())\n", - "print(\"----\")\n", - "print(torhc_cos.cpu().detach().numpy())\n", - "print(\")))))))\")\n", - "print(p_sin.numpy())\n", - "print(\"----\")\n", - "print(torhc_sin.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "integrated-boards", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n" - ] - } - ], - "source": [ - "print(np.allclose(ppe.numpy(), pe.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "flying-reserve", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "revised-divide", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/python_test.ipynb b/.notebook/python_test.ipynb deleted file mode 100644 index 819d4c48f8191f96fd0ade84055ae3e8fb20cd94..0000000000000000000000000000000000000000 --- a/.notebook/python_test.ipynb +++ /dev/null @@ -1,1680 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-lender", - "metadata": {}, - "outputs": [], - "source": [ - "eng=\"one minute a voice said and the time buzzer sounded\"\n", - "chn=\"可控是病毒武器最基本的要求\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ruled-kuwait", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "o\n", - "n\n", - "e\n", - " \n", - "m\n", - "i\n", - "n\n", - "u\n", - "t\n", - "e\n", - " \n", - "a\n", - " \n", - "v\n", - "o\n", - "i\n", - "c\n", - "e\n", - " \n", - "s\n", - "a\n", - "i\n", - "d\n", - " \n", - "a\n", - "n\n", - "d\n", - " \n", - "t\n", - "h\n", - "e\n", - " \n", - "t\n", - "i\n", - "m\n", - "e\n", - " \n", - "b\n", - "u\n", - "z\n", - "z\n", - "e\n", - "r\n", - " \n", - "s\n", - "o\n", - "u\n", - "n\n", - "d\n", - "e\n", - "d\n" - ] - } - ], - "source": [ - "for char in eng:\n", - " print(char)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "passive-petite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "可\n", - "控\n", - "是\n", - "病\n", - "毒\n", - "武\n", - "器\n", - "最\n", - "基\n", - "本\n", - "的\n", - "要\n", - "求\n" - ] - } - ], - "source": [ - "for char in chn:\n", - " print(char)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "olympic-realtor", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "one\n", - "minute\n", - "a\n", - "voice\n", - "said\n", - "and\n", - "the\n", - "time\n", - "buzzer\n", - "sounded\n" - ] - } - ], - "source": [ - "for word in eng.split():\n", - " print(word)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "induced-enhancement", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "可控是病毒武器最基本的要求\n" - ] - } - ], - "source": [ - "for word in chn.split():\n", - " print(word)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "lovely-bottle", - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'StringIO'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mStringIO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'StringIO'" - ] - } - ], - "source": [ - "import StringIO" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "interested-cardiff", - "metadata": {}, - "outputs": [], - "source": [ - "from io import StringIO" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "portable-ivory", - "metadata": {}, - "outputs": [], - "source": [ - "inputs = StringIO()" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "compatible-destination", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "64" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "federal-margin", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n", - "\n" - ] - } - ], - "source": [ - "print(inputs.getvalue())" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "consecutive-entity", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "64" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "desirable-anxiety", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n", - "nor is mister quilter's manner less interesting than his matter\n", - "\n" - ] - } - ], - "source": [ - "print(inputs.getvalue())" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "employed-schedule", - "metadata": {}, - "outputs": [], - "source": [ - "import tempfile" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "unlikely-honduras", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__class__', '__del__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__ne__', '__new__', '__next__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_checkClosed', '_checkReadable', '_checkSeekable', '_checkWritable', '_dealloc_warn', '_finalizing', 'close', 'closed', 'detach', 'fileno', 'flush', 'isatty', 'mode', 'name', 'peek', 'raw', 'read', 'read1', 'readable', 'readinto', 'readinto1', 'readline', 'readlines', 'seek', 'seekable', 'tell', 'truncate', 'writable', 'write', 'writelines']\n", - "57\n" - ] - } - ], - "source": [ - "with tempfile.TemporaryFile() as fp:\n", - " print(dir(fp))\n", - " print(fp.name)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "needed-trail", - "metadata": {}, - "outputs": [], - "source": [ - "a = tempfile.mkstemp(suffix=None, prefix='test', dir=None, text=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "hazardous-choir", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__add__', '__class__', '__contains__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getnewargs__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'count', 'index']\n" - ] - } - ], - "source": [ - "print(dir(a))" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "front-sauce", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(57, '/tmp/test27smzbzc')\n" - ] - } - ], - "source": [ - "print(a)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "shared-wages", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "print(a.index)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "charged-carnival", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_closer', 'close', 'delete', 'file', 'name']\n", - "/tmp/tmpfjn7mygy\n" - ] - } - ], - "source": [ - "fp= tempfile.NamedTemporaryFile(mode='w', delete=False)\n", - "print(dir(fp))\n", - "print(fp.name)\n", - "fp.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "religious-terror", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/tmp/tmpfjn7mygy\n" - ] - } - ], - "source": [ - "import os\n", - "os.path.exists(fp.name)\n", - "print(fp.name)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "communist-gospel", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fp.write" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "simplified-clarity", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'example'" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s='/home/ubuntu/python/example.py'\n", - "os.path.splitext(os.path.basename(s))[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "popular-genius", - "metadata": {}, - "outputs": [], - "source": [ - "from collections import Counter" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "studied-burner", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('hello', 1), ('world', 1)])\n" - ] - } - ], - "source": [ - "counter = Counter()\n", - "counter.update([\"hello\"])\n", - "counter.update([\"world\"])\n", - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "mineral-ceremony", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('h', 1), ('e', 1), ('l', 3), ('o', 2), ('w', 1), ('r', 1), ('d', 1)])\n" - ] - } - ], - "source": [ - "counter = Counter()\n", - "counter.update(\"hello\")\n", - "counter.update(\"world\")\n", - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "nonprofit-freedom", - "metadata": {}, - "outputs": [], - "source": [ - "counter.update(list(\"hello\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "extended-methodology", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('h', 2), ('e', 2), ('l', 5), ('o', 3), ('w', 1), ('r', 1), ('d', 1)])\n" - ] - } - ], - "source": [ - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "grand-benjamin", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['h', 'e', 'l', 'l', 'o']" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list(\"hello\")" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "marine-fundamentals", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{}\n" - ] - } - ], - "source": [ - "from io import StringIO\n", - "a = StringIO(initial_value='{}', newline='')\n", - "print(a.read())" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "suitable-charlotte", - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "expected str, bytes or os.PathLike object, not _io.StringIO", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mTypeError\u001b[0m: expected str, bytes or os.PathLike object, not _io.StringIO" - ] - } - ], - "source": [ - "with io.open(a) as f:\n", - " print(f.read())" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "institutional-configuration", - "metadata": {}, - "outputs": [], - "source": [ - "io.open?" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "pregnant-modem", - "metadata": {}, - "outputs": [], - "source": [ - "def get_default_args(fn):\n", - " if fn is None:\n", - " return {}\n", - "\n", - " signature = inspect.signature(fn)\n", - " return {\n", - " k: v.default\n", - " for k, v in signature.parameters.items()\n", - " if v.default is not inspect.Parameter.empty\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "first-release", - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'inspect' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_default_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mget_default_args\u001b[0;34m(fn)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0msignature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minspect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m return {\n\u001b[1;32m 7\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefault\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'inspect' is not defined" - ] - } - ], - "source": [ - "get_default_args(io.open)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "convertible-roulette", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: sox in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (1.4.1)\n", - "Requirement already satisfied: numpy>=1.9.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from sox) (1.20.1)\n", - "Requirement already satisfied: librosa in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (0.8.0)\n", - "Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.24.1)\n", - "Requirement already satisfied: numba>=0.43.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.52.0)\n", - "Requirement already satisfied: pooch>=1.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.3.0)\n", - "Requirement already satisfied: scipy>=1.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.2.1)\n", - "Requirement already satisfied: numpy>=1.15.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.20.1)\n", - "Requirement already satisfied: decorator>=3.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (4.4.2)\n", - "Requirement already satisfied: resampy>=0.2.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.2.2)\n", - "Requirement already satisfied: audioread>=2.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (2.1.9)\n", - "Requirement already satisfied: soundfile>=0.9.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.9.0.post1)\n", - "Requirement already satisfied: joblib>=0.14 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.0.1)\n", - "Requirement already satisfied: setuptools in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from numba>=0.43.0->librosa) (51.0.0)\n", - "Requirement already satisfied: llvmlite<0.36,>=0.35.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from numba>=0.43.0->librosa) (0.35.0)\n", - "Requirement already satisfied: appdirs in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (1.4.4)\n", - "Requirement already satisfied: packaging in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (20.9)\n", - "Requirement already satisfied: requests in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (2.25.1)\n", - "Requirement already satisfied: six>=1.3 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from resampy>=0.2.2->librosa) (1.15.0)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa) (2.1.0)\n", - "Requirement already satisfied: cffi>=0.6 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from soundfile>=0.9.0->librosa) (1.14.4)\n", - "Requirement already satisfied: pycparser in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from cffi>=0.6->soundfile>=0.9.0->librosa) (2.20)\n", - "Requirement already satisfied: pyparsing>=2.0.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from packaging->pooch>=1.0->librosa) (2.4.7)\n", - "Requirement already satisfied: idna<3,>=2.5 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (2.10)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (2020.12.5)\n", - "Requirement already satisfied: chardet<5,>=3.0.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (4.0.0)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (1.26.3)\n" - ] - } - ], - "source": [ - "!pip install sox\n", - "!pip install librosa" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "cutting-fleece", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import sox\n", - "tfm = sox.Transformer()\n", - "sample_rate = 44100\n", - "y = np.sin(2 * np.pi * 440.0 * np.arange(sample_rate * 1.0) / sample_rate)\n", - "print(y.dtype.type)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "historical-diving", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.06264832 0.12505052 ... -0.18696144 -0.12505052\n", - " -0.06264832]\n" - ] - } - ], - "source": [ - "output_array = tfm.build_array(input_array=y, sample_rate_in=sample_rate)\n", - "print(output_array)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "similar-spice", - "metadata": {}, - "outputs": [], - "source": [ - "tfm.build_array?" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "grand-influence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['8svx', 'aif', 'aifc', 'aiff', 'aiffc', 'al', 'amb', 'amr-nb', 'amr-wb', 'anb', 'au', 'avr', 'awb', 'caf', 'cdda', 'cdr', 'cvs', 'cvsd', 'cvu', 'dat', 'dvms', 'f32', 'f4', 'f64', 'f8', 'fap', 'flac', 'fssd', 'gsm', 'gsrt', 'hcom', 'htk', 'ima', 'ircam', 'la', 'lpc', 'lpc10', 'lu', 'mat', 'mat4', 'mat5', 'maud', 'nist', 'ogg', 'paf', 'prc', 'pvf', 'raw', 's1', 's16', 's2', 's24', 's3', 's32', 's4', 's8', 'sb', 'sd2', 'sds', 'sf', 'sl', 'sln', 'smp', 'snd', 'sndfile', 'sndr', 'sndt', 'sou', 'sox', 'sph', 'sw', 'txw', 'u1', 'u16', 'u2', 'u24', 'u3', 'u32', 'u4', 'u8', 'ub', 'ul', 'uw', 'vms', 'voc', 'vorbis', 'vox', 'w64', 'wav', 'wavpcm', 'wv', 'wve', 'xa', 'xi']\n" - ] - } - ], - "source": [ - "print(sox.core._get_valid_formats())" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "wireless-hypothetical", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n", - "(59471,)\n", - "16000\n", - "(54065,)\n", - "1.0999907518727459\n" - ] - } - ], - "source": [ - "import soundfile as sf\n", - "wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/dev/S0724/BAC009S0724W0190.wav'\n", - "samples, sr = sf.read(wav)\n", - "print(samples.dtype)\n", - "print(samples.shape)\n", - "print(sr)\n", - "tfm = sox.Transformer()\n", - "tfm.speed(1.1)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "output_array.dtype\n", - "print(output_array.shape)\n", - "print(len(samples)/len(output_array))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "designed-fluid", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import IPython.display as ipd\n", - "ipd.Audio(wav) # load a local WAV file" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "cultural-friendship", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(1.0)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "fossil-lotus", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(1.1)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "constitutional-poker", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(0.9)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "threaded-strap", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "66078\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8K0lEQVR4nO2dd3hUZfbHvycdQoAEQpEWmlQVJICoKApqABdcF8u6KlbUXX+77rqrIFZs7Lr2si5WXHXtrigI0myoSFB67xDpoYQEUs/vj7kTJpM7M/fO7XPP53nmye33ZObe97zveU8hZoYgCILgX5KcFkAQBEFwFlEEgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPscURUBEBUS0log2ENF4lf1/IaJVRLSMiOYSUYeQfWOJaL3yGWuGPIIgCIJ2yGgcARElA1gH4DwAOwAsAvBbZl4Vcsw5ABYycxkR3QJgCDNfRkQ5AAoB5ANgAIsB9GPmA4aEEgRBEDRjxohgAIANzLyJmSsAvANgdOgBzDyfmcuU1R8AtFWWLwAwm5mLlcZ/NoACE2QSBEEQNJJiwjXaANgesr4DwMAox18P4PMo57aJdcPmzZtzXl6ePikFQRB8zuLFi/cxc274djMUgWaI6EoEzEBnx3HuOADjAKB9+/YoLCw0WTpBEITEhoi2qm03wzRUBKBdyHpbZVu4AMMATAQwipnL9ZwLAMw8hZnzmTk/N7eeQhMEQRDixAxFsAhAVyLqSERpAC4HMC30ACLqC+DfCCiBPSG7ZgE4n4iyiSgbwPnKNkEQBMEmDJuGmLmKiG5FoAFPBvAqM68kokkACpl5GoDHADQC8D4RAcA2Zh7FzMVE9CACygQAJjFzsVGZBEEQBO0Ydh91gvz8fJY5AkEQBH0Q0WJmzg/fLpHFgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgif42/tLsXirpKASBCuwNbJYEOLl/cU7kJ6ahH4dsp0WRRASDhkRCIKQcKwoOoSZK3Y5LYZnEEUgCELCMeGj5bj5zcVOi+EZRBEInsGDsY+CQzDkYdGDKALBM8irLWhFOg36EEUgCELCIYpAH6IIBEFIOEQP6EMUgeAZ3l64zWkRBI/gxWSaTiKKQBAEweeIIhAEIaGorK7Bml0lTovhKUQRCIKQUBworXBaBM8hikAQhISiRqYHdCOKQPAUFVU1OPHuz50WQ3AxNTJRrBtTFAERFRDRWiLaQETjVfafRUQ/EVEVEY0J21dNREuUz7TwcwUhlKMV1aioqnFaDMHFiCLQj+Hso0SUDOB5AOcB2AFgERFNY+ZVIYdtA3ANgL+qXOIoM/cxKofgE8hpAQS3UyP9BN2YMSIYAGADM29i5goA7wAYHXoAM29h5mUA5CcSdDFj+U7kjZ/utBiCh5ARgX7MUARtAGwPWd+hbNNKBhEVEtEPRHSRCfIkNHnjp2P/kXKnxbCNtWFugEkyIhBiMPKZb5wWwXO4YbK4AzPnA7gCwFNE1FntICIapyiMwr1799orocvYU+IfRRDetyMSTSBEp7SiunZ5n486TUYwQxEUAWgXst5W2aYJZi5S/m4C8CWAvhGOm8LM+cycn5ubG7+0HmD+mj1Ysv1gve3HKgMPeLWf/ONkmC8YYM9hUQRaMEMRLALQlYg6ElEagMsBaPL+IaJsIkpXlpsDOAPAquhnJT7Xvr4If3jrp3rb/zFzLQB/KYLw/3Tu6t2OyCF4E5kv0IZhRcDMVQBuBTALwGoA7zHzSiKaRESjAICI+hPRDgCXAPg3Ea1UTu8BoJCIlgKYD2BymLeRr8kbP73OROn+0kDv5kCZfyMn//TOEgCSVEzQhjwm2jCleD0zzwAwI2zbvSHLixAwGYWf9x2Ak8yQIdFQM4UHH+oHPl2FId1a2CuQQ0R6kcurapCRmmyvMB6kpoZRw4yUZDdMB9qPjAi04c+nw8UcVHr7as9vsZJDZfO+UjtFcpRIJQeHPfGVzZJ4k7/PWoPu98x0WgzH+GHTfqdF8ASiCFzG8qJDqts37DmCbzfss1ka54nUodtx4Ki9gniUVb8cRpWP5pQ+WVLXT0UUgTZEEXiEV77d7LQIjuCfJsxanGgQDx+rxLIdB2295+qddeNOkpOkidOCfEseoVri5gUDXD7lB9vv+cQX6zDquQW23zcUn06N6Ea+Jo/gp+F9KDLXZ4yFm4odu3eVCzovBb1bOS2CJxBF4BFq/KoIxDhkiIpq5xrjJAeiwMOfl7Rk8SzTgigCl1J0sO5kaLWG9rDkWCXKKqosksghRA94FicUgRAfoghcRiRTiJYRwbAnvnLEFmwloge8i+gB72BKQJlgPdOX74x5zO7D5ThyLMFGBIJnkRGBd5ARgcsw+u5UJ9jsqqSS8C4rIsTEWMl7i7bHPkiohygCB2FmVJk8mZdoCelED3iTWSt3YeFm+z2WDpRV1lkXZwNtiCJwkBe+3IguE80txJ5oikDwJou3HqhdPmPyPAclEbQgisBBPltW3+4f+gLFQ6LpgQT7d3xDqIUz3ANOcB8yWewgq3cerrftqTnrHZBEEMzjdy//gAUb3JHjp0hyUmlCRgSCq5E5Am9xoLTCNUoAAB79fA0A4L8/bkNpuXjURUIUgUuoqq7Bawu0J5Yrr6qus/7yN5vMFsn1/PX9pfhmvb/rVwPA7sPHsOvQMafFAAAcrayOfZADTPhoOeau2eO0GK5FFIFL2LK/FA98qr0427pdR+qsPzR9tdkiuZ4PFu/AB4t3OC2G4/z6+QU4+7H5TosRFSkx6m5EEbgE0hlA8Kvnvq1dTuSc6+L+F5tjVTUor6rvhjxzxS7bZYn0GF8/tdBeQQRdmKIIiKiAiNYS0QYiGq+y/ywi+omIqohoTNi+sUS0XvmMNUMerzH6uW9hJI7s2tcWmSaL23htwZaYx0xfthMb9hyJeZzfePyLtbbfkww9yYJTGFYERJQM4HkAwwH0BPBbIuoZdtg2ANcAeDvs3BwA9wEYCGAAgPuIKNuoTF5j6Y5DukcEwPGo20oHM0y6gT+8/RMeneE/01iQSHmoJMODoBUzRgQDAGxg5k3MXAHgHQCjQw9g5i3MvAxAeIt1AYDZzFzMzAcAzAZQYIJMniPaOxvpRT+s5BXya60C4Pj35tdGr7i0AgePVsY+UBCiYIYiaAMgNMHHDmWb1ecmFNESdHW6a4b6DpZcPMf/e/9ogsv+/T0qlDmBaC6RTjwaTipkv9bsMAPPTBYT0TgiKiSiwr17ve8yOPbVH+usx/MCMRiLthiLRE4U9h4pd1oES7lhaiEOllUAABZuLsbhY4FRQFKSuxSgk9IcijAyyhs/3WZJvIcZiqAIQLuQ9bbKNlPPZeYpzJzPzPm5ublxCeomvlpnXJkxu6McoJN8suQXAMDS7QedFcRi5qzejeUh2TyDvX09De/KXw4ldKMo44H4MUMRLALQlYg6ElEagMsBTNN47iwA5xNRtjJJfL6yzXfENyIQL40gLusYW8LRiuPBWkG3Wj05/x/6rO6E+okTP8cvCZQH6JVvowdV+t2MGg3DioCZqwDcikADvhrAe8y8kogmEdEoACCi/kS0A8AlAP5NRCuVc4sBPIiAMlkEYJKyzXfE6zXk10nScOL5/rwGA9hbUn58BdE7EOH7UlPqvu4V1TXYsr/UPAEd5sWv/BddbxamzBEw8wxmPpGZOzPzw8q2e5l5mrK8iJnbMnMmMzdj5l4h577KzF2Uz2tmyONF4mnGGJKLx2/0f3gOAODCZwMBhXqem68Vc+ScVbstm1j9YlXkCGIr3ZwPllVICnYDeGayONFRy0QaC2ZgW3H9Hp0f4wqqaxj7E3zCOJQ9JeXIGz8dO+PIMXTDG4VY8Ysy32By23n3/1ZE3Ldwk3WD/QofPvNmIorAJcQTgs9g3Pnh8nrbS3xat7jfQ3NwJIEzTKqN/lbF0YEArIk9eXZu9BTqKckWmu9kMGAIUQQeJpJZyM+eREcSUAlGGy2GNq16TCMXv/CdAYnUeXz2uqj7Uy1UBFpqdYsZNTKiCDxMpBffz7bSRExSN/zpbwAAN7+5uN6+X0JMQ+EeQJo8ymycY09Jsq65qapOvN/dTkQReJhIj77VL8XmfaU49/EvLb2HoI1nQswxySE+tC99vUlTL9lORj+/wLJra/lXE7GTYBaiCDxMJM+PvUfK6/icm83S7Qexaa873Q79Oj8C1B0JPjxjta8mz2tcpvS8higCBzArsGXwP9SLkVz8wnf44zs/m3IPNYI9q2U7Dlp2j3g5/8mvbb/n/iPl2HnI+cCscM8ZTSbCBGk/tfwbEnwZGVEEDvB+ofVVtfYctq50YVCPXfHSQsvu4SUum/IDBj06z2kx6plH/NRJ1jIiCDUN7Tl8DP+cZX+9BrciisABNu61vohKekqyZdcOvnNHyqtw+FhlbQI0vxJMBuc04SPNkjhcafcfKa/NbBqNPSXHcOGz3+i+vlXoVXqzVu7Cc/M3WCOMBxFF4AQ2jFDTUqz5af/15UY8FtKTuuj5BRj5tHsaBD9jxgCg30Nz8Pjs2D3ldbuOYEXRYcx3SUF4LebW0EN8NFjShCgCB7DDVmmFz/aew8fwxOy12BVidtq0txTbDzhvHw+lyvYoU3fYnoMN3UOfrTJ0nc0aHAGCHkrXvu6OMqnSsBtDFIEDvPjVRsvvkZps/k874JG5qPSAv/Zpj841fI1v1+/DnCh5c9xI0Ab+2bKdms+pZq430f39pv0xz7M0SjgONM0RuP/RdQxRBAlKsh/yMkdg3xFjNvtjldW48pWFuOENbWk/gpHcTuf6DwaU6/GXf+P7rXFNdLvt8dLSyIcqC1EKdRFFkKDYnZX5UFklikvdMWlqlEgpHf73cxG2F5fV2ZY3fjoOlrljsjyoAPQElu+LM9bAbQ2plhFBoVLNb8n2g1KbIAxRBIIpjHnxO5zzzy+dFsMUfh0hD89t7y7B8Ke/qW1EIjWi63eXOJLmI9i21dYs0EC8YrqtGdXSrr9bGCiPftHzC7B2t/Wee15CFEGCYnehlvV7jkSsGZtIHCmvqv0/Rz+nnjLhvCe/xmfLfjHlfnrNTXoVULw943g71Fa5TuuVp9rHiRnVEEVgM1p8tM3AZSZcz3Kssn6qji9W7cayHQexN4pZZXtxme3eSzXMuhv2eEcu8SqQZIs6KHrzCM1zidurW0hxWgC/cfZj6mkh3M6bP2x1WgRHmLbkF1zav12dbXd8sCzmef/8Yh3++UUgLfO8289Gp9xGlsgXCrN+k42aHtDiaBCvaSj82oFyq8aVQ2m5vtxaRh0KEg1TRgREVEBEa4loAxGNV9mfTkTvKvsXElGesj2PiI4S0RLl86IZ8riZeCpKxYOeouZaWLBhX8xj3ORSmDd+OraaUI+3vEpfA7NbJbXHLpt+83gaZ7XEhdecnhf7PBMmW//435/RccIMw9cBgN++9IMp1/ErhhUBESUDeB7AcAA9AfyWiHqGHXY9gAPM3AXAkwD+HrJvIzP3UT43G5VHCGD2CFzL9aqqGTsOlMU+0CYueMp4AroXv9qky3wy8JH6MQzlNpmInpu3Xr+tXOWEBqmx05METexNG6YC0G4qenjG6trlaUvNmUcRjGPGiGAAgA3MvImZKwC8A2B02DGjAUxVlj8AMJTsns0UbGFFUXylE63gWKXxBrjo4FHsKTHWo4+3ULxeO/yc1XtQVqEvv1BQyYUXtYl5niLbwbJArimtrsMzV+zSdR/BHsxQBG0AbA9Z36FsUz2GmasAHALQTNnXkYh+JqKviGiwCfIIMG+y+Ok565E3frpmO266RTmOnOSLlbt1N7ChBL+6eWt266oTce7jX+m+l16dEwyGGxXiAaXlpw5VbuWVNVEnztX4Zv3e2uUPF1ufjTfIT9sO2HYvL+H0W7sTQHtm7gvgLwDeJqLGagcS0TgiKiSiwr1796odIoRg1oDryTlKHVqNDcy/vrQ+fYaVbNtf37R137SVmLE8/p5sMLfUda8X4sOftDd6m/fpn+PQa7sPVrNT844KsmT7wXoeUKHmshpm3VXxrnrlx9rl299fqutcI1z1sqROV8MMRVAEINStoq2yTfUYIkoB0ATAfmYuZ+b9AMDMiwFsBHCi2k2YeQoz5zNzfm5urgli28/8tfa5rH38cxEmf77GtOuVa3R7/XFLsWn3NAO9LpwPfLpSdbtZIyyrA830KoKgPITjpqjwpIgXPb8A05fXzV8UepcaZs/UyRaLtDpmKIJFALoSUUciSgNwOYBpYcdMAzBWWR4DYB4zMxHlKpPNIKJOALoC2GSCTK5ES1ZHMzFzGOxESL4ZMRcHXJD+Yc2uktrl+6apKxqz0Psz7VGikEvKq2o7DkT1A79Ckw3OWbUbEz5aXrs+6NF5ltYjNpMjcdRo8AOGFYFi878VwCwAqwG8x8wriWgSEY1SDnsFQDMi2oCACSjoYnoWgGVEtASBSeSbmdldXUoPk2JiZjAn+nvvLtpm+BoPTzeWkjmIkY7k32fWHZlZ2Xs2cu3QrKNDH/8KRREmkN9auDXuHEWRsHO0LNTHlIAyZp4BYEbYtntDlo8BuETlvA8BfGiGDF7g5+0Hbb2fmcnQnJhkqzAh5fUeHXl3AHsUXg0zki2K/TaiCJbtOATguPkkkreTFeaVb9fvwzndWph+3Vh8sHgHxvRra/t93YZEFtvIpzb7Ta/epc2V81hlNVKSCClRahjoUSo7DpRh6ndbUF0D3Pur8JASe9HTZkWbMDXa9oUGmllpZRv13LeGrxEcSB4N+T5CzXRWqLDDDuWpWhMh06zfcNpryDfocRs0C2Zttv3THp2LP79X13PDSGTuut0leOXbzXh1wea4zg9i57TeoaOV6H7PTBRZVG0tNNDsCgujYM2YE3ng04A57faQZ+KRkEAwo7y1sH66khnLtRfTMROZOw4gisAmzAjJj4d/zFqLT5YURU1NfLCsEou3FGNEWO3hrSqulFogIlMK45jxkgZNJU/MXocNeyJnvgyaQdbuLlHdb2Z50cKt3vBlX150qHY5NysdAOpVM4uHiR+vqLet1IGOEmB+KhavIorAJpxSBKt3Hsaf3lmC/g/PiXrcL4eOYVXYMDleiaurzUkkZsYrGjRzPzN3PSZ8tCyiDf0HDeUZ7cSuLLWxCH5fF/RqBSDgITTXosydRmstx4O4kwYQRWATTqU//3KtvuC7UL/7eJVXZnpKxHTDI5/5RrOZzAzV2TgjtXZ50ZYDmLtavQ7xLW/9FPU6T89db4I0x4kVLOaW2g6/ejYw53CssrreiNFsXv7WmCkxHtxWctMpRBHYhFpyLzdSUV1TayapjtNr595PVkQ066z85TCKy7TlpTHjK5uzeje+XX88c2pVnF418UT5RiPWpG60iWs7CY4Sd6mMGBMBtyhcpxFFYBNujbwM9xXfsq+sdlu8jeb6PUdUba/BiWu7g9OufOV4WgG1uYvwOsR2EKuhj/e7F/Tx1sJtEWtU+wlRBDYRy0bvFGdMnldnfcQz39TWHq4wkD5ZLYIz2LZpNZO99p35poLMtPoe004M1oKms7zx01UnsZ3yoonEzJXHcy01zrDO6/z8J/Un2gO0FdOJxHCLTV5eQBSBDdz3SX0vCSdYExJXMH/tnohFV4K9UbMnLINzDlU1NZpGBduLzXflvOGNRfW2GVF48ZKcRLXRudtVajhMX+YuRRDK4WPWpWlYF2dReSdSoCQSoghsYOr37ijzOPW7LVinuEde+9oifLY0emNz1GQ7ddA8du7jX+EFh7KUqtUoqM2waiPJSVSraIN/3164Fa8qE6bSrGmnvKpad/ptoS6iCHzEf3/cjoKnvq5tkGOl/73nf8ZHMmUVVTh0tBJ546fjlAe+qN3+3LwNMc8deXJrw/ePRkVVDUqOVWKPSnlJqzl8rKr2d2jeKA3b9pfhro9XYJLiQtmmaYbtMrmFw8f0TeA+aILbad746Yav4WVEEVhMvNWprKKGgVIDRVb00vPeWfjzu0sA1E1lfUxDLeDcRumWyJQ3fjqKSysw/sNlOOn+L9CqSQNL7hOLexST4dpdR/C/JXUztw/u6s1U62Zw8v1f6HpvzPbo8iOSa8hiXvgyds/XbuzuAc9TCUA6s0tzW2UI59QHZ9cun9G5me15oIDjMR53fby83r5KB+Yt3ERVDSNNmQA+Ul6FBRv2ITcrHae2z6537Jqd6tHgejlUVokmDVNjH5iAiCKwmNW7zHlIzWTYE8aLuhulb/tsrN55GN1bZalGd57ywBe2+Xh/t9FdUcX//XEbSn2eNz9oNnt9wWbcr+Q+atIgFUvvO7/esfs11kuOxZ0fLsOLV/Uz5VpeQ0xDVuMuy5BrWLPzMIY//Q1mrdxdp3e+YMM+5I2fbmugzzQHRgPRmPDRcvy07aDTYjjK1uJSfL1ub60SUMNsZbm86BAu+/f3pl7TK4giEBzhi1WBVA83v7kYxaUVtUFdoYnO/IzfRwQFT32Dq1/9sc628FCBXvfNwnuLtqNBarIp9yw6eBQLNxdj0KOBTLHrd5dETdaYSIhpyGJ2OeCR4kUG/2M+crPSffPixcIrGUrtpKyiGoVbivHsvA14/NJTAAB3fLjM9PvsPHQMX6zchXH/WQwAmP3ns9C1ZRYAYPqyX3Bqh2y0dsjBwCrIi4EY+fn5XFhY6LQYmjj7H/Ox1YEUBoKQyJzYslHcwWfxsO6h4UhOInS+K1CIccvkkZbfc29JeW3678rqGqQmJ+FIeRUapcfffyeixcycH77dFNMQERUQ0Voi2kBE41X2pxPRu8r+hUSUF7JvgrJ9LRFdYIY8bkKUgCCYz+7D9o4c+z74Be6ftrLOttLyKqwwYMqsqWG8vXAbqmsYh8oq8dScdTikFBZiZvR/eA7W7y5BRVUNuk78HPPW7Ebv+2ZhT4n5VgbDioCIkgE8D2A4gJ4AfktE4fUJrwdwgJm7AHgSwN+Vc3sCuBxALwAFAF5QrqeL8qpqw+kQmFk1TL2sogoFT31de/2t+0s1ZYZk5noyScpbIRx5JOLD7qyhpeXV+M8PxzMElFdVo9d9s3ChkqZ7za7DWKc02sFU7rHaiU37juCuj5dj9c7DmPr9Fjw1Zz3+/fVG/LTtAK55LZAKZe+R8tpMtde9HrCC3PTGYny4eAfOffzLOterqKrBzBU7sWmv+kgp2ryTGXMEAwBsYOZNAEBE7wAYDSB0un80gPuV5Q8APEcBn8HRAN5h5nIAm4log3K9mFP3P24uxvbiMgzslIORz3wb0M4PXIAMZeLoaEU1UpPr1+FlDhRNYWbUMPDLwaNol9MQHSfMQIusdPw4cRiYGW8u3IaUJEJmegrW7CrBDW8U4ut1Ab/viSN6YGCnHMxcsQt/HNoV24vL0LVlFm57Zwn+t6QIn/3fmXh67nrMXlU3973LYssEFyCPhDfpdvfM2uW1u0pQ8NTxxHXDerTE0B4tMOGj5dgyeSTW7S5Bg9RkNG2Yite/24K8ZpnISE3Ggg2B9OgXPvstRvc5AQDwwpcb66Rf+deXG7EmzAX95+0H8fP2gwACOalyMtNQdKAMM1bswrw1e9C0QSp+3bcNGqQlY39pBfI7ZKO0vAr3f7oKyY1btFX7fwzPERDRGAAFzHyDsn4VgIHMfGvIMSuUY3Yo6xsBDERAOfzAzG8q218B8DkzfxDtnrkde3LmZY+p7rtyYHvkZKbhmXkbMKBjDpo0SEWbpg3QPqchpny9ybLJ2yHdcnUXgREEQbCTopduPly5f3uT8O2e8RoionEAxgFAcuNcZEY47s2F22qXf9xcbINkAUQJCILgerhG1V5lxmRxEYB2IettlW2qxxBRCoAmAPZrPBcAwMxTmDmfmfMzsuqHmQfp1jIL/Toc39+0QSo65WaiQKm5ahVm+TILgiBYiOq0lBkjgkUAuhJRRwQa8csBXBF2zDQAYxGw/Y8BMI+ZmYimAXibiJ4AcAKArgB+RAy6t8rCXy8+CYVbinHVoDxc8u/vUVFVg2X3n19bo3Z7cRmaNkxFVkbd3CF7S8rBzMhMT8Huw8ewvOgQRvdpU5t9cM2DBQCAcf9ZjM65mejaIgt3fbwcHZo1xNb9AQ+g24Z1Rc/WjfGvLzfi4V+fhJ2HjuLc7i1w4t2fo7Ka8faNA3HrWz9rLskoCIJ3GT+8OyZ/vqZ2/aQ2TTCgYw5e+XYzFt41FOt3l6BhegraZjfArW//jNM65aCiirF212HMVywJw3q0wJzV9XNytctugO0HItflGN67FU5u2xTbikvx07aDWKvMJ/RonYWiA0eRkZqMgt6tsOPAUSXnl3oRWVPiCIhoBICnACQDeJWZHyaiSQAKmXkaEWUA+A+AvgCKAVweMrk8EcB1AKoA3MbMn8e6X3gcweZ9pchMT0aLrPhT927bX4ZmjdKQGeaju/9IOfo9NAcrHwh4ti7cvB8nt22K5hEyY9bUMJKSCKXlVThWWY1+D7mzMpkgCPGRlpKE78efW/tub5k8El+v24OU5CSc2DILqclJaJyRgq37y5DXPJIRG1hRdAgXPvstXrumP37adgDPztuA35zaBoO75uI2JWPvf64fgD/+92ccKDvuJdU4IwW3DOmMv89cWyee4UBpBf799Uac17NVHasIEHCSWV50CKe0y1aNI5CAMovxe55zQbCCJg1SbXchHd67FT5fESjZuWXySGzaewRLdxzEr/uqOuLEpKq6BvdNW4l7LuyJ3YeP4bZ3l+CZy/uiXU5DMDM6TpiBz/7vTLRukoF+D83B45ecgtvfX4o5fzkbHZtn4lhldb2OaywiBZSJIrCYbnd/XicPvyAIxmnZON3WoLJFE4chMz0ZPe+dBcCeyOKftx1An3ZNQUQoOngUJzTJwNrdJejeqnHc17Q0sliITPdWWU6LIAgJw4Th3QEAb1w30NL7jDurU+3y2zcORG5WOhqmpeCOgm6YdusZlt47SN/22bUp2ts0bQAiMqQEouEZ91Gv0ja7IZbukIyasXjrhoHo1yEbz83bgOfmu6+Yj92c3LYJlslzU4eczDTcdHZn3HR259ptE0f0wJNz1qGswtz62neN6IHTOuWgXXbD2oRzAPD7IV1MvY9bkBGB4AjBkdJ4pYd3RpfmyEhNRo/W1vR4vEbDNH/30aZeN6D22QhSHRaaP/f2s3HtGXmmKYHGGYHvPGj2Obd7yzpKIJERRWA1kkxGlQt6tcKLV56KGwd3wuZHR9RuH3FSK6yaZG/uwbxmDW29Xywu6dcW+R0ix8r4gf552bj57M64elCH2m3h85mdcxvVSyFjhJPaNrHF9u9GRBFYTEsDLq1W8fq1/Z0WAZv3laKgd2skJ1GdUpVEhIZpKdgyeSQuy28X5QrmcXOIqcENPHbJKWjq09q5QVKSAk3TpNG9sWD8ubhxcEc8MLqX6rHpKeY0Y5MvPtmU63gRUQQWc0dBN6dFqEdflQLgdvPpstjlIRumWxetPfvPZ2FAxxwAzlVFa9Ig0NgX9GqFMf3quiCm+DxVbXLI/9+maQNMHNkzopvmSW3qpc6Ji3Y57hoZ2okoAovJcGHqicw0+2T6YcJQFN49LK5zj5o8ARhkzYMF6NoyCy9dnY+v/3YO1u0uiX2SBXz2f2cCAG48q2MdEwgAX08Uz7xtcB1FEItWTdw36vYaogh8RM/WjTH1ugG1dtW/XRB9tHLrOcY9JFo1yUDzRunY+MiI2vQdAHDlwA5RzgrwzqLthu+vRlA5N2mQivbNGjqirBtnpCBJaez2H6nAyW2b4sbBnXDlaYHvZfUuZ5STG9DrIvnQRb0N39OvcwNBRBHYQE+XeMI8evFJOPvEXACBRv6KAe2jHh8tPF4vyUmEJGUu4MNbBuG+X4XXLnKOOwu6xz7IZGoYSFa+j1TFxj1xZI/aRs3fhiF9NG2YJkWfDCKKwAZm/Gmw0yIACPimB/nrBd2QnZkW9fg0kybhggSH+7mNMjR5e1hhJ3/p6npBlbWJCu2kqqam1qTRPLN+3qrg/IUbCc5tuAlSz6UmaEQUgU189bchToug+rJ88oe6UZKvXdsfH94yCACQlhz/y6VWYDvYridpfOomjOgR9/0jofYf1TiQZiXoE7/mwQKc1Lb+ZGf4nIHThKZxtzLHT7wmmvAYAz24wYvOaUQR2ER4Omy3cEq7pnXW++floE+7gFdRitYWWwVWKcIYVERae29m9fEmX3xS7XJFdf28T+0d8BYJfreR5ieMfPeCdkae3BpDurVwWgzHkafNJrxiw8xISaqVVY/nRigf3DwINVHy7GVlaIuaNeM769chG5f1jx6PkKThRlk6szzG4sWr+kXdn5Hqrlczp1F0M6JX6eBjl9FQ3PW0JTBaGhsrGKjT1pySnBTSc4/vngfKKiOaW7ZMHmmrTT40YK1Jg9TayfJwbhvWNep17jV5cjuSHEEau8QOv/GRQNR344xUyz1rnEjQaMCilFCIIrCJZIcms/KaZWJMv7b49NYzox6n1vtPilPm1GQyxe5uxjsa/LfO6NIML12dHzF/+zWn50W9jtkT57FwS/xJ8Ln4cfN+AMCHt5yOId2iK7F4+dwBpwovpuG3An9ntvIBD4zupalR6Z+XjXfGDaqzLTdLvQpbLJiNTd6FXscoQWX21g2nabpXVkYKSo5VWSKL12ib3aB2+adtBwEETG1GOzV/PLcLnplXP8OsE54/PvxZVZERgU3orSRkFlqUwMK7hmJKmFvluoeGx50JtE12Awzp1gJnxTB/2IHWtiU7Mw1f/nUI+uepm9LUJr/18PHvT69dXvtQQZQjnSeYFuXpy/vWbrvnQvNMY38+78R6235zanxVvoxSI7YhAAYVARHlENFsIlqv/FVNYkNEY5Vj1hPR2JDtXxLRWiJaonwSevq+a4tGtt6vo8aAsJaNM+rZ7Y2YQk5smYVXr+mPN64bEPc1APt7a3nNMyN6KhkdEYTmd4rX5KYFM8wrGSmBzkOLkBFhaOyA0d9FrefvVGxCy8aSngIwPiIYD2AuM3cFMFdZrwMR5QC4D8BAAAMA3BemMH7HzH2Uzx6D8riaKwZGj+Q1m9Ym5mAZ2t1+HW2G/TbVxDTFZmGlIjDSoAbTccea37HCrt7zBGei7687s6Mj93UbRt+S0QCmKstTAVykcswFAGYzczEzHwAwG4C7x8YWYbeducrEYa8Tc92X9DOehvrB0cbz0ADGfrvwdNrxuuVqwYiSOTWkBsIrY/PRpmkD1ePOtaBTEJ591S6s/C28hFFF0JKZdyrLuwC0VDmmDYDQ7GE7lG1BXlPMQvdQgseJN7PZF7tdtrd9pJuYkJO/ReP4JrzNpKD38ajc+y3OsaS3XQt1CHjk14HAO2ZgaI+WEV2erxqUhykhcRBf/nUI/vcHe+r4GsXseJBEIaYiIKI5RLRC5TM69DgOjBf19pt+x8wnARisfK6KIsc4IiokosK9e/fqvI07GHXKCbbda1iPlnjkYnN6w4D2OYNmMfIX2U16ij43zBtDipaHYtbYysyKWmro7UsF8zllpafUOhaET4w/NubkOsoMqDvyaJCW7Jn6CTI1rE7Mp5KZhzFzb5XPJwB2E1FrAFD+qtn4iwCEjo3bKtvAzMG/JQDeRmAOIZIcU5g5n5nzc3Od90aJBzsHPFkZKbobQTWG9QiYAbTK/s9LTzF8Tyc5rVOzetsu6nMCBndtHvc1gw3ruLM64byeaoNm89Br6khR8klVR7F9XZLfrl7uqND7EOm/75+GHg/gs9MsdP8o9Spnfsdo92QagKAX0FgAn6gcMwvA+USUrUwSnw9gFhGlEFFzACCiVAAXAlhhUB5BwaxEai+P7R+IKNV4ucqqKLklPMqdw7sb8i4J/hR3jeih6zrrHx6u+156O+bBmIC3b4weZxFOaL8gLTlJd2nNUBfSf15iX+fBqbkIt2PUYDYZwHtEdD2ArQAuBQAiygdwMzPfwMzFRPQggEXKOZOUbZkIKIRUAMkA5gB4yaA8gkVoVSztXVYI3gyyGxozd8Ufoa2/n9ZAZ/W5oKmqT0jyQS0/dWhSvKYN09BE4/Oh1aVZsBdDioCZ9wMYqrK9EMANIeuvAng17JhSANEzbwlx40QkbGoy6a4uZSXBUpBGmKQxMjvIWzcMxO9eXlhnm13pKf73hzNAOnO2qkUJa/E2C0+OqtV0GBpbkpOZhuLSCk3nCdbiPidrwRTMzrGvJRiusto9U3ErHrgAvU0oaq63V35Gl/pzCXrNJvFC0O/mq3b8f3/cFvM8M2IhFt89DD/eVa8fGRdPX97HlOv4FVEENhNvIXenuW1Y/bQAbkatME48qE3uXnN6Hp67oi/SoiiJX/dtg6X3no8fJw5FrxOMKyQtEOmv4aDWoB/WUHgmXjUQ2kEhIrQwKbI3r5mYnIwgisBmmjeyx6/d7L65U2m0nUbt97ptWFdcePIJUXvfAzvmoEnDVLTIsi+FAYF0e6bFG1AVrwfcsUprnAn0jlAuzZdJ41BEESQodqfXzbbJ/OEGgtXmpv9RPa/P69f2x0V926ju08vP95yn+dh43Djj1e/xWoZObGlNvi3dJjHT6t8lBhJmJ5jCtFvPRKVKGUgv8tq1/XHta4vqbR87qANG9TmhtrHtEmHexMzSh9k6AvSCjWFuVjr2lpRrOsfukZ5VsTRaLts/L5BC47ExJ6umGvczMiJIUOz2GmqX0xCdcu3NrmoVfdo2Vd3+wOje6Nehbprqn3T02K0m2MvV09RmqAQdanl03GYp1NLDD+auuiS/neQYCkNGBAmKGYVh/Ep2Zhomje6FI+Wxe405mWlo3igd+46UW17KMRbBXrGeTvcNgzvizuHddd/LbYV6wt1Z1ZDGPzIyInCAm85Wz2djJmZmHg0Sq9ylW1hsgmfW1YPy8PshXUyQxj6CE6anRBjRqNEgNblOMJnW893kKgzot/kndnpL/YgiSFCssNef1LZJPW+LrIwUZKS66zFqZpNn1nHc0SgGG7cXfneqoeu0y4kdHR4ccd41Qv9owgq0dPZDG3/RA3UR05AT2NBulFuU8+cfY07BwI7NcPv7SwE4U3BcUCfYuBnJcDrztsER6xCEEqxnPO6sznHfy0wSPIO95YgicIAcG1I1V1iY/C3UFt3W4zUPEonwxjArPQUlGuY5QtGaIiSveabjcyKhaNEDoccM6twMp7Zvapk8XsNdY3qfcL0N5fGszA8ffKGeuqyPZffwEu/eNMiUvEZGCW8MNXWSE6QjrTegrEuLLHz0e28U07EDUQQOYFZxkk8iVIV6ZWw+/nWl9fn8RvcxJ2jKTF6/tr/t9+yc28iUvEZGSQ1znfGTl4w2neef70Mvogg8TPMs9UnRri2y6pQgNJsuuVmWXdso3Vq5VzarCRaZAYABHXO0pc92xzy3YcxIgudnRBF4mEiPfmiDYAUntW3iKvuwn/nNqce9uEJjR967aVBc9Qys5M3rB1p2bdEDxnDXkyLoItLQ3yv1YwVtvHjlqXX+hpKvpE0AUG8UGF572Gn0Fs3Rg5/MYFYgisDDROoF+fmlaKCjiIxXKOjdWlmK/rvqKaBjRe98aPfoOZaqLMxFpaXzI6OGyIgicAmTRusvqk0g3H5e/ToBmSbl4vcas247C00NlpX0Gh00BH+p0aSB+dliX7km+kS9FdHutUgjbwhDioCIcohoNhGtV/5mRzhuJhEdJKLPwrZ3JKKFRLSBiN4lIn+9xSEM61G/AEosiIB+efW/cj09w0QhOYl8N1G8ZfLIiBlQo3HN6XnHzzO5Af3tgHYR91npWSWTxcYwOiIYD2AuM3cFMFdZV+MxAFepbP87gCeZuQuAAwCuNyiPZ4mnr0QQlzi/MUOpgbBQKfGo57kZ0i0XAHD/qF6W2ev/rDJCDWLFKCSIXQWfEhWjimA0gKnK8lQAF6kdxMxzAZSEbqNAGOS5AD6Idb4fsLuQTKLhB3VIBPQ8oXHtMqAvC6hatHlLk0pF1uLgYzzqlBOcu7nHMaoIWjLzTmV5FwA99o1mAA4yczAGfgcA90Uo2UQ8ekDyqxzHUvuzSwg1+QVHgjU6HpxrTs+rs75l8kh0TpAaEkBgpCPER0xFQERziGiFymd06HEc6NJa9jYS0TgiKiSiwr1791p1G9toZkK+IYKMJEb3CfQCO+UmdvHytJQkdA2ZD6gdEei4xvm9WmHVpAvMFcxFxOoWSccpMjEVATMPY+beKp9PAOwmotYAoPzdo+Pe+wE0JaKgi0tbAEVR5JjCzPnMnJ+bm6vjNu5kcVhlq3ja8yQi1cliP9KpeWIrgnUPDccJIVlBGymeYXo7Ag3TrPUoc7Jb0ihD/X9b82CBzZJ4D6OmoWkAxirLYwF8ovVEZQQxH8CYeM5PNKIN8SM+yASkq5QaFBKbLZNH1pqJ0lwWPezkADVSJLUfvej0YvQpmgzgPCJaD2CYsg4iyieil4MHEdE3AN4HMJSIdhBRcHx6J4C/ENEGBOYMXjEoj2eJ9v5EepAb+TReIJTgYN+vFrIWUSZ7/fqdCPox1JIw834AQ1W2FwK4IWRdtXoJM28CMMCIDIlCPLb+YARxw7RklFVUmy2SJ2AAN53VCWd2be60KI6Rk5mG4tIKp8UQPIy7xpU+ZcvkkYZsq3eP7GmaLG7j4lNjO5JNGNEDg7t6f97IbC7NjxzcZRVuy28UikwVR0YUgUvQOyB447rjA6krBrY3WRr30LSBb4PNNXOsUn00eONZnWyWJDK/H+KOkpaCOqIIXMIJTTMwWId5o1+Hut5C5/XUn6LC63TOzcSZXfxrEgry0e9Px/Q/Ol8hDYjcobmjwB1F7gV1RBG4hIZpKfiPjoyQ4YnlXro632yRXM/c24fgEgfMH26je6vG6HWC8xXSAOtrYcTLyJNbY0DHHKfFcC3idiK4GokB8hYtskxOWWGQ684I1Ad//or6tRyE48iIwEHUcudf0Mt/Jh4hsdj4yAjcOLij02IAAE7t0NRpETyBKAIHUZvkvfK0DoaumWg96AT7d3xBchJJOgePIaYhB7mjoBtuOttczw4pUym4gbxmx1N+bHpkhIOSCFqQEYGDpKckm25TTbQCHQn27/iG3w5oVxv5nuRg50TqdWhDFIHLMJoWINHqFYuJwZsQES7vb79Hl9kjbL8gisAjFPRqpek4K6tACYIefFAiImGQOQKXEakDrMU/+9s7z4mYgdGryHjAu+gpmmMWYgqKj8RqNRKINiG55wFtk8BtsxuaX3rQaeS99ixOFE1yc64jNyOKwCMkJ/nzp5Ienndxg2nIiVGJF/Fn6+JBEszioxmZKzaGnvxVZlPtAtPQnNW7bZfBi8gcgUdING8gwV6euPQU2+95y9md0bddU9vvG0ppuT/rdOjFp/1M7zHqlNh5+RMRUX/mcFEf+5+fdjkNbU8K2Kpxep11MQ1pQxSBy+ic20h1+6DOzdDH4d6VE0T0opIRkiaCzgNOBnXZydWD8uqsd22h/j4JdTGkCIgoh4hmE9F65W92hONmEtFBIvosbPvrRLSZiJYonz5G5EkETgjzFgqlfU5DAECrRPMMikKkyeIF48+1WRJv8tBFvbFo4jCnxbCNcIU34qTWDkniLYyOCMYDmMvMXQHMVdbVeAzAVRH2/Y2Z+yifJQblSWiCveMJI/xT5CPSiKBFVrr6DqEOGanJyPXxd5VoKVeswqgiGA1gqrI8FcBFagcx81wAJQbv5TuW338+1j88vHY9RXEh7RCS0Mtv3HpOFwCSekLQhjwm2jCqCFoy805leReAeJLpP0xEy4joSSLyb9clhDsKuuGOgm7IykitEyl8z4U9AADJPnq6w//Tm6X2raADGRFoI6b7KBHNAaCW6GZi6AozMxHpnaKfgIACSQMwBcCdACZFkGMcgHEA0L594hZrB4DfD+miur1pw0Ahd1+5koa9yE5EqwrepWlDyb2lhZiKgJkjzjQR0W4ias3MO4moNYA9em4eMpooJ6LXAPw1yrFTEFAWyM/P93VrkJ3pn4fbRypPsIBozhfCcYyahqYBGKssjwXwiZ6TFeUBChh8LwKwwqA8Cc+mR0agdRP/PNwtwvzCfd0DEDQx67aznBbBcxhVBJMBnEdE6wEMU9ZBRPlE9HLwICL6BsD7AIYS0Q4iukDZ9RYRLQewHEBzAA8ZlCfh8Ys/eJDf9m+PxXcfH5SKZUiIhVotcCE6hlJMMPN+AENVthcCuCFkfXCE88UZXIhKUhKhWaOQUYEoAiEGMj+sH4ksFjxFeqo8skJ0/DZqNgN5qwRPkZGajC2TRzothuBiRA/oRxSBIAgJhZ/ibMxCFIEgCAlFY6nbrRtRBIIgJBQZqcno3irLaTE8hSgCQRAEnyOKQPAMF/U5wWkRBI8gSQn1IYpA8AwN06WyqqANUQP6EEUgCELCIQMCfYgiEDyDvNuCVkQR6EMUgeAZ5OUWtBKpxKmgjhhdBUFIOP5wThes/OWQ02J4BlEEgiAkHAW9W6Ggt1o9LUENMQ0JnqBzbibO7JLrtBiCkJDIiEDwBHNvH+K0CIKQsMiIQBAEweeIIhAEQfA5oggEQRB8jiFFQEQ5RDSbiNYrf7NVjulDRN8T0UoiWkZEl4Xs60hEC4loAxG9S0RpRuQRBEEQ9GN0RDAewFxm7gpgrrIeThmAq5m5F4ACAE8RUVNl398BPMnMXQAcAHC9QXkEQRAEnRhVBKMBTFWWpwK4KPwAZl7HzOuV5V8A7AGQS4H0gOcC+CDa+YIgCIK1GFUELZl5p7K8C0DLaAcT0QAAaQA2AmgG4CAzVym7dwBoY1AeQRAEQScx4wiIaA4AtRC9iaErzMxExFGu0xrAfwCMZeYavfnCiWgcgHEA0L59e13nCoIgCJGJqQiYeVikfUS0m4haM/NOpaHfE+G4xgCmA5jIzD8om/cDaEpEKcqooC2AoihyTAEwRbleCRGtjSW7gzQHsM9pIWLgdhndLh8gMpqB2+UD3C+jHvk6qG00Glk8DcBYAJOVv5+EH6B4An0M4A1mDs4HBEcQ8wGMAfBOpPMjsJaZ8w3KbhlEVOhm+QD3y+h2+QCR0QzcLh/gfhnNkM/oHMFkAOcR0XoAw5R1EFE+Eb2sHHMpgLMAXENES5RPH2XfnQD+QkQbEJgzeMWgPIIgCIJODI0ImHk/gKEq2wsB3KAsvwngzQjnbwIwwIgMgiAIgjG8Glk8xWkBYuB2+QD3y+h2+QCR0QzcLh/gfhkNy0fMER19BEEQBB/g1RGBIAiCYBKeUgREVEBEa5XcRGrpLFwlDxFdQ0R7QybJb3BCzjCZXiWiPUS0wmlZgNjyENEQIjoU8h3ea7eMKjK1I6L5RLRKyaH1JzfL4tLvMIOIfiSipYrcD7hZFje+y0GIKJmIfiaiz+K+CDN74gMgGYGI5E4IRCcvBdDTzfIAuAbAc05/d2EynQXgVAArnJZFizwAhgD4zGk5w2RqDeBUZTkLwDqnnkUtsrj0OyQAjZTlVAALAZzmVlnc+C6HyPYXAG8b+Y29NCIYAGADM29i5goEYg9Gizz6YOavARQ7LUcQt8mjBWbeycw/KcslAFbDofQobpJFDxzgiLKaqnwcmbB0kyx6IaK2AEYCeDnWsdHwkiJoA2B7yLrTuYm0yvMbJf32B0TUzh7REo5ByrD9cyLq5bQwoRBRHoC+CPQiHSWGLK77DhWTxhIEMhLMZmbHvkONsrjxXX4KwB0AaoxcxEuKwIt8CiCPmU8GMBvHM7UK2vkJQAdmPgXAswD+56w4xyGiRgA+BHAbMx92sSyu/A6ZuZqZ+yCQXmYAEfV2sSyue5eJ6EIAe5h5sdFreUkRFAEI1cJRcxPZQEx5mHk/M5crqy8D6GeTbAkDMx8ODtuZeQaAVCJq7rBYIKJUBBret5j5IzfL4tbvMAgzHwQwH4F6JY4SSRaXvstnABhFRFsQME2fS0Sqwbux8JIiWASgKwWqmqUBuByBXEeulUdJxBdkFAL2W0EHRNSKlFS1FEhjnoRAwkInZSIE0qGsZuYn3C6LS7/DXFIKVBFRAwDnAVjjVlnc+C4z8wRmbsvMeQi0P/OY+cp4rmU06ZxtMHMVEd0KYBYCHjuvMvNKt8lDRJMAFDLzNAB/JKJRAKoQmBC9xil5gxDRfxHwImlORDsA3MfMjuV4UpMHgck6MPOLCCQlvIWIqgAcBXA5K64SDnIGgKsALFfsygBwl9LbdoUsANoDrv4OWwOYSkTJCCim95g5fvdHC2Rx+7tsJhJZLAiC4HO8ZBoSBEEQLEAUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPkcUgSBEgYiahWSc3EVERcryESJ6wWn5BMEMxH1UEDRCRPcDOMLM/3RaFkEwExkRCEIcKDn+P1OW7yeiqUT0DRFtJaKLiegfRLSciGYqKSBARP2I6CsiWkxEs8KiVQXBMUQRCII5dAZwLgLpB94EMJ+ZT0IgknekogyeBTCGmfsBeBXAw04JKwiheCbFhCC4nM+ZuZKIliOQcmSmsn05gDwA3QD0BjBbSfuTDGCnA3IKQj1EEQiCOZQDADPXEFFlSC6fGgTeMwKwkpkHOSWgIERCTEOCYA9rAeQS0SAgkDraLQViBEEUgSDYgFLOdAyAvxPRUgBLAJzuqFCCoCDuo4IgCD5HRgSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPgcUQSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPic/wcvziJ0eY2VRAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "librosa.display.waveplot(samples_out, sr=sr)\n", - "print(len(samples_out))" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "infectious-welcome", - "metadata": {}, - "outputs": [], - "source": [ - "import librosa\n", - "x, sr = librosa.load(wav, sr=16000)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "musical-anatomy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float32\n", - "float64\n" - ] - } - ], - "source": [ - "print(x.dtype)\n", - "print(samples.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "lucky-paraguay", - "metadata": {}, - "outputs": [], - "source": [ - "sf.read?" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "annual-christmas", - "metadata": {}, - "outputs": [], - "source": [ - "librosa.load?" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "infectious-seeker", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(x, samples)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "pregnant-conditioning", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "logical-happiness", - "metadata": {}, - "outputs": [], - "source": [ - "np.random.uniform?" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "rocky-plastic", - "metadata": {}, - "outputs": [], - "source": [ - "random.uniform?" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "focused-compensation", - "metadata": {}, - "outputs": [], - "source": [ - "np.random.RandomState?" - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "centered-repository", - "metadata": {}, - "outputs": [], - "source": [ - "random.sample?" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "inner-invite", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array(['3', '5'], dtype=' 1.0, speed up the audio;\n", - " speed_rate = 1.0, unchanged;\n", - " speed_rate < 1.0, slow down the audio;\n", - " speed_rate <= 0.0, not allowed, raise ValueError.\n", - " :type speed_rate: float\n", - " :raises ValueError: If speed_rate <= 0.0.\n", - " \"\"\"\n", - " if speed_rate <= 0:\n", - " raise ValueError(\"speed_rate should be greater than zero.\")\n", - " old_length = samples.shape[0]\n", - " new_length = int(old_length / speed_rate)\n", - " old_indices = np.arange(old_length)\n", - " new_indices = np.linspace(start=0, stop=old_length, num=new_length)\n", - " samples = np.interp(new_indices, old_indices, samples)\n", - " return samples" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "tracked-purse", - "metadata": {}, - "outputs": [], - "source": [ - "samples, sr = sf.read(wav)\n", - "samples_out = change_speed(samples, 1.0)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "steady-mileage", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ipd.Audio(samples, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "regulated-google", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "homeless-forge", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples_out = change_speed(samples, 1.1)\n", - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "exciting-blocking", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples_out = change_speed(samples, 0.9)\n", - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "through-botswana", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "66078\n" - ] - } - ], - "source": [ - "print(len(samples_out))" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "cellular-violence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting matplotlib\n", - " Downloading matplotlib-3.4.1-cp37-cp37m-manylinux1_x86_64.whl (10.3 MB)\n", - "\u001b[K |████████████████████████████████| 10.3 MB 691 kB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (8.1.0)\n", - "Requirement already satisfied: numpy>=1.16 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (1.20.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (2.8.1)\n", - "Collecting kiwisolver>=1.0.1\n", - " Downloading kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[K |████████████████████████████████| 1.1 MB 45.9 MB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pyparsing>=2.2.1 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (2.4.7)\n", - "Collecting cycler>=0.10\n", - " Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n", - "Requirement already satisfied: six in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n", - "Installing collected packages: kiwisolver, cycler, matplotlib\n", - "Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n" - ] - } - ], - "source": [ - "!pip install matplotlib\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import librosa.display" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "undefined-parade", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8K0lEQVR4nO2dd3hUZfbHvycdQoAEQpEWmlQVJICoKApqABdcF8u6KlbUXX+77rqrIFZs7Lr2si5WXHXtrigI0myoSFB67xDpoYQEUs/vj7kTJpM7M/fO7XPP53nmye33ZObe97zveU8hZoYgCILgX5KcFkAQBEFwFlEEgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPscURUBEBUS0log2ENF4lf1/IaJVRLSMiOYSUYeQfWOJaL3yGWuGPIIgCIJ2yGgcARElA1gH4DwAOwAsAvBbZl4Vcsw5ABYycxkR3QJgCDNfRkQ5AAoB5ANgAIsB9GPmA4aEEgRBEDRjxohgAIANzLyJmSsAvANgdOgBzDyfmcuU1R8AtFWWLwAwm5mLlcZ/NoACE2QSBEEQNJJiwjXaANgesr4DwMAox18P4PMo57aJdcPmzZtzXl6ePikFQRB8zuLFi/cxc274djMUgWaI6EoEzEBnx3HuOADjAKB9+/YoLCw0WTpBEITEhoi2qm03wzRUBKBdyHpbZVu4AMMATAQwipnL9ZwLAMw8hZnzmTk/N7eeQhMEQRDixAxFsAhAVyLqSERpAC4HMC30ACLqC+DfCCiBPSG7ZgE4n4iyiSgbwPnKNkEQBMEmDJuGmLmKiG5FoAFPBvAqM68kokkACpl5GoDHADQC8D4RAcA2Zh7FzMVE9CACygQAJjFzsVGZBEEQBO0Ydh91gvz8fJY5AkEQBH0Q0WJmzg/fLpHFgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgif42/tLsXirpKASBCuwNbJYEOLl/cU7kJ6ahH4dsp0WRRASDhkRCIKQcKwoOoSZK3Y5LYZnEEUgCELCMeGj5bj5zcVOi+EZRBEInsGDsY+CQzDkYdGDKALBM8irLWhFOg36EEUgCELCIYpAH6IIBEFIOEQP6EMUgeAZ3l64zWkRBI/gxWSaTiKKQBAEweeIIhAEIaGorK7Bml0lTovhKUQRCIKQUBworXBaBM8hikAQhISiRqYHdCOKQPAUFVU1OPHuz50WQ3AxNTJRrBtTFAERFRDRWiLaQETjVfafRUQ/EVEVEY0J21dNREuUz7TwcwUhlKMV1aioqnFaDMHFiCLQj+Hso0SUDOB5AOcB2AFgERFNY+ZVIYdtA3ANgL+qXOIoM/cxKofgE8hpAQS3UyP9BN2YMSIYAGADM29i5goA7wAYHXoAM29h5mUA5CcSdDFj+U7kjZ/utBiCh5ARgX7MUARtAGwPWd+hbNNKBhEVEtEPRHSRCfIkNHnjp2P/kXKnxbCNtWFugEkyIhBiMPKZb5wWwXO4YbK4AzPnA7gCwFNE1FntICIapyiMwr1799orocvYU+IfRRDetyMSTSBEp7SiunZ5n486TUYwQxEUAWgXst5W2aYJZi5S/m4C8CWAvhGOm8LM+cycn5ubG7+0HmD+mj1Ysv1gve3HKgMPeLWf/ONkmC8YYM9hUQRaMEMRLALQlYg6ElEagMsBaPL+IaJsIkpXlpsDOAPAquhnJT7Xvr4If3jrp3rb/zFzLQB/KYLw/3Tu6t2OyCF4E5kv0IZhRcDMVQBuBTALwGoA7zHzSiKaRESjAICI+hPRDgCXAPg3Ea1UTu8BoJCIlgKYD2BymLeRr8kbP73OROn+0kDv5kCZfyMn//TOEgCSVEzQhjwm2jCleD0zzwAwI2zbvSHLixAwGYWf9x2Ak8yQIdFQM4UHH+oHPl2FId1a2CuQQ0R6kcurapCRmmyvMB6kpoZRw4yUZDdMB9qPjAi04c+nw8UcVHr7as9vsZJDZfO+UjtFcpRIJQeHPfGVzZJ4k7/PWoPu98x0WgzH+GHTfqdF8ASiCFzG8qJDqts37DmCbzfss1ka54nUodtx4Ki9gniUVb8cRpWP5pQ+WVLXT0UUgTZEEXiEV77d7LQIjuCfJsxanGgQDx+rxLIdB2295+qddeNOkpOkidOCfEseoVri5gUDXD7lB9vv+cQX6zDquQW23zcUn06N6Ea+Jo/gp+F9KDLXZ4yFm4odu3eVCzovBb1bOS2CJxBF4BFq/KoIxDhkiIpq5xrjJAeiwMOfl7Rk8SzTgigCl1J0sO5kaLWG9rDkWCXKKqosksghRA94FicUgRAfoghcRiRTiJYRwbAnvnLEFmwloge8i+gB72BKQJlgPdOX74x5zO7D5ThyLMFGBIJnkRGBd5ARgcsw+u5UJ9jsqqSS8C4rIsTEWMl7i7bHPkiohygCB2FmVJk8mZdoCelED3iTWSt3YeFm+z2WDpRV1lkXZwNtiCJwkBe+3IguE80txJ5oikDwJou3HqhdPmPyPAclEbQgisBBPltW3+4f+gLFQ6LpgQT7d3xDqIUz3ANOcB8yWewgq3cerrftqTnrHZBEEMzjdy//gAUb3JHjp0hyUmlCRgSCq5E5Am9xoLTCNUoAAB79fA0A4L8/bkNpuXjURUIUgUuoqq7Bawu0J5Yrr6qus/7yN5vMFsn1/PX9pfhmvb/rVwPA7sPHsOvQMafFAAAcrayOfZADTPhoOeau2eO0GK5FFIFL2LK/FA98qr0427pdR+qsPzR9tdkiuZ4PFu/AB4t3OC2G4/z6+QU4+7H5TosRFSkx6m5EEbgE0hlA8Kvnvq1dTuSc6+L+F5tjVTUor6rvhjxzxS7bZYn0GF8/tdBeQQRdmKIIiKiAiNYS0QYiGq+y/ywi+omIqohoTNi+sUS0XvmMNUMerzH6uW9hJI7s2tcWmSaL23htwZaYx0xfthMb9hyJeZzfePyLtbbfkww9yYJTGFYERJQM4HkAwwH0BPBbIuoZdtg2ANcAeDvs3BwA9wEYCGAAgPuIKNuoTF5j6Y5DukcEwPGo20oHM0y6gT+8/RMeneE/01iQSHmoJMODoBUzRgQDAGxg5k3MXAHgHQCjQw9g5i3MvAxAeIt1AYDZzFzMzAcAzAZQYIJMniPaOxvpRT+s5BXya60C4Pj35tdGr7i0AgePVsY+UBCiYIYiaAMgNMHHDmWb1ecmFNESdHW6a4b6DpZcPMf/e/9ogsv+/T0qlDmBaC6RTjwaTipkv9bsMAPPTBYT0TgiKiSiwr17ve8yOPbVH+usx/MCMRiLthiLRE4U9h4pd1oES7lhaiEOllUAABZuLsbhY4FRQFKSuxSgk9IcijAyyhs/3WZJvIcZiqAIQLuQ9bbKNlPPZeYpzJzPzPm5ublxCeomvlpnXJkxu6McoJN8suQXAMDS7QedFcRi5qzejeUh2TyDvX09De/KXw4ldKMo44H4MUMRLALQlYg6ElEagMsBTNN47iwA5xNRtjJJfL6yzXfENyIQL40gLusYW8LRiuPBWkG3Wj05/x/6rO6E+okTP8cvCZQH6JVvowdV+t2MGg3DioCZqwDcikADvhrAe8y8kogmEdEoACCi/kS0A8AlAP5NRCuVc4sBPIiAMlkEYJKyzXfE6zXk10nScOL5/rwGA9hbUn58BdE7EOH7UlPqvu4V1TXYsr/UPAEd5sWv/BddbxamzBEw8wxmPpGZOzPzw8q2e5l5mrK8iJnbMnMmMzdj5l4h577KzF2Uz2tmyONF4mnGGJKLx2/0f3gOAODCZwMBhXqem68Vc+ScVbstm1j9YlXkCGIr3ZwPllVICnYDeGayONFRy0QaC2ZgW3H9Hp0f4wqqaxj7E3zCOJQ9JeXIGz8dO+PIMXTDG4VY8Ysy32By23n3/1ZE3Ldwk3WD/QofPvNmIorAJcQTgs9g3Pnh8nrbS3xat7jfQ3NwJIEzTKqN/lbF0YEArIk9eXZu9BTqKckWmu9kMGAIUQQeJpJZyM+eREcSUAlGGy2GNq16TCMXv/CdAYnUeXz2uqj7Uy1UBFpqdYsZNTKiCDxMpBffz7bSRExSN/zpbwAAN7+5uN6+X0JMQ+EeQJo8ymycY09Jsq65qapOvN/dTkQReJhIj77VL8XmfaU49/EvLb2HoI1nQswxySE+tC99vUlTL9lORj+/wLJra/lXE7GTYBaiCDxMJM+PvUfK6/icm83S7Qexaa873Q79Oj8C1B0JPjxjta8mz2tcpvS8higCBzArsGXwP9SLkVz8wnf44zs/m3IPNYI9q2U7Dlp2j3g5/8mvbb/n/iPl2HnI+cCscM8ZTSbCBGk/tfwbEnwZGVEEDvB+ofVVtfYctq50YVCPXfHSQsvu4SUum/IDBj06z2kx6plH/NRJ1jIiCDUN7Tl8DP+cZX+9BrciisABNu61vohKekqyZdcOvnNHyqtw+FhlbQI0vxJMBuc04SPNkjhcafcfKa/NbBqNPSXHcOGz3+i+vlXoVXqzVu7Cc/M3WCOMBxFF4AQ2jFDTUqz5af/15UY8FtKTuuj5BRj5tHsaBD9jxgCg30Nz8Pjs2D3ldbuOYEXRYcx3SUF4LebW0EN8NFjShCgCB7DDVmmFz/aew8fwxOy12BVidtq0txTbDzhvHw+lyvYoU3fYnoMN3UOfrTJ0nc0aHAGCHkrXvu6OMqnSsBtDFIEDvPjVRsvvkZps/k874JG5qPSAv/Zpj841fI1v1+/DnCh5c9xI0Ab+2bKdms+pZq430f39pv0xz7M0SjgONM0RuP/RdQxRBAlKsh/yMkdg3xFjNvtjldW48pWFuOENbWk/gpHcTuf6DwaU6/GXf+P7rXFNdLvt8dLSyIcqC1EKdRFFkKDYnZX5UFklikvdMWlqlEgpHf73cxG2F5fV2ZY3fjoOlrljsjyoAPQElu+LM9bAbQ2plhFBoVLNb8n2g1KbIAxRBIIpjHnxO5zzzy+dFsMUfh0hD89t7y7B8Ke/qW1EIjWi63eXOJLmI9i21dYs0EC8YrqtGdXSrr9bGCiPftHzC7B2t/Wee15CFEGCYnehlvV7jkSsGZtIHCmvqv0/Rz+nnjLhvCe/xmfLfjHlfnrNTXoVULw943g71Fa5TuuVp9rHiRnVEEVgM1p8tM3AZSZcz3Kssn6qji9W7cayHQexN4pZZXtxme3eSzXMuhv2eEcu8SqQZIs6KHrzCM1zidurW0hxWgC/cfZj6mkh3M6bP2x1WgRHmLbkF1zav12dbXd8sCzmef/8Yh3++UUgLfO8289Gp9xGlsgXCrN+k42aHtDiaBCvaSj82oFyq8aVQ2m5vtxaRh0KEg1TRgREVEBEa4loAxGNV9mfTkTvKvsXElGesj2PiI4S0RLl86IZ8riZeCpKxYOeouZaWLBhX8xj3ORSmDd+OraaUI+3vEpfA7NbJbXHLpt+83gaZ7XEhdecnhf7PBMmW//435/RccIMw9cBgN++9IMp1/ErhhUBESUDeB7AcAA9AfyWiHqGHXY9gAPM3AXAkwD+HrJvIzP3UT43G5VHCGD2CFzL9aqqGTsOlMU+0CYueMp4AroXv9qky3wy8JH6MQzlNpmInpu3Xr+tXOWEBqmx05METexNG6YC0G4qenjG6trlaUvNmUcRjGPGiGAAgA3MvImZKwC8A2B02DGjAUxVlj8AMJTsns0UbGFFUXylE63gWKXxBrjo4FHsKTHWo4+3ULxeO/yc1XtQVqEvv1BQyYUXtYl5niLbwbJArimtrsMzV+zSdR/BHsxQBG0AbA9Z36FsUz2GmasAHALQTNnXkYh+JqKviGiwCfIIMG+y+Ok565E3frpmO266RTmOnOSLlbt1N7ChBL+6eWt266oTce7jX+m+l16dEwyGGxXiAaXlpw5VbuWVNVEnztX4Zv3e2uUPF1ufjTfIT9sO2HYvL+H0W7sTQHtm7gvgLwDeJqLGagcS0TgiKiSiwr1796odIoRg1oDryTlKHVqNDcy/vrQ+fYaVbNtf37R137SVmLE8/p5sMLfUda8X4sOftDd6m/fpn+PQa7sPVrNT844KsmT7wXoeUKHmshpm3VXxrnrlx9rl299fqutcI1z1sqROV8MMRVAEINStoq2yTfUYIkoB0ATAfmYuZ+b9AMDMiwFsBHCi2k2YeQoz5zNzfm5urgli28/8tfa5rH38cxEmf77GtOuVa3R7/XFLsWn3NAO9LpwPfLpSdbtZIyyrA830KoKgPITjpqjwpIgXPb8A05fXzV8UepcaZs/UyRaLtDpmKIJFALoSUUciSgNwOYBpYcdMAzBWWR4DYB4zMxHlKpPNIKJOALoC2GSCTK5ES1ZHMzFzGOxESL4ZMRcHXJD+Yc2uktrl+6apKxqz0Psz7VGikEvKq2o7DkT1A79Ckw3OWbUbEz5aXrs+6NF5ltYjNpMjcdRo8AOGFYFi878VwCwAqwG8x8wriWgSEY1SDnsFQDMi2oCACSjoYnoWgGVEtASBSeSbmdldXUoPk2JiZjAn+nvvLtpm+BoPTzeWkjmIkY7k32fWHZlZ2Xs2cu3QrKNDH/8KRREmkN9auDXuHEWRsHO0LNTHlIAyZp4BYEbYtntDlo8BuETlvA8BfGiGDF7g5+0Hbb2fmcnQnJhkqzAh5fUeHXl3AHsUXg0zki2K/TaiCJbtOATguPkkkreTFeaVb9fvwzndWph+3Vh8sHgHxvRra/t93YZEFtvIpzb7Ta/epc2V81hlNVKSCClRahjoUSo7DpRh6ndbUF0D3Pur8JASe9HTZkWbMDXa9oUGmllpZRv13LeGrxEcSB4N+T5CzXRWqLDDDuWpWhMh06zfcNpryDfocRs0C2Zttv3THp2LP79X13PDSGTuut0leOXbzXh1wea4zg9i57TeoaOV6H7PTBRZVG0tNNDsCgujYM2YE3ng04A57faQZ+KRkEAwo7y1sH66khnLtRfTMROZOw4gisAmzAjJj4d/zFqLT5YURU1NfLCsEou3FGNEWO3hrSqulFogIlMK45jxkgZNJU/MXocNeyJnvgyaQdbuLlHdb2Z50cKt3vBlX150qHY5NysdAOpVM4uHiR+vqLet1IGOEmB+KhavIorAJpxSBKt3Hsaf3lmC/g/PiXrcL4eOYVXYMDleiaurzUkkZsYrGjRzPzN3PSZ8tCyiDf0HDeUZ7cSuLLWxCH5fF/RqBSDgITTXosydRmstx4O4kwYQRWATTqU//3KtvuC7UL/7eJVXZnpKxHTDI5/5RrOZzAzV2TgjtXZ50ZYDmLtavQ7xLW/9FPU6T89db4I0x4kVLOaW2g6/ejYw53CssrreiNFsXv7WmCkxHtxWctMpRBHYhFpyLzdSUV1TayapjtNr595PVkQ066z85TCKy7TlpTHjK5uzeje+XX88c2pVnF418UT5RiPWpG60iWs7CY4Sd6mMGBMBtyhcpxFFYBNujbwM9xXfsq+sdlu8jeb6PUdUba/BiWu7g9OufOV4WgG1uYvwOsR2EKuhj/e7F/Tx1sJtEWtU+wlRBDYRy0bvFGdMnldnfcQz39TWHq4wkD5ZLYIz2LZpNZO99p35poLMtPoe004M1oKms7zx01UnsZ3yoonEzJXHcy01zrDO6/z8J/Un2gO0FdOJxHCLTV5eQBSBDdz3SX0vCSdYExJXMH/tnohFV4K9UbMnLINzDlU1NZpGBduLzXflvOGNRfW2GVF48ZKcRLXRudtVajhMX+YuRRDK4WPWpWlYF2dReSdSoCQSoghsYOr37ijzOPW7LVinuEde+9oifLY0emNz1GQ7ddA8du7jX+EFh7KUqtUoqM2waiPJSVSraIN/3164Fa8qE6bSrGmnvKpad/ptoS6iCHzEf3/cjoKnvq5tkGOl/73nf8ZHMmUVVTh0tBJ546fjlAe+qN3+3LwNMc8deXJrw/ePRkVVDUqOVWKPSnlJqzl8rKr2d2jeKA3b9pfhro9XYJLiQtmmaYbtMrmFw8f0TeA+aILbad746Yav4WVEEVhMvNWprKKGgVIDRVb00vPeWfjzu0sA1E1lfUxDLeDcRumWyJQ3fjqKSysw/sNlOOn+L9CqSQNL7hOLexST4dpdR/C/JXUztw/u6s1U62Zw8v1f6HpvzPbo8iOSa8hiXvgyds/XbuzuAc9TCUA6s0tzW2UI59QHZ9cun9G5me15oIDjMR53fby83r5KB+Yt3ERVDSNNmQA+Ul6FBRv2ITcrHae2z6537Jqd6tHgejlUVokmDVNjH5iAiCKwmNW7zHlIzWTYE8aLuhulb/tsrN55GN1bZalGd57ywBe2+Xh/t9FdUcX//XEbSn2eNz9oNnt9wWbcr+Q+atIgFUvvO7/esfs11kuOxZ0fLsOLV/Uz5VpeQ0xDVuMuy5BrWLPzMIY//Q1mrdxdp3e+YMM+5I2fbmugzzQHRgPRmPDRcvy07aDTYjjK1uJSfL1ub60SUMNsZbm86BAu+/f3pl7TK4giEBzhi1WBVA83v7kYxaUVtUFdoYnO/IzfRwQFT32Dq1/9sc628FCBXvfNwnuLtqNBarIp9yw6eBQLNxdj0KOBTLHrd5dETdaYSIhpyGJ2OeCR4kUG/2M+crPSffPixcIrGUrtpKyiGoVbivHsvA14/NJTAAB3fLjM9PvsPHQMX6zchXH/WQwAmP3ns9C1ZRYAYPqyX3Bqh2y0dsjBwCrIi4EY+fn5XFhY6LQYmjj7H/Ox1YEUBoKQyJzYslHcwWfxsO6h4UhOInS+K1CIccvkkZbfc29JeW3678rqGqQmJ+FIeRUapcfffyeixcycH77dFNMQERUQ0Voi2kBE41X2pxPRu8r+hUSUF7JvgrJ9LRFdYIY8bkKUgCCYz+7D9o4c+z74Be6ftrLOttLyKqwwYMqsqWG8vXAbqmsYh8oq8dScdTikFBZiZvR/eA7W7y5BRVUNuk78HPPW7Ebv+2ZhT4n5VgbDioCIkgE8D2A4gJ4AfktE4fUJrwdwgJm7AHgSwN+Vc3sCuBxALwAFAF5QrqeL8qpqw+kQmFk1TL2sogoFT31de/2t+0s1ZYZk5noyScpbIRx5JOLD7qyhpeXV+M8PxzMElFdVo9d9s3ChkqZ7za7DWKc02sFU7rHaiU37juCuj5dj9c7DmPr9Fjw1Zz3+/fVG/LTtAK55LZAKZe+R8tpMtde9HrCC3PTGYny4eAfOffzLOterqKrBzBU7sWmv+kgp2ryTGXMEAwBsYOZNAEBE7wAYDSB0un80gPuV5Q8APEcBn8HRAN5h5nIAm4log3K9mFP3P24uxvbiMgzslIORz3wb0M4PXIAMZeLoaEU1UpPr1+FlDhRNYWbUMPDLwaNol9MQHSfMQIusdPw4cRiYGW8u3IaUJEJmegrW7CrBDW8U4ut1Ab/viSN6YGCnHMxcsQt/HNoV24vL0LVlFm57Zwn+t6QIn/3fmXh67nrMXlU3973LYssEFyCPhDfpdvfM2uW1u0pQ8NTxxHXDerTE0B4tMOGj5dgyeSTW7S5Bg9RkNG2Yite/24K8ZpnISE3Ggg2B9OgXPvstRvc5AQDwwpcb66Rf+deXG7EmzAX95+0H8fP2gwACOalyMtNQdKAMM1bswrw1e9C0QSp+3bcNGqQlY39pBfI7ZKO0vAr3f7oKyY1btFX7fwzPERDRGAAFzHyDsn4VgIHMfGvIMSuUY3Yo6xsBDERAOfzAzG8q218B8DkzfxDtnrkde3LmZY+p7rtyYHvkZKbhmXkbMKBjDpo0SEWbpg3QPqchpny9ybLJ2yHdcnUXgREEQbCTopduPly5f3uT8O2e8RoionEAxgFAcuNcZEY47s2F22qXf9xcbINkAUQJCILgerhG1V5lxmRxEYB2IettlW2qxxBRCoAmAPZrPBcAwMxTmDmfmfMzsuqHmQfp1jIL/Toc39+0QSo65WaiQKm5ahVm+TILgiBYiOq0lBkjgkUAuhJRRwQa8csBXBF2zDQAYxGw/Y8BMI+ZmYimAXibiJ4AcAKArgB+RAy6t8rCXy8+CYVbinHVoDxc8u/vUVFVg2X3n19bo3Z7cRmaNkxFVkbd3CF7S8rBzMhMT8Huw8ewvOgQRvdpU5t9cM2DBQCAcf9ZjM65mejaIgt3fbwcHZo1xNb9AQ+g24Z1Rc/WjfGvLzfi4V+fhJ2HjuLc7i1w4t2fo7Ka8faNA3HrWz9rLskoCIJ3GT+8OyZ/vqZ2/aQ2TTCgYw5e+XYzFt41FOt3l6BhegraZjfArW//jNM65aCiirF212HMVywJw3q0wJzV9XNytctugO0HItflGN67FU5u2xTbikvx07aDWKvMJ/RonYWiA0eRkZqMgt6tsOPAUSXnl3oRWVPiCIhoBICnACQDeJWZHyaiSQAKmXkaEWUA+A+AvgCKAVweMrk8EcB1AKoA3MbMn8e6X3gcweZ9pchMT0aLrPhT927bX4ZmjdKQGeaju/9IOfo9NAcrHwh4ti7cvB8nt22K5hEyY9bUMJKSCKXlVThWWY1+D7mzMpkgCPGRlpKE78efW/tub5k8El+v24OU5CSc2DILqclJaJyRgq37y5DXPJIRG1hRdAgXPvstXrumP37adgDPztuA35zaBoO75uI2JWPvf64fgD/+92ccKDvuJdU4IwW3DOmMv89cWyee4UBpBf799Uac17NVHasIEHCSWV50CKe0y1aNI5CAMovxe55zQbCCJg1SbXchHd67FT5fESjZuWXySGzaewRLdxzEr/uqOuLEpKq6BvdNW4l7LuyJ3YeP4bZ3l+CZy/uiXU5DMDM6TpiBz/7vTLRukoF+D83B45ecgtvfX4o5fzkbHZtn4lhldb2OaywiBZSJIrCYbnd/XicPvyAIxmnZON3WoLJFE4chMz0ZPe+dBcCeyOKftx1An3ZNQUQoOngUJzTJwNrdJejeqnHc17Q0sliITPdWWU6LIAgJw4Th3QEAb1w30NL7jDurU+3y2zcORG5WOhqmpeCOgm6YdusZlt47SN/22bUp2ts0bQAiMqQEouEZ91Gv0ja7IZbukIyasXjrhoHo1yEbz83bgOfmu6+Yj92c3LYJlslzU4eczDTcdHZn3HR259ptE0f0wJNz1qGswtz62neN6IHTOuWgXXbD2oRzAPD7IV1MvY9bkBGB4AjBkdJ4pYd3RpfmyEhNRo/W1vR4vEbDNH/30aZeN6D22QhSHRaaP/f2s3HtGXmmKYHGGYHvPGj2Obd7yzpKIJERRWA1kkxGlQt6tcKLV56KGwd3wuZHR9RuH3FSK6yaZG/uwbxmDW29Xywu6dcW+R0ix8r4gf552bj57M64elCH2m3h85mdcxvVSyFjhJPaNrHF9u9GRBFYTEsDLq1W8fq1/Z0WAZv3laKgd2skJ1GdUpVEhIZpKdgyeSQuy28X5QrmcXOIqcENPHbJKWjq09q5QVKSAk3TpNG9sWD8ubhxcEc8MLqX6rHpKeY0Y5MvPtmU63gRUQQWc0dBN6dFqEdflQLgdvPpstjlIRumWxetPfvPZ2FAxxwAzlVFa9Ig0NgX9GqFMf3quiCm+DxVbXLI/9+maQNMHNkzopvmSW3qpc6Ji3Y57hoZ2okoAovJcGHqicw0+2T6YcJQFN49LK5zj5o8ARhkzYMF6NoyCy9dnY+v/3YO1u0uiX2SBXz2f2cCAG48q2MdEwgAX08Uz7xtcB1FEItWTdw36vYaogh8RM/WjTH1ugG1dtW/XRB9tHLrOcY9JFo1yUDzRunY+MiI2vQdAHDlwA5RzgrwzqLthu+vRlA5N2mQivbNGjqirBtnpCBJaez2H6nAyW2b4sbBnXDlaYHvZfUuZ5STG9DrIvnQRb0N39OvcwNBRBHYQE+XeMI8evFJOPvEXACBRv6KAe2jHh8tPF4vyUmEJGUu4MNbBuG+X4XXLnKOOwu6xz7IZGoYSFa+j1TFxj1xZI/aRs3fhiF9NG2YJkWfDCKKwAZm/Gmw0yIACPimB/nrBd2QnZkW9fg0kybhggSH+7mNMjR5e1hhJ3/p6npBlbWJCu2kqqam1qTRPLN+3qrg/IUbCc5tuAlSz6UmaEQUgU189bchToug+rJ88oe6UZKvXdsfH94yCACQlhz/y6VWYDvYridpfOomjOgR9/0jofYf1TiQZiXoE7/mwQKc1Lb+ZGf4nIHThKZxtzLHT7wmmvAYAz24wYvOaUQR2ER4Omy3cEq7pnXW++floE+7gFdRitYWWwVWKcIYVERae29m9fEmX3xS7XJFdf28T+0d8BYJfreR5ieMfPeCdkae3BpDurVwWgzHkafNJrxiw8xISaqVVY/nRigf3DwINVHy7GVlaIuaNeM769chG5f1jx6PkKThRlk6szzG4sWr+kXdn5Hqrlczp1F0M6JX6eBjl9FQ3PW0JTBaGhsrGKjT1pySnBTSc4/vngfKKiOaW7ZMHmmrTT40YK1Jg9TayfJwbhvWNep17jV5cjuSHEEau8QOv/GRQNR344xUyz1rnEjQaMCilFCIIrCJZIcms/KaZWJMv7b49NYzox6n1vtPilPm1GQyxe5uxjsa/LfO6NIML12dHzF/+zWn50W9jtkT57FwS/xJ8Ln4cfN+AMCHt5yOId2iK7F4+dwBpwovpuG3An9ntvIBD4zupalR6Z+XjXfGDaqzLTdLvQpbLJiNTd6FXscoQWX21g2nabpXVkYKSo5VWSKL12ib3aB2+adtBwEETG1GOzV/PLcLnplXP8OsE54/PvxZVZERgU3orSRkFlqUwMK7hmJKmFvluoeGx50JtE12Awzp1gJnxTB/2IHWtiU7Mw1f/nUI+uepm9LUJr/18PHvT69dXvtQQZQjnSeYFuXpy/vWbrvnQvNMY38+78R6235zanxVvoxSI7YhAAYVARHlENFsIlqv/FVNYkNEY5Vj1hPR2JDtXxLRWiJaonwSevq+a4tGtt6vo8aAsJaNM+rZ7Y2YQk5smYVXr+mPN64bEPc1APt7a3nNMyN6KhkdEYTmd4rX5KYFM8wrGSmBzkOLkBFhaOyA0d9FrefvVGxCy8aSngIwPiIYD2AuM3cFMFdZrwMR5QC4D8BAAAMA3BemMH7HzH2Uzx6D8riaKwZGj+Q1m9Ym5mAZ2t1+HW2G/TbVxDTFZmGlIjDSoAbTccea37HCrt7zBGei7687s6Mj93UbRt+S0QCmKstTAVykcswFAGYzczEzHwAwG4C7x8YWYbeducrEYa8Tc92X9DOehvrB0cbz0ADGfrvwdNrxuuVqwYiSOTWkBsIrY/PRpmkD1ePOtaBTEJ591S6s/C28hFFF0JKZdyrLuwC0VDmmDYDQ7GE7lG1BXlPMQvdQgseJN7PZF7tdtrd9pJuYkJO/ReP4JrzNpKD38ajc+y3OsaS3XQt1CHjk14HAO2ZgaI+WEV2erxqUhykhcRBf/nUI/vcHe+r4GsXseJBEIaYiIKI5RLRC5TM69DgOjBf19pt+x8wnARisfK6KIsc4IiokosK9e/fqvI07GHXKCbbda1iPlnjkYnN6w4D2OYNmMfIX2U16ij43zBtDipaHYtbYysyKWmro7UsF8zllpafUOhaET4w/NubkOsoMqDvyaJCW7Jn6CTI1rE7Mp5KZhzFzb5XPJwB2E1FrAFD+qtn4iwCEjo3bKtvAzMG/JQDeRmAOIZIcU5g5n5nzc3Od90aJBzsHPFkZKbobQTWG9QiYAbTK/s9LTzF8Tyc5rVOzetsu6nMCBndtHvc1gw3ruLM64byeaoNm89Br6khR8klVR7F9XZLfrl7uqND7EOm/75+GHg/gs9MsdP8o9Spnfsdo92QagKAX0FgAn6gcMwvA+USUrUwSnw9gFhGlEFFzACCiVAAXAlhhUB5BwaxEai+P7R+IKNV4ucqqKLklPMqdw7sb8i4J/hR3jeih6zrrHx6u+156O+bBmIC3b4weZxFOaL8gLTlJd2nNUBfSf15iX+fBqbkIt2PUYDYZwHtEdD2ArQAuBQAiygdwMzPfwMzFRPQggEXKOZOUbZkIKIRUAMkA5gB4yaA8gkVoVSztXVYI3gyyGxozd8Ufoa2/n9ZAZ/W5oKmqT0jyQS0/dWhSvKYN09BE4/Oh1aVZsBdDioCZ9wMYqrK9EMANIeuvAng17JhSANEzbwlx40QkbGoy6a4uZSXBUpBGmKQxMjvIWzcMxO9eXlhnm13pKf73hzNAOnO2qkUJa/E2C0+OqtV0GBpbkpOZhuLSCk3nCdbiPidrwRTMzrGvJRiusto9U3ErHrgAvU0oaq63V35Gl/pzCXrNJvFC0O/mq3b8f3/cFvM8M2IhFt89DD/eVa8fGRdPX97HlOv4FVEENhNvIXenuW1Y/bQAbkatME48qE3uXnN6Hp67oi/SoiiJX/dtg6X3no8fJw5FrxOMKyQtEOmv4aDWoB/WUHgmXjUQ2kEhIrQwKbI3r5mYnIwgisBmmjeyx6/d7L65U2m0nUbt97ptWFdcePIJUXvfAzvmoEnDVLTIsi+FAYF0e6bFG1AVrwfcsUprnAn0jlAuzZdJ41BEESQodqfXzbbJ/OEGgtXmpv9RPa/P69f2x0V926ju08vP95yn+dh43Djj1e/xWoZObGlNvi3dJjHT6t8lBhJmJ5jCtFvPRKVKGUgv8tq1/XHta4vqbR87qANG9TmhtrHtEmHexMzSh9k6AvSCjWFuVjr2lpRrOsfukZ5VsTRaLts/L5BC47ExJ6umGvczMiJIUOz2GmqX0xCdcu3NrmoVfdo2Vd3+wOje6Nehbprqn3T02K0m2MvV09RmqAQdanl03GYp1NLDD+auuiS/neQYCkNGBAmKGYVh/Ep2Zhomje6FI+Wxe405mWlo3igd+46UW17KMRbBXrGeTvcNgzvizuHddd/LbYV6wt1Z1ZDGPzIyInCAm85Wz2djJmZmHg0Sq9ylW1hsgmfW1YPy8PshXUyQxj6CE6anRBjRqNEgNblOMJnW893kKgzot/kndnpL/YgiSFCssNef1LZJPW+LrIwUZKS66zFqZpNn1nHc0SgGG7cXfneqoeu0y4kdHR4ccd41Qv9owgq0dPZDG3/RA3UR05AT2NBulFuU8+cfY07BwI7NcPv7SwE4U3BcUCfYuBnJcDrztsER6xCEEqxnPO6sznHfy0wSPIO95YgicIAcG1I1V1iY/C3UFt3W4zUPEonwxjArPQUlGuY5QtGaIiSveabjcyKhaNEDoccM6twMp7Zvapk8XsNdY3qfcL0N5fGszA8ffKGeuqyPZffwEu/eNMiUvEZGCW8MNXWSE6QjrTegrEuLLHz0e28U07EDUQQOYFZxkk8iVIV6ZWw+/nWl9fn8RvcxJ2jKTF6/tr/t9+yc28iUvEZGSQ1znfGTl4w2neef70Mvogg8TPMs9UnRri2y6pQgNJsuuVmWXdso3Vq5VzarCRaZAYABHXO0pc92xzy3YcxIgudnRBF4mEiPfmiDYAUntW3iKvuwn/nNqce9uEJjR967aVBc9Qys5M3rB1p2bdEDxnDXkyLoItLQ3yv1YwVtvHjlqXX+hpKvpE0AUG8UGF572Gn0Fs3Rg5/MYFYgisDDROoF+fmlaKCjiIxXKOjdWlmK/rvqKaBjRe98aPfoOZaqLMxFpaXzI6OGyIgicAmTRusvqk0g3H5e/ToBmSbl4vcas247C00NlpX0Gh00BH+p0aSB+dliX7km+kS9FdHutUgjbwhDioCIcohoNhGtV/5mRzhuJhEdJKLPwrZ3JKKFRLSBiN4lIn+9xSEM61G/AEosiIB+efW/cj09w0QhOYl8N1G8ZfLIiBlQo3HN6XnHzzO5Af3tgHYR91npWSWTxcYwOiIYD2AuM3cFMFdZV+MxAFepbP87gCeZuQuAAwCuNyiPZ4mnr0QQlzi/MUOpgbBQKfGo57kZ0i0XAHD/qF6W2ev/rDJCDWLFKCSIXQWfEhWjimA0gKnK8lQAF6kdxMxzAZSEbqNAGOS5AD6Idb4fsLuQTKLhB3VIBPQ8oXHtMqAvC6hatHlLk0pF1uLgYzzqlBOcu7nHMaoIWjLzTmV5FwA99o1mAA4yczAGfgcA90Uo2UQ8ekDyqxzHUvuzSwg1+QVHgjU6HpxrTs+rs75l8kh0TpAaEkBgpCPER0xFQERziGiFymd06HEc6NJa9jYS0TgiKiSiwr1791p1G9toZkK+IYKMJEb3CfQCO+UmdvHytJQkdA2ZD6gdEei4xvm9WmHVpAvMFcxFxOoWSccpMjEVATMPY+beKp9PAOwmotYAoPzdo+Pe+wE0JaKgi0tbAEVR5JjCzPnMnJ+bm6vjNu5kcVhlq3ja8yQi1cliP9KpeWIrgnUPDccJIVlBGymeYXo7Ag3TrPUoc7Jb0ihD/X9b82CBzZJ4D6OmoWkAxirLYwF8ovVEZQQxH8CYeM5PNKIN8SM+yASkq5QaFBKbLZNH1pqJ0lwWPezkADVSJLUfvej0YvQpmgzgPCJaD2CYsg4iyieil4MHEdE3AN4HMJSIdhBRcHx6J4C/ENEGBOYMXjEoj2eJ9v5EepAb+TReIJTgYN+vFrIWUSZ7/fqdCPox1JIw834AQ1W2FwK4IWRdtXoJM28CMMCIDIlCPLb+YARxw7RklFVUmy2SJ2AAN53VCWd2be60KI6Rk5mG4tIKp8UQPIy7xpU+ZcvkkYZsq3eP7GmaLG7j4lNjO5JNGNEDg7t6f97IbC7NjxzcZRVuy28UikwVR0YUgUvQOyB447rjA6krBrY3WRr30LSBb4PNNXOsUn00eONZnWyWJDK/H+KOkpaCOqIIXMIJTTMwWId5o1+Hut5C5/XUn6LC63TOzcSZXfxrEgry0e9Px/Q/Ol8hDYjcobmjwB1F7gV1RBG4hIZpKfiPjoyQ4YnlXro632yRXM/c24fgEgfMH26je6vG6HWC8xXSAOtrYcTLyJNbY0DHHKfFcC3idiK4GokB8hYtskxOWWGQ684I1Ad//or6tRyE48iIwEHUcudf0Mt/Jh4hsdj4yAjcOLij02IAAE7t0NRpETyBKAIHUZvkvfK0DoaumWg96AT7d3xBchJJOgePIaYhB7mjoBtuOttczw4pUym4gbxmx1N+bHpkhIOSCFqQEYGDpKckm25TTbQCHQn27/iG3w5oVxv5nuRg50TqdWhDFIHLMJoWINHqFYuJwZsQES7vb79Hl9kjbL8gisAjFPRqpek4K6tACYIefFAiImGQOQKXEakDrMU/+9s7z4mYgdGryHjAu+gpmmMWYgqKj8RqNRKINiG55wFtk8BtsxuaX3rQaeS99ixOFE1yc64jNyOKwCMkJ/nzp5Ienndxg2nIiVGJF/Fn6+JBEszioxmZKzaGnvxVZlPtAtPQnNW7bZfBi8gcgUdING8gwV6euPQU2+95y9md0bddU9vvG0ppuT/rdOjFp/1M7zHqlNh5+RMRUX/mcFEf+5+fdjkNbU8K2Kpxep11MQ1pQxSBy+ic20h1+6DOzdDH4d6VE0T0opIRkiaCzgNOBnXZydWD8uqsd22h/j4JdTGkCIgoh4hmE9F65W92hONmEtFBIvosbPvrRLSZiJYonz5G5EkETgjzFgqlfU5DAECrRPMMikKkyeIF48+1WRJv8tBFvbFo4jCnxbCNcIU34qTWDkniLYyOCMYDmMvMXQHMVdbVeAzAVRH2/Y2Z+yifJQblSWiCveMJI/xT5CPSiKBFVrr6DqEOGanJyPXxd5VoKVeswqgiGA1gqrI8FcBFagcx81wAJQbv5TuW338+1j88vHY9RXEh7RCS0Mtv3HpOFwCSekLQhjwm2jCqCFoy805leReAeJLpP0xEy4joSSLyb9clhDsKuuGOgm7IykitEyl8z4U9AADJPnq6w//Tm6X2raADGRFoI6b7KBHNAaCW6GZi6AozMxHpnaKfgIACSQMwBcCdACZFkGMcgHEA0L594hZrB4DfD+miur1pw0Ahd1+5koa9yE5EqwrepWlDyb2lhZiKgJkjzjQR0W4ias3MO4moNYA9em4eMpooJ6LXAPw1yrFTEFAWyM/P93VrkJ3pn4fbRypPsIBozhfCcYyahqYBGKssjwXwiZ6TFeUBChh8LwKwwqA8Cc+mR0agdRP/PNwtwvzCfd0DEDQx67aznBbBcxhVBJMBnEdE6wEMU9ZBRPlE9HLwICL6BsD7AIYS0Q4iukDZ9RYRLQewHEBzAA8ZlCfh8Ys/eJDf9m+PxXcfH5SKZUiIhVotcCE6hlJMMPN+AENVthcCuCFkfXCE88UZXIhKUhKhWaOQUYEoAiEGMj+sH4ksFjxFeqo8skJ0/DZqNgN5qwRPkZGajC2TRzothuBiRA/oRxSBIAgJhZ/ibMxCFIEgCAlFY6nbrRtRBIIgJBQZqcno3irLaTE8hSgCQRAEnyOKQPAMF/U5wWkRBI8gSQn1IYpA8AwN06WyqqANUQP6EEUgCELCIQMCfYgiEDyDvNuCVkQR6EMUgeAZ5OUWtBKpxKmgjhhdBUFIOP5wThes/OWQ02J4BlEEgiAkHAW9W6Ggt1o9LUENMQ0JnqBzbibO7JLrtBiCkJDIiEDwBHNvH+K0CIKQsMiIQBAEweeIIhAEQfA5oggEQRB8jiFFQEQ5RDSbiNYrf7NVjulDRN8T0UoiWkZEl4Xs60hEC4loAxG9S0RpRuQRBEEQ9GN0RDAewFxm7gpgrrIeThmAq5m5F4ACAE8RUVNl398BPMnMXQAcAHC9QXkEQRAEnRhVBKMBTFWWpwK4KPwAZl7HzOuV5V8A7AGQS4H0gOcC+CDa+YIgCIK1GFUELZl5p7K8C0DLaAcT0QAAaQA2AmgG4CAzVym7dwBoY1AeQRAEQScx4wiIaA4AtRC9iaErzMxExFGu0xrAfwCMZeYavfnCiWgcgHEA0L59e13nCoIgCJGJqQiYeVikfUS0m4haM/NOpaHfE+G4xgCmA5jIzD8om/cDaEpEKcqooC2AoihyTAEwRbleCRGtjSW7gzQHsM9pIWLgdhndLh8gMpqB2+UD3C+jHvk6qG00Glk8DcBYAJOVv5+EH6B4An0M4A1mDs4HBEcQ8wGMAfBOpPMjsJaZ8w3KbhlEVOhm+QD3y+h2+QCR0QzcLh/gfhnNkM/oHMFkAOcR0XoAw5R1EFE+Eb2sHHMpgLMAXENES5RPH2XfnQD+QkQbEJgzeMWgPIIgCIJODI0ImHk/gKEq2wsB3KAsvwngzQjnbwIwwIgMgiAIgjG8Glk8xWkBYuB2+QD3y+h2+QCR0QzcLh/gfhkNy0fMER19BEEQBB/g1RGBIAiCYBKeUgREVEBEa5XcRGrpLFwlDxFdQ0R7QybJb3BCzjCZXiWiPUS0wmlZgNjyENEQIjoU8h3ea7eMKjK1I6L5RLRKyaH1JzfL4tLvMIOIfiSipYrcD7hZFje+y0GIKJmIfiaiz+K+CDN74gMgGYGI5E4IRCcvBdDTzfIAuAbAc05/d2EynQXgVAArnJZFizwAhgD4zGk5w2RqDeBUZTkLwDqnnkUtsrj0OyQAjZTlVAALAZzmVlnc+C6HyPYXAG8b+Y29NCIYAGADM29i5goEYg9Gizz6YOavARQ7LUcQt8mjBWbeycw/KcslAFbDofQobpJFDxzgiLKaqnwcmbB0kyx6IaK2AEYCeDnWsdHwkiJoA2B7yLrTuYm0yvMbJf32B0TUzh7REo5ByrD9cyLq5bQwoRBRHoC+CPQiHSWGLK77DhWTxhIEMhLMZmbHvkONsrjxXX4KwB0AaoxcxEuKwIt8CiCPmU8GMBvHM7UK2vkJQAdmPgXAswD+56w4xyGiRgA+BHAbMx92sSyu/A6ZuZqZ+yCQXmYAEfV2sSyue5eJ6EIAe5h5sdFreUkRFAEI1cJRcxPZQEx5mHk/M5crqy8D6GeTbAkDMx8ODtuZeQaAVCJq7rBYIKJUBBret5j5IzfL4tbvMAgzHwQwH4F6JY4SSRaXvstnABhFRFsQME2fS0Sqwbux8JIiWASgKwWqmqUBuByBXEeulUdJxBdkFAL2W0EHRNSKlFS1FEhjnoRAwkInZSIE0qGsZuYn3C6LS7/DXFIKVBFRAwDnAVjjVlnc+C4z8wRmbsvMeQi0P/OY+cp4rmU06ZxtMHMVEd0KYBYCHjuvMvNKt8lDRJMAFDLzNAB/JKJRAKoQmBC9xil5gxDRfxHwImlORDsA3MfMjuV4UpMHgck6MPOLCCQlvIWIqgAcBXA5K64SDnIGgKsALFfsygBwl9LbdoUsANoDrv4OWwOYSkTJCCim95g5fvdHC2Rx+7tsJhJZLAiC4HO8ZBoSBEEQLEAUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPkcUgSBEgYiahWSc3EVERcryESJ6wWn5BMEMxH1UEDRCRPcDOMLM/3RaFkEwExkRCEIcKDn+P1OW7yeiqUT0DRFtJaKLiegfRLSciGYqKSBARP2I6CsiWkxEs8KiVQXBMUQRCII5dAZwLgLpB94EMJ+ZT0IgknekogyeBTCGmfsBeBXAw04JKwiheCbFhCC4nM+ZuZKIliOQcmSmsn05gDwA3QD0BjBbSfuTDGCnA3IKQj1EEQiCOZQDADPXEFFlSC6fGgTeMwKwkpkHOSWgIERCTEOCYA9rAeQS0SAgkDraLQViBEEUgSDYgFLOdAyAvxPRUgBLAJzuqFCCoCDuo4IgCD5HRgSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPgcUQSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPic/wcvziJ0eY2VRAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "librosa.display.waveplot(samples_out, sr=sr)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "special-delicious", - "metadata": {}, - "outputs": [], - "source": [ - "import getpass" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "seasonal-consensus", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['GetPassWarning',\n", - " '__all__',\n", - " '__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " '_raw_input',\n", - " 'contextlib',\n", - " 'fallback_getpass',\n", - " 'getpass',\n", - " 'getuser',\n", - " 'io',\n", - " 'os',\n", - " 'sys',\n", - " 'termios',\n", - " 'unix_getpass',\n", - " 'warnings',\n", - " 'win_getpass']" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(getpass)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "dress-distinction", - "metadata": {}, - "outputs": [], - "source": [ - "getpass?" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "rental-anthony", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Worker:" - ] - } - ], - "source": [ - "import multiprocessing\n", - "import cProfile\n", - "import time\n", - "\n", - "def worker(num):\n", - " time.sleep(3)\n", - " print('Worker:', num)\n", - "\n", - "def profile_worker(num):\n", - " cProfile.runctx('worker(num)', globals(), locals(), 'profile-%d.out' %num)\n", - "\n", - "\n", - "\n", - "for i in range(5):\n", - " p = multiprocessing.Process(target=profile_worker, args=(i,))\n", - " p.start()" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "separated-restriction", - "metadata": {}, - "outputs": [], - "source": [ - "!ls" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "painted-variable", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(2, 2)\n", - "[ 1 20]\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "l = [(1, 20), (2, 30)]\n", - "scores = np.array(l)\n", - "print(scores.shape)\n", - "print(scores[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "satellite-insider", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0 1]\n" - ] - } - ], - "source": [ - "sort_idx = np.argsort(scores[:, -1])\n", - "print(sort_idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "developed-thirty", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - } - ], - "source": [ - "sorted_val_scores = scores[sort_idx][::1]\n", - "print(sorted_val_scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "official-bench", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - } - ], - "source": [ - "sorted_val_scores = scores[sort_idx]\n", - "print(sorted_val_scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "ranking-camera", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x14\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x1e\\x00\\x00\\x00\\x00\\x00\\x00\\x00'\n", - "[ 1 20 2 30]\n", - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: DeprecationWarning: tostring() is deprecated. Use tobytes() instead.\n", - " \"\"\"Entry point for launching an IPython kernel.\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:3: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead\n", - " This is separate from the ipykernel package so we can avoid doing imports until\n" - ] - } - ], - "source": [ - "a = scores.tostring()\n", - "print(a)\n", - "b = np.fromstring(a, scores.dtype)\n", - "print(b)\n", - "print(scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "breeding-proxy", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "numpy.int16" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.int16" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "coordinate-hungary", - "metadata": {}, - "outputs": [], - "source": [ - "dtype = np.dtype('int16')" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "specified-jackson", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "int16\n", - "16\n" - ] - } - ], - "source": [ - "print(dtype)\n", - "dtype is np.int16\n", - "print(np.iinfo(dtype).bits)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "activated-insight", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/train_test.ipynb b/.notebook/train_test.ipynb deleted file mode 100644 index 67212e50a803af7bf2a9009d00e0c295a3717fa9..0000000000000000000000000000000000000000 --- a/.notebook/train_test.ipynb +++ /dev/null @@ -1,1887 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "cloudy-glass", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ['CUDA_VISISBLE_DEVICES'] = '0'" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "grand-stephen", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.0.0\n" - ] - } - ], - "source": [ - "import paddle\n", - "print(paddle.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "isolated-prize", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "romance-samuel", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/data/manifest.dev',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/aishell/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'linear',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "timely-bikini", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "from data_utils.dataset import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=False,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "organized-warrior", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " if arr.dtype == np.object:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[14 , 34 , 322 , 233 , 0 , 0 ],\n", - " [238 , 38 , 122 , 164 , 0 , 0 ],\n", - " [8 , 52 , 49 , 42 , 0 , 0 ],\n", - " [109 , 47 , 146 , 193 , 210 , 479 ],\n", - " [3330, 1751, 208 , 1923, 0 , 0 ]])\n", - "test raw 大时代里的的\n", - "test raw 煲汤受宠的的\n", - "audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "test len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [4, 4, 4, 6, 4])\n", - "audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n", - " [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n", - " [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n", - " [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n", - " [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n", - " [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n", - " [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n", - " [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n", - " [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n", - " [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n", - " [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n", - " [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n", - " [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n", - " [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n", - " [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n", - " ...,\n", - " [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n", - " [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n", - " [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n", - "\n", - " [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n", - " [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n", - " [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n", - " ...,\n", - " [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n", - " [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n", - " [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n" - ] - } - ], - "source": [ - " for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test', text)\n", - " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n", - " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n", - " print('audio len', audio_len)\n", - " print('test len', text_len)\n", - " print('audio', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "confidential-radius", - "metadata": {}, - "outputs": [], - "source": [ - "# reader = batch_reader()\n", - "# audio, test , audio_len, text_len = reader.next()\n", - "# print('test', text)\n", - "# print('t len', text_len) #[B, T]\n", - "# print('audio len', audio_len)\n", - "# print(audio)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "future-vermont", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "煲汤受宠\n" - ] - } - ], - "source": [ - "print(u'\\u7172\\u6c64\\u53d7\\u5ba0')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dental-sweden", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "sunrise-contact", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "hispanic-asthma", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "hearing-leadership", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "skilled-friday", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "copyrighted-measure", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "employed-lightweight", - "metadata": {}, - "outputs": [], - "source": [ - "from model_utils.network import DeepSpeech2, DeepSpeech2Loss\n", - "\n", - "from data_utils.dataset import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=False,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)\n", - "\n", - "\n", - "import paddle\n", - "from paddle import nn\n", - "from paddle.nn import functional as F\n", - "from paddle.nn import initializer as I\n", - "\n", - "import math\n", - "\n", - "def brelu(x, t_min=0.0, t_max=24.0, name=None):\n", - " t_min = paddle.to_tensor(t_min)\n", - " t_max = paddle.to_tensor(t_max)\n", - " return x.maximum(t_min).minimum(t_max)\n", - "\n", - "def sequence_mask(x_len, max_len=None, dtype='float32'):\n", - " max_len = max_len or x_len.max()\n", - " x_len = paddle.unsqueeze(x_len, -1)\n", - " row_vector = paddle.arange(max_len)\n", - " mask = row_vector > x_len # maybe a bug\n", - " mask = paddle.cast(mask, dtype)\n", - " print(f'seq mask: {mask}')\n", - " return mask\n", - "\n", - "\n", - "class ConvBn(nn.Layer):\n", - " def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,\n", - " padding, act):\n", - "\n", - " super().__init__()\n", - " self.kernel_size = kernel_size\n", - " self.stride = stride\n", - " self.padding = padding\n", - "\n", - " self.conv = nn.Conv2D(\n", - " num_channels_in,\n", - " num_channels_out,\n", - " kernel_size=kernel_size,\n", - " stride=stride,\n", - " padding=padding,\n", - " weight_attr=None,\n", - " bias_attr=None,\n", - " data_format='NCHW')\n", - "\n", - " self.bn = nn.BatchNorm2D(\n", - " num_channels_out,\n", - " weight_attr=None,\n", - " bias_attr=None,\n", - " data_format='NCHW')\n", - " self.act = F.relu if act == 'relu' else brelu\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x(Tensor): audio, shape [B, C, D, T]\n", - " \"\"\"\n", - " x = self.conv(x)\n", - " x = self.bn(x)\n", - " x = self.act(x)\n", - "\n", - " x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]\n", - " ) // self.stride[1] + 1\n", - "\n", - " # reset padding part to 0\n", - " masks = sequence_mask(x_len) #[B, T]\n", - " masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]\n", - " x = x.multiply(masks)\n", - "\n", - " return x, x_len\n", - "\n", - "\n", - "class ConvStack(nn.Layer):\n", - " def __init__(self, feat_size, num_stacks):\n", - " super().__init__()\n", - " self.feat_size = feat_size # D\n", - " self.num_stacks = num_stacks\n", - "\n", - " self.conv_in = ConvBn(\n", - " num_channels_in=1,\n", - " num_channels_out=32,\n", - " kernel_size=(41, 11), #[D, T]\n", - " stride=(2, 3),\n", - " padding=(20, 5),\n", - " act='brelu')\n", - "\n", - " out_channel = 32\n", - " self.conv_stack = nn.Sequential([\n", - " ConvBn(\n", - " num_channels_in=32,\n", - " num_channels_out=out_channel,\n", - " kernel_size=(21, 11),\n", - " stride=(2, 1),\n", - " padding=(10, 5),\n", - " act='brelu') for i in range(num_stacks - 1)\n", - " ])\n", - "\n", - " # conv output feat_dim\n", - " output_height = (feat_size - 1) // 2 + 1\n", - " for i in range(self.num_stacks - 1):\n", - " output_height = (output_height - 1) // 2 + 1\n", - " self.output_height = out_channel * output_height\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x: shape [B, C, D, T]\n", - " x_len : shape [B]\n", - " \"\"\"\n", - " print(f\"conv in: {x_len}\")\n", - " x, x_len = self.conv_in(x, x_len)\n", - " for i, conv in enumerate(self.conv_stack):\n", - " print(f\"conv in: {x_len}\")\n", - " x, x_len = conv(x, x_len)\n", - " print(f\"conv out: {x_len}\")\n", - " return x, x_len\n", - " \n", - " \n", - "\n", - "class RNNCell(nn.RNNCellBase):\n", - " r\"\"\"\n", - " Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it \n", - " computes the outputs and updates states.\n", - " The formula used is as follows:\n", - " .. math::\n", - " h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})\n", - " y_{t} & = h_{t}\n", - " \n", - " where :math:`act` is for :attr:`activation`.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " hidden_size,\n", - " activation=\"tanh\",\n", - " weight_ih_attr=None,\n", - " weight_hh_attr=None,\n", - " bias_ih_attr=None,\n", - " bias_hh_attr=None,\n", - " name=None):\n", - " super().__init__()\n", - " std = 1.0 / math.sqrt(hidden_size)\n", - " self.weight_hh = self.create_parameter(\n", - " (hidden_size, hidden_size),\n", - " weight_hh_attr,\n", - " default_initializer=I.Uniform(-std, std))\n", - " # self.bias_ih = self.create_parameter(\n", - " # (hidden_size, ),\n", - " # bias_ih_attr,\n", - " # is_bias=True,\n", - " # default_initializer=I.Uniform(-std, std))\n", - " self.bias_ih = None\n", - " self.bias_hh = self.create_parameter(\n", - " (hidden_size, ),\n", - " bias_hh_attr,\n", - " is_bias=True,\n", - " default_initializer=I.Uniform(-std, std))\n", - "\n", - " self.hidden_size = hidden_size\n", - " if activation not in [\"tanh\", \"relu\", \"brelu\"]:\n", - " raise ValueError(\n", - " \"activation for SimpleRNNCell should be tanh or relu, \"\n", - " \"but get {}\".format(activation))\n", - " self.activation = activation\n", - " self._activation_fn = paddle.tanh \\\n", - " if activation == \"tanh\" \\\n", - " else F.relu\n", - " if activation == 'brelu':\n", - " self._activation_fn = brelu\n", - "\n", - " def forward(self, inputs, states=None):\n", - " if states is None:\n", - " states = self.get_initial_states(inputs, self.state_shape)\n", - " pre_h = states\n", - " i2h = inputs\n", - " if self.bias_ih is not None:\n", - " i2h += self.bias_ih\n", - " h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)\n", - " if self.bias_hh is not None:\n", - " h2h += self.bias_hh\n", - " h = self._activation_fn(i2h + h2h)\n", - " return h, h\n", - "\n", - " @property\n", - " def state_shape(self):\n", - " return (self.hidden_size, )\n", - "\n", - "\n", - "class GRUCellShare(nn.RNNCellBase):\n", - " r\"\"\"\n", - " Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, \n", - " it computes the outputs and updates states.\n", - " The formula for GRU used is as follows:\n", - " .. math::\n", - " r_{t} & = \\sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})\n", - " z_{t} & = \\sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})\n", - " \\widetilde{h}_{t} & = \\tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))\n", - " h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \\widetilde{h}_{t}\n", - " y_{t} & = h_{t}\n", - " \n", - " where :math:`\\sigma` is the sigmoid fucntion, and * is the elemetwise \n", - " multiplication operator.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " input_size,\n", - " hidden_size,\n", - " weight_ih_attr=None,\n", - " weight_hh_attr=None,\n", - " bias_ih_attr=None,\n", - " bias_hh_attr=None,\n", - " name=None):\n", - " super().__init__()\n", - " std = 1.0 / math.sqrt(hidden_size)\n", - " self.weight_hh = self.create_parameter(\n", - " (3 * hidden_size, hidden_size),\n", - " weight_hh_attr,\n", - " default_initializer=I.Uniform(-std, std))\n", - " # self.bias_ih = self.create_parameter(\n", - " # (3 * hidden_size, ),\n", - " # bias_ih_attr,\n", - " # is_bias=True,\n", - " # default_initializer=I.Uniform(-std, std))\n", - " self.bias_ih = None\n", - " self.bias_hh = self.create_parameter(\n", - " (3 * hidden_size, ),\n", - " bias_hh_attr,\n", - " is_bias=True,\n", - " default_initializer=I.Uniform(-std, std))\n", - "\n", - " self.hidden_size = hidden_size\n", - " self.input_size = input_size\n", - " self._gate_activation = F.sigmoid\n", - " #self._activation = paddle.tanh\n", - " self._activation = F.relu\n", - "\n", - " def forward(self, inputs, states=None):\n", - " if states is None:\n", - " states = self.get_initial_states(inputs, self.state_shape)\n", - "\n", - " pre_hidden = states\n", - " x_gates = inputs\n", - " if self.bias_ih is not None:\n", - " x_gates = x_gates + self.bias_ih\n", - " h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)\n", - " if self.bias_hh is not None:\n", - " h_gates = h_gates + self.bias_hh\n", - "\n", - " x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)\n", - " h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)\n", - "\n", - " r = self._gate_activation(x_r + h_r)\n", - " z = self._gate_activation(x_z + h_z)\n", - " c = self._activation(x_c + r * h_c) # apply reset gate after mm\n", - " h = (pre_hidden - c) * z + c\n", - " # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru\n", - " #h = (1-z) * pre_hidden + z * c\n", - "\n", - " return h, h\n", - "\n", - " @property\n", - " def state_shape(self):\n", - " r\"\"\"\n", - " The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch\n", - " size would be automatically inserted into shape). The shape corresponds\n", - " to the shape of :math:`h_{t-1}`.\n", - " \"\"\"\n", - " return (self.hidden_size, )\n", - "\n", - "\n", - "class BiRNNWithBN(nn.Layer):\n", - " \"\"\"Bidirectonal simple rnn layer with sequence-wise batch normalization.\n", - " The batch normalization is only performed on input-state weights.\n", - "\n", - " :param name: Name of the layer parameters.\n", - " :type name: string\n", - " :param size: Dimension of RNN cells.\n", - " :type size: int\n", - " :param share_weights: Whether to share input-hidden weights between\n", - " forward and backward directional RNNs.\n", - " :type share_weights: bool\n", - " :return: Bidirectional simple rnn layer.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, share_weights):\n", - " super().__init__()\n", - " self.share_weights = share_weights\n", - " if self.share_weights:\n", - " #input-hidden weights shared between bi-directional rnn.\n", - " self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " # batch norm is only performed on input-state projection\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = self.fw_fc\n", - " self.bw_bn = self.fw_bn\n", - " else:\n", - " self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " self.bw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - "\n", - " self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n", - " self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n", - " self.fw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n", - " self.bw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n", - "\n", - " def forward(self, x, x_len):\n", - " # x, shape [B, T, D]\n", - " fw_x = self.fw_bn(self.fw_fc(x))\n", - " bw_x = self.bw_bn(self.bw_fc(x))\n", - " fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n", - " bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n", - " x = paddle.concat([fw_x, bw_x], axis=-1)\n", - " return x, x_len\n", - "\n", - "\n", - "class BiGRUWithBN(nn.Layer):\n", - " \"\"\"Bidirectonal gru layer with sequence-wise batch normalization.\n", - " The batch normalization is only performed on input-state weights.\n", - "\n", - " :param name: Name of the layer.\n", - " :type name: string\n", - " :param input: Input layer.\n", - " :type input: Variable\n", - " :param size: Dimension of GRU cells.\n", - " :type size: int\n", - " :param act: Activation type.\n", - " :type act: string\n", - " :return: Bidirectional GRU layer.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, act):\n", - " super().__init__()\n", - " hidden_size = h_size * 3\n", - " self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " hidden_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n", - " self.bw_bn = nn.BatchNorm1D(\n", - " hidden_size, bias_attr=None, data_format='NLC')\n", - "\n", - " self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n", - " self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n", - " self.fw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n", - " self.bw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n", - "\n", - " def forward(self, x, x_len):\n", - " # x, shape [B, T, D]\n", - " fw_x = self.fw_bn(self.fw_fc(x))\n", - " bw_x = self.bw_bn(self.bw_fc(x))\n", - " fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n", - " bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n", - " x = paddle.concat([fw_x, bw_x], axis=-1)\n", - " return x, x_len\n", - "\n", - "\n", - "class RNNStack(nn.Layer):\n", - " \"\"\"RNN group with stacked bidirectional simple RNN or GRU layers.\n", - "\n", - " :param input: Input layer.\n", - " :type input: Variable\n", - " :param size: Dimension of RNN cells in each layer.\n", - " :type size: int\n", - " :param num_stacks: Number of stacked rnn layers.\n", - " :type num_stacks: int\n", - " :param use_gru: Use gru if set True. Use simple rnn if set False.\n", - " :type use_gru: bool\n", - " :param share_rnn_weights: Whether to share input-hidden weights between\n", - " forward and backward directional RNNs.\n", - " It is only available when use_gru=False.\n", - " :type share_weights: bool\n", - " :return: Output layer of the RNN group.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):\n", - " super().__init__()\n", - " self.rnn_stacks = nn.LayerList()\n", - " for i in range(num_stacks):\n", - " if use_gru:\n", - " #default:GRU using tanh\n", - " self.rnn_stacks.append(\n", - " BiGRUWithBN(i_size=i_size, h_size=h_size, act=\"relu\"))\n", - " else:\n", - " self.rnn_stacks.append(\n", - " BiRNNWithBN(\n", - " i_size=i_size,\n", - " h_size=h_size,\n", - " share_weights=share_rnn_weights))\n", - " i_size = h_size * 2\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x: shape [B, T, D]\n", - " x_len: shpae [B]\n", - " \"\"\"\n", - " for i, rnn in enumerate(self.rnn_stacks):\n", - " x, x_len = rnn(x, x_len)\n", - " masks = sequence_mask(x_len) #[B, T]\n", - " masks = masks.unsqueeze(-1) # [B, T, 1]\n", - " x = x.multiply(masks)\n", - " return x, x_len\n", - "\n", - " \n", - "class DeepSpeech2Test(DeepSpeech2):\n", - " def __init__(self,\n", - " feat_size,\n", - " dict_size,\n", - " num_conv_layers=2,\n", - " num_rnn_layers=3,\n", - " rnn_size=256,\n", - " use_gru=False,\n", - " share_rnn_weights=True):\n", - " super().__init__(feat_size,\n", - " dict_size,\n", - " num_conv_layers=2,\n", - " num_rnn_layers=3,\n", - " rnn_size=256,\n", - " use_gru=False,\n", - " share_rnn_weights=True)\n", - " self.feat_size = feat_size # 161 for linear\n", - " self.dict_size = dict_size\n", - "\n", - " self.conv = ConvStack(feat_size, num_conv_layers)\n", - " \n", - "# self.fc = nn.Linear(1312, dict_size + 1)\n", - "\n", - " i_size = self.conv.output_height # H after conv stack\n", - " self.rnn = RNNStack(\n", - " i_size=i_size,\n", - " h_size=rnn_size,\n", - " num_stacks=num_rnn_layers,\n", - " use_gru=use_gru,\n", - " share_rnn_weights=share_rnn_weights)\n", - " \n", - " self.fc = nn.Linear(rnn_size * 2, dict_size + 1)\n", - " \n", - " def infer(self, audio, audio_len):\n", - " # [B, D, T] -> [B, C=1, D, T]\n", - " audio = audio.unsqueeze(1)\n", - "\n", - " # convolution group\n", - " x, audio_len = self.conv(audio, audio_len)\n", - " print('conv out', x.shape)\n", - "\n", - " # convert data from convolution feature map to sequence of vectors\n", - " B, C, D, T = paddle.shape(x)\n", - " x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]\n", - " x = x.reshape([B, T, C * D]) #[B, T, C*D]\n", - " print('rnn input', x.shape)\n", - "\n", - " # remove padding part\n", - " x, audio_len = self.rnn(x, audio_len) #[B, T, D]\n", - " print('rnn output', x.shape)\n", - "\n", - " logits = self.fc(x) #[B, T, V + 1]\n", - "\n", - " #ctcdecoder need probs, not log_probs\n", - " probs = F.softmax(logits)\n", - "\n", - " return logits, probs, audio_len\n", - "\n", - " def forward(self, audio, audio_len, text, text_len):\n", - " \"\"\"\n", - " audio: shape [B, D, T]\n", - " text: shape [B, T]\n", - " audio_len: shape [B]\n", - " text_len: shape [B]\n", - " \"\"\"\n", - " return self.infer(audio, audio_len)\n", - " \n", - "\n", - "feat_dim=161\n", - "\n", - "model = DeepSpeech2Test(\n", - " feat_size=feat_dim,\n", - " dict_size=batch_reader.dataset.vocab_size,\n", - " num_conv_layers=args.num_conv_layers,\n", - " num_rnn_layers=args.num_rnn_layers,\n", - " rnn_size=1024,\n", - " use_gru=args.use_gru,\n", - " share_rnn_weights=args.share_rnn_weights,\n", - " )\n", - "dp_model = model\n", - "#dp_model = paddle.DataParallel(model)\n", - "\n", - "loss_fn = DeepSpeech2Loss(batch_reader.dataset.vocab_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "divided-incentive", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "discrete-conjunction", - "metadata": {}, - "outputs": [], - "source": [ - "audio, audio_len, text, text_len = None, None, None, None\n", - "\n", - "for idx, inputs in enumerate(batch_reader):\n", - " audio, audio_len, text, text_len = inputs\n", - "# print(idx)\n", - "# print('a', audio.shape, audio.place)\n", - "# print('t', text)\n", - "# print('al', audio_len)\n", - "# print('tl', text_len)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "protected-announcement", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "conv in: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "conv in: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "conv out: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "conv out [5, 32, 41, 62]\n", - "rnn input [5, 62, 1312]\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", - " return (isinstance(seq, collections.Sequence) and\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "rnn output [5, 62, 2048]\n", - "logits len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "loss Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [2316.82153320])\n" - ] - } - ], - "source": [ - "outputs = dp_model(audio, audio_len, text, text_len)\n", - "logits, _, logits_len = outputs\n", - "print('logits len', logits_len)\n", - "loss = loss_fn.forward(logits, text, logits_len, text_len)\n", - "print('loss', loss)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "universal-myrtle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: None\n", - "param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: None\n", - "param grad: fc.bias: shape: [4299] stop_grad: False grad: None\n" - ] - } - ], - "source": [ - "for n, p in dp_model.named_parameters():\n", - " print(\n", - " f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "referenced-double", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: [[[[ 2.1243238 1.696022 3.770659 ... 5.234652 5.4865217\n", - " 4.757795 ]\n", - " [ 2.651376 2.3109848 4.428488 ... 5.353201 8.703288\n", - " 5.1787405 ]\n", - " [ 2.7511077 1.8823049 2.1875212 ... 3.4821286 6.386543\n", - " 3.5026932 ]\n", - " ...\n", - " [ 1.9173846 1.8623551 0.5601456 ... 2.8375719 3.8496673\n", - " 2.359191 ]\n", - " [ 2.3827765 2.497965 1.5914664 ... 2.220721 3.4617734\n", - " 4.829253 ]\n", - " [ 1.6855702 1.5040786 1.8793598 ... 4.0773935 3.176893\n", - " 3.7477999 ]]]\n", - "\n", - "\n", - " [[[ 1.8451455 2.0091445 1.5225713 ... 1.524528 0.17764974\n", - " 1.0245132 ]\n", - " [ 1.9388857 1.3873467 2.044691 ... 0.92544 -0.9746763\n", - " -0.41603735]\n", - " [ 2.6814485 2.6096234 1.6802506 ... 1.902397 1.6837387\n", - " -0.96788657]\n", - " ...\n", - " [ 4.3675485 1.9822174 1.1695029 ... 1.4672399 3.2029557\n", - " 2.6364415 ]\n", - " [ 3.2536 1.1792442 -0.5618002 ... 2.101127 1.904225\n", - " 3.3839993 ]\n", - " [ 1.9118482 1.0651072 0.5409893 ... 2.6783593 1.6871439\n", - " 4.1078367 ]]]\n", - "\n", - "\n", - " [[[-4.412424 -1.7111907 -1.7722387 ... -4.3383503 -6.2393785\n", - " -6.139402 ]\n", - " [-2.260428 -1.0250616 -2.0550888 ... -5.353946 -4.29947\n", - " -6.158736 ]\n", - " [-1.4927872 0.7552787 -0.0702923 ... -4.485656 -4.0794134\n", - " -5.416684 ]\n", - " ...\n", - " [ 2.9100134 4.156195 4.357041 ... -3.569804 -1.8634341\n", - " -0.8772557 ]\n", - " [ 1.6895763 3.4314504 4.1192107 ... -1.380024 -2.3234155\n", - " -3.6650617 ]\n", - " [ 2.4190075 1.007498 3.1173465 ... -0.96318084 -3.6175003\n", - " -2.5240796 ]]]\n", - "\n", - "\n", - " ...\n", - "\n", - "\n", - " [[[-0.6865506 -0.60106415 -1.5555015 ... 2.0853553 1.900961\n", - " 2.101063 ]\n", - " [-0.31686288 -1.4362946 -1.4929098 ... 0.15085456 1.4540495\n", - " 1.4128599 ]\n", - " [-0.57852304 -0.8204216 -2.3264258 ... 1.4970423 0.54599845\n", - " 1.6222539 ]\n", - " ...\n", - " [ 0.32624918 0.96004546 -0.7476514 ... 2.2786083 2.1000178\n", - " 2.7494807 ]\n", - " [-1.6967826 -0.78979015 -1.8424999 ... 1.0620685 2.0544293\n", - " 2.2483966 ]\n", - " [ 0.8192332 2.601636 -2.6636481 ... 0.26625186 1.7610842\n", - " 1.7467536 ]]]\n", - "\n", - "\n", - " [[[ 0.9140297 0.42424175 1.4352363 ... -2.3022954 -3.001058\n", - " -2.6987422 ]\n", - " [ 0.4491998 -0.10698095 1.5089144 ... -3.2831016 -3.6055021\n", - " -3.6595795 ]\n", - " [ 2.6818252 -1.5750014 -0.34812498 ... -4.4137015 -4.250422\n", - " -3.481941 ]\n", - " ...\n", - " [ 1.4232106 2.9689102 3.9547806 ... -0.481165 0.28190404\n", - " -1.2167063 ]\n", - " [ 2.2297084 4.8198485 4.2857304 ... 0.57483846 1.4093391\n", - " 0.0715822 ]\n", - " [ 1.679745 4.768068 5.416195 ... 0.17254728 0.4623217\n", - " 1.4772662 ]]]\n", - "\n", - "\n", - " [[[-2.0860114 -2.9508173 -1.4945896 ... -4.067145 -2.5652342\n", - " -3.5771027 ]\n", - " [-2.697845 -1.9273603 -2.3885014 ... -2.196533 -2.8573706\n", - " -2.0113711 ]\n", - " [-2.413383 -2.7204053 -1.0502659 ... -3.001385 -3.36447\n", - " -4.3225455 ]\n", - " ...\n", - " [ 1.2754489 0.9560999 1.5239805 ... -0.0105865 -1.00876\n", - " 2.6247358 ]\n", - " [ 1.1965859 1.0378222 1.1025598 ... -0.5394704 0.49838027\n", - " -0.9618193 ]\n", - " [ 1.1361816 1.3232857 0.687318 ... -0.23925456 -0.43679112\n", - " -0.79297894]]]]\n", - "param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: [ 5.9604645e-07 -3.9339066e-06 -1.0728836e-06 -1.6689301e-06\n", - " 1.1920929e-06 -2.5033951e-06 -2.3841858e-07 4.7683716e-07\n", - " 4.2915344e-06 -1.9073486e-06 -1.9073486e-06 3.0994415e-06\n", - " -2.6822090e-06 3.3378601e-06 -4.2915344e-06 5.2452087e-06\n", - " 3.8146973e-06 2.3841858e-07 7.1525574e-07 -3.6954880e-06\n", - " 2.0563602e-06 -2.6226044e-06 3.0994415e-06 -3.5762787e-07\n", - " -4.7683716e-06 1.2218952e-06 3.3378601e-06 -2.5629997e-06\n", - " 2.3841858e-07 -1.7881393e-06 4.7683716e-07 -2.7418137e-06]\n", - "param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: [ 2.363316 3.286464 1.9607866 -1.6367784 -1.6325372 -1.7729434\n", - " -0.9261875 2.0950415 0.1155543 -0.8857083 0.70079553 0.33920464\n", - " 2.6953902 -0.64524114 0.8845749 -1.2271115 0.6578167 -2.939814\n", - " 5.5728893 -1.0917969 0.01470797 1.395206 4.8009634 -0.744532\n", - " 0.944651 -1.092311 1.4877632 -3.042566 0.51686054 -5.4768667\n", - " -5.628145 -1.0894046 ]\n", - "param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: [ 1.5193373 1.8838218 3.7722278 0.28052303 0.5386534 -0.44620085\n", - " -1.6977876 3.115642 0.03312349 -2.9121587 3.8925257 0.2288351\n", - " -2.273387 -1.3597974 4.3708124 -0.23374033 0.116272 -0.7064927\n", - " 6.5267463 -1.5318865 1.0288429 0.7928574 -0.24655592 -2.1116853\n", - " 2.922772 -3.3462617 1.7016437 -3.5471547 0.29777628 -3.2820854\n", - " -4.116946 -0.9909375 ]\n", - "param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: [[[[ 6.20494843e-01 5.95983505e-01 -1.48909020e+00 ... -6.86620831e-01\n", - " 6.71104014e-01 -1.95339048e+00]\n", - " [-3.91837955e-03 1.27062631e+00 -1.63248098e+00 ... 1.07290137e+00\n", - " -9.42245364e-01 -3.34277248e+00]\n", - " [ 2.41821265e+00 2.36212373e-01 -1.84433365e+00 ... 1.23182368e+00\n", - " 1.36039746e+00 -2.94621849e+00]\n", - " ...\n", - " [ 1.55153418e+00 7.25861669e-01 2.08785534e+00 ... -6.40172660e-01\n", - " -3.23889256e-02 -2.30832791e+00]\n", - " [ 3.69824195e+00 1.27163112e-01 4.09263194e-01 ... -8.60729575e-01\n", - " -3.51897454e+00 -2.10093403e+00]\n", - " [-4.94779050e-01 -3.74262631e-01 -1.19801068e+00 ... -2.05930543e+00\n", - " -7.38576293e-01 -9.44581270e-01]]\n", - "\n", - " [[-2.04341412e+00 -3.70606273e-01 -1.40429378e+00 ... -1.71711946e+00\n", - " -4.09437418e-01 -1.74107194e+00]\n", - " [-8.72247815e-01 -1.06301677e+00 -9.19306517e-01 ... -2.98976970e+00\n", - " -3.03250861e+00 -2.37099743e+00]\n", - " [-5.00457406e-01 -1.11882675e+00 -5.91526508e-01 ... 4.23921436e-01\n", - " -2.08650708e+00 -1.82109618e+00]\n", - " ...\n", - " [ 2.07773042e+00 1.40735030e-01 -2.60543615e-01 ... -1.55956164e-01\n", - " -1.31862307e+00 -2.07174897e+00]\n", - " [ 7.95007765e-01 1.14988625e-01 -1.43308258e+00 ... 8.29253554e-01\n", - " -9.57888126e-01 -3.82121086e-01]\n", - " [ 8.34397674e-02 1.38636863e+00 -1.21593380e+00 ... -2.65783578e-01\n", - " 1.78124309e-02 -3.40287232e+00]]\n", - "\n", - " [[ 6.27344131e-01 5.71699142e-02 -3.58010936e+00 ... -4.53077674e-01\n", - " 1.65331578e+00 2.58466601e-02]\n", - " [ 2.66681361e+00 2.02069378e+00 -1.52052927e+00 ... 2.94914508e+00\n", - " 1.94632411e+00 -1.06698799e+00]\n", - " [ 1.57839453e+00 -1.03649735e-01 -4.22528505e+00 ... 2.28863955e+00\n", - " 4.27859402e+00 3.66381669e+00]\n", - " ...\n", - " [-2.44603205e+00 -2.09621000e+00 -2.57623529e+00 ... 9.00211930e-01\n", - " 4.30536079e+00 -2.49779320e+00]\n", - " [-2.52187514e+00 -3.36546659e+00 -1.26748765e+00 ... 8.11533451e-01\n", - " 2.55930424e-01 4.50821817e-02]\n", - " [-3.40082574e+00 -3.26924801e+00 -5.86932135e+00 ... -1.18203712e+00\n", - " 1.09565187e+00 -4.96661961e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 8.20469666e+00 6.96195841e+00 2.73753977e+00 ... 8.34498823e-01\n", - " 2.56748104e+00 1.67592216e+00]\n", - " [ 9.85801792e+00 8.81465149e+00 6.09280396e+00 ... 1.42389655e+00\n", - " 2.92086434e+00 2.08308399e-01]\n", - " [ 8.00702763e+00 7.97301006e+00 4.64527416e+00 ... 8.61916900e-01\n", - " 3.55370259e+00 4.75085378e-01]\n", - " ...\n", - " [ 5.61662769e+00 -4.72857296e-01 -1.04519971e-01 ... -4.03000236e-01\n", - " -1.66419971e+00 -1.70375630e-01]\n", - " [ 4.52409792e+00 -3.70670676e-01 4.54190969e-02 ... -8.20453286e-01\n", - " 9.49141383e-02 8.88008535e-01]\n", - " [ 3.27219462e+00 8.93201411e-01 1.94810414e+00 ... -2.86915004e-02\n", - " 1.93200278e+00 8.19505215e-01]]\n", - "\n", - " [[ 5.84066296e+00 6.72855520e+00 5.21399307e+00 ... 4.55058670e+00\n", - " 3.19132543e+00 3.17435169e+00]\n", - " [ 6.04594421e+00 6.88997173e+00 5.00542831e+00 ... 2.23561144e+00\n", - " 2.76059532e+00 4.83479440e-01]\n", - " [ 5.36118126e+00 4.13896275e+00 3.68701124e+00 ... 3.64462805e+00\n", - " 2.80596399e+00 1.52781498e+00]\n", - " ...\n", - " [ 2.87856674e+00 5.84320784e-01 1.74297714e+00 ... 2.83938944e-01\n", - " -2.26546407e-01 -1.18434143e+00]\n", - " [ 2.08510804e+00 1.74915957e+00 1.58637917e+00 ... 6.41967297e-01\n", - " -1.31319761e-01 -3.85830402e-01]\n", - " [ 4.41666174e+00 2.58244562e+00 2.97712159e+00 ... 1.42317235e-01\n", - " 1.68037796e+00 -6.50003672e-01]]\n", - "\n", - " [[ 1.05511594e+00 6.74880028e-01 -7.64639139e-01 ... -2.15282440e-01\n", - " 2.07197094e+00 4.48752761e-01]\n", - " [ 2.12095881e+00 3.44118834e+00 1.61375272e+00 ... -1.18487728e+00\n", - " 1.88659012e+00 1.48252523e+00]\n", - " [ 8.33427787e-01 4.35035896e+00 -3.59877385e-02 ... 8.70242774e-01\n", - " 3.75945044e+00 -3.09408635e-01]\n", - " ...\n", - " [ 5.08510351e+00 4.73114061e+00 1.97346115e+00 ... -2.25924397e+00\n", - " -1.26373076e+00 -1.37826729e+00]\n", - " [ 6.17275095e+00 4.16016817e+00 3.15675950e+00 ... -2.02416754e+00\n", - " 1.50002241e-02 1.84633851e+00]\n", - " [ 7.32995272e+00 5.34601831e+00 4.58857203e+00 ... -1.88874304e+00\n", - " 1.53240371e+00 7.47349262e-02]]]\n", - "\n", - "\n", - " [[[-1.80918843e-01 -2.52616453e+00 -2.78145695e+00 ... 1.44283652e+00\n", - " -1.08945215e+00 4.19084758e-01]\n", - " [-9.66833949e-01 -2.41106153e+00 -3.48886085e+00 ... -1.87193304e-01\n", - " 8.21905077e-01 1.89097953e+00]\n", - " [-1.59118319e+00 -2.56997013e+00 -3.10426521e+00 ... 2.05900550e+00\n", - " -2.78253704e-01 6.96343541e-01]\n", - " ...\n", - " [ 6.66302443e-02 -2.00887346e+00 -3.17550874e+00 ... 7.97579706e-01\n", - " -9.71581042e-02 1.71877682e+00]\n", - " [-8.01679730e-01 -2.02678037e+00 -3.21915555e+00 ... 8.35528374e-01\n", - " -1.15296638e+00 4.35728967e-01]\n", - " [ 1.45292446e-01 -2.15479851e+00 -1.51839817e+00 ... -3.07936192e-01\n", - " -5.39051890e-01 1.13107657e+00]]\n", - "\n", - " [[-2.43341160e+00 -3.35346818e+00 -9.87014294e-01 ... 1.34049034e+00\n", - " 2.95773447e-02 1.27177119e+00]\n", - " [-2.61602497e+00 -9.76761580e-01 -2.52060473e-01 ... -1.38134825e+00\n", - " 3.85564029e-01 4.57195908e-01]\n", - " [-2.23676014e+00 -4.00404739e+00 -2.23409963e+00 ... -1.41846514e+00\n", - " -6.58698231e-02 -3.61778140e-01]\n", - " ...\n", - " [-1.13604403e+00 -6.03917837e-02 -4.95491922e-01 ... 2.14673686e+00\n", - " 1.21484184e+00 2.22764325e+00]\n", - " [-1.05162430e+00 -1.59828448e+00 3.15489501e-01 ... 2.28046751e+00\n", - " 2.39702511e+00 2.43942714e+00]\n", - " [-1.27370405e+00 -2.05736399e-01 -1.12124372e+00 ... 2.21597219e+00\n", - " 2.50086927e+00 1.91134131e+00]]\n", - "\n", - " [[-4.53170598e-01 -1.59644139e+00 -3.63470483e+00 ... -4.35066032e+00\n", - " -3.79540777e+00 -1.09796596e+00]\n", - " [-2.21036464e-01 -2.53353834e+00 -1.28269875e+00 ... -3.38615727e+00\n", - " -2.59143281e+00 7.74220943e-01]\n", - " [-6.89323783e-01 -1.44375205e+00 6.66438341e-02 ... -1.30736077e+00\n", - " -1.23293114e+00 1.58148706e+00]\n", - " ...\n", - " [ 1.63751483e+00 -4.08427984e-01 -8.15176964e-01 ... 3.70807743e+00\n", - " 2.04232907e+00 1.97716308e+00]\n", - " [ 2.13261342e+00 1.85947633e+00 -8.06532025e-01 ... 1.98311245e+00\n", - " 2.27003932e+00 -1.11734614e-01]\n", - " [ 1.28702402e+00 3.98628891e-01 -1.63712263e+00 ... 8.00528765e-01\n", - " 5.78273535e-01 -2.59924948e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 3.96233416e+00 4.66794682e+00 1.39437711e+00 ... 7.52061129e-01\n", - " -1.53534544e+00 -6.67162359e-01]\n", - " [ 2.33841681e+00 3.35811281e+00 9.80114818e-01 ... 1.48806703e+00\n", - " 2.68609226e-01 -1.35124445e+00]\n", - " [ 2.08177710e+00 4.28519583e+00 1.52450514e+00 ... 7.45321214e-01\n", - " -5.04359961e-01 -1.81241560e+00]\n", - " ...\n", - " [ 2.95398951e-01 4.30877179e-01 -2.03731894e+00 ... -4.20221925e-01\n", - " 3.29260826e-01 5.83679557e-01]\n", - " [ 1.30742240e+00 -6.32183790e-01 -3.13741422e+00 ... 9.63868052e-02\n", - " 2.91730791e-01 1.33400351e-01]\n", - " [ 5.43292165e-01 -2.83665359e-01 -1.88138187e+00 ... 2.15468198e-01\n", - " 4.90157723e-01 2.40562439e+00]]\n", - "\n", - " [[ 1.57632053e+00 6.27885723e+00 2.87853765e+00 ... 3.07016110e+00\n", - " 1.91490650e+00 1.76274943e+00]\n", - " [ 2.57776356e+00 4.07256317e+00 2.52231169e+00 ... 4.09494352e+00\n", - " 2.53548074e+00 2.44395185e+00]\n", - " [ 2.43037057e+00 4.35728836e+00 1.96233964e+00 ... 2.26702976e+00\n", - " 2.94634581e+00 2.21452284e+00]\n", - " ...\n", - " [-2.72509992e-01 -8.41220498e-01 -1.89133918e+00 ... -1.80079627e+00\n", - " -2.00367713e+00 -7.09145784e-01]\n", - " [ 8.21575999e-01 -1.13323164e+00 -2.62418866e+00 ... -2.38889670e+00\n", - " -7.83945560e-01 -1.01922750e-01]\n", - " [-1.14730227e+00 -1.42182577e+00 -2.00993991e+00 ... -2.11025667e+00\n", - " 1.60286129e-02 -7.26446986e-01]]\n", - "\n", - " [[ 4.20389509e+00 3.75917768e+00 4.97653627e+00 ... 1.23642838e+00\n", - " 8.52760911e-01 1.27920091e-01]\n", - " [ 5.29409122e+00 5.29002380e+00 3.96404648e+00 ... 1.91227329e+00\n", - " 3.97556186e-01 1.69182217e+00]\n", - " [ 4.60112572e+00 4.12772799e+00 2.10280085e+00 ... 3.24303842e+00\n", - " -1.07720590e+00 -3.81854475e-01]\n", - " ...\n", - " [ 1.81884170e-02 -3.11472058e+00 -8.23525012e-01 ... -2.40161085e+00\n", - " -4.48192549e+00 -6.14600539e-01]\n", - " [ 1.16305006e+00 -1.15409636e+00 -3.48765063e+00 ... -1.97504926e+00\n", - " -4.44984436e+00 -2.28429958e-01]\n", - " [ 1.29197860e+00 6.17720246e-01 -5.87171853e-01 ... -1.35258228e-01\n", - " -1.29259872e+00 1.30360842e-01]]]\n", - "\n", - "\n", - " [[[-1.26687372e+00 -2.33633637e+00 -1.49625254e+00 ... 2.52396107e+00\n", - " -6.68072224e-01 -1.13282454e+00]\n", - " [-1.34229445e+00 -2.87080932e+00 -2.57388353e+00 ... -8.75385761e-01\n", - " -1.00205469e+00 -3.58956242e+00]\n", - " [-9.49853599e-01 -5.78684711e+00 -3.52962446e+00 ... 8.88233304e-01\n", - " 2.25133196e-01 -1.02802217e+00]\n", - " ...\n", - " [-7.38113701e-01 -3.47510982e+00 -3.23011065e+00 ... -1.25624001e+00\n", - " -1.63268471e+00 6.00247443e-01]\n", - " [-2.29733467e+00 -5.72547615e-01 -1.98301303e+00 ... -1.90137398e+00\n", - " -1.47013855e+00 -1.45779204e+00]\n", - " [-2.24628520e+00 -3.36337948e+00 -3.91878939e+00 ... -1.53652275e+00\n", - " -1.36285520e+00 -1.68160331e+00]]\n", - "\n", - " [[-8.11348319e-01 -7.17824280e-01 -1.02243233e+00 ... -2.69050407e+00\n", - " -2.32403350e+00 -4.25943947e+00]\n", - " [-2.35056520e+00 -2.35941172e+00 -1.24398732e+00 ... -2.08313870e+00\n", - " -1.16508257e+00 -1.30353463e+00]\n", - " [-2.25146723e+00 -1.94972813e+00 -1.13295293e+00 ... -2.61496377e+00\n", - " -1.91106403e+00 -1.07801402e+00]\n", - " ...\n", - " [-2.67012739e+00 -3.20916414e+00 -2.41768575e+00 ... 2.65138328e-01\n", - " -5.27612507e-01 1.44604075e+00]\n", - " [-3.54237866e+00 -3.62832785e+00 -2.40270257e+00 ... -9.76106226e-02\n", - " 4.67946082e-01 -7.24248111e-01]\n", - " [-2.49844384e+00 -3.42463255e+00 -2.99040008e+00 ... 4.28889185e-01\n", - " -7.51657963e-01 -1.00530767e+00]]\n", - "\n", - " [[-8.42589438e-02 1.42022014e-01 -8.51281703e-01 ... 4.21745628e-01\n", - " -2.35717297e-02 -1.71374834e+00]\n", - " [-1.05496287e+00 3.82416457e-01 -4.40595537e-01 ... 1.03381336e-01\n", - " -1.41204190e+00 -7.58325040e-01]\n", - " [-2.28930283e+00 -2.03857040e+00 -9.16261196e-01 ... -3.94939929e-01\n", - " -1.07798588e+00 -1.48433352e+00]\n", - " ...\n", - " [-3.11473966e-01 -1.40877593e+00 -2.42908645e+00 ... 7.88682699e-01\n", - " 1.24199319e+00 1.89949930e-01]\n", - " [ 5.44084549e-01 -1.02425671e+00 -1.53991556e+00 ... -4.36764538e-01\n", - " -5.78772545e-01 2.62665659e-01]\n", - " [ 1.26812792e+00 -9.89493608e-01 -1.47972977e+00 ... 2.21440494e-02\n", - " 2.79776216e-01 7.63269484e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 6.02095068e-01 5.93243122e-01 -1.06838238e+00 ... 3.56546330e+00\n", - " 1.16390383e+00 -1.47593319e-01]\n", - " [ 1.80458140e+00 1.68401957e+00 4.17516947e-01 ... 3.33444500e+00\n", - " 1.89411759e+00 1.03220642e-01]\n", - " [ 2.74264169e+00 2.92038846e+00 1.00775683e+00 ... 3.53285050e+00\n", - " 2.07282662e+00 -2.56800652e-01]\n", - " ...\n", - " [ 4.88933468e+00 3.72433925e+00 3.58677816e+00 ... 1.98363388e+00\n", - " 1.80851030e+00 8.32634747e-01]\n", - " [ 4.01546288e+00 4.78934765e+00 2.94778132e+00 ... 2.99637699e+00\n", - " 1.30439472e+00 3.61029744e-01]\n", - " [ 3.13628030e+00 2.01894832e+00 2.82585931e+00 ... 2.54264188e+00\n", - " -9.16651785e-02 9.93353873e-02]]\n", - "\n", - " [[ 2.35585642e+00 8.42678428e-01 1.57331872e+00 ... 3.65935063e+00\n", - " 3.94066262e+00 4.89832020e+00]\n", - " [ 1.85791731e+00 1.34373701e+00 1.30812299e+00 ... 2.71434736e+00\n", - " 3.22004294e+00 2.99872303e+00]\n", - " [ 1.67675853e+00 -4.05569375e-02 1.85539150e+00 ... 3.73934364e+00\n", - " 2.98195982e+00 3.37315011e+00]\n", - " ...\n", - " [ 2.14539170e+00 2.86586595e+00 2.20222116e+00 ... 1.20492995e+00\n", - " 2.13971066e+00 1.94932449e+00]\n", - " [ 4.68422651e+00 3.80044746e+00 4.23209000e+00 ... 2.40658951e+00\n", - " 2.29117441e+00 2.52368808e+00]\n", - " [ 3.10694575e+00 2.49402595e+00 4.53786707e+00 ... 9.08902645e-01\n", - " 1.86903965e+00 2.27776885e+00]]\n", - "\n", - " [[ 1.45200038e+00 5.17961740e-01 -1.58403587e+00 ... 5.07019472e+00\n", - " 7.87163258e-01 1.20610237e+00]\n", - " [ 3.39321136e+00 2.21043849e+00 -6.31202877e-01 ... 4.97822762e+00\n", - " 9.66498017e-01 1.18883348e+00]\n", - " [ 1.20627856e+00 1.82759428e+00 5.91053367e-01 ... 4.14318657e+00\n", - " 5.25399208e-01 -1.16850233e+00]\n", - " ...\n", - " [ 1.05183899e+00 5.80030501e-01 1.89724147e+00 ... 2.54626465e+00\n", - " -1.49128008e+00 -1.85064209e+00]\n", - " [ 1.50983357e+00 2.85973406e+00 2.61224055e+00 ... 4.83481932e+00\n", - " 9.67048705e-02 -4.37043965e-01]\n", - " [ 2.57720876e+00 2.09961963e+00 4.11754288e-02 ... 3.80421424e+00\n", - " -7.83308804e-01 -1.64871216e+00]]]\n", - "\n", - "\n", - " ...\n", - "\n", - "\n", - " [[[-1.16345096e+00 -2.53971386e+00 -8.99101734e-01 ... -4.35583591e-01\n", - " -1.29671764e+00 -1.61429560e+00]\n", - " [ 3.72841507e-01 3.45808208e-01 -1.82167351e+00 ... -2.14515448e+00\n", - " -1.26383066e+00 -2.27464601e-01]\n", - " [ 1.58568513e+00 2.58181524e+00 1.86554670e+00 ... -1.10401320e+00\n", - " -3.68550658e-01 -2.58849680e-01]\n", - " ...\n", - " [-9.15827155e-01 -1.25424683e+00 -4.04716206e+00 ... 2.13138080e+00\n", - " 2.67662477e+00 2.31014514e+00]\n", - " [-3.19453120e-01 -6.71132684e-01 -1.51378751e+00 ... 1.86080432e+00\n", - " 2.77418542e+00 1.22875953e+00]\n", - " [-1.20453942e+00 -3.93669218e-01 -1.51751983e+00 ... 1.17620552e+00\n", - " 1.95602298e+00 7.64306366e-01]]\n", - "\n", - " [[-8.73186827e-01 -2.12537169e+00 -1.91664994e+00 ... -2.90821463e-01\n", - " 1.90896463e+00 8.02283168e-01]\n", - " [-1.06389821e+00 -2.15300727e+00 -1.82113051e+00 ... -4.34280694e-01\n", - " 1.53455496e+00 1.94702053e+00]\n", - " [-2.08403468e+00 -4.72900331e-01 -1.10610819e+00 ... -8.79420400e-01\n", - " 7.79394627e-01 2.02670670e+00]\n", - " ...\n", - " [-4.28208113e-01 -7.90894389e-01 -1.06713009e+00 ... 1.12579381e+00\n", - " 9.61961091e-01 1.40342009e+00]\n", - " [ 4.40416574e-01 -1.65901780e-02 -1.05338669e+00 ... 1.40698349e+00\n", - " 9.43485856e-01 2.34856772e+00]\n", - " [-1.20572495e+00 -2.03134632e+00 4.88817632e-01 ... 2.20770907e+00\n", - " 1.38143206e+00 2.00714707e+00]]\n", - "\n", - " [[ 9.00486887e-01 -9.50459957e-01 -1.42935121e+00 ... -1.30648065e+00\n", - " -2.52133775e+00 -8.87715697e-01]\n", - " [ 3.73431134e+00 1.69571114e+00 5.99429727e-01 ... 6.64332986e-01\n", - " -6.10453069e-01 2.06534386e+00]\n", - " [ 1.59800696e+00 -4.59622175e-01 -6.73136234e-01 ... 2.18770742e-01\n", - " -1.12928271e+00 4.87097502e-02]\n", - " ...\n", - " [ 1.92336845e+00 1.37130380e-01 -3.51048648e-01 ... 5.41638851e-01\n", - " 1.06069386e+00 1.36404145e+00]\n", - " [ 1.29641414e+00 -2.79530913e-01 -2.63607264e-01 ... -8.62445176e-01\n", - " 1.48393130e+00 2.69196725e+00]\n", - " [ 1.14442182e+00 -1.24098969e+00 3.70959163e-01 ... -1.12241995e+00\n", - " 3.67927134e-01 2.55976987e+00]]\n", - "\n", - " ...\n", - "\n", - " [[ 5.32017851e+00 3.64207411e+00 3.84571218e+00 ... 3.60754800e+00\n", - " 2.57500267e+00 -1.38083458e-01]\n", - " [ 5.69058084e+00 3.93056583e+00 2.93337941e+00 ... 3.17091584e+00\n", - " 2.34770632e+00 6.48133337e-01]\n", - " [ 5.98239613e+00 6.16548634e+00 3.04750896e+00 ... 5.51510525e+00\n", - " 4.34810448e+00 1.31588542e+00]\n", - " ...\n", - " [ 5.09930992e+00 3.32360983e+00 2.29228449e+00 ... 3.45123887e-01\n", - " 1.06280947e+00 -5.93325794e-02]\n", - " [ 4.19760656e+00 3.97779059e+00 1.66905916e+00 ... 3.68937254e-01\n", - " 8.06131065e-02 8.08142900e-01]\n", - " [ 4.52498960e+00 3.45109749e+00 1.01074433e+00 ... -2.54036248e-01\n", - " 3.13675582e-01 2.13851762e+00]]\n", - "\n", - " [[ 6.93927193e+00 6.05758238e+00 4.60648441e+00 ... 4.32221603e+00\n", - " 3.17874146e+00 1.47012353e+00]\n", - " [ 7.88523865e+00 6.62228966e+00 4.77496338e+00 ... 4.45868683e+00\n", - " 2.73698759e+00 2.17057824e+00]\n", - " [ 7.12061214e+00 6.01714134e+00 4.52996492e+00 ... 3.97184372e+00\n", - " 3.43153954e+00 1.21802723e+00]\n", - " ...\n", - " [ 2.85720730e+00 1.89639473e+00 1.96340394e+00 ... 1.89643729e+00\n", - " 1.64856291e+00 1.15853786e+00]\n", - " [ 3.88248491e+00 2.16386199e+00 1.53069091e+00 ... 2.71704245e+00\n", - " 2.24890351e+00 2.22156644e+00]\n", - " [ 5.27136230e+00 1.68400204e+00 2.09500480e+00 ... 2.75956345e+00\n", - " 3.71970820e+00 1.69852686e+00]]\n", - "\n", - " [[ 2.55598164e+00 1.64588141e+00 6.70431674e-01 ... 3.24091220e+00\n", - " 1.48759770e+00 -1.72001183e+00]\n", - " [ 4.33942318e+00 8.40826690e-01 -7.40000725e-01 ... 7.24577069e-01\n", - " 1.74327165e-01 -1.83029580e+00]\n", - " [ 4.39864540e+00 2.28395438e+00 -1.90353513e-01 ... 5.58019161e+00\n", - " 1.05627227e+00 -8.02519619e-01]\n", - " ...\n", - " [ 1.97654784e+00 3.26888156e+00 1.52879453e+00 ... 3.15013933e+00\n", - " 4.66731453e+00 4.98701715e+00]\n", - " [ 1.40016854e+00 3.45761251e+00 3.68359756e+00 ... 1.14207900e+00\n", - " 3.32219076e+00 3.83035636e+00]\n", - " [ 1.99269783e+00 2.15428829e+00 3.35396528e-01 ... 2.45916694e-01\n", - " 2.13785577e+00 4.33214951e+00]]]\n", - "\n", - "\n", - " [[[ 1.35320330e+00 5.05850911e-02 1.04915988e+00 ... 1.82023585e-01\n", - " 2.72914767e-01 3.92112255e-01]\n", - " [ 1.04646444e+00 7.60913491e-01 1.93323612e+00 ... 1.19493449e+00\n", - " -1.44200325e-01 4.07531261e-02]\n", - " [-9.88207340e-01 -1.46165287e+00 1.05884135e-01 ... -3.23057353e-01\n", - " -2.28934169e+00 -7.38609374e-01]\n", - " ...\n", - " [ 1.01198792e+00 2.34331083e+00 1.04566610e+00 ... 1.29697472e-01\n", - " -1.23878837e+00 2.21006930e-01]\n", - " [-3.75360101e-01 1.53673506e+00 -1.32206869e+00 ... -2.55255580e-01\n", - " -6.22699618e-01 -1.73162484e+00]\n", - " [ 4.34735864e-01 5.08327007e-01 -3.49233925e-01 ... -1.04749084e+00\n", - " -1.15777385e+00 -1.13671994e+00]]\n", - "\n", - " [[ 1.67839336e+00 -1.80224836e-01 1.02194118e+00 ... 8.44027162e-01\n", - " 8.81283879e-02 -1.37762165e+00]\n", - " [ 8.39694083e-01 1.32322550e+00 4.02442753e-01 ... -4.21785116e-01\n", - " -9.98012185e-01 -1.11348581e+00]\n", - " [ 7.64424682e-01 8.58965695e-01 2.94626594e-01 ... -6.65519595e-01\n", - " -3.65677416e-01 -2.25250268e+00]\n", - " ...\n", - " [-1.10193872e+00 1.18070498e-01 1.04604781e-01 ... -1.44486964e+00\n", - " -2.52748466e+00 -2.16131711e+00]\n", - " [-1.06079710e+00 -1.48379254e+00 3.80138367e-01 ... -1.62288392e+00\n", - " -2.44736362e+00 -8.78590107e-01]\n", - " [ 3.44401300e-02 -2.60935068e+00 -2.35597759e-01 ... -2.41114974e+00\n", - " -2.45255780e+00 -1.82384634e+00]]\n", - "\n", - " [[ 1.37670958e+00 1.58661580e+00 -2.85664916e-01 ... 1.49081087e+00\n", - " 4.13422853e-01 1.12761199e+00]\n", - " [ 1.54148173e+00 6.22704089e-01 1.41886568e+00 ... 1.59678531e+00\n", - " -8.72656107e-01 1.52415514e-01]\n", - " [ 3.30207205e+00 2.89925170e+00 1.91855145e+00 ... 3.18863559e+00\n", - " 1.87347198e+00 9.48901057e-01]\n", - " ...\n", - " [-1.53920484e+00 1.77375078e-02 -1.02018684e-01 ... 1.94011092e+00\n", - " -6.83587790e-01 1.49154460e+00]\n", - " [-2.27719522e+00 1.02481163e+00 -2.11300224e-01 ... -8.18020821e-01\n", - " 1.54248989e+00 -1.46732473e+00]\n", - " [-4.50206220e-01 3.62383485e+00 1.07175660e+00 ... 4.25961137e-01\n", - " 1.12405360e-01 -6.87821358e-02]]\n", - "\n", - " ...\n", - "\n", - " [[-3.40477467e-01 -2.99311423e+00 -2.12096786e+00 ... 2.27393007e+00\n", - " 4.03424358e+00 3.73335361e+00]\n", - " [-6.99971199e-01 -2.97719741e+00 -2.72910309e+00 ... 1.50101089e+00\n", - " 2.29408574e+00 3.14105940e+00]\n", - " [-1.41648722e+00 -1.86292887e+00 -1.84006739e+00 ... 2.78402638e+00\n", - " 3.91481900e+00 5.32456112e+00]\n", - " ...\n", - " [ 5.97958088e-01 1.50512588e+00 6.23718500e-01 ... 2.83813477e+00\n", - " 3.87909842e+00 3.33359623e+00]\n", - " [ 1.65542316e+00 3.56163192e+00 4.01527691e+00 ... 3.38367462e+00\n", - " 1.55827272e+00 2.50741863e+00]\n", - " [ 2.82036042e+00 2.53322673e+00 4.38798475e+00 ... 4.64642382e+00\n", - " 3.28739667e+00 3.02895570e+00]]\n", - "\n", - " [[-3.47941303e+00 -3.49006844e+00 -2.25583363e+00 ... 1.45181656e-01\n", - " 1.52944064e+00 2.08810711e+00]\n", - " [-2.27786446e+00 -4.59218550e+00 -2.74722624e+00 ... -1.73136210e+00\n", - " 7.46028006e-01 1.74789345e+00]\n", - " [-3.35524082e+00 -4.58244705e+00 -2.40820456e+00 ... -5.04051924e-01\n", - " 1.49640536e+00 2.16613841e+00]\n", - " ...\n", - " [ 5.26107132e-01 2.05329061e+00 2.84252572e+00 ... 1.33222675e+00\n", - " 3.87935114e+00 3.69385266e+00]\n", - " [ 4.38092083e-01 2.15028906e+00 3.13363624e+00 ... 3.36048746e+00\n", - " 5.36551809e+00 2.94915986e+00]\n", - " [ 2.75497317e+00 3.25929213e+00 2.33522987e+00 ... 1.69926262e+00\n", - " 3.93462896e+00 3.68200874e+00]]\n", - "\n", - " [[ 1.10951948e+00 5.31419516e-02 -1.58864903e+00 ... 5.24887085e+00\n", - " 1.60273385e+00 4.90113163e+00]\n", - " [-2.94517064e+00 -2.81092644e+00 -4.89631557e+00 ... 3.99868512e+00\n", - " 1.40544355e+00 2.84833241e+00]\n", - " [-3.51893663e-01 -3.53325534e+00 -2.21239805e+00 ... 4.26225853e+00\n", - " 6.87886119e-01 2.58609629e+00]\n", - " ...\n", - " [ 2.92248201e+00 5.40264511e+00 4.65721560e+00 ... 5.24537373e+00\n", - " 2.30406880e+00 1.29892707e+00]\n", - " [ 1.43473256e+00 4.61167526e+00 3.57578802e+00 ... 5.12181854e+00\n", - " 8.59923482e-01 1.38731599e+00]\n", - " [-6.50881350e-01 2.18233657e+00 2.74669623e+00 ... 4.86368895e+00\n", - " 1.44120216e+00 1.79993320e+00]]]\n", - "\n", - "\n", - " [[[ 1.64106202e+00 3.54410499e-01 -3.54172409e-01 ... 2.32646990e+00\n", - " 1.65043330e+00 3.45897645e-01]\n", - " [ 2.16236949e+00 1.28213906e+00 2.26082468e+00 ... 6.10507369e-01\n", - " 9.12241280e-01 1.27429694e-01]\n", - " [ 2.07962990e+00 7.03816175e-01 2.01272345e+00 ... -2.26959705e-01\n", - " 1.00041127e+00 5.87104559e-02]\n", - " ...\n", - " [-1.62972426e+00 -3.04028845e+00 -1.39124167e+00 ... 2.47561097e+00\n", - " 2.35047388e+00 1.61532843e+00]\n", - " [-1.97368932e+00 -5.44541061e-01 -5.92882216e-01 ... 1.39800012e+00\n", - " 2.32770801e+00 9.96662021e-01]\n", - " [-1.15636075e+00 -1.34654212e+00 -8.50648999e-01 ... 1.85655832e+00\n", - " 2.05776072e+00 5.34575820e-01]]\n", - "\n", - " [[-1.02104437e+00 3.08469892e-01 2.81789303e-01 ... -8.24654043e-01\n", - " -9.85817850e-01 -2.05517030e+00]\n", - " [ 9.50192690e-01 3.35105330e-01 5.31637192e-01 ... -1.42974198e-01\n", - " -1.79659498e+00 -1.58266973e+00]\n", - " [-2.51316994e-01 -1.28709340e+00 3.01498562e-01 ... -1.32253516e+00\n", - " -1.55507576e+00 -9.37123299e-01]\n", - " ...\n", - " [ 2.33016998e-01 2.92454743e+00 3.15420461e+00 ... 1.15574491e+00\n", - " 1.27850962e+00 1.35487700e+00]\n", - " [ 3.81013602e-01 1.44239831e+00 6.64825320e-01 ... -3.89374971e-01\n", - " 1.50716826e-01 1.33641326e+00]\n", - " [ 1.71373415e+00 1.67357373e+00 1.76596940e+00 ... 1.57941079e+00\n", - " 1.60940981e+00 1.78091609e+00]]\n", - "\n", - " [[-5.16522598e+00 -1.68099070e+00 -3.24440050e+00 ... -3.46229005e+00\n", - " -2.18273020e+00 -1.98621082e+00]\n", - " [-3.05743694e+00 9.15392339e-01 -1.93508530e+00 ... -1.82306373e+00\n", - " -2.12960863e+00 -3.45255351e+00]\n", - " [-4.32777822e-01 -1.00303245e+00 -1.61397791e+00 ... -2.08376765e+00\n", - " -3.72989595e-01 -1.36516929e+00]\n", - " ...\n", - " [-5.83641946e-01 4.14125490e+00 1.58227599e+00 ... 2.03144050e+00\n", - " 2.13982654e+00 -1.81909311e+00]\n", - " [-1.74230576e+00 2.39347410e+00 2.44080925e+00 ... 5.43732524e-01\n", - " 2.07899213e+00 -3.71748984e-01]\n", - " [ 3.80016506e-01 7.84988403e-01 1.20596504e+00 ... -2.32057095e+00\n", - " -2.81265080e-01 -3.69353056e+00]]\n", - "\n", - " ...\n", - "\n", - " [[-3.48024845e+00 -2.60937548e+00 -3.84952760e+00 ... 6.68736577e-01\n", - " -1.75104141e-02 -3.54720926e+00]\n", - " [-2.59637117e+00 -5.18190145e+00 -2.33887696e+00 ... 9.13373232e-02\n", - " -3.58282638e+00 -2.40778995e+00]\n", - " [-2.50912881e+00 -1.22113395e+00 -2.34372020e+00 ... 1.40071487e+00\n", - " -1.67449510e+00 -1.14655948e+00]\n", - " ...\n", - " [-5.75253534e+00 -6.67348385e+00 -5.05184650e+00 ... -2.73145151e+00\n", - " -1.48933101e+00 -1.36807609e+00]\n", - " [-3.29049587e+00 -3.73956156e+00 -2.85064268e+00 ... -3.92481357e-01\n", - " -8.00529659e-01 -8.39800835e-01]\n", - " [-4.30351114e+00 -4.21471930e+00 -2.41703367e+00 ... -1.27081513e+00\n", - " 1.67839837e+00 8.47821474e-01]]\n", - "\n", - " [[-5.27856112e-01 -1.09752083e+00 3.39107156e-01 ... 2.00062895e+00\n", - " 8.83528054e-01 2.57416844e-01]\n", - " [-1.58655810e+00 -3.36268663e-01 1.16161990e+00 ... 1.54868484e+00\n", - " 2.38878536e+00 1.84097290e+00]\n", - " [ 5.96052647e-01 2.15484858e-01 1.85280466e+00 ... 2.74587560e+00\n", - " 1.61432290e+00 1.13214278e+00]\n", - " ...\n", - " [-4.57659864e+00 -5.42679739e+00 -4.35204458e+00 ... -1.82452416e+00\n", - " -2.18670201e+00 -3.91811800e+00]\n", - " [-1.32477629e+00 -4.19110394e+00 -3.41308069e+00 ... 1.39622003e-01\n", - " -1.59393203e+00 -9.08105671e-01]\n", - " [-3.60161018e+00 -4.05932713e+00 -2.23674798e+00 ... 9.09647286e-01\n", - " 9.73127842e-01 1.19991803e+00]]\n", - "\n", - " [[ 2.04062796e+00 7.95603275e-01 -1.28833270e+00 ... 4.64749050e+00\n", - " 2.25974560e+00 1.02396965e+00]\n", - " [ 1.68882537e+00 2.63353348e+00 2.53597498e-02 ... 4.69063854e+00\n", - " -4.19382691e-01 2.91669458e-01]\n", - " [ 7.71395087e-01 1.20833695e+00 -2.58601785e-01 ... 1.21794045e+00\n", - " -1.51922226e-01 7.44265199e-01]\n", - " ...\n", - " [-6.66095781e+00 -4.81577682e+00 -5.39921665e+00 ... -2.20548606e+00\n", - " 5.72486281e-01 -4.35207397e-01]\n", - " [-7.51608658e+00 -6.67776871e+00 -3.73199415e+00 ... -1.70327055e+00\n", - " 1.01334639e-02 -3.20627165e+00]\n", - " [-5.73050356e+00 -2.74379373e+00 -3.70248461e+00 ... -1.09794116e+00\n", - " -1.73590891e-02 -1.80156028e+00]]]]\n", - "param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: [-1.4305115e-06 0.0000000e+00 -4.0531158e-06 -1.6689301e-06\n", - " 2.3841858e-07 -7.1525574e-07 1.1920929e-06 1.5497208e-06\n", - " -2.3841858e-07 1.6689301e-06 9.5367432e-07 9.5367432e-07\n", - " -2.6226044e-06 1.1920929e-06 1.3113022e-06 1.9669533e-06\n", - " -4.7683716e-07 1.1920929e-06 -1.6689301e-06 -1.5497208e-06\n", - " -2.2649765e-06 4.7683716e-07 2.3841858e-06 -3.5762787e-06\n", - " 2.3841858e-07 2.1457672e-06 -3.5762787e-07 8.3446503e-07\n", - " -3.5762787e-07 -7.1525574e-07 2.6524067e-06 -1.1920929e-06]\n", - "param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: [-3.7669735 1.5226867 1.759756 4.501629 -2.2077336 0.18411277\n", - " 1.3558264 -1.0269645 3.9628277 3.9300344 -2.80754 1.8462183\n", - " -0.03385968 2.1284049 0.46124816 -4.364863 0.78491163 0.25565645\n", - " -5.3538237 3.2606194 0.79100513 -1.4652673 2.769378 1.2283417\n", - " -4.7466464 -1.3404545 -6.9374166 0.710248 2.0944448 0.4334769\n", - " -0.24313992 0.31392363]\n", - "param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: [-0.6251638 2.833331 0.6993131 3.7106915 -2.262496 0.7390424\n", - " 0.5360477 -2.803875 2.1646228 2.117193 -1.9988279 1.5135905\n", - " -2.0181084 2.6450465 0.06302822 -3.0530102 1.4788482 0.5941844\n", - " -3.1690063 1.8753575 -0.0737313 -2.7806277 -0.04483938 0.16129279\n", - " -1.2960215 -0.38020235 -0.55218065 0.10754502 2.065371 -1.4703183\n", - " -0.40964937 -1.4454535 ]\n", - "param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: [[-0.46178514 0.1095643 0.06441769 ... 0.42020613 -0.34181893\n", - " -0.0658682 ]\n", - " [-0.03619978 0.21653323 0.01727325 ... 0.05731536 -0.37822944\n", - " -0.05464617]\n", - " [-0.32397318 0.04158126 -0.08091418 ... 0.0928297 -0.06518176\n", - " -0.40110156]\n", - " ...\n", - " [-0.2702023 0.05126935 0.11825457 ... 0.0069707 -0.36951366\n", - " 0.37071258]\n", - " [-0.11326203 0.19305304 -0.133317 ... -0.13030824 -0.09068564\n", - " 0.32735693]\n", - " [-0.04543798 0.09902512 -0.10745425 ... -0.06685166 -0.3055201\n", - " 0.0752247 ]]\n", - "param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.07338604 0.64991236 0.5465856 ... 0.507725 0.14061031\n", - " 0.3020359 ]\n", - "param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.41395143 -0.28493872 0.36796764 ... 0.2387953 0.06732331\n", - " 0.16263628]\n", - "param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.09370177 -0.12264141 -0.08237482 ... -0.50241685 -0.149155\n", - " -0.25661892]\n", - " [-0.37426725 0.44987115 0.10685667 ... -0.65946174 -0.4499248\n", - " -0.17545304]\n", - " [-0.03753807 0.33422717 0.12750985 ... 0.05405155 -0.17648363\n", - " 0.05315325]\n", - " ...\n", - " [ 0.15721183 0.03064088 -0.00751081 ... 0.27183983 0.3881693\n", - " -0.01544908]\n", - " [ 0.26047793 0.16917065 0.00915196 ... 0.18076143 -0.05080506\n", - " 0.14791614]\n", - " [ 0.19052255 0.03642382 -0.14313167 ... 0.2611448 0.20763844\n", - " 0.26846847]]\n", - "param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.4139514 -0.28493875 0.36796758 ... 0.23879525 0.06732336\n", - " 0.16263627]\n", - "param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[ 0.04214853 -0.1710323 0.17557406 ... 0.11926915 0.21577051\n", - " -0.30598596]\n", - " [-0.02370887 -0.03498494 -0.05991999 ... -0.06049232 -0.14527473\n", - " -0.5335691 ]\n", - " [-0.21417995 -0.10263194 -0.05903128 ... -0.26958284 0.05936668\n", - " 0.25522667]\n", - " ...\n", - " [ 0.31594425 -0.29487017 0.15871571 ... 0.3504135 -0.1418606\n", - " -0.07482046]\n", - " [ 0.22316164 0.7682122 -0.22191924 ... -0.00535548 -0.6497105\n", - " -0.2011079 ]\n", - " [-0.05800886 0.13750821 0.02450509 ... 0.245736 0.07425706\n", - " -0.17761081]]\n", - "param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.45080703 0.19005743 0.077441 ... -0.24504453 0.19666554\n", - " -0.10503208]\n", - "param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237206 0.03389215 ... -0.35602498 0.25528812\n", - " 0.11344345]\n", - "param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.48457903 0.04466334 -0.19785863 ... -0.0254025 -0.10338341\n", - " -0.29202533]\n", - " [-0.15261276 0.00412052 0.22198747 ... 0.22460426 -0.03752084\n", - " 0.05170784]\n", - " [-0.09337254 0.02530848 0.1263681 ... -0.02056236 0.33342454\n", - " -0.08760723]\n", - " ...\n", - " [-0.28645608 -0.19169135 -0.1361257 ... -0.00444204 -0.06552711\n", - " -0.14726155]\n", - " [ 0.21883707 0.2049045 0.23723911 ... 0.4626113 -0.14110637\n", - " 0.02569831]\n", - " [ 0.37554163 -0.19249167 0.14591683 ... 0.25602737 0.40088275\n", - " 0.41056633]]\n", - "param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237211 0.0338921 ... -0.35602498 0.2552881\n", - " 0.11344352]\n", - "param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[-0.28007814 -0.09206 -0.01297755 ... -0.2557205 -0.2693453\n", - " 0.05862035]\n", - " [-0.34194735 -0.01383794 -0.06490533 ... -0.11063005 0.16226721\n", - " -0.3197178 ]\n", - " [-0.3646778 0.15443833 0.02241019 ... -0.15093157 -0.09886418\n", - " -0.44295847]\n", - " ...\n", - " [-0.01041886 -0.57636976 -0.03988511 ... -0.2260822 0.49646813\n", - " -0.15528557]\n", - " [-0.19385241 -0.56451964 -0.05551083 ... -0.5638106 0.43611372\n", - " -0.61484563]\n", - " [ 0.1051331 -0.4762463 0.11194798 ... -0.26766616 -0.30734932\n", - " 0.17856634]]\n", - "param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.02791309 -0.992517 0.63012564 ... -1.1830902 1.4646478\n", - " 1.6333911 ]\n", - "param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.10834587 -1.7079136 0.81259465 ... -1.4478713 1.455745\n", - " 2.069446 ]\n", - "param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.14363798 -0.06933184 0.02901152 ... -0.19233373 -0.03206367\n", - " -0.00845779]\n", - " [-0.44314507 -0.8921327 -1.031872 ... -0.558997 -0.53070104\n", - " -0.855925 ]\n", - " [ 0.15673254 0.28793585 0.13351494 ... 0.38433537 0.5040767\n", - " 0.11303265]\n", - " ...\n", - " [-0.22923109 -0.62508404 -0.6195032 ... -0.6876448 -0.41718128\n", - " -0.74844164]\n", - " [ 0.18024652 0.45618314 0.81391454 ... 0.5780604 0.87566674\n", - " 0.71526295]\n", - " [ 0.3763076 0.54033077 0.9940485 ... 1.087821 0.72288674\n", - " 1.2852117 ]]\n", - "param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.10834593 -1.7079139 0.8125948 ... -1.4478711 1.4557447\n", - " 2.0694466 ]\n", - "param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: [[ 1.4382483e-02 2.0160766e-02 1.2322801e-02 ... 1.0075266e-02\n", - " 7.4421698e-03 -2.3925617e+01]\n", - " [ 3.7887424e-02 5.7105277e-02 2.8803380e-02 ... 2.4820438e-02\n", - " 1.8560058e-02 -5.0687141e+01]\n", - " [ 4.5566272e-02 5.4415584e-02 3.2858539e-02 ... 3.2725763e-02\n", - " 2.1536341e-02 -6.1036335e+01]\n", - " ...\n", - " [ 2.8015019e-02 3.5967816e-02 2.3228688e-02 ... 2.1284629e-02\n", - " 1.3860047e-02 -5.2543671e+01]\n", - " [ 2.8445240e-02 4.2448867e-02 2.7125146e-02 ... 2.2253662e-02\n", - " 1.7470375e-02 -4.3619675e+01]\n", - " [ 4.7438074e-02 5.8287360e-02 3.4546286e-02 ... 3.0827176e-02\n", - " 2.2168703e-02 -6.7901680e+01]]\n", - "param grad: fc.bias: shape: [4299] stop_grad: False grad: [ 8.8967547e-02 1.0697905e-01 6.5251388e-02 ... 6.1503030e-02\n", - " 4.3404289e-02 -1.3512518e+02]\n" - ] - } - ], - "source": [ - "loss.backward(retain_graph=False)\n", - "for n, p in dp_model.named_parameters():\n", - " print(\n", - " f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "selected-crazy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1.]\n" - ] - } - ], - "source": [ - "print(loss.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bottom-engineer", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "stuffed-yeast", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/.notebook/u2_confermer_model_wenet.ipynb b/.notebook/u2_confermer_model_wenet.ipynb deleted file mode 100644 index a425e16cb6c5cc3c7f3a8883f39711d6c50fb8f0..0000000000000000000000000000000000000000 --- a/.notebook/u2_confermer_model_wenet.ipynb +++ /dev/null @@ -1,4608 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-grade", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "broke-broad", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "register user softmax to paddle, remove this when fixed!\n", - "register user log_softmax to paddle, remove this when fixed!\n", - "register user sigmoid to paddle, remove this when fixed!\n", - "register user log_sigmoid to paddle, remove this when fixed!\n", - "register user relu to paddle, remove this when fixed!\n", - "override cat of paddle if exists or register, remove this when fixed!\n", - "override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle if exists or register, remove this when fixed!\n", - "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "register user view to paddle.Tensor, remove this when fixed!\n", - "register user view_as to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "register user fill_ to paddle.Tensor, remove this when fixed!\n", - "register user repeat to paddle.Tensor, remove this when fixed!\n", - "register user softmax to paddle.Tensor, remove this when fixed!\n", - "register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "register user relu to paddle.Tensor, remove this when fixed!\n", - "register user type_as to paddle.Tensor, remove this when fixed!\n", - "register user to to paddle.Tensor, remove this when fixed!\n", - "register user float to paddle.Tensor, remove this when fixed!\n", - "register user tolist to paddle.Tensor, remove this when fixed!\n", - "register user glu to paddle.nn.functional, remove this when fixed!\n", - "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "register user Module to paddle.nn, remove this when fixed!\n", - "register user ModuleList to paddle.nn, remove this when fixed!\n", - "register user GLU to paddle.nn, remove this when fixed!\n", - "register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "register user export to paddle.jit, remove this when fixed!\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import paddle\n", - "from yacs.config import CfgNode as CN\n", - "\n", - "from deepspeech.models.u2 import U2Model\n", - "from deepspeech.utils.layer_tools import print_params\n", - "from deepspeech.utils.layer_tools import summary" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "permanent-summary", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[INFO 2021/04/20 03:32:21 u2.py:834] U2 Encoder type: conformer\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "encoder.embed.conv.0.bias | [256] | 256 | True\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "encoder.embed.conv.2.bias | [256] | 256 | True\n", - "encoder.embed.out.0.weight | [4864, 256] | 1245184 | True\n", - "encoder.embed.out.0.bias | [256] | 256 | True\n", - "encoder.after_norm.weight | [256] | 256 | True\n", - "encoder.after_norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.0.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.1.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.2.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.3.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.4.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.5.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.6.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.7.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.8.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.9.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.10.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.11.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", - "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", - "decoder.after_norm.weight | [256] | 256 | True\n", - "decoder.after_norm.bias | [256] | 256 | True\n", - "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", - "decoder.output_layer.bias | [4233] | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", - "ctc.ctc_lo.bias | [4233] | 4233 | True\n", - "Total parameters: 687.0, 49355282.0 elements.\n" - ] - } - ], - "source": [ - "conf_str='examples/aishell/s1/conf/conformer.yaml'\n", - "cfg = CN().load_cfg(open(conf_str))\n", - "cfg.model.input_dim = 80\n", - "cfg.model.output_dim = 4233\n", - "cfg.model.cmvn_file = \"/workspace/wenet/examples/aishell/s0/raw_wav/train/global_cmvn\"\n", - "cfg.model.cmvn_file_type = 'json'\n", - "cfg.freeze()\n", - "\n", - "model = U2Model(cfg.model)\n", - "print_params(model)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "sapphire-agent", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.global_cmvn.mean | [80] | 80\n", - "encoder.global_cmvn.istd | [80] | 80\n", - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304\n", - "encoder.embed.conv.0.bias | [256] | 256\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824\n", - "encoder.embed.conv.2.bias | [256] | 256\n", - "encoder.embed.out.0.weight | [4864, 256] | 1245184\n", - "encoder.embed.out.0.bias | [256] | 256\n", - "encoder.after_norm.weight | [256] | 256\n", - "encoder.after_norm.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.0.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.0.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.0.norm_ff.weight | [256] | 256\n", - "encoder.encoders.0.norm_ff.bias | [256] | 256\n", - "encoder.encoders.0.norm_mha.weight | [256] | 256\n", - "encoder.encoders.0.norm_mha.bias | [256] | 256\n", - "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.0.norm_conv.weight | [256] | 256\n", - "encoder.encoders.0.norm_conv.bias | [256] | 256\n", - "encoder.encoders.0.norm_final.weight | [256] | 256\n", - "encoder.encoders.0.norm_final.bias | [256] | 256\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.1.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.1.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.1.norm_ff.weight | [256] | 256\n", - "encoder.encoders.1.norm_ff.bias | [256] | 256\n", - "encoder.encoders.1.norm_mha.weight | [256] | 256\n", - "encoder.encoders.1.norm_mha.bias | [256] | 256\n", - "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.1.norm_conv.weight | [256] | 256\n", - "encoder.encoders.1.norm_conv.bias | [256] | 256\n", - "encoder.encoders.1.norm_final.weight | [256] | 256\n", - "encoder.encoders.1.norm_final.bias | [256] | 256\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.2.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.2.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.2.norm_ff.weight | [256] | 256\n", - "encoder.encoders.2.norm_ff.bias | [256] | 256\n", - "encoder.encoders.2.norm_mha.weight | [256] | 256\n", - "encoder.encoders.2.norm_mha.bias | [256] | 256\n", - "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.2.norm_conv.weight | [256] | 256\n", - "encoder.encoders.2.norm_conv.bias | [256] | 256\n", - "encoder.encoders.2.norm_final.weight | [256] | 256\n", - "encoder.encoders.2.norm_final.bias | [256] | 256\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.3.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.3.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.3.norm_ff.weight | [256] | 256\n", - "encoder.encoders.3.norm_ff.bias | [256] | 256\n", - "encoder.encoders.3.norm_mha.weight | [256] | 256\n", - "encoder.encoders.3.norm_mha.bias | [256] | 256\n", - "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.3.norm_conv.weight | [256] | 256\n", - "encoder.encoders.3.norm_conv.bias | [256] | 256\n", - "encoder.encoders.3.norm_final.weight | [256] | 256\n", - "encoder.encoders.3.norm_final.bias | [256] | 256\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.4.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.4.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.4.norm_ff.weight | [256] | 256\n", - "encoder.encoders.4.norm_ff.bias | [256] | 256\n", - "encoder.encoders.4.norm_mha.weight | [256] | 256\n", - "encoder.encoders.4.norm_mha.bias | [256] | 256\n", - "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.4.norm_conv.weight | [256] | 256\n", - "encoder.encoders.4.norm_conv.bias | [256] | 256\n", - "encoder.encoders.4.norm_final.weight | [256] | 256\n", - "encoder.encoders.4.norm_final.bias | [256] | 256\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.5.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.5.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.5.norm_ff.weight | [256] | 256\n", - "encoder.encoders.5.norm_ff.bias | [256] | 256\n", - "encoder.encoders.5.norm_mha.weight | [256] | 256\n", - "encoder.encoders.5.norm_mha.bias | [256] | 256\n", - "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.5.norm_conv.weight | [256] | 256\n", - "encoder.encoders.5.norm_conv.bias | [256] | 256\n", - "encoder.encoders.5.norm_final.weight | [256] | 256\n", - "encoder.encoders.5.norm_final.bias | [256] | 256\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.6.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.6.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.6.norm_ff.weight | [256] | 256\n", - "encoder.encoders.6.norm_ff.bias | [256] | 256\n", - "encoder.encoders.6.norm_mha.weight | [256] | 256\n", - "encoder.encoders.6.norm_mha.bias | [256] | 256\n", - "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.6.norm_conv.weight | [256] | 256\n", - "encoder.encoders.6.norm_conv.bias | [256] | 256\n", - "encoder.encoders.6.norm_final.weight | [256] | 256\n", - "encoder.encoders.6.norm_final.bias | [256] | 256\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.7.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.7.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.7.norm_ff.weight | [256] | 256\n", - "encoder.encoders.7.norm_ff.bias | [256] | 256\n", - "encoder.encoders.7.norm_mha.weight | [256] | 256\n", - "encoder.encoders.7.norm_mha.bias | [256] | 256\n", - "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.7.norm_conv.weight | [256] | 256\n", - "encoder.encoders.7.norm_conv.bias | [256] | 256\n", - "encoder.encoders.7.norm_final.weight | [256] | 256\n", - "encoder.encoders.7.norm_final.bias | [256] | 256\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.8.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.8.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.8.norm_ff.weight | [256] | 256\n", - "encoder.encoders.8.norm_ff.bias | [256] | 256\n", - "encoder.encoders.8.norm_mha.weight | [256] | 256\n", - "encoder.encoders.8.norm_mha.bias | [256] | 256\n", - "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.8.norm_conv.weight | [256] | 256\n", - "encoder.encoders.8.norm_conv.bias | [256] | 256\n", - "encoder.encoders.8.norm_final.weight | [256] | 256\n", - "encoder.encoders.8.norm_final.bias | [256] | 256\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.9.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.9.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.9.norm_ff.weight | [256] | 256\n", - "encoder.encoders.9.norm_ff.bias | [256] | 256\n", - "encoder.encoders.9.norm_mha.weight | [256] | 256\n", - "encoder.encoders.9.norm_mha.bias | [256] | 256\n", - "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.9.norm_conv.weight | [256] | 256\n", - "encoder.encoders.9.norm_conv.bias | [256] | 256\n", - "encoder.encoders.9.norm_final.weight | [256] | 256\n", - "encoder.encoders.9.norm_final.bias | [256] | 256\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.10.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.10.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.10.norm_ff.weight | [256] | 256\n", - "encoder.encoders.10.norm_ff.bias | [256] | 256\n", - "encoder.encoders.10.norm_mha.weight | [256] | 256\n", - "encoder.encoders.10.norm_mha.bias | [256] | 256\n", - "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.10.norm_conv.weight | [256] | 256\n", - "encoder.encoders.10.norm_conv.bias | [256] | 256\n", - "encoder.encoders.10.norm_final.weight | [256] | 256\n", - "encoder.encoders.10.norm_final.bias | [256] | 256\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.11.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.11.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.11.norm_ff.weight | [256] | 256\n", - "encoder.encoders.11.norm_ff.bias | [256] | 256\n", - "encoder.encoders.11.norm_mha.weight | [256] | 256\n", - "encoder.encoders.11.norm_mha.bias | [256] | 256\n", - "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.11.norm_conv.weight | [256] | 256\n", - "encoder.encoders.11.norm_conv.bias | [256] | 256\n", - "encoder.encoders.11.norm_final.weight | [256] | 256\n", - "encoder.encoders.11.norm_final.bias | [256] | 256\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256\n", - "decoder.embed.0.weight | [4233, 256] | 1083648\n", - "decoder.after_norm.weight | [256] | 256\n", - "decoder.after_norm.bias | [256] | 256\n", - "decoder.output_layer.weight | [256, 4233] | 1083648\n", - "decoder.output_layer.bias | [4233] | 4233\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.0.norm1.weight | [256] | 256\n", - "decoder.decoders.0.norm1.bias | [256] | 256\n", - "decoder.decoders.0.norm2.weight | [256] | 256\n", - "decoder.decoders.0.norm2.bias | [256] | 256\n", - "decoder.decoders.0.norm3.weight | [256] | 256\n", - "decoder.decoders.0.norm3.bias | [256] | 256\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.1.norm1.weight | [256] | 256\n", - "decoder.decoders.1.norm1.bias | [256] | 256\n", - "decoder.decoders.1.norm2.weight | [256] | 256\n", - "decoder.decoders.1.norm2.bias | [256] | 256\n", - "decoder.decoders.1.norm3.weight | [256] | 256\n", - "decoder.decoders.1.norm3.bias | [256] | 256\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.2.norm1.weight | [256] | 256\n", - "decoder.decoders.2.norm1.bias | [256] | 256\n", - "decoder.decoders.2.norm2.weight | [256] | 256\n", - "decoder.decoders.2.norm2.bias | [256] | 256\n", - "decoder.decoders.2.norm3.weight | [256] | 256\n", - "decoder.decoders.2.norm3.bias | [256] | 256\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.3.norm1.weight | [256] | 256\n", - "decoder.decoders.3.norm1.bias | [256] | 256\n", - "decoder.decoders.3.norm2.weight | [256] | 256\n", - "decoder.decoders.3.norm2.bias | [256] | 256\n", - "decoder.decoders.3.norm3.weight | [256] | 256\n", - "decoder.decoders.3.norm3.bias | [256] | 256\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.4.norm1.weight | [256] | 256\n", - "decoder.decoders.4.norm1.bias | [256] | 256\n", - "decoder.decoders.4.norm2.weight | [256] | 256\n", - "decoder.decoders.4.norm2.bias | [256] | 256\n", - "decoder.decoders.4.norm3.weight | [256] | 256\n", - "decoder.decoders.4.norm3.bias | [256] | 256\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.5.norm1.weight | [256] | 256\n", - "decoder.decoders.5.norm1.bias | [256] | 256\n", - "decoder.decoders.5.norm2.weight | [256] | 256\n", - "decoder.decoders.5.norm2.bias | [256] | 256\n", - "decoder.decoders.5.norm3.weight | [256] | 256\n", - "decoder.decoders.5.norm3.bias | [256] | 256\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648\n", - "ctc.ctc_lo.bias | [4233] | 4233\n", - "Total parameters: 689, 49355442 elements.\n" - ] - } - ], - "source": [ - "summary(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "ruled-invitation", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "U2Model(\n", - " (encoder): ConformerEncoder(\n", - " (global_cmvn): GlobalCMVN()\n", - " (embed): Conv2dSubsampling4(\n", - " (pos_enc): RelPositionalEncoding(\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " )\n", - " (conv): Sequential(\n", - " (0): Conv2D(1, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n", - " (1): ReLU()\n", - " (2): Conv2D(256, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n", - " (3): ReLU()\n", - " )\n", - " (out): Sequential(\n", - " (0): Linear(in_features=4864, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (encoders): LayerList(\n", - " (0): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (1): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (2): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (3): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (4): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (5): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (6): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (7): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (8): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (9): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (10): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (11): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " )\n", - " (decoder): TransformerDecoder(\n", - " (embed): Sequential(\n", - " (0): Embedding(4233, 256, sparse=False)\n", - " (1): PositionalEncoding(\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (output_layer): Linear(in_features=256, out_features=4233, dtype=float32)\n", - " (decoders): LayerList(\n", - " (0): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " )\n", - " (ctc): CTCDecoder(\n", - " (ctc_lo): Linear(in_features=256, out_features=4233, dtype=float32)\n", - " (criterion): CTCLoss(\n", - " (loss): CTCLoss()\n", - " )\n", - " )\n", - " (criterion_att): LabelSmoothingLoss(\n", - " (criterion): KLDivLoss()\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "fossil-means", - "metadata": {}, - "outputs": [], - "source": [ - "# load feat" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "fleet-despite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "compute_cmvn_loader_test.ipynb encoder.npz\r\n", - "dataloader.ipynb hack_api_test.ipynb\r\n", - "dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n", - "data.npz layer_norm_test.ipynb\r\n", - "decoder.npz Linear_test.ipynb\r\n", - "enc_0_ff_out.npz mask_and_masked_fill_test.ipynb\r\n", - "enc_0_norm_ff.npz model.npz\r\n", - "enc_0.npz position_embeding_check.ipynb\r\n", - "enc_0_selattn_out.npz python_test.ipynb\r\n", - "enc_2.npz train_test.ipynb\r\n", - "enc_all.npz u2_model.ipynb\r\n", - "enc_embed.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "abroad-oracle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['BAC009S0739W0246' 'BAC009S0727W0424' 'BAC009S0753W0412'\n", - " 'BAC009S0756W0206' 'BAC009S0740W0414' 'BAC009S0728W0426'\n", - " 'BAC009S0739W0214' 'BAC009S0753W0423' 'BAC009S0734W0201'\n", - " 'BAC009S0740W0427' 'BAC009S0730W0423' 'BAC009S0728W0367'\n", - " 'BAC009S0730W0418' 'BAC009S0727W0157' 'BAC009S0749W0409'\n", - " 'BAC009S0727W0418']\n", - "(16, 207, 80)\n", - "[[[ 8.994624 9.538309 9.191589 ... 10.507416 9.563305 8.256403 ]\n", - " [ 9.798841 10.405224 9.26511 ... 10.251211 9.543982 8.873768 ]\n", - " [10.6890745 10.395469 8.053548 ... 9.906749 10.064903 8.050915 ]\n", - " ...\n", - " [ 9.217986 9.65069 8.505259 ... 9.687183 8.742463 7.9865475]\n", - " [10.129122 9.935194 9.37982 ... 9.563894 9.825992 8.979543 ]\n", - " [ 9.095531 7.1338377 9.468001 ... 9.472748 9.021235 7.447914 ]]\n", - "\n", - " [[11.430976 10.671858 6.0841026 ... 9.382682 8.729745 7.5315614]\n", - " [ 9.731717 7.8104815 7.5714607 ... 10.043035 9.243595 7.3540792]\n", - " [10.65017 10.600604 8.467784 ... 9.281448 9.186885 8.070343 ]\n", - " ...\n", - " [ 9.096987 9.2637 8.075275 ... 8.431845 8.370505 8.002926 ]\n", - " [10.461651 10.147784 6.7693496 ... 9.779426 9.577453 8.080652 ]\n", - " [ 7.794432 5.621059 7.9750648 ... 9.997245 9.849678 8.031287 ]]\n", - "\n", - " [[ 7.3455667 7.896357 7.5795946 ... 11.631024 10.451254 9.123633 ]\n", - " [ 8.628678 8.4630575 7.499242 ... 12.415986 10.975749 8.9425745]\n", - " [ 9.831394 10.2812805 8.97241 ... 12.1386795 10.40175 9.005517 ]\n", - " ...\n", - " [ 7.089641 7.405548 6.8142557 ... 9.325196 9.273162 8.353427 ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " ...\n", - "\n", - " [[10.933237 10.464394 7.7202725 ... 10.348816 9.302338 7.1553144]\n", - " [10.449866 9.907033 9.029272 ... 9.952465 9.414051 7.559279 ]\n", - " [10.487655 9.81259 9.895244 ... 9.58662 9.341254 7.7849016]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " [[ 9.944384 9.585867 8.220328 ... 11.588647 11.045029 8.817075 ]\n", - " [ 7.678356 8.322397 7.533047 ... 11.055085 10.535685 9.27465 ]\n", - " [ 8.626197 9.675917 9.841045 ... 11.378827 10.922112 8.991444 ]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " [[ 8.107938 7.759043 6.710301 ... 12.650573 11.466156 11.061517 ]\n", - " [11.380332 11.222007 8.658889 ... 12.810616 12.222216 11.689288 ]\n", - " [10.677676 9.920579 8.046089 ... 13.572894 12.5624075 11.155033 ]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]]\n", - "[207 207 205 205 203 203 198 197 195 188 186 186 185 180 166 163]\n", - "[[2995 3116 1209 565 -1 -1]\n", - " [ 236 1176 331 66 3925 4077]\n", - " [2693 524 234 1145 366 -1]\n", - " [3875 4211 3062 700 -1 -1]\n", - " [ 272 987 1134 494 2959 -1]\n", - " [1936 3715 120 2553 2695 2710]\n", - " [ 25 1149 3930 -1 -1 -1]\n", - " [1753 1778 1237 482 3925 110]\n", - " [3703 2 565 3827 -1 -1]\n", - " [1150 2734 10 2478 3490 -1]\n", - " [ 426 811 95 489 144 -1]\n", - " [2313 2006 489 975 -1 -1]\n", - " [3702 3414 205 1488 2966 1347]\n", - " [ 70 1741 702 1666 -1 -1]\n", - " [ 703 1778 1030 849 -1 -1]\n", - " [ 814 1674 115 3827 -1 -1]]\n", - "[4 6 5 4 5 6 3 6 4 5 5 4 6 4 4 4]\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/data.npz', allow_pickle=True)\n", - "keys=data['keys']\n", - "feat=data['feat']\n", - "feat_len=data['feat_len']\n", - "text=data['text']\n", - "text_len=data['text_len']\n", - "print(keys)\n", - "print(feat.shape)\n", - "print(feat)\n", - "print(feat_len)\n", - "print(text)\n", - "print(text_len)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "false-instrument", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "arctic-proxy", - "metadata": {}, - "outputs": [], - "source": [ - "# ['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", - "# torch.Size([16, 207, 80])\n", - "# tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", - "# [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", - "# [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", - "# ...,\n", - "# [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", - "# [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", - "# [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", - "\n", - "# [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", - "# [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", - "# [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", - "# ...,\n", - "# [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", - "# [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", - "# [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", - "\n", - "# [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", - "# [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", - "# [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", - "# ...,\n", - "# [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# ...,\n", - "\n", - "# [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", - "# [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", - "# [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", - "# [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", - "# [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", - "# [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", - "# [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", - "# tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", - "# 166, 163], dtype=torch.int32)\n", - "# tensor([[2995, 3116, 1209, 565, -1, -1],\n", - "# [ 236, 1176, 331, 66, 3925, 4077],\n", - "# [2693, 524, 234, 1145, 366, -1],\n", - "# [3875, 4211, 3062, 700, -1, -1],\n", - "# [ 272, 987, 1134, 494, 2959, -1],\n", - "# [1936, 3715, 120, 2553, 2695, 2710],\n", - "# [ 25, 1149, 3930, -1, -1, -1],\n", - "# [1753, 1778, 1237, 482, 3925, 110],\n", - "# [3703, 2, 565, 3827, -1, -1],\n", - "# [1150, 2734, 10, 2478, 3490, -1],\n", - "# [ 426, 811, 95, 489, 144, -1],\n", - "# [2313, 2006, 489, 975, -1, -1],\n", - "# [3702, 3414, 205, 1488, 2966, 1347],\n", - "# [ 70, 1741, 702, 1666, -1, -1],\n", - "# [ 703, 1778, 1030, 849, -1, -1],\n", - "# [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", - "# tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "seasonal-switch", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "defined-brooks", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "compute_cmvn_loader_test.ipynb\t encoder.npz\r\n", - "dataloader.ipynb\t\t hack_api_test.ipynb\r\n", - "dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n", - "data.npz\t\t\t layer_norm_test.ipynb\r\n", - "decoder.npz\t\t\t Linear_test.ipynb\r\n", - "enc_0_ff_out.npz\t\t mask_and_masked_fill_test.ipynb\r\n", - "enc_0_norm_ff.npz\t\t model.npz\r\n", - "enc_0.npz\t\t\t position_embeding_check.ipynb\r\n", - "enc_0_selattn_out.npz\t\t python_test.ipynb\r\n", - "enc_2.npz\t\t\t train_test.ipynb\r\n", - "enc_all.npz\t\t\t u2_model.ipynb\r\n", - "enc_embed.npz\r\n" - ] - } - ], - "source": [ - "# load model param\n", - "!ls .notebook\n", - "data = np.load('.notebook/model.npz', allow_pickle=True)\n", - "state_dict = data['state'].item()\n", - "\n", - "for key, _ in model.state_dict().items():\n", - " if key not in state_dict:\n", - " print(f\"{key} not find.\")\n", - "\n", - "model.set_state_dict(state_dict)\n", - "\n", - "now_state_dict = model.state_dict()\n", - "for key, value in now_state_dict.items():\n", - " if not np.allclose(value.numpy(), state_dict[key]):\n", - " print(key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "exempt-viewer", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "confident-piano", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/framework.py:687: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " elif dtype == np.bool:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [142.48880005]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [377.33258057])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:238: UserWarning: The dtype of left and right variables are not the same, left dtype is VarType.FP32, but right dtype is VarType.INT32, the right dtype will convert to VarType.FP32\n", - " format(lhs_dtype, rhs_dtype, lhs_dtype))\n" - ] - } - ], - "source": [ - "# compute loss\n", - "import paddle\n", - "feat=paddle.to_tensor(feat)\n", - "feat_len=paddle.to_tensor(feat_len, dtype='int64')\n", - "text=paddle.to_tensor(text, dtype='int64')\n", - "text_len=paddle.to_tensor(text_len, dtype='int64')\n", - "\n", - "model.eval()\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss, attention_loss, ctc_loss )" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "better-senator", - "metadata": {}, - "outputs": [], - "source": [ - "# tensor(142.4888, device='cuda:0', grad_fn=) \n", - "# tensor(41.8415, device='cuda:0', grad_fn=) \n", - "# tensor(377.3326, device='cuda:0', grad_fn=)\n", - "# 142.4888 41.84146 377.33258" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "related-banking", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "olympic-problem", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[16, 51, 256]\n", - "[16, 1, 51]\n", - "Tensor(shape=[51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[-0.70194179, 0.56254166, 0.68803459, ..., 1.12373221, 0.78039235, 1.13693869],\n", - " [-0.77877808, 0.39126658, 0.71887815, ..., 1.25188220, 0.88616788, 1.31734526],\n", - " [-0.95908946, 0.63460249, 0.87671334, ..., 0.98183727, 0.74401081, 1.29032660],\n", - " ...,\n", - " [-1.07322502, 0.67236906, 0.92303109, ..., 0.90754563, 0.81767166, 1.32396567],\n", - " [-1.16541159, 0.68199694, 0.69394493, ..., 1.22383487, 0.80282891, 1.45065081],\n", - " [-1.27320945, 0.71458030, 0.75819558, ..., 0.94154912, 0.87748396, 1.26230514]])\n" - ] - } - ], - "source": [ - "# ecnoder\n", - "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)\n", - "print(encoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "shaped-alaska", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "deepspeech examples README_cn.md\tsetup.sh tools\r\n", - "docs\t LICENSE README.md\t\ttests\t utils\r\n", - "env.sh\t log requirements.txt\tthird_party\r\n" - ] - } - ], - "source": [ - "!ls\n", - "data = np.load('.notebook/encoder.npz', allow_pickle=True)\n", - "torch_mask = data['mask']\n", - "torch_encoder_out = data['out']" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "federal-rover", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "None\n" - ] - } - ], - "source": [ - "print(np.testing.assert_equal(torch_mask, encoder_mask.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "regulated-interstate", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n", - " 1.2623053 ]]\n", - "----\n", - "[[-0.7019418 0.56254166 0.6880346 ... 1.1237322 0.78039235\n", - " 1.1369387 ]\n", - " [-0.7787781 0.39126658 0.71887815 ... 1.2518822 0.8861679\n", - " 1.3173453 ]\n", - " [-0.95908946 0.6346025 0.87671334 ... 0.9818373 0.7440108\n", - " 1.2903266 ]\n", - " ...\n", - " [-1.073225 0.67236906 0.9230311 ... 0.9075456 0.81767166\n", - " 1.3239657 ]\n", - " [-1.1654116 0.68199694 0.69394493 ... 1.2238349 0.8028289\n", - " 1.4506508 ]\n", - " [-1.2732095 0.7145803 0.7581956 ... 0.9415491 0.87748396\n", - " 1.2623051 ]]\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", - "print(torch_encoder_out[0])\n", - "print(\"----\")\n", - "print(encoder_out.numpy()[0])\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5, rtol=1e-6))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6, rtol=1e-6))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "proof-scheduling", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [377.33258057])\n", - "[1.]\n", - "[[ 3.16902876e+00 -1.51763987e-02 4.91095744e-02 ... -2.47971853e-03\n", - " -5.93360700e-03 -7.26609165e-03]\n", - " [-1.74184477e+00 7.75874173e-03 -4.49434854e-02 ... 9.92412097e-04\n", - " 2.46337592e-03 2.31892057e-03]\n", - " [-2.33343339e+00 1.30475955e-02 -2.66557075e-02 ... 2.27532350e-03\n", - " 5.76924905e-03 7.48788286e-03]\n", - " ...\n", - " [-4.30358458e+00 2.46054661e-02 -9.00950655e-02 ... 4.43156436e-03\n", - " 1.16122244e-02 1.44715561e-02]\n", - " [-3.36921120e+00 1.73153952e-02 -6.36872873e-02 ... 3.28363618e-03\n", - " 8.58010259e-03 1.07794888e-02]\n", - " [-6.62045336e+00 3.49955931e-02 -1.23962618e-01 ... 6.36671018e-03\n", - " 1.60814095e-02 2.03891303e-02]]\n", - "[-4.3777819e+00 2.3245810e-02 -9.3339294e-02 ... 4.2569344e-03\n", - " 1.0919910e-02 1.3787797e-02]\n" - ] - } - ], - "source": [ - "from paddle.nn import functional as F\n", - "def ctc_loss(logits,\n", - " labels,\n", - " input_lengths,\n", - " label_lengths,\n", - " blank=0,\n", - " reduction='mean',\n", - " norm_by_times=False):\n", - " loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,\n", - " input_lengths, label_lengths)\n", - " loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])\n", - " assert reduction in ['mean', 'sum', 'none']\n", - " if reduction == 'mean':\n", - " loss_out = paddle.mean(loss_out / label_lengths)\n", - " elif reduction == 'sum':\n", - " loss_out = paddle.sum(loss_out)\n", - " return loss_out\n", - "\n", - "F.ctc_loss = ctc_loss\n", - "\n", - "torch_mask_t = paddle.to_tensor(torch_mask, dtype='int64')\n", - "encoder_out_lens = torch_mask_t.squeeze(1).sum(1)\n", - "loss_ctc = model.ctc(paddle.to_tensor(torch_encoder_out), encoder_out_lens, text, text_len)\n", - "print(loss_ctc)\n", - "loss_ctc.backward()\n", - "print(loss_ctc.grad)\n", - "print(model.ctc.ctc_lo.weight.grad)\n", - "print(model.ctc.ctc_lo.bias.grad)\n", - "\n", - "\n", - "# tensor(377.3326, device='cuda:0', grad_fn=)\n", - "# None\n", - "# [[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", - "# -5.93366381e-03 -7.26613170e-03]\n", - "# [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", - "# 2.46338220e-03 2.31891591e-03]\n", - "# [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", - "# 5.76929189e-03 7.48792710e-03]\n", - "# ...\n", - "# [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", - "# 1.16123557e-02 1.44716976e-02]\n", - "# [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", - "# 8.58021621e-03 1.07796099e-02]\n", - "# [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", - "# 1.60815325e-02 2.03892551e-02]]\n", - "# [-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", - "# 1.0920014e-02 1.3787906e-02]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "enclosed-consolidation", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "synthetic-hungarian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118]) 0.0\n" - ] - } - ], - "source": [ - "loss_att, acc_att = model._calc_att_loss(paddle.to_tensor(torch_encoder_out), paddle.to_tensor(torch_mask),\n", - " text, text_len)\n", - "print(loss_att, acc_att)\n", - "#tensor(41.8416, device='cuda:0', grad_fn=) 0.0" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "indian-sweden", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 202, - "id": "marine-cuisine", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n", - " 1.5034772e-02 4.0337229e-01]\n", - " [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n", - " -1.4352810e-01 -1.0023664e+00]\n", - " [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n", - " -1.1672243e+00 -2.6848501e-01]\n", - " ...\n", - " [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n", - " 4.5470044e-02 -3.7139410e-01]\n", - " [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n", - " -7.9347193e-04 4.2537671e-01]\n", - " [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n", - " -1.0499060e-03 4.2678756e-01]]\n" - ] - } - ], - "source": [ - "data = np.load(\".notebook/decoder.npz\", allow_pickle=True)\n", - "torch_decoder_out = data['decoder_out']\n", - "print(torch_decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 180, - "id": "several-result", - "metadata": {}, - "outputs": [], - "source": [ - "def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,\n", - " ignore_id: int):\n", - " \"\"\"Add and labels.\n", - " Args:\n", - " ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)\n", - " sos (int): index of \n", - " eos (int): index of \n", - " ignore_id (int): index of padding\n", - " Returns:\n", - " ys_in (paddle.Tensor) : (B, Lmax + 1)\n", - " ys_out (paddle.Tensor) : (B, Lmax + 1)\n", - " Examples:\n", - " >>> sos_id = 10\n", - " >>> eos_id = 11\n", - " >>> ignore_id = -1\n", - " >>> ys_pad\n", - " tensor([[ 1, 2, 3, 4, 5],\n", - " [ 4, 5, 6, -1, -1],\n", - " [ 7, 8, 9, -1, -1]], dtype=paddle.int32)\n", - " >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)\n", - " >>> ys_in\n", - " tensor([[10, 1, 2, 3, 4, 5],\n", - " [10, 4, 5, 6, 11, 11],\n", - " [10, 7, 8, 9, 11, 11]])\n", - " >>> ys_out\n", - " tensor([[ 1, 2, 3, 4, 5, 11],\n", - " [ 4, 5, 6, 11, -1, -1],\n", - " [ 7, 8, 9, 11, -1, -1]])\n", - " \"\"\"\n", - " # TODO(Hui Zhang): using comment code, \n", - " #_sos = paddle.to_tensor(\n", - " # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", - " #_eos = paddle.to_tensor(\n", - " # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", - " #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", - " #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]\n", - " #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]\n", - " #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)\n", - " B = ys_pad.size(0)\n", - " _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos\n", - " _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos\n", - " ys_in = paddle.cat([_sos, ys_pad], dim=1)\n", - " mask_pad = (ys_in == ignore_id)\n", - " ys_in = ys_in.masked_fill(mask_pad, eos)\n", - " \n", - "\n", - " ys_out = paddle.cat([ys_pad, _eos], dim=1)\n", - " ys_out = ys_out.masked_fill(mask_pad, eos)\n", - " mask_eos = (ys_out == ignore_id)\n", - " ys_out = ys_out.masked_fill(mask_eos, eos)\n", - " ys_out = ys_out.masked_fill(mask_pad, ignore_id)\n", - " return ys_in, ys_out" - ] - }, - { - "cell_type": "code", - "execution_count": 181, - "id": "possible-bulgaria", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [[4232, 2995, 3116, 1209, 565 , 4232, 4232],\n", - " [4232, 236 , 1176, 331 , 66 , 3925, 4077],\n", - " [4232, 2693, 524 , 234 , 1145, 366 , 4232],\n", - " [4232, 3875, 4211, 3062, 700 , 4232, 4232],\n", - " [4232, 272 , 987 , 1134, 494 , 2959, 4232],\n", - " [4232, 1936, 3715, 120 , 2553, 2695, 2710],\n", - " [4232, 25 , 1149, 3930, 4232, 4232, 4232],\n", - " [4232, 1753, 1778, 1237, 482 , 3925, 110 ],\n", - " [4232, 3703, 2 , 565 , 3827, 4232, 4232],\n", - " [4232, 1150, 2734, 10 , 2478, 3490, 4232],\n", - " [4232, 426 , 811 , 95 , 489 , 144 , 4232],\n", - " [4232, 2313, 2006, 489 , 975 , 4232, 4232],\n", - " [4232, 3702, 3414, 205 , 1488, 2966, 1347],\n", - " [4232, 70 , 1741, 702 , 1666, 4232, 4232],\n", - " [4232, 703 , 1778, 1030, 849 , 4232, 4232],\n", - " [4232, 814 , 1674, 115 , 3827, 4232, 4232]])\n", - "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [[2995, 3116, 1209, 565, 4232, -1 , -1 ],\n", - " [ 236, 1176, 331, 66 , 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1 ],\n", - " [3875, 4211, 3062, 700, 4232, -1 , -1 ],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1 ],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25 , 1149, 3930, 4232, -1 , -1 , -1 ],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2 , 565, 3827, 4232, -1 , -1 ],\n", - " [1150, 2734, 10 , 2478, 3490, 4232, -1 ],\n", - " [ 426, 811, 95 , 489, 144, 4232, -1 ],\n", - " [2313, 2006, 489, 975, 4232, -1 , -1 ],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70 , 1741, 702, 1666, 4232, -1 , -1 ],\n", - " [ 703, 1778, 1030, 849, 4232, -1 , -1 ],\n", - " [ 814, 1674, 115, 3827, 4232, -1 , -1 ]])\n" - ] - } - ], - "source": [ - "ys_pad = text\n", - "ys_pad_lens = text_len\n", - "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", - " model.ignore_id)\n", - "ys_in_lens = ys_pad_lens + 1\n", - "print(ys_in_pad)\n", - "print(ys_out_pad)" - ] - }, - { - "cell_type": "code", - "execution_count": 285, - "id": "north-walter", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n", - "True\n", - "False\n", - "[[-3.76389682e-01 -8.22720408e-01 7.42762923e-01 ... 3.42005253e-01\n", - " 1.50350705e-02 4.03372347e-01]\n", - " [-8.73864174e-01 -3.13894272e-01 4.19878662e-01 ... 3.77237231e-01\n", - " -1.43528014e-01 -1.00236630e+00]\n", - " [-4.35050905e-01 3.45046446e-02 -2.87102997e-01 ... 7.72742853e-02\n", - " -1.16722476e+00 -2.68485069e-01]\n", - " ...\n", - " [ 4.24714804e-01 5.88856399e-01 2.02039629e-02 ... 3.74054879e-01\n", - " 4.54700664e-02 -3.71394157e-01]\n", - " [-3.79784584e-01 -8.10841978e-01 7.57250786e-01 ... 2.60389000e-01\n", - " -7.93404877e-04 4.25376773e-01]\n", - " [-3.82798851e-01 -8.12067091e-01 7.49434292e-01 ... 2.61730075e-01\n", - " -1.04988366e-03 4.26787734e-01]]\n", - "---\n", - "[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n", - " 1.5034772e-02 4.0337229e-01]\n", - " [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n", - " -1.4352810e-01 -1.0023664e+00]\n", - " [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n", - " -1.1672243e+00 -2.6848501e-01]\n", - " ...\n", - " [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n", - " 4.5470044e-02 -3.7139410e-01]\n", - " [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n", - " -7.9347193e-04 4.2537671e-01]\n", - " [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n", - " -1.0499060e-03 4.2678756e-01]]\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out))\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-6))\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-7))\n", - "print(decoder_out.numpy()[0])\n", - "print('---')\n", - "print(torch_decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "armed-cowboy", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fifty-earth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "proud-commonwealth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 183, - "id": "assisted-fortune", - "metadata": {}, - "outputs": [], - "source": [ - "from paddle import nn\n", - "import paddle\n", - "from paddle.nn import functional as F\n", - "\n", - "class LabelSmoothingLoss(nn.Layer):\n", - "\n", - " def __init__(self,\n", - " size: int,\n", - " padding_idx: int,\n", - " smoothing: float,\n", - " normalize_length: bool=False):\n", - " super().__init__()\n", - " self.size = size\n", - " self.padding_idx = padding_idx\n", - " self.smoothing = smoothing\n", - " self.confidence = 1.0 - smoothing\n", - " self.normalize_length = normalize_length\n", - " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", - "\n", - " def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:\n", - " \"\"\"Compute loss between x and target.\n", - " The model outputs and data labels tensors are flatten to\n", - " (batch*seqlen, class) shape and a mask is applied to the\n", - " padding part which should not be calculated for loss.\n", - " \n", - " Args:\n", - " x (paddle.Tensor): prediction (batch, seqlen, class)\n", - " target (paddle.Tensor):\n", - " target signal masked with self.padding_id (batch, seqlen)\n", - " Returns:\n", - " loss (paddle.Tensor) : The KL loss, scalar float value\n", - " \"\"\"\n", - " B, T, D = paddle.shape(x)\n", - " assert D == self.size\n", - " x = x.reshape((-1, self.size))\n", - " target = target.reshape([-1])\n", - "\n", - " # use zeros_like instead of torch.no_grad() for true_dist,\n", - " # since no_grad() can not be exported by JIT\n", - " true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))\n", - " ignore = target == self.padding_idx # (B,)\n", - " print(self.smoothing / (self.size - 1))\n", - " print(true_dist)\n", - "\n", - " #target = target * (1 - ignore) # avoid -1 index\n", - " target = target.masked_fill(ignore, 0) # avoid -1 index\n", - " \n", - " \n", - " #true_dist += F.one_hot(target, self.size) * self.confidence\n", - " target_mask = F.one_hot(target, self.size)\n", - " true_dist *= (1 - target_mask)\n", - " true_dist += target_mask * self.confidence\n", - " \n", - "\n", - " kl = self.criterion(F.log_softmax(x, axis=1), true_dist)\n", - " \n", - " #TODO(Hui Zhang): sum not support bool type\n", - " #total = len(target) - int(ignore.sum())\n", - " total = len(target) - int(ignore.type_as(target).sum())\n", - " denom = total if self.normalize_length else B\n", - "\n", - " #numer = (kl * (1 - ignore)).sum()\n", - " numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", - " return numer / denom\n" - ] - }, - { - "cell_type": "code", - "execution_count": 184, - "id": "weighted-delight", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.3629489603024576e-05\n", - "Tensor(shape=[112, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " ...,\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363]])\n", - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [41.84146118])\n", - "VarType.INT64\n" - ] - } - ], - "source": [ - "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", - "loss_att = criteron(paddle.to_tensor(torch_decoder_out), ys_out_pad.astype('int64'))\n", - "print(loss_att)\n", - "print(ys_out_pad.dtype)\n", - "# tensor(41.8416, device='cuda:0', grad_fn=)" - ] - }, - { - "cell_type": "code", - "execution_count": 286, - "id": "dress-shelter", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [41.84146118])\n", - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118])\n", - "4233\n", - "-1\n", - "0.1\n", - "False\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "\n", - "loss_att = model.criterion_att(paddle.to_tensor(torch_decoder_out), ys_out_pad)\n", - "print(loss_att)\n", - "\n", - "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "\n", - "print(model.criterion_att.size)\n", - "print(model.criterion_att.padding_idx)\n", - "print(model.criterion_att.smoothing)\n", - "print(model.criterion_att.normalize_length)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "growing-tooth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "going-hungary", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "naughty-citizenship", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "experimental-emerald", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "adverse-saskatchewan", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "speaking-shelf", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import List\n", - "from typing import Optional\n", - "from typing import Tuple\n", - "\n", - "import paddle\n", - "from paddle import nn\n", - "from typeguard import check_argument_types\n", - "\n", - "from deepspeech.modules.activation import get_activation\n", - "from deepspeech.modules.attention import MultiHeadedAttention\n", - "from deepspeech.modules.attention import RelPositionMultiHeadedAttention\n", - "from deepspeech.modules.conformer_convolution import ConvolutionModule\n", - "from deepspeech.modules.embedding import PositionalEncoding\n", - "from deepspeech.modules.embedding import RelPositionalEncoding\n", - "from deepspeech.modules.encoder_layer import ConformerEncoderLayer\n", - "from deepspeech.modules.encoder_layer import TransformerEncoderLayer\n", - "from deepspeech.modules.mask import add_optional_chunk_mask\n", - "from deepspeech.modules.mask import make_non_pad_mask\n", - "from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling4\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling6\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling8\n", - "from deepspeech.modules.subsampling import LinearNoSubsampling\n", - "\n", - "class BaseEncoder(nn.Layer):\n", - " def __init__(\n", - " self,\n", - " input_size: int,\n", - " output_size: int=256,\n", - " attention_heads: int=4,\n", - " linear_units: int=2048,\n", - " num_blocks: int=6,\n", - " dropout_rate: float=0.1,\n", - " positional_dropout_rate: float=0.1,\n", - " attention_dropout_rate: float=0.0,\n", - " input_layer: str=\"conv2d\",\n", - " pos_enc_layer_type: str=\"abs_pos\",\n", - " normalize_before: bool=True,\n", - " concat_after: bool=False,\n", - " static_chunk_size: int=0,\n", - " use_dynamic_chunk: bool=False,\n", - " global_cmvn: paddle.nn.Layer=None,\n", - " use_dynamic_left_chunk: bool=False, ):\n", - " \"\"\"\n", - " Args:\n", - " input_size (int): input dim, d_feature\n", - " output_size (int): dimension of attention, d_model\n", - " attention_heads (int): the number of heads of multi head attention\n", - " linear_units (int): the hidden units number of position-wise feed\n", - " forward\n", - " num_blocks (int): the number of encoder blocks\n", - " dropout_rate (float): dropout rate\n", - " attention_dropout_rate (float): dropout rate in attention\n", - " positional_dropout_rate (float): dropout rate after adding\n", - " positional encoding\n", - " input_layer (str): input layer type.\n", - " optional [linear, conv2d, conv2d6, conv2d8]\n", - " pos_enc_layer_type (str): Encoder positional encoding layer type.\n", - " opitonal [abs_pos, scaled_abs_pos, rel_pos]\n", - " normalize_before (bool):\n", - " True: use layer_norm before each sub-block of a layer.\n", - " False: use layer_norm after each sub-block of a layer.\n", - " concat_after (bool): whether to concat attention layer's input\n", - " and output.\n", - " True: x -> x + linear(concat(x, att(x)))\n", - " False: x -> x + att(x)\n", - " static_chunk_size (int): chunk size for static chunk training and\n", - " decoding\n", - " use_dynamic_chunk (bool): whether use dynamic chunk size for\n", - " training or not, You can only use fixed chunk(chunk_size > 0)\n", - " or dyanmic chunk size(use_dynamic_chunk = True)\n", - " global_cmvn (Optional[paddle.nn.Layer]): Optional GlobalCMVN layer\n", - " use_dynamic_left_chunk (bool): whether use dynamic left chunk in\n", - " dynamic chunk training\n", - " \"\"\"\n", - " assert check_argument_types()\n", - " super().__init__()\n", - " self._output_size = output_size\n", - "\n", - " if pos_enc_layer_type == \"abs_pos\":\n", - " pos_enc_class = PositionalEncoding\n", - " elif pos_enc_layer_type == \"rel_pos\":\n", - " pos_enc_class = RelPositionalEncoding\n", - " else:\n", - " raise ValueError(\"unknown pos_enc_layer: \" + pos_enc_layer_type)\n", - "\n", - " if input_layer == \"linear\":\n", - " subsampling_class = LinearNoSubsampling\n", - " elif input_layer == \"conv2d\":\n", - " subsampling_class = Conv2dSubsampling4\n", - " elif input_layer == \"conv2d6\":\n", - " subsampling_class = Conv2dSubsampling6\n", - " elif input_layer == \"conv2d8\":\n", - " subsampling_class = Conv2dSubsampling8\n", - " else:\n", - " raise ValueError(\"unknown input_layer: \" + input_layer)\n", - "\n", - " self.global_cmvn = global_cmvn\n", - " self.embed = subsampling_class(\n", - " idim=input_size,\n", - " odim=output_size,\n", - " dropout_rate=dropout_rate,\n", - " pos_enc_class=pos_enc_class(\n", - " d_model=output_size, dropout_rate=positional_dropout_rate), )\n", - "\n", - " self.normalize_before = normalize_before\n", - " self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)\n", - " self.static_chunk_size = static_chunk_size\n", - " self.use_dynamic_chunk = use_dynamic_chunk\n", - " self.use_dynamic_left_chunk = use_dynamic_left_chunk\n", - "\n", - " def output_size(self) -> int:\n", - " return self._output_size\n", - "\n", - " def forward(\n", - " self,\n", - " xs: paddle.Tensor,\n", - " xs_lens: paddle.Tensor,\n", - " decoding_chunk_size: int=0,\n", - " num_decoding_left_chunks: int=-1,\n", - " ) -> Tuple[paddle.Tensor, paddle.Tensor]:\n", - " \"\"\"Embed positions in tensor.\n", - " Args:\n", - " xs: padded input tensor (B, L, D)\n", - " xs_lens: input length (B)\n", - " decoding_chunk_size: decoding chunk size for dynamic chunk\n", - " 0: default for training, use random dynamic chunk.\n", - " <0: for decoding, use full chunk.\n", - " >0: for decoding, use fixed chunk size as set.\n", - " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", - " the chunk size is decoding_chunk_size.\n", - " >=0: use num_decoding_left_chunks\n", - " <0: use all left chunks\n", - " Returns:\n", - " encoder output tensor, lens and mask\n", - " \"\"\"\n", - " masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)\n", - "\n", - " if self.global_cmvn is not None:\n", - " xs = self.global_cmvn(xs)\n", - " #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor\n", - " xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)\n", - " #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor\n", - " masks = masks.astype(paddle.bool)\n", - " #TODO(Hui Zhang): mask_pad = ~masks\n", - " mask_pad = masks.logical_not()\n", - " chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,\n", - " decoding_chunk_size, self.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - " for layer in self.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " if self.normalize_before:\n", - " xs = self.after_norm(xs)\n", - " # Here we assume the mask is not changed in encoder layers, so just\n", - " # return the masks before encoder layers, and the masks will be used\n", - " # for cross attention with decoder later\n", - " return xs, masks" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "sharp-municipality", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "class ConformerEncoder(BaseEncoder):\n", - " \"\"\"Conformer encoder module.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " input_size: int,\n", - " output_size: int=256,\n", - " attention_heads: int=4,\n", - " linear_units: int=2048,\n", - " num_blocks: int=6,\n", - " dropout_rate: float=0.1,\n", - " positional_dropout_rate: float=0.1,\n", - " attention_dropout_rate: float=0.0,\n", - " input_layer: str=\"conv2d\",\n", - " pos_enc_layer_type: str=\"rel_pos\",\n", - " normalize_before: bool=True,\n", - " concat_after: bool=False,\n", - " static_chunk_size: int=0,\n", - " use_dynamic_chunk: bool=False,\n", - " global_cmvn: nn.Layer=None,\n", - " use_dynamic_left_chunk: bool=False,\n", - " positionwise_conv_kernel_size: int=1,\n", - " macaron_style: bool=True,\n", - " selfattention_layer_type: str=\"rel_selfattn\",\n", - " activation_type: str=\"swish\",\n", - " use_cnn_module: bool=True,\n", - " cnn_module_kernel: int=15,\n", - " causal: bool=False,\n", - " cnn_module_norm: str=\"batch_norm\", ):\n", - " \"\"\"Construct ConformerEncoder\n", - " Args:\n", - " input_size to use_dynamic_chunk, see in BaseEncoder\n", - " positionwise_conv_kernel_size (int): Kernel size of positionwise\n", - " conv1d layer.\n", - " macaron_style (bool): Whether to use macaron style for\n", - " positionwise layer.\n", - " selfattention_layer_type (str): Encoder attention layer type,\n", - " the parameter has no effect now, it's just for configure\n", - " compatibility.\n", - " activation_type (str): Encoder activation function type.\n", - " use_cnn_module (bool): Whether to use convolution module.\n", - " cnn_module_kernel (int): Kernel size of convolution module.\n", - " causal (bool): whether to use causal convolution or not.\n", - " cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']\n", - " \"\"\"\n", - " assert check_argument_types()\n", - " super().__init__(input_size, output_size, attention_heads, linear_units,\n", - " num_blocks, dropout_rate, positional_dropout_rate,\n", - " attention_dropout_rate, input_layer,\n", - " pos_enc_layer_type, normalize_before, concat_after,\n", - " static_chunk_size, use_dynamic_chunk, global_cmvn,\n", - " use_dynamic_left_chunk)\n", - " activation = get_activation(activation_type)\n", - "\n", - " # self-attention module definition\n", - " encoder_selfattn_layer = RelPositionMultiHeadedAttention\n", - " encoder_selfattn_layer_args = (attention_heads, output_size,\n", - " attention_dropout_rate)\n", - " # feed-forward module definition\n", - " positionwise_layer = PositionwiseFeedForward\n", - " positionwise_layer_args = (output_size, linear_units, dropout_rate,\n", - " activation)\n", - " # convolution module definition\n", - " convolution_layer = ConvolutionModule\n", - " convolution_layer_args = (output_size, cnn_module_kernel, activation,\n", - " cnn_module_norm, causal)\n", - "\n", - " self.encoders = nn.LayerList([\n", - " ConformerEncoderLayer(\n", - " size=output_size,\n", - " self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),\n", - " feed_forward=positionwise_layer(*positionwise_layer_args),\n", - " feed_forward_macaron=positionwise_layer(\n", - " *positionwise_layer_args) if macaron_style else None,\n", - " conv_module=convolution_layer(*convolution_layer_args)\n", - " if use_cnn_module else None,\n", - " dropout_rate=dropout_rate,\n", - " normalize_before=normalize_before,\n", - " concat_after=concat_after) for _ in range(num_blocks)\n", - " ])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "tutorial-syndication", - "metadata": {}, - "outputs": [], - "source": [ - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.modules.cmvn import GlobalCMVN\n", - "\n", - "configs=cfg.model\n", - "mean, istd = load_cmvn(configs['cmvn_file'],\n", - " configs['cmvn_file_type'])\n", - "global_cmvn = GlobalCMVN(\n", - " paddle.to_tensor(mean, dtype=paddle.float),\n", - " paddle.to_tensor(istd, dtype=paddle.float))\n", - "\n", - "\n", - "input_dim = configs['input_dim']\n", - "vocab_size = configs['output_dim']\n", - "encoder_type = configs.get('encoder', 'transformer')\n", - " \n", - "encoder = ConformerEncoder(\n", - " input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "fuzzy-register", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] - } - ], - "source": [ - "o = global_cmvn(feat)\n", - "o2 = model.encoder.global_cmvn(feat)\n", - "print(np.allclose(o.numpy(), o2.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "explicit-triumph", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "humanitarian-belgium", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dying-proposal", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "honest-quick", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bound-cholesterol", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "viral-packaging", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 203, - "id": "balanced-locator", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 1, 207], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[True , True , True , ..., True , True , True ]],\n", - "\n", - " [[True , True , True , ..., True , True , True ]],\n", - "\n", - " [[True , True , True , ..., True , False, False]],\n", - "\n", - " ...,\n", - "\n", - " [[True , True , True , ..., False, False, False]],\n", - "\n", - " [[True , True , True , ..., False, False, False]],\n", - "\n", - " [[True , True , True , ..., False, False, False]]])\n" - ] - } - ], - "source": [ - "from deepspeech.modules.mask import make_non_pad_mask\n", - "from deepspeech.modules.mask import make_pad_mask\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "print(masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 204, - "id": "induced-proposition", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 207, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[-0.53697914, -0.19910523, -0.34997201, ..., -0.82427669, -1.02650309, -0.96300691],\n", - " [-0.04464225, 0.23176001, -0.32538742, ..., -0.90158713, -1.03248465, -0.75986791],\n", - " [ 0.50035292, 0.22691160, -0.73052198, ..., -1.00552964, -0.87123060, -1.03062117],\n", - " ...,\n", - " [-0.40023831, -0.14325078, -0.57947433, ..., -1.07178426, -1.28059900, -1.05180073],\n", - " [ 0.15755332, -0.00184949, -0.28702953, ..., -1.10898709, -0.94518697, -0.72506356],\n", - " [-0.47520429, -1.39415145, -0.25754252, ..., -1.13649082, -1.19430351, -1.22903371]],\n", - "\n", - " [[ 0.95454037, 0.36427975, -1.38908529, ..., -1.16366839, -1.28453600, -1.20151031],\n", - " [-0.08573537, -1.05785275, -0.89172721, ..., -0.96440506, -1.12547100, -1.25990939],\n", - " [ 0.47653601, 0.32886592, -0.59200549, ..., -1.19421589, -1.14302588, -1.02422845],\n", - " ...,\n", - " [-0.47431335, -0.33558893, -0.72325647, ..., -1.45058632, -1.39574063, -1.04641151],\n", - " [ 0.36112556, 0.10380996, -1.15994537, ..., -1.04394984, -1.02212358, -1.02083635],\n", - " [-1.27172923, -2.14601755, -0.75676596, ..., -0.97822225, -0.93785471, -1.03707945]],\n", - "\n", - " [[-1.54652190, -1.01517177, -0.88900733, ..., -0.48522446, -0.75163364, -0.67765164],\n", - " [-0.76100892, -0.73351598, -0.91587651, ..., -0.24835993, -0.58927339, -0.73722762],\n", - " [-0.02471367, 0.17015894, -0.42326337, ..., -0.33203802, -0.76695800, -0.71651691],\n", - " ...,\n", - " [-1.70319796, -1.25910866, -1.14492917, ..., -1.18101490, -1.11631835, -0.93108195],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.64982772, 0.26116797, -0.84196597, ..., -0.87213463, -1.10728693, -1.32531130],\n", - " [ 0.35391113, -0.01584581, -0.40424931, ..., -0.99173468, -1.07270539, -1.19239008],\n", - " [ 0.37704495, -0.06278508, -0.11467686, ..., -1.10212946, -1.09524000, -1.11815071],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " [[ 0.04445776, -0.17546852, -0.67475224, ..., -0.49801198, -0.56782746, -0.77852231],\n", - " [-1.34279025, -0.80342549, -0.90457231, ..., -0.65901577, -0.72549772, -0.62796098],\n", - " [-0.76252806, -0.13071291, -0.13280024, ..., -0.56132573, -0.60587686, -0.72114766],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " [[-1.07980299, -1.08341801, -1.17969072, ..., -0.17757270, -0.43746525, -0.04000654],\n", - " [ 0.92353648, 0.63770926, -0.52810186, ..., -0.12927933, -0.20342292, 0.16655664],\n", - " [ 0.49337494, -0.00911332, -0.73301607, ..., 0.10074048, -0.09811471, -0.00923573],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]]])\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "print(xs)" - ] - }, - { - "cell_type": "code", - "execution_count": 205, - "id": "cutting-julian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 256, 51, 19], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0.00209083],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0.01194306, 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0.04610471, 0. ],\n", - " [0. , 0. , 0. , ..., 0.00967231, 0.04613467, 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.22816099, 0.24614786, 0.25304127, ..., 0.20401822, 0.23248228, 0.31190544],\n", - " [0.13587360, 0.28877240, 0.27991283, ..., 0.19210319, 0.20346391, 0.19934426],\n", - " [0.25739068, 0.39348233, 0.27877361, ..., 0.27482539, 0.19302306, 0.23810163],\n", - " ...,\n", - " [0.11939213, 0.28473237, 0.33082074, ..., 0.23838061, 0.22104350, 0.23905794],\n", - " [0.17387670, 0.20402060, 0.40263173, ..., 0.24782266, 0.26742202, 0.15426503],\n", - " [0. , 0.29080707, 0.27725950, ..., 0.17539823, 0.18478745, 0.22483408]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.35446781, 0.38861471, 0.39724261, ..., 0.38680089, 0.33568040, 0.34552398],\n", - " [0.41739127, 0.51038563, 0.41729912, ..., 0.33992639, 0.37081629, 0.35109508],\n", - " [0.36116859, 0.40744874, 0.48490953, ..., 0.34848654, 0.32321057, 0.35188958],\n", - " ...,\n", - " [0.23143977, 0.38021481, 0.51526314, ..., 0.36499465, 0.37411752, 0.39986172],\n", - " [0.34678638, 0.40238205, 0.50076538, ..., 0.36184520, 0.31596646, 0.36334658],\n", - " [0.36498138, 0.37943166, 0.51718897, ..., 0.31798238, 0.33656698, 0.34130475]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.01456045, 0.09447514, 0. , ..., 0. , 0. , 0. ],\n", - " [0.01500242, 0.02963220, 0. , ..., 0. , 0. , 0. ],\n", - " [0.03295187, 0. , 0. , ..., 0.04584959, 0.02043908, 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0.04425837],\n", - " [0. , 0. , 0.02556529, ..., 0. , 0.00900441, 0.04908358]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.11141267, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.33696529, 0.38526866, 0.32900479, ..., 0.28703830, 0.23351061, 0.19004467],\n", - " [0.13575366, 0.35783342, 0.33573425, ..., 0.22081660, 0.15854910, 0.13587447],\n", - " [0.21928655, 0.28900093, 0.28255141, ..., 0.20602837, 0.23927397, 0.21909429],\n", - " ...,\n", - " [0.23291890, 0.39096734, 0.36399242, ..., 0.20598020, 0.25373828, 0.23137446],\n", - " [0.18739152, 0.30793777, 0.30296701, ..., 0.27250600, 0.25191751, 0.20836820],\n", - " [0.22454213, 0.41402060, 0.54082996, ..., 0.31874508, 0.25079906, 0.25938687]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.26456982, 0.49519050, 0.56702250, ..., 0.30954638, 0.35292268, 0.32668519],\n", - " [0.21576807, 0.51833367, 0.49183372, ..., 0.36043224, 0.38523889, 0.36154741],\n", - " [0.20067888, 0.42784205, 0.52817714, ..., 0.31871423, 0.32452232, 0.31036487],\n", - " ...,\n", - " [0.49855131, 0.51001430, 0.52278662, ..., 0.36450142, 0.34338164, 0.33602941],\n", - " [0.41233343, 0.55517823, 0.52827710, ..., 0.40675971, 0.33873138, 0.36724189],\n", - " [0.40820011, 0.46187383, 0.47338152, ..., 0.38690975, 0.36039269, 0.38022059]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0.00578516, 0. , ..., 0.00748384, 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0.03035110, 0. , 0.00026720],\n", - " [0.00094807, 0. , 0. , ..., 0.00795512, 0. , 0. ],\n", - " ...,\n", - " [0.02032628, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0.01080076, 0. ],\n", - " [0.18470290, 0. , 0. , ..., 0.05058352, 0.09475817, 0.05914564]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.38708323, 0.28021947, 0.35892880, ..., 0.16595127, 0.16031364, 0.21136315],\n", - " [0.15595171, 0.30544323, 0.24666184, ..., 0.22675267, 0.25765014, 0.19682154],\n", - " [0.29517862, 0.41209796, 0.20063159, ..., 0.17595036, 0.22536841, 0.22214051],\n", - " ...,\n", - " [0.24744980, 0.26258564, 0.38654143, ..., 0.23620218, 0.23157144, 0.18514194],\n", - " [0.25714791, 0.29592845, 0.47744542, ..., 0.23545510, 0.25072727, 0.20976165],\n", - " [1.20154655, 0.84644288, 0.73385584, ..., 1.02517247, 0.95309550, 1.00134516]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.45013186, 0.47484034, 0.40540054, ..., 0.19346163, 0.17825794, 0.14776605],\n", - " [0.47545874, 0.48186573, 0.36760187, ..., 0.27809089, 0.32997063, 0.32337096],\n", - " [0.46160024, 0.40050328, 0.39060861, ..., 0.36612910, 0.35242686, 0.29738861],\n", - " ...,\n", - " [0.55148494, 0.51017821, 0.40132499, ..., 0.38948193, 0.35737294, 0.33088297],\n", - " [0.41972569, 0.45475486, 0.45320493, ..., 0.38343129, 0.40125814, 0.36180776],\n", - " [0.34279808, 0.31606171, 0.44701228, ..., 0.21665487, 0.23984617, 0.23903391]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.04178291, 0. , 0.01580476, ..., 0. , 0.02250817, 0. ],\n", - " [0.04323414, 0.07786420, 0. , ..., 0.01634724, 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.03209178, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.13563479, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0. , 0.25187218, 0.24979387, ..., 0.24774717, 0.22354351, 0.19149347],\n", - " [0.16540922, 0.19585510, 0.19812922, ..., 0.27344131, 0.20928150, 0.26150429],\n", - " [0.10494646, 0.06329897, 0.33843631, ..., 0.25138417, 0.12470355, 0.23926635],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.11428106, 0.45667490, 0.46820879, ..., 0.32057840, 0.33578536, 0.39012644],\n", - " [0.10441341, 0.45739070, 0.46107352, ..., 0.38467997, 0.38291249, 0.36685589],\n", - " [0.19867736, 0.35519636, 0.44313061, ..., 0.40679252, 0.38067645, 0.30645671],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.02465414, 0. , 0. , ..., 0. , 0. , 0.03390232],\n", - " [0. , 0. , 0.01830704, ..., 0.05166877, 0.00948385, 0.07453502],\n", - " [0.09921519, 0. , 0.01587192, ..., 0.01620276, 0.05140074, 0.00192392],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.40034360, 0.25306445, 0.20217699, ..., 0.09816189, 0.07064310, 0.04974059],\n", - " [0.12567598, 0.21030979, 0.11181555, ..., 0.04278110, 0.11968569, 0.12005232],\n", - " [0.28786880, 0.24030517, 0.22565845, ..., 0. , 0.06418110, 0.05872961],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.38404641, 0.30990323, 0.37156230, ..., 0.18125033, 0.15050662, 0.19619957],\n", - " [0.47285745, 0.40528792, 0.39718056, ..., 0.24709940, 0.04565683, 0.11500744],\n", - " [0.32620737, 0.30072594, 0.30477354, ..., 0.23529193, 0.21356541, 0.16985542],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.03343770, 0.00123780, 0.05297198, ..., 0.07271163, 0.08656286, 0.14493589],\n", - " [0.11043239, 0.06143146, 0.06362963, ..., 0.08127750, 0.06259022, 0.08315435],\n", - " [0.01767678, 0.00201111, 0.07875030, ..., 0.06963293, 0.08979890, 0.05326346],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.10033827, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.15627117, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.05144687, 0. , 0. , ..., 0. , 0. , 0.00436414],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.25142455, 0.45964020, 0.37346074, ..., 0.04763087, 0. , 0. ],\n", - " [0.19760093, 0.26626948, 0.11190540, ..., 0.03044968, 0. , 0. ],\n", - " [0.16340607, 0.32938001, 0.25689697, ..., 0.05569421, 0. , 0. ],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0.02218930, 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0.02848953],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.25810039, 0.63016868, 0.37037861, ..., 0.18704373, 0.08269356, 0.09912672],\n", - " [0.17292863, 0.50678611, 0.40738991, ..., 0.16006103, 0.11725381, 0.09940521],\n", - " [0.24175072, 0.41616210, 0.41256818, ..., 0.13519743, 0.07912572, 0.12846369],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]]])\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "\n", - "#xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "# print(xs)\n", - "\n", - "x = xs.unsqueeze(1)\n", - "x = model.encoder.embed.conv(x)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 206, - "id": "friendly-nightlife", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[-0.03426375, 0.14291267, -0.06718873, ..., 0.09064753, 0.01809387, -0.04340880],\n", - " [-0.05007839, 0.11054724, -0.10399298, ..., 0.11457238, 0.04244684, -0.01249714],\n", - " [-0.10695291, 0.16910909, -0.08352133, ..., 0.07710276, 0.01168563, -0.03584499],\n", - " ...,\n", - " [-0.06060536, 0.14455931, -0.05470302, ..., 0.05364908, 0.03033342, -0.02610814],\n", - " [-0.08505894, 0.13611752, -0.11132983, ..., 0.13079923, 0.01580139, -0.02281028],\n", - " [-0.10604677, 0.14714901, -0.10885533, ..., 0.08543444, 0.03719445, -0.04634233]],\n", - "\n", - " [[-0.12392755, 0.14486063, -0.05674079, ..., 0.02573164, 0.03128851, 0.00545091],\n", - " [-0.04775286, 0.08473608, -0.08507854, ..., 0.04573154, 0.04240163, 0.01053247],\n", - " [-0.05940291, 0.10023535, -0.08143730, ..., 0.03596500, 0.01673085, 0.02089563],\n", - " ...,\n", - " [-0.09222981, 0.15823206, -0.07700447, ..., 0.08122957, 0.03136991, -0.00646474],\n", - " [-0.07331756, 0.14482647, -0.07838815, ..., 0.10869440, 0.01356864, -0.02777974],\n", - " [-0.07937264, 0.20143102, -0.05544947, ..., 0.10287814, 0.00608235, -0.04799180]],\n", - "\n", - " [[-0.03670349, 0.08931590, -0.08718812, ..., 0.01314050, 0.00642052, 0.00573716],\n", - " [ 0.01089254, 0.11146393, -0.10263617, ..., 0.05070438, 0.01960694, 0.03521532],\n", - " [-0.02182280, 0.11443964, -0.06678198, ..., 0.04327708, 0.00861394, 0.02871092],\n", - " ...,\n", - " [-0.06792898, 0.14376275, -0.07899005, ..., 0.11248926, 0.03208683, -0.03264240],\n", - " [-0.07884051, 0.17024788, -0.08583611, ..., 0.09028331, 0.03588808, -0.02075090],\n", - " [-0.13792302, 0.27163863, -0.23930418, ..., 0.13391261, 0.07521040, -0.08621951]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.02446348, 0.11595841, -0.03591986, ..., 0.06288970, 0.02895011, -0.06532725],\n", - " [-0.05378424, 0.12607370, -0.09023033, ..., 0.09078894, 0.01035743, 0.03701983],\n", - " [-0.04566649, 0.14275314, -0.06686870, ..., 0.09890588, -0.00612222, 0.03439377],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n", - "\n", - " [[-0.01012144, 0.03909408, -0.07077143, ..., 0.00452683, -0.01377654, 0.02897627],\n", - " [-0.00519154, 0.03594019, -0.06831125, ..., 0.05693541, -0.00406374, 0.04561640],\n", - " [-0.01762631, 0.00500899, -0.05886075, ..., 0.02112178, -0.00729015, 0.02782153],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n", - "\n", - " [[-0.03411558, -0.04318277, -0.08497842, ..., -0.04886402, 0.04296734, 0.06151697],\n", - " [ 0.00263296, -0.06913657, -0.08993219, ..., -0.00149064, 0.05696633, 0.03304394],\n", - " [-0.01818341, -0.01178640, -0.09679577, ..., -0.00870231, 0.00362198, 0.01916483],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]]])\n", - "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[-0.54821998, 2.28660274, -1.07501972, ..., 1.45036042, 0.28950194, -0.69454080],\n", - " [-0.80125421, 1.76875579, -1.66388774, ..., 1.83315802, 0.67914939, -0.19995420],\n", - " [-1.71124649, 2.70574546, -1.33634126, ..., 1.23364413, 0.18697014, -0.57351983],\n", - " ...,\n", - " [-0.96968573, 2.31294894, -0.87524825, ..., 0.85838526, 0.48533469, -0.41773027],\n", - " [-1.36094308, 2.17788029, -1.78127730, ..., 2.09278774, 0.25282228, -0.36496443],\n", - " [-1.69674826, 2.35438418, -1.74168527, ..., 1.36695099, 0.59511113, -0.74147725]],\n", - "\n", - " [[-1.98284078, 2.31777000, -0.90785271, ..., 0.41170627, 0.50061619, 0.08721463],\n", - " [-0.76404583, 1.35577726, -1.36125672, ..., 0.73170459, 0.67842603, 0.16851945],\n", - " [-0.95044655, 1.60376561, -1.30299675, ..., 0.57544005, 0.26769355, 0.33433008],\n", - " ...,\n", - " [-1.47567701, 2.53171301, -1.23207152, ..., 1.29967308, 0.50191855, -0.10343577],\n", - " [-1.17308092, 2.31722355, -1.25421047, ..., 1.73911047, 0.21709818, -0.44447583],\n", - " [-1.26996231, 3.22289634, -0.88719147, ..., 1.64605021, 0.09731755, -0.76786882]],\n", - "\n", - " [[-0.58725590, 1.42905438, -1.39500988, ..., 0.21024795, 0.10272825, 0.09179455],\n", - " [ 0.17428070, 1.78342295, -1.64217877, ..., 0.81127012, 0.31371105, 0.56344515],\n", - " [-0.34916472, 1.83103430, -1.06851172, ..., 0.69243336, 0.13782299, 0.45937473],\n", - " ...,\n", - " [-1.08686376, 2.30020404, -1.26384079, ..., 1.79982817, 0.51338923, -0.52227837],\n", - " [-1.26144814, 2.72396612, -1.37337780, ..., 1.44453299, 0.57420933, -0.33201432],\n", - " [-2.20676827, 4.34621811, -3.82886696, ..., 2.14260173, 1.20336640, -1.37951219]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.39141566, 1.85533464, -0.57471782, ..., 1.00623512, 0.46320182, -1.04523599],\n", - " [-0.86054784, 2.01717925, -1.44368529, ..., 1.45262301, 0.16571884, 0.59231722],\n", - " [-0.73066384, 2.28405023, -1.06989920, ..., 1.58249414, -0.09795550, 0.55030036],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n", - "\n", - " [[-0.16194311, 0.62550521, -1.13234293, ..., 0.07242929, -0.22042468, 0.46362036],\n", - " [-0.08306468, 0.57504302, -1.09298003, ..., 0.91096652, -0.06501988, 0.72986233],\n", - " [-0.28202093, 0.08014385, -0.94177192, ..., 0.33794850, -0.11664233, 0.44514441],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n", - "\n", - " [[-0.54584920, -0.69092435, -1.35965478, ..., -0.78182435, 0.68747747, 0.98427159],\n", - " [ 0.04212743, -1.10618520, -1.43891501, ..., -0.02385022, 0.91146135, 0.52870303],\n", - " [-0.29093450, -0.18858244, -1.54873240, ..., -0.13923697, 0.05795169, 0.30663735],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]]])\n", - "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n", - " [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n", - " [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n", - " ...,\n", - " [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n", - " [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n", - " [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n" - ] - } - ], - "source": [ - "b, c, t, f = paddle.shape(x)\n", - "x = model.encoder.embed.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))\n", - "print(x)\n", - "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", - "print(x)\n", - "print(pos_emb)" - ] - }, - { - "cell_type": "code", - "execution_count": 207, - "id": "guilty-cache", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n", - " [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n", - " [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n", - " ...,\n", - " [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n", - " [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n", - " [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n" - ] - } - ], - "source": [ - "print(pos_emb)" - ] - }, - { - "cell_type": "code", - "execution_count": 208, - "id": "iraqi-payday", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", - " 0.0000000e+00 1.0000000e+00]\n", - " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", - " 1.0746076e-04 1.0000000e+00]\n", - " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", - " 2.1492151e-04 1.0000000e+00]\n", - " ...\n", - " [ 9.5625257e-01 -2.9254240e-01 4.8925215e-01 ... 8.3807874e-01\n", - " 5.1154459e-01 8.5925674e-01]\n", - " [ 2.7049953e-01 -9.6272010e-01 9.9170387e-01 ... 8.3801574e-01\n", - " 5.1163691e-01 8.5920173e-01]\n", - " [-6.6394955e-01 -7.4777740e-01 6.9544029e-01 ... 8.3795273e-01\n", - " 5.1172924e-01 8.5914677e-01]]]\n", - "[1, 5000, 256]\n" - ] - } - ], - "source": [ - "import torch\n", - "import math\n", - "import numpy as np\n", - "\n", - "max_len=5000\n", - "d_model=256\n", - "\n", - "pe = torch.zeros(max_len, d_model)\n", - "position = torch.arange(0, max_len,\n", - " dtype=torch.float32).unsqueeze(1)\n", - "toruch_position = position\n", - "div_term = torch.exp(\n", - " torch.arange(0, d_model, 2, dtype=torch.float32) *\n", - " -(math.log(10000.0) / d_model))\n", - "tourch_div_term = div_term.cpu().detach().numpy()\n", - "\n", - "torhc_sin = torch.sin(position * div_term)\n", - "torhc_cos = torch.cos(position * div_term)\n", - "\n", - "np_sin = np.sin((position * div_term).cpu().detach().numpy())\n", - "np_cos = np.cos((position * div_term).cpu().detach().numpy())\n", - "pe[:, 0::2] = torhc_sin\n", - "pe[:, 1::2] = torhc_cos\n", - "pe = pe.unsqueeze(0) \n", - "tourch_pe = pe.cpu().detach().numpy()\n", - "print(tourch_pe)\n", - "bak_pe = model.encoder.embed.pos_enc.pe\n", - "print(bak_pe.shape)\n", - "model.encoder.embed.pos_enc.pe = paddle.to_tensor(tourch_pe)" - ] - }, - { - "cell_type": "code", - "execution_count": 210, - "id": "exempt-cloud", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "#print(xs)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "composite-involvement", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 269, - "id": "handed-harris", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "False\n", - "True\n", - "[256, 2048]\n", - "[2048]\n", - "[2048, 256]\n", - "[256]\n", - "--------ff-------\n", - "True\n", - "False\n", - "False\n", - "False\n", - "False\n", - "True\n", - "linear_714.w_0 True\n", - "linear_714.b_0 True\n", - "linear_715.w_0 True\n", - "linear_715.b_0 True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "for layer in model.encoder.encoders:\n", - " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " print(layer.feed_forward_macaron is not None)\n", - " print(layer.normalize_before)\n", - " \n", - " data = np.load('.notebook/enc_0_norm_ff.npz')\n", - " t_norm_ff = data['norm_ff']\n", - " t_xs = data['xs']\n", - " \n", - " \n", - " x = xs\n", - " print(np.allclose(t_xs, x.numpy()))\n", - " residual = x\n", - " print(np.allclose(t_xs, residual.numpy()))\n", - " x_nrom = layer.norm_ff_macaron(x)\n", - " print(np.allclose(t.numpy(), x_nrom.numpy()))\n", - " print(np.allclose(t_norm_ff, x_nrom.numpy()))\n", - "# for n, p in layer.norm_ff_macaron.state_dict().items():\n", - "# print(n, p)\n", - "# pass\n", - "\n", - " layer.eval()\n", - " x_nrom = paddle.to_tensor(t_norm_ff)\n", - " print(np.allclose(t_norm_ff, x_nrom.numpy()))\n", - " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_nrom)\n", - " \n", - " ps=[]\n", - " for n, p in layer.feed_forward_macaron.state_dict().items():\n", - " #print(n, p)\n", - " ps.append(p)\n", - " print(p.shape)\n", - " pass\n", - "\n", - " x_nrom = paddle.to_tensor(t_norm_ff)\n", - " ff_l_x = layer.feed_forward_macaron.w_1(x_nrom)\n", - " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", - " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", - " data = np.load('.notebook/enc_0_ff_out.npz', allow_pickle=True)\n", - " t_norm_ff = data['norm_ff']\n", - " t_ff_out = data['ff_out']\n", - " t_ff_l_x = data['ff_l_x']\n", - " t_ff_l_a_x = data['ff_l_a_x']\n", - " t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - " t_ps = data['ps']\n", - " \n", - " print(\"--------ff-------\")\n", - " print(np.allclose(x_nrom.numpy(), t_norm_ff))\n", - " print(np.allclose(x.numpy(), t_ff_out))\n", - " print(np.allclose(ff_l_x.numpy(), t_ff_l_x))\n", - " print(np.allclose(ff_l_a_x.numpy(), t_ff_l_a_x))\n", - " print(np.allclose(ff_l_a_l_x.numpy(), t_ff_l_a_l_x))\n", - " \n", - " print(np.allclose(ff_l_x.numpy(), t_ff_l_x, atol=1e-6))\n", - " for p, t_p in zip(ps, t_ps):\n", - " print(p.name, np.allclose(p.numpy(), t_p.T))\n", - " \n", - " \n", - "# residual = x\n", - "# x = layer.norm_mha(x)\n", - "# x_q = x\n", - " \n", - " data = np.load('.notebook/enc_0_selattn_out.npz', allow_pickle=True)\n", - " tx_q = data['x_q']\n", - " tx = data['x']\n", - " tpos_emb=data['pos_emb']\n", - " tmask=data['mask']\n", - " tt_x_att=data['x_att']\n", - " x_q = paddle.to_tensor(tx_q)\n", - " x = paddle.to_tensor(tx)\n", - " pos_emb = paddle.to_tensor(tpos_emb)\n", - " mask = paddle.to_tensor(tmask)\n", - " \n", - " x_att = layer.self_attn(x_q, x, x, pos_emb, mask)\n", - " print(np.allclose(x_att.numpy(), t_x_att))\n", - " print(np.allclose(x_att.numpy(), t_x_att, atol=1e-6))\n", - " \n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 270, - "id": "sonic-thumb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " break\n", - "data = np.load('.notebook/enc_0.npz')\n", - "torch_xs = data['enc_0']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(xs.numpy(), torch_xs, atol=1e-6))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 273, - "id": "brave-latino", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "--------layers_______\n", - "False\n", - "True\n", - "[[-0.70194244 0.56254214 0.6880346 ... 1.1237319 0.7803924\n", - " 1.1369387 ]\n", - " [-0.7787783 0.3912667 0.71887773 ... 1.251882 0.886168\n", - " 1.3173451 ]\n", - " [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n", - " 1.2903278 ]\n", - " ...\n", - " [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n", - " 1.3239655 ]\n", - " [-1.1654118 0.6819967 0.6939453 ... 1.2238353 0.8028295\n", - " 1.4506507 ]\n", - " [-1.2732092 0.7145806 0.75819594 ... 0.94154835 0.8774845\n", - " 1.2623049 ]]\n", - "xxxxxx\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.273209 0.71458095 0.75819623 ... 0.9415484 0.8774842\n", - " 1.2623055 ]]\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "print(\"--------layers_______\")\n", - "i =0\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " i+=1\n", - "# if i == 2:\n", - "# data = np.load('.notebook/enc_2.npz')\n", - "# torch_xs = data['enc_2']\n", - "# print(np.allclose(xs.numpy(), torch_xs))\n", - "# print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n", - "# print(xs[0].numpy())\n", - "# print('xxxxxx')\n", - "# print(torch_xs[0])\n", - "# print('----i==2')\n", - "data = np.load('.notebook/enc_all.npz')\n", - "torch_xs = data['enc_all']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n", - "print(xs[0].numpy())\n", - "print('xxxxxx')\n", - "print(torch_xs[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "municipal-stock", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 278, - "id": "macro-season", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[-0.7019424 0.5625421 0.68803453 ... 1.1237317 0.7803923\n", - " 1.1369386 ]\n", - " [-0.7787783 0.39126673 0.71887773 ... 1.251882 0.886168\n", - " 1.3173451 ]\n", - " [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n", - " 1.2903278 ]\n", - " ...\n", - " [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n", - " 1.3239655 ]\n", - " [-1.1654117 0.68199664 0.6939452 ... 1.2238352 0.8028294\n", - " 1.4506506 ]\n", - " [-1.2732091 0.71458054 0.7581958 ... 0.9415482 0.8774844\n", - " 1.2623048 ]]\n", - "---\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n", - " 1.2623053 ]]\n", - "False\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "encoder_out, mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.numpy()[0])\n", - "print(\"---\")\n", - "print(torch_encoder_out[0])\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "associate-sampling", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/u2_tansformer_model_espnet.ipynb b/.notebook/u2_tansformer_model_espnet.ipynb deleted file mode 100644 index 75c2ea5c6c371d4cda43922f614c4a12cb389f9b..0000000000000000000000000000000000000000 --- a/.notebook/u2_tansformer_model_espnet.ipynb +++ /dev/null @@ -1,1672 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-grade", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "broke-broad", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "register user softmax to paddle, remove this when fixed!\n", - "register user log_softmax to paddle, remove this when fixed!\n", - "register user sigmoid to paddle, remove this when fixed!\n", - "register user log_sigmoid to paddle, remove this when fixed!\n", - "register user relu to paddle, remove this when fixed!\n", - "override cat of paddle if exists or register, remove this when fixed!\n", - "override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle if exists or register, remove this when fixed!\n", - "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "register user view to paddle.Tensor, remove this when fixed!\n", - "register user view_as to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "register user fill_ to paddle.Tensor, remove this when fixed!\n", - "register user repeat to paddle.Tensor, remove this when fixed!\n", - "register user softmax to paddle.Tensor, remove this when fixed!\n", - "register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "register user relu to paddle.Tensor, remove this when fixed!\n", - "register user type_as to paddle.Tensor, remove this when fixed!\n", - "register user to to paddle.Tensor, remove this when fixed!\n", - "register user float to paddle.Tensor, remove this when fixed!\n", - "register user tolist to paddle.Tensor, remove this when fixed!\n", - "register user glu to paddle.nn.functional, remove this when fixed!\n", - "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "register user Module to paddle.nn, remove this when fixed!\n", - "register user ModuleList to paddle.nn, remove this when fixed!\n", - "register user GLU to paddle.nn, remove this when fixed!\n", - "register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "register user export to paddle.jit, remove this when fixed!\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import paddle\n", - "from yacs.config import CfgNode as CN\n", - "\n", - "from deepspeech.models.u2 import U2Model\n", - "from deepspeech.utils.layer_tools import print_params\n", - "from deepspeech.utils.layer_tools import summary" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "permanent-summary", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[INFO 2021/05/31 03:23:22 u2.py:839] U2 Encoder type: transformer\n", - "[INFO 2021/05/31 03:23:22 u2.py:840] attention_dropout_rate: 0.0\n", - "attention_heads: 4\n", - "dropout_rate: 0.1\n", - "input_layer: conv2d\n", - "linear_units: 2048\n", - "normalize_before: True\n", - "num_blocks: 12\n", - "output_size: 256\n", - "positional_dropout_rate: 0.1\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "encoder.embed.conv.0.bias | [256] | 256 | True\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "encoder.embed.conv.2.bias | [256] | 256 | True\n", - "encoder.embed.out.0.weight | [5120, 256] | 1310720 | True\n", - "encoder.embed.out.0.bias | [256] | 256 | True\n", - "encoder.after_norm.weight | [256] | 256 | True\n", - "encoder.after_norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", - "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", - "decoder.after_norm.weight | [256] | 256 | True\n", - "decoder.after_norm.bias | [256] | 256 | True\n", - "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", - "decoder.output_layer.bias | [4233] | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", - "ctc.ctc_lo.bias | [4233] | 4233 | True\n", - "Total parameters: 411.0, 32.01M elements.\n" - ] - } - ], - "source": [ - "conf_str='examples/tiny/s1/conf/transformer.yaml'\n", - "cfg = CN().load_cfg(open(conf_str))\n", - "cfg.model.input_dim = 83\n", - "cfg.model.output_dim = 4233\n", - "cfg.model.cmvn_file = None\n", - "cfg.model.cmvn_file_type = 'json'\n", - "#cfg.model.encoder_conf.concat_after=True\n", - "cfg.freeze()\n", - "model = U2Model(cfg.model)\n", - "\n", - "print_params(model)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "sapphire-agent", - "metadata": {}, - "outputs": [], - "source": [ - "#summary(model)\n", - "#print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ruled-invitation", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "fossil-means", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n", - "encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook/espnet" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "45c2b75f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state\n", - "odict_keys(['mask_feature', 'encoder.embed.conv.0.weight', 'encoder.embed.conv.0.bias', 'encoder.embed.conv.2.weight', 'encoder.embed.conv.2.bias', 'encoder.embed.out.0.weight', 'encoder.embed.out.0.bias', 'encoder.encoders.0.self_attn.linear_q.weight', 'encoder.encoders.0.self_attn.linear_q.bias', 'encoder.encoders.0.self_attn.linear_k.weight', 'encoder.encoders.0.self_attn.linear_k.bias', 'encoder.encoders.0.self_attn.linear_v.weight', 'encoder.encoders.0.self_attn.linear_v.bias', 'encoder.encoders.0.self_attn.linear_out.weight', 'encoder.encoders.0.self_attn.linear_out.bias', 'encoder.encoders.0.feed_forward.w_1.weight', 'encoder.encoders.0.feed_forward.w_1.bias', 'encoder.encoders.0.feed_forward.w_2.weight', 'encoder.encoders.0.feed_forward.w_2.bias', 'encoder.encoders.0.norm1.weight', 'encoder.encoders.0.norm1.bias', 'encoder.encoders.0.norm2.weight', 'encoder.encoders.0.norm2.bias', 'encoder.encoders.1.self_attn.linear_q.weight', 'encoder.encoders.1.self_attn.linear_q.bias', 'encoder.encoders.1.self_attn.linear_k.weight', 'encoder.encoders.1.self_attn.linear_k.bias', 'encoder.encoders.1.self_attn.linear_v.weight', 'encoder.encoders.1.self_attn.linear_v.bias', 'encoder.encoders.1.self_attn.linear_out.weight', 'encoder.encoders.1.self_attn.linear_out.bias', 'encoder.encoders.1.feed_forward.w_1.weight', 'encoder.encoders.1.feed_forward.w_1.bias', 'encoder.encoders.1.feed_forward.w_2.weight', 'encoder.encoders.1.feed_forward.w_2.bias', 'encoder.encoders.1.norm1.weight', 'encoder.encoders.1.norm1.bias', 'encoder.encoders.1.norm2.weight', 'encoder.encoders.1.norm2.bias', 'encoder.encoders.2.self_attn.linear_q.weight', 'encoder.encoders.2.self_attn.linear_q.bias', 'encoder.encoders.2.self_attn.linear_k.weight', 'encoder.encoders.2.self_attn.linear_k.bias', 'encoder.encoders.2.self_attn.linear_v.weight', 'encoder.encoders.2.self_attn.linear_v.bias', 'encoder.encoders.2.self_attn.linear_out.weight', 'encoder.encoders.2.self_attn.linear_out.bias', 'encoder.encoders.2.feed_forward.w_1.weight', 'encoder.encoders.2.feed_forward.w_1.bias', 'encoder.encoders.2.feed_forward.w_2.weight', 'encoder.encoders.2.feed_forward.w_2.bias', 'encoder.encoders.2.norm1.weight', 'encoder.encoders.2.norm1.bias', 'encoder.encoders.2.norm2.weight', 'encoder.encoders.2.norm2.bias', 'encoder.encoders.3.self_attn.linear_q.weight', 'encoder.encoders.3.self_attn.linear_q.bias', 'encoder.encoders.3.self_attn.linear_k.weight', 'encoder.encoders.3.self_attn.linear_k.bias', 'encoder.encoders.3.self_attn.linear_v.weight', 'encoder.encoders.3.self_attn.linear_v.bias', 'encoder.encoders.3.self_attn.linear_out.weight', 'encoder.encoders.3.self_attn.linear_out.bias', 'encoder.encoders.3.feed_forward.w_1.weight', 'encoder.encoders.3.feed_forward.w_1.bias', 'encoder.encoders.3.feed_forward.w_2.weight', 'encoder.encoders.3.feed_forward.w_2.bias', 'encoder.encoders.3.norm1.weight', 'encoder.encoders.3.norm1.bias', 'encoder.encoders.3.norm2.weight', 'encoder.encoders.3.norm2.bias', 'encoder.encoders.4.self_attn.linear_q.weight', 'encoder.encoders.4.self_attn.linear_q.bias', 'encoder.encoders.4.self_attn.linear_k.weight', 'encoder.encoders.4.self_attn.linear_k.bias', 'encoder.encoders.4.self_attn.linear_v.weight', 'encoder.encoders.4.self_attn.linear_v.bias', 'encoder.encoders.4.self_attn.linear_out.weight', 'encoder.encoders.4.self_attn.linear_out.bias', 'encoder.encoders.4.feed_forward.w_1.weight', 'encoder.encoders.4.feed_forward.w_1.bias', 'encoder.encoders.4.feed_forward.w_2.weight', 'encoder.encoders.4.feed_forward.w_2.bias', 'encoder.encoders.4.norm1.weight', 'encoder.encoders.4.norm1.bias', 'encoder.encoders.4.norm2.weight', 'encoder.encoders.4.norm2.bias', 'encoder.encoders.5.self_attn.linear_q.weight', 'encoder.encoders.5.self_attn.linear_q.bias', 'encoder.encoders.5.self_attn.linear_k.weight', 'encoder.encoders.5.self_attn.linear_k.bias', 'encoder.encoders.5.self_attn.linear_v.weight', 'encoder.encoders.5.self_attn.linear_v.bias', 'encoder.encoders.5.self_attn.linear_out.weight', 'encoder.encoders.5.self_attn.linear_out.bias', 'encoder.encoders.5.feed_forward.w_1.weight', 'encoder.encoders.5.feed_forward.w_1.bias', 'encoder.encoders.5.feed_forward.w_2.weight', 'encoder.encoders.5.feed_forward.w_2.bias', 'encoder.encoders.5.norm1.weight', 'encoder.encoders.5.norm1.bias', 'encoder.encoders.5.norm2.weight', 'encoder.encoders.5.norm2.bias', 'encoder.encoders.6.self_attn.linear_q.weight', 'encoder.encoders.6.self_attn.linear_q.bias', 'encoder.encoders.6.self_attn.linear_k.weight', 'encoder.encoders.6.self_attn.linear_k.bias', 'encoder.encoders.6.self_attn.linear_v.weight', 'encoder.encoders.6.self_attn.linear_v.bias', 'encoder.encoders.6.self_attn.linear_out.weight', 'encoder.encoders.6.self_attn.linear_out.bias', 'encoder.encoders.6.feed_forward.w_1.weight', 'encoder.encoders.6.feed_forward.w_1.bias', 'encoder.encoders.6.feed_forward.w_2.weight', 'encoder.encoders.6.feed_forward.w_2.bias', 'encoder.encoders.6.norm1.weight', 'encoder.encoders.6.norm1.bias', 'encoder.encoders.6.norm2.weight', 'encoder.encoders.6.norm2.bias', 'encoder.encoders.7.self_attn.linear_q.weight', 'encoder.encoders.7.self_attn.linear_q.bias', 'encoder.encoders.7.self_attn.linear_k.weight', 'encoder.encoders.7.self_attn.linear_k.bias', 'encoder.encoders.7.self_attn.linear_v.weight', 'encoder.encoders.7.self_attn.linear_v.bias', 'encoder.encoders.7.self_attn.linear_out.weight', 'encoder.encoders.7.self_attn.linear_out.bias', 'encoder.encoders.7.feed_forward.w_1.weight', 'encoder.encoders.7.feed_forward.w_1.bias', 'encoder.encoders.7.feed_forward.w_2.weight', 'encoder.encoders.7.feed_forward.w_2.bias', 'encoder.encoders.7.norm1.weight', 'encoder.encoders.7.norm1.bias', 'encoder.encoders.7.norm2.weight', 'encoder.encoders.7.norm2.bias', 'encoder.encoders.8.self_attn.linear_q.weight', 'encoder.encoders.8.self_attn.linear_q.bias', 'encoder.encoders.8.self_attn.linear_k.weight', 'encoder.encoders.8.self_attn.linear_k.bias', 'encoder.encoders.8.self_attn.linear_v.weight', 'encoder.encoders.8.self_attn.linear_v.bias', 'encoder.encoders.8.self_attn.linear_out.weight', 'encoder.encoders.8.self_attn.linear_out.bias', 'encoder.encoders.8.feed_forward.w_1.weight', 'encoder.encoders.8.feed_forward.w_1.bias', 'encoder.encoders.8.feed_forward.w_2.weight', 'encoder.encoders.8.feed_forward.w_2.bias', 'encoder.encoders.8.norm1.weight', 'encoder.encoders.8.norm1.bias', 'encoder.encoders.8.norm2.weight', 'encoder.encoders.8.norm2.bias', 'encoder.encoders.9.self_attn.linear_q.weight', 'encoder.encoders.9.self_attn.linear_q.bias', 'encoder.encoders.9.self_attn.linear_k.weight', 'encoder.encoders.9.self_attn.linear_k.bias', 'encoder.encoders.9.self_attn.linear_v.weight', 'encoder.encoders.9.self_attn.linear_v.bias', 'encoder.encoders.9.self_attn.linear_out.weight', 'encoder.encoders.9.self_attn.linear_out.bias', 'encoder.encoders.9.feed_forward.w_1.weight', 'encoder.encoders.9.feed_forward.w_1.bias', 'encoder.encoders.9.feed_forward.w_2.weight', 'encoder.encoders.9.feed_forward.w_2.bias', 'encoder.encoders.9.norm1.weight', 'encoder.encoders.9.norm1.bias', 'encoder.encoders.9.norm2.weight', 'encoder.encoders.9.norm2.bias', 'encoder.encoders.10.self_attn.linear_q.weight', 'encoder.encoders.10.self_attn.linear_q.bias', 'encoder.encoders.10.self_attn.linear_k.weight', 'encoder.encoders.10.self_attn.linear_k.bias', 'encoder.encoders.10.self_attn.linear_v.weight', 'encoder.encoders.10.self_attn.linear_v.bias', 'encoder.encoders.10.self_attn.linear_out.weight', 'encoder.encoders.10.self_attn.linear_out.bias', 'encoder.encoders.10.feed_forward.w_1.weight', 'encoder.encoders.10.feed_forward.w_1.bias', 'encoder.encoders.10.feed_forward.w_2.weight', 'encoder.encoders.10.feed_forward.w_2.bias', 'encoder.encoders.10.norm1.weight', 'encoder.encoders.10.norm1.bias', 'encoder.encoders.10.norm2.weight', 'encoder.encoders.10.norm2.bias', 'encoder.encoders.11.self_attn.linear_q.weight', 'encoder.encoders.11.self_attn.linear_q.bias', 'encoder.encoders.11.self_attn.linear_k.weight', 'encoder.encoders.11.self_attn.linear_k.bias', 'encoder.encoders.11.self_attn.linear_v.weight', 'encoder.encoders.11.self_attn.linear_v.bias', 'encoder.encoders.11.self_attn.linear_out.weight', 'encoder.encoders.11.self_attn.linear_out.bias', 'encoder.encoders.11.feed_forward.w_1.weight', 'encoder.encoders.11.feed_forward.w_1.bias', 'encoder.encoders.11.feed_forward.w_2.weight', 'encoder.encoders.11.feed_forward.w_2.bias', 'encoder.encoders.11.norm1.weight', 'encoder.encoders.11.norm1.bias', 'encoder.encoders.11.norm2.weight', 'encoder.encoders.11.norm2.bias', 'encoder.after_norm.weight', 'encoder.after_norm.bias', 'decoder.embed.0.weight', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'sfc.weight', 'sfc.bias', 'deconv.0.weight', 'deconv.0.bias', 'deconv.1.weight', 'deconv.1.bias', 'xlm_embed.0.weight', 'xlm_pred.weight', 'xlm_pred.bias'])\n" - ] - } - ], - "source": [ - "#!pip install torch\n", - "import torch\n", - "\n", - "e_model = np.load('.notebook/espnet/model.npz',allow_pickle=True)\n", - "for k in e_model.files:\n", - " print(k)\n", - "state_dict = e_model['state']\n", - "state_dict = state_dict.tolist()\n", - "print(state_dict.keys())" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f187bb55", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "# embed.conv.0.weight None torch.Size([256, 1, 3, 3]) \tencoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "# embed.conv.0.bias None torch.Size([256]) \tencoder.embed.conv.0.bias | [256] | 256 | True\n", - "# embed.conv.2.weight None torch.Size([256, 256, 3, 3]) \tencoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "# embed.conv.2.bias None torch.Size([256]) \tencoder.embed.conv.2.bias | [256] | 256 | True\n", - "# embed.out.0.weight None torch.Size([256, 5120]) 83 feature\tencoder.embed.out.0.weight | [4864, 256] | 1245184 | True 80 feature\n", - "# embed.out.0.bias None torch.Size([256]) \tencoder.embed.out.0.bias | [256] | 256 | True\n", - "# after_norm.weight None torch.Size([256]) \tencoder.after_norm.weight | [256] | 256 | True\n", - "# after_norm.bias None torch.Size([256]) \tencoder.after_norm.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_q.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_q.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_k.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_k.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_v.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_v.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_out.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_out.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "# encoders.9.feed_forward.w_1.weight None torch.Size([2048, 256]) \tencoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "# encoders.9.feed_forward.w_1.bias None torch.Size([2048]) \tencoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "# encoders.9.feed_forward.w_2.weight None torch.Size([256, 2048]) \tencoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "# encoders.9.feed_forward.w_2.bias None torch.Size([256]) \tencoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "# encoders.9.norm1.weight None torch.Size([256]) \tencoder.encoders.0.norm1.weight | [256] | 256 | True\n", - "# encoders.9.norm1.bias None torch.Size([256]) \tencoder.encoders.0.norm1.bias | [256] | 256 | True\n", - "# encoders.9.norm2.weight None torch.Size([256]) \tencoder.encoders.0.norm2.weight | [256] | 256 | True\n", - "# encoders.9.norm2.bias None torch.Size([256]) \tencoder.encoders.0.norm2.bias | [256] | 256 | True\n", - "# \tencoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "# \tencoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "# espnet transformer\tconcat_linear只是保存了,但是未使用\n", - "\t\n", - "# \tpaddle transformer" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "2a0428ae", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-> encoder.embed.conv.0.weight\n", - "-> encoder.embed.conv.0.bias\n", - "-> encoder.embed.conv.2.weight\n", - "-> encoder.embed.conv.2.bias\n", - "-> encoder.embed.out.0.weight\n", - "encoder.embed.out.0.weight: (256, 5120) -> (5120, 256)\n", - "-> encoder.embed.out.0.bias\n", - "-> encoder.encoders.0.self_attn.linear_q.weight\n", - "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_q.bias\n", - "-> encoder.encoders.0.self_attn.linear_k.weight\n", - "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_k.bias\n", - "-> encoder.encoders.0.self_attn.linear_v.weight\n", - "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_v.bias\n", - "-> encoder.encoders.0.self_attn.linear_out.weight\n", - "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_out.bias\n", - "-> encoder.encoders.0.feed_forward.w_1.weight\n", - "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.0.feed_forward.w_1.bias\n", - "-> encoder.encoders.0.feed_forward.w_2.weight\n", - "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.0.feed_forward.w_2.bias\n", - "-> encoder.encoders.0.norm1.weight\n", - "-> encoder.encoders.0.norm1.bias\n", - "-> encoder.encoders.0.norm2.weight\n", - "-> encoder.encoders.0.norm2.bias\n", - "-> encoder.encoders.1.self_attn.linear_q.weight\n", - "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_q.bias\n", - "-> encoder.encoders.1.self_attn.linear_k.weight\n", - "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_k.bias\n", - "-> encoder.encoders.1.self_attn.linear_v.weight\n", - "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_v.bias\n", - "-> encoder.encoders.1.self_attn.linear_out.weight\n", - "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_out.bias\n", - "-> encoder.encoders.1.feed_forward.w_1.weight\n", - "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.1.feed_forward.w_1.bias\n", - "-> encoder.encoders.1.feed_forward.w_2.weight\n", - "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.1.feed_forward.w_2.bias\n", - "-> encoder.encoders.1.norm1.weight\n", - "-> encoder.encoders.1.norm1.bias\n", - "-> encoder.encoders.1.norm2.weight\n", - "-> encoder.encoders.1.norm2.bias\n", - "-> encoder.encoders.2.self_attn.linear_q.weight\n", - "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_q.bias\n", - "-> encoder.encoders.2.self_attn.linear_k.weight\n", - "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_k.bias\n", - "-> encoder.encoders.2.self_attn.linear_v.weight\n", - "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_v.bias\n", - "-> encoder.encoders.2.self_attn.linear_out.weight\n", - "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_out.bias\n", - "-> encoder.encoders.2.feed_forward.w_1.weight\n", - "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.2.feed_forward.w_1.bias\n", - "-> encoder.encoders.2.feed_forward.w_2.weight\n", - "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.2.feed_forward.w_2.bias\n", - "-> encoder.encoders.2.norm1.weight\n", - "-> encoder.encoders.2.norm1.bias\n", - "-> encoder.encoders.2.norm2.weight\n", - "-> encoder.encoders.2.norm2.bias\n", - "-> encoder.encoders.3.self_attn.linear_q.weight\n", - "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_q.bias\n", - "-> encoder.encoders.3.self_attn.linear_k.weight\n", - "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_k.bias\n", - "-> encoder.encoders.3.self_attn.linear_v.weight\n", - "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_v.bias\n", - "-> encoder.encoders.3.self_attn.linear_out.weight\n", - "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_out.bias\n", - "-> encoder.encoders.3.feed_forward.w_1.weight\n", - "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.3.feed_forward.w_1.bias\n", - "-> encoder.encoders.3.feed_forward.w_2.weight\n", - "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.3.feed_forward.w_2.bias\n", - "-> encoder.encoders.3.norm1.weight\n", - "-> encoder.encoders.3.norm1.bias\n", - "-> encoder.encoders.3.norm2.weight\n", - "-> encoder.encoders.3.norm2.bias\n", - "-> encoder.encoders.4.self_attn.linear_q.weight\n", - "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_q.bias\n", - "-> encoder.encoders.4.self_attn.linear_k.weight\n", - "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_k.bias\n", - "-> encoder.encoders.4.self_attn.linear_v.weight\n", - "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_v.bias\n", - "-> encoder.encoders.4.self_attn.linear_out.weight\n", - "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_out.bias\n", - "-> encoder.encoders.4.feed_forward.w_1.weight\n", - "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.4.feed_forward.w_1.bias\n", - "-> encoder.encoders.4.feed_forward.w_2.weight\n", - "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.4.feed_forward.w_2.bias\n", - "-> encoder.encoders.4.norm1.weight\n", - "-> encoder.encoders.4.norm1.bias\n", - "-> encoder.encoders.4.norm2.weight\n", - "-> encoder.encoders.4.norm2.bias\n", - "-> encoder.encoders.5.self_attn.linear_q.weight\n", - "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_q.bias\n", - "-> encoder.encoders.5.self_attn.linear_k.weight\n", - "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_k.bias\n", - "-> encoder.encoders.5.self_attn.linear_v.weight\n", - "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_v.bias\n", - "-> encoder.encoders.5.self_attn.linear_out.weight\n", - "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_out.bias\n", - "-> encoder.encoders.5.feed_forward.w_1.weight\n", - "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.5.feed_forward.w_1.bias\n", - "-> encoder.encoders.5.feed_forward.w_2.weight\n", - "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.5.feed_forward.w_2.bias\n", - "-> encoder.encoders.5.norm1.weight\n", - "-> encoder.encoders.5.norm1.bias\n", - "-> encoder.encoders.5.norm2.weight\n", - "-> encoder.encoders.5.norm2.bias\n", - "-> encoder.encoders.6.self_attn.linear_q.weight\n", - "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_q.bias\n", - "-> encoder.encoders.6.self_attn.linear_k.weight\n", - "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_k.bias\n", - "-> encoder.encoders.6.self_attn.linear_v.weight\n", - "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_v.bias\n", - "-> encoder.encoders.6.self_attn.linear_out.weight\n", - "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_out.bias\n", - "-> encoder.encoders.6.feed_forward.w_1.weight\n", - "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.6.feed_forward.w_1.bias\n", - "-> encoder.encoders.6.feed_forward.w_2.weight\n", - "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.6.feed_forward.w_2.bias\n", - "-> encoder.encoders.6.norm1.weight\n", - "-> encoder.encoders.6.norm1.bias\n", - "-> encoder.encoders.6.norm2.weight\n", - "-> encoder.encoders.6.norm2.bias\n", - "-> encoder.encoders.7.self_attn.linear_q.weight\n", - "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_q.bias\n", - "-> encoder.encoders.7.self_attn.linear_k.weight\n", - "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_k.bias\n", - "-> encoder.encoders.7.self_attn.linear_v.weight\n", - "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_v.bias\n", - "-> encoder.encoders.7.self_attn.linear_out.weight\n", - "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_out.bias\n", - "-> encoder.encoders.7.feed_forward.w_1.weight\n", - "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.7.feed_forward.w_1.bias\n", - "-> encoder.encoders.7.feed_forward.w_2.weight\n", - "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.7.feed_forward.w_2.bias\n", - "-> encoder.encoders.7.norm1.weight\n", - "-> encoder.encoders.7.norm1.bias\n", - "-> encoder.encoders.7.norm2.weight\n", - "-> encoder.encoders.7.norm2.bias\n", - "-> encoder.encoders.8.self_attn.linear_q.weight\n", - "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_q.bias\n", - "-> encoder.encoders.8.self_attn.linear_k.weight\n", - "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_k.bias\n", - "-> encoder.encoders.8.self_attn.linear_v.weight\n", - "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_v.bias\n", - "-> encoder.encoders.8.self_attn.linear_out.weight\n", - "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_out.bias\n", - "-> encoder.encoders.8.feed_forward.w_1.weight\n", - "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.8.feed_forward.w_1.bias\n", - "-> encoder.encoders.8.feed_forward.w_2.weight\n", - "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.8.feed_forward.w_2.bias\n", - "-> encoder.encoders.8.norm1.weight\n", - "-> encoder.encoders.8.norm1.bias\n", - "-> encoder.encoders.8.norm2.weight\n", - "-> encoder.encoders.8.norm2.bias\n", - "-> encoder.encoders.9.self_attn.linear_q.weight\n", - "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_q.bias\n", - "-> encoder.encoders.9.self_attn.linear_k.weight\n", - "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_k.bias\n", - "-> encoder.encoders.9.self_attn.linear_v.weight\n", - "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_v.bias\n", - "-> encoder.encoders.9.self_attn.linear_out.weight\n", - "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_out.bias\n", - "-> encoder.encoders.9.feed_forward.w_1.weight\n", - "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.9.feed_forward.w_1.bias\n", - "-> encoder.encoders.9.feed_forward.w_2.weight\n", - "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.9.feed_forward.w_2.bias\n", - "-> encoder.encoders.9.norm1.weight\n", - "-> encoder.encoders.9.norm1.bias\n", - "-> encoder.encoders.9.norm2.weight\n", - "-> encoder.encoders.9.norm2.bias\n", - "-> encoder.encoders.10.self_attn.linear_q.weight\n", - "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_q.bias\n", - "-> encoder.encoders.10.self_attn.linear_k.weight\n", - "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_k.bias\n", - "-> encoder.encoders.10.self_attn.linear_v.weight\n", - "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_v.bias\n", - "-> encoder.encoders.10.self_attn.linear_out.weight\n", - "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_out.bias\n", - "-> encoder.encoders.10.feed_forward.w_1.weight\n", - "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.10.feed_forward.w_1.bias\n", - "-> encoder.encoders.10.feed_forward.w_2.weight\n", - "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.10.feed_forward.w_2.bias\n", - "-> encoder.encoders.10.norm1.weight\n", - "-> encoder.encoders.10.norm1.bias\n", - "-> encoder.encoders.10.norm2.weight\n", - "-> encoder.encoders.10.norm2.bias\n", - "-> encoder.encoders.11.self_attn.linear_q.weight\n", - "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_q.bias\n", - "-> encoder.encoders.11.self_attn.linear_k.weight\n", - "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_k.bias\n", - "-> encoder.encoders.11.self_attn.linear_v.weight\n", - "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_v.bias\n", - "-> encoder.encoders.11.self_attn.linear_out.weight\n", - "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_out.bias\n", - "-> encoder.encoders.11.feed_forward.w_1.weight\n", - "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.11.feed_forward.w_1.bias\n", - "-> encoder.encoders.11.feed_forward.w_2.weight\n", - "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.11.feed_forward.w_2.bias\n", - "-> encoder.encoders.11.norm1.weight\n", - "-> encoder.encoders.11.norm1.bias\n", - "-> encoder.encoders.11.norm2.weight\n", - "-> encoder.encoders.11.norm2.bias\n", - "-> encoder.after_norm.weight\n", - "-> encoder.after_norm.bias\n" - ] - } - ], - "source": [ - "# dump torch model to paddle\n", - "#state_dict = model.state_dict()\n", - "paddle_state_dict = {}\n", - "\n", - "for n, p in state_dict.items():\n", - " if 'encoder' not in n:\n", - " continue \n", - " print(f'-> {n}')\n", - " \n", - " \n", - " name_change=True\n", - " if 'norm.running_mean' in n:\n", - " new_n = n.replace('norm.running_', 'norm._')\n", - " elif 'norm.running_var' in n:\n", - " new_n = n.replace('norm.running_var', 'norm._variance')\n", - " else:\n", - " name_change=False\n", - " new_n = n\n", - " if name_change:\n", - " print(f\"{n} -> {new_n}\")\n", - " \n", - " \n", - " p = p.cpu().detach().numpy()\n", - " if n.endswith('weight') and p.ndim == 2:\n", - " new_p = p.T\n", - " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", - " else:\n", - " new_p = p\n", - " \n", - " if 'global_cmvn.mean' in n:\n", - " print(p, p.dtype)\n", - " \n", - " paddle_state_dict[new_n] = new_p\n", - " \n", - "# np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", - "# state=paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a1d97e9f", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.weight. encoder.encoders.0.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.bias. encoder.encoders.0.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.weight. encoder.encoders.1.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.bias. encoder.encoders.1.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.weight. encoder.encoders.2.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.bias. encoder.encoders.2.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.weight. encoder.encoders.3.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.bias. encoder.encoders.3.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.weight. encoder.encoders.4.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.bias. encoder.encoders.4.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.weight. encoder.encoders.5.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.bias. encoder.encoders.5.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.weight. encoder.encoders.6.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.bias. encoder.encoders.6.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.weight. encoder.encoders.7.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.bias. encoder.encoders.7.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.weight. encoder.encoders.8.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.bias. encoder.encoders.8.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.weight. encoder.encoders.9.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.bias. encoder.encoders.9.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.weight. encoder.encoders.10.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.bias. encoder.encoders.10.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.weight. encoder.encoders.11.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.bias. encoder.encoders.11.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.embed.0.weight. decoder.embed.0.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.weight. decoder.after_norm.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.bias. decoder.after_norm.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.weight. decoder.output_layer.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.bias. decoder.output_layer.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.weight. decoder.decoders.0.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.bias. decoder.decoders.0.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.weight. decoder.decoders.0.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.bias. decoder.decoders.0.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.weight. decoder.decoders.0.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.bias. decoder.decoders.0.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.weight. decoder.decoders.0.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.bias. decoder.decoders.0.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.weight. decoder.decoders.0.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.bias. decoder.decoders.0.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.weight. decoder.decoders.0.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.bias. decoder.decoders.0.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.weight. decoder.decoders.0.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.bias. decoder.decoders.0.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.weight. decoder.decoders.0.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.bias. decoder.decoders.0.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.weight. decoder.decoders.0.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.bias. decoder.decoders.0.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.weight. decoder.decoders.0.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.bias. decoder.decoders.0.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.weight. decoder.decoders.0.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.bias. decoder.decoders.0.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.weight. decoder.decoders.0.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.bias. decoder.decoders.0.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.weight. decoder.decoders.0.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.bias. decoder.decoders.0.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.weight. decoder.decoders.0.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.bias. decoder.decoders.0.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.weight. decoder.decoders.0.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.bias. decoder.decoders.0.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.weight. decoder.decoders.1.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.bias. decoder.decoders.1.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.weight. decoder.decoders.1.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.bias. decoder.decoders.1.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.weight. decoder.decoders.1.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.bias. decoder.decoders.1.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.weight. decoder.decoders.1.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.bias. decoder.decoders.1.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.weight. decoder.decoders.1.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.bias. decoder.decoders.1.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.weight. decoder.decoders.1.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.bias. decoder.decoders.1.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.weight. decoder.decoders.1.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.bias. decoder.decoders.1.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.weight. decoder.decoders.1.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.bias. decoder.decoders.1.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.weight. decoder.decoders.1.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.bias. decoder.decoders.1.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.weight. decoder.decoders.1.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.bias. decoder.decoders.1.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.weight. decoder.decoders.1.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.bias. decoder.decoders.1.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.weight. decoder.decoders.1.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.bias. decoder.decoders.1.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.weight. decoder.decoders.1.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.bias. decoder.decoders.1.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.weight. decoder.decoders.1.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.bias. decoder.decoders.1.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.weight. decoder.decoders.1.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.bias. decoder.decoders.1.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.weight. decoder.decoders.2.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.bias. decoder.decoders.2.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.weight. decoder.decoders.2.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.bias. decoder.decoders.2.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.weight. decoder.decoders.2.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.bias. decoder.decoders.2.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.weight. decoder.decoders.2.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.bias. decoder.decoders.2.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.weight. decoder.decoders.2.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.bias. decoder.decoders.2.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.weight. decoder.decoders.2.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.bias. decoder.decoders.2.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.weight. decoder.decoders.2.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.bias. decoder.decoders.2.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.weight. decoder.decoders.2.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.bias. decoder.decoders.2.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.weight. decoder.decoders.2.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.bias. decoder.decoders.2.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.weight. decoder.decoders.2.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.bias. decoder.decoders.2.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.weight. decoder.decoders.2.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.bias. decoder.decoders.2.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.weight. decoder.decoders.2.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.bias. decoder.decoders.2.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.weight. decoder.decoders.2.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.bias. decoder.decoders.2.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.weight. decoder.decoders.2.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.bias. decoder.decoders.2.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.weight. decoder.decoders.2.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.bias. decoder.decoders.2.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.weight. decoder.decoders.3.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.bias. decoder.decoders.3.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.weight. decoder.decoders.3.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.bias. decoder.decoders.3.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.weight. decoder.decoders.3.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.bias. decoder.decoders.3.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.weight. decoder.decoders.3.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.bias. decoder.decoders.3.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.weight. decoder.decoders.3.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.bias. decoder.decoders.3.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.weight. decoder.decoders.3.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.bias. decoder.decoders.3.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.weight. decoder.decoders.3.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.bias. decoder.decoders.3.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.weight. decoder.decoders.3.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.bias. decoder.decoders.3.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.weight. decoder.decoders.3.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.bias. decoder.decoders.3.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.weight. decoder.decoders.3.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.bias. decoder.decoders.3.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.weight. decoder.decoders.3.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.bias. decoder.decoders.3.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.weight. decoder.decoders.3.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.bias. decoder.decoders.3.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.weight. decoder.decoders.3.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.bias. decoder.decoders.3.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.weight. decoder.decoders.3.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.bias. decoder.decoders.3.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.weight. decoder.decoders.3.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.bias. decoder.decoders.3.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.weight. decoder.decoders.4.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.bias. decoder.decoders.4.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.weight. decoder.decoders.4.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.bias. decoder.decoders.4.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.weight. decoder.decoders.4.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.bias. decoder.decoders.4.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.weight. decoder.decoders.4.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.bias. decoder.decoders.4.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.weight. decoder.decoders.4.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.bias. decoder.decoders.4.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.weight. decoder.decoders.4.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.bias. decoder.decoders.4.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.weight. decoder.decoders.4.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.bias. decoder.decoders.4.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.weight. decoder.decoders.4.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.bias. decoder.decoders.4.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.weight. decoder.decoders.4.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.bias. decoder.decoders.4.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.weight. decoder.decoders.4.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.bias. decoder.decoders.4.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.weight. decoder.decoders.4.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.bias. decoder.decoders.4.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.weight. decoder.decoders.4.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.bias. decoder.decoders.4.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.weight. decoder.decoders.4.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.bias. decoder.decoders.4.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.weight. decoder.decoders.4.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.bias. decoder.decoders.4.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.weight. decoder.decoders.4.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.bias. decoder.decoders.4.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.weight. decoder.decoders.5.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.bias. decoder.decoders.5.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.weight. decoder.decoders.5.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.bias. decoder.decoders.5.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.weight. decoder.decoders.5.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.bias. decoder.decoders.5.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.weight. decoder.decoders.5.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.bias. decoder.decoders.5.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.weight. decoder.decoders.5.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.bias. decoder.decoders.5.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.weight. decoder.decoders.5.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.bias. decoder.decoders.5.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.weight. decoder.decoders.5.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.bias. decoder.decoders.5.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.weight. decoder.decoders.5.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.bias. decoder.decoders.5.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.weight. decoder.decoders.5.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.bias. decoder.decoders.5.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.weight. decoder.decoders.5.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.bias. decoder.decoders.5.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.weight. decoder.decoders.5.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.bias. decoder.decoders.5.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.weight. decoder.decoders.5.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.bias. decoder.decoders.5.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.weight. decoder.decoders.5.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.bias. decoder.decoders.5.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.weight. decoder.decoders.5.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.bias. decoder.decoders.5.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.weight. decoder.decoders.5.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.bias. decoder.decoders.5.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.weight. ctc.ctc_lo.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.bias. ctc.ctc_lo.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n" - ] - } - ], - "source": [ - "model.set_state_dict(paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "fc7edf1e", - "metadata": {}, - "outputs": [], - "source": [ - "e_state = model.encoder.state_dict()\n", - "for key, value in e_state.items():\n", - " if 'concat_linear' in key:\n", - " continue\n", - " if not np.allclose(value.numpy(), paddle_state_dict['encoder.' + key]):\n", - " print(key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "572097d0", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "748250b7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "91e5deee", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "fleet-despite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n", - "encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook/espnet" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "abroad-oracle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(8, 57, 83)\n", - "(8, 1, 57)\n", - "[57 50 48 38 32 31 28 25]\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/espnet/feat.npz', allow_pickle=True)\n", - "xs=data['xs']\n", - "masks=data['masks']\n", - "print(xs.shape)\n", - "print(masks.shape)\n", - "xs_lens = masks.sum(axis=-1).squeeze()\n", - "print(xs_lens)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "false-instrument", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[8, 13, 256]\n", - "[8, 1, 13]\n" - ] - } - ], - "source": [ - "# ecnoder\n", - "xs = paddle.to_tensor(xs, dtype='float32')\n", - "x_lens = paddle.to_tensor(xs_lens, dtype='int32')\n", - "model.eval()\n", - "encoder_out, encoder_mask = model.encoder(xs, x_lens)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "arctic-proxy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(8, 13, 256)\n", - "(8, 1, 13)\n", - "False\n", - "False\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/espnet/encoder.npz', allow_pickle=True)\n", - "xs = data['xs']\n", - "masks = data['masks']\n", - "print(xs.shape)\n", - "print(masks.shape)\n", - "print(np.allclose(xs, encoder_out.numpy()))\n", - "print(np.allclose(xs, encoder_out.numpy(), atol=1e-6))\n", - "print(np.allclose(xs, encoder_out.numpy(), atol=1e-5))\n", - "print(np.allclose(masks, encoder_mask.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "seasonal-switch", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 2.1380312 1.8675405 -1.1873871 ... -0.30456656 0.56382173\n", - " -0.6526459 ]\n", - " [ 2.1926146 2.1373641 -0.6548196 ... -0.897318 0.6044322\n", - " -0.63332295]\n", - " [ 1.6367635 2.3320658 -0.8848577 ... -0.9640939 1.2420733\n", - " -0.05243584]\n", - " ...\n", - " [ 1.8533031 1.8421621 -0.6728406 ... 0.04810616 0.6459763\n", - " -0.18188554]\n", - " [ 2.0894065 1.7813934 -1.1591585 ... -0.09513803 0.8321831\n", - " -0.72916794]\n", - " [ 1.6488649 2.0984242 -1.3490562 ... 0.42678255 0.5903866\n", - " -0.32597935]]\n", - "Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[ 2.13803196, 1.86753929, -1.18738675, ..., -0.30456796, 0.56382364, -0.65264463],\n", - " [ 2.19261336, 2.13736486, -0.65482187, ..., -0.89731705, 0.60443199, -0.63332343],\n", - " [ 1.63676369, 2.33206534, -0.88485885, ..., -0.96409231, 1.24207270, -0.05243752],\n", - " ...,\n", - " [ 1.85330284, 1.84216177, -0.67284071, ..., 0.04810715, 0.64597648, -0.18188696],\n", - " [ 2.08940673, 1.78139246, -1.15916038, ..., -0.09513779, 0.83218288, -0.72916913],\n", - " [ 1.64886570, 2.09842515, -1.34905660, ..., 0.42678308, 0.59038705, -0.32598034]])\n" - ] - } - ], - "source": [ - "print(xs[0])\n", - "print(encoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "defined-brooks", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 2.209824 1.5208759 0.1417884 ... -0.73617566 1.6538682\n", - " -0.16355833]\n", - " [ 2.1441019 1.4377339 0.3629197 ... -0.91226125 1.3739952\n", - " 0.11874156]\n", - " [ 1.8725398 1.5417286 0.38919652 ... -0.89621615 1.1841662\n", - " 0.27621832]\n", - " ...\n", - " [ 2.4591084 0.7238764 -1.1456345 ... -0.24188249 0.8232168\n", - " -0.9794884 ]\n", - " [ 2.5156236 1.1919155 -0.97032744 ... -0.7360675 1.0647209\n", - " -1.3076135 ]\n", - " [ 2.160009 0.98425585 -1.2231126 ... -0.03393313 1.9141548\n", - " -1.0099151 ]]\n", - "Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[ 2.20982409, 1.52087593, 0.14178854, ..., -0.73617446, 1.65386844, -0.16355731],\n", - " [ 2.14410043, 1.43773460, 0.36291891, ..., -0.91226172, 1.37399518, 0.11874183],\n", - " [ 1.87254059, 1.54172909, 0.38919681, ..., -0.89621687, 1.18416822, 0.27621880],\n", - " ...,\n", - " [ 2.45910931, 0.72387671, -1.14563596, ..., -0.24188218, 0.82321703, -0.97948682],\n", - " [ 2.51562238, 1.19191694, -0.97032893, ..., -0.73606837, 1.06472087, -1.30761123],\n", - " [ 2.16000915, 0.98425680, -1.22311163, ..., -0.03393326, 1.91415381, -1.00991392]])\n" - ] - } - ], - "source": [ - "print(xs[1])\n", - "print(encoder_out[1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0504e3f8", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/wenet_model.ipynb b/.notebook/wenet_model.ipynb deleted file mode 100644 index 8e10b6c4bf94206ef1664b697b64ce1973d7f60c..0000000000000000000000000000000000000000 --- a/.notebook/wenet_model.ipynb +++ /dev/null @@ -1,5015 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "cfb832c0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/wenet\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/wenet'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd /workspace/wenet/\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "62277538", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import argparse\n", - "import copy\n", - "import logging\n", - "import os\n", - "\n", - "import torch\n", - "import torch.distributed as dist\n", - "import torch.optim as optim\n", - "import yaml\n", - "from tensorboardX import SummaryWriter\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from wenet.dataset.dataset import AudioDataset, CollateFunc\n", - "from wenet.transformer.asr_model import init_asr_model\n", - "from wenet.utils.checkpoint import load_checkpoint, save_checkpoint\n", - "from wenet.utils.executor import Executor\n", - "from wenet.utils.scheduler import WarmupLR\n", - "\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "2f6ea33a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'config': 'examples/aishell/s0/conf/train_conformer.yaml', 'train_data': 'examples/aishell/s0/raw_wav/train/format.data', 'cv_data': 'examples/aishell/s0/raw_wav/dev/format.data', 'gpu': -1, 'model_dir': None, 'checkpoint': None, 'tensorboard_dir': 'tensorboard', 'rank': 0, 'world_size': -1, 'dist_backend': 'nccl', 'init_method': None, 'num_workers': 0, 'pin_memory': False, 'cmvn': 'examples/aishell/s0/raw_wav/train/global_cmvn'}\n" - ] - } - ], - "source": [ - "parser = argparse.ArgumentParser(description='training your network')\n", - "parser.add_argument('--config', default=\"examples/aishell/s0/conf/train_conformer.yaml\", help='config file')\n", - "parser.add_argument('--train_data', default=\"examples/aishell/s0/raw_wav/train/format.data\", help='train data file')\n", - "parser.add_argument('--cv_data', default=\"examples/aishell/s0/raw_wav/dev/format.data\", help='cv data file')\n", - "parser.add_argument('--gpu',\n", - " type=int,\n", - " default=-1,\n", - " help='gpu id for this local rank, -1 for cpu')\n", - "parser.add_argument('--model_dir' , help='save model dir')\n", - "parser.add_argument('--checkpoint', help='checkpoint model')\n", - "parser.add_argument('--tensorboard_dir',\n", - " default='tensorboard',\n", - " help='tensorboard log dir')\n", - "parser.add_argument('--ddp.rank',\n", - " dest='rank',\n", - " default=0,\n", - " type=int,\n", - " help='global rank for distributed training')\n", - "parser.add_argument('--ddp.world_size',\n", - " dest='world_size',\n", - " default=-1,\n", - " type=int,\n", - " help='''number of total processes/gpus for\n", - " distributed training''')\n", - "parser.add_argument('--ddp.dist_backend',\n", - " dest='dist_backend',\n", - " default='nccl',\n", - " choices=['nccl', 'gloo'],\n", - " help='distributed backend')\n", - "parser.add_argument('--ddp.init_method',\n", - " dest='init_method',\n", - " default=None,\n", - " help='ddp init method')\n", - "parser.add_argument('--num_workers',\n", - " default=0,\n", - " type=int,\n", - " help='num of subprocess workers for reading')\n", - "parser.add_argument('--pin_memory',\n", - " action='store_true',\n", - " default=False,\n", - " help='Use pinned memory buffers used for reading')\n", - "parser.add_argument('--cmvn', default=\"examples/aishell/s0/raw_wav/train/global_cmvn\", help='global cmvn file')\n", - "\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "f5d6af9b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Namespace(checkpoint=None, cmvn='examples/aishell/s0/raw_wav/train/global_cmvn', config='examples/aishell/s0/conf/train_conformer.yaml', cv_data='examples/aishell/s0/raw_wav/dev/format.data', dist_backend='nccl', gpu=-1, init_method=None, model_dir=None, num_workers=0, pin_memory=False, rank=0, tensorboard_dir='tensorboard', train_data='examples/aishell/s0/raw_wav/train/format.data', world_size=-1)\n" - ] - } - ], - "source": [ - "# Set random seed\n", - "torch.manual_seed(777)\n", - "print(args)\n", - "with open(args.config, 'r') as fin:\n", - " configs = yaml.load(fin, Loader=yaml.FullLoader)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "264bd353", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "7507 batches\n", - "896\n" - ] - } - ], - "source": [ - "raw_wav = configs['raw_wav']\n", - "\n", - "train_collate_func = CollateFunc(**configs['collate_conf'],\n", - " raw_wav=raw_wav)\n", - "\n", - "cv_collate_conf = copy.deepcopy(configs['collate_conf'])\n", - "# no augmenation on cv set\n", - "cv_collate_conf['spec_aug'] = False\n", - "cv_collate_conf['spec_sub'] = False\n", - "if raw_wav:\n", - " cv_collate_conf['feature_dither'] = 0.0\n", - " cv_collate_conf['speed_perturb'] = False\n", - " cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0\n", - "cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)\n", - "\n", - "dataset_conf = configs.get('dataset_conf', {})\n", - "train_dataset = AudioDataset(args.train_data,\n", - " **dataset_conf,\n", - " raw_wav=raw_wav)\n", - "cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)\n", - "# 120098 data/train/wav.scp\n", - "print(len(train_dataset), 'batches')\n", - "# 14326 data/dev/wav.scp\n", - "print(len(cv_dataset))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "88863d3c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "896\n" - ] - } - ], - "source": [ - "train_sampler = None\n", - "cv_sampler = None\n", - "train_data_loader = DataLoader(train_dataset,\n", - " collate_fn=train_collate_func,\n", - " sampler=train_sampler,\n", - " #shuffle=(train_sampler is None),\n", - " shuffle=False,\n", - " pin_memory=args.pin_memory,\n", - " batch_size=1,\n", - " num_workers=args.num_workers)\n", - "cv_data_loader = DataLoader(cv_dataset,\n", - " collate_fn=cv_collate_func,\n", - " sampler=cv_sampler,\n", - " shuffle=False,\n", - " batch_size=1,\n", - " pin_memory=args.pin_memory,\n", - " num_workers=args.num_workers)\n", - "print(len(cv_data_loader))" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "10d5acd4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4233 vocab\n", - "80 feat dim\n" - ] - } - ], - "source": [ - "if raw_wav:\n", - " input_dim = configs['collate_conf']['feature_extraction_conf'][\n", - " 'mel_bins']\n", - "else:\n", - " input_dim = train_dataset.input_dim\n", - "vocab_size = train_dataset.output_dim\n", - "print(vocab_size, 'vocab')\n", - "print(input_dim , 'feat dim')" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "0380ef5a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "examples/aishell/s0/raw_wav/train/global_cmvn\n" - ] - } - ], - "source": [ - "# Save configs to model_dir/train.yaml for inference and export\n", - "configs['input_dim'] = input_dim\n", - "configs['output_dim'] = vocab_size\n", - "configs['cmvn_file'] = args.cmvn\n", - "configs['is_json_cmvn'] = raw_wav\n", - "print(args.cmvn)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "15ebf2bf", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(80,)\n", - "(80,)\n", - "[ 9.87176362 9.93891555 10.23818678 10.85971412 11.68652649 12.2548801\n", - " 12.65768161 12.86138996 12.80733912 12.56625574 12.32007066 12.13879205\n", - " 12.31318868 12.55255216 12.61223855 12.56974526 12.38972728 12.14383338\n", - " 12.09285066 11.79395822 11.62259065 11.9263303 11.8154422 11.95122567\n", - " 11.83180553 11.88788759 11.79014437 11.88072035 11.90005711 11.97348142\n", - " 12.00982189 12.00881339 12.02619706 12.10479646 12.21555081 12.34399304\n", - " 12.45014401 12.4966879 12.48653775 12.3550783 12.39291732 12.2553737\n", - " 12.26496277 12.25314244 12.32545763 12.43359839 12.54867439 12.6763342\n", - " 12.80920698 12.92934681 12.96115138 12.96883353 12.99593057 13.04728142\n", - " 13.0588804 13.05737948 12.99921175 12.93402238 12.87429219 12.71652995\n", - " 12.48942004 12.27478385 12.26163069 12.28631891 12.31956049 12.4229073\n", - " 12.51480191 12.5785164 12.64719411 12.73762568 12.80017069 12.86872766\n", - " 12.96666856 13.06478583 13.15915908 13.27284306 13.31081821 13.23904279\n", - " 12.87936075 11.18310185]\n", - "[0.61219383 0.49700994 0.33439025 0.31503119 0.29640823 0.28411759\n", - " 0.26972922 0.25610475 0.24632936 0.24610228 0.24733299 0.24426536\n", - " 0.23751781 0.22987273 0.22659963 0.2268427 0.23059031 0.23420722\n", - " 0.23771761 0.2411352 0.24404673 0.24557175 0.24724932 0.25055198\n", - " 0.25482755 0.2602407 0.26363878 0.26503898 0.2648467 0.26435072\n", - " 0.26353625 0.26364794 0.26411054 0.26339948 0.26212082 0.26146597\n", - " 0.26196556 0.26365859 0.26592959 0.26963884 0.27392766 0.27818809\n", - " 0.28313664 0.2863325 0.28713431 0.28649323 0.28636648 0.2867843\n", - " 0.28635904 0.28562022 0.28492711 0.28429201 0.28402977 0.28401045\n", - " 0.28560797 0.28728033 0.28969549 0.29351627 0.29826453 0.30572631\n", - " 0.31811682 0.32887739 0.33288219 0.33326245 0.33014147 0.32403202\n", - " 0.31903576 0.31316258 0.30741037 0.30370692 0.30204833 0.30049064\n", - " 0.29901079 0.29824511 0.29812308 0.29753329 0.29779342 0.30175296\n", - " 0.30955538 0.32904205]\n" - ] - } - ], - "source": [ - "import json\n", - "import math\n", - "import numpy as np\n", - "def _load_json_cmvn(json_cmvn_file):\n", - " \"\"\" Load the json format cmvn stats file and calculate cmvn\n", - "\n", - " Args:\n", - " json_cmvn_file: cmvn stats file in json format\n", - "\n", - " Returns:\n", - " a numpy array of [means, vars]\n", - " \"\"\"\n", - " with open(json_cmvn_file) as f:\n", - " cmvn_stats = json.load(f)\n", - "\n", - " means = cmvn_stats['mean_stat']\n", - " variance = cmvn_stats['var_stat']\n", - " count = cmvn_stats['frame_num']\n", - " for i in range(len(means)):\n", - " means[i] /= count\n", - " variance[i] = variance[i] / count - means[i] * means[i]\n", - " if variance[i] < 1.0e-20:\n", - " variance[i] = 1.0e-20\n", - " variance[i] = 1.0 / math.sqrt(variance[i])\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def _load_kaldi_cmvn(kaldi_cmvn_file):\n", - " \"\"\" Load the kaldi format cmvn stats file and calculate cmvn\n", - "\n", - " Args:\n", - " kaldi_cmvn_file: kaldi text style global cmvn file, which\n", - " is generated by:\n", - " compute-cmvn-stats --binary=false scp:feats.scp global_cmvn\n", - "\n", - " Returns:\n", - " a numpy array of [means, vars]\n", - " \"\"\"\n", - " means = []\n", - " variance = []\n", - " with open(kaldi_cmvn_file, 'r') as fid:\n", - " # kaldi binary file start with '\\0B'\n", - " if fid.read(2) == '\\0B':\n", - " logger.error('kaldi cmvn binary file is not supported, please '\n", - " 'recompute it by: compute-cmvn-stats --binary=false '\n", - " ' scp:feats.scp global_cmvn')\n", - " sys.exit(1)\n", - " fid.seek(0)\n", - " arr = fid.read().split()\n", - " assert (arr[0] == '[')\n", - " assert (arr[-2] == '0')\n", - " assert (arr[-1] == ']')\n", - " feat_dim = int((len(arr) - 2 - 2) / 2)\n", - " for i in range(1, feat_dim + 1):\n", - " means.append(float(arr[i]))\n", - " count = float(arr[feat_dim + 1])\n", - " for i in range(feat_dim + 2, 2 * feat_dim + 2):\n", - " variance.append(float(arr[i]))\n", - "\n", - " for i in range(len(means)):\n", - " means[i] /= count\n", - " variance[i] = variance[i] / count - means[i] * means[i]\n", - " if variance[i] < 1.0e-20:\n", - " variance[i] = 1.0e-20\n", - " variance[i] = 1.0 / math.sqrt(variance[i])\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):\n", - " npzfile = np.load(npz_cmvn_file)\n", - " means = npzfile[\"mean\"] #(1, D)\n", - " std = npzfile[\"std\"] #(1, D)\n", - " std = np.clip(std, eps, None)\n", - " variance = 1.0 / std\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def load_cmvn(cmvn_file: str, filetype: str):\n", - " \"\"\"load cmvn from file.\n", - "\n", - " Args:\n", - " cmvn_file (str): cmvn path.\n", - " filetype (str): file type, optional[npz, json, kaldi].\n", - "\n", - " Raises:\n", - " ValueError: file type not support.\n", - "\n", - " Returns:\n", - " Tuple[np.ndarray, np.ndarray]: mean, istd\n", - " \"\"\"\n", - " assert filetype in ['npz', 'json', 'kaldi'], filetype\n", - " filetype = filetype.lower()\n", - " if filetype == \"json\":\n", - " cmvn = _load_json_cmvn(cmvn_file)\n", - " elif filetype == \"kaldi\":\n", - " cmvn = _load_kaldi_cmvn(cmvn_file)\n", - " elif filetype == \"npz\":\n", - " cmvn = _load_npz_cmvn(cmvn_file)\n", - " else:\n", - " raise ValueError(f\"cmvn file type no support: {filetype}\")\n", - " return cmvn[0], cmvn[1]\n", - "\n", - "mean, istd = load_cmvn(args.cmvn, 'json')\n", - "print(mean.shape)\n", - "print(istd.shape)\n", - "print(mean)\n", - "print(istd)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "3cfa5e23", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ASRModel(\n", - " (encoder): ConformerEncoder(\n", - " (global_cmvn): GlobalCMVN()\n", - " (embed): Conv2dSubsampling4(\n", - " (conv): Sequential(\n", - " (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))\n", - " (1): ReLU()\n", - " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", - " (3): ReLU()\n", - " )\n", - " (out): Sequential(\n", - " (0): Linear(in_features=4864, out_features=256, bias=True)\n", - " )\n", - " (pos_enc): RelPositionalEncoding(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (encoders): ModuleList(\n", - " (0): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (1): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (2): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (3): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (4): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (5): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (6): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (7): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (8): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (9): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (10): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (11): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " )\n", - " )\n", - " (decoder): TransformerDecoder(\n", - " (embed): Sequential(\n", - " (0): Embedding(4233, 256)\n", - " (1): PositionalEncoding(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (output_layer): Linear(in_features=256, out_features=4233, bias=True)\n", - " (decoders): ModuleList(\n", - " (0): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " )\n", - " )\n", - " (ctc): CTC(\n", - " (ctc_lo): Linear(in_features=256, out_features=4233, bias=True)\n", - " (ctc_loss): CTCLoss()\n", - " )\n", - " (criterion_att): LabelSmoothingLoss(\n", - " (criterion): KLDivLoss()\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "# Init asr model from configs\n", - "model = init_asr_model(configs)\n", - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "3c780af5", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "def summary(layer, print_func=print):\n", - " num_params = num_elements = 0\n", - " for name, param in layer.state_dict().items():\n", - " if print_func:\n", - " print_func(\n", - " \"{} | {} | {}\".format(name, param.shape, np.prod(param.shape)))\n", - " num_elements += np.prod(param.shape)\n", - " num_params += 1\n", - " if print_func:\n", - " print_func(\n", - " f\"Total parameters: {num_params}, {num_elements} elements.\"\n", - " )\n", - " \n", - "def print_params(model, print_func=print):\n", - " if print_func is None:\n", - " return\n", - " total = 0.0\n", - " num_params = 0.0\n", - " for n, p in model.named_parameters():\n", - " msg = f\"{n} | {p.shape} | {np.prod(p.shape)} | {p.requires_grad}\"\n", - " total += np.prod(p.shape)\n", - " num_params += 1\n", - " if print_func:\n", - " print_func(msg)\n", - " if print_func:\n", - " print_func(f\"Total parameters: {num_params}, {total} elements.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e159a200", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.global_cmvn.mean | torch.Size([80]) | 80\n", - "encoder.global_cmvn.istd | torch.Size([80]) | 80\n", - "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304\n", - "encoder.embed.conv.0.bias | torch.Size([256]) | 256\n", - "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824\n", - "encoder.embed.conv.2.bias | torch.Size([256]) | 256\n", - "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184\n", - "encoder.embed.out.0.bias | torch.Size([256]) | 256\n", - "encoder.after_norm.weight | torch.Size([256]) | 256\n", - "encoder.after_norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256\n", - "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648\n", - "decoder.after_norm.weight | torch.Size([256]) | 256\n", - "decoder.after_norm.bias | torch.Size([256]) | 256\n", - "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648\n", - "decoder.output_layer.bias | torch.Size([4233]) | 4233\n", - "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256\n", - "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648\n", - "ctc.ctc_lo.bias | torch.Size([4233]) | 4233\n", - "Total parameters: 701, 49355454.0 elements.\n" - ] - } - ], - "source": [ - "summary(model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8494c6ab", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "0648a969", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304 | True\n", - "encoder.embed.conv.0.bias | torch.Size([256]) | 256 | True\n", - "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824 | True\n", - "encoder.embed.conv.2.bias | torch.Size([256]) | 256 | True\n", - "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184 | True\n", - "encoder.embed.out.0.bias | torch.Size([256]) | 256 | True\n", - "encoder.after_norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.after_norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "decoder.after_norm.weight | torch.Size([256]) | 256 | True\n", - "decoder.after_norm.bias | torch.Size([256]) | 256 | True\n", - "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "decoder.output_layer.bias | torch.Size([4233]) | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "ctc.ctc_lo.bias | torch.Size([4233]) | 4233 | True\n", - "Total parameters: 663.0, 49349138.0 elements.\n" - ] - } - ], - "source": [ - "print_params(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "5ad6de2a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", - "torch.Size([16, 207, 80])\n", - "tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", - " [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", - " [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", - " ...,\n", - " [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", - " [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", - " [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", - "\n", - " [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", - " [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", - " [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", - " ...,\n", - " [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", - " [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", - " [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", - "\n", - " [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", - " [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", - " [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", - " ...,\n", - " [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " ...,\n", - "\n", - " [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", - " [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", - " [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", - " [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", - " [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", - " [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", - " [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", - "tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", - " 166, 163], dtype=torch.int32)\n", - "tensor([[2995, 3116, 1209, 565, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077],\n", - " [2693, 524, 234, 1145, 366, -1],\n", - " [3875, 4211, 3062, 700, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710],\n", - " [ 25, 1149, 3930, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110],\n", - " [3703, 2, 565, 3827, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, -1],\n", - " [ 426, 811, 95, 489, 144, -1],\n", - " [2313, 2006, 489, 975, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347],\n", - " [ 70, 1741, 702, 1666, -1, -1],\n", - " [ 703, 1778, 1030, 849, -1, -1],\n", - " [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", - "tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)\n" - ] - } - ], - "source": [ - "for batch in cv_data_loader:\n", - " keys, feat, text, feat_len, text_len = batch\n", - " print(keys)\n", - " print(feat.shape)\n", - " print(feat)\n", - " print(feat_len)\n", - " print(text)\n", - " print(text_len)\n", - " np.savez('data.npz', keys=keys, feat=feat.numpy(), feat_len=feat_len.numpy(), text=text.numpy(), text_len=text_len.numpy())\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "852a9c95", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CODE_OF_CONDUCT.md data.npz install.sh README.md\t tools\r\n", - "CONTRIBUTING.md docs LICENSE\t requirements.txt venv\r\n", - "CPPLINT.cfg\t examples Makefile\t runtime\t wenet\r\n" - ] - } - ], - "source": [ - "!ls\n", - "!cp data.npz /workspace/DeepSpeech-2.x/.notebook" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "cde24c4e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(111.9988)\n", - "tensor(830.9634, grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True])\n", - "tensor(669.4633, grad_fn=)\n", - "tensor(142.4888, grad_fn=) tensor(41.8415, grad_fn=) tensor(377.3326, grad_fn=)\n" - ] - } - ], - "source": [ - "model.cpu().eval()\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss, attention_loss, ctc_loss )" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "be5b2a2c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cpu\n" - ] - } - ], - "source": [ - "print(total_loss.device)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "5b791771", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "cuda:0\n", - "142.4888 41.84146 377.33258\n" - ] - } - ], - "source": [ - "model.cuda().eval()\n", - "feat=feat.cuda()\n", - "feat_len=feat_len.cuda()\n", - "text=text.cuda()\n", - "text_len=text_len.cuda()\n", - "\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss.device)\n", - "print(total_loss.cpu().data.numpy(), attention_loss.cpu().data.numpy(), ctc_loss.cpu().data.numpy() )" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "1baef537", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 51, 256])\n", - "torch.Size([16, 1, 51])\n", - "tensor([[-0.7019, 0.5625, 0.6880, ..., 1.1237, 0.7804, 1.1369],\n", - " [-0.7788, 0.3913, 0.7189, ..., 1.2519, 0.8862, 1.3173],\n", - " [-0.9591, 0.6346, 0.8767, ..., 0.9818, 0.7440, 1.2903],\n", - " ...,\n", - " [-1.0732, 0.6724, 0.9230, ..., 0.9075, 0.8177, 1.3240],\n", - " [-1.1654, 0.6820, 0.6939, ..., 1.2238, 0.8028, 1.4507],\n", - " [-1.2732, 0.7146, 0.7582, ..., 0.9415, 0.8775, 1.2623]],\n", - " device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)\n", - "print(encoder_out[0])\n", - "\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/encoder.npz',\n", - " mask=encoder_mask.cpu().detach().numpy(), \n", - " out=encoder_out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e22c782", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "30b6b946", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 9.871763 9.938915 10.238187 10.8597145 11.686526 12.25488\n", - " 12.657681 12.86139 12.807339 12.566256 12.32007 12.138792\n", - " 12.313189 12.552552 12.612239 12.569745 12.389728 12.143833\n", - " 12.092851 11.793959 11.622591 11.926331 11.815442 11.951225\n", - " 11.831805 11.887888 11.790144 11.88072 11.900057 11.973481\n", - " 12.009822 12.008814 12.026197 12.104796 12.21555 12.343993\n", - " 12.450144 12.496688 12.486538 12.355079 12.392918 12.255374\n", - " 12.264963 12.253142 12.325458 12.4335985 12.548675 12.676334\n", - " 12.809207 12.929347 12.961151 12.968834 12.995931 13.047281\n", - " 13.058881 13.05738 12.999211 12.934022 12.874292 12.71653\n", - " 12.48942 12.274784 12.261631 12.286319 12.31956 12.422907\n", - " 12.514802 12.578516 12.647194 12.737626 12.800171 12.868728\n", - " 12.966668 13.064786 13.159159 13.272843 13.310819 13.239043\n", - " 12.879361 11.183102 ] float32\n", - "encoder.embed.out.0.weight: (256, 4864) -> (4864, 256)\n", - "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.0.conv_module.norm.running_mean -> encoder.encoders.0.conv_module.norm._mean\n", - "encoder.encoders.0.conv_module.norm.running_var -> encoder.encoders.0.conv_module.norm._variance\n", - "encoder.encoders.0.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.1.conv_module.norm.running_mean -> encoder.encoders.1.conv_module.norm._mean\n", - "encoder.encoders.1.conv_module.norm.running_var -> encoder.encoders.1.conv_module.norm._variance\n", - "encoder.encoders.1.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.2.conv_module.norm.running_mean -> encoder.encoders.2.conv_module.norm._mean\n", - "encoder.encoders.2.conv_module.norm.running_var -> encoder.encoders.2.conv_module.norm._variance\n", - "encoder.encoders.2.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.3.conv_module.norm.running_mean -> encoder.encoders.3.conv_module.norm._mean\n", - "encoder.encoders.3.conv_module.norm.running_var -> encoder.encoders.3.conv_module.norm._variance\n", - "encoder.encoders.3.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.4.conv_module.norm.running_mean -> encoder.encoders.4.conv_module.norm._mean\n", - "encoder.encoders.4.conv_module.norm.running_var -> encoder.encoders.4.conv_module.norm._variance\n", - "encoder.encoders.4.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.5.conv_module.norm.running_mean -> encoder.encoders.5.conv_module.norm._mean\n", - "encoder.encoders.5.conv_module.norm.running_var -> encoder.encoders.5.conv_module.norm._variance\n", - "encoder.encoders.5.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.6.conv_module.norm.running_mean -> encoder.encoders.6.conv_module.norm._mean\n", - "encoder.encoders.6.conv_module.norm.running_var -> encoder.encoders.6.conv_module.norm._variance\n", - "encoder.encoders.6.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.7.conv_module.norm.running_mean -> encoder.encoders.7.conv_module.norm._mean\n", - "encoder.encoders.7.conv_module.norm.running_var -> encoder.encoders.7.conv_module.norm._variance\n", - "encoder.encoders.7.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.8.conv_module.norm.running_mean -> encoder.encoders.8.conv_module.norm._mean\n", - "encoder.encoders.8.conv_module.norm.running_var -> encoder.encoders.8.conv_module.norm._variance\n", - "encoder.encoders.8.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.9.conv_module.norm.running_mean -> encoder.encoders.9.conv_module.norm._mean\n", - "encoder.encoders.9.conv_module.norm.running_var -> encoder.encoders.9.conv_module.norm._variance\n", - "encoder.encoders.9.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.10.conv_module.norm.running_mean -> encoder.encoders.10.conv_module.norm._mean\n", - "encoder.encoders.10.conv_module.norm.running_var -> encoder.encoders.10.conv_module.norm._variance\n", - "encoder.encoders.10.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.11.conv_module.norm.running_mean -> encoder.encoders.11.conv_module.norm._mean\n", - "encoder.encoders.11.conv_module.norm.running_var -> encoder.encoders.11.conv_module.norm._variance\n", - "encoder.encoders.11.concat_linear.weight: (256, 512) -> (512, 256)\n", - "decoder.output_layer.weight: (4233, 256) -> (256, 4233)\n", - "decoder.decoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.0.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.0.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.1.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.1.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.2.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.2.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.3.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.3.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.4.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.4.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.5.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.5.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "ctc.ctc_lo.weight: (4233, 256) -> (256, 4233)\n" - ] - } - ], - "source": [ - "# dump torch model to paddle\n", - "import numpy as np\n", - "state_dict = model.state_dict()\n", - "paddle_state_dict = {}\n", - "\n", - "for n, p in state_dict.items():\n", - " name_change=True\n", - "\n", - " if 'norm.running_mean' in n:\n", - " new_n = n.replace('norm.running_', 'norm._')\n", - " elif 'norm.running_var' in n:\n", - " new_n = n.replace('norm.running_var', 'norm._variance')\n", - " else:\n", - " name_change=False\n", - " new_n = n\n", - " \n", - " if name_change:\n", - " print(f\"{n} -> {new_n}\")\n", - " \n", - " p = p.cpu().detach().numpy()\n", - " if n.endswith('weight') and p.ndim == 2 and 'embed.0.weight' not in n:\n", - " new_p = p.T\n", - " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", - " else:\n", - " new_p = p\n", - " \n", - " if 'global_cmvn.mean' in n:\n", - " print(p, p.dtype)\n", - " \n", - " paddle_state_dict[new_n] = new_p\n", - " \n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", - " state=paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7307dc5b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "d99b29bc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(377.3326, device='cuda:0', grad_fn=)\n", - "None\n", - "[[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", - " -5.93366381e-03 -7.26613170e-03]\n", - " [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", - " 2.46338220e-03 2.31891591e-03]\n", - " [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", - " 5.76929189e-03 7.48792710e-03]\n", - " ...\n", - " [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", - " 1.16123557e-02 1.44716976e-02]\n", - " [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", - " 8.58021621e-03 1.07796099e-02]\n", - " [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", - " 1.60815325e-02 2.03892551e-02]]\n", - "[-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", - " 1.0920014e-02 1.3787906e-02]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n", - " print(loss_ctc.grad)\n" - ] - } - ], - "source": [ - "encoder_out_lens = encoder_mask.squeeze(1).sum(1)\n", - "loss_ctc = model.ctc(encoder_out, encoder_out_lens, text, text_len)\n", - "print(loss_ctc)\n", - "dir(loss_ctc)\n", - "loss_ctc.backward()\n", - "print(loss_ctc.grad)\n", - "#print(model.ctc.ctc_lo.weight.grad)\n", - "print(model.ctc.ctc_lo.weight.grad.T.cpu().data.numpy())\n", - "print(model.ctc.ctc_lo.bias.grad.cpu().data.numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "49b05d6d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=) 0.0\n" - ] - } - ], - "source": [ - "loss_att, acc_att = model._calc_att_loss(encoder_out, encoder_mask,\n", - " text, text_len)\n", - "print(loss_att, acc_att)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "413b413f", - "metadata": {}, - "outputs": [], - "source": [ - "def pad_list(xs, pad_value: int):\n", - " n_batch = len(xs)\n", - " max_len = max([x.size(0) for x in xs])\n", - " pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)\n", - " pad = pad.fill_(pad_value)\n", - " for i in range(n_batch):\n", - " pad[i, :xs[i].size(0)] = xs[i]\n", - "\n", - " return pad\n", - "\n", - "def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,\n", - " ignore_id: int):\n", - "\n", - " _sos = torch.tensor([sos],\n", - " dtype=torch.long,\n", - " requires_grad=False,\n", - " device=ys_pad.device)\n", - " _eos = torch.tensor([eos],\n", - " dtype=torch.long,\n", - " requires_grad=False,\n", - " device=ys_pad.device)\n", - " ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", - " ys_in = [torch.cat([_sos, y], dim=0) for y in ys]\n", - " ys_out = [torch.cat([y, _eos], dim=0) for y in ys]\n", - " return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "ff0c2400", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[4232, 2995, 3116, 1209, 565, 4232, 4232],\n", - " [4232, 236, 1176, 331, 66, 3925, 4077],\n", - " [4232, 2693, 524, 234, 1145, 366, 4232],\n", - " [4232, 3875, 4211, 3062, 700, 4232, 4232],\n", - " [4232, 272, 987, 1134, 494, 2959, 4232],\n", - " [4232, 1936, 3715, 120, 2553, 2695, 2710],\n", - " [4232, 25, 1149, 3930, 4232, 4232, 4232],\n", - " [4232, 1753, 1778, 1237, 482, 3925, 110],\n", - " [4232, 3703, 2, 565, 3827, 4232, 4232],\n", - " [4232, 1150, 2734, 10, 2478, 3490, 4232],\n", - " [4232, 426, 811, 95, 489, 144, 4232],\n", - " [4232, 2313, 2006, 489, 975, 4232, 4232],\n", - " [4232, 3702, 3414, 205, 1488, 2966, 1347],\n", - " [4232, 70, 1741, 702, 1666, 4232, 4232],\n", - " [4232, 703, 1778, 1030, 849, 4232, 4232],\n", - " [4232, 814, 1674, 115, 3827, 4232, 4232]], device='cuda:0')\n", - "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1],\n", - " [3875, 4211, 3062, 700, 4232, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2, 565, 3827, 4232, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", - " [ 426, 811, 95, 489, 144, 4232, -1],\n", - " [2313, 2006, 489, 975, 4232, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", - " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", - " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n" - ] - } - ], - "source": [ - "ys_pad = text\n", - "ys_pad_lens = text_len\n", - "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", - " model.ignore_id)\n", - "ys_in_lens = ys_pad_lens + 1\n", - "print(ys_in_pad)\n", - "print(ys_out_pad)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "3e84da38", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 7, 4233])\n", - "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", - " 1.5035e-02, 4.0337e-01],\n", - " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", - " -1.4353e-01, -1.0024e+00],\n", - " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", - " -1.1672e+00, -2.6849e-01],\n", - " ...,\n", - " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", - " 4.5470e-02, -3.7139e-01],\n", - " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", - " -7.9347e-04, 4.2538e-01],\n", - " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", - " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "print(decoder_out.shape)\n", - "print(decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aac441ea", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "5ddbca73", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.float32\n", - "torch.int64\n", - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=)\n", - "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1],\n", - " [3875, 4211, 3062, 700, 4232, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2, 565, 3827, 4232, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", - " [ 426, 811, 95, 489, 144, 4232, -1],\n", - " [2313, 2006, 489, 975, 4232, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", - " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", - " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n", - "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", - " 1.5035e-02, 4.0337e-01],\n", - " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", - " -1.4353e-01, -1.0024e+00],\n", - " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", - " -1.1672e+00, -2.6849e-01],\n", - " ...,\n", - " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", - " 4.5470e-02, -3.7139e-01],\n", - " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", - " -7.9347e-04, 4.2538e-01],\n", - " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", - " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "print(decoder_out.dtype)\n", - "print(ys_out_pad.dtype)\n", - "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "print(ys_out_pad)\n", - "print(decoder_out[0])\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/decoder',\n", - " decoder_out=decoder_out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78f98c0b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "8d968cd3", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch import nn\n", - "\n", - "\n", - "class LabelSmoothingLoss(nn.Module):\n", - " def __init__(self,\n", - " size: int,\n", - " padding_idx: int,\n", - " smoothing: float,\n", - " normalize_length: bool = False):\n", - " \"\"\"Construct an LabelSmoothingLoss object.\"\"\"\n", - " super(LabelSmoothingLoss, self).__init__()\n", - " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", - " self.padding_idx = padding_idx\n", - " self.confidence = 1.0 - smoothing\n", - " self.smoothing = smoothing\n", - " self.size = size\n", - " self.normalize_length = normalize_length\n", - "\n", - " def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"Compute loss between x and target.\n", - "\n", - " The model outputs and data labels tensors are flatten to\n", - " (batch*seqlen, class) shape and a mask is applied to the\n", - " padding part which should not be calculated for loss.\n", - "\n", - " Args:\n", - " x (torch.Tensor): prediction (batch, seqlen, class)\n", - " target (torch.Tensor):\n", - " target signal masked with self.padding_id (batch, seqlen)\n", - " Returns:\n", - " loss (torch.Tensor) : The KL loss, scalar float value\n", - " \"\"\"\n", - " assert x.size(2) == self.size\n", - " batch_size = x.size(0)\n", - " x = x.view(-1, self.size)\n", - " target = target.view(-1)\n", - " # use zeros_like instead of torch.no_grad() for true_dist,\n", - " # since no_grad() can not be exported by JIT\n", - " true_dist = torch.zeros_like(x)\n", - " true_dist.fill_(self.smoothing / (self.size - 1))\n", - " ignore = target == self.padding_idx # (B,)\n", - " print(self.smoothing / (self.size - 1))\n", - " print(true_dist)\n", - " total = len(target) - ignore.sum().item()\n", - " target = target.masked_fill(ignore, 0) # avoid -1 index\n", - " true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n", - " print(true_dist.dtype)\n", - " print(true_dist.square().sum())\n", - " kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)\n", - " print(kl.sum())\n", - " denom = total if self.normalize_length else batch_size\n", - " print(ignore)\n", - " numer= kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", - " print(numer)\n", - " return numer /denom" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "3df340ec", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.3629489603024576e-05\n", - "tensor([[2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " ...,\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05]], device='cuda:0')\n", - "torch.float32\n", - "tensor(90.7203, device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=)\n", - "torch.int64\n" - ] - } - ], - "source": [ - "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", - "loss_att = criteron(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "print(ys_out_pad.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "badc410d", - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecoder_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m allow_unreachable=True) # allow_unreachable flag\n", - "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time." - ] - } - ], - "source": [ - "loss_att.backward()\n", - "print(loss_att.grad)\n", - "print(decoder_out.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "219eb41f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([ 0.0024, 0.0019, -0.1098, ..., 0.0028, 0.0020, -1.7978],\n", - " device='cuda:0')\n", - "tensor([[ 6.5052e-04, 6.4419e-05, -6.1955e-06, ..., 9.8220e-04,\n", - " -2.5918e-05, 3.3754e-04],\n", - " [ 3.9305e-04, 4.5799e-04, 1.4362e-04, ..., 4.6800e-04,\n", - " 1.6911e-04, 2.7067e-04],\n", - " [-1.3593e-01, 5.2201e-02, 3.2895e-02, ..., 2.4580e-02,\n", - " 1.4590e-01, -4.6850e-02],\n", - " ...,\n", - " [ 1.0434e-03, 4.2251e-04, 6.5688e-04, ..., 1.2144e-03,\n", - " 2.1159e-04, 6.6838e-04],\n", - " [ 6.4997e-04, 4.4301e-04, 4.1550e-04, ..., 1.0420e-03,\n", - " 2.4114e-04, 1.5338e-04],\n", - " [-9.9337e-01, 5.4573e-01, -1.1371e-02, ..., -4.3175e-01,\n", - " -2.7850e-01, -4.4679e-01]], device='cuda:0')\n" - ] - } - ], - "source": [ - "print(model.decoder.output_layer.bias.grad)\n", - "print(model.decoder.output_layer.weight.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "40d00a54", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-5.3698e-01, -1.9911e-01, -3.4997e-01, ..., -8.2428e-01,\n", - " -1.0265e+00, -9.6301e-01],\n", - " [-4.4642e-02, 2.3176e-01, -3.2539e-01, ..., -9.0159e-01,\n", - " -1.0325e+00, -7.5987e-01],\n", - " [ 5.0035e-01, 2.2691e-01, -7.3052e-01, ..., -1.0055e+00,\n", - " -8.7123e-01, -1.0306e+00],\n", - " ...,\n", - " [-4.0024e-01, -1.4325e-01, -5.7947e-01, ..., -1.0718e+00,\n", - " -1.2806e+00, -1.0518e+00],\n", - " [ 1.5755e-01, -1.8495e-03, -2.8703e-01, ..., -1.1090e+00,\n", - " -9.4519e-01, -7.2506e-01],\n", - " [-4.7520e-01, -1.3942e+00, -2.5754e-01, ..., -1.1365e+00,\n", - " -1.1943e+00, -1.2290e+00]],\n", - "\n", - " [[ 9.5454e-01, 3.6428e-01, -1.3891e+00, ..., -1.1637e+00,\n", - " -1.2845e+00, -1.2015e+00],\n", - " [-8.5735e-02, -1.0579e+00, -8.9173e-01, ..., -9.6441e-01,\n", - " -1.1255e+00, -1.2599e+00],\n", - " [ 4.7654e-01, 3.2887e-01, -5.9201e-01, ..., -1.1942e+00,\n", - " -1.1430e+00, -1.0242e+00],\n", - " ...,\n", - " [-4.7431e-01, -3.3559e-01, -7.2326e-01, ..., -1.4506e+00,\n", - " -1.3957e+00, -1.0464e+00],\n", - " [ 3.6113e-01, 1.0381e-01, -1.1599e+00, ..., -1.0439e+00,\n", - " -1.0221e+00, -1.0208e+00],\n", - " [-1.2717e+00, -2.1460e+00, -7.5677e-01, ..., -9.7822e-01,\n", - " -9.3785e-01, -1.0371e+00]],\n", - "\n", - " [[-1.5465e+00, -1.0152e+00, -8.8901e-01, ..., -4.8522e-01,\n", - " -7.5163e-01, -6.7765e-01],\n", - " [-7.6101e-01, -7.3352e-01, -9.1588e-01, ..., -2.4836e-01,\n", - " -5.8927e-01, -7.3723e-01],\n", - " [-2.4714e-02, 1.7016e-01, -4.2326e-01, ..., -3.3204e-01,\n", - " -7.6696e-01, -7.1652e-01],\n", - " ...,\n", - " [-1.7032e+00, -1.2591e+00, -1.1449e+00, ..., -1.1810e+00,\n", - " -1.1163e+00, -9.3108e-01],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.4983e-01, 2.6117e-01, -8.4197e-01, ..., -8.7213e-01,\n", - " -1.1073e+00, -1.3253e+00],\n", - " [ 3.5391e-01, -1.5846e-02, -4.0425e-01, ..., -9.9173e-01,\n", - " -1.0727e+00, -1.1924e+00],\n", - " [ 3.7704e-01, -6.2785e-02, -1.1468e-01, ..., -1.1021e+00,\n", - " -1.0952e+00, -1.1182e+00],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " [[ 4.4458e-02, -1.7547e-01, -6.7475e-01, ..., -4.9801e-01,\n", - " -5.6783e-01, -7.7852e-01],\n", - " [-1.3428e+00, -8.0343e-01, -9.0457e-01, ..., -6.5902e-01,\n", - " -7.2550e-01, -6.2796e-01],\n", - " [-7.6253e-01, -1.3071e-01, -1.3280e-01, ..., -5.6133e-01,\n", - " -6.0588e-01, -7.2115e-01],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " [[-1.0798e+00, -1.0834e+00, -1.1797e+00, ..., -1.7757e-01,\n", - " -4.3747e-01, -4.0007e-02],\n", - " [ 9.2354e-01, 6.3771e-01, -5.2810e-01, ..., -1.2928e-01,\n", - " -2.0342e-01, 1.6656e-01],\n", - " [ 4.9337e-01, -9.1133e-03, -7.3302e-01, ..., 1.0074e-01,\n", - " -9.8115e-02, -9.2357e-03],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]]], device='cuda:0')\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "print(xs)" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "505ca294", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[ True, True, True, ..., True, True, True]],\n", - "\n", - " [[ True, True, True, ..., True, True, True]],\n", - "\n", - " [[ True, True, True, ..., True, False, False]],\n", - "\n", - " ...,\n", - "\n", - " [[ True, True, True, ..., False, False, False]],\n", - "\n", - " [[ True, True, True, ..., False, False, False]],\n", - "\n", - " [[ True, True, True, ..., False, False, False]]], device='cuda:0')\n" - ] - } - ], - "source": [ - "from wenet.utils.mask import make_pad_mask\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "print(masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "aa03c2b9", - "metadata": {}, - "outputs": [], - "source": [ - "xs, pos_emb, masks = model.encoder.embed(xs, masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "ebc0ea12", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-0.5482, 2.2866, -1.0750, ..., 1.4504, 0.2895, -0.6945],\n", - " [-0.8013, 1.7688, -1.6639, ..., 1.8332, 0.6791, -0.2000],\n", - " [-1.7112, 2.7057, -1.3363, ..., 1.2336, 0.1870, -0.5735],\n", - " ...,\n", - " [-0.9697, 2.3129, -0.8752, ..., 0.8584, 0.4853, -0.4177],\n", - " [-1.3609, 2.1779, -1.7813, ..., 2.0928, 0.2528, -0.3650],\n", - " [-1.6967, 2.3544, -1.7417, ..., 1.3670, 0.5951, -0.7415]],\n", - "\n", - " [[-1.9828, 2.3178, -0.9079, ..., 0.4117, 0.5006, 0.0872],\n", - " [-0.7640, 1.3558, -1.3613, ..., 0.7317, 0.6784, 0.1685],\n", - " [-0.9504, 1.6038, -1.3030, ..., 0.5754, 0.2677, 0.3343],\n", - " ...,\n", - " [-1.4757, 2.5317, -1.2321, ..., 1.2997, 0.5019, -0.1034],\n", - " [-1.1731, 2.3172, -1.2542, ..., 1.7391, 0.2171, -0.4445],\n", - " [-1.2700, 3.2229, -0.8872, ..., 1.6461, 0.0973, -0.7679]],\n", - "\n", - " [[-0.5873, 1.4291, -1.3950, ..., 0.2102, 0.1027, 0.0918],\n", - " [ 0.1743, 1.7834, -1.6422, ..., 0.8113, 0.3137, 0.5634],\n", - " [-0.3492, 1.8310, -1.0685, ..., 0.6924, 0.1378, 0.4594],\n", - " ...,\n", - " [-1.0869, 2.3002, -1.2638, ..., 1.7998, 0.5134, -0.5223],\n", - " [-1.2614, 2.7240, -1.3734, ..., 1.4445, 0.5742, -0.3320],\n", - " [-2.2068, 4.3462, -3.8289, ..., 2.1426, 1.2034, -1.3795]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.3914, 1.8553, -0.5747, ..., 1.0062, 0.4632, -1.0452],\n", - " [-0.8605, 2.0172, -1.4437, ..., 1.4526, 0.1657, 0.5923],\n", - " [-0.7307, 2.2841, -1.0699, ..., 1.5825, -0.0980, 0.5503],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", - "\n", - " [[-0.1619, 0.6255, -1.1323, ..., 0.0724, -0.2204, 0.4636],\n", - " [-0.0831, 0.5750, -1.0930, ..., 0.9110, -0.0650, 0.7299],\n", - " [-0.2820, 0.0801, -0.9418, ..., 0.3379, -0.1166, 0.4451],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", - "\n", - " [[-0.5458, -0.6909, -1.3597, ..., -0.7818, 0.6875, 0.9843],\n", - " [ 0.0421, -1.1062, -1.4389, ..., -0.0239, 0.9115, 0.5287],\n", - " [-0.2909, -0.1886, -1.5487, ..., -0.1392, 0.0580, 0.3066],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n", - " 0.0000e+00, 1.0000e+00],\n", - " [ 8.4147e-01, 5.4030e-01, 8.0196e-01, ..., 1.0000e+00,\n", - " 1.0746e-04, 1.0000e+00],\n", - " [ 9.0930e-01, -4.1615e-01, 9.5814e-01, ..., 1.0000e+00,\n", - " 2.1492e-04, 1.0000e+00],\n", - " ...,\n", - " [-7.6825e-01, -6.4014e-01, 6.3280e-01, ..., 9.9998e-01,\n", - " 5.1581e-03, 9.9999e-01],\n", - " [-9.5375e-01, 3.0059e-01, 9.9899e-01, ..., 9.9998e-01,\n", - " 5.2656e-03, 9.9999e-01],\n", - " [-2.6237e-01, 9.6497e-01, 5.6075e-01, ..., 9.9998e-01,\n", - " 5.3730e-03, 9.9999e-01]]], device='cuda:0')\n", - "tensor([[[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, False, False, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, False, False, False, False, False, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, False, False, False, False, False, False, False, False, False,\n", - " False]]], device='cuda:0')\n", - "torch.Size([16, 1, 51])\n" - ] - } - ], - "source": [ - "print(xs)\n", - "print(pos_emb)\n", - "print(masks)\n", - "print(masks.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "4289461b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", - " -0.6945408 ]\n", - " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", - " -0.1999542 ]\n", - " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", - " -0.5735198 ]\n", - " ...\n", - " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", - " -0.41773027]\n", - " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", - " -0.36496443]\n", - " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", - " -0.74147725]]\n", - "\n", - " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", - " 0.08721463]\n", - " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", - " 0.16851945]\n", - " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", - " 0.33433008]\n", - " ...\n", - " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", - " -0.10343577]\n", - " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", - " -0.44447583]\n", - " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", - " -0.7678688 ]]\n", - "\n", - " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", - " 0.09179455]\n", - " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", - " 0.56344515]\n", - " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", - " 0.45937473]\n", - " ...\n", - " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", - " -0.52227837]\n", - " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", - " -0.33201432]\n", - " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", - " -1.3795122 ]]\n", - "\n", - " ...\n", - "\n", - " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", - " -1.045236 ]\n", - " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", - " 0.5923172 ]\n", - " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", - " 0.55030036]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", - " 0.46362036]\n", - " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", - " 0.72986233]\n", - " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", - " 0.44514441]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", - " 0.9842716 ]\n", - " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", - " 0.52870303]\n", - " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", - " 0.30663735]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]]\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", - "print(xs.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "67e10d73", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 2.0908e-03],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 1.1943e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 4.6105e-02, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 9.6723e-03,\n", - " 4.6135e-02, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[2.2816e-01, 2.4615e-01, 2.5304e-01, ..., 2.0402e-01,\n", - " 2.3248e-01, 3.1191e-01],\n", - " [1.3587e-01, 2.8877e-01, 2.7991e-01, ..., 1.9210e-01,\n", - " 2.0346e-01, 1.9934e-01],\n", - " [2.5739e-01, 3.9348e-01, 2.7877e-01, ..., 2.7483e-01,\n", - " 1.9302e-01, 2.3810e-01],\n", - " ...,\n", - " [1.1939e-01, 2.8473e-01, 3.3082e-01, ..., 2.3838e-01,\n", - " 2.2104e-01, 2.3906e-01],\n", - " [1.7388e-01, 2.0402e-01, 4.0263e-01, ..., 2.4782e-01,\n", - " 2.6742e-01, 1.5427e-01],\n", - " [0.0000e+00, 2.9081e-01, 2.7726e-01, ..., 1.7540e-01,\n", - " 1.8479e-01, 2.2483e-01]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.5447e-01, 3.8861e-01, 3.9724e-01, ..., 3.8680e-01,\n", - " 3.3568e-01, 3.4552e-01],\n", - " [4.1739e-01, 5.1039e-01, 4.1730e-01, ..., 3.3993e-01,\n", - " 3.7082e-01, 3.5110e-01],\n", - " [3.6117e-01, 4.0745e-01, 4.8491e-01, ..., 3.4849e-01,\n", - " 3.2321e-01, 3.5189e-01],\n", - " ...,\n", - " [2.3144e-01, 3.8021e-01, 5.1526e-01, ..., 3.6499e-01,\n", - " 3.7412e-01, 3.9986e-01],\n", - " [3.4679e-01, 4.0238e-01, 5.0077e-01, ..., 3.6185e-01,\n", - " 3.1597e-01, 3.6335e-01],\n", - " [3.6498e-01, 3.7943e-01, 5.1719e-01, ..., 3.1798e-01,\n", - " 3.3657e-01, 3.4130e-01]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.4560e-02, 9.4475e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5002e-02, 2.9632e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.2952e-02, 0.0000e+00, 0.0000e+00, ..., 4.5850e-02,\n", - " 2.0439e-02, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 4.4258e-02],\n", - " [0.0000e+00, 0.0000e+00, 2.5565e-02, ..., 0.0000e+00,\n", - " 9.0044e-03, 4.9084e-02]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1141e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[3.3697e-01, 3.8527e-01, 3.2900e-01, ..., 2.8704e-01,\n", - " 2.3351e-01, 1.9004e-01],\n", - " [1.3575e-01, 3.5783e-01, 3.3573e-01, ..., 2.2082e-01,\n", - " 1.5855e-01, 1.3587e-01],\n", - " [2.1929e-01, 2.8900e-01, 2.8255e-01, ..., 2.0603e-01,\n", - " 2.3927e-01, 2.1909e-01],\n", - " ...,\n", - " [2.3292e-01, 3.9097e-01, 3.6399e-01, ..., 2.0598e-01,\n", - " 2.5374e-01, 2.3137e-01],\n", - " [1.8739e-01, 3.0794e-01, 3.0297e-01, ..., 2.7251e-01,\n", - " 2.5192e-01, 2.0837e-01],\n", - " [2.2454e-01, 4.1402e-01, 5.4083e-01, ..., 3.1875e-01,\n", - " 2.5080e-01, 2.5939e-01]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.6457e-01, 4.9519e-01, 5.6702e-01, ..., 3.0955e-01,\n", - " 3.5292e-01, 3.2669e-01],\n", - " [2.1577e-01, 5.1833e-01, 4.9183e-01, ..., 3.6043e-01,\n", - " 3.8524e-01, 3.6155e-01],\n", - " [2.0068e-01, 4.2784e-01, 5.2818e-01, ..., 3.1871e-01,\n", - " 3.2452e-01, 3.1036e-01],\n", - " ...,\n", - " [4.9855e-01, 5.1001e-01, 5.2279e-01, ..., 3.6450e-01,\n", - " 3.4338e-01, 3.3603e-01],\n", - " [4.1233e-01, 5.5518e-01, 5.2828e-01, ..., 4.0676e-01,\n", - " 3.3873e-01, 3.6724e-01],\n", - " [4.0820e-01, 4.6187e-01, 4.7338e-01, ..., 3.8691e-01,\n", - " 3.6039e-01, 3.8022e-01]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 5.7852e-03, 0.0000e+00, ..., 7.4838e-03,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0351e-02,\n", - " 0.0000e+00, 2.6720e-04],\n", - " [9.4807e-04, 0.0000e+00, 0.0000e+00, ..., 7.9551e-03,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [2.0326e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 1.0801e-02, 0.0000e+00],\n", - " [1.8470e-01, 0.0000e+00, 0.0000e+00, ..., 5.0584e-02,\n", - " 9.4758e-02, 5.9146e-02]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[3.8708e-01, 2.8022e-01, 3.5893e-01, ..., 1.6595e-01,\n", - " 1.6031e-01, 2.1136e-01],\n", - " [1.5595e-01, 3.0544e-01, 2.4666e-01, ..., 2.2675e-01,\n", - " 2.5765e-01, 1.9682e-01],\n", - " [2.9518e-01, 4.1210e-01, 2.0063e-01, ..., 1.7595e-01,\n", - " 2.2537e-01, 2.2214e-01],\n", - " ...,\n", - " [2.4745e-01, 2.6259e-01, 3.8654e-01, ..., 2.3620e-01,\n", - " 2.3157e-01, 1.8514e-01],\n", - " [2.5715e-01, 2.9593e-01, 4.7745e-01, ..., 2.3546e-01,\n", - " 2.5073e-01, 2.0976e-01],\n", - " [1.2015e+00, 8.4644e-01, 7.3386e-01, ..., 1.0252e+00,\n", - " 9.5310e-01, 1.0013e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[4.5013e-01, 4.7484e-01, 4.0540e-01, ..., 1.9346e-01,\n", - " 1.7826e-01, 1.4777e-01],\n", - " [4.7546e-01, 4.8187e-01, 3.6760e-01, ..., 2.7809e-01,\n", - " 3.2997e-01, 3.2337e-01],\n", - " [4.6160e-01, 4.0050e-01, 3.9061e-01, ..., 3.6613e-01,\n", - " 3.5243e-01, 2.9739e-01],\n", - " ...,\n", - " [5.5148e-01, 5.1018e-01, 4.0132e-01, ..., 3.8948e-01,\n", - " 3.5737e-01, 3.3088e-01],\n", - " [4.1973e-01, 4.5475e-01, 4.5320e-01, ..., 3.8343e-01,\n", - " 4.0126e-01, 3.6181e-01],\n", - " [3.4280e-01, 3.1606e-01, 4.4701e-01, ..., 2.1665e-01,\n", - " 2.3985e-01, 2.3903e-01]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.1783e-02, 0.0000e+00, 1.5805e-02, ..., 0.0000e+00,\n", - " 2.2508e-02, 0.0000e+00],\n", - " [4.3234e-02, 7.7864e-02, 0.0000e+00, ..., 1.6347e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.2092e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.3563e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[0.0000e+00, 2.5187e-01, 2.4979e-01, ..., 2.4775e-01,\n", - " 2.2354e-01, 1.9149e-01],\n", - " [1.6541e-01, 1.9586e-01, 1.9813e-01, ..., 2.7344e-01,\n", - " 2.0928e-01, 2.6150e-01],\n", - " [1.0495e-01, 6.3299e-02, 3.3844e-01, ..., 2.5138e-01,\n", - " 1.2470e-01, 2.3927e-01],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.1428e-01, 4.5667e-01, 4.6821e-01, ..., 3.2058e-01,\n", - " 3.3579e-01, 3.9013e-01],\n", - " [1.0441e-01, 4.5739e-01, 4.6107e-01, ..., 3.8468e-01,\n", - " 3.8291e-01, 3.6686e-01],\n", - " [1.9868e-01, 3.5520e-01, 4.4313e-01, ..., 4.0679e-01,\n", - " 3.8068e-01, 3.0646e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.4654e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 3.3902e-02],\n", - " [0.0000e+00, 0.0000e+00, 1.8307e-02, ..., 5.1669e-02,\n", - " 9.4838e-03, 7.4535e-02],\n", - " [9.9215e-02, 0.0000e+00, 1.5872e-02, ..., 1.6203e-02,\n", - " 5.1401e-02, 1.9239e-03],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[4.0034e-01, 2.5306e-01, 2.0218e-01, ..., 9.8162e-02,\n", - " 7.0643e-02, 4.9741e-02],\n", - " [1.2568e-01, 2.1031e-01, 1.1182e-01, ..., 4.2781e-02,\n", - " 1.1969e-01, 1.2005e-01],\n", - " [2.8787e-01, 2.4031e-01, 2.2566e-01, ..., 0.0000e+00,\n", - " 6.4181e-02, 5.8730e-02],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.8405e-01, 3.0990e-01, 3.7156e-01, ..., 1.8125e-01,\n", - " 1.5051e-01, 1.9620e-01],\n", - " [4.7286e-01, 4.0529e-01, 3.9718e-01, ..., 2.4710e-01,\n", - " 4.5657e-02, 1.1501e-01],\n", - " [3.2621e-01, 3.0073e-01, 3.0477e-01, ..., 2.3529e-01,\n", - " 2.1357e-01, 1.6986e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.3438e-02, 1.2378e-03, 5.2972e-02, ..., 7.2712e-02,\n", - " 8.6563e-02, 1.4494e-01],\n", - " [1.1043e-01, 6.1431e-02, 6.3630e-02, ..., 8.1278e-02,\n", - " 6.2590e-02, 8.3154e-02],\n", - " [1.7677e-02, 2.0111e-03, 7.8750e-02, ..., 6.9633e-02,\n", - " 8.9799e-02, 5.3263e-02],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.0034e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5627e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.1447e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 4.3641e-03],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[2.5142e-01, 4.5964e-01, 3.7346e-01, ..., 4.7631e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.9760e-01, 2.6627e-01, 1.1191e-01, ..., 3.0450e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.6341e-01, 3.2938e-01, 2.5690e-01, ..., 5.5694e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 2.2189e-02, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 2.8490e-02],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.5810e-01, 6.3017e-01, 3.7038e-01, ..., 1.8704e-01,\n", - " 8.2694e-02, 9.9127e-02],\n", - " [1.7293e-01, 5.0679e-01, 4.0739e-01, ..., 1.6006e-01,\n", - " 1.1725e-01, 9.9405e-02],\n", - " [2.4175e-01, 4.1616e-01, 4.1257e-01, ..., 1.3520e-01,\n", - " 7.9126e-02, 1.2846e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]]], device='cuda:0',\n", - " grad_fn=)\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "\n", - "x = xs.unsqueeze(1)\n", - "x = model.encoder.embed.conv(x)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "9a9478ad", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.03426375 0.14291267 -0.06718873 ... 0.09064753 0.01809387\n", - " -0.0434088 ]\n", - " [-0.05007839 0.11054724 -0.10399298 ... 0.11457238 0.04244684\n", - " -0.01249714]\n", - " [-0.10695291 0.16910909 -0.08352133 ... 0.07710276 0.01168563\n", - " -0.03584499]\n", - " ...\n", - " [-0.06060536 0.14455931 -0.05470302 ... 0.05364908 0.03033342\n", - " -0.02610814]\n", - " [-0.08505894 0.13611752 -0.11132983 ... 0.13079923 0.01580139\n", - " -0.02281028]\n", - " [-0.10604677 0.14714901 -0.10885533 ... 0.08543444 0.03719445\n", - " -0.04634233]]\n", - "\n", - " [[-0.12392755 0.14486063 -0.05674079 ... 0.02573164 0.03128851\n", - " 0.00545091]\n", - " [-0.04775286 0.08473608 -0.08507854 ... 0.04573154 0.04240163\n", - " 0.01053247]\n", - " [-0.05940291 0.10023535 -0.0814373 ... 0.035965 0.01673085\n", - " 0.02089563]\n", - " ...\n", - " [-0.09222981 0.15823206 -0.07700447 ... 0.08122957 0.03136991\n", - " -0.00646474]\n", - " [-0.07331756 0.14482647 -0.07838815 ... 0.1086944 0.01356864\n", - " -0.02777974]\n", - " [-0.07937264 0.20143102 -0.05544947 ... 0.10287814 0.00608235\n", - " -0.0479918 ]]\n", - "\n", - " [[-0.03670349 0.0893159 -0.08718812 ... 0.0131405 0.00642052\n", - " 0.00573716]\n", - " [ 0.01089254 0.11146393 -0.10263617 ... 0.05070438 0.01960694\n", - " 0.03521532]\n", - " [-0.0218228 0.11443964 -0.06678198 ... 0.04327708 0.00861394\n", - " 0.02871092]\n", - " ...\n", - " [-0.06792898 0.14376275 -0.07899005 ... 0.11248926 0.03208683\n", - " -0.0326424 ]\n", - " [-0.07884051 0.17024788 -0.08583611 ... 0.09028331 0.03588808\n", - " -0.0207509 ]\n", - " [-0.13792302 0.27163863 -0.23930418 ... 0.13391261 0.0752104\n", - " -0.08621951]]\n", - "\n", - " ...\n", - "\n", - " [[-0.02446348 0.11595841 -0.03591986 ... 0.0628897 0.02895011\n", - " -0.06532725]\n", - " [-0.05378424 0.1260737 -0.09023033 ... 0.09078894 0.01035743\n", - " 0.03701983]\n", - " [-0.04566649 0.14275314 -0.0668687 ... 0.09890588 -0.00612222\n", - " 0.03439377]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]\n", - "\n", - " [[-0.01012144 0.03909408 -0.07077143 ... 0.00452683 -0.01377654\n", - " 0.02897627]\n", - " [-0.00519154 0.03594019 -0.06831125 ... 0.05693541 -0.00406374\n", - " 0.0456164 ]\n", - " [-0.01762631 0.00500899 -0.05886075 ... 0.02112178 -0.00729015\n", - " 0.02782153]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]\n", - "\n", - " [[-0.03411558 -0.04318277 -0.08497842 ... -0.04886402 0.04296734\n", - " 0.06151697]\n", - " [ 0.00263296 -0.06913657 -0.08993219 ... -0.00149064 0.05696633\n", - " 0.03304394]\n", - " [-0.01818341 -0.0117864 -0.09679577 ... -0.00870231 0.00362198\n", - " 0.01916483]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]]\n", - "torch.Size([16, 51, 256])\n" - ] - } - ], - "source": [ - "b, c, t, f = x.size()\n", - "x = model.encoder.embed.out(x.transpose(1, 2).contiguous().view(b, t, c * f))\n", - "print(x.cpu().detach().numpy())\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "fd69003f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", - " -0.6945408 ]\n", - " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", - " -0.1999542 ]\n", - " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", - " -0.5735198 ]\n", - " ...\n", - " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", - " -0.41773027]\n", - " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", - " -0.36496443]\n", - " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", - " -0.74147725]]\n", - "\n", - " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", - " 0.08721463]\n", - " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", - " 0.16851945]\n", - " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", - " 0.33433008]\n", - " ...\n", - " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", - " -0.10343577]\n", - " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", - " -0.44447583]\n", - " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", - " -0.7678688 ]]\n", - "\n", - " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", - " 0.09179455]\n", - " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", - " 0.56344515]\n", - " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", - " 0.45937473]\n", - " ...\n", - " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", - " -0.52227837]\n", - " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", - " -0.33201432]\n", - " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", - " -1.3795122 ]]\n", - "\n", - " ...\n", - "\n", - " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", - " -1.045236 ]\n", - " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", - " 0.5923172 ]\n", - " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", - " 0.55030036]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", - " 0.46362036]\n", - " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", - " 0.72986233]\n", - " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", - " 0.44514441]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", - " 0.9842716 ]\n", - " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", - " 0.52870303]\n", - " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", - " 0.30663735]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]]\n" - ] - } - ], - "source": [ - "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", - "print(x.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "8ed88489", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.float32\n", - "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", - " 0.0000000e+00 1.0000000e+00]\n", - " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", - " 1.0746076e-04 1.0000000e+00]\n", - " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", - " 2.1492151e-04 1.0000000e+00]\n", - " ...\n", - " [-7.6825464e-01 -6.4014435e-01 6.3279724e-01 ... 9.9998462e-01\n", - " 5.1580933e-03 9.9998671e-01]\n", - " [-9.5375264e-01 3.0059254e-01 9.9899054e-01 ... 9.9998397e-01\n", - " 5.2655530e-03 9.9998611e-01]\n", - " [-2.6237485e-01 9.6496606e-01 5.6074661e-01 ... 9.9998331e-01\n", - " 5.3730118e-03 9.9998558e-01]]]\n" - ] - } - ], - "source": [ - "print(pos_emb.dtype)\n", - "print(pos_emb.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "id": "5e277881", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 51, 256])\n" - ] - }, - { - "ename": "NameError", - "evalue": "name 'mask' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mpos_emb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpos_emb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0mx_att\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m )\n", - "\u001b[0;31mNameError\u001b[0m: name 'mask' is not defined" - ] - } - ], - "source": [ - "def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,\n", - " use_dynamic_chunk: bool,\n", - " use_dynamic_left_chunk: bool,\n", - " decoding_chunk_size: int, static_chunk_size: int,\n", - " num_decoding_left_chunks: int):\n", - " \"\"\" Apply optional mask for encoder.\n", - " Args:\n", - " xs (torch.Tensor): padded input, (B, L, D), L for max length\n", - " mask (torch.Tensor): mask for xs, (B, 1, L)\n", - " use_dynamic_chunk (bool): whether to use dynamic chunk or not\n", - " use_dynamic_left_chunk (bool): whether to use dynamic left chunk for\n", - " training.\n", - " decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's\n", - " 0: default for training, use random dynamic chunk.\n", - " <0: for decoding, use full chunk.\n", - " >0: for decoding, use fixed chunk size as set.\n", - " static_chunk_size (int): chunk size for static chunk training/decoding\n", - " if it's greater than 0, if use_dynamic_chunk is true,\n", - " this parameter will be ignored\n", - " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", - " the chunk size is decoding_chunk_size.\n", - " >=0: use num_decoding_left_chunks\n", - " <0: use all left chunks\n", - " Returns:\n", - " torch.Tensor: chunk mask of the input xs.\n", - " \"\"\"\n", - " # Whether to use chunk mask or not\n", - " if use_dynamic_chunk:\n", - " max_len = xs.size(1)\n", - " if decoding_chunk_size < 0:\n", - " chunk_size = max_len\n", - " num_left_chunks = -1\n", - " elif decoding_chunk_size > 0:\n", - " chunk_size = decoding_chunk_size\n", - " num_left_chunks = num_decoding_left_chunks\n", - " else:\n", - " # chunk size is either [1, 25] or full context(max_len).\n", - " # Since we use 4 times subsampling and allow up to 1s(100 frames)\n", - " # delay, the maximum frame is 100 / 4 = 25.\n", - " chunk_size = torch.randint(1, max_len, (1, )).item()\n", - " num_left_chunks = -1\n", - " if chunk_size > max_len // 2:\n", - " chunk_size = max_len\n", - " else:\n", - " chunk_size = chunk_size % 25 + 1\n", - " if use_dynamic_left_chunk:\n", - " max_left_chunks = (max_len - 1) // chunk_size\n", - " num_left_chunks = torch.randint(0, max_left_chunks,\n", - " (1, )).item()\n", - " chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,\n", - " num_left_chunks,\n", - " xs.device) # (L, L)\n", - " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", - " chunk_masks = masks & chunk_masks # (B, L, L)\n", - " elif static_chunk_size > 0:\n", - " num_left_chunks = num_decoding_left_chunks\n", - " chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,\n", - " num_left_chunks,\n", - " xs.device) # (L, L)\n", - " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", - " chunk_masks = masks & chunk_masks # (B, L, L)\n", - " else:\n", - " chunk_masks = masks\n", - " return chunk_masks\n", - "\n", - "from wenet.utils.mask import make_pad_mask\n", - "\n", - "\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1)\n", - "xs = model.encoder.global_cmvn(feat)\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", - "\n", - "mask_pad = masks\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "use_dynamic_left_chunk=-1\n", - "use_dynamic_chunk=False\n", - "static_chunk_size=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, \n", - " masks, \n", - " use_dynamic_chunk,\n", - " use_dynamic_left_chunk,\n", - " decoding_chunk_size, \n", - " static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_embed', \n", - " embed_out=xs.cpu().detach().numpy(), \n", - " pos_emb=pos_emb.cpu().detach().numpy(),\n", - " chunk_masks=chunk_masks.cpu().detach().numpy(),\n", - " mask_pad=mask_pad.cpu().detach().numpy())\n", - "\n", - "model.eval()\n", - "# print(chunk_masks)\n", - "print(xs.shape)\n", - "for layer in model.encoder.encoders:\n", - " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " #np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0', enc_0=xs.cpu().detach().numpy())\n", - " \n", - " x = xs\n", - " residual = x\n", - " x_norm = layer.norm_ff_macaron(x)\n", - " !rm /workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff.npz\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff', \n", - " norm_ff=x_norm.cpu().detach().numpy(),\n", - " xs=xs.cpu().detach().numpy())\n", - " #print(x.cpu().detach().numpy())\n", - " for p in layer.norm_ff_macaron.parameters():\n", - " #print(p, p.sum())\n", - " pass\n", - " \n", - " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_norm)\n", - " \n", - " ps = []\n", - " for n, p in layer.feed_forward_macaron.state_dict().items():\n", - " #print(n, p.cpu().data.numpy())\n", - " ps.append(p.cpu().data.numpy())\n", - " pass\n", - "\n", - " ff_l_x = layer.feed_forward_macaron.w_1(x_norm)\n", - " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", - " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_ff_out', \n", - " norm_ff=x_norm.cpu().detach().numpy(),\n", - " ff_out=x.cpu().detach().numpy(),\n", - " ff_l_x = ff_l_x.cpu().detach().numpy(),\n", - " ff_l_a_x=ff_l_a_x.cpu().detach().numpy(),\n", - " ff_l_a_l_x=ff_l_a_l_x.cpu().detach().numpy(),\n", - " ps=ps,\n", - " )\n", - " \n", - " \n", - " residual = x\n", - " x = layer.norm_mha(x)\n", - " x_q = x\n", - " \n", - " x_att = layer.self_attn(x_q, x, x, pos_emb, masks)\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_selattn_out', \n", - " x_q=x_q.cpu().detach().numpy(),\n", - " x=x.cpu().detach().numpy(),\n", - " pos_emb = pos_emb.cpu().detach().numpy(),\n", - " mask=mask.cpu().detach().numpy(),\n", - " x_att=x_att.cpu().detach().numpy(),\n", - " )\n", - " \n", - " break\n", - "#print(xs.cpu().detach().numpy())\n", - "\n", - "\n", - "i = 0\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " i += 1\n", - " if i == 2:\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_2', enc_2=xs.cpu().detach().numpy())\n", - " \n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_all', enc_all=xs.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c43fd4f1", - "metadata": {}, - "outputs": [], - "source": [ - "out, mask = model.encoder(feat, feat_len)\n", - "#print(out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e73db22", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f506114", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/README.md b/README.md index de24abe2f674f4fe1f62555a7c5474c9eac64b64..71bc636380922c6d0b0e084e9d813d152a61ab56 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -[中文版](README_cn.md) - # PaddlePaddle Speech to Any toolkit ![License](https://img.shields.io/badge/license-Apache%202-red.svg) @@ -11,31 +9,29 @@ ## Features - See [feature list](doc/src/feature_list.md) for more information. + See [feature list](docs/src/feature_list.md) for more information. ## Setup All tested under: * Ubuntu 16.04 * python>=3.7 -* paddlepaddle>=2.1.2 +* paddlepaddle>=2.2.0rc -Please see [install](doc/src/install.md). +Please see [install](docs/src/install.md). ## Getting Started -Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md). +Please see [Getting Started](docs/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md). ## More Information -* [Data Prepration](doc/src/data_preparation.md) -* [Data Augmentation](doc/src/augmentation.md) -* [Ngram LM](doc/src/ngram_lm.md) -* [Server Demo](doc/src/server.md) -* [Benchmark](doc/src/benchmark.md) -* [Relased Model](doc/src/released_model.md) -* [FAQ](doc/src/faq.md) +* [Data Prepration](docs/src/data_preparation.md) +* [Data Augmentation](docs/src/augmentation.md) +* [Ngram LM](docs/src/ngram_lm.md) +* [Benchmark](docs/src/benchmark.md) +* [Relased Model](docs/src/released_model.md) ## Questions and Help @@ -45,8 +41,8 @@ You are welcome to submit questions in [Github Discussions](https://github.com/P ## License -DeepASR is provided under the [Apache-2.0 License](./LICENSE). +DeepSpeech is provided under the [Apache-2.0 License](./LICENSE). ## Acknowledgement -We depends on many open source repos. See [References](doc/src/reference.md) for more information. +We depends on many open source repos. See [References](docs/src/reference.md) for more information. diff --git a/README_cn.md b/README_cn.md deleted file mode 100644 index 4b92736252ef7249e204cecc4ed80349a7316299..0000000000000000000000000000000000000000 --- a/README_cn.md +++ /dev/null @@ -1,51 +0,0 @@ -[English](README.md) - -# PaddlePaddle Speech to Any toolkit - -![License](https://img.shields.io/badge/license-Apache%202-red.svg) -![python version](https://img.shields.io/badge/python-3.7+-orange.svg) -![support os](https://img.shields.io/badge/os-linux-yellow.svg) - -*DeepSpeech*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别引擎的开源项目, -我们的愿景是为语音识别在工业应用和学术研究上,提供易于使用、高效、小型化和可扩展的工具,包括训练,推理,以及 部署。 - -## 特性 - - 参看 [特性列表](doc/src/feature_list.md)。 - - -## 安装 - -在以下环境测试验证过: - -* Ubuntu 16.04 -* python>=3.7 -* paddlepaddle>=2.1.2 - -参看 [安装](doc/src/install.md)。 - -## 开始 - -请查看 [开始](doc/src/getting_started.md) 和 [tiny egs](examples/tiny/s0/README.md)。 - -## 更多信息 - -* [数据处理](doc/src/data_preparation.md) -* [数据增强](doc/src/augmentation.md) -* [语言模型](doc/src/ngram_lm.md) -* [服务部署](doc/src/server.md) -* [Benchmark](doc/src/benchmark.md) -* [Relased Model](doc/src/released_model.md) -* [FAQ](doc/src/faq.md) - -## 问题和帮助 - -欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。 - -## License - -DeepASR 遵循[Apache-2.0开源协议](./LICENSE)。 - -## 感谢 - -开发中参考一些优秀的仓库,详情参见 [References](doc/src/reference.md)。 diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index d85a3dde7d44a388878a0b0f411f4a2bd594800d..5505ecbf04434477b1c3490ba893368433390b0d 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -80,23 +80,23 @@ def convert_dtype_to_string(tensor_dtype): if not hasattr(paddle, 'softmax'): - logger.warn("register user softmax to paddle, remove this when fixed!") + logger.debug("register user softmax to paddle, remove this when fixed!") setattr(paddle, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle, 'log_softmax'): - logger.warn("register user log_softmax to paddle, remove this when fixed!") + logger.debug("register user log_softmax to paddle, remove this when fixed!") setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) if not hasattr(paddle, 'sigmoid'): - logger.warn("register user sigmoid to paddle, remove this when fixed!") + logger.debug("register user sigmoid to paddle, remove this when fixed!") setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle, 'log_sigmoid'): - logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + logger.debug("register user log_sigmoid to paddle, remove this when fixed!") setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) if not hasattr(paddle, 'relu'): - logger.warn("register user relu to paddle, remove this when fixed!") + logger.debug("register user relu to paddle, remove this when fixed!") setattr(paddle, 'relu', paddle.nn.functional.relu) @@ -105,7 +105,7 @@ def cat(xs, dim=0): if not hasattr(paddle, 'cat'): - logger.warn( + logger.debug( "override cat of paddle if exists or register, remove this when fixed!") paddle.cat = cat @@ -116,7 +116,7 @@ def item(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'item'): - logger.warn( + logger.debug( "override item of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.item = item @@ -127,13 +127,13 @@ def func_long(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'long'): - logger.warn( + logger.debug( "override long of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.long = func_long if not hasattr(paddle.Tensor, 'numel'): - logger.warn( + logger.debug( "override numel of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.numel = paddle.numel @@ -147,7 +147,7 @@ def new_full(x: paddle.Tensor, if not hasattr(paddle.Tensor, 'new_full'): - logger.warn( + logger.debug( "override new_full of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.new_full = new_full @@ -162,13 +162,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'eq'): - logger.warn( + logger.debug( "override eq of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.eq = eq if not hasattr(paddle, 'eq'): - logger.warn( + logger.debug( "override eq of paddle if exists or register, remove this when fixed!") paddle.eq = eq @@ -178,7 +178,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'contiguous'): - logger.warn( + logger.debug( "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.contiguous = contiguous @@ -195,7 +195,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: #`to_static` do not process `size` property, maybe some `paddle` api dependent on it. -logger.warn( +logger.debug( "override size of paddle.Tensor " "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" ) @@ -207,7 +207,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view'): - logger.warn("register user view to paddle.Tensor, remove this when fixed!") + logger.debug("register user view to paddle.Tensor, remove this when fixed!") paddle.Tensor.view = view @@ -216,7 +216,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view_as'): - logger.warn( + logger.debug( "register user view_as to paddle.Tensor, remove this when fixed!") paddle.Tensor.view_as = view_as @@ -242,7 +242,7 @@ def masked_fill(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill'): - logger.warn( + logger.debug( "register user masked_fill to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill = masked_fill @@ -260,7 +260,7 @@ def masked_fill_(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill_'): - logger.warn( + logger.debug( "register user masked_fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill_ = masked_fill_ @@ -272,7 +272,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'fill_'): - logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.fill_ = fill_ @@ -281,22 +282,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'repeat'): - logger.warn( + logger.debug( "register user repeat to paddle.Tensor, remove this when fixed!") paddle.Tensor.repeat = repeat if not hasattr(paddle.Tensor, 'softmax'): - logger.warn( + logger.debug( "register user softmax to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle.Tensor, 'sigmoid'): - logger.warn( + logger.debug( "register user sigmoid to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle.Tensor, 'relu'): - logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + logger.debug("register user relu to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) @@ -305,7 +306,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'type_as'): - logger.warn( + logger.debug( "register user type_as to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'type_as', type_as) @@ -321,7 +322,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'to'): - logger.warn("register user to to paddle.Tensor, remove this when fixed!") + logger.debug("register user to to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'to', to) @@ -330,7 +331,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'float'): - logger.warn("register user float to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user float to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'float', func_float) @@ -339,7 +341,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'int'): - logger.warn("register user int to paddle.Tensor, remove this when fixed!") + logger.debug("register user int to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'int', func_int) @@ -348,23 +350,6 @@ def tolist(x: paddle.Tensor) -> List[Any]: if not hasattr(paddle.Tensor, 'tolist'): - logger.warn( + logger.debug( "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) - - -########### hcak paddle.nn ############# -class GLU(nn.Layer): - """Gated Linear Units (GLU) Layer""" - - def __init__(self, dim: int=-1): - super().__init__() - self.dim = dim - - def forward(self, xs): - return F.glu(xs, axis=self.dim) - - -if not hasattr(paddle.nn, 'GLU'): - logger.warn("register user GLU to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'GLU', GLU) diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp index 4dcc7c899934e25b13bce6ca2b03c6623cc05e7d..fcb1f76425e0c5704f12c5517eb96f3db03b5ced 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp @@ -35,7 +35,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { @@ -48,7 +49,7 @@ std::vector> ctc_beam_search_decoder( // assign blank id // size_t blank_id = vocabulary.size(); - size_t blank_id = 0; + // size_t blank_id = 0; // assign space id auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); @@ -57,7 +58,6 @@ std::vector> ctc_beam_search_decoder( if ((size_t)space_id >= vocabulary.size()) { space_id = -2; } - // init prefixes' root PathTrie root; root.score = root.log_prob_b_prev = 0.0; @@ -218,7 +218,8 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); // thread pool ThreadPool pool(num_processes); @@ -234,7 +235,8 @@ ctc_beam_search_decoder_batch( beam_size, cutoff_prob, cutoff_top_n, - ext_scorer)); + ext_scorer, + blank_id)); } // get decoding results diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.h b/deepspeech/decoders/swig/ctc_beam_search_decoder.h index c31510da34dc5444b0eab05ab682291902cc0bda..eaba9da8c8a05ae10391d81f30c368f125524aac 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.h +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.h @@ -43,7 +43,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); /* CTC Beam Search Decoder for batch data @@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp index 1c735c424bec1cc2bfb33f9d848ff63f03faaf14..18008cced8be1929c3c7e044b4e28b79bb8eeb15 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp @@ -17,17 +17,18 @@ std::string ctc_greedy_decoder( const std::vector> &probs_seq, - const std::vector &vocabulary) { + const std::vector &vocabulary, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size() + 1, + vocabulary.size(), "The shape of probs_seq does not match with " "the shape of the vocabulary"); } - size_t blank_id = vocabulary.size(); + // size_t blank_id = vocabulary.size(); std::vector max_idx_vec(num_time_steps, 0); std::vector idx_vec; diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.h b/deepspeech/decoders/swig/ctc_greedy_decoder.h index 5e8c5c251fd76c8618504e96361cb692cfc6ed43..dd1b333153510766ed64875200825d145a374186 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.h +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.h @@ -29,6 +29,7 @@ */ std::string ctc_greedy_decoder( const std::vector>& probs_seq, - const std::vector& vocabulary); + const std::vector& vocabulary, + size_t blank_id); #endif // CTC_GREEDY_DECODER_H diff --git a/deepspeech/decoders/swig/setup.py b/deepspeech/decoders/swig/setup.py index 8fb792962b7b970efe831916ea01453c21cb0d8a..c089f96cd75f41ac336a20a4b67955b4928d13f4 100644 --- a/deepspeech/decoders/swig/setup.py +++ b/deepspeech/decoders/swig/setup.py @@ -85,9 +85,8 @@ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') # yapf: disable FILES = [ - fn for fn in FILES - if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( - 'unittest.cc')) + fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') + or fn.endswith('unittest.cc')) ] # yapf: enable diff --git a/deepspeech/decoders/swig_wrapper.py b/deepspeech/decoders/swig_wrapper.py index 3ffdb9c74d73f1863e9ba730a0ac03b2eafc7fce..d883d430cfa2db91c7d15474290df60bfc46e778 100644 --- a/deepspeech/decoders/swig_wrapper.py +++ b/deepspeech/decoders/swig_wrapper.py @@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer): swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) -def ctc_greedy_decoder(probs_seq, vocabulary): +def ctc_greedy_decoder(probs_seq, vocabulary, blank_id): """Wrapper for ctc best path decoder in swig. :param probs_seq: 2-D list of probability distributions over each time @@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): :return: Decoding result string. :rtype: str """ - result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) + result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary, + blank_id) return result @@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq, beam_size, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time @@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq, """ beam_results = swig_decoders.ctc_beam_search_decoder( probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, - ext_scoring_func) + ext_scoring_func, blank_id) beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] return beam_results @@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split, num_processes, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the batched CTC beam search decoder. :param probs_seq: 3-D list with each element as an instance of 2-D list @@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split, batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( probs_split, vocabulary, beam_size, num_processes, cutoff_prob, - cutoff_top_n, ext_scoring_func) + cutoff_top_n, ext_scoring_func, blank_id) batch_beam_results = [[(res[0], res[1]) for res in beam_results] for beam_results in batch_beam_results] return batch_beam_results diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index 69ff043a08d28171711543afcbd51bcd571e69d2..6740f288fd667d2f84bca96277c46da9ce894c42 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -27,7 +27,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py deleted file mode 100644 index 94a9b6c47df79e3e4240ba0695d1c48c03a5ff57..0000000000000000000000000000000000000000 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Beam search parameters tuning for DeepSpeech2 model.""" -import functools -import sys - -import numpy as np -from paddle.io import DataLoader - -from deepspeech.exps.deepspeech2.config import get_cfg_defaults -from deepspeech.io.collator import SpeechCollator -from deepspeech.io.dataset import ManifestDataset -from deepspeech.models.ds2 import DeepSpeech2Model -from deepspeech.training.cli import default_argument_parser -from deepspeech.utils import error_rate -from deepspeech.utils.utility import add_arguments -from deepspeech.utils.utility import print_arguments - - -def tune(config, args): - """Tune parameters alpha and beta incrementally.""" - if not args.num_alphas >= 0: - raise ValueError("num_alphas must be non-negative!") - if not args.num_betas >= 0: - raise ValueError("num_betas must be non-negative!") - config.defrost() - config.data.manfiest = config.data.dev_manifest - config.data.augmentation_config = "" - config.data.keep_transcription_text = True - dev_dataset = ManifestDataset.from_config(config) - - valid_loader = DataLoader( - dev_dataset, - batch_size=config.data.batch_size, - shuffle=False, - drop_last=False, - collate_fn=SpeechCollator(keep_transcription_text=True)) - - model = DeepSpeech2Model.from_pretrained(valid_loader, config, - args.checkpoint_path) - model.eval() - - # decoders only accept string encoded in utf-8 - vocab_list = valid_loader.dataset.vocab_list - errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors - - # create grid for search - cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) - cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) - params_grid = [(alpha, beta) for alpha in cand_alphas - for beta in cand_betas] - - err_sum = [0.0 for i in range(len(params_grid))] - err_ave = [0.0 for i in range(len(params_grid))] - - num_ins, len_refs, cur_batch = 0, 0, 0 - # initialize external scorer - model.decoder.init_decode(args.alpha_from, args.beta_from, - config.decoding.lang_model_path, vocab_list, - config.decoding.decoding_method) - ## incremental tuning parameters over multiple batches - print("start tuning ...") - for infer_data in valid_loader(): - if (args.num_batches >= 0) and (cur_batch >= args.num_batches): - break - - def ordid2token(texts, texts_len): - """ ord() id to chr() chr """ - trans = [] - for text, n in zip(texts, texts_len): - n = n.numpy().item() - ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) - return trans - - audio, audio_len, text, text_len = infer_data - target_transcripts = ordid2token(text, text_len) - num_ins += audio.shape[0] - - # model infer - eouts, eouts_len = model.encoder(audio, audio_len) - probs = model.decoder.softmax(eouts) - - # grid search - for index, (alpha, beta) in enumerate(params_grid): - print(f"tuneing: alpha={alpha} beta={beta}") - result_transcripts = model.decoder.decode_probs( - probs.numpy(), eouts_len, vocab_list, - config.decoding.decoding_method, - config.decoding.lang_model_path, alpha, beta, - config.decoding.beam_size, config.decoding.cutoff_prob, - config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch) - - for target, result in zip(target_transcripts, result_transcripts): - errors, len_ref = errors_func(target, result) - err_sum[index] += errors - - # accumulate the length of references of every batchπ - # in the first iteration - if args.alpha_from == alpha and args.beta_from == beta: - len_refs += len_ref - - err_ave[index] = err_sum[index] / len_refs - if index % 2 == 0: - sys.stdout.write('.') - sys.stdout.flush() - print("tuneing: one grid done!") - - # output on-line tuning result at the end of current batch - err_ave_min = min(err_ave) - min_index = err_ave.index(err_ave_min) - print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), " - " min [%s] = %f" % - (cur_batch, num_ins, "%.3f" % params_grid[min_index][0], - "%.3f" % params_grid[min_index][1], - config.decoding.error_rate_type, err_ave_min)) - cur_batch += 1 - - # output WER/CER at every (alpha, beta) - print("\nFinal %s:\n" % config.decoding.error_rate_type) - for index in range(len(params_grid)): - print("(alpha, beta) = (%s, %s), [%s] = %f" % - ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1], - config.decoding.error_rate_type, err_ave[index])) - - err_ave_min = min(err_ave) - min_index = err_ave.index(err_ave_min) - print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" % - (cur_batch, "%.3f" % params_grid[min_index][0], - "%.3f" % params_grid[min_index][1])) - - print("finish tuning") - - -def main(config, args): - tune(config, args) - - -if __name__ == "__main__": - parser = default_argument_parser() - add_arg = functools.partial(add_arguments, argparser=parser) - add_arg('num_batches', int, -1, "# of batches tuning on. " - "Default -1, on whole dev set.") - add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.") - add_arg('num_betas', int, 8, "# of beta candidates for tuning.") - add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.") - add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.") - add_arg('beta_from', float, 0.1, "Where beta starts tuning from.") - add_arg('beta_to', float, 0.45, "Where beta ends tuning with.") - - add_arg('batch_size', int, 256, "# of samples per batch.") - add_arg('beam_size', int, 500, "Beam search width.") - add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.") - add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") - add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") - - args = parser.parse_args() - print_arguments(args, globals()) - - # https://yaml.org/type/float.html - config = get_cfg_defaults() - if args.config: - config.merge_from_file(args.config) - if args.opts: - config.merge_from_list(args.opts) - - config.data.batch_size = args.batch_size - config.decoding.beam_size = args.beam_size - config.decoding.num_proc_bsearch = args.num_proc_bsearch - config.decoding.cutoff_prob = args.cutoff_prob - config.decoding.cutoff_top_n = args.cutoff_top_n - - config.freeze() - print(config) - - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) - - main(config, args) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index f3e3fcadf99daacea13e39d0f6273e2124c0d01a..79a676345fdcb7544bec1511861f5e7accc17928 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -15,9 +15,11 @@ import os import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist @@ -34,12 +36,14 @@ from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.training.reporter import report from deepspeech.training.trainer import Trainer from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Autolog from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -65,29 +69,52 @@ class DeepSpeech2Trainer(Trainer): super().__init__(config, args) def train_batch(self, batch_index, batch_data, msg): + batch_size = self.config.collator.batch_size + accum_grad = self.config.training.accum_grad + start = time.time() + + # forward utt, audio, audio_len, text, text_len = batch_data loss = self.model(audio, audio_len, text, text_len) - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - self.optimizer.step() - self.optimizer.clear_grad() - iteration_time = time.time() - start - losses_np = { 'train_loss': float(loss), } - msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.collator.batch_size) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_np.items()) - logger.info(msg) + + # loss backward + if (batch_index + 1) % accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step + if (batch_index + 1) % accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.iteration += 1 + + iteration_time = time.time() - start + + for k, v in losses_np.items(): + report(k, v) + report("batch_size", batch_size) + report("accum", accum_grad) + report("step_cost", iteration_time) if dist.get_rank() == 0 and self.visualizer: for k, v in losses_np.items(): + # `step -1` since we update `step` after optimizer.step(). self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration) - self.iteration += 1 + self.iteration - 1) @paddle.no_grad() def valid(self): @@ -124,10 +151,9 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config.clone() - config.defrost() - config.model.feat_size = self.train_loader.collate_fn.feature_size - config.model.dict_size = self.train_loader.collate_fn.vocab_size - config.freeze() + with UpdateConfig(config): + config.model.feat_size = self.train_loader.collate_fn.feature_size + config.model.dict_size = self.train_loader.collate_fn.vocab_size if self.args.model_type == 'offline': model = DeepSpeech2Model.from_config(config.model) @@ -280,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("Current error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -325,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cfg = self.config error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch metrics = self.compute_metrics(utts, audio, audio_len, texts, @@ -378,7 +405,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() self.setup_checkpointer() @@ -610,7 +637,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index 9dd0041dd6bf144f51a334f4b97edf3be30afd39..17fb08a6c41202da97fec48911ba714942cf29c7 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -22,6 +22,8 @@ from deepspeech.exps.u2.model import U2Trainer as Trainer from deepspeech.training.cli import default_argument_parser from deepspeech.utils.utility import print_arguments +# from deepspeech.exps.u2.trainer import U2Trainer as Trainer + def main_sp(config, args): exp = Trainer(config, args) @@ -30,7 +32,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 0662e38d9fcdbf60ab764f8a2936f4b2006790f1..5cb0962a7fda3dbac6730b52e660f56192a01b68 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -17,9 +17,12 @@ import os import sys import time from collections import defaultdict +from collections import OrderedDict +from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist @@ -32,7 +35,10 @@ from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory +from deepspeech.training.reporter import ObsScope +from deepspeech.training.reporter import report from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -41,6 +47,7 @@ from deepspeech.utils import mp_tools from deepspeech.utils import text_grid from deepspeech.utils import utility from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -79,21 +86,36 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() - utt, audio, audio_len, text, text_len = batch_data + # forward + utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + # When using cpu w/o DDP, model does not have `no_sync` + context = self.model.no_sync if self.parallel else nullcontext + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -102,14 +124,13 @@ class U2Trainer(Trainer): iteration_time = time.time() - start - if (batch_index + 1) % train_conf.log_interval == 0: - msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.collator.batch_size) - msg += "accum: {}, ".format(train_conf.accum_grad) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_np.items()) - logger.info(msg) + for k, v in losses_np.items(): + report(k, v) + report("batch_size", self.config.collator.batch_size) + report("accum", train_conf.accum_grad) + report("step_cost", iteration_time) + if (batch_index + 1) % train_conf.accum_grad == 0: if dist.get_rank() == 0 and self.visualizer: losses_np_v = losses_np.copy() losses_np_v.update({"lr": self.lr_scheduler()}) @@ -163,46 +184,58 @@ class U2Trainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - - self.lr_scheduler.step(self.iteration) - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train:" + observation = OrderedDict() + with ObsScope(observation): + report("Rank", dist.get_rank()) + report("epoch", self.epoch) + report('step', self.iteration) + report("lr", self.lr_scheduler()) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + report('iter', batch_index + 1) + report('total', len(self.train_loader)) + report('reader_cost', dataload_time) + observation['batch_cost'] = observation[ + 'reader_cost'] + observation['step_cost'] + observation['samples'] = observation['batch_size'] + observation['ips[sent./sec]'] = observation[ + 'batch_size'] / observation['batch_cost'] + for k, v in observation.items(): + msg += f" {k}: " + msg += f"{v:>.8f}" if isinstance(v, + float) else f"{v}" + msg += "," + if (batch_index + 1 + ) % self.config.training.log_interval == 0: + logger.info(msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -294,10 +327,11 @@ class U2Trainer(Trainer): def setup_model(self): config = self.config model_conf = config.model - model_conf.defrost() - model_conf.input_dim = self.train_loader.collate_fn.feature_size - model_conf.output_dim = self.train_loader.collate_fn.vocab_size - model_conf.freeze() + + with UpdateConfig(model_conf): + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + model = U2Model.from_config(model_conf) if self.parallel: @@ -433,9 +467,10 @@ class U2Tester(U2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -460,7 +495,7 @@ class U2Tester(U2Trainer): errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] @@ -540,7 +575,7 @@ class U2Tester(U2Trainer): # 1. Encoder encoder_out, encoder_mask = self.model._forward_encoder( feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) + maxlen = encoder_out.shape[1] ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) @@ -548,26 +583,25 @@ class U2Tester(U2Trainer): ctc_probs = ctc_probs.squeeze(0) target = target.squeeze(0) alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info("align ids", key[0], alignment) + logger.info(f"align ids: {key[0]} {alignment}") fout.write('{} {}\n'.format(key[0], alignment)) # 3. gen praat # segment alignment align_segs = text_grid.segment_alignment(alignment) - logger.info("align tokens", key[0], align_segs) + logger.info(f"align tokens: {key[0]}, {align_segs}") # IntervalTier, List["start end token\n"] subsample = utility.get_subsample(self.config) tierformat = text_grid.align_to_tierformat( align_segs, subsample, token_dict) # write tier - align_output_path = os.path.join( - os.path.dirname(self.args.result_file), "align") - tier_path = os.path.join(align_output_path, key[0] + ".tier") - with open(tier_path, 'w') as f: + align_output_path = Path(self.args.result_file).parent / "align" + align_output_path.mkdir(parents=True, exist_ok=True) + tier_path = align_output_path / (key[0] + ".tier") + with tier_path.open('w') as f: f.writelines(tierformat) # write textgrid - textgrid_path = os.path.join(align_output_path, - key[0] + ".TextGrid") + textgrid_path = align_output_path / (key[0] + ".TextGrid") second_per_frame = 1. / (1000. / stride_ms) # 25ms window, 10ms stride second_per_example = ( @@ -575,7 +609,7 @@ class U2Tester(U2Trainer): text_grid.generate_textgrid( maxtime=second_per_example, intervals=tierformat, - output=textgrid_path) + output=str(textgrid_path)) def run_align(self): self.resume_or_scratch() @@ -621,7 +655,7 @@ class U2Tester(U2Trainer): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() self.setup_checkpointer() diff --git a/deepspeech/exps/u2/trainer.py b/deepspeech/exps/u2/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8634ac331f569a8be68e84c1947fb057b830ab --- /dev/null +++ b/deepspeech/exps/u2/trainer.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains U2 model.""" +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.models.u2 import U2Evaluator +from deepspeech.models.u2 import U2Model +from deepspeech.models.u2 import U2Updater +from deepspeech.training.extensions.snapshot import Snapshot +from deepspeech.training.extensions.visualizer import VisualDL +from deepspeech.training.optimizer import OptimizerFactory +from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer +from deepspeech.training.trainer import Trainer +from deepspeech.training.updaters.trainer import Trainer as NewTrainer +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig + +logger = Log(__name__).getlog() + + +class U2Trainer(Trainer): + def __init__(self, config, args): + super().__init__(config, args) + + def setup_dataloader(self): + config = self.config.clone() + config.defrost() + config.collator.keep_transcription_text = False + + # train/valid dataset, return token ids + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers, ) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.collator.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev) + + # test dataset, return raw text + config.data.manifest = config.data.test_manifest + # filter test examples, will cause less examples, but no mismatch with training + # and can use large batch size , save training time, so filter test egs now. + config.data.min_input_len = 0.0 # second + config.data.max_input_len = float('inf') # second + config.data.min_output_len = 0.0 # tokens + config.data.max_output_len = float('inf') # tokens + config.data.min_output_input_ratio = 0.00 + config.data.max_output_input_ratio = float('inf') + + test_dataset = ManifestDataset.from_config(config) + # return text ord id + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + # return text token id + config.collator.keep_transcription_text = False + self.align_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + logger.info("Setup train/valid/test/align Dataloader!") + + def setup_model(self): + config = self.config + model_conf = config.model + with UpdateConfig(model_conf): + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + + model = U2Model.from_config(model_conf) + + if self.parallel: + model = paddle.DataParallel(model) + + model.train() + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + + scheduler_args = { + "learning_rate": optim_conf.lr, + "verbose": False, + "warmup_steps": scheduler_conf.warmup_steps, + "gamma": scheduler_conf.lr_decay, + "d_model": model_conf.encoder_conf.output_size, + } + lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, + scheduler_args) + + def optimizer_args( + config, + parameters, + lr_scheduler=None, ): + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + return { + "grad_clip": train_config.global_grad_clip, + "weight_decay": optim_conf.weight_decay, + "learning_rate": lr_scheduler + if lr_scheduler else optim_conf.lr, + "parameters": parameters, + "epsilon": 1e-9 if optim_type == 'noam' else None, + "beta1": 0.9 if optim_type == 'noam' else None, + "beat2": 0.98 if optim_type == 'noam' else None, + } + + optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) + optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + logger.info("Setup model/optimizer/lr_scheduler!") + + def setup_updater(self): + output_dir = self.output_dir + config = self.config.training + + updater = U2Updater( + model=self.model, + optimizer=self.optimizer, + scheduler=self.lr_scheduler, + dataloader=self.train_loader, + output_dir=output_dir, + accum_grad=config.accum_grad) + + trainer = NewTrainer(updater, (config.n_epoch, 'epoch'), output_dir) + + evaluator = U2Evaluator(self.model, self.valid_loader) + + trainer.extend(evaluator, trigger=(1, "epoch")) + + if dist.get_rank() == 0: + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + num_snapshots = config.checkpoint.kbest_n + trainer.extend( + Snapshot( + mode='kbest', + max_size=num_snapshots, + indicator='VALID/LOSS', + less_better=True), + trigger=(1, 'epoch')) + # print(trainer.extensions) + # trainer.run() + self.trainer = trainer + + def run(self): + """The routine of the experiment after setup. This method is intended + to be used by the user. + """ + self.setup_updater() + with Timer("Training Done: {}"): + self.trainer.run() diff --git a/deepspeech/exps/u2_kaldi/bin/train.py b/deepspeech/exps/u2_kaldi/bin/train.py index 1dcd154d35bcd941db66260fe54e60873694dc28..d909727f3a5a637f977b2f5569d553db5ff382d1 100644 --- a/deepspeech/exps/u2_kaldi/bin/train.py +++ b/deepspeech/exps/u2_kaldi/bin/train.py @@ -36,7 +36,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 6a932d75137b302f98eb9f8e66c402dbacc6d787..d38afe25cf315b9bf5a0bae603397fd83ac61aa6 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -17,9 +17,11 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist @@ -31,6 +33,7 @@ from deepspeech.io.dataloader import BatchDataLoader from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -39,6 +42,7 @@ from deepspeech.utils import mp_tools from deepspeech.utils import text_grid from deepspeech.utils import utility from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -83,20 +87,34 @@ class U2Trainer(Trainer): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -167,43 +185,42 @@ class U2Trainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - self.lr_scheduler.step(self.iteration) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -300,10 +317,10 @@ class U2Trainer(Trainer): # model model_conf = config.model - model_conf.defrost() - model_conf.input_dim = self.train_loader.feat_dim - model_conf.output_dim = self.train_loader.vocab_size - model_conf.freeze() + with UpdateConfig(model_conf): + model_conf.input_dim = self.train_loader.feat_dim + model_conf.output_dim = self.train_loader.vocab_size + model = U2Model.from_config(model_conf) if self.parallel: model = paddle.DataParallel(model) @@ -429,9 +446,10 @@ class U2Tester(U2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -456,7 +474,7 @@ class U2Tester(U2Trainer): errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] @@ -526,9 +544,8 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") - stride_ms = self.config.collater.stride_ms - token_dict = self.args.char_list - + stride_ms = self.align_loader.collate_fn.stride_ms + token_dict = self.align_loader.collate_fn.vocab_list with open(self.args.result_file, 'w') as fout: # one example in batch for i, batch in enumerate(self.align_loader): @@ -537,7 +554,7 @@ class U2Tester(U2Trainer): # 1. Encoder encoder_out, encoder_mask = self.model._forward_encoder( feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) + maxlen = encoder_out.shape[1] ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) @@ -545,26 +562,25 @@ class U2Tester(U2Trainer): ctc_probs = ctc_probs.squeeze(0) target = target.squeeze(0) alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info("align ids", key[0], alignment) + logger.info(f"align ids: {key[0]} {alignment}") fout.write('{} {}\n'.format(key[0], alignment)) # 3. gen praat # segment alignment align_segs = text_grid.segment_alignment(alignment) - logger.info("align tokens", key[0], align_segs) + logger.info(f"align tokens: {key[0]}, {align_segs}") # IntervalTier, List["start end token\n"] subsample = utility.get_subsample(self.config) tierformat = text_grid.align_to_tierformat( align_segs, subsample, token_dict) # write tier - align_output_path = os.path.join( - os.path.dirname(self.args.result_file), "align") - tier_path = os.path.join(align_output_path, key[0] + ".tier") - with open(tier_path, 'w') as f: + align_output_path = Path(self.args.result_file).parent / "align" + align_output_path.mkdir(parents=True, exist_ok=True) + tier_path = align_output_path / (key[0] + ".tier") + with tier_path.open('w') as f: f.writelines(tierformat) # write textgrid - textgrid_path = os.path.join(align_output_path, - key[0] + ".TextGrid") + textgrid_path = align_output_path / (key[0] + ".TextGrid") second_per_frame = 1. / (1000. / stride_ms) # 25ms window, 10ms stride second_per_example = ( @@ -572,7 +588,7 @@ class U2Tester(U2Trainer): text_grid.generate_textgrid( maxtime=second_per_example, intervals=tierformat, - output=textgrid_path) + output=str(textgrid_path)) def run_align(self): self.resume_or_scratch() @@ -623,7 +639,7 @@ class U2Tester(U2Trainer): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() self.setup_checkpointer() diff --git a/deepspeech/exps/u2_st/bin/train.py b/deepspeech/exps/u2_st/bin/train.py index 86a0f000051c720276fccdadcf5d8cd0e27a9c9c..1e6a746b848ed7e387aec91217c96640e628a39e 100644 --- a/deepspeech/exps/u2_st/bin/train.py +++ b/deepspeech/exps/u2_st/bin/train.py @@ -30,7 +30,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 5734e15f58c5fdcc843602b69475fbf60ecd006c..e4e70292cda53226bd1eeaebffae7f3752275f87 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -17,9 +17,11 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist @@ -37,6 +39,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2_st import U2STModel from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.scheduler import WarmupLR +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import bleu_score from deepspeech.utils import ctc_utils @@ -45,6 +48,7 @@ from deepspeech.utils import mp_tools from deepspeech.utils import text_grid from deepspeech.utils import utility from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -83,6 +87,7 @@ class U2STTrainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data if isinstance(text, list) and isinstance(text_len, list): # joint training with ASR. Two decoding texts [translation, transcription] @@ -94,18 +99,30 @@ class U2STTrainer(Trainer): else: loss, st_loss, attention_loss, ctc_loss = self.model( audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} - losses_np['st_loss'] = float(st_loss) if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -182,46 +199,42 @@ class U2STTrainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - - self.lr_scheduler.step(self.iteration) - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -327,10 +340,10 @@ class U2STTrainer(Trainer): def setup_model(self): config = self.config model_conf = config.model - model_conf.defrost() - model_conf.input_dim = self.train_loader.collate_fn.feature_size - model_conf.output_dim = self.train_loader.collate_fn.vocab_size - model_conf.freeze() + with UpdateConfig(model_conf): + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + model = U2STModel.from_config(model_conf) if self.parallel: @@ -467,8 +480,10 @@ class U2STTester(U2STTrainer): len_refs += len(target.split()) num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nReference: %s\nHypothesis: %s" % (target, result)) + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example BLEU = %s" % (bleu_func([result], [[target]]).prec_str)) @@ -496,7 +511,7 @@ class U2STTester(U2STTrainer): len_refs, num_ins = 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_translation_metrics( *batch, bleu_func=bleu_func, fout=fout) @@ -569,7 +584,7 @@ class U2STTester(U2STTrainer): # 1. Encoder encoder_out, encoder_mask = self.model._forward_encoder( feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) + maxlen = encoder_out.shape[1] ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) @@ -577,26 +592,25 @@ class U2STTester(U2STTrainer): ctc_probs = ctc_probs.squeeze(0) target = target.squeeze(0) alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info("align ids", key[0], alignment) + logger.info(f"align ids: {key[0]} {alignment}") fout.write('{} {}\n'.format(key[0], alignment)) # 3. gen praat # segment alignment align_segs = text_grid.segment_alignment(alignment) - logger.info("align tokens", key[0], align_segs) + logger.info(f"align tokens: {key[0]}, {align_segs}") # IntervalTier, List["start end token\n"] subsample = utility.get_subsample(self.config) tierformat = text_grid.align_to_tierformat( align_segs, subsample, token_dict) # write tier - align_output_path = os.path.join( - os.path.dirname(self.args.result_file), "align") - tier_path = os.path.join(align_output_path, key[0] + ".tier") - with open(tier_path, 'w') as f: + align_output_path = Path(self.args.result_file).parent / "align" + align_output_path.mkdir(parents=True, exist_ok=True) + tier_path = align_output_path / (key[0] + ".tier") + with tier_path.open('w') as f: f.writelines(tierformat) # write textgrid - textgrid_path = os.path.join(align_output_path, - key[0] + ".TextGrid") + textgrid_path = align_output_path / (key[0] + ".TextGrid") second_per_frame = 1. / (1000. / stride_ms) # 25ms window, 10ms stride second_per_example = ( @@ -604,7 +618,7 @@ class U2STTester(U2STTrainer): text_grid.generate_textgrid( maxtime=second_per_example, intervals=tierformat, - output=textgrid_path) + output=str(textgrid_path)) def run_align(self): self.resume_or_scratch() @@ -650,7 +664,7 @@ class U2STTester(U2STTrainer): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() self.setup_checkpointer() diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index e4364f70a2baa40fbcb2f1b237465ad4df8c3848..7dc01c40afd8b29576bef2179d91b8db6c81efd8 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -76,7 +76,7 @@ class TextFeaturizer(): Args: text (str): Text. - + Returns: List[int]: List of token indices. """ @@ -89,7 +89,7 @@ class TextFeaturizer(): def defeaturize(self, idxs): """Convert a list of token indices to text string, - ignore index after eos_id. + ignore index after eos_id. Args: idxs (List[int]): List of token indices. @@ -196,7 +196,12 @@ class TextFeaturizer(): [(idx, token) for (idx, token) in enumerate(vocab_list)]) token2id = dict( [(token, idx) for (idx, token) in enumerate(vocab_list)]) - - unk_id = vocab_list.index(UNK) - eos_id = vocab_list.index(EOS) + if UNK in vocab_list: + unk_id = vocab_list.index(UNK) + else: + unk_id = -1 + if EOS in vocab_list: + eos_id = vocab_list.index(EOS) + else: + eos_id = -1 return token2id, id2token, vocab_list, unk_id, eos_id diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index 73b3a4ba6e788a6bbe4ccdde17132c39d2ef47ed..6ace4fc6ddcaccd75436c1a528e7f6b512ba8292 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -130,7 +130,8 @@ class FeatureNormalizer(object): def _read_mean_std_from_file(self, filepath, eps=1e-20): """Load mean and std from file.""" - mean, istd = load_cmvn(filepath, filetype='json') + filetype = filepath.split(".")[-1] + mean, istd = load_cmvn(filepath, filetype=filetype) self._mean = np.expand_dims(mean, axis=0) self._istd = np.expand_dims(istd, axis=0) diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index 72dfc98dd3fe5132c320c29c93860d3468846945..f7e2cb2142b998735c48fb4367fe88a147bffb4c 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains data helper functions.""" -import codecs import json import math from typing import List from typing import Optional from typing import Text +import jsonlines import numpy as np from deepspeech.utils.log import Log @@ -69,19 +69,19 @@ def read_manifest( Args: manifest_path ([type]): Manifest file to load and parse. - max_input_len ([type], optional): maximum output seq length, - in seconds for raw wav, in frame numbers for feature data. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). - min_input_len (float, optional): minimum input seq length, - in seconds for raw wav, in frame numbers for feature data. + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. - max_output_len (float, optional): maximum input seq length, + max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. - min_output_len (float, optional): minimum input seq length, + min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. - max_output_input_ratio (float, optional): + max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. - min_output_input_ratio (float, optional): + min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. Raises: @@ -92,26 +92,22 @@ def read_manifest( """ manifest = [] - for json_line in codecs.open(manifest_path, 'r', 'utf-8'): - try: - json_data = json.loads(json_line) - except Exception as e: - raise IOError("Error reading manifest: %s" % str(e)) - - feat_len = json_data["feat_shape"][ - 0] if 'feat_shape' in json_data else 1.0 - token_len = json_data["token_shape"][ - 0] if 'token_shape' in json_data else 1.0 - conditions = [ - feat_len >= min_input_len, - feat_len <= max_input_len, - token_len >= min_output_len, - token_len <= max_output_len, - token_len / feat_len >= min_output_input_ratio, - token_len / feat_len <= max_output_input_ratio, - ] - if all(conditions): - manifest.append(json_data) + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + feat_len = json_data["feat_shape"][ + 0] if 'feat_shape' in json_data else 1.0 + token_len = json_data["token_shape"][ + 0] if 'token_shape' in json_data else 1.0 + conditions = [ + feat_len >= min_input_len, + feat_len <= max_input_len, + token_len >= min_output_len, + token_len <= max_output_len, + token_len / feat_len >= min_output_input_ratio, + token_len / feat_len <= max_output_input_ratio, + ] + if all(conditions): + manifest.append(json_data) return manifest @@ -131,7 +127,7 @@ def rms_to_dbfs(rms: float): """Root Mean Square to dBFS. https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. - + dB = dBFS + 3.0103 dBFS = db - 3.0103 e.g. 0 dB = -3.0103 dBFS @@ -146,26 +142,26 @@ def rms_to_dbfs(rms: float): def max_dbfs(sample_data: np.ndarray): - """Peak dBFS based on the maximum energy sample. + """Peak dBFS based on the maximum energy sample. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) def mean_dbfs(sample_data): - """Peak dBFS based on the RMS energy. + """Peak dBFS based on the RMS energy. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ return rms_to_dbfs( math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) @@ -185,7 +181,7 @@ def gain_db_to_ratio(gain_db: float): def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): """Nomalize audio to dBFS. - + Args: sample_data (np.ndarray): input wave samples, [-1, 1]. dbfs (float, optional): target dBFS. Defaults to -3.0103. @@ -284,6 +280,13 @@ def load_cmvn(cmvn_file: str, filetype: str): cmvn = _load_json_cmvn(cmvn_file) elif filetype == "kaldi": cmvn = _load_kaldi_cmvn(cmvn_file) + elif filetype == "npz": + eps = 1e-14 + npzfile = np.load(cmvn_file) + mean = np.squeeze(npzfile["mean"]) + std = np.squeeze(npzfile["std"]) + istd = 1 / (std + eps) + cmvn = [mean, istd] else: raise ValueError(f"cmvn file type no support: {filetype}") return cmvn[0], cmvn[1] diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index df300479059fa2dcc5119f313e723365f7a11b78..15b89ab9f7768518314ba30a11f046c143a8d860 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -292,10 +292,6 @@ class SpeechCollator(): olens = np.array(text_lens).astype(np.int64) return utts, xs_pad, ilens, ys_pad, olens - @property - def manifest(self): - return self._manifest - @property def vocab_size(self): return self._speech_featurizer.vocab_size diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py index a35a0bc09695f63d5bbfa147d45b988cdf44dcab..310f5f581826460c947580a440e1aec5cc2a146a 100644 --- a/deepspeech/io/dataloader.py +++ b/deepspeech/io/dataloader.py @@ -44,7 +44,7 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]], def batch_collate(x): - """de-tuple. + """de-minibatch, since user compose batch. Args: x (List[Tuple]): [(utts, xs, ilens, ys, olens)] diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index d1fe047077cd851a799131f68de4e0e2b9ab1b12..56e534756e47fce796887b4de9e25ba9323bfec3 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -76,19 +76,19 @@ class ManifestDataset(Dataset): Args: manifest_path (str): manifest josn file path - max_input_len ([type], optional): maximum output seq length, + max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). - min_input_len (float, optional): minimum input seq length, + min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. - max_output_len (float, optional): maximum input seq length, + max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. - min_output_len (float, optional): minimum input seq length, + min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. - max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. + max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. - + """ super().__init__() @@ -147,3 +147,131 @@ class TransformDataset(Dataset): def __getitem__(self, idx): """[] operator.""" return self.converter([self.reader(self.data[idx], return_uttid=True)]) + + +class AudioDataset(Dataset): + def __init__(self, + data_file, + max_length=10240, + min_length=0, + token_max_length=200, + token_min_length=1, + batch_type='static', + batch_size=1, + max_frames_in_batch=0, + sort=True, + raw_wav=True, + stride_ms=10): + """Dataset for loading audio data. + Attributes:: + data_file: input data file + Plain text data file, each line contains following 7 fields, + which is split by '\t': + utt:utt1 + feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30 + feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames) + text:i love you + token: i l o v e y o u + tokenid: int id of this token + token_shape: M,N # M is the number of token, N is vocab size + max_length: drop utterance which is greater than max_length(10ms), unit 10ms. + min_length: drop utterance which is less than min_length(10ms), unit 10ms. + token_max_length: drop utterance which is greater than token_max_length, + especially when use char unit for english modeling + token_min_length: drop utterance which is less than token_max_length + batch_type: static or dynamic, see max_frames_in_batch(dynamic) + batch_size: number of utterances in a batch, + it's for static batch size. + max_frames_in_batch: max feature frames in a batch, + when batch_type is dynamic, it's for dynamic batch size. + Then batch_size is ignored, we will keep filling the + batch until the total frames in batch up to max_frames_in_batch. + sort: whether to sort all data, so the utterance with the same + length could be filled in a same batch. + raw_wav: use raw wave or extracted featute. + if raw wave is used, dynamic waveform-level augmentation could be used + and the feature is extracted by torchaudio. + if extracted featute(e.g. by kaldi) is used, only feature-level + augmentation such as specaug could be used. + """ + assert batch_type in ['static', 'dynamic'] + # read manifest + data = read_manifest(data_file) + if sort: + data = sorted(data, key=lambda x: x["feat_shape"][0]) + if raw_wav: + assert data[0]['feat'].split(':')[0].splitext()[-1] not in ('.ark', + '.scp') + data = map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms)) + + self.input_dim = data[0]['feat_shape'][1] + self.output_dim = data[0]['token_shape'][1] + + # with open(data_file, 'r') as f: + # for line in f: + # arr = line.strip().split('\t') + # if len(arr) != 7: + # continue + # key = arr[0].split(':')[1] + # tokenid = arr[5].split(':')[1] + # output_dim = int(arr[6].split(':')[1].split(',')[1]) + # if raw_wav: + # wav_path = ':'.join(arr[1].split(':')[1:]) + # duration = int(float(arr[2].split(':')[1]) * 1000 / 10) + # data.append((key, wav_path, duration, tokenid)) + # else: + # feat_ark = ':'.join(arr[1].split(':')[1:]) + # feat_info = arr[2].split(':')[1].split(',') + # feat_dim = int(feat_info[1].strip()) + # num_frames = int(feat_info[0].strip()) + # data.append((key, feat_ark, num_frames, tokenid)) + # self.input_dim = feat_dim + # self.output_dim = output_dim + + valid_data = [] + for i in range(len(data)): + length = data[i]['feat_shape'][0] + token_length = data[i]['token_shape'][0] + # remove too lang or too short utt for both input and output + # to prevent from out of memory + if length > max_length or length < min_length: + # logging.warn('ignore utterance {} feature {}'.format( + # data[i][0], length)) + pass + elif token_length > token_max_length or token_length < token_min_length: + pass + else: + valid_data.append(data[i]) + data = valid_data + + self.minibatch = [] + num_data = len(data) + # Dynamic batch size + if batch_type == 'dynamic': + assert (max_frames_in_batch > 0) + self.minibatch.append([]) + num_frames_in_batch = 0 + for i in range(num_data): + length = data[i]['feat_shape'][0] + num_frames_in_batch += length + if num_frames_in_batch > max_frames_in_batch: + self.minibatch.append([]) + num_frames_in_batch = length + self.minibatch[-1].append(data[i]) + # Static batch size + else: + cur = 0 + while cur < num_data: + end = min(cur + batch_size, num_data) + item = [] + for i in range(cur, end): + item.append(data[i]) + self.minibatch.append(item) + cur = end + + def __len__(self): + return len(self.minibatch) + + def __getitem__(self, idx): + instance = self.minibatch[idx] + return instance["utt"], instance["feat"], instance["text"] diff --git a/deepspeech/models/ds2/conv.py b/deepspeech/models/ds2/conv.py index ce962a445318c924ac286fcd4b43e85ffb4bf91b..9548af0a2585d33c04f7127995323355fef080b0 100644 --- a/deepspeech/models/ds2/conv.py +++ b/deepspeech/models/ds2/conv.py @@ -106,11 +106,9 @@ class ConvBn(nn.Layer): # reset padding part to 0 masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 5f8f3255765a92e9c5ead2f70a4c73e7056ea9a6..dda26358b0b26f0d0f9f802cc683b8318fda2406 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -128,8 +128,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, #Number of stacking RNN layers. rnn_layer_size=1024, #RNN layer size (number of RNN cells). use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) + share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + ctc_grad_norm_type='instance', )) if config is not None: config.merge_from_other_cfg(default) return default @@ -141,7 +141,9 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + blank_id=0, + ctc_grad_norm_type='instance'): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -156,10 +158,11 @@ class DeepSpeech2Model(nn.Layer): self.decoder = CTCDecoder( odim=dict_size, # is in vocab enc_n_units=self.encoder.output_size, - blank_id=0, # first token is + blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=ctc_grad_norm_type) def forward(self, audio, audio_len, text, text_len): """Compute Model loss @@ -221,7 +224,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + share_rnn_weights=config.model.share_rnn_weights, + blank_id=config.model.blank_id) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -246,7 +250,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.num_rnn_layers, rnn_size=config.rnn_layer_size, use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights) + share_rnn_weights=config.share_rnn_weights, + blank_id=config.blank_id) return model @@ -258,7 +263,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + blank_id=0): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -266,7 +272,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=num_rnn_layers, rnn_size=rnn_size, use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + share_rnn_weights=share_rnn_weights, + blank_id=blank_id) def forward(self, audio, audio_len): """export model function diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py index 3ff91d0afb652c81588b0da4bf254f526b3b0bd8..3fc52a378b3d684538522186c3e0702a2e513a2a 100644 --- a/deepspeech/models/ds2/rnn.py +++ b/deepspeech/models/ds2/rnn.py @@ -308,7 +308,8 @@ class RNNStack(nn.Layer): x, x_len = rnn(x, x_len) masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks + return x, x_len diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index f597a578305c15aa74b42386e3c19f970e5de6c3..29d207c44c03d292f790eba8f18489e78dbc34db 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer): num_fc_layers=2, fc_layers_size_list=[512, 256], use_gru=True, #Use gru if set True. Use simple rnn if set False. + blank_id=0, # index of blank in vocob.txt )) if config is not None: config.merge_from_other_cfg(default) @@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False): + use_gru=False, + blank_id=0): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -284,10 +286,11 @@ class DeepSpeech2ModelOnline(nn.Layer): self.decoder = CTCDecoder( odim=dict_size, # is in vocab enc_n_units=self.encoder.output_size, - blank_id=0, # first token is + blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type='instance') def forward(self, audio, audio_len, text, text_len): """Compute Model loss @@ -353,7 +356,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction=config.model.rnn_direction, num_fc_layers=config.model.num_fc_layers, fc_layers_size_list=config.model.fc_layers_size_list, - use_gru=config.model.use_gru) + use_gru=config.model.use_gru, + blank_id=config.model.blank_id) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -380,7 +384,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction=config.rnn_direction, num_fc_layers=config.num_fc_layers, fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru) + use_gru=config.use_gru, + blank_id=config.blank_id) return model @@ -394,7 +399,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False): + use_gru=False, + blank_id=0): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -404,7 +410,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): rnn_direction=rnn_direction, num_fc_layers=num_fc_layers, fc_layers_size_list=fc_layers_size_list, - use_gru=use_gru) + use_gru=use_gru, + blank_id=blank_id) def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box): diff --git a/deepspeech/models/u2/__init__.py b/deepspeech/models/u2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9010f1d09263dc643d16308a8cefbd06744c958 --- /dev/null +++ b/deepspeech/models/u2/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .u2 import U2InferModel +from .u2 import U2Model +from .updater import U2Evaluator +from .updater import U2Updater + +__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"] diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2/u2.py similarity index 96% rename from deepspeech/models/u2.py rename to deepspeech/models/u2/u2.py index c1a35560a707b34b3687a4a0849334e70d282cc4..46bbd102f39125820b0d556afb4a046125dbbc3c 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2/u2.py @@ -48,6 +48,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add +from deepspeech.utils.utility import UpdateConfig __all__ = ["U2Model", "U2InferModel"] @@ -115,7 +116,8 @@ class U2BaseModel(nn.Layer): ctc_weight: float=0.5, ignore_id: int=IGNORE_ID, lsm_weight: float=0.0, - length_normalized_loss: bool=False): + length_normalized_loss: bool=False, + **kwargs): assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() @@ -162,10 +164,7 @@ class U2BaseModel(nn.Layer): encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_time = time.time() - start #logger.debug(f"encoder time: {encoder_time}") - #TODO(Hui Zhang): sum not support bool type - #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] - encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( - 1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] # 2a. Attention-decoder branch loss_att = None @@ -299,8 +298,8 @@ class U2BaseModel(nn.Layer): speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) - encoder_dim = encoder_out.size(2) + maxlen = encoder_out.shape[1] + encoder_dim = encoder_out.shape[2] running_size = batch_size * beam_size encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) @@ -320,8 +319,7 @@ class U2BaseModel(nn.Layer): # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos - # TODO(Hui Zhang): if end_flag.sum() == running_size: - if end_flag.cast(paddle.int64).sum() == running_size: + if end_flag.sum() == running_size: break # 2.1 Forward decoder step @@ -406,10 +404,8 @@ class U2BaseModel(nn.Layer): encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) - maxlen = encoder_out.size(1) - # (TODO Hui Zhang): bool no support reduce_sum - # encoder_out_lens = encoder_mask.squeeze(1).sum(1) - encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) + maxlen = encoder_out.shape[1] + encoder_out_lens = encoder_mask.squeeze(1).sum(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) @@ -459,7 +455,7 @@ class U2BaseModel(nn.Layer): speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) + maxlen = encoder_out.shape[1] ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) @@ -587,7 +583,7 @@ class U2BaseModel(nn.Layer): encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( - (beam_size, 1, encoder_out.size(1)), dtype=paddle.bool) + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) @@ -667,9 +663,7 @@ class U2BaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - # @jit.to_static([ - # paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D] - # ]) + # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -696,13 +690,13 @@ class U2BaseModel(nn.Layer): Returns: paddle.Tensor: decoder output, (B, L) """ - assert encoder_out.size(0) == 1 - num_hyps = hyps.size(0) - assert hyps_lens.size(0) == num_hyps + assert encoder_out.shape[0] == 1 + num_hyps = hyps.shape[0] + assert hyps_lens.shape[0] == num_hyps encoder_out = encoder_out.repeat(num_hyps, 1, 1) # (B, 1, T) encoder_mask = paddle.ones( - [num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) + [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) # (num_hyps, max_hyps_len, vocab_size) decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, hyps_lens) @@ -757,7 +751,7 @@ class U2BaseModel(nn.Layer): Returns: List[List[int]]: transcripts. """ - batch_size = feats.size(0) + batch_size = feats.shape[0] if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: logger.fatal( @@ -785,7 +779,7 @@ class U2BaseModel(nn.Layer): # result in List[int], change it to List[List[int]] for compatible # with other batch decoding mode elif decoding_method == 'ctc_prefix_beam_search': - assert feats.size(0) == 1 + assert feats.shape[0] == 1 hyp = self.ctc_prefix_beam_search( feats, feats_lengths, @@ -795,7 +789,7 @@ class U2BaseModel(nn.Layer): simulate_streaming=simulate_streaming) hyps = [hyp] elif decoding_method == 'attention_rescoring': - assert feats.size(0) == 1 + assert feats.shape[0] == 1 hyp = self.attention_rescoring( feats, feats_lengths, @@ -836,6 +830,7 @@ class U2Model(U2BaseModel): Returns: int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc """ + # cmvn if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['cmvn_file_type']) @@ -845,11 +840,13 @@ class U2Model(U2BaseModel): else: global_cmvn = None + # input & output dim input_dim = configs['input_dim'] vocab_size = configs['output_dim'] assert input_dim != 0, input_dim assert vocab_size != 0, vocab_size + # encoder encoder_type = configs.get('encoder', 'transformer') logger.info(f"U2 Encoder type: {encoder_type}") if encoder_type == 'transformer': @@ -861,16 +858,21 @@ class U2Model(U2BaseModel): else: raise ValueError(f"not support encoder type:{encoder_type}") + # decoder decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + + # ctc decoder and ctc loss + model_conf = configs['model_conf'] ctc = CTCDecoder( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=0.0, + dropout_rate=model_conf['ctc_dropoutrate'], reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=model_conf['ctc_grad_norm_type']) return vocab_size, encoder, decoder, ctc @@ -902,10 +904,10 @@ class U2Model(U2BaseModel): Returns: DeepSpeech2Model: The model built from pretrained result. """ - config.defrost() - config.input_dim = dataloader.collate_fn.feature_size - config.output_dim = dataloader.collate_fn.vocab_size - config.freeze() + with UpdateConfig(config): + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size + model = cls.from_config(config) if checkpoint_path: diff --git a/deepspeech/models/u2/updater.py b/deepspeech/models/u2/updater.py new file mode 100644 index 0000000000000000000000000000000000000000..7b70ca047d7d815652fd9ca196e0ba4b11bbe606 --- /dev/null +++ b/deepspeech/models/u2/updater.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +import paddle +from paddle import distributed as dist + +from deepspeech.training.extensions.evaluator import StandardEvaluator +from deepspeech.training.reporter import report +from deepspeech.training.timer import Timer +from deepspeech.training.updaters.standard_updater import StandardUpdater +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class U2Evaluator(StandardEvaluator): + def __init__(self, model, dataloader): + super().__init__(model, dataloader) + self.msg = "" + self.num_seen_utts = 0 + self.total_loss = 0.0 + + def evaluate_core(self, batch): + self.msg = "Valid: Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + self.num_seen_utts += num_utts + self.total_loss += float(loss) * num_utts + + losses_dict['loss'] = float(loss) + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + + for k, v in losses_dict.items(): + report("eval/" + k, v) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + logger.info(self.msg) + return self.total_loss, self.num_seen_utts + + +class U2Updater(StandardUpdater): + def __init__(self, + model, + optimizer, + scheduler, + dataloader, + init_state=None, + accum_grad=1, + **kwargs): + super().__init__( + model, optimizer, scheduler, dataloader, init_state=init_state) + self.accum_grad = accum_grad + self.forward_count = 0 + self.msg = "" + + def update_core(self, batch): + """One Step + + Args: + batch (List[Object]): utts, xs, xlens, ys, ylens + """ + losses_dict = {} + self.msg = "Rank: {}, ".format(dist.get_rank()) + + # forward + batch_size = batch[1].shape[0] + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + # loss div by `batch_size * accum_grad` + loss /= self.accum_grad + + # loss backward + if (self.forward_count + 1) != self.accum_grad: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # loss info + losses_dict['loss'] = float(loss) * self.accum_grad + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + # report loss + for k, v in losses_dict.items(): + report("train/" + k, v) + # loss msg + self.msg += "batch size: {}, ".format(batch_size) + self.msg += "accum: {}, ".format(self.accum_grad) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + # Truncate the graph + loss.detach() + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + + self.optimizer.step() + self.optimizer.clear_grad() + self.scheduler.step() + + def update(self): + # model is default in train mode + + # training for a step is implemented here + with Timer("data time cost:{}"): + batch = self.read_batch() + with Timer("step time cost:{}"): + self.update_core(batch) + + # #iterations with accum_grad > 1 + # Ref.: https://github.com/espnet/espnet/issues/777 + if self.forward_count == 0: + self.state.iteration += 1 + if self.updates_per_epoch is not None: + if self.state.iteration % self.updates_per_epoch == 0: + self.state.epoch += 1 diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py index b725cc359ebf1d285eb4d1305d3583aa363fa839..a3d99942fd90948d28c8a6c3b44f1e3a3b8b236d 100644 --- a/deepspeech/models/u2_st.py +++ b/deepspeech/models/u2_st.py @@ -42,6 +42,7 @@ from deepspeech.utils import layer_tools from deepspeech.utils.log import Log from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import th_accuracy +from deepspeech.utils.utility import UpdateConfig __all__ = ["U2STModel", "U2STInferModel"] @@ -163,10 +164,7 @@ class U2STBaseModel(nn.Layer): encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_time = time.time() - start #logger.debug(f"encoder time: {encoder_time}") - #TODO(Hui Zhang): sum not support bool type - #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] - encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( - 1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] # 2a. ST-decoder branch start = time.time() @@ -342,8 +340,8 @@ class U2STBaseModel(nn.Layer): speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) - encoder_dim = encoder_out.size(2) + maxlen = encoder_out.shape[1] + encoder_dim = encoder_out.shape[2] running_size = batch_size * beam_size encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) @@ -363,8 +361,7 @@ class U2STBaseModel(nn.Layer): # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos - # TODO(Hui Zhang): if end_flag.sum() == running_size: - if end_flag.cast(paddle.int64).sum() == running_size: + if end_flag.sum() == running_size: break # 2.1 Forward decoder step @@ -417,26 +414,26 @@ class U2STBaseModel(nn.Layer): best_hyps = best_hyps[:, 1:] return best_hyps - @jit.to_static + # @jit.to_static def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate - @jit.to_static + # @jit.to_static def right_context(self) -> int: """ Export interface for c++ call, return right_context of the model """ return self.encoder.embed.right_context - @jit.to_static + # @jit.to_static def sos_symbol(self) -> int: """ Export interface for c++ call, return sos symbol id of the model """ return self.sos - @jit.to_static + # @jit.to_static def eos_symbol(self) -> int: """ Export interface for c++ call, return eos symbol id of the model """ @@ -472,7 +469,7 @@ class U2STBaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - @jit.to_static + # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -499,13 +496,13 @@ class U2STBaseModel(nn.Layer): Returns: paddle.Tensor: decoder output, (B, L) """ - assert encoder_out.size(0) == 1 - num_hyps = hyps.size(0) - assert hyps_lens.size(0) == num_hyps + assert encoder_out.shape[0] == 1 + num_hyps = hyps.shape[0] + assert hyps_lens.shape[0] == num_hyps encoder_out = encoder_out.repeat(num_hyps, 1, 1) # (B, 1, T) encoder_mask = paddle.ones( - [num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) + [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) # (num_hyps, max_hyps_len, vocab_size) decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, hyps_lens) @@ -560,7 +557,7 @@ class U2STBaseModel(nn.Layer): Returns: List[List[int]]: transcripts. """ - batch_size = feats.size(0) + batch_size = feats.shape[0] if decoding_method == 'fullsentence': hyps = self.translate( @@ -647,13 +644,16 @@ class U2STModel(U2STBaseModel): decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + # ctc decoder and ctc loss + model_conf = configs['model_conf'] ctc = CTCDecoder( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=0.0, + dropout_rate=model_conf['ctc_dropout_rate'], reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=model_conf['ctc_grad_norm_type']) return vocab_size, encoder, (st_decoder, decoder, ctc) else: @@ -687,10 +687,10 @@ class U2STModel(U2STBaseModel): Returns: DeepSpeech2Model: The model built from pretrained result. """ - config.defrost() - config.input_dim = dataloader.collate_fn.feature_size - config.output_dim = dataloader.collate_fn.vocab_size - config.freeze() + with UpdateConfig(config): + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size + model = cls.from_config(config) if checkpoint_path: diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 30132775edcf86884ba06822225d3f1b553a25e2..3cb8729e1d81611587088e1b8f9c05806fd43be9 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -15,12 +15,13 @@ from collections import OrderedDict import paddle from paddle import nn +from paddle.nn import functional as F from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] +__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return x.maximum(t_min).minimum(t_max) +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return F.glu(xs, axis=self.dim) + + class LinearGLUBlock(nn.Layer): """A linear Gated Linear Units (GLU) block.""" @@ -133,13 +145,18 @@ def get_activation(act): """Return activation function.""" # Lazy load to avoid unused import activation_funcs = { + "hardshrink": paddle.nn.Hardshrink, + "hardswish": paddle.nn.Hardswish, "hardtanh": paddle.nn.Hardtanh, "tanh": paddle.nn.Tanh, "relu": paddle.nn.ReLU, + "relu6": paddle.nn.ReLU6, + "leakyrelu": paddle.nn.LeakyReLU, "selu": paddle.nn.SELU, "swish": paddle.nn.Swish, "gelu": paddle.nn.GELU, - "brelu": brelu, + "glu": GLU, + "elu": paddle.nn.ELU, } return activation_funcs[act]() diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index 4401a4a5552fa26e8141ef07c4fef1da095ec483..f94797282a5eff2ee5f0d5ae8b558ad53328acee 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -70,7 +70,7 @@ class MultiHeadedAttention(nn.Layer): paddle.Tensor: Transformed value tensor, size (#batch, n_head, time2, d_k). """ - n_batch = query.size(0) + n_batch = query.shape[0] q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) @@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer): paddle.Tensor: Transformed value weighted by the attention score, (#batch, time1, d_model). """ - n_batch = value.size(0) + n_batch = value.shape[0] if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) scores = scores.masked_fill(mask, -float('inf')) @@ -109,8 +109,8 @@ class MultiHeadedAttention(nn.Layer): p_attn = self.dropout(attn) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose([0, 2, 1, 3]).contiguous().view( - n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + x = x.transpose([0, 2, 1, 3]).view(n_batch, -1, self.h * + self.d_k) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) @@ -172,15 +172,16 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): paddle.Tensor: Output tensor. (batch, head, time1, time1) """ zero_pad = paddle.zeros( - (x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype) + (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) + x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, + x.shape[2]) x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] if zero_triu: - ones = paddle.ones((x.size(2), x.size(3))) - x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + ones = paddle.ones((x.shape[2], x.shape[3])) + x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :] return x @@ -205,7 +206,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): q, k, v = self.forward_qkv(query, key, value) q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) - n_batch_pos = pos_emb.size(0) + n_batch_pos = pos_emb.shape[0] p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py index 8bf48b2c80de27f6270f5858d03a90098ddb18f1..22a168800be17885c772f69d1ad3fd710f0abfe6 100644 --- a/deepspeech/modules/conv.py +++ b/deepspeech/modules/conv.py @@ -113,11 +113,9 @@ class ConvBn(nn.Layer): # reset padding part to 0 masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 31e489a3d4b11bdd38f4b2bb6bf62caa3faae6ca..b3ca28279b52c281a3b1668998a4f42bee44517b 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -16,15 +16,19 @@ from paddle import nn from paddle.nn import functional as F from typeguard import check_argument_types -from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch -from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder -from deepspeech.decoders.swig_wrapper import Scorer from deepspeech.modules.loss import CTCLoss from deepspeech.utils import ctc_utils from deepspeech.utils.log import Log logger = Log(__name__).getlog() +try: + from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 + from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401 + from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401 +except Exception as e: + logger.info("ctcdecoder not installed!") + __all__ = ['CTCDecoder'] @@ -35,7 +39,8 @@ class CTCDecoder(nn.Layer): blank_id=0, dropout_rate: float=0.0, reduction: bool=True, - batch_average: bool=True): + batch_average: bool=True, + grad_norm_type: str="instance"): """CTC decoder Args: @@ -44,6 +49,7 @@ class CTCDecoder(nn.Layer): dropout_rate (float): dropout rate (0.0 ~ 1.0) reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' batch_average (bool): do batch dim wise average. + grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None. """ assert check_argument_types() super().__init__() @@ -56,7 +62,8 @@ class CTCDecoder(nn.Layer): self.criterion = CTCLoss( blank=self.blank_id, reduction=reduction_type, - batch_average=batch_average) + batch_average=batch_average, + grad_norm_type=grad_norm_type) # CTCDecoder LM Score handle self._ext_scorer = None @@ -132,7 +139,7 @@ class CTCDecoder(nn.Layer): results = [] for i, probs in enumerate(probs_split): output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) + probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id) results.append(output_transcription) return results @@ -212,13 +219,15 @@ class CTCDecoder(nn.Layer): num_processes=num_processes, ext_scoring_func=self._ext_scorer, cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) + cutoff_top_n=cutoff_top_n, + blank_id=self.blank_id) results = [result[0][1] for result in beam_search_results] return results def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, decoding_method): + if decoding_method == "ctc_beam_search": self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, vocab_list) @@ -229,7 +238,7 @@ class CTCDecoder(nn.Layer): """ctc decoding with probs. Args: - probs (Tenosr): activation after softmax + probs (Tenosr): activation after softmax logits_lens (Tenosr): audio output lens vocab_list ([type]): [description] decoding_method ([type]): [description] diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 87c9fa492182b7822e92e1ab398d35a33eafd73b..8ca72894a934d44aef300c4b333586030b3de966 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -122,11 +122,9 @@ class TransformerDecoder(nn.Layer): # tgt_mask: (B, 1, L) tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1)) # m: (1, L, L) - m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0) + m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0) # tgt_mask: (B, L, L) - # TODO(Hui Zhang): not support & for tensor - # tgt_mask = tgt_mask & m - tgt_mask = tgt_mask.logical_and(m) + tgt_mask = tgt_mask & m x, _ = self.embed(tgt) for layer in self.decoders: @@ -137,9 +135,7 @@ class TransformerDecoder(nn.Layer): if self.use_output_layer: x = self.output_layer(x) - # TODO(Hui Zhang): reduce_sum not support bool type - # olens = tgt_mask.sum(1) - olens = tgt_mask.astype(paddle.int).sum(1) + olens = tgt_mask.sum(1) return x, olens def forward_one_step( diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py index 98b4e1291415176dd36d54de185ef8acc68d9701..fbbda023c613a5328ac6a1e87938ed2428ff98aa 100644 --- a/deepspeech/modules/embedding.py +++ b/deepspeech/modules/embedding.py @@ -68,7 +68,7 @@ class PositionalEncoding(nn.Layer): paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) """ T = x.shape[1] - assert offset + x.size(1) < self.max_len + assert offset + x.shape[1] < self.max_len #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + T] x = x * self.xscale + pos_emb @@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding): paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`). """ - assert offset + x.size(1) < self.max_len + assert offset + x.shape[1] < self.max_len x = x * self.xscale #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + x.shape[1]] diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 71ec61a0e3bbfe03d40799fd8c9cdbc642d7d59a..d4a8275c3ddb577ace09735c2088eb956970509c 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -159,11 +159,10 @@ class BaseEncoder(nn.Layer): if self.global_cmvn is not None: xs = self.global_cmvn(xs) #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor - xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0) + xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor masks = masks.astype(paddle.bool) - #TODO(Hui Zhang): mask_pad = ~masks - mask_pad = masks.logical_not() + mask_pad = ~masks chunk_masks = add_optional_chunk_mask( xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, @@ -207,11 +206,11 @@ class BaseEncoder(nn.Layer): chunk computation List[paddle.Tensor]: conformer cnn cache """ - assert xs.size(0) == 1 # batch size must be one + assert xs.shape[0] == 1 # batch size must be one # tmp_masks is just for interface compatibility # TODO(Hui Zhang): stride_slice not support bool tensor # tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) - tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.int32) + tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32) tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T] if self.global_cmvn is not None: @@ -221,25 +220,25 @@ class BaseEncoder(nn.Layer): xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) if subsampling_cache is not None: - cache_size = subsampling_cache.size(1) #T + cache_size = subsampling_cache.shape[1] #T xs = paddle.cat((subsampling_cache, xs), dim=1) else: cache_size = 0 # only used when using `RelPositionMultiHeadedAttention` pos_emb = self.embed.position_encoding( - offset=offset - cache_size, size=xs.size(1)) + offset=offset - cache_size, size=xs.shape[1]) if required_cache_size < 0: next_cache_start = 0 elif required_cache_size == 0: - next_cache_start = xs.size(1) + next_cache_start = xs.shape[1] else: - next_cache_start = xs.size(1) - required_cache_size + next_cache_start = xs.shape[1] - required_cache_size r_subsampling_cache = xs[:, next_cache_start:, :] # Real mask for transformer/conformer layers - masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) + masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) #[B=1, L'=1, T] r_elayers_output_cache = [] r_conformer_cnn_cache = [] @@ -303,7 +302,7 @@ class BaseEncoder(nn.Layer): stride = subsampling * decoding_chunk_size decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.size(1) + num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks subsampling_cache: Optional[paddle.Tensor] = None elayers_output_cache: Optional[List[paddle.Tensor]] = None @@ -319,10 +318,10 @@ class BaseEncoder(nn.Layer): chunk_xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) outputs.append(y) - offset += y.size(1) + offset += y.shape[1] ys = paddle.cat(outputs, 1) # fake mask, just for jit script and compatibility with `forward` api - masks = paddle.ones([1, ys.size(1)], dtype=paddle.bool) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) return ys, masks diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 8918ca669b02819cbaadef6f079da5e420a83d0c..2c58be7e3673e0fbf3608170a38b4a91f61dcb18 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -23,11 +23,32 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] class CTCLoss(nn.Layer): - def __init__(self, blank=0, reduction='sum', batch_average=False): + def __init__(self, + blank=0, + reduction='sum', + batch_average=False, + grad_norm_type=None): super().__init__() # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.batch_average = batch_average + logger.info( + f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") + + # instance for norm_by_times + # batch for norm_by_batchsize + # frame for norm_by_total_logits_len + assert grad_norm_type in ('instance', 'batch', 'frame', None) + self.norm_by_times = False + self.norm_by_batchsize = False + self.norm_by_total_logits_len = False + logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") + if grad_norm_type == 'instance': + self.norm_by_times = True + if grad_norm_type == 'batch': + self.norm_by_batchsize = True + if grad_norm_type == 'frame': + self.norm_by_total_logits_len = True def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. @@ -46,10 +67,15 @@ class CTCLoss(nn.Layer): # warp-ctc need activation with shape [T, B, V + 1] # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) - # (TODO:Hui Zhang) ctc loss does not support int64 labels ys_pad = ys_pad.astype(paddle.int32) loss = self.loss( - logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average) + logits, + ys_pad, + hlens, + ys_lens, + norm_by_times=self.norm_by_times, + norm_by_batchsize=self.norm_by_batchsize, + norm_by_total_logits_len=self.norm_by_total_logits_len) if self.batch_average: # Batch-size average loss = loss / B @@ -124,9 +150,9 @@ class LabelSmoothingLoss(nn.Layer): # use zeros_like instead of torch.no_grad() for true_dist, # since no_grad() can not be exported by JIT true_dist = paddle.full_like(x, self.smoothing / (self.size - 1)) - ignore = target == self.padding_idx # (B,) + ignore = (target == self.padding_idx) # (B,) - # target = target * (1 - ignore) # avoid -1 index + #TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index target = target.masked_fill(ignore, 0) # avoid -1 index # true_dist.scatter_(1, target.unsqueeze(1), self.confidence) target_mask = F.one_hot(target, self.size) @@ -135,10 +161,8 @@ class LabelSmoothingLoss(nn.Layer): kl = self.criterion(F.log_softmax(x, axis=1), true_dist) - #TODO(Hui Zhang): sum not support bool type - #total = len(target) - int(ignore.sum()) - total = len(target) - int(ignore.type_as(target).sum()) + total = len(target) - int(ignore.sum()) denom = total if self.normalize_length else B - #numer = (kl * (1 - ignore)).sum() + #TODO(Hui Zhang): numer = (kl * (1 - ignore)).sum() numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum() return numer / denom diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 05e86eb33b586a86abcebe646dbe7e34bbd7de64..6d46f5ba06b1ba306f7ad8eef1aa830a00b7d366 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -69,8 +69,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] """ - #TODO(Hui Zhang): return ~make_pad_mask(lengths), not support ~ - return make_pad_mask(lengths).logical_not() + return ~make_pad_mask(lengths) def subsequent_mask(size: int) -> paddle.Tensor: @@ -92,12 +91,7 @@ def subsequent_mask(size: int) -> paddle.Tensor: [1, 1, 1]] """ ret = paddle.ones([size, size], dtype=paddle.bool) - #TODO(Hui Zhang): tril not support bool - #return paddle.tril(ret) - ret = ret.astype(paddle.float) - ret = paddle.tril(ret) - ret = ret.astype(paddle.bool) - return ret + return paddle.tril(ret) def subsequent_chunk_mask( @@ -186,15 +180,13 @@ def add_optional_chunk_mask(xs: paddle.Tensor, chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size, num_left_chunks) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) - # chunk_masks = masks & chunk_masks # (B, L, L) - chunk_masks = masks.logical_and(chunk_masks) # (B, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) elif static_chunk_size > 0: num_left_chunks = num_decoding_left_chunks chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size, num_left_chunks) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) - # chunk_masks = masks & chunk_masks # (B, L, L) - chunk_masks = masks.logical_and(chunk_masks) # (B, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) else: chunk_masks = masks return chunk_masks diff --git a/deepspeech/modules/rnn.py b/deepspeech/modules/rnn.py index 0d8c9fd2cd859748ec9f882893b026fa6840df46..8f8b2a18dd5455943d03e9634f958b14caa3fc57 100644 --- a/deepspeech/modules/rnn.py +++ b/deepspeech/modules/rnn.py @@ -308,7 +308,7 @@ class RNNStack(nn.Layer): x, x_len = rnn(x, x_len) masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index 7f4bb804832a1df7dcfabf8fb48ee39ddaed8d5a..e079293c75a09c6903d8d554d1464b6dcea563d3 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -14,6 +14,20 @@ import argparse +class ExtendAction(argparse.Action): + """ + [Since Python 3.8, the "extend" is available directly in stdlib] + (https://docs.python.org/3.8/library/argparse.html#action). + If you only have to support 3.8+ then defining it yourself is no longer required. + Usage of stdlib "extend" action is exactly the same way as this answer originally described: + """ + + def __call__(self, parser, namespace, values, option_string=None): + items = getattr(namespace, self.dest) or [] + items.extend(values) + setattr(namespace, self.dest, items) + + def default_argument_parser(): r"""A simple yet genral argument parser for experiments with parakeet. @@ -30,7 +44,7 @@ def default_argument_parser(): The ``--checkpoint_path`` specifies the checkpoint to load from. - The ``--device`` and ``--nprocs`` specifies how to run the training. + The ``--nprocs`` specifies how to run the training. See Also @@ -42,29 +56,53 @@ def default_argument_parser(): the parser """ parser = argparse.ArgumentParser() + parser.register('action', 'extend', ExtendAction) - # yapf: disable - # data and output - parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") - parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") - parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") - - # load from saved checkpoint - parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") - - # running - parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], - help="device type to use, cpu and gpu are supported.") - parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") - - # overwrite extra config and default config - # parser.add_argument("--opts", nargs=argparse.REMAINDER, - # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") - parser.add_argument("--opts", type=str, default=[], nargs='+', - help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + train_group = parser.add_argument_group( + title='Train Options', description=None) + train_group.add_argument( + "--seed", + type=int, + default=None, + help="seed to use for paddle, np and random. None or 0 for random, else set seed." + ) + train_group.add_argument( + "--nprocs", + type=int, + default=1, + help="number of parallel processes. 0 for cpu.") + train_group.add_argument( + "--config", metavar="CONFIG_FILE", help="config file.") + train_group.add_argument( + "--output", metavar="CKPT_DIR", help="path to save checkpoint.") + train_group.add_argument( + "--checkpoint_path", type=str, help="path to load checkpoint") + train_group.add_argument( + "--opts", + action='extend', + nargs=2, + metavar=('key', 'val'), + help="overwrite --config field, passing (KEY VALUE) pairs") + train_group.add_argument( + "--dump-config", metavar="FILE", help="dump config to `this` file.") - parser.add_argument("--seed", type=int, default=None, - help="seed to use for paddle, np and random. None or 0 for random, else set seed.") - # yapd: enable + profile_group = parser.add_argument_group( + title='Benchmark Options', description=None) + profile_group.add_argument( + '--profiler-options', + type=str, + default=None, + help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + ) + profile_group.add_argument( + '--benchmark-batch-size', + type=int, + default=None, + help='batch size for benchmark.') + profile_group.add_argument( + '--benchmark-max-step', + type=int, + default=None, + help='max iteration for benchmark.') return parser diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py index 96ff967f53a14203d313aaf024799d95b6fd307f..1026a4ec39e9257551dddb4d9beff9a415a82da3 100644 --- a/deepspeech/training/extensions/evaluator.py +++ b/deepspeech/training/extensions/evaluator.py @@ -13,14 +13,18 @@ # limitations under the License. from typing import Dict -import extension import paddle +from paddle import distributed as dist from paddle.io import DataLoader from paddle.nn import Layer +from . import extension from ..reporter import DictSummary +from ..reporter import ObsScope from ..reporter import report -from ..reporter import scope +from ..timer import Timer +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() class StandardEvaluator(extension.Extension): @@ -43,6 +47,27 @@ class StandardEvaluator(extension.Extension): def evaluate_core(self, batch): # compute self.model(batch) # you may report here + return + + def evaluate_sync(self, data): + # dist sync `evaluate_core` outputs + if data is None: + return + + numerator, denominator = data + if dist.get_world_size() > 1: + numerator = paddle.to_tensor(numerator) + denominator = paddle.to_tensor(denominator) + # the default operator in all_reduce function is sum. + dist.all_reduce(numerator) + dist.all_reduce(denominator) + value = numerator / denominator + value = float(value) + else: + value = numerator / denominator + # used for `snapshort` to do kbest save. + report("VALID/LOSS", value) + logger.info(f"Valid: all-reduce loss {value}") def evaluate(self): # switch to eval mode @@ -53,12 +78,16 @@ class StandardEvaluator(extension.Extension): summary = DictSummary() for batch in self.dataloader: observation = {} - with scope(observation): + with ObsScope(observation): # main evaluation computation here. with paddle.no_grad(): - self.evaluate_core(batch) + self.evaluate_sync(self.evaluate_core(batch)) summary.add(observation) summary = summary.compute_mean() + + # switch to train mode + for model in self.models.values(): + model.train() return summary def __call__(self, trainer=None): @@ -66,6 +95,7 @@ class StandardEvaluator(extension.Extension): # if it is used to extend a trainer, the metrics is reported to # to observation of the trainer # or otherwise, you can use your own observation - summary = self.evaluate() + with Timer("Eval Time Cost: {}"): + summary = self.evaluate() for k, v in summary.items(): report(k, v) diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py index cb4e6dfbff84d4dfed91af599b1bb040f37e9660..e81eb97fccf9612c240cb5e4dcfae792f8e68800 100644 --- a/deepspeech/training/extensions/snapshot.py +++ b/deepspeech/training/extensions/snapshot.py @@ -20,8 +20,9 @@ from typing import List import jsonlines -from deepspeech.training.extensions import extension -from deepspeech.training.updaters.trainer import Trainer +from . import extension +from ..reporter import get_observations +from ..updaters.trainer import Trainer from deepspeech.utils.log import Log from deepspeech.utils.mp_tools import rank_zero_only @@ -52,8 +53,19 @@ class Snapshot(extension.Extension): priority = -100 default_name = "snapshot" - def __init__(self, max_size: int=5, snapshot_on_error: bool=False): + def __init__(self, + mode='latest', + max_size: int=5, + indicator=None, + less_better=True, + snapshot_on_error: bool=False): self.records: List[Dict[str, Any]] = [] + assert mode in ('latest', 'kbest'), mode + if mode == 'kbest': + assert indicator is not None + self.mode = mode + self.indicator = indicator + self.less_is_better = less_better self.max_size = max_size self._snapshot_on_error = snapshot_on_error self._save_all = (max_size == -1) @@ -66,16 +78,17 @@ class Snapshot(extension.Extension): # load existing records record_path: Path = self.checkpoint_dir / "records.jsonl" if record_path.exists(): - logger.debug("Loading from an existing checkpoint dir") self.records = load_records(record_path) - trainer.updater.load(self.records[-1]['path']) + ckpt_path = self.records[-1]['path'] + logger.info(f"Loading from an existing checkpoint {ckpt_path}") + trainer.updater.load(ckpt_path) def on_error(self, trainer, exc, tb): if self._snapshot_on_error: - self.save_checkpoint_and_update(trainer) + self.save_checkpoint_and_update(trainer, 'latest') def __call__(self, trainer: Trainer): - self.save_checkpoint_and_update(trainer) + self.save_checkpoint_and_update(trainer, self.mode) def full(self): """Whether the number of snapshots it keeps track of is greater @@ -83,12 +96,12 @@ class Snapshot(extension.Extension): return (not self._save_all) and len(self.records) > self.max_size @rank_zero_only - def save_checkpoint_and_update(self, trainer: Trainer): + def save_checkpoint_and_update(self, trainer: Trainer, mode: str): """Saving new snapshot and remove the oldest snapshot if needed.""" iteration = trainer.updater.state.iteration epoch = trainer.updater.state.epoch num = epoch if self.trigger[1] == 'epoch' else iteration - path = self.checkpoint_dir / f"{num}.pdz" + path = self.checkpoint_dir / f"{num}.np" # add the new one trainer.updater.save(path) @@ -97,11 +110,17 @@ class Snapshot(extension.Extension): 'path': str(path.resolve()), # use absolute path 'iteration': iteration, 'epoch': epoch, + 'indicator': get_observations()[self.indicator] } self.records.append(record) # remove the earist if self.full(): + if mode == 'kbest': + self.records = sorted( + self.records, + key=lambda record: record['indicator'], + reverse=not self.less_is_better) eariest_record = self.records[0] os.remove(eariest_record["path"]) self.records.pop(0) diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py index b69e94aaf4bbd42bb1bb50010af86e419d7c7ddb..e5f456cac4ff2b1fd9623ec1948a9e7337b712f0 100644 --- a/deepspeech/training/extensions/visualizer.py +++ b/deepspeech/training/extensions/visualizer.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from deepspeech.training.extensions import extension -from deepspeech.training.updaters.trainer import Trainer +from visualdl import LogWriter + +from . import extension +from ..updaters.trainer import Trainer class VisualDL(extension.Extension): @@ -26,8 +28,8 @@ class VisualDL(extension.Extension): default_name = 'visualdl' priority = extension.PRIORITY_READER - def __init__(self, writer): - self.writer = writer + def __init__(self, output_dir): + self.writer = LogWriter(str(output_dir)) def __call__(self, trainer: Trainer): for k, v in trainer.observation.items(): diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index f46814eb0ae04887ceb4c7c1f674fc360f3644c0..87b36acaeccd9fecffba48a7b0c6c61a3ff782b2 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -47,7 +47,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): sum_square = layers.reduce_sum(square) sum_square_list.append(sum_square) - # debug log + # debug log, not dump all since slow down train process if i < 10: logger.debug( f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") @@ -76,7 +76,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): new_grad = layers.elementwise_mul(x=g, y=clip_var) params_and_grads.append((p, new_grad)) - # debug log + # debug log, not dump all since slow down train process if i < 10: logger.debug( f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" diff --git a/deepspeech/training/reporter.py b/deepspeech/training/reporter.py index 66a81adef1c47f8fe55ad8d608daaa2cb97545ff..7afc33f38966529c75831d45443c848ea0c12839 100644 --- a/deepspeech/training/reporter.py +++ b/deepspeech/training/reporter.py @@ -19,7 +19,7 @@ OBSERVATIONS = None @contextlib.contextmanager -def scope(observations): +def ObsScope(observations): # make `observation` the target to report to. # it is basically a dictionary that stores temporary observations global OBSERVATIONS diff --git a/deepspeech/training/timer.py b/deepspeech/training/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca9d6386de45074ac76c7c754f89cefa36c5459 --- /dev/null +++ b/deepspeech/training/timer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import time + +from deepspeech.utils.log import Log + +__all__ = ["Timer"] + +logger = Log(__name__).getlog() + + +class Timer(): + """To be used like this: + with Timer("Message") as value: + do some thing + """ + + def __init__(self, message=None): + self.message = message + + def duration(self) -> str: + elapsed_time = time.time() - self.start + time_str = str(datetime.timedelta(seconds=elapsed_time)) + return time_str + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, type, value, traceback): + if self.message: + logger.info(self.message.format(self.duration())) + + def __call__(self) -> float: + return time.time() - self.start + + def __str__(self): + return self.duration() diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 3a922c6f4f88f03dadf20e8e978a84bcf436a58a..79b1562e4ddf419bc99e7c9f9dd82e81a28f3d9a 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -11,17 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import time +from collections import OrderedDict from pathlib import Path import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter +from deepspeech.training.reporter import ObsScope +from deepspeech.training.reporter import report +from deepspeech.training.timer import Timer from deepspeech.utils import mp_tools +from deepspeech.utils import profiler from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log from deepspeech.utils.utility import seed_all +from deepspeech.utils.utility import UpdateConfig __all__ = ["Trainer"] @@ -79,7 +86,7 @@ class Trainer(): >>> config.merge_from_list(args.opts) >>> config.freeze() >>> - >>> if args.nprocs > 1 and args.device == "gpu": + >>> if args.nprocs > 0: >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) >>> else: >>> main_sp(config, args) @@ -94,15 +101,25 @@ class Trainer(): self.checkpoint_dir = None self.iteration = 0 self.epoch = 0 + self.rank = dist.get_rank() + + logger.info(f"Rank: {self.rank}/{dist.get_world_size()}") if args.seed: seed_all(args.seed) logger.info(f"Set seed {args.seed}") + if self.args.benchmark_batch_size: + with UpdateConfig(self.config): + self.config.collator.batch_size = self.args.benchmark_batch_size + self.config.training.log_interval = 1 + logger.info( + f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") + def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') if self.parallel: self.init_parallel() @@ -122,7 +139,7 @@ class Trainer(): """A flag indicating whether the experiment should run with multiprocessing. """ - return self.args.device == "gpu" and self.args.nprocs > 1 + return self.args.nprocs > 0 def init_parallel(self): """Init environment for multiprocess training. @@ -162,67 +179,108 @@ class Trainer(): checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) if infos: - # restore from ckpt + # just restore ckpt + # lr will resotre from optimizer ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] scratch = False + logger.info( + f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") else: self.iteration = 0 self.epoch = 0 scratch = True - + logger.info("Init from scratch!") return scratch - def new_epoch(self): - """Reset the train loader seed and increment `epoch`. - """ - self.epoch += 1 - if self.parallel and hasattr(self.train_loader, "batch_sampler"): + def maybe_batch_sampler_step(self): + """ batch_sampler seed by epoch """ + if hasattr(self.train_loader, "batch_sampler"): batch_sampler = self.train_loader.batch_sampler if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): batch_sampler.set_epoch(self.epoch) - def train(self): - """The training process control by epoch.""" + def before_train(self): from_scratch = self.resume_or_scratch() if from_scratch: - # save init model, i.e. 0 epoch + # scratch: save init model, i.e. 0 epoch self.save(tag='init', infos=None) - self.lr_scheduler.step(self.epoch) - if self.parallel and hasattr(self.train_loader, "batch_sampler"): - self.train_loader.batch_sampler.set_epoch(self.epoch) + else: + # resume: train next_epoch and next_iteration + self.epoch += 1 + self.iteration += 1 + logger.info( + f"Resume train: epoch {self.epoch }, step {self.iteration}!") + + self.maybe_batch_sampler_step() + + def new_epoch(self): + """Reset the train loader seed and increment `epoch`. + """ + # `iteration` increased by train step + self.epoch += 1 + self.maybe_batch_sampler_step() + + def after_train_batch(self): + if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step: + profiler.add_profiler_step(self.args.profiler_options) + logger.info( + f"Reach benchmark-max-step: {self.args.benchmark_max_step}") + sys.exit( + f"Reach benchmark-max-step: {self.args.benchmark_max_step}") + + def train(self): + """The training process control by epoch.""" + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train:" + observation = OrderedDict() + with ObsScope(observation): + report("Rank", dist.get_rank()) + report("epoch", self.epoch) + report('step', self.iteration) + report("lr", self.lr_scheduler()) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + report('iter', batch_index + 1) + report('total', len(self.train_loader)) + report('reader_cost', dataload_time) + observation['batch_cost'] = observation[ + 'reader_cost'] + observation['step_cost'] + observation['samples'] = observation['batch_size'] + observation['ips[sent./sec]'] = observation[ + 'batch_size'] / observation['batch_cost'] + for k, v in observation.items(): + msg += f" {k}: " + msg += f"{v:>.8f}" if isinstance(v, + float) else f"{v}" + msg += "," + logger.info(msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -231,6 +289,7 @@ class Trainer(): 'epoch', {'cv_loss': cv_loss, 'lr': self.lr_scheduler()}, self.epoch) + # after epoch self.save(tag=self.epoch, infos={'val_loss': cv_loss}) # step lr every epoch self.lr_scheduler.step() @@ -240,14 +299,13 @@ class Trainer(): """The routine of the experiment after setup. This method is intended to be used by the user. """ - try: - self.train() - except KeyboardInterrupt: - self.save() - exit(-1) - finally: - self.destory() - logger.info("Training Done.") + with Timer("Training Done: {}"): + try: + self.train() + except KeyboardInterrupt: + exit(-1) + finally: + self.destory() def setup_output_dir(self): """Create a directory used for output. diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py index fc758e93e7390694a1bbd26763db4c941ebc85dd..10c99e7fced7ae94cb09b630e38d21e4153ddbb2 100644 --- a/deepspeech/training/updaters/standard_updater.py +++ b/deepspeech/training/updaters/standard_updater.py @@ -14,12 +14,12 @@ from typing import Dict from typing import Optional -from paddle import Tensor +import paddle from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from paddle.nn import Layer from paddle.optimizer import Optimizer -from timer import timer +from paddle.optimizer.lr import LRScheduler from deepspeech.training.reporter import report from deepspeech.training.updaters.updater import UpdaterBase @@ -39,8 +39,10 @@ class StandardUpdater(UpdaterBase): def __init__(self, model: Layer, optimizer: Optimizer, + scheduler: LRScheduler, dataloader: DataLoader, init_state: Optional[UpdaterState]=None): + super().__init__(init_state) # it is designed to hold multiple models models = {"main": model} self.models: Dict[str, Layer] = models @@ -51,15 +53,14 @@ class StandardUpdater(UpdaterBase): self.optimizer = optimizer self.optimizers: Dict[str, Optimizer] = optimizers + # it is designed to hold multiple scheduler + schedulers = {"main": scheduler} + self.scheduler = scheduler + self.schedulers: Dict[str, LRScheduler] = schedulers + # dataloaders self.dataloader = dataloader - # init state - if init_state is None: - self.state = UpdaterState() - else: - self.state = init_state - self.train_iterator = iter(dataloader) def update(self): @@ -103,8 +104,10 @@ class StandardUpdater(UpdaterBase): model.train() # training for a step is implemented here - batch = self.read_batch() - self.update_core(batch) + with Timier("data time cost:{}"): + batch = self.read_batch() + with Timier("step time cost:{}"): + self.update_core(batch) self.state.iteration += 1 if self.updates_per_epoch is not None: @@ -115,13 +118,14 @@ class StandardUpdater(UpdaterBase): """A simple case for a training step. Basic assumptions are: Single model; Single optimizer; + Single scheduler, and update learning rate each step; A batch from the dataloader is just the input of the model; The model return a single loss, or a dict containing serval losses. Parameters updates at every batch, no gradient accumulation. """ loss = self.model(*batch) - if isinstance(loss, Tensor): + if isinstance(loss, paddle.Tensor): loss_dict = {"main": loss} else: # Dict[str, Tensor] @@ -135,14 +139,15 @@ class StandardUpdater(UpdaterBase): for name, loss_item in loss_dict.items(): report(name, float(loss_item)) - self.optimizer.clear_gradient() + self.optimizer.clear_grad() loss_dict["main"].backward() - self.optimizer.update() + self.optimizer.step() + self.scheduler.step() @property def updates_per_epoch(self): - """Number of updater per epoch, determined by the length of the - dataloader.""" + """Number of steps per epoch, + determined by the length of the dataloader.""" length_of_dataloader = None try: length_of_dataloader = len(self.dataloader) @@ -163,18 +168,16 @@ class StandardUpdater(UpdaterBase): def read_batch(self): """Read a batch from the data loader, auto renew when data is exhausted.""" - with timer() as t: - try: - batch = next(self.train_iterator) - except StopIteration: - self.new_epoch() - batch = next(self.train_iterator) - logger.debug( - f"Read a batch takes {t.elapse}s.") # replace it with logger + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) return batch def state_dict(self): - """State dict of a Updater, model, optimizer and updater state are included.""" + """State dict of a Updater, model, optimizers/schedulers + and updater state are included.""" state_dict = super().state_dict() for name, model in self.models.items(): state_dict[f"{name}_params"] = model.state_dict() @@ -184,7 +187,7 @@ class StandardUpdater(UpdaterBase): def set_state_dict(self, state_dict): """Set state dict for a Updater. Parameters of models, states for - optimizers and UpdaterState are restored.""" + optimizers/schedulers and UpdaterState are restored.""" for name, model in self.models.items(): model.set_state_dict(state_dict[f"{name}_params"]) for name, optim in self.optimizers.items(): diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py index 954ce2604d18569b34c35be4fd517f74a59fc14e..077694659505a7d9e65b70db0f2a54198a03da09 100644 --- a/deepspeech/training/updaters/trainer.py +++ b/deepspeech/training/updaters/trainer.py @@ -24,7 +24,7 @@ import tqdm from deepspeech.training.extensions.extension import Extension from deepspeech.training.extensions.extension import PRIORITY_READER -from deepspeech.training.reporter import scope +from deepspeech.training.reporter import ObsScope from deepspeech.training.triggers import get_trigger from deepspeech.training.triggers.limit_trigger import LimitTrigger from deepspeech.training.updaters.updater import UpdaterBase @@ -140,11 +140,11 @@ class Trainer(): try: while not stop_trigger(self): self.observation = {} - # set observation as the report target - # you can use report freely in Updater.update() + # set observation as the `report` target + # you can use `report` freely in Updater.update() # updating parameters and state - with scope(self.observation): + with ObsScope(self.observation): update() p.update() diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py index 66fdc2bbc7aea7b7f08f1423a58736e6ffb3b068..e5dd65563dd846f8c01dbe68e2c01a931631a311 100644 --- a/deepspeech/training/updaters/updater.py +++ b/deepspeech/training/updaters/updater.py @@ -52,6 +52,7 @@ class UpdaterBase(): """ def __init__(self, init_state=None): + # init state if init_state is None: self.state = UpdaterState() else: diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index a59f8be796932c7fb4976178b6c5abb63b0d5ffd..8e31edfaeeb7eb4745df3492aa877c2b91bde734 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -114,13 +114,13 @@ class Checkpoint(): params_path = checkpoint_path + ".pdparams" model_dict = paddle.load(params_path) model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) + logger.info("Rank {}: Restore model from {}".format(rank, params_path)) optimizer_path = checkpoint_path + ".pdopt" if optimizer and os.path.isfile(optimizer_path): optimizer_dict = paddle.load(optimizer_path) optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( + logger.info("Rank {}: Restore optimizer state from {}".format( rank, optimizer_path)) info_path = re.sub('.pdparams$', '.json', params_path) diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 09543d48d45dec32cae5192913aa4c20636264d8..fc43a71f07e6ea803d810dd535561239262c0319 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -84,19 +84,19 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, y_insert_blank = insert_blank(y, blank_id) #(2L+1) log_alpha = paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) + (ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero - # TODO(Hui Zhang): zeros not support paddle.int16 + + # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16 state_path = (paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1 + (ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1 ) # state path, Tuple((T, 2L+1)) # init start state - # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 - log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb - log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb + log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb + log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb - for t in range(1, ctc_probs.size(0)): # T + for t in range(1, ctc_probs.shape[0]): # T for s in range(len(y_insert_blank)): # 2L+1 if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ s] == y_insert_blank[s - 2]: @@ -110,13 +110,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, log_alpha[t - 1, s - 2], ]) prev_state = [s, s - 1, s - 2] - # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 - log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int( - y_insert_blank[s])] + log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ + y_insert_blank[s]] state_path[t, s] = prev_state[paddle.argmax(candidates)] - - # TODO(Hui Zhang): zeros not support paddle.int16 - state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32) + # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16 + state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32) candidates = paddle.to_tensor([ log_alpha[-1, len(y_insert_blank) - 1], # Sb @@ -124,11 +122,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, ]) prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] state_seq[-1] = prev_state[paddle.argmax(candidates)] - for t in range(ctc_probs.size(0) - 2, -1, -1): + for t in range(ctc_probs.shape[0] - 2, -1, -1): state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] output_alignment = [] - for t in range(0, ctc_probs.size(0)): + for t in range(0, ctc_probs.shape[0]): output_alignment.append(y_insert_blank[state_seq[t, 0]]) return output_alignment diff --git a/deepspeech/utils/log.py b/deepspeech/utils/log.py index 3fd7d24800383d287ad5c32f142cb7e0489fcc0e..7e8de600a9b46665b76d05dee8a811686d7e4cb3 100644 --- a/deepspeech/utils/log.py +++ b/deepspeech/utils/log.py @@ -12,19 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import getpass -import logging import os import socket import sys +from loguru import logger from paddle import inference -FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' -DATE_FMT_STR = '%Y/%m/%d %H:%M:%S' - -logging.basicConfig( - level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR) - def find_log_dir(log_dir=None): """Returns the most suitable directory to put log files into. @@ -98,59 +92,28 @@ def find_log_dir_and_names(program_name=None, log_dir=None): class Log(): - - log_name = None - - def __init__(self, logger=None): - self.logger = logging.getLogger(logger) - self.logger.setLevel(logging.DEBUG) - - file_dir = os.getcwd() + '/log' - if not os.path.exists(file_dir): - os.mkdir(file_dir) - self.log_dir = file_dir - - actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( - program_name=None, log_dir=self.log_dir) - - basename = '%s.DEBUG.%d' % (file_prefix, os.getpid()) - filename = os.path.join(actual_log_dir, basename) - if Log.log_name is None: - Log.log_name = filename - - # Create a symlink to the log file with a canonical name. - symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG') - try: - if os.path.islink(symlink): - os.unlink(symlink) - os.symlink(os.path.basename(Log.log_name), symlink) - except EnvironmentError: - # If it fails, we're sad but it's no error. Commonly, this - # fails because the symlink was created by another user and so - # we can't modify it - pass - - if not self.logger.hasHandlers(): - formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR) - fh = logging.FileHandler(Log.log_name) - fh.setLevel(logging.DEBUG) - fh.setFormatter(formatter) - self.logger.addHandler(fh) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - self.logger.addHandler(ch) - - # stop propagate for propagating may print - # log multiple times - self.logger.propagate = False + """Default Logger for all.""" + logger.remove() + logger.add( + sys.stdout, + level='INFO', + enqueue=True, + filter=lambda record: record['level'].no >= 20) + _, file_prefix, _ = find_log_dir_and_names() + sink_prefix = os.path.join("exp/log", file_prefix) + sink_path = sink_prefix[:-3] + "{time}.log" + logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB") + + def __init__(self, name=None): + pass def getlog(self): - return self.logger + return logger class Autolog: + """Just used by fullchain project""" + def __init__(self, batch_size, model_name="DeepSpeech", diff --git a/deepspeech/utils/profiler.py b/deepspeech/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..5733f8ed5bc943d00de36a57852495aaf9320be6 --- /dev/null +++ b/deepspeech/utils/profiler.py @@ -0,0 +1,119 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +import paddle + +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +# A global variable to record the number of calling times for profiler +# functions. It is used to specify the tracing range of training steps. +_profiler_step_id = 0 + +# A global variable to avoid parsing from string every time. +_profiler_options = None + + +class ProfilerOptions(object): + ''' + Use a string to initialize a ProfilerOptions. + The string should be in the format: "key1=value1;key2=value;key3=value3". + For example: + "profile_path=model.profile" + "batch_range=[50, 60]; profile_path=model.profile" + "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" + ProfilerOptions supports following key-value pair: + batch_range - a integer list, e.g. [100, 110]. + state - a string, the optional values are 'CPU', 'GPU' or 'All'. + sorted_key - a string, the optional values are 'calls', 'total', + 'max', 'min' or 'ave. + tracer_option - a string, the optional values are 'Default', 'OpDetail', + 'AllOpDetail'. + profile_path - a string, the path to save the serialized profile data, + which can be used to generate a timeline. + exit_on_finished - a boolean. + ''' + + def __init__(self, options_str): + assert isinstance(options_str, str) + + self._options = { + 'batch_range': [10, 20], + 'state': 'All', + 'sorted_key': 'total', + 'tracer_option': 'Default', + 'profile_path': '/tmp/profile', + 'exit_on_finished': True + } + self._parse_from_string(options_str) + + def _parse_from_string(self, options_str): + if not options_str: + return + + for kv in options_str.replace(' ', '').split(';'): + key, value = kv.split('=') + if key == 'batch_range': + value_list = value.replace('[', '').replace(']', '').split(',') + value_list = list(map(int, value_list)) + if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ + 1] > value_list[0]: + self._options[key] = value_list + elif key == 'exit_on_finished': + self._options[key] = value.lower() in ("yes", "true", "t", "1") + elif key in [ + 'state', 'sorted_key', 'tracer_option', 'profile_path' + ]: + self._options[key] = value + + def __getitem__(self, name): + if self._options.get(name, None) is None: + raise ValueError( + "ProfilerOptions does not have an option named %s." % name) + return self._options[name] + + +def add_profiler_step(options_str=None): + ''' + Enable the operator-level timing using PaddlePaddle's profiler. + The profiler uses a independent variable to count the profiler steps. + One call of this function is treated as a profiler step. + + Args: + profiler_options - a string to initialize the ProfilerOptions. + Default is None, and the profiler is disabled. + ''' + if options_str is None: + return + + global _profiler_step_id + global _profiler_options + + if _profiler_options is None: + _profiler_options = ProfilerOptions(options_str) + logger.info(f"Profiler: {options_str}") + logger.info(f"Profiler: {_profiler_options._options}") + + if _profiler_step_id == _profiler_options['batch_range'][0]: + paddle.utils.profiler.start_profiler(_profiler_options['state'], + _profiler_options['tracer_option']) + elif _profiler_step_id == _profiler_options['batch_range'][1]: + paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], + _profiler_options['profile_path']) + if _profiler_options['exit_on_finished']: + sys.exit(0) + + _profiler_step_id += 1 diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 9bff6b0f3f35ccb7a5392481efd2da7fb9e70ea1..61798816b5e4205c21ce540f61b06011ea8dad11 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -83,7 +83,7 @@ def pad_sequence(sequences: List[paddle.Tensor], # (TODO Hui Zhang): slice not supprot `end==start` # trailing_dims = max_size[1:] trailing_dims = max_size[1:] if max_size.ndim >= 2 else () - max_len = max([s.size(0) for s in sequences]) + max_len = max([s.shape[0] for s in sequences]) if batch_first: out_dims = (len(sequences), max_len) + trailing_dims else: @@ -91,12 +91,22 @@ def pad_sequence(sequences: List[paddle.Tensor], out_tensor = sequences[0].new_full(out_dims, padding_value) for i, tensor in enumerate(sequences): - length = tensor.size(0) + length = tensor.shape[0] # use index notation to prevent duplicate references to the tensor if batch_first: - out_tensor[i, :length, ...] = tensor + # TODO (Hui Zhang): set_value op not supprot `end==start` + # out_tensor[i, :length, ...] = tensor + if length != 0: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[i, length, ...] = tensor else: - out_tensor[:length, i, ...] = tensor + # TODO (Hui Zhang): set_value op not supprot `end==start` + # out_tensor[:length, i, ...] = tensor + if length != 0: + out_tensor[:length, i, ...] = tensor + else: + out_tensor[length, i, ...] = tensor return out_tensor @@ -139,7 +149,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) - B = ys_pad.size(0) + B = ys_pad.shape[0] _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos ys_in = paddle.cat([_sos, ys_pad], dim=1) @@ -165,16 +175,10 @@ def th_accuracy(pad_outputs: paddle.Tensor, Returns: float: Accuracy value (0.0 - 1.0). """ - pad_pred = pad_outputs.view( - pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) + pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], + pad_outputs.shape[1]).argmax(2) mask = pad_targets != ignore_label - #TODO(Hui Zhang): sum not support bool type - # numerator = paddle.sum( - # pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) - numerator = ( + numerator = paddle.sum( pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) - numerator = paddle.sum(numerator.type_as(pad_targets)) - #TODO(Hui Zhang): sum not support bool type - # denominator = paddle.sum(mask) - denominator = paddle.sum(mask.type_as(pad_targets)) + denominator = paddle.sum(mask) return float(numerator) / float(denominator) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index e18fc1f775fc1281dbcd05ecd25bc3de6d1cbef1..6f84c41beff6ade9ac05af41a8aa9975ca78274a 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -16,15 +16,27 @@ import distutils.util import math import os import random +from contextlib import contextmanager from typing import List import numpy as np import paddle -__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"] +__all__ = [ + "UpdateConfig", "seed_all", 'print_arguments', 'add_arguments', "log_add" +] + + +@contextmanager +def UpdateConfig(config): + """Update yacs config""" + config.defrost() + yield + config.freeze() def seed_all(seed: int=210329): + """freeze random generator seed.""" np.random.seed(seed) random.seed(seed) paddle.seed(seed) diff --git a/doc/images/multi_gpu_speedup.png b/doc/images/multi_gpu_speedup.png deleted file mode 100755 index 286de51519203bb070ce9a539a21627808b7403c..0000000000000000000000000000000000000000 Binary files a/doc/images/multi_gpu_speedup.png and /dev/null differ diff --git a/doc/images/tuning_error_surface.png b/doc/images/tuning_error_surface.png deleted file mode 100644 index 2204cee2f5204d1d2d2e53fab8cdd0a1cb9ac47d..0000000000000000000000000000000000000000 Binary files a/doc/images/tuning_error_surface.png and /dev/null differ diff --git a/doc/src/benchmark.md b/doc/src/benchmark.md deleted file mode 100644 index 9c1c86fd7a83b5ae09f4dd4d5a95f5859a96845c..0000000000000000000000000000000000000000 --- a/doc/src/benchmark.md +++ /dev/null @@ -1,16 +0,0 @@ -# Benchmarks - -## Acceleration with Multi-GPUs - -We compare the training time with 1, 2, 4, 8 Tesla V100 GPUs (with a subset of LibriSpeech samples whose audio durations are between 6.0 and 7.0 seconds). And it shows that a **near-linear** acceleration with multiple GPUs has been achieved. In the following figure, the time (in seconds) cost for training is printed on the blue bars. - - - -| # of GPU | Acceleration Rate | -| -------- | --------------: | -| 1 | 1.00 X | -| 2 | 1.98 X | -| 4 | 3.73 X | -| 8 | 6.95 X | - -`utils/profile.sh` provides such a demo profiling tool, you can change it as need. diff --git a/doc/src/faq.md b/doc/src/faq.md deleted file mode 100644 index e29428176639358b2e619027750fd80ca76d31ed..0000000000000000000000000000000000000000 --- a/doc/src/faq.md +++ /dev/null @@ -1,37 +0,0 @@ -# FAQ - -1. 音频变速快慢到达什么晨读会影响识别率? - - 变速会提升识别效果,一般用0.9, 1.0, 1.1 的变速。 - -2. 音量大小到什么程度会影响识别率? - - 一般训练会固定音量到一个范围内,波动过大会影响训练,估计在10dB ~ 20dB吧。 - -3. 语音模型训练数据的最小时长要求时多少? - - Aishell-1大约178h的数据,数据越多越好。 - -4. 那些噪声或背景生会影响识别率? - - 主要是人生干扰和低信噪比会影响识别率。 - -5. 单条语音数据的长度限制是多少? - - 一般训练的语音长度会限制在1s~6s之间,和训练配置有关。 - -6. 背景声在识别前是否需要分离出来,或做降噪处理? - - 需要分离的,需要结合具体场景考虑。 - -7. 模型是否带有VAD人生激活识别能力? - - VAD是单独的模型或模块,模型不包含此能力。 - -8. 是否支持长语音识别? - - 一般过VAD后识别。 - -9. Mandarin LM Large语言模型需要的硬件配置时怎样的? - - 内存能放得下LM即可。 diff --git a/doc/src/reference.md b/doc/src/reference.md deleted file mode 100644 index 69ff6ab88eb1bcf6aea2457c18f86e18e56ae682..0000000000000000000000000000000000000000 --- a/doc/src/reference.md +++ /dev/null @@ -1,3 +0,0 @@ -# Reference - -* [wenet](https://github.com/mobvoi/wenet) diff --git a/doc/src/released_model.md b/doc/src/released_model.md deleted file mode 100644 index 0919bba584205c38644ded21a4114fe228d9c43b..0000000000000000000000000000000000000000 --- a/doc/src/released_model.md +++ /dev/null @@ -1,9 +0,0 @@ -# Released Models - -## Language Model Released - -Language Model | Training Data | Token-based | Size | Descriptions -:-------------:| :------------:| :-----: | -----: | :----------------- -[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1;
About 1.85 billion n-grams;
'trie' binary with '-a 22 -q 8 -b 8' -[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4;
About 0.13 billion n-grams;
'probing' binary with default settings -[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning;
About 3.7 billion n-grams;
'probing' binary with default settings diff --git a/doc/src/server.md b/doc/src/server.md deleted file mode 100644 index 4918d5ebe08e9cc4e1843717daab80960cedfce9..0000000000000000000000000000000000000000 --- a/doc/src/server.md +++ /dev/null @@ -1,34 +0,0 @@ - -# Trying Live Demo with Your Own Voice - -Until now, an ASR model is trained and tested qualitatively (`infer`) and quantitatively (`test`) with existing audio files. But it is not yet tested with your own speech. We build up a real-time demo ASR engine with the trained model, enabling you to test and play around with the demo, with your own voice. - -First, change your directory to `examples/aishell` and `source path.sh`. - -To start the demo's server, please run this in one console: - -```bash -CUDA_VISIBLE_DEVICES=0 bash local/server.sh -``` - -For the machine (might not be the same machine) to run the demo's client, please do the following installation before moving on. - -For example, on MAC OS X: - -```bash -brew install portaudio -pip install pyaudio -pip install keyboard -``` - -Then to start the client, please run this in another console: - -```bash -CUDA_VISIBLE_DEVICES=0 bash local/client.sh -``` - -Now, in the client console, press the `whitespace` key, hold, and start speaking. Until finishing your utterance, release the key to let the speech-to-text results shown in the console. To quit the client, just press `ESC` key. - -Notice that `deepspeech/exps/deepspeech2/deploy/client.py` must be run on a machine with a microphone device, while `deepspeech/exps/deepspeech2/deploy/server.py` could be run on one without any audio recording hardware, e.g. any remote server machine. Just be careful to set the `host_ip` and `host_port` argument with the actual accessible IP address and port, if the server and client are running with two separate machines. Nothing should be done if they are running on one single machine. - -Please also refer to `examples/aishell/local/server.sh`, which will first download a pre-trained Chinese model (trained with AISHELL1) and then start the demo server with the model. With running `examples/aishell/local/client.sh`, you can speak Chinese to test it. If you would like to try some other models, just update `--checkpoint_path` argument in the script.   diff --git a/docs/images/ds2offlineModel.png b/docs/images/ds2offlineModel.png new file mode 100644 index 0000000000000000000000000000000000000000..0d8722ab00127074b04c03a6b27e4352ff15eb41 Binary files /dev/null and b/docs/images/ds2offlineModel.png differ diff --git a/docs/images/ds2onlineModel.png b/docs/images/ds2onlineModel.png new file mode 100644 index 0000000000000000000000000000000000000000..97a0e561961433d90db81f641a555f3612f7286e Binary files /dev/null and b/docs/images/ds2onlineModel.png differ diff --git a/doc/src/augmentation.md b/docs/src/augmentation.md similarity index 100% rename from doc/src/augmentation.md rename to docs/src/augmentation.md diff --git a/doc/src/data_preparation.md b/docs/src/data_preparation.md similarity index 100% rename from doc/src/data_preparation.md rename to docs/src/data_preparation.md diff --git a/docs/src/deepspeech_architecture.md b/docs/src/deepspeech_architecture.md new file mode 100644 index 0000000000000000000000000000000000000000..b93441222e2d2b8b0df8c745985e0ace9ede1393 --- /dev/null +++ b/docs/src/deepspeech_architecture.md @@ -0,0 +1,190 @@ +# Deepspeech2 +## Streaming + +The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes. +The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers. + +To illustrate the model implementation clearly, 3 parts are described in detail. +- Data Preparation +- Encoder +- Decoder + +In addition, the training process and the testing process are also introduced. + +The arcitecture of the model is shown in Fig.1. + +

+ +
Fig.1 The Arcitecture of deepspeech2 online model +

+ +### Data Preparation +#### Vocabulary +For English data, the vocabulary dictionary is composed of 26 English characters with " ' ", space, \ and \. The \ represents the blank label in CTC, the \ represents the unknown character and the \ represents the start and the end characters. For mandarin, the vocabulary dictionary is composed of chinese characters statisticed from the training set and three additional characters are added. The added characters are \, \ and \. For both English and mandarin data, we set the default indexs that \=0, \=1 and \= last index. +``` + # The code to build vocabulary + cd examples/aishell/s0 + python3 ../../../utils/build_vocab.py \ + --unit_type="char" \ + --count_threshold=0 \ + --vocab_path="data/vocab.txt" \ + --manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw" + +# vocabulary for aishell dataset (Mandarin) +vi examples/aishell/s0/data/vocab.txt + +# vocabulary for librispeech dataset (English) +vi examples/librispeech/s0/data/vocab.txt +``` + +#### CMVN +For CMVN, a subset or the full of traininig set is chosed and be used to compute the feature mean and std. +``` + # The code to compute the feature mean and std +cd examples/aishell/s0 +python3 ../../../utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --specgram_type="linear" \ + --delta_delta=false \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --sample_rate=16000 \ + --use_dB_normalization=True \ + --num_samples=2000 \ + --num_workers=10 \ + --output_path="data/mean_std.json" + +``` + +#### Feature Extraction + For feature extraction, three methods are implemented, which are linear (FFT without using filter bank), fbank and mfcc. + Currently, the released deepspeech2 online model use the linear feature extraction method. + ``` + The code for feature extraction + vi deepspeech/frontend/featurizer/audio_featurizer.py + ``` + +### Encoder +The encoder is composed of two 2D convolution subsampling layers and a number of stacked single direction rnn layers. The 2D convolution subsampling layers extract feature representation from the raw audio feature and reduce the length of audio feature at the same time. After passing through the convolution subsampling layers, then the feature representation are input into the stacked rnn layers. For the stacked rnn layers, LSTM cell and GRU cell are provided to use. Adding one fully connected (fc) layer after the stacked rnn layers is optional. If the number of stacked rnn layers is less than 5, adding one fc layer after stacked rnn layers is recommand. + +The code of Encoder is in: +``` +vi deepspeech/models/ds2_online/deepspeech2.py +``` + +### Decoder +To got the character possibilities of each frame, the feature representation of each frame output from the encoder are input into a projection layer which is implemented as a dense layer to do feature projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to transform the frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results. + +The code of the decoder is in: +``` +# The code of constructing the decoder in model +vi deepspeech/models/ds2_online/deepspeech2.py +# The code of CTC Decoder +vi deepspeech/modules/ctc.py +``` + +## Training Process +Using the command below, you can train the deepspeech2 online model. +``` + cd examples/aishell/s0 + bash run.sh --stage 0 --stop_stage 2 --model_type online --conf_path conf/deepspeech2_online.yaml +``` +The detail commands are: +``` +# The code for training in run.sh +set -e +source path.sh + +gpus=2,3,5,7 +stage=0 +stop_stage=5 +conf_path=conf/deepspeech2_online.yaml # conf/deepspeech2.yaml | conf/deepspeech2_online.yaml +avg_num=1 +model_type=online # online | offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi +``` + +By using the command above, the training process can be started. There are 5 stages in "run.sh", and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss. + +## Testing Process +Using the command below, you can test the deepspeech2 online model. + ``` + bash run.sh --stage 3 --stop_stage 5 --model_type online --conf_path conf/deepspeech2_online.yaml +``` +The detail commands are: +``` +conf_path=conf/deepspeech2_online.yaml +avg_num=1 +model_type=online +avg_ckpt=avg_${avg_num} + + if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=2 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # export ckpt avg_n + CUDA_VISIBLE_DEVICES=5 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # test export ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1 +fi + ``` +After the training process, we use stage 3,4,5 for testing process. The stage 3 is for testing the model generated in the stage 2 and provided the CER index of the test set. The stage 4 is for transforming the model from dynamic graph to static graph by using "paddle.jit" library. The stage 5 is for testing the model in static graph. + + +## Non-Streaming +The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the stacked bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used. For the stacked bi-directional rnn layers in the offline model, the rnn cell and gru cell are provided to use. + +The arcitecture of the model is shown in Fig.2. +

+ +
Fig.2 The Arcitecture of deepspeech2 offline model +

+ + + +For data preparation and decoder, the deepspeech2 offline model is same with the deepspeech2 online model. + +The code of encoder and decoder for deepspeech2 offline model is in: +``` +vi deepspeech/models/ds2/deepspeech2.py +``` + +The training process and testing process of deepspeech2 offline model is very similary to deepspeech2 online model. +Only some changes should be noticed. + +For training and testing, the "model_type" and the "conf_path" must be set. + ``` +# Training offline +cd examples/aishell/s0 +bash run.sh --stage 0 --stop_stage 2 --model_type offline --conf_path conf/deepspeech2.yaml +``` +``` +# Testing offline +cd examples/aishell/s0 +bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml +``` diff --git a/doc/src/feature_list.md b/docs/src/feature_list.md similarity index 79% rename from doc/src/feature_list.md rename to docs/src/feature_list.md index b675d810017c2ebe84b818fcfd5728354f63fdf6..4639ddd6fd78f689590839b88c0cad4801c21fe3 100644 --- a/doc/src/feature_list.md +++ b/docs/src/feature_list.md @@ -1,13 +1,20 @@ # Features +### Dataset +* Aishell +* Librispeech +* THCHS30 +* TIMIT + ### Speech Recognition -* Offline +* Non-Streaming * [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf) * [Transformer](https://arxiv.org/abs/1706.03762) * [Conformer](https://arxiv.org/abs/2005.08100) -* Online +* Streaming + * [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf) * [U2](https://arxiv.org/pdf/2012.05481.pdf) ### Language Model @@ -22,6 +29,15 @@ * beam search * attention rescore +### Deployment + +* Paddle Inference + +### Aligment + +* MFA +* CTC Aligment + ### Speech Frontend * Audio diff --git a/doc/src/getting_started.md b/docs/src/getting_started.md similarity index 100% rename from doc/src/getting_started.md rename to docs/src/getting_started.md diff --git a/doc/src/install.md b/docs/src/install.md similarity index 95% rename from doc/src/install.md rename to docs/src/install.md index 01049a2fc5352c8b692e9c607e4a064a562e3623..8cecba125162e5f065990fe3ff025c22e1b7cdfd 100644 --- a/doc/src/install.md +++ b/docs/src/install.md @@ -4,15 +4,16 @@ To avoid the trouble of environment setup, [running in Docker container](#runnin ## Prerequisites - Python >= 3.7 -- PaddlePaddle 2.0.0 or later (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html)) +- PaddlePaddle latest version (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html)) -## Setup +## Setup (Important) - Make sure these libraries or tools installed: `pkg-config`, `flac`, `ogg`, `vorbis`, `boost`, `sox, and `swig`, e.g. installing them via `apt-get`: ```bash sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev ``` +The version of `swig` should >= 3.0 or, installing them via `yum`: diff --git a/doc/src/ngram_lm.md b/docs/src/ngram_lm.md similarity index 64% rename from doc/src/ngram_lm.md rename to docs/src/ngram_lm.md index 119a3b21ccad01d12f987e679923940ebbefb64f..7872df22d6b421d05b0acbe38eb7a60ebd4c3dbe 100644 --- a/doc/src/ngram_lm.md +++ b/docs/src/ngram_lm.md @@ -35,52 +35,3 @@ Different from the English language model, Mandarin language model is character- * A whitespace character between two tokens is inserted. Please notice that the released language models only contain Chinese simplified characters. After preprocessing done we can begin to train the language model. The key training arguments for small LM is '-o 5 --prune 0 1 2 4 4' and '-o 5' for large LM. Please refer above section for the meaning of each argument. We also convert the arpa file to binary file using default settings. - - - -## [KenLM](http://kheafield.com/code/kenlm/) - -统计语言模型工具有比较多的选择,目前使用比较好的有srilm及kenlm,其中kenlm比srilm晚出来,训练速度也更快,而且支持单机大数据的训练。现在介绍一下kenlm的使用方法。 - -1. 工具包的下载地址:http://kheafield.com/code/kenlm.tar.gz - -2. 使用。该工具在linux环境下使用方便。 先确保linux环境已经按照1.36.0的Boost和zlib - - ``` - boost: - yum install boost - yum install boost-devel - - zlib: - yum install zlib - yum install zlib-devel - ``` - - 然后gcc版本需要是4.8.2及以上。 - - ``` - wget -O - https://kheafield.com/code/kenlm.tar.gz |tar xz - mkdir kenlm/build - cd kenlm/build - cmake .. - make -j2 - ``` - -3. 训练。使用如下命令进行训练: - - ``` - build/bin/lmplz -o 3 --verbose_header --text people2014corpus_words.txt --arpa result/people2014corpus_words.arps - ``` - - 其中, - 1)people2014corpus_words.txt文件必须是分词以后的文件。 - - 训练语料<人民日报2014版熟语料>,包括: 1)标准人工切词及词性数据people2014.tar.gz, 2)未切词文本数据people2014_words.txt, 3)kenlm训练字粒度语言模型文件及其二进制文件people2014corpus_chars.arps/klm, 4)kenlm词粒度语言模型文件及其二进制文件people2014corpus_words.arps/klm。 - - 2)-o后面的5表示的是5-gram,一般取到3即可,但可以结合自己实际情况判断。 - -4. 压缩。压缩模型为二进制,方便模型快速加载: - - ``` - build/bin/build_binary ./result/people2014corpus_words.arps ./result/people2014corpus_words.klm - ``` diff --git a/docs/src/reference.md b/docs/src/reference.md new file mode 100644 index 0000000000000000000000000000000000000000..d3676fff2371ba82386d8e0f8da5c4ef5be5f780 --- /dev/null +++ b/docs/src/reference.md @@ -0,0 +1,8 @@ +# Reference + +We refer these repos to build `model` and `engine`: + +* [delta](https://github.com/Delta-ML/delta.git) +* [espnet](https://github.com/espnet/espnet.git) +* [kaldi](https://github.com/kaldi-asr/kaldi.git) +* [wenet](https://github.com/mobvoi/wenet) diff --git a/docs/src/released_model.md b/docs/src/released_model.md new file mode 100644 index 0000000000000000000000000000000000000000..61fd1560ee831705d103d0a72cff711402b62f31 --- /dev/null +++ b/docs/src/released_model.md @@ -0,0 +1,28 @@ +# Released Models + +## Acoustic Model Released in paddle 2.X +Acoustic Model | Training Data | Token-based | Size | Descriptions | CER or WER | Hours of speech +:-------------:| :------------:| :-----: | -----: | :----------------- | :---------- | :--------- +[Ds2 Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds_online.5rnn.debug.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.0824 | 151 h +[Ds2 Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds2.offline.cer6p65.release.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.065 | 151 h +[Conformer Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz) | Aishell Dataset | Char-based | 283 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention + CTC | 0.0594 | 151 h +[Conformer Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.release.tar.gz) | Aishell Dataset | Char-based | 284 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0547 | 151 h +[Conformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/conformer.release.tar.gz) | Librispeech Dataset | Word-based | 287 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0325 | 960 h +[Transformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/transformer.release.tar.gz) | Librispeech Dataset | Word-based | 195 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0544 | 960 h + +## Acoustic Model Transformed from paddle 1.8 +Acoustic Model | Training Data | Token-based | Size | Descriptions | CER or WER | Hours of speech +:-------------:| :------------:| :-----: | -----: | :----------------- | :---------- | :--------- +[Ds2 Offline Aishell model](https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz)|Aishell Dataset| Char-based| 234 MB| 2 Conv + 3 bidirectional GRU layers| 0.0804 | 151 h| +[Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz)|Librispeech Dataset| Word-based| 307 MB| 2 Conv + 3 bidirectional sharing weight RNN layers | 0.0685| 960 h| +[Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz)|Baidu Internal English Dataset| Word-based| 273 MB| 2 Conv + 3 bidirectional GRU layers | 0.0541 | 8628 h| + + + +## Language Model Released + +Language Model | Training Data | Token-based | Size | Descriptions +:-------------:| :------------:| :-----: | -----: | :----------------- +[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1;
About 1.85 billion n-grams;
'trie' binary with '-a 22 -q 8 -b 8' +[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4;
About 0.13 billion n-grams;
'probing' binary with default settings +[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning;
About 3.7 billion n-grams;
'probing' binary with default settings diff --git a/examples/1xt2x/.gitignore b/examples/1xt2x/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a9a5aecf429fd8a0d81fbd5fd37006bfa498d5c1 --- /dev/null +++ b/examples/1xt2x/.gitignore @@ -0,0 +1 @@ +tmp diff --git a/examples/1xt2x/README.md b/examples/1xt2x/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1f5fe8e3b17af649d361be4250cdd94c1795e00a --- /dev/null +++ b/examples/1xt2x/README.md @@ -0,0 +1,11 @@ +# 1xt2x + +Convert Deepspeech 1.8 released model to 2.x. + +## Model +* Deepspeech2x + +## Exp +* baidu_en8k +* aishell +* librispeech diff --git a/examples/1xt2x/aishell/.gitignore b/examples/1xt2x/aishell/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7024e0e954e16122e3df2e2778949668c7692d72 --- /dev/null +++ b/examples/1xt2x/aishell/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/1xt2x/aishell/conf/augmentation.json b/examples/1xt2x/aishell/conf/augmentation.json new file mode 100644 index 0000000000000000000000000000000000000000..fe51488c7066f6687ef680d6bfaa4f7768ef205c --- /dev/null +++ b/examples/1xt2x/aishell/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/1xt2x/aishell/conf/deepspeech2.yaml b/examples/1xt2x/aishell/conf/deepspeech2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6e745e9d1c9dd9e525d49b35fd19711aa343b2ee --- /dev/null +++ b/examples/1xt2x/aishell/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.0 + max_input_len: 27.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 1024 + use_gru: True + share_rnn_weights: False + blank_id: 4333 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: cer + decoding_method: ctc_beam_search + lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm + alpha: 2.6 + beta: 5.0 + beam_size: 300 + cutoff_prob: 0.99 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/1xt2x/aishell/local/data.sh b/examples/1xt2x/aishell/local/data.sh new file mode 100755 index 0000000000000000000000000000000000000000..1cde0c6ea795ed4332e33445919897422352baca --- /dev/null +++ b/examples/1xt2x/aishell/local/data.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf aishell_model_v1.8_to_v2.x.tar.gz +mv aishell_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm aishell_model_v1.8_to_v2.x.tar.gz -f + + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/aishell/aishell.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/aishell" + + if [ $? -ne 0 ]; then + echo "Prepare Aishell failed. Terminated." + exit 1 + fi + + for dataset in train dev test; do + mv data/manifest.${dataset} data/manifest.${dataset}.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --specgram_type="linear" \ + --delta_delta=false \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --sample_rate=16000 \ + --use_dB_normalization=True \ + --num_samples=2000 \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for dataset in train dev test; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type "char" \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${dataset}.raw" \ + --output_path="data/manifest.${dataset}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi + } & + done + wait +fi + +echo "Aishell data preparation done." +exit 0 diff --git a/examples/1xt2x/aishell/local/download_lm_ch.sh b/examples/1xt2x/aishell/local/download_lm_ch.sh new file mode 100755 index 0000000000000000000000000000000000000000..ac27a9076d9f7f3d1556c7dac10e31ea788ff622 --- /dev/null +++ b/examples/1xt2x/aishell/local/download_lm_ch.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm' +MD5="29e02312deb2e59b3c8686c7966d4fe3" +TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm + + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/aishell/local/download_model.sh b/examples/1xt2x/aishell/local/download_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e4873ef6385fa2ad5af16e5ee4cd980c41b5899 --- /dev/null +++ b/examples/1xt2x/aishell/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz' +MD5=4ade113c69ea291b8ce5ec6a03296659 +TARGET=./aishell_model_v1.8_to_v2.x.tar.gz + + +echo "Download Aishell model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download Aishell model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/1xt2x/aishell/local/test.sh b/examples/1xt2x/aishell/local/test.sh new file mode 100755 index 0000000000000000000000000000000000000000..2ae0740b3e8d44ab03e45f4c1b5dbb945657705e --- /dev/null +++ b/examples/1xt2x/aishell/local/test.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_ch.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/aishell/path.sh b/examples/1xt2x/aishell/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..080ab1f797f7bb516d7aad379d3a915515ec86e9 --- /dev/null +++ b/examples/1xt2x/aishell/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/1xt2x/aishell/run.sh b/examples/1xt2x/aishell/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..482ab2a09eadd9e4313a92befd02d8846d0fc844 --- /dev/null +++ b/examples/1xt2x/aishell/run.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=aishell_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi + diff --git a/examples/1xt2x/baidu_en8k/.gitignore b/examples/1xt2x/baidu_en8k/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7024e0e954e16122e3df2e2778949668c7692d72 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/1xt2x/baidu_en8k/conf/augmentation.json b/examples/1xt2x/baidu_en8k/conf/augmentation.json new file mode 100644 index 0000000000000000000000000000000000000000..fe51488c7066f6687ef680d6bfaa4f7768ef205c --- /dev/null +++ b/examples/1xt2x/baidu_en8k/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/1xt2x/baidu_en8k/conf/deepspeech2.yaml b/examples/1xt2x/baidu_en8k/conf/deepspeech2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fbc7466f239d8597ff5001c0d684741d9921fc78 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test-clean + min_input_len: 0.0 + max_input_len: .inf # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 1024 + use_gru: True + share_rnn_weights: False + blank_id: 28 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 1.4 + beta: 0.35 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/1xt2x/baidu_en8k/local/data.sh b/examples/1xt2x/baidu_en8k/local/data.sh new file mode 100755 index 0000000000000000000000000000000000000000..8f9468b139484b6c183ee0bb06855b74c0450949 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/local/data.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz +mv baidu_en8k_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm baidu_en8k_v1.8_to_v2.x.tar.gz -f + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/librispeech/librispeech.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/librispeech" \ + --full_download="True" + + if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 + fi + + for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mv data/manifest.${set} data/manifest.${set}.raw + done + + rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw + for set in train-clean-100 train-clean-360 train-other-500; do + cat data/manifest.${set}.raw >> data/manifest.train.raw + done + + for set in dev-clean dev-other; do + cat data/manifest.${set}.raw >> data/manifest.dev.raw + done + + for set in test-clean test-other; do + cat data/manifest.${set}.raw >> data/manifest.test.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=2000 \ + --specgram_type="linear" \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --use_dB_normalization=True \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test dev-clean dev-other test-clean test-other; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type ${unit_type} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest.${set} failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "LibriSpeech Data preparation done." +exit 0 + diff --git a/examples/1xt2x/baidu_en8k/local/download_lm_en.sh b/examples/1xt2x/baidu_en8k/local/download_lm_en.sh new file mode 100755 index 0000000000000000000000000000000000000000..dc1bdf665ac7783bc1e7344fbcbddc0b9744f44b --- /dev/null +++ b/examples/1xt2x/baidu_en8k/local/download_lm_en.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm +MD5="099a601759d467cd0a8523ff939819c5" +TARGET=${DIR}/common_crawl_00.prune01111.trie.klm + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/baidu_en8k/local/download_model.sh b/examples/1xt2x/baidu_en8k/local/download_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..6d06e3d6f5610b9a655f49a1d853a6c496049dc8 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz' +MD5=fdabeb6c96963ac85d9188f0275c6a1b +TARGET=./baidu_en8k_v1.8_to_v2.x.tar.gz + + +echo "Download BaiduEn8k model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download BaiduEn8k model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/1xt2x/baidu_en8k/local/test.sh b/examples/1xt2x/baidu_en8k/local/test.sh new file mode 100755 index 0000000000000000000000000000000000000000..4d00f30b852da5a370f5d4934f3caadd2b833c00 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/local/test.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_en.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/baidu_en8k/path.sh b/examples/1xt2x/baidu_en8k/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..080ab1f797f7bb516d7aad379d3a915515ec86e9 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/1xt2x/baidu_en8k/run.sh b/examples/1xt2x/baidu_en8k/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..c590312d17d9dd2c9d48cf37682941dd756a85b3 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/run.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=baidu_en8k_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi + diff --git a/examples/1xt2x/deepspeech2x/__init__.py b/examples/1xt2x/deepspeech2x/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d85a3dde7d44a388878a0b0f411f4a2bd594800d --- /dev/null +++ b/examples/1xt2x/deepspeech2x/__init__.py @@ -0,0 +1,370 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from typing import List +from typing import Tuple +from typing import Union + +import paddle +from paddle import nn +from paddle.fluid import core +from paddle.nn import functional as F + +from deepspeech.utils.log import Log + +#TODO(Hui Zhang): remove fluid import +logger = Log(__name__).getlog() + +########### hcak logging ############# +logger.warn = logger.warning + +########### hcak paddle ############# +paddle.half = 'float16' +paddle.float = 'float32' +paddle.double = 'float64' +paddle.short = 'int16' +paddle.int = 'int32' +paddle.long = 'int64' +paddle.uint16 = 'uint16' +paddle.cdouble = 'complex128' + + +def convert_dtype_to_string(tensor_dtype): + """ + Convert the data type in numpy to the data type in Paddle + Args: + tensor_dtype(core.VarDesc.VarType): the data type in numpy. + Returns: + core.VarDesc.VarType: the data type in Paddle. + """ + dtype = tensor_dtype + if dtype == core.VarDesc.VarType.FP32: + return paddle.float32 + elif dtype == core.VarDesc.VarType.FP64: + return paddle.float64 + elif dtype == core.VarDesc.VarType.FP16: + return paddle.float16 + elif dtype == core.VarDesc.VarType.INT32: + return paddle.int32 + elif dtype == core.VarDesc.VarType.INT16: + return paddle.int16 + elif dtype == core.VarDesc.VarType.INT64: + return paddle.int64 + elif dtype == core.VarDesc.VarType.BOOL: + return paddle.bool + elif dtype == core.VarDesc.VarType.BF16: + # since there is still no support for bfloat16 in NumPy, + # uint16 is used for casting bfloat16 + return paddle.uint16 + elif dtype == core.VarDesc.VarType.UINT8: + return paddle.uint8 + elif dtype == core.VarDesc.VarType.INT8: + return paddle.int8 + elif dtype == core.VarDesc.VarType.COMPLEX64: + return paddle.complex64 + elif dtype == core.VarDesc.VarType.COMPLEX128: + return paddle.complex128 + else: + raise ValueError("Not supported tensor dtype %s" % dtype) + + +if not hasattr(paddle, 'softmax'): + logger.warn("register user softmax to paddle, remove this when fixed!") + setattr(paddle, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle, 'log_softmax'): + logger.warn("register user log_softmax to paddle, remove this when fixed!") + setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) + +if not hasattr(paddle, 'sigmoid'): + logger.warn("register user sigmoid to paddle, remove this when fixed!") + setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) + +if not hasattr(paddle, 'log_sigmoid'): + logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) + +if not hasattr(paddle, 'relu'): + logger.warn("register user relu to paddle, remove this when fixed!") + setattr(paddle, 'relu', paddle.nn.functional.relu) + + +def cat(xs, dim=0): + return paddle.concat(xs, axis=dim) + + +if not hasattr(paddle, 'cat'): + logger.warn( + "override cat of paddle if exists or register, remove this when fixed!") + paddle.cat = cat + + +########### hcak paddle.Tensor ############# +def item(x: paddle.Tensor): + return x.numpy().item() + + +if not hasattr(paddle.Tensor, 'item'): + logger.warn( + "override item of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.item = item + + +def func_long(x: paddle.Tensor): + return paddle.cast(x, paddle.long) + + +if not hasattr(paddle.Tensor, 'long'): + logger.warn( + "override long of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.long = func_long + +if not hasattr(paddle.Tensor, 'numel'): + logger.warn( + "override numel of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.numel = paddle.numel + + +def new_full(x: paddle.Tensor, + size: Union[List[int], Tuple[int], paddle.Tensor], + fill_value: Union[float, int, bool, paddle.Tensor], + dtype=None): + return paddle.full(size, fill_value, dtype=x.dtype) + + +if not hasattr(paddle.Tensor, 'new_full'): + logger.warn( + "override new_full of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.new_full = new_full + + +def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: + if convert_dtype_to_string(xs.dtype) == paddle.bool: + xs = xs.astype(paddle.int) + return xs.equal( + paddle.to_tensor( + ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place)) + + +if not hasattr(paddle.Tensor, 'eq'): + logger.warn( + "override eq of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.eq = eq + +if not hasattr(paddle, 'eq'): + logger.warn( + "override eq of paddle if exists or register, remove this when fixed!") + paddle.eq = eq + + +def contiguous(xs: paddle.Tensor) -> paddle.Tensor: + return xs + + +if not hasattr(paddle.Tensor, 'contiguous'): + logger.warn( + "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.contiguous = contiguous + + +def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + nargs = len(args) + assert (nargs <= 1) + s = paddle.shape(xs) + if nargs == 1: + return s[args[0]] + else: + return s + + +#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. +logger.warn( + "override size of paddle.Tensor " + "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" +) +paddle.Tensor.size = size + + +def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + return xs.reshape(args) + + +if not hasattr(paddle.Tensor, 'view'): + logger.warn("register user view to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view = view + + +def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: + return xs.reshape(ys.size()) + + +if not hasattr(paddle.Tensor, 'view_as'): + logger.warn( + "register user view_as to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view_as = view_as + + +def is_broadcastable(shp1, shp2): + for a, b in zip(shp1[::-1], shp2[::-1]): + if a == 1 or b == 1 or a == b: + pass + else: + return False + return True + + +def masked_fill(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]): + assert is_broadcastable(xs.shape, mask.shape) is True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value + xs = paddle.where(mask, trues, xs) + return xs + + +if not hasattr(paddle.Tensor, 'masked_fill'): + logger.warn( + "register user masked_fill to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill = masked_fill + + +def masked_fill_(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]) -> paddle.Tensor: + assert is_broadcastable(xs.shape, mask.shape) is True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value + ret = paddle.where(mask, trues, xs) + paddle.assign(ret.detach(), output=xs) + return xs + + +if not hasattr(paddle.Tensor, 'masked_fill_'): + logger.warn( + "register user masked_fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill_ = masked_fill_ + + +def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: + val = paddle.full_like(xs, value) + paddle.assign(val.detach(), output=xs) + return xs + + +if not hasattr(paddle.Tensor, 'fill_'): + logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.fill_ = fill_ + + +def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: + return paddle.tile(xs, size) + + +if not hasattr(paddle.Tensor, 'repeat'): + logger.warn( + "register user repeat to paddle.Tensor, remove this when fixed!") + paddle.Tensor.repeat = repeat + +if not hasattr(paddle.Tensor, 'softmax'): + logger.warn( + "register user softmax to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle.Tensor, 'sigmoid'): + logger.warn( + "register user sigmoid to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) + +if not hasattr(paddle.Tensor, 'relu'): + logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) + + +def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: + return x.astype(other.dtype) + + +if not hasattr(paddle.Tensor, 'type_as'): + logger.warn( + "register user type_as to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'type_as', type_as) + + +def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: + assert len(args) == 1 + if isinstance(args[0], str): # dtype + return x.astype(args[0]) + elif isinstance(args[0], paddle.Tensor): #Tensor + return x.astype(args[0].dtype) + else: # Device + return x + + +if not hasattr(paddle.Tensor, 'to'): + logger.warn("register user to to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'to', to) + + +def func_float(x: paddle.Tensor) -> paddle.Tensor: + return x.astype(paddle.float) + + +if not hasattr(paddle.Tensor, 'float'): + logger.warn("register user float to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'float', func_float) + + +def func_int(x: paddle.Tensor) -> paddle.Tensor: + return x.astype(paddle.int) + + +if not hasattr(paddle.Tensor, 'int'): + logger.warn("register user int to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'int', func_int) + + +def tolist(x: paddle.Tensor) -> List[Any]: + return x.numpy().tolist() + + +if not hasattr(paddle.Tensor, 'tolist'): + logger.warn( + "register user tolist to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'tolist', tolist) + + +########### hcak paddle.nn ############# +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return F.glu(xs, axis=self.dim) + + +if not hasattr(paddle.nn, 'GLU'): + logger.warn("register user GLU to paddle.nn, remove this when fixed!") + setattr(paddle.nn, 'GLU', GLU) diff --git a/examples/1xt2x/deepspeech2x/bin/test.py b/examples/1xt2x/deepspeech2x/bin/test.py new file mode 100644 index 0000000000000000000000000000000000000000..3fa0a61de95196411c978de85fa2aa6f9821217c --- /dev/null +++ b/examples/1xt2x/deepspeech2x/bin/test.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for DeepSpeech2 model.""" +from deepspeech2x.model import DeepSpeech2Tester as Tester + +from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument("--model_type") + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + args = parser.parse_args() + print_arguments(args, globals()) + if args.model_type is None: + args.model_type = 'offline' + print("model_type:{}".format(args.model_type)) + + # https://yaml.org/type/float.html + config = get_cfg_defaults(args.model_type) + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/examples/1xt2x/deepspeech2x/model.py b/examples/1xt2x/deepspeech2x/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbc502d2e8e29479b091afb7a3a02fe9c2c85b1 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/model.py @@ -0,0 +1,427 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains DeepSpeech2 and DeepSpeech2Online model.""" +import time +from collections import defaultdict +from contextlib import nullcontext +from pathlib import Path +from typing import Optional + +import numpy as np +import paddle +from deepspeech2x.models.ds2 import DeepSpeech2InferModel +from deepspeech2x.models.ds2 import DeepSpeech2Model +from paddle import distributed as dist +from paddle.io import DataLoader +from yacs.config import CfgNode + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline +from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.training.trainer import Trainer +from deepspeech.utils import error_rate +from deepspeech.utils import layer_tools +from deepspeech.utils import mp_tools +from deepspeech.utils.log import Log +#from deepspeech.utils.log import Autolog + +logger = Log(__name__).getlog() + + +class DeepSpeech2Trainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + lr=5e-4, # learning rate + lr_decay=1.0, # learning rate decay + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=5.0, # the global norm clip + n_epoch=50, # train epochs + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def train_batch(self, batch_index, batch_data, msg): + train_conf = self.config.training + start = time.time() + + # forward + utt, audio, audio_len, text, text_len = batch_data + loss = self.model(audio, audio_len, text, text_len) + losses_np = { + 'train_loss': float(loss), + } + + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step + if (batch_index + 1) % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.iteration += 1 + + iteration_time = time.time() - start + + msg += "train time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config.collator.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_np.items()) + logger.info(msg) + + if dist.get_rank() == 0 and self.visualizer: + for k, v in losses_np.items(): + # `step -1` since we update `step` after optimizer.step(). + self.visualizer.add_scalar("train/{}".format(k), v, + self.iteration - 1) + + @paddle.no_grad() + def valid(self): + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + self.model.eval() + valid_losses = defaultdict(list) + num_seen_utts = 1 + total_loss = 0.0 + for i, batch in enumerate(self.valid_loader): + utt, audio, audio_len, text, text_len = batch + loss = self.model(audio, audio_len, text, text_len) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + num_seen_utts += num_utts + total_loss += float(loss) * num_utts + valid_losses['val_loss'].append(float(loss)) + + if (i + 1) % self.config.training.log_interval == 0: + valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} + valid_dump['val_history_loss'] = total_loss / num_seen_utts + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_dump.items()) + logger.info(msg) + + logger.info('Rank {} Val info val_loss {}'.format( + dist.get_rank(), total_loss / num_seen_utts)) + return total_loss, num_seen_utts + + def setup_model(self): + config = self.config.clone() + config.defrost() + config.model.feat_size = self.train_loader.collate_fn.feature_size + #config.model.dict_size = self.train_loader.collate_fn.vocab_size + config.model.dict_size = len(self.train_loader.collate_fn.vocab_list) + config.freeze() + + if self.args.model_type == 'offline': + model = DeepSpeech2Model.from_config(config.model) + elif self.args.model_type == 'online': + model = DeepSpeech2ModelOnline.from_config(config.model) + else: + raise Exception("wrong model type") + if self.parallel: + model = paddle.DataParallel(model) + + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + grad_clip = ClipGradByGlobalNormWithLog( + config.training.global_grad_clip) + lr_scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=config.training.lr, + gamma=config.training.lr_decay, + verbose=True) + optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=paddle.regularizer.L2Decay( + config.training.weight_decay), + grad_clip=grad_clip) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + logger.info("Setup model/optimizer/lr_scheduler!") + + def setup_dataloader(self): + config = self.config.clone() + config.defrost() + config.collator.keep_transcription_text = False + + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + collate_fn_test = SpeechCollator.from_config(config) + + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.collator.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev) + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_test) + if "" in self.test_loader.collate_fn.vocab_list: + self.test_loader.collate_fn.vocab_list.remove("") + if "" in self.valid_loader.collate_fn.vocab_list: + self.valid_loader.collate_fn.vocab_list.remove("") + if "" in self.train_loader.collate_fn.vocab_list: + self.train_loader.collate_fn.vocab_list.remove("") + logger.info("Setup train/valid/test Dataloader!") + + +class DeepSpeech2Tester(DeepSpeech2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # testing config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def ordid2token(self, texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([chr(i) for i in ids])) + return trans + + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): + cfg = self.config.decoding + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer + + vocab_list = self.test_loader.collate_fn.vocab_list + if "" in vocab_list: + space_id = vocab_list.index("") + vocab_list[space_id] = " " + + target_transcripts = self.ordid2token(texts, texts_len) + + result_transcripts = self.compute_result_transcripts(audio, audio_len, + vocab_list, cfg) + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") + logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + logger.info("Current error rate [%s] = %f" % + (cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type) + + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + return result_transcripts + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + self.model.eval() + cfg = self.config + error_rate_type = None + errors_sum, len_refs, num_ins = 0.0, 0, 0 + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + utts, audio, audio_len, texts, texts_len = batch + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) + + # logging + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "Final error rate [%s] (%d/%d) = %f" % ( + error_rate_type, num_ins, num_ins, errors_sum / len_refs) + logger.info(msg) + + # self.autolog.report() + + def run_test(self): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def export(self): + if self.args.model_type == 'offline': + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + elif self.args.model_type == 'online': + infer_model = DeepSpeech2InferModelOnline.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + else: + raise Exception("wrong model type") + + infer_model.eval() + feat_dim = self.test_loader.collate_fn.feature_size + static_model = infer_model.export() + logger.info(f"Export code: {static_model.forward.code}") + paddle.jit.save(static_model, self.args.export_path) + + def run_export(self): + try: + self.export() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') + + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir diff --git a/examples/1xt2x/deepspeech2x/models/__init__.py b/examples/1xt2x/deepspeech2x/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/examples/1xt2x/deepspeech2x/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/1xt2x/deepspeech2x/models/ds2/__init__.py b/examples/1xt2x/deepspeech2x/models/ds2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39bea5bf9da14bd4ebd89518dd68789534cfd266 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/models/ds2/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .deepspeech2 import DeepSpeech2InferModel +from .deepspeech2 import DeepSpeech2Model + +__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] diff --git a/examples/1xt2x/deepspeech2x/models/ds2/deepspeech2.py b/examples/1xt2x/deepspeech2x/models/ds2/deepspeech2.py new file mode 100644 index 0000000000000000000000000000000000000000..f154ddb54354a546702ed950f423a95930df1e15 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/models/ds2/deepspeech2.py @@ -0,0 +1,314 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Deepspeech2 ASR Model""" +from typing import Optional + +import paddle +from deepspeech2x.models.ds2.rnn import RNNStack +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.models.ds2.conv import ConvStack +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] + + +class CRNNEncoder(nn.Layer): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True): + super().__init__() + self.rnn_size = rnn_size + self.feat_size = feat_size # 161 for linear + self.dict_size = dict_size + + self.conv = ConvStack(feat_size, num_conv_layers) + + i_size = self.conv.output_height # H after conv stack + self.rnn = RNNStack( + i_size=i_size, + h_size=rnn_size, + num_stacks=num_rnn_layers, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + + @property + def output_size(self): + return self.rnn_size * 2 + + def forward(self, audio, audio_len): + """Compute Encoder outputs + + Args: + audio (Tensor): [B, Tmax, D] + text (Tensor): [B, Umax] + audio_len (Tensor): [B] + text_len (Tensor): [B] + Returns: + x (Tensor): encoder outputs, [B, T, D] + x_lens (Tensor): encoder length, [B] + """ + # [B, T, D] -> [B, D, T] + audio = audio.transpose([0, 2, 1]) + # [B, D, T] -> [B, C=1, D, T] + x = audio.unsqueeze(1) + x_lens = audio_len + + # convolution group + x, x_lens = self.conv(x, x_lens) + x_val = x.numpy() + + # convert data from convolution feature map to sequence of vectors + #B, C, D, T = paddle.shape(x) # not work under jit + x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] + #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit + x = x.reshape([0, 0, -1]) #[B, T, C*D] + + # remove padding part + x, x_lens = self.rnn(x, x_lens) #[B, T, D] + return x, x_lens + + +class DeepSpeech2Model(nn.Layer): + """The DeepSpeech2 network structure. + + :param audio_data: Audio spectrogram data layer. + :type audio_data: Variable + :param text_data: Transcription text data layer. + :type text_data: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable + :param masks: Masks data layer to reset padding. + :type masks: Variable + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (dimension of RNN cells). + :type rnn_size: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward direction RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=3, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + use_gru=True, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + )) + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, + blank_id=0): + super().__init__() + self.encoder = CRNNEncoder( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + assert (self.encoder.output_size == rnn_size * 2) + + self.decoder = CTCDecoder( + odim=dict_size, # is in vocab + enc_n_units=self.encoder.output_size, + blank_id=blank_id, # first token is + dropout_rate=0.0, + reduction=True, # sum + batch_average=True) # sum / batch_size + + def forward(self, audio, audio_len, text, text_len): + """Compute Model loss + + Args: + audio (Tenosr): [B, T, D] + audio_len (Tensor): [B] + text (Tensor): [B, U] + text_len (Tensor): [B] + + Returns: + loss (Tenosr): [1] + """ + eouts, eouts_len = self.encoder(audio, audio_len) + loss = self.decoder(eouts, eouts_len, text, text_len) + return loss + + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + # init once + # decoders only accept string encoded in utf-8 + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.softmax(eouts) + print("probs.shape", probs.shape) + return self.decoder.decode_probs( + probs.numpy(), eouts_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) + + def decode_probs_split(self, probs_split, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, + cutoff_prob, cutoff_top_n, num_processes): + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + return self.decoder.decode_probs_split( + probs_split, vocab_list, decoding_method, lang_model_path, + beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, + num_processes) + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + Parameters + ---------- + dataloader: paddle.io.DataLoader + + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path or str + the path of pretrained model checkpoint, without extension name + + Returns + ------- + DeepSpeech2Model + The model built from pretrained result. + """ + model = cls(feat_size=dataloader.collate_fn.feature_size, + dict_size=len(dataloader.collate_fn.vocab_list), + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + infos = Checkpoint().load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2Model from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2Model + The model built from config. + """ + model = cls(feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + use_gru=config.use_gru, + share_rnn_weights=config.share_rnn_weights, + blank_id=config.blank_id) + return model + + +class DeepSpeech2InferModel(DeepSpeech2Model): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, + blank_id=0): + super().__init__( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights, + blank_id=blank_id) + + def forward(self, audio, audio_len): + """export model function + + Args: + audio (Tensor): [B, T, D] + audio_len (Tensor): [B] + + Returns: + probs: probs after softmax + """ + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.softmax(eouts) + return probs, eouts_len + + def export(self): + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + return static_model diff --git a/examples/1xt2x/deepspeech2x/models/ds2/rnn.py b/examples/1xt2x/deepspeech2x/models/ds2/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..e45db7c053d4128bf51ed1787af3a4f78c9081e4 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/models/ds2/rnn.py @@ -0,0 +1,334 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.activation import brelu +from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ['RNNStack'] + + +class RNNCell(nn.RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + The formula used is as follows: + .. math:: + h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`act` is for :attr:`activation`. + """ + + def __init__(self, + hidden_size: int, + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + if activation not in ["tanh", "relu", "brelu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + if activation == 'brelu': + self._activation_fn = brelu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = inputs + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h + + @property + def state_shape(self): + return (self.hidden_size, ) + + +class GRUCell(nn.RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + The formula for GRU used is as follows: + .. math:: + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) + z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) + \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) + h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + multiplication operator. + """ + + def __init__(self, + input_size: int, + hidden_size: int, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.relu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states # shape [batch_size, hidden_size] + + x_gates = inputs + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + bias_u, bias_r, bias_c = paddle.split( + self.bias_hh, num_or_sections=3, axis=0) + + weight_hh = paddle.transpose( + self.weight_hh, + perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size] + w_u_r_c = paddle.flatten(weight_hh) + size_u_r = self.hidden_size * 2 * self.hidden_size + w_u_r = paddle.reshape(w_u_r_c[:size_u_r], + (self.hidden_size, self.hidden_size * 2)) + w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1) + w_c = paddle.reshape(w_u_r_c[size_u_r:], + (self.hidden_size, self.hidden_size)) + + h_u = paddle.matmul( + pre_hidden, w_u, + transpose_y=False) + bias_u #shape [batch_size, hidden_size] + h_r = paddle.matmul( + pre_hidden, w_r, + transpose_y=False) + bias_r #shape [batch_size, hidden_size] + + x_u, x_r, x_c = paddle.split( + x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size] + + u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size] + r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size] + c = self._activation( + x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) + + bias_c) # [batch_size, hidden_size] + + h = (1 - u) * pre_hidden + u * c + # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class BiRNNWithBN(nn.Layer): + """Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param size: Dimension of RNN cells. + :type size: int + :param share_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + :type share_weights: bool + :return: Bidirectional simple rnn layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int, share_weights: bool): + super().__init__() + self.share_weights = share_weights + if self.share_weights: + #input-hidden weights shared between bi-directional rnn. + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + # batch norm is only performed on input-state projection + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = self.fw_fc + self.bw_bn = self.fw_bn + else: + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + + self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class BiGRUWithBN(nn.Layer): + """Bidirectonal gru layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer. + :type name: string + :param input: Input layer. + :type input: Variable + :param size: Dimension of GRU cells. + :type size: int + :param act: Activation type. + :type act: string + :return: Bidirectional GRU layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int): + super().__init__() + hidden_size = h_size * 3 + + self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + + self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class RNNStack(nn.Layer): + """RNN group with stacked bidirectional simple RNN or GRU layers. + + :param input: Input layer. + :type input: Variable + :param size: Dimension of RNN cells in each layer. + :type size: int + :param num_stacks: Number of stacked rnn layers. + :type num_stacks: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: Output layer of the RNN group. + :rtype: Variable + """ + + def __init__(self, + i_size: int, + h_size: int, + num_stacks: int, + use_gru: bool, + share_rnn_weights: bool): + super().__init__() + rnn_stacks = [] + for i in range(num_stacks): + if use_gru: + #default:GRU using tanh + rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) + else: + rnn_stacks.append( + BiRNNWithBN( + i_size=i_size, + h_size=h_size, + share_weights=share_rnn_weights)) + i_size = h_size * 2 + + self.rnn_stacks = nn.LayerList(rnn_stacks) + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + """ + x: shape [B, T, D] + x_len: shpae [B] + """ + for i, rnn in enumerate(self.rnn_stacks): + x, x_len = rnn(x, x_len) + masks = make_non_pad_mask(x_len) #[B, T] + masks = masks.unsqueeze(-1) # [B, T, 1] + # TODO(Hui Zhang): not support bool multiply + masks = masks.astype(x.dtype) + x = x.multiply(masks) + return x, x_len diff --git a/examples/1xt2x/librispeech/.gitignore b/examples/1xt2x/librispeech/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7024e0e954e16122e3df2e2778949668c7692d72 --- /dev/null +++ b/examples/1xt2x/librispeech/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/1xt2x/librispeech/conf/augmentation.json b/examples/1xt2x/librispeech/conf/augmentation.json new file mode 100644 index 0000000000000000000000000000000000000000..fe51488c7066f6687ef680d6bfaa4f7768ef205c --- /dev/null +++ b/examples/1xt2x/librispeech/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/1xt2x/librispeech/conf/deepspeech2.yaml b/examples/1xt2x/librispeech/conf/deepspeech2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..edef07972b41cabe98ae54edfe9b560299b2e19f --- /dev/null +++ b/examples/1xt2x/librispeech/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test-clean + min_input_len: 0.0 + max_input_len: 1000.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 2048 + use_gru: False + share_rnn_weights: True + blank_id: 28 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/1xt2x/librispeech/local/data.sh b/examples/1xt2x/librispeech/local/data.sh new file mode 100755 index 0000000000000000000000000000000000000000..22a86bb2edb31f23ab8a161066bf7102bde54c68 --- /dev/null +++ b/examples/1xt2x/librispeech/local/data.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf librispeech_v1.8_to_v2.x.tar.gz +mv librispeech_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm librispeech_v1.8_to_v2.x.tar.gz -f + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/librispeech/librispeech.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/librispeech" \ + --full_download="True" + + if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 + fi + + for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mv data/manifest.${set} data/manifest.${set}.raw + done + + rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw + for set in train-clean-100 train-clean-360 train-other-500; do + cat data/manifest.${set}.raw >> data/manifest.train.raw + done + + for set in dev-clean dev-other; do + cat data/manifest.${set}.raw >> data/manifest.dev.raw + done + + for set in test-clean test-other; do + cat data/manifest.${set}.raw >> data/manifest.test.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=2000 \ + --specgram_type="linear" \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --use_dB_normalization=True \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test dev-clean dev-other test-clean test-other; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type ${unit_type} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest.${set} failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "LibriSpeech Data preparation done." +exit 0 + diff --git a/examples/1xt2x/librispeech/local/download_lm_en.sh b/examples/1xt2x/librispeech/local/download_lm_en.sh new file mode 100755 index 0000000000000000000000000000000000000000..dc1bdf665ac7783bc1e7344fbcbddc0b9744f44b --- /dev/null +++ b/examples/1xt2x/librispeech/local/download_lm_en.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm +MD5="099a601759d467cd0a8523ff939819c5" +TARGET=${DIR}/common_crawl_00.prune01111.trie.klm + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/librispeech/local/download_model.sh b/examples/1xt2x/librispeech/local/download_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..cc6a9ec7536a1d49ac5b95b199c78116346220b3 --- /dev/null +++ b/examples/1xt2x/librispeech/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz' +MD5=7b0f582fe2f5a840b840e7ee52246bc5 +TARGET=./librispeech_v1.8_to_v2.x.tar.gz + + +echo "Download LibriSpeech model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download LibriSpeech model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/1xt2x/librispeech/local/test.sh b/examples/1xt2x/librispeech/local/test.sh new file mode 100755 index 0000000000000000000000000000000000000000..4d00f30b852da5a370f5d4934f3caadd2b833c00 --- /dev/null +++ b/examples/1xt2x/librispeech/local/test.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_en.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/librispeech/path.sh b/examples/1xt2x/librispeech/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..080ab1f797f7bb516d7aad379d3a915515ec86e9 --- /dev/null +++ b/examples/1xt2x/librispeech/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/1xt2x/librispeech/run.sh b/examples/1xt2x/librispeech/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..05706a428c876e71a3af8efba91f5a6f025eb13f --- /dev/null +++ b/examples/1xt2x/librispeech/run.sh @@ -0,0 +1,26 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=librispeech_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index e5ebfcbaf14ee900d22b1bc1967ddd4099ef6e33..ee0f1405e82d4f48f5c1afcfc59ce2ded52175fd 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -10,8 +10,11 @@ | Model | Params | Release | Config | Test set | Loss | CER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 | +| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 6.016139030456543 | 0.066549 | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 58.4M | 7181e427 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | | DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | +| --- | --- | --- | --- | --- | --- | --- | | DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 | diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 7f0a1462f286e82ebb1ae5d0f1e2fd1b77bb1820..9560930acb9b831ad231279da3dc3bf4b9651a39 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -40,9 +40,12 @@ model: rnn_layer_size: 1024 use_gru: True share_rnn_weights: False + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 80 + accum_grad: 1 lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index fdc3a5365e37200f3286a7038f46d2cc8ebf82c3..7e87594ccbfe0de36d09fcc1bbdb9d9a932603fe 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -36,17 +36,20 @@ collator: model: num_conv_layers: 2 - num_rnn_layers: 3 + num_rnn_layers: 5 rnn_layer_size: 1024 rnn_direction: forward # [forward, bidirect] - num_fc_layers: 1 - fc_layers_size_list: 512, + num_fc_layers: 0 + fc_layers_size_list: -1, use_gru: False - + blank_id: 0 + ctc_grad_norm_type: instance + training: n_epoch: 50 + accum_grad: 1 lr: 2e-3 - lr_decay: 0.91 # 0.83 + lr_decay: 0.9 # 0.83 weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 @@ -59,7 +62,7 @@ decoding: error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm - alpha: 1.9 + alpha: 2.2 #1.9 beta: 5.0 beam_size: 300 cutoff_prob: 0.99 diff --git a/examples/aishell/s0/local/client.sh b/examples/aishell/s0/local/client.sh deleted file mode 100755 index 3b59ad3dff202a11cdae438661fbb43348502ec6..0000000000000000000000000000000000000000 --- a/examples/aishell/s0/local/client.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -source path.sh - -# run on MacOS -# brew install portaudio -# pip install pyaudio -# pip install keyboard - -# start demo client -python3 -u ${BIN_DIR}/deploy/client.py \ ---host_ip="localhost" \ ---host_port=8086 \ - -if [ $? -ne 0 ]; then - echo "Failed in starting demo client!" - exit 1 -fi - -exit 0 diff --git a/examples/aishell/s0/local/export.sh b/examples/aishell/s0/local/export.sh index 2e09e5f5e76a7f7cdf9cca8fbb91d66bb48aea0c..a5e62c28d2fa23a5ef9b9e2a0281b025aa943a30 100755 --- a/examples/aishell/s0/local/export.sh +++ b/examples/aishell/s0/local/export.sh @@ -13,13 +13,7 @@ ckpt_path_prefix=$2 jit_model_export_path=$3 model_type=$4 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/aishell/s0/local/server.sh b/examples/aishell/s0/local/server.sh deleted file mode 100755 index 2b88109932e619881128054bd23a012b27cb672d..0000000000000000000000000000000000000000 --- a/examples/aishell/s0/local/server.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash -# TODO: replace the model with a mandarin model - -if [[ $# != 1 ]];then - echo "usage: $1 checkpoint_path" - exit -1 -fi - -source path.sh - -# download language model -bash local/download_lm_ch.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -# download well-trained model -#bash local/download_model.sh -#if [ $? -ne 0 ]; then -# exit 1 -#fi - -# start demo server -CUDA_VISIBLE_DEVICES=0 \ -python3 -u ${BIN_DIR}/deploy/server.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---host_ip="localhost" \ ---host_port=8086 \ ---speech_save_dir="demo_cache" \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in starting demo server!" - exit 1 -fi - - -exit 0 diff --git a/examples/aishell/s0/local/test.sh b/examples/aishell/s0/local/test.sh index 9fd0bc8d5bcdded1d33990a8ae20101d0a538441..2ae0740b3e8d44ab03e45f4c1b5dbb945657705e 100755 --- a/examples/aishell/s0/local/test.sh +++ b/examples/aishell/s0/local/test.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 model_type=$3 @@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then fi python3 -u ${BIN_DIR}/test.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/aishell/s0/local/test_export.sh b/examples/aishell/s0/local/test_export.sh index b6d580979c9797d7de59b7555c7304e00658b591..a9a6b122df8055f872f9f0a68717b57241d99359 100755 --- a/examples/aishell/s0/local/test_export.sh +++ b/examples/aishell/s0/local/test_export.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 jit_model_export_path=$2 model_type=$3 @@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then fi python3 -u ${BIN_DIR}/test_export.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${jit_model_export_path}.rsl \ --export_path ${jit_model_export_path} \ diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index 3438a7357c65ac6366aee95b8669b9117dc6da2e..edbf3383070ed0c8a71a125c0e77371d0a4412b6 100755 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -12,27 +12,22 @@ config_path=$1 ckpt_name=$2 model_type=$3 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - mkdir -p exp +# seed may break model convergence seed=10086 -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ --model_type ${model_type} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/aishell/s0/local/tune.sh b/examples/aishell/s0/local/tune.sh deleted file mode 100755 index 59406cd5b3b28f17a3a0d4cbb1e01c97f6cb3703..0000000000000000000000000000000000000000 --- a/examples/aishell/s0/local/tune.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -# grid-search for hyper-parameters in language model -python3 -u ${BIN_DIR}/tune.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---num_batches=10 \ ---batch_size=128 \ ---beam_size=300 \ ---num_proc_bsearch=8 \ ---num_alphas=10 \ ---num_betas=10 \ ---alpha_from=0.0 \ ---alpha_to=5.0 \ ---beta_from=-6 \ ---beta_to=6 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in tuning!" - exit 1 -fi - - -exit 0 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index e5ab12a591e8827a8a9d5c1eaceaf680d365d55d..71191c3ac60081eb128cc93a95e58c6feb17d510 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -27,7 +27,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 3e606788ef86eee601233fab2039923b1ee8cb34..6f8ae135f6210757208dde85cad85e5ee776f381 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 4b1430c58848a5cac7303518021a6256b52d525d..a4248459c261d442a3e23e18cf538927cda5236a 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -71,6 +71,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/aishell/s1/local/align.sh b/examples/aishell/s1/local/align.sh index ad6c84bc8cf398cddecfadbbb47b7ef9c60e9158..279461aafe19967ac054df815b08e60e351fcc7f 100755 --- a/examples/aishell/s1/local/align.sh +++ b/examples/aishell/s1/local/align.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 @@ -22,8 +18,7 @@ mkdir -p ${output_dir} # align dump in `result_file` # .tier, .TextGrid dump in `dir of result_file` python3 -u ${BIN_DIR}/alignment.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.align \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/aishell/s1/local/export.sh b/examples/aishell/s1/local/export.sh index f99a15bade1c89f968e84a6c10d500466f884d5b..b562218e7a76c91d3b906d2c099218bf492627a8 100755 --- a/examples/aishell/s1/local/export.sh +++ b/examples/aishell/s1/local/export.sh @@ -12,13 +12,7 @@ config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/aishell/s1/local/test.sh b/examples/aishell/s1/local/test.sh index f7e99ad7f668d804a108e295faf9ac2e58760595..c87412c9b67533f1ff3275097099f628f964a608 100755 --- a/examples/aishell/s1/local/test.sh +++ b/examples/aishell/s1/local/test.sh @@ -8,11 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - config_path=$1 ckpt_prefix=$2 @@ -39,8 +34,7 @@ for type in attention ctc_greedy_search; do output_dir=${ckpt_prefix} mkdir -p ${output_dir} python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ + --nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ @@ -58,8 +52,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do output_dir=${ckpt_prefix} mkdir -p ${output_dir} python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ + --nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/aishell/s1/local/train.sh b/examples/aishell/s1/local/train.sh index ec17054ab1e6fe6b831846b2d7b10a4b9fb05479..71af3a006deec9032fcefd561fc46cea274d1e10 100755 --- a/examples/aishell/s1/local/train.sh +++ b/examples/aishell/s1/local/train.sh @@ -1,37 +1,43 @@ #!/bin/bash +profiler_options= +benchmark_batch_size=0 +benchmark_max_step=0 + +# seed may break model convergence +seed=0 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +if [ ${seed} != 0 ]; then + export FLAGS_cudnn_deterministic=True + echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." +fi + if [ $# != 2 ];then echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" exit -1 fi -ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -echo "using $ngpu gpus..." - config_path=$1 ckpt_name=$2 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi -echo "using ${device}..." - mkdir -p exp -seed=1024 -if [ ${seed} ]; then - export FLAGS_cudnn_deterministic=True -fi - python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ +--seed ${seed} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---seed ${seed} +--profiler-options "${profiler_options}" \ +--benchmark-batch-size ${benchmark_batch_size} \ +--benchmark-max-step ${benchmark_max_step} + -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index d55d47ea626b56751eb89b2f861029e4e1b7bcdf..e3c008234340aac108b26e59075a39acb7b41d13 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/callcenter/s1/local/align.sh b/examples/callcenter/s1/local/align.sh index f2c878c20c6d5bb9a89b437c6c0d5b11b1b80c4f..b679e2ea7fcb1d5ebd5a50dc88cb90e59e3e4789 100755 --- a/examples/callcenter/s1/local/align.sh +++ b/examples/callcenter/s1/local/align.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 @@ -20,7 +16,6 @@ ckpt_name=$(basename ${ckpt_prefxi}) mkdir -p exp - batch_size=1 output_dir=${ckpt_prefix} mkdir -p ${output_dir} @@ -28,8 +23,7 @@ mkdir -p ${output_dir} # align dump in `result_file` # .tier, .TextGrid dump in `dir of result_file` python3 -u ${BIN_DIR}/alignment.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.align \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/callcenter/s1/local/export.sh b/examples/callcenter/s1/local/export.sh index d171899cdbf4220436bb71ad07e3a1a5e9ea8bc2..d5f912e9033b1d71423d87ef0a9cf4b1991d3563 100755 --- a/examples/callcenter/s1/local/export.sh +++ b/examples/callcenter/s1/local/export.sh @@ -12,13 +12,7 @@ config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/callcenter/s1/local/test.sh b/examples/callcenter/s1/local/test.sh index 7a5b1cdb1c6f62fe25e6349226198ea6d8868a5e..dca3137dd0cf1faadd6b3c383dcb178881bb9080 100755 --- a/examples/callcenter/s1/local/test.sh +++ b/examples/callcenter/s1/local/test.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 @@ -32,8 +28,7 @@ for type in attention ctc_greedy_search; do output_dir=${ckpt_prefix} mkdir -p ${output_dir} python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ + --nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ @@ -51,8 +46,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do output_dir=${ckpt_prefix} mkdir -p ${output_dir} python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ + --nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/callcenter/s1/local/train.sh b/examples/callcenter/s1/local/train.sh index 928c6492c41302b1e863fbd0b3b3be7dafca7103..eb8f86626f8988890c89dc29accb5ae1f172e1d6 100755 --- a/examples/callcenter/s1/local/train.sh +++ b/examples/callcenter/s1/local/train.sh @@ -11,27 +11,23 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +# seed may break model convergence +seed=0 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/callcenter/s1/run.sh b/examples/callcenter/s1/run.sh index 52dd44eca36f92567a18377797c570de4481e8e1..305021f1919a50c6a2360ce51c7f7abcf851b14e 100644 --- a/examples/callcenter/s1/run.sh +++ b/examples/callcenter/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/cc-cedict/README.md b/examples/cc-cedict/README.md index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..513fca5335cd5ef5ab7ec0005d53b11142fb2d90 100644 --- a/examples/cc-cedict/README.md +++ b/examples/cc-cedict/README.md @@ -0,0 +1,58 @@ +# [CC-CEDICT](https://cc-cedict.org/wiki/) + +What is CC-CEDICT? +CC-CEDICT is a continuation of the CEDICT project. +The objective of the CEDICT project was to create an online, downloadable (as opposed to searchable-only) public-domain Chinese-English dictionary. +CEDICT was started by Paul Andrew Denisowski in October 1997. +For the most part, the project is modeled on Jim Breen's highly successful EDICT (Japanese-English dictionary) project and is intended to be a collaborative effort, +with users providing entries and corrections to the main file. + + +## Parse CC-CEDICT to Json format + +1. Parse to Json + +``` +run.sh +``` + +2. Result + +``` +exp/ +|-- cedict +`-- cedict.json + +0 directories, 2 files +``` + +``` +4c4bffc84e24467fe1b2ea9ba37ed6b6 exp/cedict +3adf504dacd13886f88cc9fe3b37c75d exp/cedict.json +``` + +``` +==> exp/cedict <== +# CC-CEDICT +# Community maintained free Chinese-English dictionary. +# +# Published by MDBG +# +# License: +# Creative Commons Attribution-ShareAlike 4.0 International License +# https://creativecommons.org/licenses/by-sa/4.0/ +# +# Referenced works: + +==> exp/cedict.json <== +{"traditional": "2019\u51a0\u72c0\u75c5\u6bd2\u75c5", "simplified": "2019\u51a0\u72b6\u75c5\u6bd2\u75c5", "pinyin": "er4 ling2 yi1 jiu3 guan1 zhuang4 bing4 du2 bing4", "english": "COVID-19, the coronavirus disease identified in 2019"} +{"traditional": "21\u4e09\u9ad4\u7d9c\u5408\u75c7", "simplified": "21\u4e09\u4f53\u7efc\u5408\u75c7", "pinyin": "er4 shi2 yi1 san1 ti3 zong1 he2 zheng4", "english": "trisomy"} +{"traditional": "3C", "simplified": "3C", "pinyin": "san1 C", "english": "abbr. for computers, communications, and consumer electronics"} +{"traditional": "3P", "simplified": "3P", "pinyin": "san1 P", "english": "(slang) threesome"} +{"traditional": "3Q", "simplified": "3Q", "pinyin": "san1 Q", "english": "(Internet slang) thank you (loanword)"} +{"traditional": "421", "simplified": "421", "pinyin": "si4 er4 yi1", "english": "four grandparents, two parents and an only child"} +{"traditional": "502\u81a0", "simplified": "502\u80f6", "pinyin": "wu3 ling2 er4 jiao1", "english": "cyanoacrylate glue"} +{"traditional": "88", "simplified": "88", "pinyin": "ba1 ba1", "english": "(Internet slang) bye-bye (alternative for \u62dc\u62dc[bai2 bai2])"} +{"traditional": "996", "simplified": "996", "pinyin": "jiu3 jiu3 liu4", "english": "9am-9pm, six days a week (work schedule)"} +{"traditional": "A", "simplified": "A", "pinyin": "A", "english": "(slang) (Tw) to steal"} +``` diff --git a/examples/chinese_g2p/README.md b/examples/chinese_g2p/README.md deleted file mode 100644 index e3fdfe6841c0c1f1f1653324854d01c9054ea6cf..0000000000000000000000000000000000000000 --- a/examples/chinese_g2p/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Download Baker dataset - -Baker dataset has to be downloaded mannually and moved to 'data/', because you will have to pass the CATTCHA from a browswe to download the dataset. - -Download URL https://test.data-baker.com/#/data/index/source. diff --git a/examples/chinese_g2p/.gitignore b/examples/g2p/.gitignore similarity index 100% rename from examples/chinese_g2p/.gitignore rename to examples/g2p/.gitignore diff --git a/examples/g2p/README.md b/examples/g2p/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4ec5922b31cafab8b83a6409383578a781c7931d --- /dev/null +++ b/examples/g2p/README.md @@ -0,0 +1,3 @@ +# G2P + +* zh - Chinese G2P diff --git a/examples/g2p/zh/README.md b/examples/g2p/zh/README.md new file mode 100644 index 0000000000000000000000000000000000000000..de5573565962082941b82504a083f84d8dec5529 --- /dev/null +++ b/examples/g2p/zh/README.md @@ -0,0 +1,93 @@ +# G2P + +* WS +jieba +* G2P +pypinyin +* Tone sandhi +simple + +We recommend using [Paraket](https://github.com/PaddlePaddle/Parakeet] [TextFrontEnd](https://github.com/PaddlePaddle/Parakeet/blob/develop/parakeet/frontend/__init__.py) to do G2P. +The phoneme set should be changed, you can reference `examples/thchs30/a0/data/dict/syllable.lexicon`. + +## Download Baker dataset + +[Baker](https://test.data-baker.com/#/data/index/source) dataset has to be downloaded mannually and moved to './data', +because you will have to pass the `CATTCHA` from a browswe to download the dataset. + + +## RUN + +``` +. path.sh +./run.sh +``` + +## Result + +``` +exp/ +|-- 000001-010000.txt +|-- ref.pinyin +|-- trans.jieba.pinyin +`-- trans.pinyin + +0 directories, 4 files +``` + +``` +4f5a368441eb16aaf43dc1972f8b63dd exp/000001-010000.txt +01707896391c2de9b6fc4a39654be942 exp/ref.pinyin +43380ef160f65a23a3a0544700aa49b8 exp/trans.jieba.pinyin +8e6ff1fc22d8e8584082e804e8bcdeb7 exp/trans.pinyin +``` + +``` +==> exp/000001-010000.txt <== +000001 卡尔普#2陪外孙#1玩滑梯#4。 + ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 假语村言#2别再#1拥抱我#4。 + jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 宝马#1配挂#1跛骡鞍#3,貂蝉#1怨枕#2董翁榻#4。 + bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 邓小平#2与#1撒切尔#2会晤#4。 + deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4 +000005 老虎#1幼崽#2与#1宠物犬#1玩耍#4。 + lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3 + +==> exp/ref.pinyin <== +000001 ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4 +000005 lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu2 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan2 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi2 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 + +==> exp/trans.jieba.pinyin <== +000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4 +000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 + +==> exp/trans.pinyin <== +000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4 +000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 +``` diff --git a/examples/chinese_g2p/local/convert_transcription.py b/examples/g2p/zh/local/convert_transcription.py similarity index 100% rename from examples/chinese_g2p/local/convert_transcription.py rename to examples/g2p/zh/local/convert_transcription.py diff --git a/examples/chinese_g2p/local/extract_pinyin_label.py b/examples/g2p/zh/local/extract_pinyin_label.py similarity index 100% rename from examples/chinese_g2p/local/extract_pinyin_label.py rename to examples/g2p/zh/local/extract_pinyin_label.py diff --git a/examples/chinese_g2p/local/ignore_sandhi.py b/examples/g2p/zh/local/ignore_sandhi.py similarity index 100% rename from examples/chinese_g2p/local/ignore_sandhi.py rename to examples/g2p/zh/local/ignore_sandhi.py diff --git a/examples/chinese_g2p/local/prepare_dataset.sh b/examples/g2p/zh/local/prepare_dataset.sh similarity index 100% rename from examples/chinese_g2p/local/prepare_dataset.sh rename to examples/g2p/zh/local/prepare_dataset.sh diff --git a/examples/chinese_g2p/path.sh b/examples/g2p/zh/path.sh similarity index 82% rename from examples/chinese_g2p/path.sh rename to examples/g2p/zh/path.sh index 482177dc6367acc8a4eb7fb71d66c3271f763645..f475ed8331508f8a86ab1d6a4bf7f9a88b85aeaf 100644 --- a/examples/chinese_g2p/path.sh +++ b/examples/g2p/zh/path.sh @@ -1,4 +1,4 @@ -export MAIN_ROOT=`realpath ${PWD}/../../` +export MAIN_ROOT=`realpath ${PWD}/../../../` export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C diff --git a/examples/chinese_g2p/requirements.txt b/examples/g2p/zh/requirements.txt similarity index 100% rename from examples/chinese_g2p/requirements.txt rename to examples/g2p/zh/requirements.txt diff --git a/examples/chinese_g2p/run.sh b/examples/g2p/zh/run.sh similarity index 82% rename from examples/chinese_g2p/run.sh rename to examples/g2p/zh/run.sh index 8197dce4bea7f39182d40b510883be222939ddeb..25b713110ebc92f5047bb7a648c918270bcf0650 100755 --- a/examples/chinese_g2p/run.sh +++ b/examples/g2p/zh/run.sh @@ -6,16 +6,19 @@ stage=-1 stop_stage=100 exp_dir=exp -data_dir=data +data=data source ${MAIN_ROOT}/utils/parse_options.sh || exit -1 mkdir -p ${exp_dir} +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ];then + test -e ${data}/BZNSYP.rar || { echo "Please download BZNSYP.rar and put it in ${data}; exit -1; } +fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then echo "stage 0: Extracting Prosody Labeling" - bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data_dir} + bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data} fi # convert transcription in chinese into pinyin with pypinyin or jieba+pypinyin diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 5603d3c8aa1743b8aceae59f459f66d162fee24e..11bcf5f65df1ceb169096968ec9fa9ead23b58ff 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -1,10 +1,17 @@ # LibriSpeech +## Data +| Data Subset | Duration in Seconds | +| --- | --- | +| data/manifest.train | 0.83s ~ 29.735s | +| data/manifest.dev | 1.065 ~ 35.155s | +| data/manifest.test-clean | 1.285s ~ 34.955s | + ## Deepspeech2 | Model | Params | release | Config | Test set | Loss | WER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | 14.49190807 | test-clean | 0.067283 | -| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | -| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | +| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | test-clean | 14.49190807 | 0.067283 | +| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | test-clean | 15.184467315673828 | 0.072154 | +| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | test-clean | - | 0.073973 | | DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 | diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index dab8d0462a7275e560d5afcf0b969d6042deefa7..3f1a376f181bf2cd7066ac0d9d4e858864d661db 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -4,7 +4,7 @@ data: dev_manifest: data/manifest.dev-clean test_manifest: data/manifest.test-clean min_input_len: 0.0 - max_input_len: 27.0 # second + max_input_len: 30.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 @@ -40,9 +40,12 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 + accum_grad: 1 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml index 2e4aed40ab07f5dc495e6c3062f72a9535c798db..180a6205f2af0429b3e42de8fed3772c4e8471b2 100644 --- a/examples/librispeech/s0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -4,14 +4,14 @@ data: dev_manifest: data/manifest.dev-clean test_manifest: data/manifest.test-clean min_input_len: 0.0 - max_input_len: 27.0 # second + max_input_len: 30.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf collator: - batch_size: 20 + batch_size: 15 mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -42,9 +42,12 @@ model: num_fc_layers: 2 fc_layers_size_list: 512, 256 use_gru: False + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 + accum_grad: 4 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/local/export.sh b/examples/librispeech/s0/local/export.sh index 2e09e5f5e76a7f7cdf9cca8fbb91d66bb48aea0c..a5e62c28d2fa23a5ef9b9e2a0281b025aa943a30 100755 --- a/examples/librispeech/s0/local/export.sh +++ b/examples/librispeech/s0/local/export.sh @@ -13,13 +13,7 @@ ckpt_path_prefix=$2 jit_model_export_path=$3 model_type=$4 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/librispeech/s0/local/test.sh b/examples/librispeech/s0/local/test.sh index b5b68c599c45ab50aa12ee35c120e02fb68740b4..4d00f30b852da5a370f5d4934f3caadd2b833c00 100755 --- a/examples/librispeech/s0/local/test.sh +++ b/examples/librispeech/s0/local/test.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 model_type=$3 @@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then fi python3 -u ${BIN_DIR}/test.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index dcd21df34345e29b7dccadcb32aa2bbedf38ffbb..519df7fe9387a215f6d8835d90869ec0d4d7f5e5 100755 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -12,28 +12,22 @@ config_path=$1 ckpt_name=$2 model_type=$3 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi -echo "using ${device}..." - mkdir -p exp -seed=1024 -if [ ${seed} ]; then +# seed may break model convergence +seed=0 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ --model_type ${model_type} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/librispeech/s0/local/tune.sh b/examples/librispeech/s0/local/tune.sh deleted file mode 100755 index c344e77e58a52c803bd69e1bd27804c8d0072415..0000000000000000000000000000000000000000 --- a/examples/librispeech/s0/local/tune.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -if [ $# != 1 ];then - echo "usage: tune ckpt_path" - exit 1 -fi - -# grid-search for hyper-parameters in language model -python3 -u ${BIN_DIR}/tune.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---num_batches=-1 \ ---batch_size=128 \ ---beam_size=500 \ ---num_proc_bsearch=12 \ ---num_alphas=45 \ ---num_betas=8 \ ---alpha_from=1.0 \ ---alpha_to=3.2 \ ---beta_from=0.1 \ ---beta_to=0.45 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in tuning!" - exit 1 -fi - - -exit 0 diff --git a/examples/librispeech/s0/run.sh b/examples/librispeech/s0/run.sh index c7902a56a882abecddd36b6440f9c88e134d628b..af47fb9b8b70d7d30eead78722303d0de58c4b06 100755 --- a/examples/librispeech/s0/run.sh +++ b/examples/librispeech/s0/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/librispeech/s1/cmd.sh b/examples/librispeech/s1/cmd.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b70ef5e06e550d466e41668a44a86263952262c --- /dev/null +++ b/examples/librispeech/s1/cmd.sh @@ -0,0 +1,89 @@ +# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ====== +# Usage: .pl [options] JOB=1: +# e.g. +# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB +# +# Options: +# --time