“2214c0d27c27d2db8b7e99a6a9c948b3f8f5c45f”上不存在“paddlespeech/s2t/git@gitcode.net:paddlepaddle/DeepSpeech.git”
提交 472cf70e 编写于 作者: H Hui Zhang

refactor egs; add utils; add tools; rm notebook;add speechnn; more docs;

上级 5ef4a34e
unset GREP_OPTIONS
# https://zhuanlan.zhihu.com/p/33050965
alias nvs='nvidia-smi'
alias his='history'
alias jobs='jobs -l'
alias ports='netstat -tulanp'
alias wget='wget -c'
## Colorize the grep command output for ease of use (good for log files)##
alias grep='grep --color=auto'
alias egrep='egrep --color=auto'
alias fgrep='fgrep --color=auto'
......@@ -42,6 +42,10 @@ ignore =
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
per-file-ignores =
*/__init__.py: F401
# Specify the list of error codes you wish Flake8 to report.
select =
E,
......
......@@ -10,8 +10,15 @@
.ipynb_checkpoints
*.npz
*.done
*.whl
tools/venv
tools/kenlm
tools/sox-14.4.2
tools/soxbindings
tools/montreal-forced-aligner/
tools/Montreal-Forced-Aligner/
tools/sctk
tools/sctk-20159b5/
*output/
......@@ -87,3 +87,9 @@ pull_request_rules:
actions:
label:
add: ["Docker"]
- name: "auto add label=Deployment"
conditions:
- files~=^speechnn/
actions:
label:
add: ["Deployment"]
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "academic-surname",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from paddle import nn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fundamental-treasure",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"L = nn.Linear(256, 2048)\n",
"L2 = nn.Linear(2048, 256)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "consolidated-elephant",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "moderate-noise",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float64\n",
"Tensor(shape=[2, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[[-1.54171216, -2.61531472, -1.79881978, ..., -0.31395876, 0.56513089, -0.44516513],\n",
" [-0.79492962, 1.91157901, 0.66567147, ..., 0.54825783, -1.01471853, -0.84924090],\n",
" [-1.22556651, -0.36225814, 0.65063190, ..., 0.65726501, 0.05563191, 0.09009409],\n",
" ...,\n",
" [ 0.38615900, -0.77905393, 0.99732304, ..., -1.38463700, -3.32365036, -1.31089687],\n",
" [ 0.05579993, 0.06885809, -1.66662002, ..., -0.23346378, -3.29372883, 1.30561364],\n",
" [ 1.90676069, 1.95093191, -0.28849599, ..., -0.06860496, 0.95347673, 1.00475824]],\n",
"\n",
" [[-0.91453546, 0.55298805, -1.06146812, ..., -0.86378336, 1.00454640, 1.26062179],\n",
" [ 0.10223761, 0.81301165, 2.36865163, ..., 0.16821407, 0.29240361, 1.05408621],\n",
" [-1.33196676, 1.94433689, 0.01934209, ..., 0.48036841, 0.51585966, 1.22893548],\n",
" ...,\n",
" [-0.19558455, -0.47075930, 0.90796155, ..., -1.28598249, -0.24321797, 0.17734711],\n",
" [ 0.89819717, -1.39516675, 0.17138045, ..., 2.39761519, 1.76364994, -0.52177650],\n",
" [ 0.94122332, -0.18581429, 1.36099780, ..., 0.67647684, -0.04699665, 1.51205540]]])\n",
"tensor([[[-1.5417, -2.6153, -1.7988, ..., -0.3140, 0.5651, -0.4452],\n",
" [-0.7949, 1.9116, 0.6657, ..., 0.5483, -1.0147, -0.8492],\n",
" [-1.2256, -0.3623, 0.6506, ..., 0.6573, 0.0556, 0.0901],\n",
" ...,\n",
" [ 0.3862, -0.7791, 0.9973, ..., -1.3846, -3.3237, -1.3109],\n",
" [ 0.0558, 0.0689, -1.6666, ..., -0.2335, -3.2937, 1.3056],\n",
" [ 1.9068, 1.9509, -0.2885, ..., -0.0686, 0.9535, 1.0048]],\n",
"\n",
" [[-0.9145, 0.5530, -1.0615, ..., -0.8638, 1.0045, 1.2606],\n",
" [ 0.1022, 0.8130, 2.3687, ..., 0.1682, 0.2924, 1.0541],\n",
" [-1.3320, 1.9443, 0.0193, ..., 0.4804, 0.5159, 1.2289],\n",
" ...,\n",
" [-0.1956, -0.4708, 0.9080, ..., -1.2860, -0.2432, 0.1773],\n",
" [ 0.8982, -1.3952, 0.1714, ..., 2.3976, 1.7636, -0.5218],\n",
" [ 0.9412, -0.1858, 1.3610, ..., 0.6765, -0.0470, 1.5121]]])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"x = np.random.randn(2, 51, 256)\n",
"print(x.dtype)\n",
"px = paddle.to_tensor(x, dtype='float32')\n",
"tx = torch.tensor(x, dtype=torch.float32)\n",
"print(px)\n",
"print(tx)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cooked-progressive",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "mechanical-prisoner",
"metadata": {},
"outputs": [],
"source": [
"data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n",
"t_norm_ff = data['norm_ff']\n",
"t_ff_out = data['ff_out']\n",
"t_ff_l_x = data['ff_l_x']\n",
"t_ff_l_a_x = data['ff_l_a_x']\n",
"t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
"t_ps = data['ps']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "indie-marriage",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "assured-zambia",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"L.set_state_dict({'weight': t_ps[0].T, 'bias': t_ps[1]})\n",
"L2.set_state_dict({'weight': t_ps[2].T, 'bias': t_ps[3]})\n",
"\n",
"ps = []\n",
"for n, p in L.named_parameters():\n",
" ps.append(p)\n",
"\n",
"for n, p in L2.state_dict().items():\n",
" ps.append(p)\n",
" \n",
"for p, tp in zip(ps, t_ps):\n",
" print(np.allclose(p.numpy(), tp.T))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "committed-jacob",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "extreme-traffic",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "optimum-milwaukee",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "viral-indian",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"# data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n",
"# t_norm_ff = data['norm_ff']\n",
"# t_ff_out = data['ff_out']\n",
"# t_ff_l_x = data['ff_l_x']\n",
"# t_ff_l_a_x = data['ff_l_a_x']\n",
"# t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
"# t_ps = data['ps']\n",
"TL = torch.nn.Linear(256, 2048)\n",
"TL2 = torch.nn.Linear(2048, 256)\n",
"TL.load_state_dict({'weight': torch.tensor(t_ps[0]), 'bias': torch.tensor(t_ps[1])})\n",
"TL2.load_state_dict({'weight': torch.tensor(t_ps[2]), 'bias': torch.tensor(t_ps[3])})\n",
"\n",
"# for n, p in TL.named_parameters():\n",
"# print(n, p)\n",
"# for n, p in TL2.named_parameters():\n",
"# print(n, p)\n",
"\n",
"ps = []\n",
"for n, p in TL.state_dict().items():\n",
" ps.append(p.data.numpy())\n",
" \n",
"for n, p in TL2.state_dict().items():\n",
" ps.append(p.data.numpy())\n",
" \n",
"for p, tp in zip(ps, t_ps):\n",
" print(np.allclose(p, tp))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "skilled-vietnamese",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[ 0.67277956 0.08313607 -0.62761104 ... -0.17480263 0.42718208\n",
" -0.5787626 ]\n",
" [ 0.91516656 0.5393416 1.7159258 ... 0.06144593 0.06486575\n",
" -0.03350811]\n",
" [ 0.438351 0.6227843 0.24096036 ... 1.0912522 -0.90929437\n",
" -1.012989 ]\n",
" ...\n",
" [ 0.68631977 0.14240924 0.10763275 ... -0.11513516 0.48065388\n",
" 0.04070369]\n",
" [-0.9525228 0.23197874 0.31264272 ... 0.5312439 0.18773697\n",
" -0.8450228 ]\n",
" [ 0.42024016 -0.04561988 0.54541194 ... -0.41933843 -0.00436018\n",
" -0.06663495]]\n",
"\n",
" [[-0.11638781 -0.33566502 -0.20887226 ... 0.17423287 -0.9195841\n",
" -0.8161046 ]\n",
" [-0.3469874 0.88269687 -0.11887559 ... -0.15566081 0.16357468\n",
" -0.20766167]\n",
" [-0.3847657 0.3984318 -0.06963477 ... -0.00360622 1.2360432\n",
" -0.26811332]\n",
" ...\n",
" [ 0.08230796 -0.46158582 0.54582864 ... 0.15747628 -0.44790155\n",
" 0.06020184]\n",
" [-0.8095085 0.43163058 -0.42837143 ... 0.8627463 0.90656304\n",
" 0.15847842]\n",
" [-1.485811 -0.18216592 -0.8882585 ... 0.32596245 0.7822631\n",
" -0.6460344 ]]]\n",
"[[[ 0.67278004 0.08313602 -0.6276114 ... -0.17480245 0.42718196\n",
" -0.5787625 ]\n",
" [ 0.91516703 0.5393413 1.7159253 ... 0.06144581 0.06486579\n",
" -0.03350812]\n",
" [ 0.43835106 0.62278455 0.24096027 ... 1.0912521 -0.9092943\n",
" -1.0129892 ]\n",
" ...\n",
" [ 0.6863195 0.14240888 0.10763284 ... -0.11513527 0.48065376\n",
" 0.04070365]\n",
" [-0.9525231 0.23197863 0.31264275 ... 0.53124386 0.18773702\n",
" -0.84502304]\n",
" [ 0.42024007 -0.04561983 0.545412 ... -0.41933888 -0.00436005\n",
" -0.066635 ]]\n",
"\n",
" [[-0.11638767 -0.33566508 -0.20887226 ... 0.17423296 -0.9195838\n",
" -0.8161046 ]\n",
" [-0.34698725 0.88269705 -0.11887549 ... -0.15566081 0.16357464\n",
" -0.20766166]\n",
" [-0.3847657 0.3984319 -0.06963488 ... -0.00360619 1.2360426\n",
" -0.26811326]\n",
" ...\n",
" [ 0.08230786 -0.4615857 0.5458287 ... 0.15747619 -0.44790167\n",
" 0.06020182]\n",
" [-0.8095083 0.4316307 -0.42837155 ... 0.862746 0.9065631\n",
" 0.15847899]\n",
" [-1.485811 -0.18216613 -0.8882584 ... 0.32596254 0.7822631\n",
" -0.6460344 ]]]\n",
"True\n",
"False\n"
]
}
],
"source": [
"y = L(px)\n",
"print(y.numpy())\n",
"\n",
"ty = TL(tx)\n",
"print(ty.data.numpy())\n",
"print(np.allclose(px.numpy(), tx.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "incorrect-allah",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "prostate-cameroon",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"id": "governmental-surge",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.04476918 0.554463 -0.3027508 ... -0.49600336 0.3751858\n",
" 0.8254095 ]\n",
" [ 0.95594174 -0.29528382 -1.2899452 ... 0.43718258 0.05584608\n",
" -0.06974669]]\n",
"[[ 0.04476918 0.5544631 -0.3027507 ... -0.49600336 0.37518573\n",
" 0.8254096 ]\n",
" [ 0.95594174 -0.29528376 -1.2899454 ... 0.4371827 0.05584623\n",
" -0.0697467 ]]\n",
"True\n",
"False\n",
"True\n"
]
}
],
"source": [
"x = np.random.randn(2, 256)\n",
"px = paddle.to_tensor(x, dtype='float32')\n",
"tx = torch.tensor(x, dtype=torch.float32)\n",
"y = L(px)\n",
"print(y.numpy())\n",
"ty = TL(tx)\n",
"print(ty.data.numpy())\n",
"print(np.allclose(px.numpy(), tx.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy()))\n",
"print(np.allclose(y.numpy(), ty.detach().numpy(), atol=1e-5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "confidential-jacket",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
"id": "improved-civilization",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5e7e7c9fde8350084abf1898cf52651cfc84b17a\n"
]
}
],
"source": [
"print(paddle.version.commit)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d1e2d3b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" 'commit',\n",
" 'full_version',\n",
" 'istaged',\n",
" 'major',\n",
" 'minor',\n",
" 'mkl',\n",
" 'patch',\n",
" 'rc',\n",
" 'show',\n",
" 'with_mkl']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(paddle.version)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c880c719",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.1.0\n"
]
}
],
"source": [
"print(paddle.version.full_version)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f26977bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"commit: 5e7e7c9fde8350084abf1898cf52651cfc84b17a\n",
"None\n"
]
}
],
"source": [
"print(paddle.version.show())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "04ad47f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.6.0\n"
]
}
],
"source": [
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "e1e03830",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" '__version__',\n",
" 'cuda',\n",
" 'debug',\n",
" 'git_version',\n",
" 'hip']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(torch.version)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "4ad0389b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'b31f58de6fa8bbda5353b3c77d9be4914399724d'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.version.git_version"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7870ea10",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'10.2'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.version.cuda"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db8ee5a7",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6321ec2a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d6a0e098",
"metadata": {},
"outputs": [],
"source": [
"from typing import Union\n",
"\n",
"import torch\n",
"from torch.optim.lr_scheduler import _LRScheduler\n",
"\n",
"from typeguard import check_argument_types\n",
"\n",
"\n",
"class WarmupLR(_LRScheduler):\n",
" \"\"\"The WarmupLR scheduler\n",
" This scheduler is almost same as NoamLR Scheduler except for following\n",
" difference:\n",
" NoamLR:\n",
" lr = optimizer.lr * model_size ** -0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" WarmupLR:\n",
" lr = optimizer.lr * warmup_step ** 0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" Note that the maximum lr equals to optimizer.lr in this scheduler.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" optimizer: torch.optim.Optimizer,\n",
" warmup_steps: Union[int, float] = 25000,\n",
" last_epoch: int = -1,\n",
" ):\n",
" assert check_argument_types()\n",
" self.warmup_steps = warmup_steps\n",
"\n",
" # __init__() must be invoked before setting field\n",
" # because step() is also invoked in __init__()\n",
" super().__init__(optimizer, last_epoch)\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n",
"\n",
" def get_lr(self):\n",
" step_num = self.last_epoch + 1\n",
" return [\n",
" lr\n",
" * self.warmup_steps ** 0.5\n",
" * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n",
" for lr in self.base_lrs\n",
" ]\n",
"\n",
" def set_step(self, step: int):\n",
" self.last_epoch = step"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0d496677",
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"model = torch.nn.Linear(10, 200)\n",
"optimizer = optim.Adam(model.parameters())\n",
"scheduler = WarmupLR(optimizer, warmup_steps=25000)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e3e3f3dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.0 -1\n"
]
}
],
"source": [
"infos = {}\n",
"start_epoch = infos.get('epoch', -1) + 1\n",
"cv_loss = infos.get('cv_loss', 0.0)\n",
"step = infos.get('step', -1)\n",
"print(start_epoch, cv_loss, step)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "dc3d550c",
"metadata": {},
"outputs": [],
"source": [
"scheduler.set_step(step)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e527634e",
"metadata": {},
"outputs": [],
"source": [
"lrs=[]\n",
"for i in range(100000):\n",
" scheduler.step()\n",
" lrs.append(scheduler.get_lr())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f1452db9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting matplotlib\n",
" Downloading matplotlib-3.4.1-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)\n",
"\u001b[K |████████████████████████████████| 10.3 MB 575 kB/s eta 0:00:01\n",
"\u001b[?25hCollecting kiwisolver>=1.0.1\n",
" Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)\n",
"\u001b[K |████████████████████████████████| 1.2 MB 465 kB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (8.1.2)\n",
"Requirement already satisfied: numpy>=1.16 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (1.20.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.8.1)\n",
"Collecting cycler>=0.10\n",
" Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n",
"Requirement already satisfied: pyparsing>=2.2.1 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n",
"Requirement already satisfied: six in /workspace/wenet/venv/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n",
"Installing collected packages: kiwisolver, cycler, matplotlib\n",
"Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n"
]
}
],
"source": [
"!pip install matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0f36d04f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f0c39aa82e0>]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqc0lEQVR4nO3deXxV1b338c8vCUkYkkAghJAEAhLQIJMEHHFCBa2KVkG0T7Wt1qet9ra1w9Xn3ufe1ld7b21tvVq1alut+mhJQK3Yqjig1SpCDgIyBiLTSZhCAglTyLSeP86GxjTDQZKc6ft+vXh5zjrrrLM2O+bL3mvv3zHnHCIiIu2JC/UEREQkvCkoRESkQwoKERHpkIJCREQ6pKAQEZEOJYR6Al1h0KBBLi8vL9TTEBGJKMuXL9/rnMvorF9UBEVeXh4+ny/U0xARiShmti2Yfjr1JCIiHVJQiIhIhxQUIiLSIQWFiIh0SEEhIiIdCioozGymmZWaWZmZ3d3G60lmVuS9vtTM8lq8do/XXmpmM1q0P2lme8xsTaux0s3sTTPb5P13wElsn4iInKROg8LM4oFHgMuBAuBGMyto1e1WYJ9zbhTwAHCf994CYC4wFpgJPOqNB/BHr621u4G3nXP5wNvecxERCZFgjiimAmXOuc3OuXpgHjCrVZ9ZwNPe4wXAdDMzr32ec+6oc24LUOaNh3PuPaC6jc9rOdbTwDXBb450p82VB3m3dE+opyEiPSyYoMgG/C2el3ttbfZxzjUCNcDAIN/bWqZzbqf3eBeQ2VYnM7vdzHxm5qusrAxiM+Rk3fS7pXzlqRLeWrc71FMRkR4U1ovZLvCtSm1+s5Jz7gnnXKFzrjAjo9M70OUkle05wK7aOgC+V7SSTysPhnhGItJTggmKCiC3xfMcr63NPmaWAKQBVUG+t7XdZpbljZUF6FxHGCj2lZMQZ7xy53kkJsRx+zM+DtQ1hHpaItIDggmKEiDfzEaYWSKBxemFrfosBG7xHl8PLPaOBhYCc72rokYA+cCyTj6v5Vi3AC8HMUfpRg1Nzbz4cTnTTxvMuJw0Hr7pDLZWHeZ7RatobtZX6YpEu06DwltzuBNYBKwHip1za83sXjO72uv2B2CgmZUBd+FdqeScWwsUA+uA14E7nHNNAGb2J2AJMMbMys3sVm+snwOXmtkm4BLvuYTQ4g172HuwnjmFgYPDs08ZyL9/4TTeWr+bB9/eFOLZiUh3s8A//CNbYWGhU/XY7nPb0yV8Ul7Dh3dfTEJ84N8Wzjl+uOATFiwv58G5E5k1sbNrFEQk3JjZcudcYWf9wnoxW0JvT20d75RWct3knOMhAWBm/Oza0zlzRDo/nP8JJVvbutJZRKKBgkI6tODjcpqa3fHTTi0lJcTz+Jcnk5Pem68/42PL3kMhmKGIdDcFhbTLOcd8XzlT89IZMahvm33690nkqa9MIc6Mrz61jOpD9T08SxHpbgoKaVfJ1n1s2XuIOVP++WiipeED+/K7myezo6aOrz/j40h9Uw/NUER6goJC2lXs89MvKYErxg3ptO/k4ek8eMNEVmzfxzefW059Y3MPzFBEeoKCQtp0oK6Bv36yk6smZNEnMbivVr98XBY/u3Yc75ZW8oP5usdCJFoE9xtAYs5fP9nJkYamNhexO3Lj1GHsP9zAfa9vIK13L+6dNZZAfUgRiVQKCmlTkc9P/uB+TMztf8Lv/eaFp7D/cD2Pv7eZ/n168f3LxnT9BEWkxygo5J9s2n2AFdv38+9fOO1zHw3cffmp1Bxp4DeLy0iMj+Pb0/O7eJYi0lMUFPJPin1+EuKMayZ9/rutAzfkjaO+sZlfvbkRM7jzYoWFSCRSUMhnBAoAVnDJaZkM6pd0UmPFxxm/nD0BgPvf2AgoLEQikYJCPuPt9XuoOlTPnCk5XTJe67AwM+64aFSXjC0iPUNBIZ8x3+cnMzWJ8/O77sugjoWFA365qJT6xma+e0m+roYSiRAKCjlud20d75Tu4RsXnPKZAoBdIT7OuH/2BBLijAff3kTNkQb+48oC4uIUFiLhTkEhxy1YXk6z44TvnQhWfJxx33XjSUnuxZMfbOFAXSP3XTeuy0NJRLqWgkKAYwUA/UwdkU5eOwUAu0JcnPF/rzyNtN69eOCtjRyoa+A3N00iKSG+2z5TRE6O/iknACzbUs3WqsPc0E1HEy2ZGd+5JJ//vKqAN9bt5qtPlVCr798WCVsKCgGg2FfuFQDM6rHP/Oq5I/j1nAks21LN9b/9kIr9R3rss0UkeAoK4UBdA6+u3slVE4bSO7FnTwF98Ywcnv7aVHbur+PaRz5gTUVNj36+iHROQSH8xSsAeEMn3zvRXc4dNYgF3zyHXvFxzHl8Ce9s2BOSeYhI2xQUQlGJn9GZ/ZiQkxayOYwZksJL3zqHkRl9ufXpEp5ZshXnVKZcJBwoKGLcxt0HWOnfz5zC3JDfADc4NZmi28/mojGD+Y+X13LPi6s52qhvyxMJNQVFjCsu8dMr3rj2JAoAdqW+SQk8cXMhd1x0CvNK/Nz0u6Xsqa0L9bREYpqCIobVNzbz0opAAcCBJ1kAsCvFxxk/nHEqj9x0But21HLVw39npX9/qKclErMUFDFs8YbdgQKAPXDvxOfxhfFZvPDNc0iICyxyF5f4Qz0lkZikoIhhxb5yhqQmc/7orisA2NUKhqbyyrfPo3D4AH70wid8v3gVh+sbQz0tkZiioIhRu2rqeLd0D9dNziY+zAvzpfdN5Nlbz+Rfpufz4opyZj38AWV7DoR6WiIxQ0ERo174OFAAcPbk8Dzt1Fp8nHHXpaN55mtTqT5Uz1W/+YCXVpSHeloiMUFBEYOccxT7/JzZzQUAu8O0/Axe/c40xuWk8b2iVfxowSoOHdWpKJHupKCIQUu3VLOt6nDI7sQ+WZmpyTx/25ncedEo5i8v54qH3mf5tn2hnpZI1FJQxKBin5+UpAQuP73nCgB2tYT4OH4wYwxFt59NY5Nj9mMf8us3N9LQ1BzqqYlEnaCCwsxmmlmpmZWZ2d1tvJ5kZkXe60vNLK/Fa/d47aVmNqOzMc1supl9bGYrzezvZqYvWO5CtccKAE7s+QKA3WHqiHRe++40rpmUzUNvb2L2Y0vYsvdQqKclElU6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5m+BLznnJgLPA/9+Ulson/GXVTupa2juke+d6Cmpyb349ZyJPHzTJLbsPcQVD77PUx9soblZtaJEukIwRxRTgTLn3GbnXD0wD5jVqs8s4Gnv8QJgugUKB80C5jnnjjrntgBl3ngdjemAVO9xGrDj822atKXI52dMZgrjQ1gAsLtcOX4or393GmeOTOcnr6xjzuNL+LTyYKinJRLxggmKbKDlLbHlXlubfZxzjUANMLCD93Y05m3Aq2ZWDnwZ+HlbkzKz283MZ2a+ysrKIDZDSncdYJV/P3OmhL4AYHfJSuvNU1+Zwq9mT2DTnoNc/uD7/PbdT2nU2oXI5xaOi9nfA65wzuUATwG/bquTc+4J51yhc64wIyN87ywOJ8W+8CoA2F3MjOsm5/DmXedz0ZgM7nt9A9c++iHrd9aGemoiESmYoKgAWp7QzvHa2uxjZgkEThlVdfDeNtvNLAOY4Jxb6rUXAecEtSXSoWMFAC8tyCS9b2Kop9MjBqck89j/mswjN53Bjv1HuPI3f+e/Xl2v+y5ETlAwQVEC5JvZCDNLJLA4vbBVn4XALd7j64HFLvCtMwuBud5VUSOAfGBZB2PuA9LMbLQ31qXA+s+/eXLM2+t3U32ontlRtIgdDDPjC+OzeOuuC5hTmMMT721m+q/+xmurd+qLkUSClNBZB+dco5ndCSwC4oEnnXNrzexewOecWwj8AXjWzMqAagK/+PH6FQPrgEbgDudcE0BbY3rtXwdeMLNmAsHxtS7d4hhV5PMHCgDmx+ZpugF9E/nvL45ndmEu//bSGr753MdcMDqDn1w9NuLuThfpaRYN/6oqLCx0Pp8v1NMIWztrjnDuzxfzrQtH8YMZY0I9nZBrbGrm2Y+28as3NlLf1Mw3LjiFb1wwkj6Jnf67SSSqmNly51xhZ/3CcTFbutgLy70CgIU5oZ5KWEiIj+Or547g7e9fwMyxQ3jo7U1cfP/fePHjct17IdIGBUWUa252FPvKOWtkOsMH6hRLS5mpyTx04yQWfONsMlOTuKt4Fdc8+gG+rdWhnppIWFFQRLmlW6rZXh25BQB7QmFeOi9961weuGECe2qPcv1jS7jj+Y/xVx8O9dREwoJOyka5+T4/KcmRXQCwJ8TFGddOymHG2CE8/rfNPP7ep7y5djdfOmsYd1w0ikFh9J3iIj1NRxRRrLaugVfX7OTqCUNJ7hX5BQB7Qp/EBL536Wje+cGFXDspm6c/3MoFv3iHX7+5kQN1DaGenkhIKCii2CurdgQKAOq00wnLSuvNfdeP543vXcAFYzJ46O1NnP+Ld/j9+5upa2gK9fREepSCIooVl/g5dUgK47KjrwBgTxk1uB+PfmkyC+88l9Oz0/jpX9dz0f3v8vzS7dQ3qn6UxAYFRZTasKuWVeU1zCmM3gKAPWl8Tn+evfVMnr/tTDJTk/k/L63mwl++w7NLtnK0UUcYEt0UFFGquKScXvHGNVFeALCnnTNqEC996xye+dpUsvr35v++vJYLfvEuf/xgi05JSdRSUEShQAHAci4rGBIzBQB7kplx/ugMFnzjbJ677UyGpffhx6+sY5q3hnG4XkUHJbro8tgo9Nb63ew73KA7sbuZmXHuqEGcO2oQH22u4qG3N/HTv67n4XfKuPms4dx8Tp4uq5WooKCIQkUlfrLSkpkWowUAQ+GskQM5a+RAlm+r5rG/beahxWU8/t5mrp+cw9enjVThQYloCooos2P/Ed7bVMmdF40iPk6L2D1t8vB0fndzOmV7DvL79zcz31fO88u2M3PsEG4/fySThg0I9RRFTpiCIsq8sLwc52D2ZN07EUqjBvfj59eN565LR/PHD7fy/z7axmtrdjE1L51bzsnjsrGZ9IrXEqFEBpUZjyLNzY4L73+XnAG9ef7rZ4V6OtLCwaONzFu2naeXbMVffYQhqcl8+ezhzJ2Sy0CtY0iIqMx4DPpoSxXbqw8zJ8a+xS4S9EtK4LZpI3n3Bxfxu5sLGTW4H79cVMrZP1/M94tXsbq8JtRTFGmXTj1Fkfm+clKSE5h5+pBQT0XaER9nXFqQyaUFmZTtOcDTH27jhY/LeeHjcs4Y1p8vnz2cy0/PUm0uCSs69RQlao40MPVnbzG7MIefXjMu1NORE1Bb18ACXznPLNnK1qrDpPXuxbWTsrlx6jDGDEkJ9fQkigV76klHFFHilVU7ONrYzA2Fw0I9FTlBqcm9+Np5I/jKOXl8tLmK55dt57ml2/jjh1s5Y1h/bpw6jCvHD6V3oo4yJDR0RBElrn7479Q3NvPad6aptlMUqDp4lBc/ruBPJdvZXHmIlOQErpmYzQ1Tchk7NFX7WLqEjihiyPqdtXxSXsN/XlWgXyBRYmC/JL5+/khumzaCZVuq+dOy7RT5/Dz70TbGZKZw3eRsZk3MJjM1OdRTlRigoIgCxT4/ifFxXDNRBQCjjZlx5siBnDlyID8+XM8rn+zkxY/L+a9XN/Dz1zZwXn4G152RzWUFQ3RqSrqNgiLCHW1s4s8rKrh0bCYDVAAwqvXvk8iXzxrOl88azqeVB3np4wpeWlHBd+atpF9SAl8Yl8UXz8hmSl46cborX7qQgiLCvbVuD/sON+jeiRhzSkY/fjBjDHddOpqPtlTx4scV/OWTHRT5AnW+rhyfxZXjhzI+J02nI+WkaTE7wt385DLKdh/g/X+9WLWdYtzh+kbeWLubv3yyg79trKShyTEsvQ9XTcjiqglDGZOZotCQz9BidgzYsf8I72+q5NsqAChAn8QErpmUzTWTsqk53MCitbt45ZMdPPa3zTzyzqeMGtyPq8YP5coJWZyS0S/U05UIoqCIYAuOFQDUaSdpJa1PL+ZMyWXOlFz2HjzKa2t28ZdVO/iftzfywFsbOXVICpeNHcKMsZkUZOlyW+mYTj1FqOZmxwX3v8Ow9D48d5sKAEpwdtXU8erqnby+dhe+rdU0O8hN782MgiHMOH0IZwwboKPTGKJTT1Huo81V+KuP8IPLxoR6KhJBhqQl87XzRvC180aw9+BR3lq3m0Vrd/HMkm38/u9bGNQviUsLMpl5+hDOHjmQxATVDRUFRcQq9vlJTU5gxlgVAJTPZ1C/JOZOHcbcqcM4UNfAO6WVLFq7i5dXVvCnZdtJSUrg/NEZXHTqYC4ck6GvdY1hQQWFmc0EHgTigd87537e6vUk4BlgMlAF3OCc2+q9dg9wK9AE/ItzblFHY1rgZOlPgdnee37rnHvo5DYzutQcaeC1NbuYU5irKqPSJVKSe3H1hKFcPWEodQ1NfFC2lzfW7uad0j38dfVOzGBCTn+mnzqYi04drDIiMabToDCzeOAR4FKgHCgxs4XOuXUtut0K7HPOjTKzucB9wA1mVgDMBcYCQ4G3zGy09572xvwKkAuc6pxrNrPBXbGh0WThsQKAU7SILV0vuVc800/LZPppmTQ3O9btrOXt9XtYXLqHX725kV+9uZEhqclcdGoGF5+aybmjBtInUScnolkwe3cqUOac2wxgZvOAWUDLoJgF/Nh7vAB42DsymAXMc84dBbaYWZk3Hh2M+U3gJudcM4Bzbs/n37zoVFzi57SsVMYOTQ31VCTKxcUZp2encXp2Gt+5JJ89B+p4t7SSdzbsYeHKHfxpmZ/EhDim5qUzLX8Q54/O4NQhul8j2gQTFNmAv8XzcuDM9vo45xrNrAYY6LV/1Oq9xwoStTfmKQSORq4FKgmcrtrUelJmdjtwO8CwYbFTWnvdjlpWV9TwYxUAlBAYnJLMnMJc5hTmUt/YTMnWahZv2MP7myr579c28N+vbSAjJYlpowYxbfQgzhuVQUaK1jYiXTgeLyYBdc65QjP7IvAkMK11J+fcE8ATELg8tmenGDrHCgDOUgFACbHEhDjOHTWIc0cNAgKX3r6/qZL3N+3l3Y2VvLiiAoDTslI5P38Q0/IzKMwboHW1CBRMUFQQWDM4Jsdra6tPuZklAGkEFrU7em977eXAi97jl4CngphjTDja2MSfV1ZwmQoAShgakpbM7MJcZhfmHl/beG9TJe9v3MuTH2zh8fc2k5QQR2HeAM4eOZCzRg5kfE5/XYIbAYIJihIg38xGEPhlPhe4qVWfhcAtwBLgemCxc86Z2ULgeTP7NYHF7HxgGWAdjPln4CJgC3ABsPFzb12UeXPdbvarAKBEgJZrG9+6cBSHjjaybEs172/ay5LNVdz/RuB/69694gPBccpAzh45kHHZaSTEKzjCTadB4a053AksInAp65POubVmdi/gc84tBP4APOstVlcT+MWP16+YwCJ1I3CHc64JoK0xvY/8OfCcmX0POAjc1nWbG9mKSvxk9+99/FBfJFL0TUrgIu/SWoB9h+pZuqWKJZ9WsWRzFb94vRSAfkkJTDkeHIMoGJqqO8XDgEp4RIiK/Uc4777FfPvifO66dHTnbxCJIHsPHuWjzf8Ijs2VhwBISUrgjOEDmJI3gMK8dCbm9tcaRxdSCY8os8BXDsDsyTkhnolI1xvUL4krxw/lyvFDAdhdW8dHm6tYtqUa39Z9x09V9YoPnNKakpdO4fBAeKRrva7bKSgiQHOzY/5yP+eeMojc9D6hno5It8tMTWbWxOzjV/ftP1zP8m37KNm6D9/Wav74wVaeeG8zAKdk9A0Ehxcewwf20aXjXUxBEQGWbK6ifN8RfjhDBQAlNvXvk3j8bnGAuoYmVlfUULI1cMTx6uqdzCsJ3JqV3jeRibn9mZTbn0nDBjA+N43U5F6hnH7EU1BEABUAFPms5F7xTMlLZ0peOhA46t605yC+bdWs3L6fFf79LN4QKOpgBqMy+gXCY9gAJub2Z3RmP11ddQIUFGGu5nCgAODcKSoAKNKeuDhjzJAUxgxJ4UtnDgcCxTM/Kd/Piu37Wenfz1vrdzN/eWCtr09iPONz0piYGwiOCblpDElN1imrdigowtzCVRXUNzbr3gmRE5TWuxfT8jOYlp8BgHOObVWHWenfz4rt+1jp38/v399MY3Pgys9B/RI5PTuN8d79H+NyFB7HKCjCXJHPT0FWKqdnp4V6KiIRzczIG9SXvEF9uWZSYJG8rqGJtTtqWVNRw+qKGlaX1/Dexkq87GBQvyTGZacyLqc/47LTGJ+TRmZqcgi3IjQUFGFs7Y4a1lTU8pOrx4Z6KiJRKblXPJOHD2Dy8AHH247UN7FuZyA0VlfUsrpiP39rER4ZKUmM8446CrwqzjkDekf1kYeCIozN95WTmBDHrIlDQz0VkZjROzGeycPTmTw8/Xjb4fpG1u+s5ZPywJHHmooa3i3dczw8UpISOC0rldOyUjgtK5WCoamMzkyJmnVFBUWYqmto4qUVFcwYO4T+fXRDkUgo9UlM+KfwOFLfROnuA6zbUcv6nbWs21nLguXlHKpvAiDO4JSMfseD41iQDE6JvFNXCoow9ea63dQcaWBOoe7EFglHvRPjmZjbn4m5/Y+3NTc7tlcfZv3Of4TH8m37WLhqx/E+g/olcVpWCmMyUxg9JIXRmSnkD+5H36Tw/XUcvjOLccU+rwDgKSoAKBIp4uL+sWB++bis4+37D9ezfueB4+Gxfmctz360jaONzcf75Kb3ZkxmCvmZXohkpnDK4L4kJYT+9JWCIgyV7zvM38v28p3p+cSpcqZIxOvfJzFQEfeUgcfbmpod/urDlO4+wMZdByjdfYBNuw/ybmnl8Ut24+OMvIF9GO0Fx5ghKYzO7EfewL49esOggiIMLfBuCrpeBQBFolZ8i6OPllUX6hub2Vp1iI0tAmTDrgMsWrvr+OJ5YnwcIwb1ZdTgftx9+andXgNOQRFmmpsd833lnDdqEDkDVABQJNYkJsQdP4Jg/D/a6xqaKNtzMBAguw9Stucga3fU9Mg3BCoowsyHn1ZRsf8I/3r5qaGeioiEkeRe8ce/NbCnqSpWmCn2+Unr3YvLCjJDPRUREUBBEVZqDjfw+tpdXDNxaNTcqCMikU9BEUZePlYAcIoKAIpI+FBQhJGiEj9jh6YydqgKAIpI+FBQhIk1FTWs3VHLDTqaEJEwo6AIE/N9/kABwAnZoZ6KiMhnKCjCQF1DE39euYOZY4eQ1kff7Ssi4UVBEQbeOF4AUKedRCT8KCjCQHGJn5wBvTmnRR0YEZFwoaAIMX/1YT74dC+zJ+eqAKCIhCUFRYgdLwCo750QkTCloAih5mbHguWBAoDZ/XuHejoiIm1SUITQB5/upWL/ES1ii0hYU1CEULGvnP59enHZWBUAFJHwpaAIkf2H61m0dhfXTMwOi686FBFpT1BBYWYzzazUzMrM7O42Xk8ysyLv9aVmltfitXu89lIzm3ECYz5kZgc/53aFvZdX7ggUANRpJxEJc50GhZnFA48AlwMFwI1mVtCq263APufcKOAB4D7vvQXAXGAsMBN41MziOxvTzAqBASe5bWGtqMTP6dmpFAxNDfVUREQ6FMwRxVSgzDm32TlXD8wDZrXqMwt42nu8AJhuZua1z3POHXXObQHKvPHaHdMLkV8CPzq5TQtfaypqWLezlht0NCEiESCYoMgG/C2el3ttbfZxzjUCNcDADt7b0Zh3Agudczs7mpSZ3W5mPjPzVVZWBrEZ4aPYKwB4tQoAikgECKvFbDMbCswGftNZX+fcE865QudcYUZGRvdProvUNTTx5xUVXH66CgCKSGQIJigqgJbnSHK8tjb7mFkCkAZUdfDe9tonAaOAMjPbCvQxs7IgtyUiLFq7i9q6Ri1ii0jECCYoSoB8MxthZokEFqcXtuqzELjFe3w9sNg557z2ud5VUSOAfGBZe2M65/7qnBvinMtzzuUBh70F8qhR7POTm96bs0eqAKCIRIaEzjo45xrN7E5gERAPPOmcW2tm9wI+59xC4A/As96//qsJ/OLH61cMrAMagTucc00AbY3Z9ZsXXvzVh/mgrIq7Lh2tAoAiEjE6DQoA59yrwKut2v6jxeM6AmsLbb33Z8DPghmzjT79gplfpJi/vBwzuG6yCgCKSOQIq8XsaNbU7Fjg8zMtP0MFAEUkoigoesgHZXvZUVPHHJUTF5EIo6DoIcU+P/379OLSAhUAFJHIoqDoAfsO1fPG2t0qACgiEUlB0QNeXllBfZMKAIpIZFJQdDPnHEW+csZlp6kAoIhEJAVFN1tTUcv6nbXMmaKjCRGJTAqKblbs85OUEMfVE4aGeioiIp+LgqIb1TU08eeVXgHA3ioAKCKRSUHRjRat3cWBukaddhKRiKag6EZFJYECgGeNUAFAEYlcCopu4q8+zIefVjFncq4KAIpIRFNQdJP5Pr8KAIpIVFBQdIOmZseC5eWcn5/BUBUAFJEIp6DoBn8/XgBQi9giEvkUFN2g2OdnQJ9eXFIwONRTERE5aQqKLrbvUD1vrt3NNZNUAFBEooOCoou9tCJQAPAG3TshIlFCQdGFnHMU+/yMz0nj1CEqACgi0UFB0YVWV9SwYdcBLWKLSFRRUHShYwUAr1IBQBGJIgqKLlLX0MTLK3dwxbgsFQAUkaiioOgir6/xCgDqtJOIRBkFRRcpKvEzLL0PZ45ID/VURES6lIKiC2yvOsySzVXMKcxRAUARiToKii4wf7mfOBUAFJEopaA4SccLAI7OICtNBQBFJPooKE7S+5sq2akCgCISxRQUJ2m+r5z0volcclpmqKciItItFBQnofpQPW+s28U1E7NJTNBfpYhEp6B+u5nZTDMrNbMyM7u7jdeTzKzIe32pmeW1eO0er73UzGZ0NqaZPee1rzGzJ80sbO9ee2lFBQ1NTgUARSSqdRoUZhYPPAJcDhQAN5pZQatutwL7nHOjgAeA+7z3FgBzgbHATOBRM4vvZMzngFOBcUBv4LaT2sJu4pxjvs/PhJw0xgxJCfV0RES6TTBHFFOBMufcZudcPTAPmNWqzyzgae/xAmC6mZnXPs85d9Q5twUo88Zrd0zn3KvOAywDwvKa00/KvQKAOpoQkSgXTFBkA/4Wz8u9tjb7OOcagRpgYAfv7XRM75TTl4HX25qUmd1uZj4z81VWVgaxGV2r2OcnuZcKAIpI9AvnFdhHgfecc++39aJz7gnnXKFzrjAjI6NHJ3akvomFK3dwxelZpCaH7RKKiEiXSAiiTwXQ8vxKjtfWVp9yM0sA0oCqTt7b7phm9p9ABvC/g5hfj3t97U4OHG3UaScRiQnBHFGUAPlmNsLMEgksTi9s1WchcIv3+HpgsbfGsBCY610VNQLIJ7Du0O6YZnYbMAO40TnXfHKb1z2KSvwMH6gCgCISGzo9onDONZrZncAiIB540jm31szuBXzOuYXAH4BnzawMqCbwix+vXzGwDmgE7nDONQG0Nab3kY8B24AlgfVwXnTO3dtlW3yStlUd4qPN1fxwxhi8+YmIRLVgTj3hnHsVeLVV23+0eFwHzG7nvT8DfhbMmF57UHMKlfm+8kABwDPC8mIsEZEuF86L2WHnWAHAC0ZnMCQtOdTTERHpEQqKE/Depkp21aoAoIjEFgXFCZjv85PeN5HpKgAoIjFEQRGkqoNHeXPdbq6dpAKAIhJb9BsvSMcKAOq0k4jEGgVFEJxzFPv8TMjtrwKAIhJzFBRBWFVew8bdB7lBRxMiEoMUFEH4RwHArFBPRUSkxykoOnGkvolXVu7ginFZpKgAoIjEIAVFJ15bEygAqNNOIhKrFBSdKCrxkzewD1NVAFBEYpSCogNb9x5i6ZZqZhfmqgCgiMQsBUUH5i/3qwCgiMQ8BUU7jhUAvHDMYBUAFJGYpqBox3sbK9lde5Q5hTqaEJHYpqBoR1GJn4F9E7n4VBUAFJHYpqBoQ9XBo7y1XgUARURAQdGml1ZU0NjsmDNF906IiCgoWnHOUVTiZ2Juf0ZnqgCgiIiCopWV/v1s2nOQG3Q0ISICKCj+SbGvnN694rlyvAoAioiAguIzDtc38soqFQAUEWlJQdHCa6t3cfBoo047iYi0oKBoocjnZ8SgvkzJGxDqqYiIhA0FhWfL3kMs21LN7MIcFQAUEWlBQeGZ71MBQBGRtigogMamZl74uJyLxgwmM1UFAEVEWlJQAO9tChQAnK1vsRMR+ScKCgIFAAf1S2T6aYNDPRURkbAT80Gx9+BR3l6/h2snZdMrPub/OkRE/knM/2Z86eNAAUDdOyEi0raggsLMZppZqZmVmdndbbyeZGZF3utLzSyvxWv3eO2lZjajszHNbIQ3Rpk3ZuJJbmO7nHMU+/ycMaw/owarAKCISFs6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5n3AA95Y+7yxu8UKrwDgHC1ii4i0K5gjiqlAmXNus3OuHpgHzGrVZxbwtPd4ATDdAnetzQLmOeeOOue2AGXeeG2O6b3nYm8MvDGv+dxb14n5Pn+gAOCEod31ESIiES+YoMgG/C2el3ttbfZxzjUCNcDADt7bXvtAYL83RnufBYCZ3W5mPjPzVVZWBrEZ/2xYel++cm4e/ZISPtf7RURiQcT+hnTOPQE8AVBYWOg+zxjfvPCULp2TiEg0CuaIogJoeRI/x2trs4+ZJQBpQFUH722vvQro743R3meJiEgPCiYoSoB872qkRAKL0wtb9VkI3OI9vh5Y7JxzXvtc76qoEUA+sKy9Mb33vOONgTfmy59/80RE5GR1eurJOddoZncCi4B44Enn3FozuxfwOecWAn8AnjWzMqCawC9+vH7FwDqgEbjDOdcE0NaY3kf+KzDPzH4KrPDGFhGRELHAP+IjW2FhofP5fKGehohIRDGz5c65ws76xfyd2SIi0jEFhYiIdEhBISIiHVJQiIhIh6JiMdvMKoFtn/Ptg4C9XTidSKBtjg3a5uh3sts73DmX0VmnqAiKk2FmvmBW/aOJtjk2aJujX09tr049iYhIhxQUIiLSIQWFV1gwxmibY4O2Ofr1yPbG/BqFiIh0TEcUIiLSIQWFiIh0KKaDwsxmmlmpmZWZ2d2hns+JMLNcM3vHzNaZ2Voz+47Xnm5mb5rZJu+/A7x2M7OHvG39xMzOaDHWLV7/TWZ2S4v2yWa22nvPQ95X1Yac973rK8zsL97zEWa21JtnkVe6Hq+8fZHXvtTM8lqMcY/XXmpmM1q0h93PhJn1N7MFZrbBzNab2dnRvp/N7Hvez/UaM/uTmSVH2342syfNbI+ZrWnR1u37tb3P6JBzLib/EChv/ikwEkgEVgEFoZ7XCcw/CzjDe5wCbAQKgF8Ad3vtdwP3eY+vAF4DDDgLWOq1pwObvf8O8B4P8F5b5vU1772Xh3q7vXndBTwP/MV7XgzM9R4/BnzTe/wt4DHv8VygyHtc4O3vJGCE93MQH64/EwS+O/4273Ei0D+a9zOBrz/eAvRusX+/Em37GTgfOANY06Kt2/dre5/R4VxD/T9BCH8YzwYWtXh+D3BPqOd1EtvzMnApUApkeW1ZQKn3+HHgxhb9S73XbwQeb9H+uNeWBWxo0f6ZfiHczhzgbeBi4C/e/wR7gYTW+5XA952c7T1O8PpZ6319rF84/kwQ+LbILXgXnrTef9G4nwkEhd/75Zfg7ecZ0bifgTw+GxTdvl/b+4yO/sTyqadjP4zHlHttEcc71J4ELAUynXM7vZd2AZne4/a2t6P28jbaQ+1/gB8Bzd7zgcB+51yj97zlPI9vm/d6jdf/RP8uQmkEUAk85Z1u+72Z9SWK97NzrgK4H9gO7CSw35YT3fv5mJ7Yr+19RrtiOSiigpn1A14Avuucq235mgv8kyFqrn82syuBPc655aGeSw9KIHB64rfOuUnAIQKnC46Lwv08AJhFICSHAn2BmSGdVAj0xH4N9jNiOSgqgNwWz3O8tohhZr0IhMRzzrkXvebdZpblvZ4F7PHa29vejtpz2mgPpXOBq81sKzCPwOmnB4H+Znbsa31bzvP4tnmvpwFVnPjfRSiVA+XOuaXe8wUEgiOa9/MlwBbnXKVzrgF4kcC+j+b9fExP7Nf2PqNdsRwUJUC+dyVFIoFFsIUhnlPQvCsY/gCsd879usVLC4FjVz7cQmDt4lj7zd7VE2cBNd7h5yLgMjMb4P1L7jIC5293ArVmdpb3WTe3GCsknHP3OOdynHN5BPbXYufcl4B3gOu9bq23+djfxfVef+e1z/WulhkB5BNY+Au7nwnn3C7Ab2ZjvKbpBL6DPmr3M4FTTmeZWR9vTse2OWr3cws9sV/b+4z2hXLRKtR/CFxJsJHAFRD/Fur5nODczyNwyPgJsNL7cwWBc7NvA5uAt4B0r78Bj3jbuhoobDHW14Ay789XW7QXAmu89zxMqwXVEG//hfzjqqeRBH4BlAHzgSSvPdl7Xua9PrLF+//N265SWlzlE44/E8BEwOft6z8TuLolqvcz8BNggzevZwlcuRRV+xn4E4E1mAYCR4639sR+be8zOvqjEh4iItKhWD71JCIiQVBQiIhIhxQUIiLSIQWFiIh0SEEhIiIdUlCIiEiHFBQiItKh/w/uhegfvR+Q7QAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xs = list(range(100000))\n",
"plt.plot(xs, lrs)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4f4e282c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/wenet/venv/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"from typing import Union\n",
"\n",
"from paddle.optimizer.lr import LRScheduler\n",
"from typeguard import check_argument_types\n",
"\n",
"class WarmupLR(LRScheduler):\n",
" \"\"\"The WarmupLR scheduler\n",
" This scheduler is almost same as NoamLR Scheduler except for following\n",
" difference:\n",
" NoamLR:\n",
" lr = optimizer.lr * model_size ** -0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" WarmupLR:\n",
" lr = optimizer.lr * warmup_step ** 0.5\n",
" * min(step ** -0.5, step * warmup_step ** -1.5)\n",
" Note that the maximum lr equals to optimizer.lr in this scheduler.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" warmup_steps: Union[int, float]=25000,\n",
" learning_rate=1.0,\n",
" last_epoch=-1,\n",
" verbose=False):\n",
" assert check_argument_types()\n",
" self.warmup_steps = warmup_steps\n",
" super().__init__(learning_rate, last_epoch, verbose)\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n",
"\n",
" def get_lr(self):\n",
" step_num = self.last_epoch + 1\n",
" return self.base_lr * self.warmup_steps**0.5 * min(\n",
" step_num**-0.5, step_num * self.warmup_steps**-1.5)\n",
"\n",
" def set_step(self, step: int):\n",
" self.step(step)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8c40b202",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1\n"
]
}
],
"source": [
"sc = WarmupLR(warmup_steps=25000, learning_rate=0.001)\n",
"print(step)\n",
"#sc.set_step(step)\n",
"sc.set_step(0)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ecbc7e37",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f0ba6dd9c40>]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqaUlEQVR4nO3de3xU9Z3/8dcnCUm4JIGEEAIBEiCAQW4SEG94F7QqagGhu9Varb9a3W51267+tr/dtrvdVevW1VardrVaa4WAN7QqKqJ4QchwvwYiAZMQICQQ7uT2/f0xB4xpLoMkmcnM+/l48GDmO99z5ns4Yd4553vOZ8w5h4iISHOigj0AEREJbQoKERFpkYJCRERapKAQEZEWKShERKRFMcEeQFvo3bu3y8zMDPYwREQ6lRUrVux1zqW21i8sgiIzMxOfzxfsYYiIdCpmtiOQfjr1JCIiLVJQiIhIixQUIiLSIgWFiIi0SEEhIiItCigozGyqmRWYWaGZ3dvE63FmNtd7fZmZZTZ47T6vvcDMpjRof8bM9pjZ+kbrSjazd81sq/d3r9PYPhEROU2tBoWZRQOPAVcCOcBsM8tp1O1WYJ9zbijwMPCAt2wOMAsYCUwFHvfWB/Cs19bYvcAi51w2sMh7LiIiQRLIEcVEoNA5t805Vw3MAaY16jMNeM57PB+41MzMa5/jnDvunCsCCr314ZxbAlQ28X4N1/UccF3gmyPtaVv5IT4o2BPsYYhIBwskKPoDxQ2el3htTfZxztUCVUBKgMs2luacK/Me7wLSmupkZrebmc/MfOXl5QFshpyuWU99xnf+mM+iTbuDPRQR6UAhPZnt/N+q1OQ3KznnnnLO5TrnclNTW70DXU7T1t0H2XPwOAA/mrOaz8sPBXlEItJRAgmKUmBAg+cZXluTfcwsBkgCKgJctrHdZpburSsd0LmOEJDnKyYmynj9rvPpEhPF7X/ycfBYTbCHJSIdIJCgyAeyzSzLzGLxT04vaNRnAXCz93g68L53NLAAmOVdFZUFZAPLW3m/huu6GXgtgDFKO6qpq+fllaVcdkYaozKSeOxbZ7G94gj35K2hvl5fpSsS7loNCm/O4S5gIbAJyHPObTCzX5rZtV63p4EUMysE7sG7Usk5twHIAzYCbwN3OufqAMzsRWApMNzMSszsVm9d9wOXm9lW4DLvuQTRok17qDhczcwJGQCcMySFf7nqDN7duJtH398a5NGJSHsz/y/+nVtubq5T9dj2c+uz+azfWcUn/3wJMdH+3y2cc/x43lpeWlnCI7PGMm1sa9coiEioMbMVzrnc1vqF9GS2BN/uA8dYXLCHb56VcTIkAMyM/7zhTCZmJfOTeWvJ397Ulc4iEg4UFNKil1aWUO9gZu6Av3ktLiaap749noxeXbn9Tz627z0chBGKSHtTUEiznHPM85UwMSuZzN7dm+zTs1ssf7xlAmbGLc/ms+9wdQePUkTam4JCmrW8qJKivYe5sYmjiYYGpXTnDzeNp3T/Ub73Jx9Hq+s6aIQi0hEUFNKsPF8JPeJiuHJU31b7jh+UzP/cOJYVX+zjBy+soKauvgNGKCIdQUEhTTp4rIY315VxzZh+dIsN7KvVrxqVzn9eP4rFBeX8eJ7usRAJF4F9AkjEeWNtGUdr6rhxQsunnRqbPXEg+45U8+DbBSR17cIvrh2Jvz6kiHRWCgpp0tz8Yoal9WBMRtIpL3vHhUPYf6SGp5Zso2fXLtxzxfB2GKGIdBQFhfyNLbsPsrp4Pz/7xhlf62jAzLjvyhFUHanh0fcLiY2J4q5LstthpCLSERQU8jfy8ovpEm1cP+7r323tvyFvFDV19Tz0zhbMjDsvHtqGoxSRjqKgkK+orq3nlVX+AoApPeJOa13RUcavZ4zBAb9eWACgsBDphBQU8hXvb97tLwDYyr0TgYqOMh6aMQbwh4UZ/OAihYVIZ6KgkK/I85XQNzGeycPa7sugToSFc44H3y6gptbxw0uH6mookU5CQSEn7ao6xgcFe7jjoiFER7Xth3h0lPHfM8cSEx3Fw+9toepoDT/7xhlEtfH7iEjbU1DISScKAM4Y3zannRqLjjIe/OZoEuJjeOaTIg4cq+H+G0Z9pSqtiIQeBYUAJwoAFnN2CwUA20JUlPGvV+eQ1LUL//PeVg4dq+WR2WOJi4lut/cUkdOjX+UEgGVFlWyvOHLKd2J/HWbGjy4bxr9encPbG3bx3WfzOaDv3xYJWQoKASDPV0xCXAxXnpneYe/53fOz+O8ZY1i2rZIZv1/Kzv1HO+y9RSRwCgrhwIkCgGP70TW2Y08BfXN8Bs/eMpGd+49y3WOfsL60qkPfX0Rap6AQ3lhTxrGa+ja7d+JUnZ/dm3l3nENMlHHjk0tZXLAnKOMQkaYpKIS5vmKGpyV8rQKAbWVE30ReufM8BqV057bnfDy/dDvOqUy5SChQUES4gl0HWVO8n5kTBgT9Bri0xHjyvn8OFw5L5f+9toH/+8o6qmv1BUgiwaagiHB5vtMvANiWesTF8IebcvnBRUN4cXkxs//wGXsOHgv2sEQimoIigp0oAHh5ThrJ3WODPZyToqOMn04dwe++NY6NOw9w7W8/YW3J/mAPSyRiKSgi2KJNu6k8XM2MIE1it+bq0f2Yf8c5REcZ059YyjxfcbCHJBKRFBQRLM9X7C8AmN12BQDb2sh+SSy46zxyB/XiJ/PX8uN5azhaXRfsYYlEFAVFhNpVdYwPt5QzfXxGmxcAbGspPeJ4/taz+eElQ3lpZQnTHvuYwj0Hgz0skYihoIhQJwsA5mYEeygBiY4y7rliOM/dMpGKQ9Vc+7tPeGVVSbCHJRIRFBQRqL7ekecrZtLgZAaltF8BwPYweVgqf/3hBZzZL4m7567hp/PXcPh4bbCHJRLWFBQRaPn2SnZ0UAHA9tA3KZ6/fO9sfnDREOatKOGqRz9i5Rf7gj0skbCloIhAefn+AoBTR3ZcAcC2FhMdxU+njmDO9yZRW+eY8cRSHn53C7V1ukFPpK0FFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTWlunmV1qZivNbLWZfWxm+oLlNnTgWA1vri/j2iAUAGwPZw9O4a0fXcC0Mf14ZNFWpj+xlKK9h4M9LJGw0mpQmFk08BhwJZADzDaznEbdbgX2OeeGAg8DD3jL5gCzgJHAVOBxM4tuZZ2/B/7OOTcW+Avws9PaQvmK19fsDGoBwPaQGN+F39w4lt/OHse28kNc9chHPPtJEfX1qhUl0hYCOaKYCBQ657Y556qBOcC0Rn2mAc95j+cDl5q/cNA0YI5z7rhzrggo9NbX0jodkOg9TgJ2fr1Nk6bk5Rczom8Co4NYALC9XDOmHwvvnszErGR+/vpGZj65lG3lh4I9LJFOL5Cg6A80vCW2xGtrso9zrhaoAlJaWLaldd4GvGlmJcC3gfubGpSZ3W5mPjPzlZeXB7AZsnnXAdaUVDEzN/gFANtLelJXnr1lAg/NGMOW3QeZ+shHPPHh55q7EDkNoTiZfTdwlXMuA/gj8JumOjnnnnLO5TrnclNTQ/fO4lCSl19Cl2jjuhApANhezIzp4zN4754LuXh4Kve/tZkbfv8pm3cdCPbQRDqlQIKiFGh4QjvDa2uyj5nF4D9lVNHCsk22m1kqMMY5t8xrnwucG9CWSIv8BQBLuCKnb0gVAGxPfRLjeeLvx/O7b42jdN9RvvHox/znm5t034XIKQokKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzv/N86swCY5V0VlQVkA8tbWOc+IMnMhnnruhzY9PU3T054b9Nu9h2p6TR3YrcVM+Pq0f14754LmTE+g6eWbOOy33zIW+vK9MVIIgGKaa2Dc67WzO4CFgLRwDPOuQ1m9kvA55xbADwNPG9mhUAl/g9+vH55wEagFrjTOVcH0NQ6vfbvAS+ZWT3+4Phum25xhMrzFZOeFM8FIVwAsD316h7L/d8czYzcAfzs1fXc8cJKLhyWyi+uHUlm7851d7pIR7Nw+K0qNzfX+Xy+YA8jZJVVHeW8+9/nzouH8k9XDA/2cIKutq6ePy3dwW/e3UJ1XT3fv3AI379wMN1iW/29SSSsmNkK51xua/1CcTJb2thLK7wCgOPD596J0xETHcV3z89i0T9dyJSRfXl00VYueehDXl5ZonsvRJqgoAhz/gKAJZwzOIWBKd2CPZyQkpYYz29nj2Pe98+hT2Ic9+St4brHP8G3vTLYQxMJKQqKMLesqJIvKjtvAcCOMCEzmVd/cB6/mTmG3QeOMf2Jpdz5l5UUVx4J9tBEQoJOyoa5PF8xCfExTD2zb7CHEtKioowbzspg6pl9efLDbTy55HPe3bCbv580iDsvHkJKj7hgD1EkaHREEcaqjtbw5roypo3tR3yXzl8AsCN0i43h7suHsfjHF3H9uP48+2kRkx9czMPvbuHgsZpgD08kKBQUYez1NTs5XhteBQA7SnpSVx6YPpp37p7M5GGpPLJoKxf++gOe/riIYzX6zm6JLAqKMJbn8xcAHNU//AoAdpShfRL4/d+P57U7zyMnPZF/f2Mjlzz0AS8u/4LqWtWPksigoAhTm8oOsDbMCwB2pDEDevLn287mhdvOJjUxnvteXsfFD33Anz/bwfFaHWFIeFNQhKk8XzGx0VFcH+YFADvaeUN78+oPzuXZWybQJzGOn726ngsf/IDnPt2uU1ISthQUYeh4bR2vrirl8pFp9IqQAoAdycy4aHgfXr7jXP5869kMSO7Kvy3YwOQHF/P0x0UcrVZgSHjR5bFh6L2Ne9h3pEaT2O3MzDg/uzfnDU1h6bYKHl20lX9/YyO/e38rN52TyU3nDNJltRIWFBRhKM9XTL+keM4f2jvYQ4kIZsa5Q3pz7pDe5G+v5MkPP+eRRVt5csnnzBg/gNsuyGJQigoPSueloAgzO/cfZcnWcv7h4qFER2kSu6NNyExmQmYyhXsO8tSSbczNL+aFZTu48sx0bp88mDEDegZ7iCKnTEERZl5aUYJzMEOnnYJqaJ8EHpw+hh9fMZw/frqdP3+2g7+uK2NiVjLfOTeTK3LSiInWFKF0DiozHkbq6x0XPrSYAb268ZfvTQr2cKSBQ8drmbP8C579dDsl+47SLymev5s0iNkTB0bMNw5K6FGZ8Qj0WVEFxZVHVQAwBPWIi+G2Cwbz4U8u5g835ZKV2p1fLyxg0n8t4sfz1rC+tCrYQxRplk49hZG8fH8BwCkjVQAwVEVHGZfnpHF5Thpbdx/kuaXbeXllKfNXlDB+UC9uOmcQU0b2VW0uCSkKijBRdbSGt9bvYmbuAH3IdBLZaQn8x3Wj+MmUEcxfUcLzS7fzj3NW07NbF24Yl8HsiQPITksI9jBFFBThYoEKAHZaSV27cOv5WdxybiZLt1Xw4vIveP6z7TzzSRG5g3oxe+JArhqVTtdY/QIgwaHJ7DBxzW8/prbe8eYPz1dtpzBQceg4L68s5cXlX7Bt72ES4mO4YVx/Zk4YwMh+KvIobSPQyWwdUYSBjTsPsK60in+7JkchESZSesTxvcmDue2CLJYXVfLi8i94Mb+Y55buYETfBL55VgbTxvajT2J8sIcqEUBBEQZOFAC8bqwKAIYbM+PswSmcPTiFnx+p5vW1Zby0ooRfvbmJ/3prE5OHpXLDWRlckZOmuSlpNwqKTu54bR2vrlYBwEjQs1ss3540iG9PGsTn5Yd4eWUJr6ws5YcvriIhLoZvjE7nhrMyyB3UiyjdlS9tSEHRyb27cTf7j9RwoyaxI8qQ1B78ZMoI/uny4XxWVMFLK0pZsGYnc/L9db6uHtOPq0enM6p/kk5HymnTZHYnd9Mzy/l8zyGW/PRi1XaKcEeqa3lnw25eX7OTJVvLqalzDErpxjWj+3H1mHSGpyUoNOQrNJkdAUr3H+WjreX8wyXZCgmhW2wM143rz3Xj+lN1pIaFG3bx+tqdPP5BIb9bXEh2nx5cPbof14xJZ3Bqj2APVzoRBUUndrIA4PiMYA9FQkxSty7MnDCAmRMGsPfQcd5aV8bra8t4+L0tPPzeFkb0TWDKyL5MGdmXM9J1pCEt06mnTqq+3jH514sZlNKNF25TAUAJTFnVUd5ct4uF63eRv6MS52BgcjemjExjysi+nDVQE+GRRKeewtxn2yoo2XeUn0wZHuyhSCeSntSVW8/P4tbzsyg/eJz3Nu1m4YZdPPvpdv7wURGpCXFcnpPG1JF9mTQ4hdgY1Q0VBUWnNddXTKIKAMppSE2IY/bEgcyeOJADx2pYvHkPCzfs4tVVpfxl2RckxMcweVgql47ow0XD+6gcegQLKCjMbCrwCBAN/K9z7v5Gr8cBfwLGAxXAjc657d5r9wG3AnXAD51zC1tap/lPlv4HMMNb5vfOuUdPbzPDS9URfwHAWRNUAFDaRmJ8F6aN7c+0sf05VlPHx1v38s7GXSwuKOeva8swg3EDenLJiD5cMiJN8xoRptWgMLNo4DHgcqAEyDezBc65jQ263Qrsc84NNbNZwAPAjWaWA8wCRgL9gPfMbJi3THPr/A4wABjhnKs3sz5tsaHhZMGaUqpVAFDaSXyXaC7LSeOynDTq6x3rd1bx/uY9vL95Dw+9s4WH3tlCelI8F4/owyXD+3De0N4qWBjmAjmimAgUOue2AZjZHGAa0DAopgE/9x7PB37nHRlMA+Y4544DRWZW6K2PFtZ5B/At51w9gHNuz9ffvPA011dMTnoiZ/ZXcThpX1FRxuiMnozO6MmPLhvGngPH+KCgnEWbd/Oad4oqLiaKiVnJTM5O5YJhvXW/RhgKJCj6A8UNnpcAZzfXxzlXa2ZVQIrX/lmjZU8UJGpunUPwH41cD5TjP121tfGgzOx24HaAgQMHBrAZ4WHDzirWlx7g59fkBHsoEoH6JMafvOz2eG0dy4sqWby5nI+2lvOrNzfBm/65jwuyezM5O5XzhvYmNSEu2MOW0xSKk9lxwDHnXK6Z3QA8A1zQuJNz7ingKfBfHtuxQwyeeb4SfwHAcSoAKMEVFxPNBdmpXJCdCvgvvf1o614+2rqXxZv38PLKUgBy0hO5YJg/OHIzexEXo9NUnU0gQVGKf87ghAyvrak+JWYWAyThn9Ruadnm2kuAl73HrwB/DGCMEeFYTR2vrCrlipFp9OymK1AktKQndWVm7gBm5g6gvt6xYecBlmwtZ8mWcp75uIgnP9xGfJcocgclc86QFCYNTmF0RhJdonUJbqgLJCjygWwzy8L/YT4L+FajPguAm4GlwHTgfeecM7MFwF/M7Df4J7OzgeWAtbDOV4GLgSLgQmDL1966MPPuxt1UHa3hxgmaxJbQFhVljMpIYlRGEndePJTDx2tZVlTBki17+WxbBb9eWABAt9hoJmQmM2lwCucMSeHMfonEKDhCTqtB4c053AUsxH8p6zPOuQ1m9kvA55xbADwNPO9NVlfi/+DH65eHf5K6FrjTOVcH0NQ6vbe8H3jBzO4GDgG3td3mdm55vmL69+zKeUN6B3soIqeke1wMl4xI45IRaYD/G/yWFVXy2bYKln5ewQNvbwYgIS6GCVnJnOMFxxnpiapjFgJUwqOTKNl3hAseXMwPL8nm7suHtb6ASCdSfvC4PzS2VfDZ5xVs23sYgIT4GMYP6sWEzGRyB/VizICeuneoDamER5h5aYV/CmdGrgoASvhJTYjjmjH9uGZMPwB2VR3js20VLN9eSX5RJR8U+E9VdYk2RvVP8geHFx76wq72pyOKTuBEAcDMlO78+bbGVyaLhL99h6tZsWMf+Tsq8W3fx9qS/dTU+T+7hvbpwYTMXuQOSmZCZjIDkrvqPo4A6YgijCz1CgD+dOqIYA9FJCh6dY89ebc4+K8AXFtSRf72SnzbK3ljbRkvLvffmpXSPZaxA3oybmBPxg3sxeiMJBLiuwRz+J2egqITmJtfTFLXLlzh/ScRiXTxXaKZmJXMxKxkwH/UvWXPQfK372P1F/tZXbyPRZv9RR3MILtPDy88ejF2QE+GpSVokvwUKChCXNWRGt7esIvZKgAo0qyoKGNE30RG9E3k25MGAf7/O2tK9rPKC453Nu4mz1cCQPfYaEZlJJ0MjtEZSfRNjNcpq2YoKELca14BwBkqAChySpK6dWHysFQmD/PfOe6cY0fFEVYV+486VhXv5w9LtlFb75/r6N0jjlH9ExnVP4lRGT0Z1T+JtMQ4hQcKipCX5ytmZD8VABQ5XWZGZu/uZPbuzvXj/FcPHqupY8POA6wvrWJtSRXrS6v4cEs5XnbQu0ccozOSOLN/EqP7+28gTEuMD+JWBIeCIoSdKAD4i2tHBnsoImEpvks04wf1YvygXifbjlTXsqnsAOtKqlhb6g+PDwr2nAyP1IQ4Rvf3h0dOv0Ry0hPJ6BXeV1opKEJYXn4xsTFRTBvbL9hDEYkY3WJjGD8omfGDkk+2HamuZePOA6wrrWJdSRXrSqtY3CA8EuJjOKNvIjn9EjkjPYEz0hMZlpYQNvOKCooQdaymjldX72TKyL4qACgSZN1iY/w3+GV+NTwKdh1kU9lBNpZVsansIHm+Yo5U1wEQHWUM7t3dCw//n5z0xE5Zdl1BEaLeOVEAUJPYIiGpW2wM4wb2YtzAL09b1dc7vqg8wsayA2wqO8DGnQfIL6rktdU7T/bp3SOOM9ITGJ6WwLC+/r+z03rQLTZ0P45Dd2QRbp5XAPDcISnBHoqIBCgq6ssJ86tGpZ9s33e4mk27/MGxqewgm8oO8KeiHVTX1p/sMyC5qz880hIY3tf/9+DU7iHx/R0KihBUsu8IHxfu5R8vzSZKNwWJdHq9usdy7pDenNug8nNdvWNHxWG27D7Ilt2HKNh9kC27DvJBQfnJS3ajo4zMlG4ng+PEn8yUbh1ajl1BEYLmr/DfFDR9vAoAioSr6ChjcGoPBqf2YOqZX7ZX19ZTtPfwyeDYsvsgG3ce4K31uzhRmi82Ooqs3t0ZmtaDe6eOYEByt3Ydq4IixNTXO+b5Sjh/aG8yerXvzheR0BMbE8Xwvv7TT4z5sv1odR2Few55RyAHKdxziHUlVcTGtP+RhYIixHz6eQWl+49y75UqACgiX+rqlR0ZldHxN9/qOwdDzFyfvwDg5SoAKCIhQkERQvYfqWbhhl1cP65/2NyoIyKdn4IihLy2eqdXAFCT2CISOhQUISTPV8yZ/RMZ2U8FAEUkdCgoQsT60io27DzATN2JLSIhRkERIvJ8XgHAMf2DPRQRka9QUISAYzV1vLqqlKkj+5LUTd/tKyKhRUERAhZu2MWBY7XcOEGnnUQk9CgoQsA8XwkZvbpyzmAVABSR0KOgCLLiSn8BwBnjB6gAoIiEJAVFkM1fUYIZTNe9EyISohQUQVRX75i/wl8AsH/PrsEejohIkxQUQfTp53sp3X9Uk9giEtIUFEE0N7+Ynt1UAFBEQpuCIkj2H6nmnQ27uW5s/5D4qkMRkeYEFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTTmGdj5rZoa+5XSHv1VWlVNfVq2SHiIS8VoPCzKKBx4ArgRxgtpnlNOp2K7DPOTcUeBh4wFs2B5gFjASmAo+bWXRr6zSzXKDXaW5bSMvzlTCqfxI5/RKDPRQRkRYFckQxESh0zm1zzlUDc4BpjfpMA57zHs8HLjUz89rnOOeOO+eKgEJvfc2u0wuRXwM/Pb1NC13rS6vYWHaAmbokVkQ6gUCCoj9Q3OB5idfWZB/nXC1QBaS0sGxL67wLWOCcK2tpUGZ2u5n5zMxXXl4ewGaEjjxfMXExUVw7VgUARST0hdRktpn1A2YAv22tr3PuKedcrnMuNzU1tf0H10ZOFgA8sy9JXVUAUERCXyBBUQo0nHHN8Nqa7GNmMUASUNHCss21jwOGAoVmth3oZmaFAW5Lp3CyAKAmsUWkkwgkKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzvnHNe+yzvqqgsIBtY3tw6nXN/dc71dc5lOucygSPeBHnYyPMVMyC5K5NUAFBEOomY1jo452rN7C5gIRANPOOc22BmvwR8zrkFwNPA895v/5X4P/jx+uUBG4Fa4E7nXB1AU+ts+80LLcWVR/iksIJ7Lh+mAoAi0mm0GhQAzrk3gTcbtf1rg8fH8M8tNLXsr4BfBbLOJvr0CGR8ncU8rwDgN8fraicR6TxCajI7nNXVO+b7irkgO1UFAEWkU1FQdJBPCveys+qYJrFFpNNRUHSQub5ienXrwmU5fYI9FBGRU6Kg6AD7Dlfz7obdXDdOBQBFpPNRUHSAV1erAKCIdF4KinbmnGNufjGjM5I4I10FAEWk81FQtLP1pQfYvOsgM3Q0ISKdlIKinZ0sADimX7CHIiLytSgo2tGxmjpeXV3KlSoAKCKdmIKiHb29fhcHj9Uyc4JOO4lI56WgaEcnCwBmqQCgiHReCop28kXFET79vIKZ4weoAKCIdGoKinYyf0WxCgCKSFhQULSDunrHvBUlTM5OpZ8KAIpIJ6egaAcfF+6lrOoYN2oSW0TCgIKiHeTl+wsAXnqGCgCKSOenoGhjlYereWfjLq4fl6ECgCISFhQUbezVVaXU1DlmTtAktoiEBwVFG3LOkecrZkxGEiP6qgCgiIQHBUUbWldapQKAIhJ2FBRt6GQBwLEqACgi4UNB0UaO1dTx2uqdXDUqncR4FQAUkfChoGgjJwsA6rSTiIQZBUUbmZtfzMDkbpydlRzsoYiItCkFRRvYUXGYpdsqmJmboQKAIhJ2FBRtYP6KEqJUAFBEwpSC4jTV1Tvmryhh8rBU0pNUAFBEwo+C4jR9tLWcsqpjmsQWkbCloDhNeb5ikrvHctkZacEeiohIu1BQnIbKw9W8u3E314/rT2yM/ilFJDwF9OlmZlPNrMDMCs3s3iZejzOzud7ry8wss8Fr93ntBWY2pbV1mtkLXvt6M3vGzEL27rVXThQA1GknEQljrQaFmUUDjwFXAjnAbDPLadTtVmCfc24o8DDwgLdsDjALGAlMBR43s+hW1vkCMAIYBXQFbjutLWwnzjnm+YoZM6Anw/smBHs4IiLtJpAjiolAoXNum3OuGpgDTGvUZxrwnPd4PnCpmZnXPsc5d9w5VwQUeutrdp3OuTedB1gOhOQ1p2tL/AUAZ+aG5PBERNpMIEHRHyhu8LzEa2uyj3OuFqgCUlpYttV1eqecvg283dSgzOx2M/OZma+8vDyAzWhbeb5i4rtEcc0YFQAUkfAWyjOwjwNLnHMfNfWic+4p51yucy43NTW1Qwd2tLqOBat3ctWZKgAoIuEvJoA+pUDD2doMr62pPiVmFgMkARWtLNvsOs3s34BU4P8EML4O9/aGMg4er2XmBE1ii0j4C+SIIh/INrMsM4vFPzm9oFGfBcDN3uPpwPveHMMCYJZ3VVQWkI1/3qHZdZrZbcAUYLZzrv70Nq99zM0vZlCKCgCKSGRo9YjCOVdrZncBC4Fo4Bnn3AYz+yXgc84tAJ4GnjezQqAS/wc/Xr88YCNQC9zpnKsDaGqd3ls+AewAlvrnw3nZOffLNtvi07Sj4jCfbavkJ1OG441PRCSsBXLqCefcm8Cbjdr+tcHjY8CMZpb9FfCrQNbptQc0pmCZ5/MKAJ6lq51EJDKE8mR2yDlRAPDCYan0TYoP9nBERDqEguIULNlazq4DKgAoIpFFQXEK8vKLSekey6UqACgiEURBEaCKQ8d5b5MKAIpI5NEnXoBOFgDUvRMiEmEUFAFwzpHnK2bsgJ4MS1MBQBGJLAqKAKwpqWLL7kOaxBaRiKSgCMCXBQDTgz0UEZEOp6BoxdHqOl5fvZOrRqWToAKAIhKBFBSteGu9vwDgjTrtJCIRSkHRirn5xWSmdGOiCgCKSIRSULRg+97DLCuqZEbuABUAFJGIpaBowbwVxSoAKCIRT0HRjNq6euavKOGi4X1UAFBEIpqCohkfbd3L7gPHmZmrowkRiWwKimbM9QoAXjJCBQBFJLIpKJqgAoAiIl/Sp2ATXllVSm2940YVABQRUVA05pxjbn4x4wb2JFsFAEVEFBSNrS7ez9Y9KgAoInKCgqKRPF8JXbtEc/VoFQAUEQEFxVccqa7l9TUqACgi0pCCooG31u3i0PFaTWKLiDSgoGhgrq+YrN7dmZDZK9hDEREJGQoKT9HewywvqmRGboYKAIqINKCg8MzzqQCgiEhTFBT4CwC+tLKEi4f3IS1RBQBFRBpSUABLtpaz+8BxZujeCRGRv6GgwF8AsHePWC49o0+whyIiEnIiPij2HjrOok17uH5cf7pER/w/h4jI34j4T8ZXVqoAoIhISwIKCjObamYFZlZoZvc28Xqcmc31Xl9mZpkNXrvPay8wsymtrdPMsrx1FHrrjD3NbWyWc448XzFnDezJ0D4qACgi0pRWg8LMooHHgCuBHGC2meU06nYrsM85NxR4GHjAWzYHmAWMBKYCj5tZdCvrfAB42FvXPm/d7WKVCgCKiLQqkCOKiUChc26bc64amANMa9RnGvCc93g+cKn571qbBsxxzh13zhUBhd76mlynt8wl3jrw1nnd1966VszzFfsLAI7p115vISLS6QUSFP2B4gbPS7y2Jvs452qBKiClhWWba08B9nvraO69ADCz283MZ2a+8vLyADbjbw1M7s53zsukR1zM11peRCQSdNpPSOfcU8BTALm5ue7rrOOOi4a06ZhERMJRIEcUpUDDk/gZXluTfcwsBkgCKlpYtrn2CqCnt47m3ktERDpQIEGRD2R7VyPF4p+cXtCozwLgZu/xdOB955zz2md5V0VlAdnA8ubW6S2z2FsH3jpf+/qbJyIip6vVU0/OuVozuwtYCEQDzzjnNpjZLwGfc24B8DTwvJkVApX4P/jx+uUBG4Fa4E7nXB1AU+v03vKfgTlm9h/AKm/dIiISJOb/Jb5zy83NdT6fL9jDEBHpVMxshXMut7V+EX9ntoiItExBISIiLVJQiIhIixQUIiLSorCYzDazcmDH11y8N7C3DYfTGWibI4O2Ofyd7vYOcs6lttYpLILidJiZL5BZ/3CibY4M2ubw11Hbq1NPIiLSIgWFiIi0SEHhFRaMMNrmyKBtDn8dsr0RP0chIiIt0xGFiIi0SEEhIiItiuigMLOpZlZgZoVmdm+wx3MqzGyAmS02s41mtsHM/tFrTzazd81sq/d3L6/dzOxRb1vXmtlZDdZ1s9d/q5nd3KB9vJmt85Z51Puq2qDzvnd9lZm94T3PMrNl3jjneqXr8crbz/Xal5lZZoN13Oe1F5jZlAbtIfczYWY9zWy+mW02s01mdk6472czu9v7uV5vZi+aWXy47Wcze8bM9pjZ+gZt7b5fm3uPFjnnIvIP/vLmnwODgVhgDZAT7HGdwvjTgbO8xwnAFiAHeBC412u/F3jAe3wV8BZgwCRgmdeeDGzz/u7lPe7lvbbc62veslcGe7u9cd0D/AV4w3ueB8zyHj8B3OE9/gHwhPd4FjDXe5zj7e84IMv7OYgO1Z8J/N8df5v3OBboGc77Gf/XHxcBXRvs3++E234GJgNnAesbtLX7fm3uPVoca7D/EwTxh/EcYGGD5/cB9wV7XKexPa8BlwMFQLrXlg4UeI+fBGY36F/gvT4beLJB+5NeWzqwuUH7V/oFcTszgEXAJcAb3n+CvUBM4/2K//tOzvEex3j9rPG+PtEvFH8m8H9bZBHehSeN91847mf8QVHsffjFePt5SjjuZyCTrwZFu+/X5t6jpT+RfOrpxA/jCSVeW6fjHWqPA5YBac65Mu+lXUCa97i57W2pvaSJ9mD7H+CnQL33PAXY75yr9Z43HOfJbfNer/L6n+q/RTBlAeXAH73Tbf9rZt0J4/3snCsFHgK+AMrw77cVhPd+PqEj9mtz79GsSA6KsGBmPYCXgB855w40fM35f2UIm+ufzexqYI9zbkWwx9KBYvCfnvi9c24ccBj/6YKTwnA/9wKm4Q/JfkB3YGpQBxUEHbFfA32PSA6KUmBAg+cZXlunYWZd8IfEC865l73m3WaW7r2eDuzx2pvb3pbaM5poD6bzgGvNbDswB//pp0eAnmZ24mt9G47z5LZ5rycBFZz6v0UwlQAlzrll3vP5+IMjnPfzZUCRc67cOVcDvIx/34fzfj6hI/Zrc+/RrEgOinwg27uSIhb/JNiCII8pYN4VDE8Dm5xzv2nw0gLgxJUPN+OfuzjRfpN39cQkoMo7/FwIXGFmvbzf5K7Af/62DDhgZpO897qpwbqCwjl3n3MuwzmXiX9/ve+c+ztgMTDd69Z4m0/8W0z3+juvfZZ3tUwWkI1/4i/kfiacc7uAYjMb7jVdiv876MN2P+M/5TTJzLp5YzqxzWG7nxvoiP3a3Hs0L5iTVsH+g/9Kgi34r4D4l2CP5xTHfj7+Q8a1wGrvz1X4z80uArYC7wHJXn8DHvO2dR2Q22Bd3wUKvT+3NGjPBdZ7y/yORhOqQd7+i/jyqqfB+D8ACoF5QJzXHu89L/ReH9xg+X/xtquABlf5hOLPBDAW8Hn7+lX8V7eE9X4GfgFs9sb1PP4rl8JqPwMv4p+DqcF/5HhrR+zX5t6jpT8q4SEiIi2K5FNPIiISAAWFiIi0SEEhIiItUlCIiEiLFBQiItIiBYWIiLRIQSEiIi36/zob5nVzA95IAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"lrs=[]\n",
"for i in range(100000):\n",
" sc.step()\n",
" lrs.append(sc.get_lr())\n",
"xs = list(range(100000))\n",
"plt.plot(xs, lrs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e613fe16",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0fd9f40",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "purple-consequence",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "defensive-mason",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "patient-convention",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Namespace(delta_delta=False, feat_dim=80, manifest_path='examples/aishell/s1/data/manifest.train.raw', num_samples=-1, num_workers=16, output_path='data/librispeech/mean_std.npz', sample_rate=16000, specgram_type='fbank', stride_ms=10.0, window_ms=25.0)\n"
]
}
],
"source": [
"import argparse\n",
"import functools\n",
"\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"from deepspeech.utils.utility import add_arguments\n",
"from deepspeech.utils.utility import print_arguments\n",
"\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, -1, \"# of samples to for statistics.\")\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc, fbank.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n",
"add_arg('delta_delta', bool,\n",
" False,\n",
" \"Audio feature with delta delta.\")\n",
"add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n",
"add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n",
"add_arg('sample_rate', int, 16000, \"target sample rate.\")\n",
"add_arg('manifest_path', str,\n",
" 'examples/aishell/s1/data/manifest.train.raw',\n",
" \"Filepath of manifest to compute normalizer's mean and stddev.\")\n",
"add_arg('num_workers',\n",
" default=16,\n",
" type=int,\n",
" help='num of subprocess workers for processing')\n",
"add_arg('output_path', str,\n",
" 'data/librispeech/mean_std.npz',\n",
" \"Filepath of write mean and stddev to (.npz).\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(args)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "enormous-currency",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import paddle\n",
"from paddle.io import DataLoader\n",
"from paddle.io import Dataset\n",
"\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"class CollateFunc(object):\n",
" ''' Collate function for AudioDataset\n",
" '''\n",
" def __init__(self):\n",
" pass\n",
" \n",
" def __call__(self, batch):\n",
" mean_stat = None\n",
" var_stat = None\n",
" number = 0\n",
" for feat in batch:\n",
" sums = np.sum(feat, axis=1)\n",
" if mean_stat is None:\n",
" mean_stat = sums\n",
" else:\n",
" mean_stat += sums\n",
"\n",
" square_sums = np.sum(np.square(feat), axis=1)\n",
" if var_stat is None:\n",
" var_stat = square_sums\n",
" else:\n",
" var_stat += square_sums\n",
"\n",
" number += feat.shape[1]\n",
" #return paddle.to_tensor(number), paddle.to_tensor(mean_stat), paddle.to_tensor(var_stat)\n",
" return number, mean_stat, var_stat\n",
"\n",
"\n",
"class AudioDataset(Dataset):\n",
" def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None):\n",
" self.feature_func = feature_func\n",
" self._rng = rng\n",
" manifest = read_manifest(manifest_path)\n",
" if num_samples == -1:\n",
" sampled_manifest = manifest\n",
" else:\n",
" sampled_manifest = self._rng.sample(manifest, num_samples)\n",
" self.items = sampled_manifest\n",
"\n",
" def __len__(self):\n",
" return len(self.items)\n",
"\n",
" def __getitem__(self, idx):\n",
" key = self.items[idx]['feat']\n",
" audioseg = AudioSegment.from_file(key)\n",
" feat = self.feature_func(audioseg) #(D, T)\n",
" return feat"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "armed-semester",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"process 1000 wavs,450739 frames\n",
"process 2000 wavs,887447 frames\n",
"process 3000 wavs,1354148 frames\n",
"process 4000 wavs,1816494 frames\n",
"process 5000 wavs,2359211 frames\n",
"process 6000 wavs,2828455 frames\n",
"process 7000 wavs,3276186 frames\n",
"process 8000 wavs,3692234 frames\n",
"process 9000 wavs,4139360 frames\n",
"process 10000 wavs,4591528 frames\n",
"process 11000 wavs,5020114 frames\n",
"process 12000 wavs,5459523 frames\n",
"process 13000 wavs,5899534 frames\n",
"process 14000 wavs,6323242 frames\n",
"process 15000 wavs,6736597 frames\n",
"process 16000 wavs,7207686 frames\n",
"process 17000 wavs,7637800 frames\n",
"process 18000 wavs,8093004 frames\n",
"process 19000 wavs,8529518 frames\n",
"process 20000 wavs,8906022 frames\n",
"process 21000 wavs,9352652 frames\n",
"process 22000 wavs,9807495 frames\n",
"process 23000 wavs,10247938 frames\n",
"process 24000 wavs,10700011 frames\n",
"process 25000 wavs,11126134 frames\n",
"process 26000 wavs,11558061 frames\n",
"process 27000 wavs,12010359 frames\n",
"process 28000 wavs,12470938 frames\n",
"process 29000 wavs,12916013 frames\n",
"process 30000 wavs,13345816 frames\n",
"process 31000 wavs,13752365 frames\n",
"process 32000 wavs,14174801 frames\n",
"process 33000 wavs,14642170 frames\n",
"process 34000 wavs,15053557 frames\n",
"process 35000 wavs,15531890 frames\n",
"process 36000 wavs,16022711 frames\n",
"process 37000 wavs,16437688 frames\n",
"process 38000 wavs,16859517 frames\n",
"process 39000 wavs,17307676 frames\n",
"process 40000 wavs,17796629 frames\n",
"process 41000 wavs,18264151 frames\n",
"process 42000 wavs,18711898 frames\n",
"process 43000 wavs,19159890 frames\n",
"process 44000 wavs,19576435 frames\n",
"process 45000 wavs,19992793 frames\n",
"process 46000 wavs,20464449 frames\n",
"process 47000 wavs,20886021 frames\n",
"process 48000 wavs,21317318 frames\n",
"process 49000 wavs,21738034 frames\n",
"process 50000 wavs,22171890 frames\n",
"process 51000 wavs,22622238 frames\n",
"process 52000 wavs,23100734 frames\n",
"process 53000 wavs,23526901 frames\n",
"process 54000 wavs,23969746 frames\n",
"process 55000 wavs,24418691 frames\n",
"process 56000 wavs,24862546 frames\n",
"process 57000 wavs,25336448 frames\n",
"process 58000 wavs,25778435 frames\n",
"process 59000 wavs,26216199 frames\n",
"process 60000 wavs,26694692 frames\n",
"process 61000 wavs,27148978 frames\n",
"process 62000 wavs,27617088 frames\n",
"process 63000 wavs,28064946 frames\n",
"process 64000 wavs,28519843 frames\n",
"process 65000 wavs,28989722 frames\n",
"process 66000 wavs,29470156 frames\n",
"process 67000 wavs,29952931 frames\n",
"process 68000 wavs,30360555 frames\n",
"process 69000 wavs,30797929 frames\n",
"process 70000 wavs,31218227 frames\n",
"process 71000 wavs,31663934 frames\n",
"process 72000 wavs,32107468 frames\n",
"process 73000 wavs,32541943 frames\n",
"process 74000 wavs,33010702 frames\n",
"process 75000 wavs,33448082 frames\n",
"process 76000 wavs,33886812 frames\n",
"process 77000 wavs,34338108 frames\n",
"process 78000 wavs,34761495 frames\n",
"process 79000 wavs,35199730 frames\n",
"process 80000 wavs,35669630 frames\n",
"process 81000 wavs,36122402 frames\n",
"process 82000 wavs,36604561 frames\n",
"process 83000 wavs,37085552 frames\n",
"process 84000 wavs,37517500 frames\n",
"process 85000 wavs,37987196 frames\n",
"process 86000 wavs,38415721 frames\n",
"process 87000 wavs,38889467 frames\n",
"process 88000 wavs,39337809 frames\n",
"process 89000 wavs,39792342 frames\n",
"process 90000 wavs,40287946 frames\n",
"process 91000 wavs,40719461 frames\n",
"process 92000 wavs,41178919 frames\n",
"process 93000 wavs,41659635 frames\n",
"process 94000 wavs,42132985 frames\n",
"process 95000 wavs,42584564 frames\n",
"process 96000 wavs,43018598 frames\n",
"process 97000 wavs,43480662 frames\n",
"process 98000 wavs,43973670 frames\n",
"process 99000 wavs,44448190 frames\n",
"process 100000 wavs,44935034 frames\n",
"process 101000 wavs,45379812 frames\n",
"process 102000 wavs,45821207 frames\n",
"process 103000 wavs,46258420 frames\n",
"process 104000 wavs,46743733 frames\n",
"process 105000 wavs,47206922 frames\n",
"process 106000 wavs,47683041 frames\n",
"process 107000 wavs,48122809 frames\n",
"process 108000 wavs,48594623 frames\n",
"process 109000 wavs,49086358 frames\n",
"process 110000 wavs,49525568 frames\n",
"process 111000 wavs,49985820 frames\n",
"process 112000 wavs,50428262 frames\n",
"process 113000 wavs,50897957 frames\n",
"process 114000 wavs,51344589 frames\n",
"process 115000 wavs,51774621 frames\n",
"process 116000 wavs,52243372 frames\n",
"process 117000 wavs,52726025 frames\n",
"process 118000 wavs,53170026 frames\n",
"process 119000 wavs,53614141 frames\n",
"process 120000 wavs,54071271 frames\n"
]
}
],
"source": [
"\n",
"augmentation_pipeline = AugmentationPipeline('{}')\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20)\n",
"\n",
"def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
"\n",
"collate_func = CollateFunc()\n",
"\n",
"dataset = AudioDataset(\n",
" args.manifest_path,\n",
" augment_and_featurize, \n",
" args.num_samples)\n",
"\n",
"batch_size = 20\n",
"data_loader = DataLoader(\n",
" dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=args.num_workers,\n",
" collate_fn=collate_func)\n",
"\n",
"with paddle.no_grad():\n",
" all_mean_stat = None\n",
" all_var_stat = None\n",
" all_number = 0\n",
" wav_number = 0\n",
" for i, batch in enumerate(data_loader()):\n",
" #for batch in data_loader():\n",
" number, mean_stat, var_stat = batch\n",
" if i == 0:\n",
" all_mean_stat = mean_stat\n",
" all_var_stat = var_stat\n",
" else:\n",
" all_mean_stat += mean_stat\n",
" all_var_stat += var_stat\n",
" all_number += number\n",
" wav_number += batch_size\n",
"\n",
" if wav_number % 1000 == 0:\n",
" print('process {} wavs,{} frames'.format(wav_number,\n",
" all_number))\n",
"\n",
"cmvn_info = {\n",
" 'mean_stat': list(all_mean_stat.tolist()),\n",
" 'var_stat': list(all_var_stat.tolist()),\n",
" 'frame_num': all_number\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "danish-executive",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'mean_stat': [-813852467.7953382, -769025957.9140725, -809499593.411409, -774700574.014532, -750961217.5896736, -760564397.2864963, -805662399.3771614, -843490965.4231446, -850242081.9416809, -857678651.504435, -879067453.9826999, -908602072.3856701, -936850957.7187386, -957242686.489041, -968425442.0916103, -972687545.5953809, -980383731.7683417, -991533337.6343704, -1001966818.1164789, -1010334169.7486078, -1016855066.9099333, -1022176245.7021623, -1025700476.4788507, -1030678878.3195274, -1037075963.124199, -1042705719.0195516, -1047422212.6492896, -1049003537.271861, -1050314833.7453628, -1050772191.0204058, -1050010034.9948177, -1050436065.1336465, -1053327181.7978873, -1058710548.2036785, -1065950852.4966162, -1071709705.0060445, -1077682778.259181, -1083371045.272074, -1089708906.2657735, -1096312217.7865202, -1101089858.8364556, -1104965332.4332569, -1107791702.5223634, -1109431075.2374773, -1110066333.0280604, -1110382732.0722318, -1110480306.3793216, -1110203297.7110727, -1109972534.3583376, -1109378081.8792782, -1108212059.413654, -1107235713.2041805, -1106973581.9280007, -1107352339.7860134, -1108730029.862537, -1110425202.83704, -1113220669.4552443, -1115887535.4870913, -1118105356.3628063, -1120001376.8503075, -1121135822.320366, -1122265971.8751016, -1123990217.401155, -1125786729.6230593, -1127784957.2745507, -1129180108.9033566, -1132000461.6688302, -1134675829.8190608, -1137652487.5164194, -1141755948.0463965, -1145340901.5468378, -1148637682.593287, -1151755522.470022, -1154981643.2268832, -1157417488.840151, -1161240429.0989249, -1165411128.671642, -1170521097.1034513, -1176307165.5109766, -1183456865.0039694, -1190535938.6591117, -1197946309.0472982, -1203596565.037139, -1207563038.1241052, -1209707561.5829782, -1211407066.2452552, -1211884576.9201162, -1212778872.005509, -1214041413.8080075, -1215367953.1745043, -1216850831.482193, -1217678325.5351057, -1218854289.54188, -1219325064.8610544, -1219080344.7580786, -1218541313.657531, -1217889833.2067819, -1216552930.1654336, -1216423777.4113154, -1216575252.225508, -1217075384.9826024, -1217391577.901724, -1217838974.57273, -1218131805.6054134, -1218294889.7465532, -1218566666.1755593, -1218790537.5519717, -1218748668.9956846, -1218603191.4941735, -1218004566.4348054, -1217312410.127734, -1217207493.9522285, -1217284002.3834674, -1217644312.51745, -1218039821.6444128, -1218721811.6269798, -1219121088.9265897, -1219014460.8090584, -1218530127.6776083, -1217952335.451711, -1217316073.8666434, -1217035380.1151958, -1216636431.2964456, -1216257015.2945514, -1215658496.1208403, -1215097272.0976632, -1214669859.2064147, -1214593853.4809475, -1214599475.7838447, -1214575440.823035, -1214158828.8008435, -1213482920.2673717, -1212476577.5897374, -1211251374.2198513, -1210284855.590475, -1209302456.065669, -1209106252.6625297, -1209373211.5146718, -1209689421.7984035, -1210021342.495856, -1210650609.3592312, -1211428521.3900626, -1212616111.4257205, -1213820075.2948189, -1215320588.7144456, -1217175082.2739282, -1219703351.4585004, -1222007827.120464, -1224637375.5900724, -1228367798.912171, -1234853879.862459, -1247222219.867692, -1268562808.1616178, -1302034822.9569275, -1347823631.0776038, -1402753916.9445229, -1458826717.3262982, -1505843092.0970414, -1534278782.249077, -1543955545.8994718, -1600409154.893352], 'var_stat': [12665413908.91729, 11145088801.244318, 12567119446.035736, 11758392758.06822, 11200687982.736668, 11551903443.711124, 12880777868.435602, 14084854368.236998, 14394011058.866192, 14678818621.277662, 15346278722.626339, 16268053979.757076, 17191705347.854794, 17877540386.548733, 18251857849.077663, 18392628178.710472, 18645534548.4045, 19018598212.22902, 19366711357.782673, 19655730286.72857, 19890681996.786858, 20094163350.461906, 20227774955.225887, 20423525628.66887, 20669928826.76939, 20882313568.247944, 21062392676.270527, 21126648821.879055, 21185210734.751118, 21209014745.520447, 21182293842.91236, 21197433134.875977, 21302147790.662144, 21504666657.651955, 21781818550.89697, 21996170165.145462, 22217169779.096275, 22431161762.176693, 22672708668.38104, 22922683961.072956, 23101137011.201683, 23249680793.556847, 23358894817.24979, 23422895267.919228, 23449479198.303394, 23464433357.671055, 23469197140.124596, 23459013479.866177, 23447935341.542686, 23422585038.052387, 23375601301.949135, 23338397991.497776, 23329682884.21905, 23348002892.39853, 23406274659.89975, 23478242518.92228, 23592891371.876236, 23703885161.772205, 23797158601.65954, 23875230355.66992, 23918333664.3946, 23968582109.371258, 24040547318.081936, 24112364295.110058, 24189973697.612144, 24242165205.640236, 24364255205.82311, 24472408850.760197, 24590211203.05312, 24763026764.005527, 24909192634.69144, 25043438176.23281, 25167141466.500504, 25297108031.48665, 25395377064.0999, 25550930772.86505, 25721404827.10336, 25931101211.156487, 26168988710.098465, 26465528802.762875, 26760033029.443783, 27075408488.605213, 27316626931.655052, 27487275073.52796, 27579518448.2332, 27652308513.875782, 27673412508.45838, 27711509210.702576, 27767312240.641487, 27827464683.295334, 27894794590.957966, 27935988489.16511, 27992337099.891083, 28019655483.58796, 28014286886.252903, 27996189233.857716, 27973078840.875465, 27920045013.68706, 27917103211.22359, 27927566165.64652, 27953525818.61368, 27973386070.140022, 27999317832.502476, 28019494120.641834, 28033010746.452637, 28051086123.896503, 28066195174.191753, 28068570977.318798, 28064890246.85437, 28042424375.860577, 28015849655.869568, 28014812222.566605, 28021039053.959835, 28039270607.169422, 28058271295.10199, 28088976520.10178, 28107824988.74732, 28105633030.784756, 28087681357.818607, 28065484299.963837, 28039555887.004284, 28028214431.52875, 28011714871.929447, 27995603790.480755, 27970125897.561134, 27946436130.511288, 27929044772.5522, 27926612443.390316, 27926256324.387302, 27924771848.71099, 27905526922.390133, 27876268519.168198, 27832532606.552593, 27779497699.976765, 27737034351.907337, 27692129825.179924, 27684252911.371475, 27698882622.878677, 27712387157.27985, 27726474638.933037, 27752647691.051613, 27786197932.382797, 27836378752.662235, 27887415700.334576, 27949784230.702114, 28028117657.84245, 28136313097.200474, 28234098926.207996, 28345845477.25874, 28507222800.146496, 28793832339.90449, 29350765483.070816, 30328262350.231213, 31894930713.76519, 34093669067.422382, 36801959396.22739, 39638995447.49344, 42088579425.44825, 43616108982.85117, 44152063315.31461, 47464832889.5967], 'frame_num': 54129649}\n"
]
}
],
"source": [
"print(cmvn_info)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "accurate-terminal",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dominant-abuse",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 1000 wavs,450240 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 2000 wavs,886411 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 3000 wavs,1352580 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 4000 wavs,1814397 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 5000 wavs,2356587 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 6000 wavs,2825310 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 7000 wavs,3272506 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 8000 wavs,3688045 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 9000 wavs,4134669 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 10000 wavs,4586357 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 11000 wavs,5014429 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 12000 wavs,5453334 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 13000 wavs,5892888 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 14000 wavs,6316059 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 15000 wavs,6728870 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 16000 wavs,7199442 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 17000 wavs,7629055 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 18000 wavs,8083729 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 19000 wavs,8519732 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 20000 wavs,8895694 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 21000 wavs,9341778 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 22000 wavs,9796126 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 23000 wavs,10236057 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 24000 wavs,10687461 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 25000 wavs,11113082 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 26000 wavs,11544482 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 27000 wavs,11996273 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 28000 wavs,12456350 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 29000 wavs,12900895 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 30000 wavs,13330353 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 31000 wavs,13736568 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 32000 wavs,14158472 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 33000 wavs,14625316 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 34000 wavs,15036206 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 35000 wavs,15514001 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 36000 wavs,16004323 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 37000 wavs,16418799 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 38000 wavs,16840100 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 39000 wavs,17287752 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 40000 wavs,17776206 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 41000 wavs,18243209 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 42000 wavs,18690449 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 43000 wavs,19137940 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 44000 wavs,19553966 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 45000 wavs,19969813 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 46000 wavs,20440963 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 47000 wavs,20862022 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 48000 wavs,21292801 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 49000 wavs,21713004 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 50000 wavs,22146346 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 51000 wavs,22596172 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 52000 wavs,23074160 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 53000 wavs,23499823 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 54000 wavs,23942151 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 55000 wavs,24390566 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 56000 wavs,24833905 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 57000 wavs,25307270 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 58000 wavs,25748720 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 59000 wavs,26185964 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 60000 wavs,26663953 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 61000 wavs,27117720 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 62000 wavs,27585349 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 63000 wavs,28032693 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 64000 wavs,28487074 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 65000 wavs,28956462 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 66000 wavs,29436358 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 67000 wavs,29918569 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 68000 wavs,30325682 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 69000 wavs,30762528 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 70000 wavs,31182319 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 71000 wavs,31627526 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 72000 wavs,32070556 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 73000 wavs,32504534 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 74000 wavs,32972775 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 75000 wavs,33409637 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 76000 wavs,33847861 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 77000 wavs,34298647 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 78000 wavs,34721536 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 79000 wavs,35159236 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 80000 wavs,35628628 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 81000 wavs,36080909 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 82000 wavs,36562496 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 83000 wavs,37042976 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 84000 wavs,37474403 frames\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 85000 wavs,37943596 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 86000 wavs,38371620 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 87000 wavs,38844874 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 88000 wavs,39292686 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 89000 wavs,39746715 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 90000 wavs,40241800 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 91000 wavs,40672817 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 92000 wavs,41131773 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 93000 wavs,41612001 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 94000 wavs,42084822 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 95000 wavs,42535878 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 96000 wavs,42969365 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 97000 wavs,43430890 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 98000 wavs,43923378 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 99000 wavs,44397370 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 100000 wavs,44883695 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 101000 wavs,45327968 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 102000 wavs,45768860 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 103000 wavs,46205602 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 104000 wavs,46690407 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 105000 wavs,47153089 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 106000 wavs,47628699 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 107000 wavs,48067945 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 108000 wavs,48539256 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 109000 wavs,49030485 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 110000 wavs,49469189 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 111000 wavs,49928968 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 112000 wavs,50370921 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 113000 wavs,50840090 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 114000 wavs,51286249 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 115000 wavs,51715786 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 116000 wavs,52184017 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 117000 wavs,52666156 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 118000 wavs,53109645 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 119000 wavs,53553253 frames\n",
"<class 'int'> <class 'paddle.VarBase'> <class 'paddle.VarBase'>\n",
"process 120000 wavs,54009877 frames\n",
"{'mean_stat': [700612678.1184504, 704246512.9321843, 720430663.1822729, 754033269.0474415, 798737761.616614, 829467218.4204571, 851246702.9426627, 862261185.2661449, 859339943.6923889, 846303730.8696194, 832995109.605447, 823196536.6029147, 832626008.2569772, 845571326.1936859, 848801373.0562981, 846503549.328017, 836774344.5500796, 823481091.0445303, 820728368.2518216, 804571348.4957463, 795306095.0083207, 811729024.2415155, 805734803.5703195, 813076782.1959459, 806620199.406499, 809655573.8886961, 804371708.9347517, 809272248.6085774, 810322689.7490631, 814294131.1973915, 816262716.0476038, 816213124.2411841, 817158473.4380915, 821414211.5629157, 827408091.5728914, 834353896.0519086, 840094990.3467333, 842613218.6554606, 842070761.1727513, 834970952.5260613, 837020570.8200948, 829592602.7833654, 830116543.8893851, 829482316.3881509, 833397219.4597517, 839251633.3120549, 845475010.4718693, 852378426.7183967, 859563981.8633184, 866063840.5523493, 867790921.9978689, 868215100.5962687, 869683066.032885, 872467375.6674014, 873097681.1780069, 873025823.0543871, 869897292.7201596, 866386426.3869117, 863166726.7256871, 854653071.2244718, 842402803.9000899, 830838253.4144138, 830143002.3536818, 831492285.0310817, 833304371.8781006, 838896092.8621838, 843866088.9578133, 847316792.1429776, 851038022.3643295, 855931698.0149751, 859320543.9795249, 863031001.3470656, 868325062.1832993, 873626971.0115026, 878726636.924209, 884861725.972504, 886920281.5192285, 883056006.5094173, 863719240.7255149, 773378975.9476194], 'var_stat': [9237018652.657722, 9417257721.82426, 10105084297.159702, 11071318522.587782, 12422783727.426847, 13400306419.784964, 14148498843.406874, 14576436982.89939, 14529009036.494726, 14105645932.596651, 13682988821.478252, 13413013425.088106, 13764134927.293928, 14233704806.737064, 14361631309.367067, 14281358385.45644, 13939662689.213865, 13496884231.929493, 13382566162.783987, 12871350930.6626, 12576198160.876635, 13051463889.56708, 12859205935.513906, 13053861416.098743, 12830323588.550724, 12886405923.897238, 12708529922.84171, 12847306110.231739, 12880398489.53404, 13002566299.565536, 13066708060.463543, 13064231286.858614, 13088983337.353497, 13221393824.891022, 13412425607.755072, 13631485149.777075, 13807797519.156103, 13877277485.033077, 13848613909.96762, 13609176326.2529, 13649815250.130072, 13397698404.696907, 13388964704.359968, 13354326914.968012, 13469861474.898457, 13652539440.283333, 13846837321.329163, 14062143714.601675, 14292571198.61228, 14504626563.299246, 14563864749.132776, 14579720287.991764, 14626700787.353922, 14716185568.128899, 14728532777.28015, 14719101187.113443, 14607945896.239174, 14478517828.531614, 14355110561.681187, 14057430280.249746, 13634284490.879377, 13248236002.494394, 13217602306.335958, 13257856701.946049, 13323688441.072674, 13515395318.023148, 13685827169.67645, 13811622609.426846, 13947347160.615082, 14115883822.884943, 14231204526.433033, 14356066668.651815, 14533604268.238445, 14708971788.69237, 14875667326.732443, 15079098318.79331, 15144888989.667963, 15002658970.504765, 14349232841.34513, 11544480117.013124], 'frame_num': 54068199}\n"
]
}
],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import paddle\n",
"from paddle.io import DataLoader\n",
"from paddle.io import Dataset\n",
"\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"# https://github.com/PaddlePaddle/Paddle/pull/31481\n",
"class CollateFunc(object):\n",
" ''' Collate function for AudioDataset\n",
" '''\n",
" def __init__(self, feature_func):\n",
" self.feature_func = feature_func\n",
" \n",
" def __call__(self, batch):\n",
" mean_stat = None\n",
" var_stat = None\n",
" number = 0\n",
" for item in batch:\n",
" audioseg = AudioSegment.from_file(item['feat'])\n",
" feat = self.feature_func(audioseg) #(D, T)\n",
"\n",
" sums = np.sum(feat, axis=1)\n",
" if mean_stat is None:\n",
" mean_stat = sums\n",
" else:\n",
" mean_stat += sums\n",
"\n",
" square_sums = np.sum(np.square(feat), axis=1)\n",
" if var_stat is None:\n",
" var_stat = square_sums\n",
" else:\n",
" var_stat += square_sums\n",
"\n",
" number += feat.shape[1]\n",
" return number, mean_stat, var_stat\n",
"\n",
"\n",
"class AudioDataset(Dataset):\n",
" def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):\n",
" self._rng = rng if rng else np.random.RandomState(random_seed)\n",
" manifest = read_manifest(manifest_path)\n",
" if num_samples == -1:\n",
" sampled_manifest = manifest\n",
" else:\n",
" sampled_manifest = self._rng.choice(manifest, num_samples, replace=False)\n",
" self.items = sampled_manifest\n",
"\n",
" def __len__(self):\n",
" return len(self.items)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.items[idx]\n",
" \n",
" \n",
"augmentation_pipeline = AugmentationPipeline('{}')\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20)\n",
"\n",
"def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
"\n",
"collate_func = CollateFunc(augment_and_featurize)\n",
"\n",
"dataset = AudioDataset(\n",
" args.manifest_path,\n",
" args.num_samples)\n",
"\n",
"batch_size = 20\n",
"data_loader = DataLoader(\n",
" dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=args.num_workers,\n",
" collate_fn=collate_func)\n",
"\n",
"with paddle.no_grad():\n",
" all_mean_stat = None\n",
" all_var_stat = None\n",
" all_number = 0\n",
" wav_number = 0\n",
" for i, batch in enumerate(data_loader):\n",
" number, mean_stat, var_stat = batch\n",
" if i == 0:\n",
" all_mean_stat = mean_stat\n",
" all_var_stat = var_stat\n",
" else:\n",
" all_mean_stat += mean_stat\n",
" all_var_stat += var_stat\n",
" all_number += number\n",
" wav_number += batch_size\n",
"\n",
" if wav_number % 1000 == 0:\n",
" print('process {} wavs,{} frames'.format(wav_number,\n",
" all_number))\n",
"\n",
"cmvn_info = {\n",
" 'mean_stat': list(all_mean_stat.tolist()),\n",
" 'var_stat': list(all_var_stat.tolist()),\n",
" 'frame_num': all_number\n",
"}\n",
"print(cmvn_info)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unlike-search",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "emerging-meter",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"import math\n",
"import random\n",
"import tarfile\n",
"import logging\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"from functools import partial\n",
"\n",
"import paddle\n",
"from paddle.io import Dataset\n",
"from paddle.io import DataLoader\n",
"from paddle.io import BatchSampler\n",
"from paddle.io import DistributedBatchSampler\n",
"from paddle import distributed as dist\n",
"\n",
"from data_utils.utility import read_manifest\n",
"from data_utils.augmentor.augmentation import AugmentationPipeline\n",
"from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n",
"from data_utils.speech import SpeechSegment\n",
"from data_utils.normalizer import FeatureNormalizer\n",
"\n",
"\n",
"from data_utils.dataset import (\n",
" DeepSpeech2Dataset,\n",
" DeepSpeech2DistributedBatchSampler,\n",
" DeepSpeech2BatchSampler,\n",
" SpeechCollator,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "excessive-american",
"metadata": {},
"outputs": [],
"source": [
"def create_dataloader(manifest_path,\t\n",
" vocab_filepath,\t\n",
" mean_std_filepath,\t\n",
" augmentation_config='{}',\t\n",
" max_duration=float('inf'),\t\n",
" min_duration=0.0,\t\n",
" stride_ms=10.0,\t\n",
" window_ms=20.0,\t\n",
" max_freq=None,\t\n",
" specgram_type='linear',\t\n",
" use_dB_normalization=True,\t\n",
" random_seed=0,\t\n",
" keep_transcription_text=False,\t\n",
" is_training=False,\t\n",
" batch_size=1,\t\n",
" num_workers=0,\t\n",
" sortagrad=False,\t\n",
" shuffle_method=None,\t\n",
" dist=False):\t\n",
"\n",
" dataset = DeepSpeech2Dataset(\t\n",
" manifest_path,\t\n",
" vocab_filepath,\t\n",
" mean_std_filepath,\t\n",
" augmentation_config=augmentation_config,\t\n",
" max_duration=max_duration,\t\n",
" min_duration=min_duration,\t\n",
" stride_ms=stride_ms,\t\n",
" window_ms=window_ms,\t\n",
" max_freq=max_freq,\t\n",
" specgram_type=specgram_type,\t\n",
" use_dB_normalization=use_dB_normalization,\t\n",
" random_seed=random_seed,\t\n",
" keep_transcription_text=keep_transcription_text)\t\n",
"\n",
" if dist:\t\n",
" batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n",
" dataset,\t\n",
" batch_size,\t\n",
" num_replicas=None,\t\n",
" rank=None,\t\n",
" shuffle=is_training,\t\n",
" drop_last=is_training,\t\n",
" sortagrad=is_training,\t\n",
" shuffle_method=shuffle_method)\t\n",
" else:\t\n",
" batch_sampler = DeepSpeech2BatchSampler(\t\n",
" dataset,\t\n",
" shuffle=is_training,\t\n",
" batch_size=batch_size,\t\n",
" drop_last=is_training,\t\n",
" sortagrad=is_training,\t\n",
" shuffle_method=shuffle_method)\t\n",
"\n",
" def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n",
" \"\"\"\t\n",
" Padding audio features with zeros to make them have the same shape (or\t\n",
" a user-defined shape) within one bach.\t\n",
"\n",
" If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n",
" as the target shape for padding. Otherwise, `padding_to` will be the\t\n",
" target shape (only refers to the second axis).\t\n",
"\n",
" If `flatten` is True, features will be flatten to 1darray.\t\n",
" \"\"\"\t\n",
" new_batch = []\t\n",
" # get target shape\t\n",
" max_length = max([audio.shape[1] for audio, text in batch])\t\n",
" if padding_to != -1:\t\n",
" if padding_to < max_length:\t\n",
" raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n",
" \"than any instance's shape in the batch\")\t\n",
" max_length = padding_to\t\n",
" max_text_length = max([len(text) for audio, text in batch])\t\n",
" # padding\t\n",
" padded_audios = []\t\n",
" audio_lens = []\t\n",
" texts, text_lens = [], []\t\n",
" for audio, text in batch:\t\n",
" padded_audio = np.zeros([audio.shape[0], max_length])\t\n",
" padded_audio[:, :audio.shape[1]] = audio\t\n",
" if flatten:\t\n",
" padded_audio = padded_audio.flatten()\t\n",
" padded_audios.append(padded_audio)\t\n",
" audio_lens.append(audio.shape[1])\t\n",
"\n",
" padded_text = np.zeros([max_text_length])\n",
" if is_training:\n",
" padded_text[:len(text)] = text\t# ids\n",
" else:\n",
" padded_text[:len(text)] = [ord(t) for t in text] # string\n",
" \n",
" texts.append(padded_text)\t\n",
" text_lens.append(len(text))\t\n",
"\n",
" padded_audios = np.array(padded_audios).astype('float32')\t\n",
" audio_lens = np.array(audio_lens).astype('int64')\t\n",
" texts = np.array(texts).astype('int32')\t\n",
" text_lens = np.array(text_lens).astype('int64')\t\n",
" return padded_audios, texts, audio_lens, text_lens\t\n",
"\n",
" loader = DataLoader(\t\n",
" dataset,\t\n",
" batch_sampler=batch_sampler,\t\n",
" collate_fn=partial(padding_batch, is_training=is_training),\t\n",
" num_workers=num_workers)\t\n",
" return loader"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "naval-brave",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/data/manifest.dev',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/aishell/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'linear',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "bearing-physics",
"metadata": {},
"outputs": [],
"source": [
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=True,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "classified-melissa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[22823, 26102, 20195, 37324, 0 , 0 ],\n",
" [22238, 26469, 23601, 22909, 0 , 0 ],\n",
" [20108, 26376, 22235, 26085, 0 , 0 ],\n",
" [36824, 35201, 20445, 25345, 32654, 24863],\n",
" [29042, 27748, 21463, 23456, 0 , 0 ]])\n",
"test raw 大时代里\n",
"test raw 煲汤受宠\n",
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [4, 4, 4, 6, 4])\n",
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
" ...,\n",
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
"\n",
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
" ...,\n",
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
]
}
],
"source": [
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len', audio_len)\n",
" print('test len', text_len)\n",
" print('audio', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unexpected-skating",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "minus-modern",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "medieval-monday",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/DeepSpeech-2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "emerging-meter",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n"
]
}
],
"source": [
"import math\n",
"import random\n",
"import tarfile\n",
"import logging\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"from functools import partial\n",
"\n",
"import paddle\n",
"from paddle.io import Dataset\n",
"from paddle.io import DataLoader\n",
"from paddle.io import BatchSampler\n",
"from paddle.io import DistributedBatchSampler\n",
"from paddle import distributed as dist\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "excessive-american",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "naval-brave",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:93] register user softmax to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:97] register user log_softmax to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:101] register user sigmoid to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:105] register user log_sigmoid to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:109] register user relu to paddle, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:119] override cat of paddle if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:133] override item of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:144] override long of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:164] override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:179] override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:185] override eq of paddle if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:195] override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:212] override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:223] register user view to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:233] register user view_as to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:259] register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:277] register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:288] register user fill_ to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:298] register user repeat to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:303] register user softmax to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:308] register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:312] register user relu to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:322] register user type_as to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:337] register user to to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:346] register user float to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:356] register user tolist to paddle.Tensor, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:371] register user glu to paddle.nn.functional, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:422] override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:428] register user Module to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:434] register user ModuleList to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:450] register user GLU to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:483] register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"[WARNING 2021/04/16 06:32:09 __init__.py:489] register user export to paddle.jit, remove this when fixed!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/tiny/s1/data/spm_bpe', 'infer_manifest': 'examples/tiny/s1/data/manifest.tiny', 'mean_std_path': 'examples/tiny/s1/data/mean_std.npz', 'vocab_path': 'examples/tiny/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/tiny/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('unit_type', str,\n",
" 'char',\n",
" \"Options: char, word, spm.\",\n",
" choices=['char', 'word', 'spm'])\n",
"add_arg('spm_model_prefix', str,\n",
" 'examples/tiny/s1/data/spm_bpe',\n",
" \"spm model prefix.\",)\n",
"add_arg('infer_manifest', str,\n",
" 'examples/tiny/s1/data/manifest.tiny',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/tiny/s1/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/tiny/s1/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/tiny/s1/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n",
"add_arg('delta_delta', bool, False, \"delta delta\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "wired-principal",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/aishell/s1/data/spm_bpe', 'infer_manifest': 'examples/aishell/s1/data/manifest.test', 'mean_std_path': '', 'vocab_path': 'examples/aishell/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('unit_type', str,\n",
" 'char',\n",
" \"Options: char, word, spm.\",\n",
" choices=['char', 'word', 'spm'])\n",
"add_arg('spm_model_prefix', str,\n",
" 'examples/aishell/s1/data/spm_bpe',\n",
" \"spm model prefix.\",)\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/s1/data/manifest.test',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" '',\n",
" \"examples/aishell/s1/data/mean_std.npz, Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/s1/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/s1/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n",
"add_arg('delta_delta', bool, False, \"delta delta\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bearing-physics",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n",
"from deepspeech.frontend.speech import SpeechSegment\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"\n",
"\n",
"from deepspeech.io.collator import SpeechCollator\n",
"from deepspeech.io.dataset import ManifestDataset\n",
"from deepspeech.io.sampler import (\n",
" SortagradDistributedBatchSampler,\n",
" SortagradBatchSampler,\n",
")\n",
"from deepspeech.io import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" unit_type=args.unit_type,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" spm_model_prefix=args.spm_model_prefix,\n",
" augmentation_config='{}',\n",
" max_input_len=27.0,\n",
" min_input_len=0.0,\n",
" max_output_len=float('inf'),\n",
" min_output_len=0.0,\n",
" max_output_input_ratio=float('inf'),\n",
" min_output_input_ratio=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=True,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" num_workers=0,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "classified-melissa",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[232 387 331 ... 249 249 262] int16\n",
"fbank\n",
"[-138 -219 -192 ... 338 324 351] int16\n",
"fbank\n",
"[ 694 1175 1022 ... 553 514 627] int16\n",
"fbank\n",
"[-39 -79 -53 ... 139 172 99] int16\n",
"fbank\n",
"[-277 -480 -425 ... 758 767 739] int16\n",
"fbank\n",
"[ 399 693 609 ... 1291 1270 1291] int16\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if arr.dtype == np.object:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[ -750 -1254 -1107 ... 2276 1889 2067] int16\n",
"fbank\n",
"[ -127 -199 -149 ... -5243 -5065 -5398] int16\n",
"fbank\n",
"[ 465 783 677 ... 980 903 1008] int16\n",
"fbank\n",
"[ 90 160 157 ... -2 -16 -21] int16\n",
"fbank\n",
"[ 213 345 295 ... 2483 2246 2501] int16\n",
"fbank\n",
"[ -86 -159 -131 ... 270 258 290] int16\n",
"fbank\n",
"[-1023 -1714 -1505 ... 1532 1596 1575] int16\n",
"fbank\n",
"[-366 -602 -527 ... 374 370 379] int16\n",
"fbank\n",
"[ 761 1275 1127 ... 369 413 295] int16\n",
"fbank\n",
"[382 621 550 ... 161 161 174] int16\n",
"fbank\n",
"[ -28 -91 -120 ... 28 34 11] int16\n",
"fbank\n",
"[ -5 -5 -5 ... 268 294 341] int16\n",
"fbank\n",
"[240 417 684 ... 267 262 219] int16\n",
"fbank\n",
"[131 206 194 ... 383 320 343] int16\n",
"test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[31069, 21487, 29233, 30340, 20320, -1 , -1 ],\n",
" [20540, 24471, 19968, 25552, 30340, 26159, -1 ],\n",
" [36825, 20010, 31243, 24230, 26159, 32654, 30340],\n",
" [20108, 21040, 20108, -1 , -1 , -1 , -1 ],\n",
" [21435, 34892, 25919, 21270, -1 , -1 , -1 ]])\n",
"fbank\n",
"[1155 1890 1577 ... 1092 989 1130] int16\n",
"fbank\n",
"[296 358 296 ... 140 140 168] int16\n",
"fbank\n",
"[-50 -91 -63 ... 104 104 86] int16\n",
"fbank\n",
"[-37 -66 -50 ... -31 -45 -52] int16\n",
"fbank\n",
"[-401 -652 -547 ... -339 -307 -344] int16\n",
"fbank\n",
"[-21 -47 -51 ... 94 81 107] int16\n",
"fbank\n",
"[ 533 887 755 ... 3074 2853 3254] int16\n",
"fbank\n",
"[ 44 71 66 ... -628 -733 -601] int16\n",
"fbank\n",
"[ 50 86 79 ... 129 116 138] int16\n",
"fbank\n",
"[ 92 146 126 ... -208 -193 -179] int16\n",
"test raw: 祝可爱的你\n",
"test raw: 去行政化\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [184, 194, 196, 204, 207])\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [5, 6, 7, 3, 4])\n",
"audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[12.25633812, 12.61639309, 10.36936474, ..., 13.02949619, 11.51365757, 10.59789085],\n",
" [13.32148266, 13.41071606, 11.43800735, ..., 13.69783783, 12.83939362, 11.51259613],\n",
" [12.62640572, 12.53621101, 10.97212505, ..., 13.33757591, 12.32293034, 10.75493717],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[10.99619484, 11.35202599, 9.56922054 , ..., 9.94971657 , 9.88354111 , 9.55315971 ],\n",
" [10.44461155, 9.81688595 , 5.62538481 , ..., 10.60468388, 10.94417381, 9.42646980 ],\n",
" [10.23835754, 10.23407459, 7.99464273 , ..., 10.68097591, 9.91640091 , 10.04131031],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[14.10299397, 14.50298119, 12.87738323, ..., 12.62796497, 12.69949627, 11.43171215],\n",
" [13.85035992, 13.15289116, 10.66541386, ..., 13.34364223, 13.46972179, 11.02160740],\n",
" [13.19866467, 13.23537827, 11.65760899, ..., 12.72559357, 12.42716217, 11.74562359],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.85668373, 12.82431412, 11.68144703, ..., 14.10119247, 15.12791920, 13.68221378],\n",
" [13.19507027, 13.40244961, 11.43618393, ..., 13.32919979, 13.68267441, 12.73429012],\n",
" [13.02173328, 12.92082500, 11.44303989, ..., 12.77793121, 13.10915661, 11.77327728],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.90771198, 13.40234852, 13.01435471, ..., 13.80359459, 14.08088684, 13.17883396],\n",
" [14.06678009, 14.06943512, 12.52837276, ..., 13.66423225, 13.66300583, 13.60142994],\n",
" [12.58743191, 12.94520760, 11.75190544, ..., 14.28828907, 14.08229160, 13.02433395],\n",
" ...,\n",
" [16.20896912, 16.42283821, 14.94358730, ..., 12.91146755, 12.66766262, 11.76361752],\n",
" [13.49324894, 14.14653301, 13.16490936, ..., 13.23435783, 13.45378494, 12.60386276],\n",
" [15.56288910, 15.92445087, 14.90794277, ..., 13.43840790, 13.41075516, 12.55605984]]])\n"
]
}
],
"source": [
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len:', audio_len)\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unexpected-skating",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "minus-modern",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fbank\n",
"[232 387 331 ... 249 249 262] int16\n",
"fbank\n",
"[-138 -219 -192 ... 338 324 351] int16\n",
"fbank\n",
"[ 694 1175 1022 ... 553 514 627] int16\n",
"fbank\n",
"[-39 -79 -53 ... 139 172 99] int16\n",
"fbank\n",
"[-277 -480 -425 ... 758 767 739] int16\n",
"fbank\n",
"test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[2695, 505, 2332, 2553, 169, -1 , -1 ],\n",
" [ 230, 1237, 2 , 1556, 2553, 1694, -1 ],\n",
" [3703, 28 , 2739, 1172, 1694, 2966, 2553],\n",
" [ 70 , 355, 70 , -1 , -1 , -1 , -1 ],\n",
" [ 477, 3363, 1621, 412, -1 , -1 , -1 ]])\n",
"[ 399 693 609 ... 1291 1270 1291] int16\n",
"test raw: ઇǹज৹©\n",
"test raw: ǝണٕƜ\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [5, 6, 7, 3, 4])\n",
"audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[12.25794601, 12.61855793, 10.37306023, ..., 13.12571049, 11.53678799, 10.32210350],\n",
" [13.32333183, 13.41336918, 11.44248962, ..., 13.65861225, 12.79308128, 11.31168747],\n",
" [12.62584686, 12.53506088, 10.96861362, ..., 13.32526493, 12.41560936, 10.71458912],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[11.00003052, 11.35529137, 9.56384087 , ..., 10.06063652, 10.16322994, 9.43149185 ],\n",
" [10.44556236, 9.81155300 , 5.49400425 , ..., 10.84116268, 11.02734756, 9.42253590 ],\n",
" [10.23620510, 10.23321152, 7.99466419 , ..., 10.93381882, 10.28395081, 10.00841141],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[14.10379314, 14.50375748, 12.87825108, ..., 12.68065739, 12.62359715, 11.53773308],\n",
" [13.84964657, 13.15079498, 10.67198086, ..., 13.24875164, 13.45796680, 10.97363472],\n",
" [13.19808197, 13.23482990, 11.65900230, ..., 12.70375061, 12.41395664, 11.88668156],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.85676289, 12.82410812, 11.67961884, ..., 14.12018299, 15.14850044, 13.80065727],\n",
" [13.19532776, 13.40243340, 11.43492508, ..., 13.29144669, 13.70278549, 12.67841339],\n",
" [13.02196407, 12.92111111, 11.43998623, ..., 12.71165752, 13.16518497, 11.92028046],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[12.90661621, 13.40162563, 13.01394463, ..., 13.84056377, 14.11240959, 13.21227264],\n",
" [14.06642914, 14.06922340, 12.52955723, ..., 13.55829811, 13.60157204, 13.50268650],\n",
" [12.58881378, 12.94780254, 11.75758171, ..., 14.29055786, 14.12165928, 13.02695847],\n",
" ...,\n",
" [16.20891571, 16.42290306, 14.94398117, ..., 12.86083794, 12.63515949, 11.67581463],\n",
" [13.49345875, 14.14656067, 13.16498375, ..., 13.28024578, 13.40956783, 12.70357513],\n",
" [15.56265163, 15.92387581, 14.90643024, ..., 13.45694065, 13.44703197, 12.81099033]]])\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [184, 194, 196, 204, 207])\n"
]
}
],
"source": [
"keep_transcription_text=False\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" unit_type=args.unit_type,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" spm_model_prefix=args.spm_model_prefix,\n",
" augmentation_config='{}',\n",
" max_input_len=27.0,\n",
" min_input_len=0.0,\n",
" max_output_len=float('inf'),\n",
" min_output_len=0.0,\n",
" max_output_input_ratio=float('inf'),\n",
" min_output_input_ratio=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=keep_transcription_text,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" num_workers=0,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)\n",
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" print('audio len:', audio_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "competitive-mounting",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"id": "knowing-military",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 1, 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False, 'stride_ms': 10.0, 'window_ms': 25.0, 'sample_rate': 16000, 'manifest_path': 'examples/aishell/s1/data/manifest.train', 'output_path': 'examples/aishell/s1/data/mean_std.npz'}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"\n",
"add_arg('num_samples', int, 1, \"# of samples to for statistics.\")\n",
"add_arg('specgram_type', str, 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc, fbank.\",\n",
" choices=['linear', 'mfcc', 'fbank'])\n",
"add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n",
"add_arg('delta_delta', bool, False,\"Audio feature with delta delta.\")\n",
"add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n",
"add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n",
"add_arg('sample_rate', int, 16000, \"target sample rate.\")\n",
"add_arg('manifest_path', str,\n",
" 'examples/aishell/s1/data/manifest.train',\n",
" \"Filepath of manifest to compute normalizer's mean and stddev.\")\n",
"add_arg('output_path', str,\n",
" 'examples/aishell/s1/data/mean_std.npz',\n",
" \"Filepath of write mean and stddev to (.npz).\")\n",
"args = parser.parse_args([])\n",
"print(vars(args))\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "unnecessary-province",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"from deepspeech.frontend.audio import AudioSegment\n",
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.frontend.utility import read_manifest\n",
"\n",
"\n",
"\n",
"def mean(args):\n",
" augmentation_pipeline = AugmentationPipeline('{}')\n",
" audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=True,\n",
" target_dB=-20,\n",
" dither=0.0)\n",
"\n",
" def augment_and_featurize(audio_segment):\n",
" augmentation_pipeline.transform_audio(audio_segment)\n",
" return audio_featurizer.featurize(audio_segment)\n",
"\n",
" normalizer = FeatureNormalizer(\n",
" mean_std_filepath=None,\n",
" manifest_path=args.manifest_path,\n",
" featurize_func=augment_and_featurize,\n",
" num_samples=args.num_samples)\n",
" normalizer.write_to_file(args.output_path)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "interested-camping",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"[54. 90. 77. ... 58. 58. 61.]\n",
"29746\n",
"fbank\n",
"[54 90 77 ... 58 58 61] int16\n",
"(184, 80) float64\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n",
"(184, 80) float64\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n"
]
}
],
"source": [
"wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/test/S0916/BAC009S0916W0426.wav'\n",
"test='祝可爱的你'\n",
"audio_featurizer = AudioFeaturizer(\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" stride_ms=args.stride_ms,\n",
" window_ms=args.window_ms,\n",
" n_fft=None,\n",
" max_freq=None,\n",
" target_sample_rate=args.sample_rate,\n",
" use_dB_normalization=False,\n",
" target_dB=-20,\n",
" dither=0.0)\n",
"samples = AudioSegment.from_file(wav)\n",
"print(samples._samples)\n",
"print(samples._samples * 2**15)\n",
"print(len(samples._samples))\n",
"feat = audio_featurizer.featurize(samples, False, False)\n",
"feat = feat.T\n",
"print(feat.shape, feat.dtype)\n",
"print(feat)\n",
"\n",
"from python_speech_features import logfbank\n",
"max_freq = args.sample_rate / 2\n",
"fbank_feat = logfbank(\n",
" signal=samples.to('int16'),\n",
" samplerate=args.sample_rate,\n",
" winlen=0.001 * args.window_ms,\n",
" winstep=0.001 * args.stride_ms,\n",
" nfilt=args.feat_dim,\n",
" nfft=512,\n",
" lowfreq=20,\n",
" highfreq=max_freq,\n",
" preemph=0.97,\n",
" dither=0.0,\n",
" wintype='povey')\n",
"print(fbank_feat.shape, fbank_feat.dtype)\n",
"print(fbank_feat)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "numeric-analyst",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(184, 160)\n",
"[ 8.59522397 8.43148278 8.36414052 8.45487173 8.31761643 8.04843683\n",
" 8.01683696 7.6574614 7.95521932 8.22945157 10.20138275 9.0447775\n",
" 9.14763398 9.18184349 9.03801065 9.04852307 8.67706728 8.71894271\n",
" 9.54553655 9.19535135 8.76413076 8.47828946 8.52586143 8.49469288\n",
" 8.72461247 8.28562879 8.11581393 7.99922156 7.91023364 8.04142296\n",
" 7.89762773 7.76257636 8.32043745 8.01592886 8.34109665 8.90115454\n",
" 8.48246945 7.98658664 8.05745122 8.11384088 8.18864479 8.8091827\n",
" 11.8067711 13.25258218 14.44311795 13.90515283 14.00120623 13.99801252\n",
" 13.81595394 13.6379904 13.3574897 13.14933334 12.96518543 13.02601156\n",
" 12.70246737 12.54410834 12.15615068 11.86574681 11.67497882 10.79645481\n",
" 10.48150035 10.03758575 10.05637027 9.92891308 10.06923218 12.43382431\n",
" 12.71428321 14.33135052 13.94470959 14.29188291 14.11483993 14.03496606\n",
" 13.78167331 13.66701466 14.40308625 14.73934137 15.09569382 14.89565815\n",
" 15.10519995 14.94383582 15.03275563 15.42194679 15.29219967 15.41602274\n",
" 15.39242545 15.76836177 16.259222 16.47777231 17.03366795 17.46165793\n",
" 17.52596217 17.78844031 17.99878075 18.11446843 17.95761578 17.99900337\n",
" 17.86282737 17.7290163 17.47686504 17.43425516 17.07750485 16.64395242\n",
" 15.68217043 14.90058399 14.45645737 14.0405463 14.89549542 16.00405781\n",
" 16.27301689 16.37572895 16.31219037 16.31765447 16.44819716 16.36281089\n",
" 16.24932823 15.79302555 14.76361963 13.95761882 13.48917053 13.45543501\n",
" 13.00091327 13.13854248 13.74596395 13.86340629 14.00656109 13.77432101\n",
" 13.64267001 13.35742634 13.23042234 12.97916104 12.80694468 12.70005006\n",
" 13.2802483 13.22644525 13.14579624 13.02536594 13.36511022 11.37167205\n",
" 12.11598045 12.47619798 12.83885973 11.63880287 11.42083924 11.08747705\n",
" 11.04093403 11.11263149 10.74353319 10.58734669 10.46180738 10.34157335\n",
" 9.63131146 9.70582692 9.29059204 8.94583657 8.66065094 8.46799095\n",
" 8.25064103 8.30239167 8.19463371 8.12104567 8.02731234 8.06412715\n",
" 7.84889951 7.73090283 7.74119562 7.85444657 7.80717312 7.7129933\n",
" 7.84087442 7.77907788 7.60660865 7.55051479 7.458385 7.496416\n",
" 7.69519793 7.49086759 7.32199493 8.01617458 7.58525375 7.06661122\n",
" 6.94653756 7.19874283 7.28515661 7.17574078]\n",
"(184,)\n",
"(184,)\n",
"[1.48370471 1.52174523 1.46984238 1.67010478 1.88757689 1.68825992\n",
" 1.74270259 1.55497318 1.29200818 1.68446481 1.88133219 1.97138928\n",
" 2.15910096 2.3149476 1.9820247 2.07694378 1.93498835 2.01493974\n",
" 2.39156824 2.02396518 1.69586449 1.63808752 1.64020228 1.43573473\n",
" 1.93092656 1.37466294 1.34704929 1.59600739 1.03960441 1.45276496\n",
" 1.59360131 1.57466343 1.89491479 1.79333746 1.32701974 1.49441767\n",
" 1.51466756 1.63497989 1.42858074 1.51135396 1.61077201 1.81066387\n",
" 1.83367783 2.3507094 2.87885378 3.26231227 2.1313117 1.98557548\n",
" 1.99105426 2.26150533 2.34298751 2.44621608 2.39201042 2.41226503\n",
" 2.5142992 3.03777565 2.81592295 2.75117863 2.78324175 2.68819666\n",
" 2.8945782 2.84464168 2.680973 2.78397395 2.47996808 1.71829563\n",
" 1.60636949 1.65992483 1.38122631 1.74831825 2.16006884 1.68076185\n",
" 1.69329487 1.44929837 1.63763312 1.80101076 2.01166253 2.03254244\n",
" 1.9583913 2.04542255 2.00859694 2.16600883 2.16095629 1.97541122\n",
" 2.13807632 2.06386436 2.2154187 2.84205688 2.54862449 2.64321545\n",
" 2.6805773 2.52300146 2.53209001 2.54682059 2.4521937 2.43155532\n",
" 2.42571275 2.23421289 2.23164529 2.23597192 2.14215121 2.10406703\n",
" 2.07962874 1.88506161 1.80092372 1.61156092 1.77426835 1.98765563\n",
" 2.0356793 1.87964187 1.779513 1.87187681 1.76463632 1.70978684\n",
" 1.76471778 1.75604749 1.62792552 1.73929352 1.6887024 1.8677704\n",
" 2.17342368 2.08166072 2.14567453 2.15936953 2.18351006 2.41010388\n",
" 2.26101752 2.25468001 2.23739715 2.15395133 2.04547813 1.92038843\n",
" 1.85491264 1.91905927 2.16709365 1.99924152 2.1850471 2.55461622\n",
" 2.72476673 1.69682926 1.73249614 2.06992695 2.1210591 1.66854454\n",
" 1.63907505 1.32203822 1.38992558 1.2436937 1.17932877 1.02963653\n",
" 1.26085036 1.16997132 1.09339504 1.14188689 1.18675772 1.31859788\n",
" 1.21746591 1.3872131 1.26095274 1.34885761 1.46633543 1.64506975\n",
" 1.36013821 1.45574721 1.43766588 1.65119054 1.57163772 1.55082968\n",
" 1.29413316 1.38351736 1.64234673 1.57186432 1.45381083 1.71204761\n",
" 1.51828607 1.30639985 1.32928395 1.49004237 1.6057589 1.81815735\n",
" 1.67784678 1.72180861 1.60703743 1.64850255]\n"
]
}
],
"source": [
"a = np.hstack([feat, feat])\n",
"print(a.shape)\n",
"m = np.mean(a, axis=1)\n",
"print(m)\n",
"print(m.shape)\n",
"std = np.std(a, axis=1)\n",
"print(std.shape)\n",
"print(std)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "nonprofit-potato",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 18,
"id": "hispanic-ethics",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchaudio\n",
"import torchaudio.compliance.kaldi as kaldi\n",
"import torchaudio.sox_effects as sox_effects\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"torchaudio.set_audio_backend(\"sox\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "changing-calvin",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 29746])\n",
"tensor([[54., 90., 77., ..., 58., 58., 61.]])\n",
"(184, 80)\n",
"[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n",
" [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n",
" [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n",
" ...\n",
" [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n",
" [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n",
" [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n",
"-----------\n",
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"(184, 80)\n",
"[[-10.177039 -10.717326 -15.46954 ... -10.546229 -11.897424 -12.987689]\n",
" [ -9.750411 -10.476343 -14.485752 ... -9.557108 -10.436023 -11.955799]\n",
" [-10.525113 -10.798049 -13.46475 ... -10.343097 -11.101464 -12.832712]\n",
" ...\n",
" [-10.649446 -10.907673 -14.056403 ... -10.578607 -11.790988 -12.038239]\n",
" [-10.816959 -11.114918 -12.88781 ... -10.570049 -11.199847 -13.101528]\n",
" [-14.320845 -13.03106 -13.036756 ... -10.829194 -11.171779 -12.634331]]\n",
"**************\n",
"[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n",
"[54. 90. 77. ... 58. 58. 61.] float32\n",
"(184, 80)\n",
"[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n",
" [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n",
" [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n",
" ...\n",
" [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n",
" [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n",
" [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: torchaudio.backend.sox_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use \"torchaudio.load\".\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
}
],
"source": [
"waveform, sample_rate = torchaudio.load_wav(wav)\n",
"print(waveform.shape)\n",
"print(waveform)\n",
"mat = kaldi.fbank(\n",
" waveform,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)\n",
"\n",
"print('-----------')\n",
"print(samples._samples)\n",
"aud = torch.tensor(samples._samples).view(1, -1)\n",
"mat = kaldi.fbank(\n",
" aud,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)\n",
"\n",
"print('**************')\n",
"print(samples._samples)\n",
"tmp = samples.to('int16').astype('float32')\n",
"print(tmp, tmp.dtype)\n",
"aud = torch.tensor(tmp).view(1, -1)\n",
"mat = kaldi.fbank(\n",
" aud,\n",
" num_mel_bins=80,\n",
" frame_length=25,\n",
" frame_shift=10,\n",
" dither=0,\n",
" energy_floor=0.0,\n",
" sample_frequency=sample_rate\n",
" )\n",
"mat = mat.detach().numpy()\n",
"print(mat.shape)\n",
"print(mat)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "buried-dependence",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "silver-printing",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 20,
"id": "outer-space",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(29746,)\n",
"[54 90 77 ... 58 58 61]\n",
"(184, 80)\n",
"[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n",
" 7.80671114]\n",
" [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n",
" 8.83860079]\n",
" [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n",
" 7.96168491]\n",
" ...\n",
" [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n",
" 8.75616521]\n",
" [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n",
" 7.69287184]\n",
" [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n",
" 8.16007108]]\n",
"(184, 13)\n",
"[[ 14.73775998 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n",
" 8.86862748]\n",
" [ 15.31274834 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n",
" 6.78353115]\n",
" [ 13.82218765 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n",
" -0.05919222]\n",
" ...\n",
" [ 13.5837844 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n",
" 5.59045842]\n",
" [ 13.75757034 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n",
" 12.0785146 ]\n",
" [ 11.92813809 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n",
" 0.88337965]]\n"
]
}
],
"source": [
"from python_speech_features import mfcc\n",
"from python_speech_features import delta\n",
"from python_speech_features import logfbank\n",
"import scipy.io.wavfile as iowav\n",
"\n",
"(rate,sig) = iowav.read(wav)\n",
"print(sig.shape)\n",
"print(sig)\n",
"\n",
"# note that generally nfilt=40 is used for speech recognition\n",
"fbank_feat = logfbank(sig,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n",
"print(fbank_feat.shape)\n",
"print(fbank_feat)\n",
"\n",
"# the computed fbank coefficents of english.wav with dimension [110,23]\n",
"# [ 12.2865\t12.6906\t13.1765\t15.714\t16.064\t15.7553\t16.5746\t16.9205\t16.6472\t16.1302\t16.4576\t16.7326\t16.8864\t17.7215\t18.88\t19.1377\t19.1495\t18.6683\t18.3886\t20.3506\t20.2772\t18.8248\t18.1899\n",
"# 11.9198\t13.146\t14.7215\t15.8642\t17.4288\t16.394\t16.8238\t16.1095\t16.4297\t16.6331\t16.3163\t16.5093\t17.4981\t18.3429\t19.6555\t19.6263\t19.8435\t19.0534\t19.001\t20.0287\t19.7707\t19.5852\t19.1112\n",
"# ...\n",
"# ...\n",
"# the same with that using kaldi commands: compute-fbank-feats --dither=0.0\n",
"\n",
"mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')\n",
"print(mfcc_feat.shape)\n",
"print(mfcc_feat)\n",
"\n",
"# the computed mfcc coefficents of english.wav with dimension [110,13]\n",
"# [ 17.1337\t-23.3651\t-7.41751\t-7.73686\t-21.3682\t-8.93884\t-3.70843\t4.68346\t-16.0676\t12.782\t-7.24054\t8.25089\t10.7292\n",
"# 17.1692\t-23.3028\t-5.61872\t-4.0075\t-23.287\t-20.6101\t-5.51584\t-6.15273\t-14.4333\t8.13052\t-0.0345329\t2.06274\t-0.564298\n",
"# ...\n",
"# ...\n",
"# the same with that using kaldi commands: compute-mfcc-feats --dither=0.0"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "sporting-school",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(184, 80)\n",
"[[-10.17703627 -10.71732606 -15.46954014 ... -10.54623152 -11.89742148\n",
" -12.98770428]\n",
" [ -9.75040771 -10.47634331 -14.48575413 ... -9.55710616 -10.43602673\n",
" -11.95581463]\n",
" [-10.52510987 -10.79804975 -13.46475161 ... -10.34309947 -11.10146239\n",
" -12.83273051]\n",
" ...\n",
" [-10.64944197 -10.90767335 -14.05640404 ... -10.57860915 -11.7909807\n",
" -12.03825021]\n",
" [-10.8169558 -11.11491806 -12.88781116 ... -10.57004889 -11.19985048\n",
" -13.10154358]\n",
" [-14.32084168 -13.03106051 -13.03675699 ... -10.82919465 -11.17177892\n",
" -12.63434434]]\n",
"(184, 13)\n",
"[[ -6.05665544 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n",
" 8.86862748]\n",
" [ -5.48166707 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n",
" 6.78353115]\n",
" [ -6.97222776 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n",
" -0.05919222]\n",
" ...\n",
" [ -7.21063102 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n",
" 5.59045842]\n",
" [ -7.03684508 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n",
" 12.0785146 ]\n",
" [ -8.86627732 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n",
" 0.88337965]]\n"
]
}
],
"source": [
"fbank_feat = logfbank(samples._samples,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n",
"print(fbank_feat.shape)\n",
"print(fbank_feat)\n",
"\n",
"mfcc_feat = mfcc(samples._samples,dither=0,useEnergy=True,wintype='povey')\n",
"print(mfcc_feat.shape)\n",
"print(mfcc_feat)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "restricted-license",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "specialized-threat",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "breeding-haven",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "appropriate-theta",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LICENSE deepspeech examples\t\t requirements.txt tools\r\n",
"README.md docs\t libsndfile-1.0.28\t setup.sh\t utils\r\n",
"README_cn.md env.sh\t libsndfile-1.0.28.tar.gz tests\r\n"
]
}
],
"source": [
"!ls"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "entire-bloom",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"WARNING:root:override cat of paddle.Tensor if exists or register, remove this when fixed!\n",
"WARNING:root:register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user repeat to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user glu to paddle.nn.functional, remove this when fixed!\n",
"WARNING:root:register user GLU to paddle.nn, remove this when fixed!\n",
"WARNING:root:register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"WARNING:root:override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n"
]
}
],
"source": [
"from deepspeech.modules import loss"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "governmental-aircraft",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"import paddle"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "proprietary-disaster",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function deepspeech.modules.repeat(xs: paddle.VarBase, *size: Any) -> paddle.VarBase>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.repeat"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "first-diagram",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<property at 0x7fb515eeeb88>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.size"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "intelligent-david",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function paddle.tensor.manipulation.concat(x, axis=0, name=None)>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.Tensor.cat"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bronze-tenant",
"metadata": {},
"outputs": [],
"source": [
"a = paddle.to_tensor([12,32, 10, 12, 123,32 ,4])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "balanced-bearing",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "extreme-republic",
"metadata": {},
"outputs": [],
"source": [
"def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:\n",
" nargs = len(args)\n",
" assert (nargs <= 1)\n",
" s = paddle.shape(xs)\n",
" if nargs == 1:\n",
" return s[args[0]]\n",
" else:\n",
" return s\n",
"\n",
"# logger.warn(\n",
"# \"override size of paddle.Tensor if exists or register, remove this when fixed!\"\n",
"# )\n",
"paddle.Tensor.size = size"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "gross-addiction",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [7])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size(0)\n",
"a.size()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "adverse-dining",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [7])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.size()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "popular-potato",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x\n"
]
},
{
"data": {
"text/plain": [
"'/home/ssd5/zhanghui/DeepSpeech2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-26 02:55:23,873 - WARNING - register user softmax to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,875 - WARNING - register user sigmoid to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,875 - WARNING - register user relu to paddle, remove this when fixed!\n",
"2021-03-26 02:55:23,876 - WARNING - override cat of paddle if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,876 - WARNING - override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,877 - WARNING - override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"2021-03-26 02:55:23,877 - WARNING - override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"2021-03-26 02:55:23,878 - WARNING - register user view to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,878 - WARNING - register user view_as to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,879 - WARNING - register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,880 - WARNING - register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,880 - WARNING - register user fill_ to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,881 - WARNING - register user repeat to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,881 - WARNING - register user softmax to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,882 - WARNING - register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,882 - WARNING - register user relu to paddle.Tensor, remove this when fixed!\n",
"2021-03-26 02:55:23,883 - WARNING - register user glu to paddle.nn.functional, remove this when fixed!\n",
"2021-03-26 02:55:23,883 - WARNING - override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"2021-03-26 02:55:23,884 - WARNING - register user GLU to paddle.nn, remove this when fixed!\n",
"2021-03-26 02:55:23,884 - WARNING - register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n"
]
}
],
"source": [
"import os\n",
"import time\n",
"import argparse\n",
"import functools\n",
"import paddle\n",
"import numpy as np\n",
"\n",
"from deepspeech.utils.socket_server import warm_up_test\n",
"from deepspeech.utils.socket_server import AsrTCPServer\n",
"from deepspeech.utils.socket_server import AsrRequestHandler\n",
"\n",
"from deepspeech.training.cli import default_argument_parser\n",
"from deepspeech.exps.deepspeech2.config import get_cfg_defaults\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"\n",
"from deepspeech.models.deepspeech2 import DeepSpeech2Model\n",
"from deepspeech.models.deepspeech2 import DeepSpeech2InferModel\n",
"from deepspeech.io.dataset import ManifestDataset\n",
"\n",
"\n",
"\n",
"from deepspeech.frontend.utility import read_manifest"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0.0\n",
"e7f28d6c0db54eb9c9a810612300b526687e56a6\n",
"OFF\n",
"OFF\n",
"commit: e7f28d6c0db54eb9c9a810612300b526687e56a6\n",
"None\n",
"0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"data": {
"text/plain": [
"['__builtins__',\n",
" '__cached__',\n",
" '__doc__',\n",
" '__file__',\n",
" '__loader__',\n",
" '__name__',\n",
" '__package__',\n",
" '__spec__',\n",
" 'commit',\n",
" 'full_version',\n",
" 'istaged',\n",
" 'major',\n",
" 'minor',\n",
" 'mkl',\n",
" 'patch',\n",
" 'rc',\n",
" 'show',\n",
" 'with_mkl']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(paddle.__version__)\n",
"print(paddle.version.commit)\n",
"print(paddle.version.with_mkl)\n",
"print(paddle.version.mkl())\n",
"print(paddle.version.show())\n",
"print(paddle.version.patch)\n",
"dir(paddle.version)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data:\n",
" augmentation_config: conf/augmentation.config\n",
" batch_size: 64\n",
" dev_manifest: data/manifest.dev\n",
" keep_transcription_text: False\n",
" max_duration: 27.0\n",
" max_freq: None\n",
" mean_std_filepath: examples/aishell/data/mean_std.npz\n",
" min_duration: 0.0\n",
" n_fft: None\n",
" num_workers: 0\n",
" random_seed: 0\n",
" shuffle_method: batch_shuffle\n",
" sortagrad: True\n",
" specgram_type: linear\n",
" stride_ms: 10.0\n",
" target_dB: -20\n",
" target_sample_rate: 16000\n",
" test_manifest: examples/aishell/data/manifest.test\n",
" train_manifest: data/manifest.train\n",
" use_dB_normalization: True\n",
" vocab_filepath: examples/aishell/data/vocab.txt\n",
" window_ms: 20.0\n",
"decoding:\n",
" alpha: 2.6\n",
" batch_size: 128\n",
" beam_size: 300\n",
" beta: 5.0\n",
" cutoff_prob: 0.99\n",
" cutoff_top_n: 40\n",
" decoding_method: ctc_beam_search\n",
" error_rate_type: cer\n",
" lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm\n",
" num_proc_bsearch: 10\n",
"model:\n",
" num_conv_layers: 2\n",
" num_rnn_layers: 3\n",
" rnn_layer_size: 1024\n",
" share_rnn_weights: False\n",
" use_gru: True\n",
"training:\n",
" global_grad_clip: 5.0\n",
" lr: 0.0005\n",
" lr_decay: 0.83\n",
" n_epoch: 30\n",
" weight_decay: 1e-06\n",
"----------- Configuration Arguments -----------\n",
"checkpoint_path: examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725\n",
"config: examples/aishell/conf/deepspeech2.yaml\n",
"device: gpu\n",
"dump_config: None\n",
"export_path: None\n",
"host_ip: localhost\n",
"host_port: 8086\n",
"model_dir: None\n",
"model_file: examples/aishell/jit.model.pdmodel\n",
"nprocs: 1\n",
"opts: ['data.test_manifest', 'examples/aishell/data/manifest.test', 'data.mean_std_filepath', 'examples/aishell/data/mean_std.npz', 'data.vocab_filepath', 'examples/aishell/data/vocab.txt']\n",
"output: None\n",
"params_file: examples/aishell/jit.model.pdiparams\n",
"speech_save_dir: demo_cache\n",
"use_gpu: False\n",
"warmup_manifest: examples/aishell/data/manifest.test\n",
"------------------------------------------------\n"
]
}
],
"source": [
"parser = default_argument_parser()\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"add_arg('host_ip', str,\n",
" 'localhost',\n",
" \"Server's IP address.\")\n",
"add_arg('host_port', int, 8086, \"Server's IP port.\")\n",
"add_arg('speech_save_dir', str,\n",
" 'demo_cache',\n",
" \"Directory to save demo audios.\")\n",
"add_arg('warmup_manifest', \n",
" str, \n",
" \"examples/aishell/data/manifest.test\", \n",
" \"Filepath of manifest to warm up.\")\n",
"add_arg(\n",
" \"--model_file\",\n",
" type=str,\n",
" default=\"examples/aishell/jit.model.pdmodel\",\n",
" help=\"Model filename, Specify this when your model is a combined model.\"\n",
")\n",
"add_arg(\n",
" \"--params_file\",\n",
" type=str,\n",
" default=\"examples/aishell/jit.model.pdiparams\",\n",
" help=\n",
" \"Parameter filename, Specify this when your model is a combined model.\"\n",
")\n",
"add_arg(\n",
" \"--model_dir\",\n",
" type=str,\n",
" default=None,\n",
" help=\n",
" \"Model dir, If you load a non-combined model, specify the directory of the model.\"\n",
")\n",
"add_arg(\"--use_gpu\",type=bool,default=False, help=\"Whether use gpu.\")\n",
"\n",
"\n",
"args = parser.parse_args(\n",
" \"--checkpoint_path examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725 --config examples/aishell/conf/deepspeech2.yaml --opts data.test_manifest examples/aishell/data/manifest.test data.mean_std_filepath examples/aishell/data/mean_std.npz data.vocab_filepath examples/aishell/data/vocab.txt\".split()\n",
")\n",
"\n",
"\n",
"config = get_cfg_defaults()\n",
"if args.config:\n",
" config.merge_from_file(args.config)\n",
"if args.opts:\n",
" config.merge_from_list(args.opts)\n",
"config.freeze()\n",
"print(config)\n",
"\n",
"args.warmup_manifest = config.data.test_manifest\n",
"\n",
"print_arguments(args)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dataset = ManifestDataset(\n",
" config.data.test_manifest,\n",
" config.data.unit_type,\n",
" config.data.vocab_filepath,\n",
" config.data.mean_std_filepath,\n",
" augmentation_config=\"{}\",\n",
" max_duration=config.data.max_duration,\n",
" min_duration=config.data.min_duration,\n",
" stride_ms=config.data.stride_ms,\n",
" window_ms=config.data.window_ms,\n",
" n_fft=config.data.n_fft,\n",
" max_freq=config.data.max_freq,\n",
" target_sample_rate=config.data.target_sample_rate,\n",
" specgram_type=config.data.specgram_type,\n",
" feat_dim=config.data.feat_dim,\n",
" delta_delta=config.data.delat_delta,\n",
" use_dB_normalization=config.data.use_dB_normalization,\n",
" target_dB=config.data.target_dB,\n",
" random_seed=config.data.random_seed,\n",
" keep_transcription_text=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-26 02:55:57,930 - INFO - [checkpoint] Rank 0: loaded model from examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725.pdparams\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"layer summary:\n",
"encoder.conv.conv_in.conv.weight|[32, 1, 41, 11]|14432\n",
"encoder.conv.conv_in.bn.weight|[32]|32\n",
"encoder.conv.conv_in.bn.bias|[32]|32\n",
"encoder.conv.conv_in.bn._mean|[32]|32\n",
"encoder.conv.conv_in.bn._variance|[32]|32\n",
"encoder.conv.conv_stack.0.conv.weight|[32, 32, 21, 11]|236544\n",
"encoder.conv.conv_stack.0.bn.weight|[32]|32\n",
"encoder.conv.conv_stack.0.bn.bias|[32]|32\n",
"encoder.conv.conv_stack.0.bn._mean|[32]|32\n",
"encoder.conv.conv_stack.0.bn._variance|[32]|32\n",
"encoder.rnn.rnn_stacks.0.fw_fc.weight|[1312, 3072]|4030464\n",
"encoder.rnn.rnn_stacks.0.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_fc.weight|[1312, 3072]|4030464\n",
"encoder.rnn.rnn_stacks.0.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.0.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.0.bw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.1.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.1.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.1.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.1.bw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.2.fw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_fc.weight|[2048, 3072]|6291456\n",
"encoder.rnn.rnn_stacks.2.bw_bn.weight|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn.bias|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn._mean|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_bn._variance|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.fw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.bw_cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.fw_rnn.cell.bias_hh|[3072]|3072\n",
"encoder.rnn.rnn_stacks.2.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n",
"encoder.rnn.rnn_stacks.2.bw_rnn.cell.bias_hh|[3072]|3072\n",
"decoder.ctc_lo.weight|[2048, 4300]|8806400\n",
"decoder.ctc_lo.bias|[4300]|4300\n",
"layer has 66 parameters, 80148012 elements.\n"
]
}
],
"source": [
"model = DeepSpeech2InferModel.from_pretrained(dataset, config,\n",
" args.checkpoint_path)\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"examples/aishell/jit.model.pdmodel\n",
"examples/aishell/jit.model.pdiparams\n",
"0\n",
"False\n"
]
}
],
"source": [
"\n",
"from paddle.inference import Config\n",
"from paddle.inference import PrecisionType\n",
"from paddle.inference import create_predictor\n",
"\n",
"args.use_gpu=False\n",
"paddle.set_device('cpu')\n",
"\n",
"def init_predictor(args):\n",
" if args.model_dir is not None:\n",
" config = Config(args.model_dir)\n",
" else:\n",
" config = Config(args.model_file, args.params_file)\n",
"\n",
" if args.use_gpu:\n",
" config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)\n",
"# config.enable_tensorrt_engine(precision_mode=PrecisionType.Float32,\n",
"# use_calib_mode=True) # 开启TensorRT预测,精度为fp32,开启int8离线量化\n",
" else:\n",
" # If not specific mkldnn, you can set the blas thread.\n",
" # The thread num should not be greater than the number of cores in the CPU.\n",
" config.set_cpu_math_library_num_threads(1)\n",
" config.enable_mkldnn()\n",
" \n",
" config.enable_memory_optim()\n",
" config.switch_ir_optim(True)\n",
" \n",
" print(config.model_dir())\n",
" print(config.prog_file())\n",
" print(config.params_file())\n",
" print(config.gpu_device_id())\n",
" print(args.use_gpu)\n",
" predictor = create_predictor(config)\n",
" return predictor\n",
"\n",
"def run(predictor, audio, audio_len):\n",
" # copy img data to input tensor\n",
" input_names = predictor.get_input_names()\n",
" for i, name in enumerate(input_names):\n",
" print(\"input:\", i, name)\n",
" \n",
" audio_tensor = predictor.get_input_handle('audio')\n",
" audio_tensor.reshape(audio.shape)\n",
" audio_tensor.copy_from_cpu(audio.copy())\n",
" \n",
" audiolen_tensor = predictor.get_input_handle('audio_len')\n",
" audiolen_tensor.reshape(audio_len.shape)\n",
" audiolen_tensor.copy_from_cpu(audio_len.copy())\n",
"\n",
" output_names = predictor.get_output_names()\n",
" for i, name in enumerate(output_names):\n",
" print(\"output:\", i, name)\n",
"\n",
" # do the inference\n",
" predictor.run()\n",
"\n",
" results = []\n",
" # get out data from output tensor\n",
" output_names = predictor.get_output_names()\n",
" for i, name in enumerate(output_names):\n",
" output_tensor = predictor.get_output_handle(name)\n",
" output_data = output_tensor.copy_to_cpu()\n",
" results.append(output_data)\n",
"\n",
" return results\n",
"\n",
"\n",
"predictor = init_predictor(args)\n",
"\n",
"def file_to_transcript(filename):\n",
" print(filename)\n",
" feature = dataset.process_utterance(filename, \"\")\n",
" audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n",
" audio_len = feature[0].shape[1]\n",
" audio_len = np.array([audio_len]).astype('int64') # [1]\n",
" \n",
" \n",
" i_probs = run(predictor, audio, audio_len)\n",
" print('jit:', i_probs[0], type(i_probs[0]))\n",
" \n",
" audio = paddle.to_tensor(audio)\n",
" audio_len = paddle.to_tensor(audio_len)\n",
" print(audio.shape)\n",
" print(audio_len.shape)\n",
" \n",
" #eouts, eouts_len = model.encoder(audio, audio_len)\n",
" #probs = model.decoder.softmax(eouts)\n",
" probs = model.forward(audio, audio_len)\n",
" print('paddle:', probs.numpy())\n",
" \n",
" flag = np.allclose(i_probs[0], probs.numpy())\n",
" print(flag)\n",
" \n",
" return probs\n",
"\n",
"# result_transcript = model.decode(\n",
"# audio,\n",
"# audio_len,\n",
"# vocab_list=dataset.vocab_list,\n",
"# decoding_method=config.decoding.decoding_method,\n",
"# lang_model_path=config.decoding.lang_model_path,\n",
"# beam_alpha=config.decoding.alpha,\n",
"# beam_beta=config.decoding.beta,\n",
"# beam_size=config.decoding.beam_size,\n",
"# cutoff_prob=config.decoding.cutoff_prob,\n",
"# cutoff_top_n=config.decoding.cutoff_top_n,\n",
"# num_processes=config.decoding.num_proc_bsearch)\n",
"# return result_transcript[0]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warm-up Test Case %d: %s 0 /home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n",
"/home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n",
"input: 0 audio\n",
"input: 1 audio_len\n",
"output: 0 tmp_75\n",
"jit: [[[8.91786298e-12 4.45648032e-12 3.67572750e-09 ... 8.91767563e-12\n",
" 8.91573707e-12 4.64317296e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638127e-17 7.61802427e-16 2.93265812e-14 ... 1.24633371e-17\n",
" 1.24587264e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676260e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89334696e-13 1.66754856e-11 1.42900388e-11 ... 3.89329492e-13\n",
" 3.89252270e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]] <class 'numpy.ndarray'>\n",
"[1, 161, 522]\n",
"[1]\n",
"paddle: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n",
" 8.91577090e-12 4.64319072e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n",
" 1.24587735e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n",
" 3.89253761e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]]\n",
"False\n"
]
}
],
"source": [
"manifest = read_manifest(args.warmup_manifest)\n",
"\n",
"for idx, sample in enumerate(manifest[:1]):\n",
" print(\"Warm-up Test Case %d: %s\", idx, sample['audio_filepath'])\n",
" start_time = time.time()\n",
" transcript = file_to_transcript(sample['audio_filepath'])\n",
" finish_time = time.time()\n",
"# print(\"Response Time: %f, Transcript: %s\" %\n",
"# (finish_time - start_time, transcript))\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 161, 522) (1,)\n",
"input: 0 audio\n",
"input: 1 audio_len\n",
"output: 0 tmp_75\n",
"jit: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n",
" 8.91577090e-12 4.64319072e-08]\n",
" [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n",
" 1.55891342e-15 9.99992609e-01]\n",
" [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n",
" 1.24587735e-17 1.00000000e+00]\n",
" ...\n",
" [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n",
" 4.37354747e-15 1.00000000e+00]\n",
" [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n",
" 3.89253761e-13 1.00000000e+00]\n",
" [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n",
" 1.00334095e-10 9.99998808e-01]]]\n"
]
}
],
"source": [
"def test(filename):\n",
" feature = dataset.process_utterance(filename, \"\")\n",
" audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n",
" audio_len = feature[0].shape[1]\n",
" audio_len = np.array([audio_len]).astype('int64') # [1]\n",
" \n",
" print(audio.shape, audio_len.shape)\n",
"\n",
" i_probs = run(predictor, audio, audio_len)\n",
" print('jit:', i_probs[0])\n",
" return i_probs\n",
" \n",
"probs = test(sample['audio_filepath'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 32,
"id": "academic-surname",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from paddle import nn"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "fundamental-treasure",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])\n",
"Parameter containing:\n",
"Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n"
]
}
],
"source": [
"L = nn.LayerNorm(256, epsilon=1e-12)\n",
"for p in L.parameters():\n",
" print(p)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "consolidated-elephant",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "moderate-noise",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float64\n"
]
}
],
"source": [
"x = np.random.randn(2, 51, 256)\n",
"print(x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "cooked-progressive",
"metadata": {},
"outputs": [],
"source": [
"y = L(paddle.to_tensor(x, dtype='float32'))"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "optimum-milwaukee",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "viral-indian",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1.], requires_grad=True)\n",
"Parameter containing:\n",
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" requires_grad=True)\n"
]
}
],
"source": [
"TL = torch.nn.LayerNorm(256, eps=1e-12)\n",
"for p in TL.parameters():\n",
" print(p)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "skilled-vietnamese",
"metadata": {},
"outputs": [],
"source": [
"ty = TL(torch.tensor(x, dtype=torch.float32))"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "incorrect-allah",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.allclose(y.numpy(), ty.detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "prostate-cameroon",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 52,
"id": "governmental-surge",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.random.randn(2, 256)\n",
"y = L(paddle.to_tensor(x, dtype='float32'))\n",
"ty = TL(torch.tensor(x, dtype=torch.float32))\n",
"np.allclose(y.numpy(), ty.detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "confidential-jacket",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "primary-organic",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "stopped-semester",
"metadata": {},
"outputs": [],
"source": [
"def mask_finished_scores(score: torch.Tensor,\n",
" flag: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" If a sequence is finished, we only allow one alive branch. This function\n",
" aims to give one branch a zero score and the rest -inf score.\n",
" Args:\n",
" score (torch.Tensor): A real value array with shape\n",
" (batch_size * beam_size, beam_size).\n",
" flag (torch.Tensor): A bool array with shape\n",
" (batch_size * beam_size, 1).\n",
" Returns:\n",
" torch.Tensor: (batch_size * beam_size, beam_size).\n",
" \"\"\"\n",
" beam_size = score.size(-1)\n",
" zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n",
" if beam_size > 1:\n",
" unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n",
" dim=1)\n",
" finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n",
" dim=1)\n",
" else:\n",
" unfinished = zero_mask\n",
" finished = flag\n",
" print(unfinished)\n",
" print(finished)\n",
" score.masked_fill_(unfinished, -float('inf'))\n",
" score.masked_fill_(finished, 0)\n",
" return score"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "agreed-portuguese",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ True],\n",
" [False]])\n",
"tensor([[-0.8841, 0.7381, -0.9986],\n",
" [ 0.2675, -0.7971, 0.3798]])\n",
"tensor([[ True, True],\n",
" [False, False]])\n"
]
}
],
"source": [
"score = torch.randn((2, 3))\n",
"flag = torch.ones((2, 1), dtype=torch.bool)\n",
"flag[1] = False\n",
"print(flag)\n",
"print(score)\n",
"print(flag.repeat([1, 2]))"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "clean-aspect",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[False, True, True],\n",
" [False, False, False]])\n",
"tensor([[ True, False, False],\n",
" [False, False, False]])\n",
"tensor([[ 0.0000, -inf, -inf],\n",
" [ 0.2675, -0.7971, 0.3798]])\n",
"tensor([[ 0.0000, -inf, -inf],\n",
" [ 0.2675, -0.7971, 0.3798]])\n"
]
}
],
"source": [
"r = mask_finished_scores(score, flag)\n",
"print(r)\n",
"print(score)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "thrown-airline",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True ],\n",
" [False]])\n",
"Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"import paddle\n",
"\n",
"score = paddle.randn((2, 3))\n",
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag)\n",
"print(score)\n",
"print(flag.tile([1, 2]))"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "internal-patent",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[False, True , True ],\n",
" [False, False, False]])\n",
"Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , False, False],\n",
" [False, False, False]])\n",
"x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, 1.87704289, 0.01988174],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 2.05994511, -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n",
"Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])\n"
]
}
],
"source": [
"paddle.bool = 'bool'\n",
"\n",
"def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
" print(xs)\n",
" trues = paddle.ones_like(xs) * value\n",
" assert xs.shape == mask.shape\n",
" xs = paddle.where(mask, trues, xs)\n",
" return xs\n",
"\n",
"def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
" print('x', xs)\n",
" trues = paddle.ones_like(xs) * value\n",
" assert xs.shape == mask.shape\n",
" ret = paddle.where(mask, trues, xs)\n",
" print('2', xs)\n",
" paddle.assign(ret, output=xs)\n",
" print('3', xs)\n",
"\n",
"paddle.Tensor.masked_fill = masked_fill\n",
"paddle.Tensor.masked_fill_ = masked_fill_\n",
"\n",
"def mask_finished_scores_pd(score: paddle.Tensor,\n",
" flag: paddle.Tensor) -> paddle.Tensor:\n",
" \"\"\"\n",
" If a sequence is finished, we only allow one alive branch. This function\n",
" aims to give one branch a zero score and the rest -inf score.\n",
" Args:\n",
" score (torch.Tensor): A real value array with shape\n",
" (batch_size * beam_size, beam_size).\n",
" flag (torch.Tensor): A bool array with shape\n",
" (batch_size * beam_size, 1).\n",
" Returns:\n",
" torch.Tensor: (batch_size * beam_size, beam_size).\n",
" \"\"\"\n",
" beam_size = score.shape[-1]\n",
" zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n",
" if beam_size > 1:\n",
" unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n",
" axis=1)\n",
" finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n",
" axis=1)\n",
" else:\n",
" unfinished = zero_mask\n",
" finished = flag\n",
" print(unfinished)\n",
" print(finished)\n",
" \n",
" #score.masked_fill_(unfinished, -float('inf'))\n",
" #score.masked_fill_(finished, 0)\n",
"# infs = paddle.ones_like(score) * -float('inf')\n",
"# score = paddle.where(unfinished, infs, score)\n",
"# score = paddle.where(finished, paddle.zeros_like(score), score)\n",
"\n",
"# score = score.masked_fill(unfinished, -float('inf'))\n",
"# score = score.masked_fill(finished, 0)\n",
" score.masked_fill_(unfinished, -float('inf'))\n",
" score.masked_fill_(finished, 0)\n",
" return score\n",
"\n",
"r = mask_finished_scores_pd(score, flag)\n",
"print(r)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "vocal-prime",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<bound method PyCapsule.value of Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[ 0. , -inf. , -inf. ],\n",
" [-0.40165186, 0.77547729, -0.64469045]])>"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"score.value"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "bacterial-adolescent",
"metadata": {},
"outputs": [],
"source": [
"from typing import Union, Any"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "absent-fiber",
"metadata": {},
"outputs": [],
"source": [
"def repeat(xs : paddle.Tensor, *size: Any):\n",
" print(size)\n",
" return paddle.tile(xs, size)\n",
"paddle.Tensor.repeat = repeat"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "material-harbor",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 2)\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag.repeat(1, 2))"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "acute-brighton",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [1]), 2)\n",
"Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[True , True ],\n",
" [False, False]])\n"
]
}
],
"source": [
"flag = paddle.ones((2, 1), dtype='bool')\n",
"flag[1] = False\n",
"print(flag.repeat(paddle.to_tensor(1), 2))"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "european-rugby",
"metadata": {},
"outputs": [],
"source": [
"def size(xs, *args: int):\n",
" nargs = len(args)\n",
" s = paddle.shape(xs)\n",
" assert(nargs <= 1)\n",
" if nargs == 1:\n",
" return s[args[0]]\n",
" else:\n",
" return s\n",
"paddle.Tensor.size = size"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "moral-special",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [2, 1])"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size()"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "ahead-coach",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [1])"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size(1)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "incomplete-fitness",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
" [2])"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flag.size(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "upset-connectivity",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "designing-borough",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n",
"True\n",
"True\n"
]
}
],
"source": [
"import torch\n",
"import math\n",
"import numpy as np\n",
"\n",
"max_len=100\n",
"d_model=256\n",
"\n",
"pe = torch.zeros(max_len, d_model)\n",
"position = torch.arange(0, max_len,\n",
" dtype=torch.float32).unsqueeze(1)\n",
"toruch_position = position\n",
"div_term = torch.exp(\n",
" torch.arange(0, d_model, 2, dtype=torch.float32) *\n",
" -(math.log(10000.0) / d_model))\n",
"tourch_div_term = div_term.cpu().detach().numpy()\n",
"\n",
"\n",
"\n",
"torhc_sin = torch.sin(position * div_term)\n",
"torhc_cos = torch.cos(position * div_term)\n",
"print(torhc_sin.cpu().detach().numpy())\n",
"np_sin = np.sin((position * div_term).cpu().detach().numpy())\n",
"np_cos = np.cos((position * div_term).cpu().detach().numpy())\n",
"print(np.allclose(np_sin, torhc_sin.cpu().detach().numpy()))\n",
"print(np.allclose(np_cos, torhc_cos.cpu().detach().numpy()))\n",
"pe[:, 0::2] = torhc_sin\n",
"pe[:, 1::2] = torhc_cos\n",
"tourch_pe = pe.cpu().detach().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "swiss-referral",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"False\n",
"False\n",
"False\n",
"False\n",
"[[ 1. 1. 1. ... 1. 1.\n",
" 1. ]\n",
" [ 0.5403023 0.59737533 0.6479059 ... 1. 1.\n",
" 1. ]\n",
" [-0.41614684 -0.28628543 -0.1604359 ... 0.99999994 1.\n",
" 1. ]\n",
" ...\n",
" [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.99993724\n",
" 0.9999457 ]\n",
" [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n",
" 0.99994457]\n",
" [ 0.03982088 -0.52298605 -0.6157435 ... 0.99992454 0.9999347\n",
" 0.99994344]]\n",
"----\n",
"[[ 1. 1. 1. ... 1. 1.\n",
" 1. ]\n",
" [ 0.54030234 0.59737533 0.6479059 ... 1. 1.\n",
" 1. ]\n",
" [-0.41614684 -0.28628543 -0.1604359 ... 1. 1.\n",
" 1. ]\n",
" ...\n",
" [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.9999373\n",
" 0.9999457 ]\n",
" [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n",
" 0.99994457]\n",
" [ 0.03982088 -0.5229861 -0.6157435 ... 0.99992454 0.9999347\n",
" 0.99994344]]\n",
")))))))\n",
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n",
"----\n",
"[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
" [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n",
" 1.1547816e-04 1.0746076e-04]\n",
" [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n",
" 2.3095631e-04 2.1492151e-04]\n",
" ...\n",
" [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n",
" 1.1201146e-02 1.0423505e-02]\n",
" [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n",
" 1.1316618e-02 1.0530960e-02]\n",
" [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n",
" 1.1432089e-02 1.0638415e-02]]\n"
]
}
],
"source": [
"import paddle\n",
"paddle.set_device('cpu')\n",
"ppe = paddle.zeros((max_len, d_model), dtype='float32')\n",
"position = paddle.arange(0, max_len,\n",
" dtype='float32').unsqueeze(1)\n",
"print(np.allclose(position.numpy(), toruch_position))\n",
"div_term = paddle.exp(\n",
" paddle.arange(0, d_model, 2, dtype='float32') *\n",
" -(math.log(10000.0) / d_model))\n",
"print(np.allclose(div_term.numpy(), tourch_div_term))\n",
"\n",
"\n",
"\n",
"p_sin = paddle.sin(position * div_term)\n",
"p_cos = paddle.cos(position * div_term)\n",
"print(np.allclose(np_sin, p_sin.numpy(), rtol=1.e-6, atol=0))\n",
"print(np.allclose(np_cos, p_cos.numpy(), rtol=1.e-6, atol=0))\n",
"ppe[:, 0::2] = p_sin\n",
"ppe[:, 1::2] = p_cos\n",
"print(np.allclose(p_sin.numpy(), torhc_sin.cpu().detach().numpy()))\n",
"print(np.allclose(p_cos.numpy(), torhc_cos.cpu().detach().numpy()))\n",
"print(p_cos.numpy())\n",
"print(\"----\")\n",
"print(torhc_cos.cpu().detach().numpy())\n",
"print(\")))))))\")\n",
"print(p_sin.numpy())\n",
"print(\"----\")\n",
"print(torhc_sin.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "integrated-boards",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
"source": [
"print(np.allclose(ppe.numpy(), pe.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "flying-reserve",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "revised-divide",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
因为 它太大了无法显示 source diff 。你可以改为 查看blob
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cloudy-glass",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISISBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "grand-stephen",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0.0\n"
]
}
],
"source": [
"import paddle\n",
"print(paddle.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "isolated-prize",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "romance-samuel",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/data/manifest.dev',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/aishell/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'linear',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "timely-bikini",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"from data_utils.dataset import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=False,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "organized-warrior",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if arr.dtype == np.object:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[14 , 34 , 322 , 233 , 0 , 0 ],\n",
" [238 , 38 , 122 , 164 , 0 , 0 ],\n",
" [8 , 52 , 49 , 42 , 0 , 0 ],\n",
" [109 , 47 , 146 , 193 , 210 , 479 ],\n",
" [3330, 1751, 208 , 1923, 0 , 0 ]])\n",
"test raw 大时代里的的\n",
"test raw 煲汤受宠的的\n",
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"test len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [4, 4, 4, 6, 4])\n",
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
" ...,\n",
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
"\n",
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
" ...,\n",
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
]
}
],
"source": [
" for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n",
" print('audio len', audio_len)\n",
" print('test len', text_len)\n",
" print('audio', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "confidential-radius",
"metadata": {},
"outputs": [],
"source": [
"# reader = batch_reader()\n",
"# audio, test , audio_len, text_len = reader.next()\n",
"# print('test', text)\n",
"# print('t len', text_len) #[B, T]\n",
"# print('audio len', audio_len)\n",
"# print(audio)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "future-vermont",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"煲汤受宠\n"
]
}
],
"source": [
"print(u'\\u7172\\u6c64\\u53d7\\u5ba0')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dental-sweden",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "sunrise-contact",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "hispanic-asthma",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "hearing-leadership",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "skilled-friday",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "copyrighted-measure",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"id": "employed-lightweight",
"metadata": {},
"outputs": [],
"source": [
"from model_utils.network import DeepSpeech2, DeepSpeech2Loss\n",
"\n",
"from data_utils.dataset import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=False,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)\n",
"\n",
"\n",
"import paddle\n",
"from paddle import nn\n",
"from paddle.nn import functional as F\n",
"from paddle.nn import initializer as I\n",
"\n",
"import math\n",
"\n",
"def brelu(x, t_min=0.0, t_max=24.0, name=None):\n",
" t_min = paddle.to_tensor(t_min)\n",
" t_max = paddle.to_tensor(t_max)\n",
" return x.maximum(t_min).minimum(t_max)\n",
"\n",
"def sequence_mask(x_len, max_len=None, dtype='float32'):\n",
" max_len = max_len or x_len.max()\n",
" x_len = paddle.unsqueeze(x_len, -1)\n",
" row_vector = paddle.arange(max_len)\n",
" mask = row_vector > x_len # maybe a bug\n",
" mask = paddle.cast(mask, dtype)\n",
" print(f'seq mask: {mask}')\n",
" return mask\n",
"\n",
"\n",
"class ConvBn(nn.Layer):\n",
" def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,\n",
" padding, act):\n",
"\n",
" super().__init__()\n",
" self.kernel_size = kernel_size\n",
" self.stride = stride\n",
" self.padding = padding\n",
"\n",
" self.conv = nn.Conv2D(\n",
" num_channels_in,\n",
" num_channels_out,\n",
" kernel_size=kernel_size,\n",
" stride=stride,\n",
" padding=padding,\n",
" weight_attr=None,\n",
" bias_attr=None,\n",
" data_format='NCHW')\n",
"\n",
" self.bn = nn.BatchNorm2D(\n",
" num_channels_out,\n",
" weight_attr=None,\n",
" bias_attr=None,\n",
" data_format='NCHW')\n",
" self.act = F.relu if act == 'relu' else brelu\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x(Tensor): audio, shape [B, C, D, T]\n",
" \"\"\"\n",
" x = self.conv(x)\n",
" x = self.bn(x)\n",
" x = self.act(x)\n",
"\n",
" x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]\n",
" ) // self.stride[1] + 1\n",
"\n",
" # reset padding part to 0\n",
" masks = sequence_mask(x_len) #[B, T]\n",
" masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]\n",
" x = x.multiply(masks)\n",
"\n",
" return x, x_len\n",
"\n",
"\n",
"class ConvStack(nn.Layer):\n",
" def __init__(self, feat_size, num_stacks):\n",
" super().__init__()\n",
" self.feat_size = feat_size # D\n",
" self.num_stacks = num_stacks\n",
"\n",
" self.conv_in = ConvBn(\n",
" num_channels_in=1,\n",
" num_channels_out=32,\n",
" kernel_size=(41, 11), #[D, T]\n",
" stride=(2, 3),\n",
" padding=(20, 5),\n",
" act='brelu')\n",
"\n",
" out_channel = 32\n",
" self.conv_stack = nn.Sequential([\n",
" ConvBn(\n",
" num_channels_in=32,\n",
" num_channels_out=out_channel,\n",
" kernel_size=(21, 11),\n",
" stride=(2, 1),\n",
" padding=(10, 5),\n",
" act='brelu') for i in range(num_stacks - 1)\n",
" ])\n",
"\n",
" # conv output feat_dim\n",
" output_height = (feat_size - 1) // 2 + 1\n",
" for i in range(self.num_stacks - 1):\n",
" output_height = (output_height - 1) // 2 + 1\n",
" self.output_height = out_channel * output_height\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x: shape [B, C, D, T]\n",
" x_len : shape [B]\n",
" \"\"\"\n",
" print(f\"conv in: {x_len}\")\n",
" x, x_len = self.conv_in(x, x_len)\n",
" for i, conv in enumerate(self.conv_stack):\n",
" print(f\"conv in: {x_len}\")\n",
" x, x_len = conv(x, x_len)\n",
" print(f\"conv out: {x_len}\")\n",
" return x, x_len\n",
" \n",
" \n",
"\n",
"class RNNCell(nn.RNNCellBase):\n",
" r\"\"\"\n",
" Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it \n",
" computes the outputs and updates states.\n",
" The formula used is as follows:\n",
" .. math::\n",
" h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})\n",
" y_{t} & = h_{t}\n",
" \n",
" where :math:`act` is for :attr:`activation`.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" hidden_size,\n",
" activation=\"tanh\",\n",
" weight_ih_attr=None,\n",
" weight_hh_attr=None,\n",
" bias_ih_attr=None,\n",
" bias_hh_attr=None,\n",
" name=None):\n",
" super().__init__()\n",
" std = 1.0 / math.sqrt(hidden_size)\n",
" self.weight_hh = self.create_parameter(\n",
" (hidden_size, hidden_size),\n",
" weight_hh_attr,\n",
" default_initializer=I.Uniform(-std, std))\n",
" # self.bias_ih = self.create_parameter(\n",
" # (hidden_size, ),\n",
" # bias_ih_attr,\n",
" # is_bias=True,\n",
" # default_initializer=I.Uniform(-std, std))\n",
" self.bias_ih = None\n",
" self.bias_hh = self.create_parameter(\n",
" (hidden_size, ),\n",
" bias_hh_attr,\n",
" is_bias=True,\n",
" default_initializer=I.Uniform(-std, std))\n",
"\n",
" self.hidden_size = hidden_size\n",
" if activation not in [\"tanh\", \"relu\", \"brelu\"]:\n",
" raise ValueError(\n",
" \"activation for SimpleRNNCell should be tanh or relu, \"\n",
" \"but get {}\".format(activation))\n",
" self.activation = activation\n",
" self._activation_fn = paddle.tanh \\\n",
" if activation == \"tanh\" \\\n",
" else F.relu\n",
" if activation == 'brelu':\n",
" self._activation_fn = brelu\n",
"\n",
" def forward(self, inputs, states=None):\n",
" if states is None:\n",
" states = self.get_initial_states(inputs, self.state_shape)\n",
" pre_h = states\n",
" i2h = inputs\n",
" if self.bias_ih is not None:\n",
" i2h += self.bias_ih\n",
" h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)\n",
" if self.bias_hh is not None:\n",
" h2h += self.bias_hh\n",
" h = self._activation_fn(i2h + h2h)\n",
" return h, h\n",
"\n",
" @property\n",
" def state_shape(self):\n",
" return (self.hidden_size, )\n",
"\n",
"\n",
"class GRUCellShare(nn.RNNCellBase):\n",
" r\"\"\"\n",
" Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, \n",
" it computes the outputs and updates states.\n",
" The formula for GRU used is as follows:\n",
" .. math::\n",
" r_{t} & = \\sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})\n",
" z_{t} & = \\sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})\n",
" \\widetilde{h}_{t} & = \\tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))\n",
" h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \\widetilde{h}_{t}\n",
" y_{t} & = h_{t}\n",
" \n",
" where :math:`\\sigma` is the sigmoid fucntion, and * is the elemetwise \n",
" multiplication operator.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" input_size,\n",
" hidden_size,\n",
" weight_ih_attr=None,\n",
" weight_hh_attr=None,\n",
" bias_ih_attr=None,\n",
" bias_hh_attr=None,\n",
" name=None):\n",
" super().__init__()\n",
" std = 1.0 / math.sqrt(hidden_size)\n",
" self.weight_hh = self.create_parameter(\n",
" (3 * hidden_size, hidden_size),\n",
" weight_hh_attr,\n",
" default_initializer=I.Uniform(-std, std))\n",
" # self.bias_ih = self.create_parameter(\n",
" # (3 * hidden_size, ),\n",
" # bias_ih_attr,\n",
" # is_bias=True,\n",
" # default_initializer=I.Uniform(-std, std))\n",
" self.bias_ih = None\n",
" self.bias_hh = self.create_parameter(\n",
" (3 * hidden_size, ),\n",
" bias_hh_attr,\n",
" is_bias=True,\n",
" default_initializer=I.Uniform(-std, std))\n",
"\n",
" self.hidden_size = hidden_size\n",
" self.input_size = input_size\n",
" self._gate_activation = F.sigmoid\n",
" #self._activation = paddle.tanh\n",
" self._activation = F.relu\n",
"\n",
" def forward(self, inputs, states=None):\n",
" if states is None:\n",
" states = self.get_initial_states(inputs, self.state_shape)\n",
"\n",
" pre_hidden = states\n",
" x_gates = inputs\n",
" if self.bias_ih is not None:\n",
" x_gates = x_gates + self.bias_ih\n",
" h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)\n",
" if self.bias_hh is not None:\n",
" h_gates = h_gates + self.bias_hh\n",
"\n",
" x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)\n",
" h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)\n",
"\n",
" r = self._gate_activation(x_r + h_r)\n",
" z = self._gate_activation(x_z + h_z)\n",
" c = self._activation(x_c + r * h_c) # apply reset gate after mm\n",
" h = (pre_hidden - c) * z + c\n",
" # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru\n",
" #h = (1-z) * pre_hidden + z * c\n",
"\n",
" return h, h\n",
"\n",
" @property\n",
" def state_shape(self):\n",
" r\"\"\"\n",
" The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch\n",
" size would be automatically inserted into shape). The shape corresponds\n",
" to the shape of :math:`h_{t-1}`.\n",
" \"\"\"\n",
" return (self.hidden_size, )\n",
"\n",
"\n",
"class BiRNNWithBN(nn.Layer):\n",
" \"\"\"Bidirectonal simple rnn layer with sequence-wise batch normalization.\n",
" The batch normalization is only performed on input-state weights.\n",
"\n",
" :param name: Name of the layer parameters.\n",
" :type name: string\n",
" :param size: Dimension of RNN cells.\n",
" :type size: int\n",
" :param share_weights: Whether to share input-hidden weights between\n",
" forward and backward directional RNNs.\n",
" :type share_weights: bool\n",
" :return: Bidirectional simple rnn layer.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, share_weights):\n",
" super().__init__()\n",
" self.share_weights = share_weights\n",
" if self.share_weights:\n",
" #input-hidden weights shared between bi-directional rnn.\n",
" self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" # batch norm is only performed on input-state projection\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = self.fw_fc\n",
" self.bw_bn = self.fw_bn\n",
" else:\n",
" self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" self.bw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
"\n",
" self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n",
" self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n",
" self.fw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n",
" self.bw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n",
"\n",
" def forward(self, x, x_len):\n",
" # x, shape [B, T, D]\n",
" fw_x = self.fw_bn(self.fw_fc(x))\n",
" bw_x = self.bw_bn(self.bw_fc(x))\n",
" fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n",
" bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n",
" x = paddle.concat([fw_x, bw_x], axis=-1)\n",
" return x, x_len\n",
"\n",
"\n",
"class BiGRUWithBN(nn.Layer):\n",
" \"\"\"Bidirectonal gru layer with sequence-wise batch normalization.\n",
" The batch normalization is only performed on input-state weights.\n",
"\n",
" :param name: Name of the layer.\n",
" :type name: string\n",
" :param input: Input layer.\n",
" :type input: Variable\n",
" :param size: Dimension of GRU cells.\n",
" :type size: int\n",
" :param act: Activation type.\n",
" :type act: string\n",
" :return: Bidirectional GRU layer.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, act):\n",
" super().__init__()\n",
" hidden_size = h_size * 3\n",
" self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" hidden_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n",
" self.bw_bn = nn.BatchNorm1D(\n",
" hidden_size, bias_attr=None, data_format='NLC')\n",
"\n",
" self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n",
" self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n",
" self.fw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n",
" self.bw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n",
"\n",
" def forward(self, x, x_len):\n",
" # x, shape [B, T, D]\n",
" fw_x = self.fw_bn(self.fw_fc(x))\n",
" bw_x = self.bw_bn(self.bw_fc(x))\n",
" fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n",
" bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n",
" x = paddle.concat([fw_x, bw_x], axis=-1)\n",
" return x, x_len\n",
"\n",
"\n",
"class RNNStack(nn.Layer):\n",
" \"\"\"RNN group with stacked bidirectional simple RNN or GRU layers.\n",
"\n",
" :param input: Input layer.\n",
" :type input: Variable\n",
" :param size: Dimension of RNN cells in each layer.\n",
" :type size: int\n",
" :param num_stacks: Number of stacked rnn layers.\n",
" :type num_stacks: int\n",
" :param use_gru: Use gru if set True. Use simple rnn if set False.\n",
" :type use_gru: bool\n",
" :param share_rnn_weights: Whether to share input-hidden weights between\n",
" forward and backward directional RNNs.\n",
" It is only available when use_gru=False.\n",
" :type share_weights: bool\n",
" :return: Output layer of the RNN group.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):\n",
" super().__init__()\n",
" self.rnn_stacks = nn.LayerList()\n",
" for i in range(num_stacks):\n",
" if use_gru:\n",
" #default:GRU using tanh\n",
" self.rnn_stacks.append(\n",
" BiGRUWithBN(i_size=i_size, h_size=h_size, act=\"relu\"))\n",
" else:\n",
" self.rnn_stacks.append(\n",
" BiRNNWithBN(\n",
" i_size=i_size,\n",
" h_size=h_size,\n",
" share_weights=share_rnn_weights))\n",
" i_size = h_size * 2\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x: shape [B, T, D]\n",
" x_len: shpae [B]\n",
" \"\"\"\n",
" for i, rnn in enumerate(self.rnn_stacks):\n",
" x, x_len = rnn(x, x_len)\n",
" masks = sequence_mask(x_len) #[B, T]\n",
" masks = masks.unsqueeze(-1) # [B, T, 1]\n",
" x = x.multiply(masks)\n",
" return x, x_len\n",
"\n",
" \n",
"class DeepSpeech2Test(DeepSpeech2):\n",
" def __init__(self,\n",
" feat_size,\n",
" dict_size,\n",
" num_conv_layers=2,\n",
" num_rnn_layers=3,\n",
" rnn_size=256,\n",
" use_gru=False,\n",
" share_rnn_weights=True):\n",
" super().__init__(feat_size,\n",
" dict_size,\n",
" num_conv_layers=2,\n",
" num_rnn_layers=3,\n",
" rnn_size=256,\n",
" use_gru=False,\n",
" share_rnn_weights=True)\n",
" self.feat_size = feat_size # 161 for linear\n",
" self.dict_size = dict_size\n",
"\n",
" self.conv = ConvStack(feat_size, num_conv_layers)\n",
" \n",
"# self.fc = nn.Linear(1312, dict_size + 1)\n",
"\n",
" i_size = self.conv.output_height # H after conv stack\n",
" self.rnn = RNNStack(\n",
" i_size=i_size,\n",
" h_size=rnn_size,\n",
" num_stacks=num_rnn_layers,\n",
" use_gru=use_gru,\n",
" share_rnn_weights=share_rnn_weights)\n",
" \n",
" self.fc = nn.Linear(rnn_size * 2, dict_size + 1)\n",
" \n",
" def infer(self, audio, audio_len):\n",
" # [B, D, T] -> [B, C=1, D, T]\n",
" audio = audio.unsqueeze(1)\n",
"\n",
" # convolution group\n",
" x, audio_len = self.conv(audio, audio_len)\n",
" print('conv out', x.shape)\n",
"\n",
" # convert data from convolution feature map to sequence of vectors\n",
" B, C, D, T = paddle.shape(x)\n",
" x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]\n",
" x = x.reshape([B, T, C * D]) #[B, T, C*D]\n",
" print('rnn input', x.shape)\n",
"\n",
" # remove padding part\n",
" x, audio_len = self.rnn(x, audio_len) #[B, T, D]\n",
" print('rnn output', x.shape)\n",
"\n",
" logits = self.fc(x) #[B, T, V + 1]\n",
"\n",
" #ctcdecoder need probs, not log_probs\n",
" probs = F.softmax(logits)\n",
"\n",
" return logits, probs, audio_len\n",
"\n",
" def forward(self, audio, audio_len, text, text_len):\n",
" \"\"\"\n",
" audio: shape [B, D, T]\n",
" text: shape [B, T]\n",
" audio_len: shape [B]\n",
" text_len: shape [B]\n",
" \"\"\"\n",
" return self.infer(audio, audio_len)\n",
" \n",
"\n",
"feat_dim=161\n",
"\n",
"model = DeepSpeech2Test(\n",
" feat_size=feat_dim,\n",
" dict_size=batch_reader.dataset.vocab_size,\n",
" num_conv_layers=args.num_conv_layers,\n",
" num_rnn_layers=args.num_rnn_layers,\n",
" rnn_size=1024,\n",
" use_gru=args.use_gru,\n",
" share_rnn_weights=args.share_rnn_weights,\n",
" )\n",
"dp_model = model\n",
"#dp_model = paddle.DataParallel(model)\n",
"\n",
"loss_fn = DeepSpeech2Loss(batch_reader.dataset.vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "divided-incentive",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 22,
"id": "discrete-conjunction",
"metadata": {},
"outputs": [],
"source": [
"audio, audio_len, text, text_len = None, None, None, None\n",
"\n",
"for idx, inputs in enumerate(batch_reader):\n",
" audio, audio_len, text, text_len = inputs\n",
"# print(idx)\n",
"# print('a', audio.shape, audio.place)\n",
"# print('t', text)\n",
"# print('al', audio_len)\n",
"# print('tl', text_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "protected-announcement",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"conv in: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"conv in: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"conv out: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"conv out [5, 32, 41, 62]\n",
"rnn input [5, 62, 1312]\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
" return (isinstance(seq, collections.Sequence) and\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"rnn output [5, 62, 2048]\n",
"logits len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"loss Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [2316.82153320])\n"
]
}
],
"source": [
"outputs = dp_model(audio, audio_len, text, text_len)\n",
"logits, _, logits_len = outputs\n",
"print('logits len', logits_len)\n",
"loss = loss_fn.forward(logits, text, logits_len, text_len)\n",
"print('loss', loss)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "universal-myrtle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: None\n",
"param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: None\n",
"param grad: fc.bias: shape: [4299] stop_grad: False grad: None\n"
]
}
],
"source": [
"for n, p in dp_model.named_parameters():\n",
" print(\n",
" f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "referenced-double",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: [[[[ 2.1243238 1.696022 3.770659 ... 5.234652 5.4865217\n",
" 4.757795 ]\n",
" [ 2.651376 2.3109848 4.428488 ... 5.353201 8.703288\n",
" 5.1787405 ]\n",
" [ 2.7511077 1.8823049 2.1875212 ... 3.4821286 6.386543\n",
" 3.5026932 ]\n",
" ...\n",
" [ 1.9173846 1.8623551 0.5601456 ... 2.8375719 3.8496673\n",
" 2.359191 ]\n",
" [ 2.3827765 2.497965 1.5914664 ... 2.220721 3.4617734\n",
" 4.829253 ]\n",
" [ 1.6855702 1.5040786 1.8793598 ... 4.0773935 3.176893\n",
" 3.7477999 ]]]\n",
"\n",
"\n",
" [[[ 1.8451455 2.0091445 1.5225713 ... 1.524528 0.17764974\n",
" 1.0245132 ]\n",
" [ 1.9388857 1.3873467 2.044691 ... 0.92544 -0.9746763\n",
" -0.41603735]\n",
" [ 2.6814485 2.6096234 1.6802506 ... 1.902397 1.6837387\n",
" -0.96788657]\n",
" ...\n",
" [ 4.3675485 1.9822174 1.1695029 ... 1.4672399 3.2029557\n",
" 2.6364415 ]\n",
" [ 3.2536 1.1792442 -0.5618002 ... 2.101127 1.904225\n",
" 3.3839993 ]\n",
" [ 1.9118482 1.0651072 0.5409893 ... 2.6783593 1.6871439\n",
" 4.1078367 ]]]\n",
"\n",
"\n",
" [[[-4.412424 -1.7111907 -1.7722387 ... -4.3383503 -6.2393785\n",
" -6.139402 ]\n",
" [-2.260428 -1.0250616 -2.0550888 ... -5.353946 -4.29947\n",
" -6.158736 ]\n",
" [-1.4927872 0.7552787 -0.0702923 ... -4.485656 -4.0794134\n",
" -5.416684 ]\n",
" ...\n",
" [ 2.9100134 4.156195 4.357041 ... -3.569804 -1.8634341\n",
" -0.8772557 ]\n",
" [ 1.6895763 3.4314504 4.1192107 ... -1.380024 -2.3234155\n",
" -3.6650617 ]\n",
" [ 2.4190075 1.007498 3.1173465 ... -0.96318084 -3.6175003\n",
" -2.5240796 ]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[-0.6865506 -0.60106415 -1.5555015 ... 2.0853553 1.900961\n",
" 2.101063 ]\n",
" [-0.31686288 -1.4362946 -1.4929098 ... 0.15085456 1.4540495\n",
" 1.4128599 ]\n",
" [-0.57852304 -0.8204216 -2.3264258 ... 1.4970423 0.54599845\n",
" 1.6222539 ]\n",
" ...\n",
" [ 0.32624918 0.96004546 -0.7476514 ... 2.2786083 2.1000178\n",
" 2.7494807 ]\n",
" [-1.6967826 -0.78979015 -1.8424999 ... 1.0620685 2.0544293\n",
" 2.2483966 ]\n",
" [ 0.8192332 2.601636 -2.6636481 ... 0.26625186 1.7610842\n",
" 1.7467536 ]]]\n",
"\n",
"\n",
" [[[ 0.9140297 0.42424175 1.4352363 ... -2.3022954 -3.001058\n",
" -2.6987422 ]\n",
" [ 0.4491998 -0.10698095 1.5089144 ... -3.2831016 -3.6055021\n",
" -3.6595795 ]\n",
" [ 2.6818252 -1.5750014 -0.34812498 ... -4.4137015 -4.250422\n",
" -3.481941 ]\n",
" ...\n",
" [ 1.4232106 2.9689102 3.9547806 ... -0.481165 0.28190404\n",
" -1.2167063 ]\n",
" [ 2.2297084 4.8198485 4.2857304 ... 0.57483846 1.4093391\n",
" 0.0715822 ]\n",
" [ 1.679745 4.768068 5.416195 ... 0.17254728 0.4623217\n",
" 1.4772662 ]]]\n",
"\n",
"\n",
" [[[-2.0860114 -2.9508173 -1.4945896 ... -4.067145 -2.5652342\n",
" -3.5771027 ]\n",
" [-2.697845 -1.9273603 -2.3885014 ... -2.196533 -2.8573706\n",
" -2.0113711 ]\n",
" [-2.413383 -2.7204053 -1.0502659 ... -3.001385 -3.36447\n",
" -4.3225455 ]\n",
" ...\n",
" [ 1.2754489 0.9560999 1.5239805 ... -0.0105865 -1.00876\n",
" 2.6247358 ]\n",
" [ 1.1965859 1.0378222 1.1025598 ... -0.5394704 0.49838027\n",
" -0.9618193 ]\n",
" [ 1.1361816 1.3232857 0.687318 ... -0.23925456 -0.43679112\n",
" -0.79297894]]]]\n",
"param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: [ 5.9604645e-07 -3.9339066e-06 -1.0728836e-06 -1.6689301e-06\n",
" 1.1920929e-06 -2.5033951e-06 -2.3841858e-07 4.7683716e-07\n",
" 4.2915344e-06 -1.9073486e-06 -1.9073486e-06 3.0994415e-06\n",
" -2.6822090e-06 3.3378601e-06 -4.2915344e-06 5.2452087e-06\n",
" 3.8146973e-06 2.3841858e-07 7.1525574e-07 -3.6954880e-06\n",
" 2.0563602e-06 -2.6226044e-06 3.0994415e-06 -3.5762787e-07\n",
" -4.7683716e-06 1.2218952e-06 3.3378601e-06 -2.5629997e-06\n",
" 2.3841858e-07 -1.7881393e-06 4.7683716e-07 -2.7418137e-06]\n",
"param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: [ 2.363316 3.286464 1.9607866 -1.6367784 -1.6325372 -1.7729434\n",
" -0.9261875 2.0950415 0.1155543 -0.8857083 0.70079553 0.33920464\n",
" 2.6953902 -0.64524114 0.8845749 -1.2271115 0.6578167 -2.939814\n",
" 5.5728893 -1.0917969 0.01470797 1.395206 4.8009634 -0.744532\n",
" 0.944651 -1.092311 1.4877632 -3.042566 0.51686054 -5.4768667\n",
" -5.628145 -1.0894046 ]\n",
"param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: [ 1.5193373 1.8838218 3.7722278 0.28052303 0.5386534 -0.44620085\n",
" -1.6977876 3.115642 0.03312349 -2.9121587 3.8925257 0.2288351\n",
" -2.273387 -1.3597974 4.3708124 -0.23374033 0.116272 -0.7064927\n",
" 6.5267463 -1.5318865 1.0288429 0.7928574 -0.24655592 -2.1116853\n",
" 2.922772 -3.3462617 1.7016437 -3.5471547 0.29777628 -3.2820854\n",
" -4.116946 -0.9909375 ]\n",
"param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: [[[[ 6.20494843e-01 5.95983505e-01 -1.48909020e+00 ... -6.86620831e-01\n",
" 6.71104014e-01 -1.95339048e+00]\n",
" [-3.91837955e-03 1.27062631e+00 -1.63248098e+00 ... 1.07290137e+00\n",
" -9.42245364e-01 -3.34277248e+00]\n",
" [ 2.41821265e+00 2.36212373e-01 -1.84433365e+00 ... 1.23182368e+00\n",
" 1.36039746e+00 -2.94621849e+00]\n",
" ...\n",
" [ 1.55153418e+00 7.25861669e-01 2.08785534e+00 ... -6.40172660e-01\n",
" -3.23889256e-02 -2.30832791e+00]\n",
" [ 3.69824195e+00 1.27163112e-01 4.09263194e-01 ... -8.60729575e-01\n",
" -3.51897454e+00 -2.10093403e+00]\n",
" [-4.94779050e-01 -3.74262631e-01 -1.19801068e+00 ... -2.05930543e+00\n",
" -7.38576293e-01 -9.44581270e-01]]\n",
"\n",
" [[-2.04341412e+00 -3.70606273e-01 -1.40429378e+00 ... -1.71711946e+00\n",
" -4.09437418e-01 -1.74107194e+00]\n",
" [-8.72247815e-01 -1.06301677e+00 -9.19306517e-01 ... -2.98976970e+00\n",
" -3.03250861e+00 -2.37099743e+00]\n",
" [-5.00457406e-01 -1.11882675e+00 -5.91526508e-01 ... 4.23921436e-01\n",
" -2.08650708e+00 -1.82109618e+00]\n",
" ...\n",
" [ 2.07773042e+00 1.40735030e-01 -2.60543615e-01 ... -1.55956164e-01\n",
" -1.31862307e+00 -2.07174897e+00]\n",
" [ 7.95007765e-01 1.14988625e-01 -1.43308258e+00 ... 8.29253554e-01\n",
" -9.57888126e-01 -3.82121086e-01]\n",
" [ 8.34397674e-02 1.38636863e+00 -1.21593380e+00 ... -2.65783578e-01\n",
" 1.78124309e-02 -3.40287232e+00]]\n",
"\n",
" [[ 6.27344131e-01 5.71699142e-02 -3.58010936e+00 ... -4.53077674e-01\n",
" 1.65331578e+00 2.58466601e-02]\n",
" [ 2.66681361e+00 2.02069378e+00 -1.52052927e+00 ... 2.94914508e+00\n",
" 1.94632411e+00 -1.06698799e+00]\n",
" [ 1.57839453e+00 -1.03649735e-01 -4.22528505e+00 ... 2.28863955e+00\n",
" 4.27859402e+00 3.66381669e+00]\n",
" ...\n",
" [-2.44603205e+00 -2.09621000e+00 -2.57623529e+00 ... 9.00211930e-01\n",
" 4.30536079e+00 -2.49779320e+00]\n",
" [-2.52187514e+00 -3.36546659e+00 -1.26748765e+00 ... 8.11533451e-01\n",
" 2.55930424e-01 4.50821817e-02]\n",
" [-3.40082574e+00 -3.26924801e+00 -5.86932135e+00 ... -1.18203712e+00\n",
" 1.09565187e+00 -4.96661961e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 8.20469666e+00 6.96195841e+00 2.73753977e+00 ... 8.34498823e-01\n",
" 2.56748104e+00 1.67592216e+00]\n",
" [ 9.85801792e+00 8.81465149e+00 6.09280396e+00 ... 1.42389655e+00\n",
" 2.92086434e+00 2.08308399e-01]\n",
" [ 8.00702763e+00 7.97301006e+00 4.64527416e+00 ... 8.61916900e-01\n",
" 3.55370259e+00 4.75085378e-01]\n",
" ...\n",
" [ 5.61662769e+00 -4.72857296e-01 -1.04519971e-01 ... -4.03000236e-01\n",
" -1.66419971e+00 -1.70375630e-01]\n",
" [ 4.52409792e+00 -3.70670676e-01 4.54190969e-02 ... -8.20453286e-01\n",
" 9.49141383e-02 8.88008535e-01]\n",
" [ 3.27219462e+00 8.93201411e-01 1.94810414e+00 ... -2.86915004e-02\n",
" 1.93200278e+00 8.19505215e-01]]\n",
"\n",
" [[ 5.84066296e+00 6.72855520e+00 5.21399307e+00 ... 4.55058670e+00\n",
" 3.19132543e+00 3.17435169e+00]\n",
" [ 6.04594421e+00 6.88997173e+00 5.00542831e+00 ... 2.23561144e+00\n",
" 2.76059532e+00 4.83479440e-01]\n",
" [ 5.36118126e+00 4.13896275e+00 3.68701124e+00 ... 3.64462805e+00\n",
" 2.80596399e+00 1.52781498e+00]\n",
" ...\n",
" [ 2.87856674e+00 5.84320784e-01 1.74297714e+00 ... 2.83938944e-01\n",
" -2.26546407e-01 -1.18434143e+00]\n",
" [ 2.08510804e+00 1.74915957e+00 1.58637917e+00 ... 6.41967297e-01\n",
" -1.31319761e-01 -3.85830402e-01]\n",
" [ 4.41666174e+00 2.58244562e+00 2.97712159e+00 ... 1.42317235e-01\n",
" 1.68037796e+00 -6.50003672e-01]]\n",
"\n",
" [[ 1.05511594e+00 6.74880028e-01 -7.64639139e-01 ... -2.15282440e-01\n",
" 2.07197094e+00 4.48752761e-01]\n",
" [ 2.12095881e+00 3.44118834e+00 1.61375272e+00 ... -1.18487728e+00\n",
" 1.88659012e+00 1.48252523e+00]\n",
" [ 8.33427787e-01 4.35035896e+00 -3.59877385e-02 ... 8.70242774e-01\n",
" 3.75945044e+00 -3.09408635e-01]\n",
" ...\n",
" [ 5.08510351e+00 4.73114061e+00 1.97346115e+00 ... -2.25924397e+00\n",
" -1.26373076e+00 -1.37826729e+00]\n",
" [ 6.17275095e+00 4.16016817e+00 3.15675950e+00 ... -2.02416754e+00\n",
" 1.50002241e-02 1.84633851e+00]\n",
" [ 7.32995272e+00 5.34601831e+00 4.58857203e+00 ... -1.88874304e+00\n",
" 1.53240371e+00 7.47349262e-02]]]\n",
"\n",
"\n",
" [[[-1.80918843e-01 -2.52616453e+00 -2.78145695e+00 ... 1.44283652e+00\n",
" -1.08945215e+00 4.19084758e-01]\n",
" [-9.66833949e-01 -2.41106153e+00 -3.48886085e+00 ... -1.87193304e-01\n",
" 8.21905077e-01 1.89097953e+00]\n",
" [-1.59118319e+00 -2.56997013e+00 -3.10426521e+00 ... 2.05900550e+00\n",
" -2.78253704e-01 6.96343541e-01]\n",
" ...\n",
" [ 6.66302443e-02 -2.00887346e+00 -3.17550874e+00 ... 7.97579706e-01\n",
" -9.71581042e-02 1.71877682e+00]\n",
" [-8.01679730e-01 -2.02678037e+00 -3.21915555e+00 ... 8.35528374e-01\n",
" -1.15296638e+00 4.35728967e-01]\n",
" [ 1.45292446e-01 -2.15479851e+00 -1.51839817e+00 ... -3.07936192e-01\n",
" -5.39051890e-01 1.13107657e+00]]\n",
"\n",
" [[-2.43341160e+00 -3.35346818e+00 -9.87014294e-01 ... 1.34049034e+00\n",
" 2.95773447e-02 1.27177119e+00]\n",
" [-2.61602497e+00 -9.76761580e-01 -2.52060473e-01 ... -1.38134825e+00\n",
" 3.85564029e-01 4.57195908e-01]\n",
" [-2.23676014e+00 -4.00404739e+00 -2.23409963e+00 ... -1.41846514e+00\n",
" -6.58698231e-02 -3.61778140e-01]\n",
" ...\n",
" [-1.13604403e+00 -6.03917837e-02 -4.95491922e-01 ... 2.14673686e+00\n",
" 1.21484184e+00 2.22764325e+00]\n",
" [-1.05162430e+00 -1.59828448e+00 3.15489501e-01 ... 2.28046751e+00\n",
" 2.39702511e+00 2.43942714e+00]\n",
" [-1.27370405e+00 -2.05736399e-01 -1.12124372e+00 ... 2.21597219e+00\n",
" 2.50086927e+00 1.91134131e+00]]\n",
"\n",
" [[-4.53170598e-01 -1.59644139e+00 -3.63470483e+00 ... -4.35066032e+00\n",
" -3.79540777e+00 -1.09796596e+00]\n",
" [-2.21036464e-01 -2.53353834e+00 -1.28269875e+00 ... -3.38615727e+00\n",
" -2.59143281e+00 7.74220943e-01]\n",
" [-6.89323783e-01 -1.44375205e+00 6.66438341e-02 ... -1.30736077e+00\n",
" -1.23293114e+00 1.58148706e+00]\n",
" ...\n",
" [ 1.63751483e+00 -4.08427984e-01 -8.15176964e-01 ... 3.70807743e+00\n",
" 2.04232907e+00 1.97716308e+00]\n",
" [ 2.13261342e+00 1.85947633e+00 -8.06532025e-01 ... 1.98311245e+00\n",
" 2.27003932e+00 -1.11734614e-01]\n",
" [ 1.28702402e+00 3.98628891e-01 -1.63712263e+00 ... 8.00528765e-01\n",
" 5.78273535e-01 -2.59924948e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 3.96233416e+00 4.66794682e+00 1.39437711e+00 ... 7.52061129e-01\n",
" -1.53534544e+00 -6.67162359e-01]\n",
" [ 2.33841681e+00 3.35811281e+00 9.80114818e-01 ... 1.48806703e+00\n",
" 2.68609226e-01 -1.35124445e+00]\n",
" [ 2.08177710e+00 4.28519583e+00 1.52450514e+00 ... 7.45321214e-01\n",
" -5.04359961e-01 -1.81241560e+00]\n",
" ...\n",
" [ 2.95398951e-01 4.30877179e-01 -2.03731894e+00 ... -4.20221925e-01\n",
" 3.29260826e-01 5.83679557e-01]\n",
" [ 1.30742240e+00 -6.32183790e-01 -3.13741422e+00 ... 9.63868052e-02\n",
" 2.91730791e-01 1.33400351e-01]\n",
" [ 5.43292165e-01 -2.83665359e-01 -1.88138187e+00 ... 2.15468198e-01\n",
" 4.90157723e-01 2.40562439e+00]]\n",
"\n",
" [[ 1.57632053e+00 6.27885723e+00 2.87853765e+00 ... 3.07016110e+00\n",
" 1.91490650e+00 1.76274943e+00]\n",
" [ 2.57776356e+00 4.07256317e+00 2.52231169e+00 ... 4.09494352e+00\n",
" 2.53548074e+00 2.44395185e+00]\n",
" [ 2.43037057e+00 4.35728836e+00 1.96233964e+00 ... 2.26702976e+00\n",
" 2.94634581e+00 2.21452284e+00]\n",
" ...\n",
" [-2.72509992e-01 -8.41220498e-01 -1.89133918e+00 ... -1.80079627e+00\n",
" -2.00367713e+00 -7.09145784e-01]\n",
" [ 8.21575999e-01 -1.13323164e+00 -2.62418866e+00 ... -2.38889670e+00\n",
" -7.83945560e-01 -1.01922750e-01]\n",
" [-1.14730227e+00 -1.42182577e+00 -2.00993991e+00 ... -2.11025667e+00\n",
" 1.60286129e-02 -7.26446986e-01]]\n",
"\n",
" [[ 4.20389509e+00 3.75917768e+00 4.97653627e+00 ... 1.23642838e+00\n",
" 8.52760911e-01 1.27920091e-01]\n",
" [ 5.29409122e+00 5.29002380e+00 3.96404648e+00 ... 1.91227329e+00\n",
" 3.97556186e-01 1.69182217e+00]\n",
" [ 4.60112572e+00 4.12772799e+00 2.10280085e+00 ... 3.24303842e+00\n",
" -1.07720590e+00 -3.81854475e-01]\n",
" ...\n",
" [ 1.81884170e-02 -3.11472058e+00 -8.23525012e-01 ... -2.40161085e+00\n",
" -4.48192549e+00 -6.14600539e-01]\n",
" [ 1.16305006e+00 -1.15409636e+00 -3.48765063e+00 ... -1.97504926e+00\n",
" -4.44984436e+00 -2.28429958e-01]\n",
" [ 1.29197860e+00 6.17720246e-01 -5.87171853e-01 ... -1.35258228e-01\n",
" -1.29259872e+00 1.30360842e-01]]]\n",
"\n",
"\n",
" [[[-1.26687372e+00 -2.33633637e+00 -1.49625254e+00 ... 2.52396107e+00\n",
" -6.68072224e-01 -1.13282454e+00]\n",
" [-1.34229445e+00 -2.87080932e+00 -2.57388353e+00 ... -8.75385761e-01\n",
" -1.00205469e+00 -3.58956242e+00]\n",
" [-9.49853599e-01 -5.78684711e+00 -3.52962446e+00 ... 8.88233304e-01\n",
" 2.25133196e-01 -1.02802217e+00]\n",
" ...\n",
" [-7.38113701e-01 -3.47510982e+00 -3.23011065e+00 ... -1.25624001e+00\n",
" -1.63268471e+00 6.00247443e-01]\n",
" [-2.29733467e+00 -5.72547615e-01 -1.98301303e+00 ... -1.90137398e+00\n",
" -1.47013855e+00 -1.45779204e+00]\n",
" [-2.24628520e+00 -3.36337948e+00 -3.91878939e+00 ... -1.53652275e+00\n",
" -1.36285520e+00 -1.68160331e+00]]\n",
"\n",
" [[-8.11348319e-01 -7.17824280e-01 -1.02243233e+00 ... -2.69050407e+00\n",
" -2.32403350e+00 -4.25943947e+00]\n",
" [-2.35056520e+00 -2.35941172e+00 -1.24398732e+00 ... -2.08313870e+00\n",
" -1.16508257e+00 -1.30353463e+00]\n",
" [-2.25146723e+00 -1.94972813e+00 -1.13295293e+00 ... -2.61496377e+00\n",
" -1.91106403e+00 -1.07801402e+00]\n",
" ...\n",
" [-2.67012739e+00 -3.20916414e+00 -2.41768575e+00 ... 2.65138328e-01\n",
" -5.27612507e-01 1.44604075e+00]\n",
" [-3.54237866e+00 -3.62832785e+00 -2.40270257e+00 ... -9.76106226e-02\n",
" 4.67946082e-01 -7.24248111e-01]\n",
" [-2.49844384e+00 -3.42463255e+00 -2.99040008e+00 ... 4.28889185e-01\n",
" -7.51657963e-01 -1.00530767e+00]]\n",
"\n",
" [[-8.42589438e-02 1.42022014e-01 -8.51281703e-01 ... 4.21745628e-01\n",
" -2.35717297e-02 -1.71374834e+00]\n",
" [-1.05496287e+00 3.82416457e-01 -4.40595537e-01 ... 1.03381336e-01\n",
" -1.41204190e+00 -7.58325040e-01]\n",
" [-2.28930283e+00 -2.03857040e+00 -9.16261196e-01 ... -3.94939929e-01\n",
" -1.07798588e+00 -1.48433352e+00]\n",
" ...\n",
" [-3.11473966e-01 -1.40877593e+00 -2.42908645e+00 ... 7.88682699e-01\n",
" 1.24199319e+00 1.89949930e-01]\n",
" [ 5.44084549e-01 -1.02425671e+00 -1.53991556e+00 ... -4.36764538e-01\n",
" -5.78772545e-01 2.62665659e-01]\n",
" [ 1.26812792e+00 -9.89493608e-01 -1.47972977e+00 ... 2.21440494e-02\n",
" 2.79776216e-01 7.63269484e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 6.02095068e-01 5.93243122e-01 -1.06838238e+00 ... 3.56546330e+00\n",
" 1.16390383e+00 -1.47593319e-01]\n",
" [ 1.80458140e+00 1.68401957e+00 4.17516947e-01 ... 3.33444500e+00\n",
" 1.89411759e+00 1.03220642e-01]\n",
" [ 2.74264169e+00 2.92038846e+00 1.00775683e+00 ... 3.53285050e+00\n",
" 2.07282662e+00 -2.56800652e-01]\n",
" ...\n",
" [ 4.88933468e+00 3.72433925e+00 3.58677816e+00 ... 1.98363388e+00\n",
" 1.80851030e+00 8.32634747e-01]\n",
" [ 4.01546288e+00 4.78934765e+00 2.94778132e+00 ... 2.99637699e+00\n",
" 1.30439472e+00 3.61029744e-01]\n",
" [ 3.13628030e+00 2.01894832e+00 2.82585931e+00 ... 2.54264188e+00\n",
" -9.16651785e-02 9.93353873e-02]]\n",
"\n",
" [[ 2.35585642e+00 8.42678428e-01 1.57331872e+00 ... 3.65935063e+00\n",
" 3.94066262e+00 4.89832020e+00]\n",
" [ 1.85791731e+00 1.34373701e+00 1.30812299e+00 ... 2.71434736e+00\n",
" 3.22004294e+00 2.99872303e+00]\n",
" [ 1.67675853e+00 -4.05569375e-02 1.85539150e+00 ... 3.73934364e+00\n",
" 2.98195982e+00 3.37315011e+00]\n",
" ...\n",
" [ 2.14539170e+00 2.86586595e+00 2.20222116e+00 ... 1.20492995e+00\n",
" 2.13971066e+00 1.94932449e+00]\n",
" [ 4.68422651e+00 3.80044746e+00 4.23209000e+00 ... 2.40658951e+00\n",
" 2.29117441e+00 2.52368808e+00]\n",
" [ 3.10694575e+00 2.49402595e+00 4.53786707e+00 ... 9.08902645e-01\n",
" 1.86903965e+00 2.27776885e+00]]\n",
"\n",
" [[ 1.45200038e+00 5.17961740e-01 -1.58403587e+00 ... 5.07019472e+00\n",
" 7.87163258e-01 1.20610237e+00]\n",
" [ 3.39321136e+00 2.21043849e+00 -6.31202877e-01 ... 4.97822762e+00\n",
" 9.66498017e-01 1.18883348e+00]\n",
" [ 1.20627856e+00 1.82759428e+00 5.91053367e-01 ... 4.14318657e+00\n",
" 5.25399208e-01 -1.16850233e+00]\n",
" ...\n",
" [ 1.05183899e+00 5.80030501e-01 1.89724147e+00 ... 2.54626465e+00\n",
" -1.49128008e+00 -1.85064209e+00]\n",
" [ 1.50983357e+00 2.85973406e+00 2.61224055e+00 ... 4.83481932e+00\n",
" 9.67048705e-02 -4.37043965e-01]\n",
" [ 2.57720876e+00 2.09961963e+00 4.11754288e-02 ... 3.80421424e+00\n",
" -7.83308804e-01 -1.64871216e+00]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[-1.16345096e+00 -2.53971386e+00 -8.99101734e-01 ... -4.35583591e-01\n",
" -1.29671764e+00 -1.61429560e+00]\n",
" [ 3.72841507e-01 3.45808208e-01 -1.82167351e+00 ... -2.14515448e+00\n",
" -1.26383066e+00 -2.27464601e-01]\n",
" [ 1.58568513e+00 2.58181524e+00 1.86554670e+00 ... -1.10401320e+00\n",
" -3.68550658e-01 -2.58849680e-01]\n",
" ...\n",
" [-9.15827155e-01 -1.25424683e+00 -4.04716206e+00 ... 2.13138080e+00\n",
" 2.67662477e+00 2.31014514e+00]\n",
" [-3.19453120e-01 -6.71132684e-01 -1.51378751e+00 ... 1.86080432e+00\n",
" 2.77418542e+00 1.22875953e+00]\n",
" [-1.20453942e+00 -3.93669218e-01 -1.51751983e+00 ... 1.17620552e+00\n",
" 1.95602298e+00 7.64306366e-01]]\n",
"\n",
" [[-8.73186827e-01 -2.12537169e+00 -1.91664994e+00 ... -2.90821463e-01\n",
" 1.90896463e+00 8.02283168e-01]\n",
" [-1.06389821e+00 -2.15300727e+00 -1.82113051e+00 ... -4.34280694e-01\n",
" 1.53455496e+00 1.94702053e+00]\n",
" [-2.08403468e+00 -4.72900331e-01 -1.10610819e+00 ... -8.79420400e-01\n",
" 7.79394627e-01 2.02670670e+00]\n",
" ...\n",
" [-4.28208113e-01 -7.90894389e-01 -1.06713009e+00 ... 1.12579381e+00\n",
" 9.61961091e-01 1.40342009e+00]\n",
" [ 4.40416574e-01 -1.65901780e-02 -1.05338669e+00 ... 1.40698349e+00\n",
" 9.43485856e-01 2.34856772e+00]\n",
" [-1.20572495e+00 -2.03134632e+00 4.88817632e-01 ... 2.20770907e+00\n",
" 1.38143206e+00 2.00714707e+00]]\n",
"\n",
" [[ 9.00486887e-01 -9.50459957e-01 -1.42935121e+00 ... -1.30648065e+00\n",
" -2.52133775e+00 -8.87715697e-01]\n",
" [ 3.73431134e+00 1.69571114e+00 5.99429727e-01 ... 6.64332986e-01\n",
" -6.10453069e-01 2.06534386e+00]\n",
" [ 1.59800696e+00 -4.59622175e-01 -6.73136234e-01 ... 2.18770742e-01\n",
" -1.12928271e+00 4.87097502e-02]\n",
" ...\n",
" [ 1.92336845e+00 1.37130380e-01 -3.51048648e-01 ... 5.41638851e-01\n",
" 1.06069386e+00 1.36404145e+00]\n",
" [ 1.29641414e+00 -2.79530913e-01 -2.63607264e-01 ... -8.62445176e-01\n",
" 1.48393130e+00 2.69196725e+00]\n",
" [ 1.14442182e+00 -1.24098969e+00 3.70959163e-01 ... -1.12241995e+00\n",
" 3.67927134e-01 2.55976987e+00]]\n",
"\n",
" ...\n",
"\n",
" [[ 5.32017851e+00 3.64207411e+00 3.84571218e+00 ... 3.60754800e+00\n",
" 2.57500267e+00 -1.38083458e-01]\n",
" [ 5.69058084e+00 3.93056583e+00 2.93337941e+00 ... 3.17091584e+00\n",
" 2.34770632e+00 6.48133337e-01]\n",
" [ 5.98239613e+00 6.16548634e+00 3.04750896e+00 ... 5.51510525e+00\n",
" 4.34810448e+00 1.31588542e+00]\n",
" ...\n",
" [ 5.09930992e+00 3.32360983e+00 2.29228449e+00 ... 3.45123887e-01\n",
" 1.06280947e+00 -5.93325794e-02]\n",
" [ 4.19760656e+00 3.97779059e+00 1.66905916e+00 ... 3.68937254e-01\n",
" 8.06131065e-02 8.08142900e-01]\n",
" [ 4.52498960e+00 3.45109749e+00 1.01074433e+00 ... -2.54036248e-01\n",
" 3.13675582e-01 2.13851762e+00]]\n",
"\n",
" [[ 6.93927193e+00 6.05758238e+00 4.60648441e+00 ... 4.32221603e+00\n",
" 3.17874146e+00 1.47012353e+00]\n",
" [ 7.88523865e+00 6.62228966e+00 4.77496338e+00 ... 4.45868683e+00\n",
" 2.73698759e+00 2.17057824e+00]\n",
" [ 7.12061214e+00 6.01714134e+00 4.52996492e+00 ... 3.97184372e+00\n",
" 3.43153954e+00 1.21802723e+00]\n",
" ...\n",
" [ 2.85720730e+00 1.89639473e+00 1.96340394e+00 ... 1.89643729e+00\n",
" 1.64856291e+00 1.15853786e+00]\n",
" [ 3.88248491e+00 2.16386199e+00 1.53069091e+00 ... 2.71704245e+00\n",
" 2.24890351e+00 2.22156644e+00]\n",
" [ 5.27136230e+00 1.68400204e+00 2.09500480e+00 ... 2.75956345e+00\n",
" 3.71970820e+00 1.69852686e+00]]\n",
"\n",
" [[ 2.55598164e+00 1.64588141e+00 6.70431674e-01 ... 3.24091220e+00\n",
" 1.48759770e+00 -1.72001183e+00]\n",
" [ 4.33942318e+00 8.40826690e-01 -7.40000725e-01 ... 7.24577069e-01\n",
" 1.74327165e-01 -1.83029580e+00]\n",
" [ 4.39864540e+00 2.28395438e+00 -1.90353513e-01 ... 5.58019161e+00\n",
" 1.05627227e+00 -8.02519619e-01]\n",
" ...\n",
" [ 1.97654784e+00 3.26888156e+00 1.52879453e+00 ... 3.15013933e+00\n",
" 4.66731453e+00 4.98701715e+00]\n",
" [ 1.40016854e+00 3.45761251e+00 3.68359756e+00 ... 1.14207900e+00\n",
" 3.32219076e+00 3.83035636e+00]\n",
" [ 1.99269783e+00 2.15428829e+00 3.35396528e-01 ... 2.45916694e-01\n",
" 2.13785577e+00 4.33214951e+00]]]\n",
"\n",
"\n",
" [[[ 1.35320330e+00 5.05850911e-02 1.04915988e+00 ... 1.82023585e-01\n",
" 2.72914767e-01 3.92112255e-01]\n",
" [ 1.04646444e+00 7.60913491e-01 1.93323612e+00 ... 1.19493449e+00\n",
" -1.44200325e-01 4.07531261e-02]\n",
" [-9.88207340e-01 -1.46165287e+00 1.05884135e-01 ... -3.23057353e-01\n",
" -2.28934169e+00 -7.38609374e-01]\n",
" ...\n",
" [ 1.01198792e+00 2.34331083e+00 1.04566610e+00 ... 1.29697472e-01\n",
" -1.23878837e+00 2.21006930e-01]\n",
" [-3.75360101e-01 1.53673506e+00 -1.32206869e+00 ... -2.55255580e-01\n",
" -6.22699618e-01 -1.73162484e+00]\n",
" [ 4.34735864e-01 5.08327007e-01 -3.49233925e-01 ... -1.04749084e+00\n",
" -1.15777385e+00 -1.13671994e+00]]\n",
"\n",
" [[ 1.67839336e+00 -1.80224836e-01 1.02194118e+00 ... 8.44027162e-01\n",
" 8.81283879e-02 -1.37762165e+00]\n",
" [ 8.39694083e-01 1.32322550e+00 4.02442753e-01 ... -4.21785116e-01\n",
" -9.98012185e-01 -1.11348581e+00]\n",
" [ 7.64424682e-01 8.58965695e-01 2.94626594e-01 ... -6.65519595e-01\n",
" -3.65677416e-01 -2.25250268e+00]\n",
" ...\n",
" [-1.10193872e+00 1.18070498e-01 1.04604781e-01 ... -1.44486964e+00\n",
" -2.52748466e+00 -2.16131711e+00]\n",
" [-1.06079710e+00 -1.48379254e+00 3.80138367e-01 ... -1.62288392e+00\n",
" -2.44736362e+00 -8.78590107e-01]\n",
" [ 3.44401300e-02 -2.60935068e+00 -2.35597759e-01 ... -2.41114974e+00\n",
" -2.45255780e+00 -1.82384634e+00]]\n",
"\n",
" [[ 1.37670958e+00 1.58661580e+00 -2.85664916e-01 ... 1.49081087e+00\n",
" 4.13422853e-01 1.12761199e+00]\n",
" [ 1.54148173e+00 6.22704089e-01 1.41886568e+00 ... 1.59678531e+00\n",
" -8.72656107e-01 1.52415514e-01]\n",
" [ 3.30207205e+00 2.89925170e+00 1.91855145e+00 ... 3.18863559e+00\n",
" 1.87347198e+00 9.48901057e-01]\n",
" ...\n",
" [-1.53920484e+00 1.77375078e-02 -1.02018684e-01 ... 1.94011092e+00\n",
" -6.83587790e-01 1.49154460e+00]\n",
" [-2.27719522e+00 1.02481163e+00 -2.11300224e-01 ... -8.18020821e-01\n",
" 1.54248989e+00 -1.46732473e+00]\n",
" [-4.50206220e-01 3.62383485e+00 1.07175660e+00 ... 4.25961137e-01\n",
" 1.12405360e-01 -6.87821358e-02]]\n",
"\n",
" ...\n",
"\n",
" [[-3.40477467e-01 -2.99311423e+00 -2.12096786e+00 ... 2.27393007e+00\n",
" 4.03424358e+00 3.73335361e+00]\n",
" [-6.99971199e-01 -2.97719741e+00 -2.72910309e+00 ... 1.50101089e+00\n",
" 2.29408574e+00 3.14105940e+00]\n",
" [-1.41648722e+00 -1.86292887e+00 -1.84006739e+00 ... 2.78402638e+00\n",
" 3.91481900e+00 5.32456112e+00]\n",
" ...\n",
" [ 5.97958088e-01 1.50512588e+00 6.23718500e-01 ... 2.83813477e+00\n",
" 3.87909842e+00 3.33359623e+00]\n",
" [ 1.65542316e+00 3.56163192e+00 4.01527691e+00 ... 3.38367462e+00\n",
" 1.55827272e+00 2.50741863e+00]\n",
" [ 2.82036042e+00 2.53322673e+00 4.38798475e+00 ... 4.64642382e+00\n",
" 3.28739667e+00 3.02895570e+00]]\n",
"\n",
" [[-3.47941303e+00 -3.49006844e+00 -2.25583363e+00 ... 1.45181656e-01\n",
" 1.52944064e+00 2.08810711e+00]\n",
" [-2.27786446e+00 -4.59218550e+00 -2.74722624e+00 ... -1.73136210e+00\n",
" 7.46028006e-01 1.74789345e+00]\n",
" [-3.35524082e+00 -4.58244705e+00 -2.40820456e+00 ... -5.04051924e-01\n",
" 1.49640536e+00 2.16613841e+00]\n",
" ...\n",
" [ 5.26107132e-01 2.05329061e+00 2.84252572e+00 ... 1.33222675e+00\n",
" 3.87935114e+00 3.69385266e+00]\n",
" [ 4.38092083e-01 2.15028906e+00 3.13363624e+00 ... 3.36048746e+00\n",
" 5.36551809e+00 2.94915986e+00]\n",
" [ 2.75497317e+00 3.25929213e+00 2.33522987e+00 ... 1.69926262e+00\n",
" 3.93462896e+00 3.68200874e+00]]\n",
"\n",
" [[ 1.10951948e+00 5.31419516e-02 -1.58864903e+00 ... 5.24887085e+00\n",
" 1.60273385e+00 4.90113163e+00]\n",
" [-2.94517064e+00 -2.81092644e+00 -4.89631557e+00 ... 3.99868512e+00\n",
" 1.40544355e+00 2.84833241e+00]\n",
" [-3.51893663e-01 -3.53325534e+00 -2.21239805e+00 ... 4.26225853e+00\n",
" 6.87886119e-01 2.58609629e+00]\n",
" ...\n",
" [ 2.92248201e+00 5.40264511e+00 4.65721560e+00 ... 5.24537373e+00\n",
" 2.30406880e+00 1.29892707e+00]\n",
" [ 1.43473256e+00 4.61167526e+00 3.57578802e+00 ... 5.12181854e+00\n",
" 8.59923482e-01 1.38731599e+00]\n",
" [-6.50881350e-01 2.18233657e+00 2.74669623e+00 ... 4.86368895e+00\n",
" 1.44120216e+00 1.79993320e+00]]]\n",
"\n",
"\n",
" [[[ 1.64106202e+00 3.54410499e-01 -3.54172409e-01 ... 2.32646990e+00\n",
" 1.65043330e+00 3.45897645e-01]\n",
" [ 2.16236949e+00 1.28213906e+00 2.26082468e+00 ... 6.10507369e-01\n",
" 9.12241280e-01 1.27429694e-01]\n",
" [ 2.07962990e+00 7.03816175e-01 2.01272345e+00 ... -2.26959705e-01\n",
" 1.00041127e+00 5.87104559e-02]\n",
" ...\n",
" [-1.62972426e+00 -3.04028845e+00 -1.39124167e+00 ... 2.47561097e+00\n",
" 2.35047388e+00 1.61532843e+00]\n",
" [-1.97368932e+00 -5.44541061e-01 -5.92882216e-01 ... 1.39800012e+00\n",
" 2.32770801e+00 9.96662021e-01]\n",
" [-1.15636075e+00 -1.34654212e+00 -8.50648999e-01 ... 1.85655832e+00\n",
" 2.05776072e+00 5.34575820e-01]]\n",
"\n",
" [[-1.02104437e+00 3.08469892e-01 2.81789303e-01 ... -8.24654043e-01\n",
" -9.85817850e-01 -2.05517030e+00]\n",
" [ 9.50192690e-01 3.35105330e-01 5.31637192e-01 ... -1.42974198e-01\n",
" -1.79659498e+00 -1.58266973e+00]\n",
" [-2.51316994e-01 -1.28709340e+00 3.01498562e-01 ... -1.32253516e+00\n",
" -1.55507576e+00 -9.37123299e-01]\n",
" ...\n",
" [ 2.33016998e-01 2.92454743e+00 3.15420461e+00 ... 1.15574491e+00\n",
" 1.27850962e+00 1.35487700e+00]\n",
" [ 3.81013602e-01 1.44239831e+00 6.64825320e-01 ... -3.89374971e-01\n",
" 1.50716826e-01 1.33641326e+00]\n",
" [ 1.71373415e+00 1.67357373e+00 1.76596940e+00 ... 1.57941079e+00\n",
" 1.60940981e+00 1.78091609e+00]]\n",
"\n",
" [[-5.16522598e+00 -1.68099070e+00 -3.24440050e+00 ... -3.46229005e+00\n",
" -2.18273020e+00 -1.98621082e+00]\n",
" [-3.05743694e+00 9.15392339e-01 -1.93508530e+00 ... -1.82306373e+00\n",
" -2.12960863e+00 -3.45255351e+00]\n",
" [-4.32777822e-01 -1.00303245e+00 -1.61397791e+00 ... -2.08376765e+00\n",
" -3.72989595e-01 -1.36516929e+00]\n",
" ...\n",
" [-5.83641946e-01 4.14125490e+00 1.58227599e+00 ... 2.03144050e+00\n",
" 2.13982654e+00 -1.81909311e+00]\n",
" [-1.74230576e+00 2.39347410e+00 2.44080925e+00 ... 5.43732524e-01\n",
" 2.07899213e+00 -3.71748984e-01]\n",
" [ 3.80016506e-01 7.84988403e-01 1.20596504e+00 ... -2.32057095e+00\n",
" -2.81265080e-01 -3.69353056e+00]]\n",
"\n",
" ...\n",
"\n",
" [[-3.48024845e+00 -2.60937548e+00 -3.84952760e+00 ... 6.68736577e-01\n",
" -1.75104141e-02 -3.54720926e+00]\n",
" [-2.59637117e+00 -5.18190145e+00 -2.33887696e+00 ... 9.13373232e-02\n",
" -3.58282638e+00 -2.40778995e+00]\n",
" [-2.50912881e+00 -1.22113395e+00 -2.34372020e+00 ... 1.40071487e+00\n",
" -1.67449510e+00 -1.14655948e+00]\n",
" ...\n",
" [-5.75253534e+00 -6.67348385e+00 -5.05184650e+00 ... -2.73145151e+00\n",
" -1.48933101e+00 -1.36807609e+00]\n",
" [-3.29049587e+00 -3.73956156e+00 -2.85064268e+00 ... -3.92481357e-01\n",
" -8.00529659e-01 -8.39800835e-01]\n",
" [-4.30351114e+00 -4.21471930e+00 -2.41703367e+00 ... -1.27081513e+00\n",
" 1.67839837e+00 8.47821474e-01]]\n",
"\n",
" [[-5.27856112e-01 -1.09752083e+00 3.39107156e-01 ... 2.00062895e+00\n",
" 8.83528054e-01 2.57416844e-01]\n",
" [-1.58655810e+00 -3.36268663e-01 1.16161990e+00 ... 1.54868484e+00\n",
" 2.38878536e+00 1.84097290e+00]\n",
" [ 5.96052647e-01 2.15484858e-01 1.85280466e+00 ... 2.74587560e+00\n",
" 1.61432290e+00 1.13214278e+00]\n",
" ...\n",
" [-4.57659864e+00 -5.42679739e+00 -4.35204458e+00 ... -1.82452416e+00\n",
" -2.18670201e+00 -3.91811800e+00]\n",
" [-1.32477629e+00 -4.19110394e+00 -3.41308069e+00 ... 1.39622003e-01\n",
" -1.59393203e+00 -9.08105671e-01]\n",
" [-3.60161018e+00 -4.05932713e+00 -2.23674798e+00 ... 9.09647286e-01\n",
" 9.73127842e-01 1.19991803e+00]]\n",
"\n",
" [[ 2.04062796e+00 7.95603275e-01 -1.28833270e+00 ... 4.64749050e+00\n",
" 2.25974560e+00 1.02396965e+00]\n",
" [ 1.68882537e+00 2.63353348e+00 2.53597498e-02 ... 4.69063854e+00\n",
" -4.19382691e-01 2.91669458e-01]\n",
" [ 7.71395087e-01 1.20833695e+00 -2.58601785e-01 ... 1.21794045e+00\n",
" -1.51922226e-01 7.44265199e-01]\n",
" ...\n",
" [-6.66095781e+00 -4.81577682e+00 -5.39921665e+00 ... -2.20548606e+00\n",
" 5.72486281e-01 -4.35207397e-01]\n",
" [-7.51608658e+00 -6.67776871e+00 -3.73199415e+00 ... -1.70327055e+00\n",
" 1.01334639e-02 -3.20627165e+00]\n",
" [-5.73050356e+00 -2.74379373e+00 -3.70248461e+00 ... -1.09794116e+00\n",
" -1.73590891e-02 -1.80156028e+00]]]]\n",
"param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: [-1.4305115e-06 0.0000000e+00 -4.0531158e-06 -1.6689301e-06\n",
" 2.3841858e-07 -7.1525574e-07 1.1920929e-06 1.5497208e-06\n",
" -2.3841858e-07 1.6689301e-06 9.5367432e-07 9.5367432e-07\n",
" -2.6226044e-06 1.1920929e-06 1.3113022e-06 1.9669533e-06\n",
" -4.7683716e-07 1.1920929e-06 -1.6689301e-06 -1.5497208e-06\n",
" -2.2649765e-06 4.7683716e-07 2.3841858e-06 -3.5762787e-06\n",
" 2.3841858e-07 2.1457672e-06 -3.5762787e-07 8.3446503e-07\n",
" -3.5762787e-07 -7.1525574e-07 2.6524067e-06 -1.1920929e-06]\n",
"param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: [-3.7669735 1.5226867 1.759756 4.501629 -2.2077336 0.18411277\n",
" 1.3558264 -1.0269645 3.9628277 3.9300344 -2.80754 1.8462183\n",
" -0.03385968 2.1284049 0.46124816 -4.364863 0.78491163 0.25565645\n",
" -5.3538237 3.2606194 0.79100513 -1.4652673 2.769378 1.2283417\n",
" -4.7466464 -1.3404545 -6.9374166 0.710248 2.0944448 0.4334769\n",
" -0.24313992 0.31392363]\n",
"param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: [-0.6251638 2.833331 0.6993131 3.7106915 -2.262496 0.7390424\n",
" 0.5360477 -2.803875 2.1646228 2.117193 -1.9988279 1.5135905\n",
" -2.0181084 2.6450465 0.06302822 -3.0530102 1.4788482 0.5941844\n",
" -3.1690063 1.8753575 -0.0737313 -2.7806277 -0.04483938 0.16129279\n",
" -1.2960215 -0.38020235 -0.55218065 0.10754502 2.065371 -1.4703183\n",
" -0.40964937 -1.4454535 ]\n",
"param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: [[-0.46178514 0.1095643 0.06441769 ... 0.42020613 -0.34181893\n",
" -0.0658682 ]\n",
" [-0.03619978 0.21653323 0.01727325 ... 0.05731536 -0.37822944\n",
" -0.05464617]\n",
" [-0.32397318 0.04158126 -0.08091418 ... 0.0928297 -0.06518176\n",
" -0.40110156]\n",
" ...\n",
" [-0.2702023 0.05126935 0.11825457 ... 0.0069707 -0.36951366\n",
" 0.37071258]\n",
" [-0.11326203 0.19305304 -0.133317 ... -0.13030824 -0.09068564\n",
" 0.32735693]\n",
" [-0.04543798 0.09902512 -0.10745425 ... -0.06685166 -0.3055201\n",
" 0.0752247 ]]\n",
"param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.07338604 0.64991236 0.5465856 ... 0.507725 0.14061031\n",
" 0.3020359 ]\n",
"param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.41395143 -0.28493872 0.36796764 ... 0.2387953 0.06732331\n",
" 0.16263628]\n",
"param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.09370177 -0.12264141 -0.08237482 ... -0.50241685 -0.149155\n",
" -0.25661892]\n",
" [-0.37426725 0.44987115 0.10685667 ... -0.65946174 -0.4499248\n",
" -0.17545304]\n",
" [-0.03753807 0.33422717 0.12750985 ... 0.05405155 -0.17648363\n",
" 0.05315325]\n",
" ...\n",
" [ 0.15721183 0.03064088 -0.00751081 ... 0.27183983 0.3881693\n",
" -0.01544908]\n",
" [ 0.26047793 0.16917065 0.00915196 ... 0.18076143 -0.05080506\n",
" 0.14791614]\n",
" [ 0.19052255 0.03642382 -0.14313167 ... 0.2611448 0.20763844\n",
" 0.26846847]]\n",
"param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.4139514 -0.28493875 0.36796758 ... 0.23879525 0.06732336\n",
" 0.16263627]\n",
"param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[ 0.04214853 -0.1710323 0.17557406 ... 0.11926915 0.21577051\n",
" -0.30598596]\n",
" [-0.02370887 -0.03498494 -0.05991999 ... -0.06049232 -0.14527473\n",
" -0.5335691 ]\n",
" [-0.21417995 -0.10263194 -0.05903128 ... -0.26958284 0.05936668\n",
" 0.25522667]\n",
" ...\n",
" [ 0.31594425 -0.29487017 0.15871571 ... 0.3504135 -0.1418606\n",
" -0.07482046]\n",
" [ 0.22316164 0.7682122 -0.22191924 ... -0.00535548 -0.6497105\n",
" -0.2011079 ]\n",
" [-0.05800886 0.13750821 0.02450509 ... 0.245736 0.07425706\n",
" -0.17761081]]\n",
"param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.45080703 0.19005743 0.077441 ... -0.24504453 0.19666554\n",
" -0.10503208]\n",
"param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237206 0.03389215 ... -0.35602498 0.25528812\n",
" 0.11344345]\n",
"param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.48457903 0.04466334 -0.19785863 ... -0.0254025 -0.10338341\n",
" -0.29202533]\n",
" [-0.15261276 0.00412052 0.22198747 ... 0.22460426 -0.03752084\n",
" 0.05170784]\n",
" [-0.09337254 0.02530848 0.1263681 ... -0.02056236 0.33342454\n",
" -0.08760723]\n",
" ...\n",
" [-0.28645608 -0.19169135 -0.1361257 ... -0.00444204 -0.06552711\n",
" -0.14726155]\n",
" [ 0.21883707 0.2049045 0.23723911 ... 0.4626113 -0.14110637\n",
" 0.02569831]\n",
" [ 0.37554163 -0.19249167 0.14591683 ... 0.25602737 0.40088275\n",
" 0.41056633]]\n",
"param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237211 0.0338921 ... -0.35602498 0.2552881\n",
" 0.11344352]\n",
"param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[-0.28007814 -0.09206 -0.01297755 ... -0.2557205 -0.2693453\n",
" 0.05862035]\n",
" [-0.34194735 -0.01383794 -0.06490533 ... -0.11063005 0.16226721\n",
" -0.3197178 ]\n",
" [-0.3646778 0.15443833 0.02241019 ... -0.15093157 -0.09886418\n",
" -0.44295847]\n",
" ...\n",
" [-0.01041886 -0.57636976 -0.03988511 ... -0.2260822 0.49646813\n",
" -0.15528557]\n",
" [-0.19385241 -0.56451964 -0.05551083 ... -0.5638106 0.43611372\n",
" -0.61484563]\n",
" [ 0.1051331 -0.4762463 0.11194798 ... -0.26766616 -0.30734932\n",
" 0.17856634]]\n",
"param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.02791309 -0.992517 0.63012564 ... -1.1830902 1.4646478\n",
" 1.6333911 ]\n",
"param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.10834587 -1.7079136 0.81259465 ... -1.4478713 1.455745\n",
" 2.069446 ]\n",
"param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.14363798 -0.06933184 0.02901152 ... -0.19233373 -0.03206367\n",
" -0.00845779]\n",
" [-0.44314507 -0.8921327 -1.031872 ... -0.558997 -0.53070104\n",
" -0.855925 ]\n",
" [ 0.15673254 0.28793585 0.13351494 ... 0.38433537 0.5040767\n",
" 0.11303265]\n",
" ...\n",
" [-0.22923109 -0.62508404 -0.6195032 ... -0.6876448 -0.41718128\n",
" -0.74844164]\n",
" [ 0.18024652 0.45618314 0.81391454 ... 0.5780604 0.87566674\n",
" 0.71526295]\n",
" [ 0.3763076 0.54033077 0.9940485 ... 1.087821 0.72288674\n",
" 1.2852117 ]]\n",
"param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.10834593 -1.7079139 0.8125948 ... -1.4478711 1.4557447\n",
" 2.0694466 ]\n",
"param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: [[ 1.4382483e-02 2.0160766e-02 1.2322801e-02 ... 1.0075266e-02\n",
" 7.4421698e-03 -2.3925617e+01]\n",
" [ 3.7887424e-02 5.7105277e-02 2.8803380e-02 ... 2.4820438e-02\n",
" 1.8560058e-02 -5.0687141e+01]\n",
" [ 4.5566272e-02 5.4415584e-02 3.2858539e-02 ... 3.2725763e-02\n",
" 2.1536341e-02 -6.1036335e+01]\n",
" ...\n",
" [ 2.8015019e-02 3.5967816e-02 2.3228688e-02 ... 2.1284629e-02\n",
" 1.3860047e-02 -5.2543671e+01]\n",
" [ 2.8445240e-02 4.2448867e-02 2.7125146e-02 ... 2.2253662e-02\n",
" 1.7470375e-02 -4.3619675e+01]\n",
" [ 4.7438074e-02 5.8287360e-02 3.4546286e-02 ... 3.0827176e-02\n",
" 2.2168703e-02 -6.7901680e+01]]\n",
"param grad: fc.bias: shape: [4299] stop_grad: False grad: [ 8.8967547e-02 1.0697905e-01 6.5251388e-02 ... 6.1503030e-02\n",
" 4.3404289e-02 -1.3512518e+02]\n"
]
}
],
"source": [
"loss.backward(retain_graph=False)\n",
"for n, p in dp_model.named_parameters():\n",
" print(\n",
" f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "selected-crazy",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1.]\n"
]
}
],
"source": [
"print(loss.grad)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bottom-engineer",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "stuffed-yeast",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
\ No newline at end of file
因为 它太大了无法显示 source diff 。你可以改为 查看blob
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "choice-grade",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/DeepSpeech-2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "broke-broad",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"register user softmax to paddle, remove this when fixed!\n",
"register user log_softmax to paddle, remove this when fixed!\n",
"register user sigmoid to paddle, remove this when fixed!\n",
"register user log_sigmoid to paddle, remove this when fixed!\n",
"register user relu to paddle, remove this when fixed!\n",
"override cat of paddle if exists or register, remove this when fixed!\n",
"override item of paddle.Tensor if exists or register, remove this when fixed!\n",
"override long of paddle.Tensor if exists or register, remove this when fixed!\n",
"override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
"override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"override eq of paddle if exists or register, remove this when fixed!\n",
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"register user view to paddle.Tensor, remove this when fixed!\n",
"register user view_as to paddle.Tensor, remove this when fixed!\n",
"register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"register user fill_ to paddle.Tensor, remove this when fixed!\n",
"register user repeat to paddle.Tensor, remove this when fixed!\n",
"register user softmax to paddle.Tensor, remove this when fixed!\n",
"register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"register user relu to paddle.Tensor, remove this when fixed!\n",
"register user type_as to paddle.Tensor, remove this when fixed!\n",
"register user to to paddle.Tensor, remove this when fixed!\n",
"register user float to paddle.Tensor, remove this when fixed!\n",
"register user tolist to paddle.Tensor, remove this when fixed!\n",
"register user glu to paddle.nn.functional, remove this when fixed!\n",
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"register user Module to paddle.nn, remove this when fixed!\n",
"register user ModuleList to paddle.nn, remove this when fixed!\n",
"register user GLU to paddle.nn, remove this when fixed!\n",
"register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"register user export to paddle.jit, remove this when fixed!\n"
]
}
],
"source": [
"import numpy as np\n",
"import paddle\n",
"from yacs.config import CfgNode as CN\n",
"\n",
"from deepspeech.models.u2 import U2Model\n",
"from deepspeech.utils.layer_tools import print_params\n",
"from deepspeech.utils.layer_tools import summary"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "permanent-summary",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"[INFO 2021/05/31 03:23:22 u2.py:839] U2 Encoder type: transformer\n",
"[INFO 2021/05/31 03:23:22 u2.py:840] attention_dropout_rate: 0.0\n",
"attention_heads: 4\n",
"dropout_rate: 0.1\n",
"input_layer: conv2d\n",
"linear_units: 2048\n",
"normalize_before: True\n",
"num_blocks: 12\n",
"output_size: 256\n",
"positional_dropout_rate: 0.1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n",
"encoder.embed.conv.0.bias | [256] | 256 | True\n",
"encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n",
"encoder.embed.conv.2.bias | [256] | 256 | True\n",
"encoder.embed.out.0.weight | [5120, 256] | 1310720 | True\n",
"encoder.embed.out.0.bias | [256] | 256 | True\n",
"encoder.after_norm.weight | [256] | 256 | True\n",
"encoder.after_norm.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm1.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm1.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm2.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm2.bias | [256] | 256 | True\n",
"encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n",
"decoder.embed.0.weight | [4233, 256] | 1083648 | True\n",
"decoder.after_norm.weight | [256] | 256 | True\n",
"decoder.after_norm.bias | [256] | 256 | True\n",
"decoder.output_layer.weight | [256, 4233] | 1083648 | True\n",
"decoder.output_layer.bias | [4233] | 4233 | True\n",
"decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n",
"ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n",
"ctc.ctc_lo.bias | [4233] | 4233 | True\n",
"Total parameters: 411.0, 32.01M elements.\n"
]
}
],
"source": [
"conf_str='examples/tiny/s1/conf/transformer.yaml'\n",
"cfg = CN().load_cfg(open(conf_str))\n",
"cfg.model.input_dim = 83\n",
"cfg.model.output_dim = 4233\n",
"cfg.model.cmvn_file = None\n",
"cfg.model.cmvn_file_type = 'json'\n",
"#cfg.model.encoder_conf.concat_after=True\n",
"cfg.freeze()\n",
"model = U2Model(cfg.model)\n",
"\n",
"print_params(model)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "sapphire-agent",
"metadata": {},
"outputs": [],
"source": [
"#summary(model)\n",
"#print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ruled-invitation",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fossil-means",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n",
"encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n"
]
}
],
"source": [
"%ls .notebook/espnet"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "45c2b75f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"state\n",
"odict_keys(['mask_feature', 'encoder.embed.conv.0.weight', 'encoder.embed.conv.0.bias', 'encoder.embed.conv.2.weight', 'encoder.embed.conv.2.bias', 'encoder.embed.out.0.weight', 'encoder.embed.out.0.bias', 'encoder.encoders.0.self_attn.linear_q.weight', 'encoder.encoders.0.self_attn.linear_q.bias', 'encoder.encoders.0.self_attn.linear_k.weight', 'encoder.encoders.0.self_attn.linear_k.bias', 'encoder.encoders.0.self_attn.linear_v.weight', 'encoder.encoders.0.self_attn.linear_v.bias', 'encoder.encoders.0.self_attn.linear_out.weight', 'encoder.encoders.0.self_attn.linear_out.bias', 'encoder.encoders.0.feed_forward.w_1.weight', 'encoder.encoders.0.feed_forward.w_1.bias', 'encoder.encoders.0.feed_forward.w_2.weight', 'encoder.encoders.0.feed_forward.w_2.bias', 'encoder.encoders.0.norm1.weight', 'encoder.encoders.0.norm1.bias', 'encoder.encoders.0.norm2.weight', 'encoder.encoders.0.norm2.bias', 'encoder.encoders.1.self_attn.linear_q.weight', 'encoder.encoders.1.self_attn.linear_q.bias', 'encoder.encoders.1.self_attn.linear_k.weight', 'encoder.encoders.1.self_attn.linear_k.bias', 'encoder.encoders.1.self_attn.linear_v.weight', 'encoder.encoders.1.self_attn.linear_v.bias', 'encoder.encoders.1.self_attn.linear_out.weight', 'encoder.encoders.1.self_attn.linear_out.bias', 'encoder.encoders.1.feed_forward.w_1.weight', 'encoder.encoders.1.feed_forward.w_1.bias', 'encoder.encoders.1.feed_forward.w_2.weight', 'encoder.encoders.1.feed_forward.w_2.bias', 'encoder.encoders.1.norm1.weight', 'encoder.encoders.1.norm1.bias', 'encoder.encoders.1.norm2.weight', 'encoder.encoders.1.norm2.bias', 'encoder.encoders.2.self_attn.linear_q.weight', 'encoder.encoders.2.self_attn.linear_q.bias', 'encoder.encoders.2.self_attn.linear_k.weight', 'encoder.encoders.2.self_attn.linear_k.bias', 'encoder.encoders.2.self_attn.linear_v.weight', 'encoder.encoders.2.self_attn.linear_v.bias', 'encoder.encoders.2.self_attn.linear_out.weight', 'encoder.encoders.2.self_attn.linear_out.bias', 'encoder.encoders.2.feed_forward.w_1.weight', 'encoder.encoders.2.feed_forward.w_1.bias', 'encoder.encoders.2.feed_forward.w_2.weight', 'encoder.encoders.2.feed_forward.w_2.bias', 'encoder.encoders.2.norm1.weight', 'encoder.encoders.2.norm1.bias', 'encoder.encoders.2.norm2.weight', 'encoder.encoders.2.norm2.bias', 'encoder.encoders.3.self_attn.linear_q.weight', 'encoder.encoders.3.self_attn.linear_q.bias', 'encoder.encoders.3.self_attn.linear_k.weight', 'encoder.encoders.3.self_attn.linear_k.bias', 'encoder.encoders.3.self_attn.linear_v.weight', 'encoder.encoders.3.self_attn.linear_v.bias', 'encoder.encoders.3.self_attn.linear_out.weight', 'encoder.encoders.3.self_attn.linear_out.bias', 'encoder.encoders.3.feed_forward.w_1.weight', 'encoder.encoders.3.feed_forward.w_1.bias', 'encoder.encoders.3.feed_forward.w_2.weight', 'encoder.encoders.3.feed_forward.w_2.bias', 'encoder.encoders.3.norm1.weight', 'encoder.encoders.3.norm1.bias', 'encoder.encoders.3.norm2.weight', 'encoder.encoders.3.norm2.bias', 'encoder.encoders.4.self_attn.linear_q.weight', 'encoder.encoders.4.self_attn.linear_q.bias', 'encoder.encoders.4.self_attn.linear_k.weight', 'encoder.encoders.4.self_attn.linear_k.bias', 'encoder.encoders.4.self_attn.linear_v.weight', 'encoder.encoders.4.self_attn.linear_v.bias', 'encoder.encoders.4.self_attn.linear_out.weight', 'encoder.encoders.4.self_attn.linear_out.bias', 'encoder.encoders.4.feed_forward.w_1.weight', 'encoder.encoders.4.feed_forward.w_1.bias', 'encoder.encoders.4.feed_forward.w_2.weight', 'encoder.encoders.4.feed_forward.w_2.bias', 'encoder.encoders.4.norm1.weight', 'encoder.encoders.4.norm1.bias', 'encoder.encoders.4.norm2.weight', 'encoder.encoders.4.norm2.bias', 'encoder.encoders.5.self_attn.linear_q.weight', 'encoder.encoders.5.self_attn.linear_q.bias', 'encoder.encoders.5.self_attn.linear_k.weight', 'encoder.encoders.5.self_attn.linear_k.bias', 'encoder.encoders.5.self_attn.linear_v.weight', 'encoder.encoders.5.self_attn.linear_v.bias', 'encoder.encoders.5.self_attn.linear_out.weight', 'encoder.encoders.5.self_attn.linear_out.bias', 'encoder.encoders.5.feed_forward.w_1.weight', 'encoder.encoders.5.feed_forward.w_1.bias', 'encoder.encoders.5.feed_forward.w_2.weight', 'encoder.encoders.5.feed_forward.w_2.bias', 'encoder.encoders.5.norm1.weight', 'encoder.encoders.5.norm1.bias', 'encoder.encoders.5.norm2.weight', 'encoder.encoders.5.norm2.bias', 'encoder.encoders.6.self_attn.linear_q.weight', 'encoder.encoders.6.self_attn.linear_q.bias', 'encoder.encoders.6.self_attn.linear_k.weight', 'encoder.encoders.6.self_attn.linear_k.bias', 'encoder.encoders.6.self_attn.linear_v.weight', 'encoder.encoders.6.self_attn.linear_v.bias', 'encoder.encoders.6.self_attn.linear_out.weight', 'encoder.encoders.6.self_attn.linear_out.bias', 'encoder.encoders.6.feed_forward.w_1.weight', 'encoder.encoders.6.feed_forward.w_1.bias', 'encoder.encoders.6.feed_forward.w_2.weight', 'encoder.encoders.6.feed_forward.w_2.bias', 'encoder.encoders.6.norm1.weight', 'encoder.encoders.6.norm1.bias', 'encoder.encoders.6.norm2.weight', 'encoder.encoders.6.norm2.bias', 'encoder.encoders.7.self_attn.linear_q.weight', 'encoder.encoders.7.self_attn.linear_q.bias', 'encoder.encoders.7.self_attn.linear_k.weight', 'encoder.encoders.7.self_attn.linear_k.bias', 'encoder.encoders.7.self_attn.linear_v.weight', 'encoder.encoders.7.self_attn.linear_v.bias', 'encoder.encoders.7.self_attn.linear_out.weight', 'encoder.encoders.7.self_attn.linear_out.bias', 'encoder.encoders.7.feed_forward.w_1.weight', 'encoder.encoders.7.feed_forward.w_1.bias', 'encoder.encoders.7.feed_forward.w_2.weight', 'encoder.encoders.7.feed_forward.w_2.bias', 'encoder.encoders.7.norm1.weight', 'encoder.encoders.7.norm1.bias', 'encoder.encoders.7.norm2.weight', 'encoder.encoders.7.norm2.bias', 'encoder.encoders.8.self_attn.linear_q.weight', 'encoder.encoders.8.self_attn.linear_q.bias', 'encoder.encoders.8.self_attn.linear_k.weight', 'encoder.encoders.8.self_attn.linear_k.bias', 'encoder.encoders.8.self_attn.linear_v.weight', 'encoder.encoders.8.self_attn.linear_v.bias', 'encoder.encoders.8.self_attn.linear_out.weight', 'encoder.encoders.8.self_attn.linear_out.bias', 'encoder.encoders.8.feed_forward.w_1.weight', 'encoder.encoders.8.feed_forward.w_1.bias', 'encoder.encoders.8.feed_forward.w_2.weight', 'encoder.encoders.8.feed_forward.w_2.bias', 'encoder.encoders.8.norm1.weight', 'encoder.encoders.8.norm1.bias', 'encoder.encoders.8.norm2.weight', 'encoder.encoders.8.norm2.bias', 'encoder.encoders.9.self_attn.linear_q.weight', 'encoder.encoders.9.self_attn.linear_q.bias', 'encoder.encoders.9.self_attn.linear_k.weight', 'encoder.encoders.9.self_attn.linear_k.bias', 'encoder.encoders.9.self_attn.linear_v.weight', 'encoder.encoders.9.self_attn.linear_v.bias', 'encoder.encoders.9.self_attn.linear_out.weight', 'encoder.encoders.9.self_attn.linear_out.bias', 'encoder.encoders.9.feed_forward.w_1.weight', 'encoder.encoders.9.feed_forward.w_1.bias', 'encoder.encoders.9.feed_forward.w_2.weight', 'encoder.encoders.9.feed_forward.w_2.bias', 'encoder.encoders.9.norm1.weight', 'encoder.encoders.9.norm1.bias', 'encoder.encoders.9.norm2.weight', 'encoder.encoders.9.norm2.bias', 'encoder.encoders.10.self_attn.linear_q.weight', 'encoder.encoders.10.self_attn.linear_q.bias', 'encoder.encoders.10.self_attn.linear_k.weight', 'encoder.encoders.10.self_attn.linear_k.bias', 'encoder.encoders.10.self_attn.linear_v.weight', 'encoder.encoders.10.self_attn.linear_v.bias', 'encoder.encoders.10.self_attn.linear_out.weight', 'encoder.encoders.10.self_attn.linear_out.bias', 'encoder.encoders.10.feed_forward.w_1.weight', 'encoder.encoders.10.feed_forward.w_1.bias', 'encoder.encoders.10.feed_forward.w_2.weight', 'encoder.encoders.10.feed_forward.w_2.bias', 'encoder.encoders.10.norm1.weight', 'encoder.encoders.10.norm1.bias', 'encoder.encoders.10.norm2.weight', 'encoder.encoders.10.norm2.bias', 'encoder.encoders.11.self_attn.linear_q.weight', 'encoder.encoders.11.self_attn.linear_q.bias', 'encoder.encoders.11.self_attn.linear_k.weight', 'encoder.encoders.11.self_attn.linear_k.bias', 'encoder.encoders.11.self_attn.linear_v.weight', 'encoder.encoders.11.self_attn.linear_v.bias', 'encoder.encoders.11.self_attn.linear_out.weight', 'encoder.encoders.11.self_attn.linear_out.bias', 'encoder.encoders.11.feed_forward.w_1.weight', 'encoder.encoders.11.feed_forward.w_1.bias', 'encoder.encoders.11.feed_forward.w_2.weight', 'encoder.encoders.11.feed_forward.w_2.bias', 'encoder.encoders.11.norm1.weight', 'encoder.encoders.11.norm1.bias', 'encoder.encoders.11.norm2.weight', 'encoder.encoders.11.norm2.bias', 'encoder.after_norm.weight', 'encoder.after_norm.bias', 'decoder.embed.0.weight', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'sfc.weight', 'sfc.bias', 'deconv.0.weight', 'deconv.0.bias', 'deconv.1.weight', 'deconv.1.bias', 'xlm_embed.0.weight', 'xlm_pred.weight', 'xlm_pred.bias'])\n"
]
}
],
"source": [
"#!pip install torch\n",
"import torch\n",
"\n",
"e_model = np.load('.notebook/espnet/model.npz',allow_pickle=True)\n",
"for k in e_model.files:\n",
" print(k)\n",
"state_dict = e_model['state']\n",
"state_dict = state_dict.tolist()\n",
"print(state_dict.keys())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f187bb55",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"# embed.conv.0.weight None torch.Size([256, 1, 3, 3]) \tencoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n",
"# embed.conv.0.bias None torch.Size([256]) \tencoder.embed.conv.0.bias | [256] | 256 | True\n",
"# embed.conv.2.weight None torch.Size([256, 256, 3, 3]) \tencoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n",
"# embed.conv.2.bias None torch.Size([256]) \tencoder.embed.conv.2.bias | [256] | 256 | True\n",
"# embed.out.0.weight None torch.Size([256, 5120]) 83 feature\tencoder.embed.out.0.weight | [4864, 256] | 1245184 | True 80 feature\n",
"# embed.out.0.bias None torch.Size([256]) \tencoder.embed.out.0.bias | [256] | 256 | True\n",
"# after_norm.weight None torch.Size([256]) \tencoder.after_norm.weight | [256] | 256 | True\n",
"# after_norm.bias None torch.Size([256]) \tencoder.after_norm.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_q.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_q.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_k.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_k.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_v.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_v.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"# encoders.9.self_attn.linear_out.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"# encoders.9.self_attn.linear_out.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"# encoders.9.feed_forward.w_1.weight None torch.Size([2048, 256]) \tencoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"# encoders.9.feed_forward.w_1.bias None torch.Size([2048]) \tencoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"# encoders.9.feed_forward.w_2.weight None torch.Size([256, 2048]) \tencoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"# encoders.9.feed_forward.w_2.bias None torch.Size([256]) \tencoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"# encoders.9.norm1.weight None torch.Size([256]) \tencoder.encoders.0.norm1.weight | [256] | 256 | True\n",
"# encoders.9.norm1.bias None torch.Size([256]) \tencoder.encoders.0.norm1.bias | [256] | 256 | True\n",
"# encoders.9.norm2.weight None torch.Size([256]) \tencoder.encoders.0.norm2.weight | [256] | 256 | True\n",
"# encoders.9.norm2.bias None torch.Size([256]) \tencoder.encoders.0.norm2.bias | [256] | 256 | True\n",
"# \tencoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n",
"# \tencoder.encoders.0.concat_linear.bias | [256] | 256 | True\n",
"# espnet transformer\tconcat_linear只是保存了,但是未使用\n",
"\t\n",
"# \tpaddle transformer"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2a0428ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-> encoder.embed.conv.0.weight\n",
"-> encoder.embed.conv.0.bias\n",
"-> encoder.embed.conv.2.weight\n",
"-> encoder.embed.conv.2.bias\n",
"-> encoder.embed.out.0.weight\n",
"encoder.embed.out.0.weight: (256, 5120) -> (5120, 256)\n",
"-> encoder.embed.out.0.bias\n",
"-> encoder.encoders.0.self_attn.linear_q.weight\n",
"encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_q.bias\n",
"-> encoder.encoders.0.self_attn.linear_k.weight\n",
"encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_k.bias\n",
"-> encoder.encoders.0.self_attn.linear_v.weight\n",
"encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_v.bias\n",
"-> encoder.encoders.0.self_attn.linear_out.weight\n",
"encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.0.self_attn.linear_out.bias\n",
"-> encoder.encoders.0.feed_forward.w_1.weight\n",
"encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.0.feed_forward.w_1.bias\n",
"-> encoder.encoders.0.feed_forward.w_2.weight\n",
"encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.0.feed_forward.w_2.bias\n",
"-> encoder.encoders.0.norm1.weight\n",
"-> encoder.encoders.0.norm1.bias\n",
"-> encoder.encoders.0.norm2.weight\n",
"-> encoder.encoders.0.norm2.bias\n",
"-> encoder.encoders.1.self_attn.linear_q.weight\n",
"encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_q.bias\n",
"-> encoder.encoders.1.self_attn.linear_k.weight\n",
"encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_k.bias\n",
"-> encoder.encoders.1.self_attn.linear_v.weight\n",
"encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_v.bias\n",
"-> encoder.encoders.1.self_attn.linear_out.weight\n",
"encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.1.self_attn.linear_out.bias\n",
"-> encoder.encoders.1.feed_forward.w_1.weight\n",
"encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.1.feed_forward.w_1.bias\n",
"-> encoder.encoders.1.feed_forward.w_2.weight\n",
"encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.1.feed_forward.w_2.bias\n",
"-> encoder.encoders.1.norm1.weight\n",
"-> encoder.encoders.1.norm1.bias\n",
"-> encoder.encoders.1.norm2.weight\n",
"-> encoder.encoders.1.norm2.bias\n",
"-> encoder.encoders.2.self_attn.linear_q.weight\n",
"encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_q.bias\n",
"-> encoder.encoders.2.self_attn.linear_k.weight\n",
"encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_k.bias\n",
"-> encoder.encoders.2.self_attn.linear_v.weight\n",
"encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_v.bias\n",
"-> encoder.encoders.2.self_attn.linear_out.weight\n",
"encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.2.self_attn.linear_out.bias\n",
"-> encoder.encoders.2.feed_forward.w_1.weight\n",
"encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.2.feed_forward.w_1.bias\n",
"-> encoder.encoders.2.feed_forward.w_2.weight\n",
"encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.2.feed_forward.w_2.bias\n",
"-> encoder.encoders.2.norm1.weight\n",
"-> encoder.encoders.2.norm1.bias\n",
"-> encoder.encoders.2.norm2.weight\n",
"-> encoder.encoders.2.norm2.bias\n",
"-> encoder.encoders.3.self_attn.linear_q.weight\n",
"encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_q.bias\n",
"-> encoder.encoders.3.self_attn.linear_k.weight\n",
"encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_k.bias\n",
"-> encoder.encoders.3.self_attn.linear_v.weight\n",
"encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_v.bias\n",
"-> encoder.encoders.3.self_attn.linear_out.weight\n",
"encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.3.self_attn.linear_out.bias\n",
"-> encoder.encoders.3.feed_forward.w_1.weight\n",
"encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.3.feed_forward.w_1.bias\n",
"-> encoder.encoders.3.feed_forward.w_2.weight\n",
"encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.3.feed_forward.w_2.bias\n",
"-> encoder.encoders.3.norm1.weight\n",
"-> encoder.encoders.3.norm1.bias\n",
"-> encoder.encoders.3.norm2.weight\n",
"-> encoder.encoders.3.norm2.bias\n",
"-> encoder.encoders.4.self_attn.linear_q.weight\n",
"encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_q.bias\n",
"-> encoder.encoders.4.self_attn.linear_k.weight\n",
"encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_k.bias\n",
"-> encoder.encoders.4.self_attn.linear_v.weight\n",
"encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_v.bias\n",
"-> encoder.encoders.4.self_attn.linear_out.weight\n",
"encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.4.self_attn.linear_out.bias\n",
"-> encoder.encoders.4.feed_forward.w_1.weight\n",
"encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.4.feed_forward.w_1.bias\n",
"-> encoder.encoders.4.feed_forward.w_2.weight\n",
"encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.4.feed_forward.w_2.bias\n",
"-> encoder.encoders.4.norm1.weight\n",
"-> encoder.encoders.4.norm1.bias\n",
"-> encoder.encoders.4.norm2.weight\n",
"-> encoder.encoders.4.norm2.bias\n",
"-> encoder.encoders.5.self_attn.linear_q.weight\n",
"encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_q.bias\n",
"-> encoder.encoders.5.self_attn.linear_k.weight\n",
"encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_k.bias\n",
"-> encoder.encoders.5.self_attn.linear_v.weight\n",
"encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_v.bias\n",
"-> encoder.encoders.5.self_attn.linear_out.weight\n",
"encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.5.self_attn.linear_out.bias\n",
"-> encoder.encoders.5.feed_forward.w_1.weight\n",
"encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.5.feed_forward.w_1.bias\n",
"-> encoder.encoders.5.feed_forward.w_2.weight\n",
"encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.5.feed_forward.w_2.bias\n",
"-> encoder.encoders.5.norm1.weight\n",
"-> encoder.encoders.5.norm1.bias\n",
"-> encoder.encoders.5.norm2.weight\n",
"-> encoder.encoders.5.norm2.bias\n",
"-> encoder.encoders.6.self_attn.linear_q.weight\n",
"encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_q.bias\n",
"-> encoder.encoders.6.self_attn.linear_k.weight\n",
"encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_k.bias\n",
"-> encoder.encoders.6.self_attn.linear_v.weight\n",
"encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_v.bias\n",
"-> encoder.encoders.6.self_attn.linear_out.weight\n",
"encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.6.self_attn.linear_out.bias\n",
"-> encoder.encoders.6.feed_forward.w_1.weight\n",
"encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.6.feed_forward.w_1.bias\n",
"-> encoder.encoders.6.feed_forward.w_2.weight\n",
"encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.6.feed_forward.w_2.bias\n",
"-> encoder.encoders.6.norm1.weight\n",
"-> encoder.encoders.6.norm1.bias\n",
"-> encoder.encoders.6.norm2.weight\n",
"-> encoder.encoders.6.norm2.bias\n",
"-> encoder.encoders.7.self_attn.linear_q.weight\n",
"encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_q.bias\n",
"-> encoder.encoders.7.self_attn.linear_k.weight\n",
"encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_k.bias\n",
"-> encoder.encoders.7.self_attn.linear_v.weight\n",
"encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_v.bias\n",
"-> encoder.encoders.7.self_attn.linear_out.weight\n",
"encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.7.self_attn.linear_out.bias\n",
"-> encoder.encoders.7.feed_forward.w_1.weight\n",
"encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.7.feed_forward.w_1.bias\n",
"-> encoder.encoders.7.feed_forward.w_2.weight\n",
"encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.7.feed_forward.w_2.bias\n",
"-> encoder.encoders.7.norm1.weight\n",
"-> encoder.encoders.7.norm1.bias\n",
"-> encoder.encoders.7.norm2.weight\n",
"-> encoder.encoders.7.norm2.bias\n",
"-> encoder.encoders.8.self_attn.linear_q.weight\n",
"encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_q.bias\n",
"-> encoder.encoders.8.self_attn.linear_k.weight\n",
"encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_k.bias\n",
"-> encoder.encoders.8.self_attn.linear_v.weight\n",
"encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_v.bias\n",
"-> encoder.encoders.8.self_attn.linear_out.weight\n",
"encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.8.self_attn.linear_out.bias\n",
"-> encoder.encoders.8.feed_forward.w_1.weight\n",
"encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.8.feed_forward.w_1.bias\n",
"-> encoder.encoders.8.feed_forward.w_2.weight\n",
"encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.8.feed_forward.w_2.bias\n",
"-> encoder.encoders.8.norm1.weight\n",
"-> encoder.encoders.8.norm1.bias\n",
"-> encoder.encoders.8.norm2.weight\n",
"-> encoder.encoders.8.norm2.bias\n",
"-> encoder.encoders.9.self_attn.linear_q.weight\n",
"encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_q.bias\n",
"-> encoder.encoders.9.self_attn.linear_k.weight\n",
"encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_k.bias\n",
"-> encoder.encoders.9.self_attn.linear_v.weight\n",
"encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_v.bias\n",
"-> encoder.encoders.9.self_attn.linear_out.weight\n",
"encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.9.self_attn.linear_out.bias\n",
"-> encoder.encoders.9.feed_forward.w_1.weight\n",
"encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.9.feed_forward.w_1.bias\n",
"-> encoder.encoders.9.feed_forward.w_2.weight\n",
"encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.9.feed_forward.w_2.bias\n",
"-> encoder.encoders.9.norm1.weight\n",
"-> encoder.encoders.9.norm1.bias\n",
"-> encoder.encoders.9.norm2.weight\n",
"-> encoder.encoders.9.norm2.bias\n",
"-> encoder.encoders.10.self_attn.linear_q.weight\n",
"encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_q.bias\n",
"-> encoder.encoders.10.self_attn.linear_k.weight\n",
"encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_k.bias\n",
"-> encoder.encoders.10.self_attn.linear_v.weight\n",
"encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_v.bias\n",
"-> encoder.encoders.10.self_attn.linear_out.weight\n",
"encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.10.self_attn.linear_out.bias\n",
"-> encoder.encoders.10.feed_forward.w_1.weight\n",
"encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.10.feed_forward.w_1.bias\n",
"-> encoder.encoders.10.feed_forward.w_2.weight\n",
"encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.10.feed_forward.w_2.bias\n",
"-> encoder.encoders.10.norm1.weight\n",
"-> encoder.encoders.10.norm1.bias\n",
"-> encoder.encoders.10.norm2.weight\n",
"-> encoder.encoders.10.norm2.bias\n",
"-> encoder.encoders.11.self_attn.linear_q.weight\n",
"encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_q.bias\n",
"-> encoder.encoders.11.self_attn.linear_k.weight\n",
"encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_k.bias\n",
"-> encoder.encoders.11.self_attn.linear_v.weight\n",
"encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_v.bias\n",
"-> encoder.encoders.11.self_attn.linear_out.weight\n",
"encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"-> encoder.encoders.11.self_attn.linear_out.bias\n",
"-> encoder.encoders.11.feed_forward.w_1.weight\n",
"encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"-> encoder.encoders.11.feed_forward.w_1.bias\n",
"-> encoder.encoders.11.feed_forward.w_2.weight\n",
"encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"-> encoder.encoders.11.feed_forward.w_2.bias\n",
"-> encoder.encoders.11.norm1.weight\n",
"-> encoder.encoders.11.norm1.bias\n",
"-> encoder.encoders.11.norm2.weight\n",
"-> encoder.encoders.11.norm2.bias\n",
"-> encoder.after_norm.weight\n",
"-> encoder.after_norm.bias\n"
]
}
],
"source": [
"# dump torch model to paddle\n",
"#state_dict = model.state_dict()\n",
"paddle_state_dict = {}\n",
"\n",
"for n, p in state_dict.items():\n",
" if 'encoder' not in n:\n",
" continue \n",
" print(f'-> {n}')\n",
" \n",
" \n",
" name_change=True\n",
" if 'norm.running_mean' in n:\n",
" new_n = n.replace('norm.running_', 'norm._')\n",
" elif 'norm.running_var' in n:\n",
" new_n = n.replace('norm.running_var', 'norm._variance')\n",
" else:\n",
" name_change=False\n",
" new_n = n\n",
" if name_change:\n",
" print(f\"{n} -> {new_n}\")\n",
" \n",
" \n",
" p = p.cpu().detach().numpy()\n",
" if n.endswith('weight') and p.ndim == 2:\n",
" new_p = p.T\n",
" print(f\"{n}: {p.shape} -> {new_p.shape}\")\n",
" else:\n",
" new_p = p\n",
" \n",
" if 'global_cmvn.mean' in n:\n",
" print(p, p.dtype)\n",
" \n",
" paddle_state_dict[new_n] = new_p\n",
" \n",
"# np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n",
"# state=paddle_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a1d97e9f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.weight. encoder.encoders.0.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.bias. encoder.encoders.0.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.weight. encoder.encoders.1.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.bias. encoder.encoders.1.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.weight. encoder.encoders.2.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.bias. encoder.encoders.2.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.weight. encoder.encoders.3.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.bias. encoder.encoders.3.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.weight. encoder.encoders.4.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.bias. encoder.encoders.4.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.weight. encoder.encoders.5.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.bias. encoder.encoders.5.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.weight. encoder.encoders.6.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.bias. encoder.encoders.6.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.weight. encoder.encoders.7.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.bias. encoder.encoders.7.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.weight. encoder.encoders.8.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.bias. encoder.encoders.8.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.weight. encoder.encoders.9.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.bias. encoder.encoders.9.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.weight. encoder.encoders.10.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.bias. encoder.encoders.10.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.weight. encoder.encoders.11.concat_linear.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.bias. encoder.encoders.11.concat_linear.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.embed.0.weight. decoder.embed.0.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.weight. decoder.after_norm.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.bias. decoder.after_norm.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.weight. decoder.output_layer.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.bias. decoder.output_layer.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.weight. decoder.decoders.0.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.bias. decoder.decoders.0.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.weight. decoder.decoders.0.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.bias. decoder.decoders.0.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.weight. decoder.decoders.0.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.bias. decoder.decoders.0.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.weight. decoder.decoders.0.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.bias. decoder.decoders.0.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.weight. decoder.decoders.0.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.bias. decoder.decoders.0.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.weight. decoder.decoders.0.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.bias. decoder.decoders.0.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.weight. decoder.decoders.0.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.bias. decoder.decoders.0.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.weight. decoder.decoders.0.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.bias. decoder.decoders.0.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.weight. decoder.decoders.0.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.bias. decoder.decoders.0.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.weight. decoder.decoders.0.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.bias. decoder.decoders.0.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.weight. decoder.decoders.0.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.bias. decoder.decoders.0.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.weight. decoder.decoders.0.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.bias. decoder.decoders.0.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.weight. decoder.decoders.0.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.bias. decoder.decoders.0.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.weight. decoder.decoders.0.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.bias. decoder.decoders.0.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.weight. decoder.decoders.0.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.bias. decoder.decoders.0.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.weight. decoder.decoders.1.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.bias. decoder.decoders.1.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.weight. decoder.decoders.1.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.bias. decoder.decoders.1.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.weight. decoder.decoders.1.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.bias. decoder.decoders.1.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.weight. decoder.decoders.1.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.bias. decoder.decoders.1.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.weight. decoder.decoders.1.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.bias. decoder.decoders.1.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.weight. decoder.decoders.1.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.bias. decoder.decoders.1.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.weight. decoder.decoders.1.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.bias. decoder.decoders.1.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.weight. decoder.decoders.1.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.bias. decoder.decoders.1.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.weight. decoder.decoders.1.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.bias. decoder.decoders.1.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.weight. decoder.decoders.1.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.bias. decoder.decoders.1.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.weight. decoder.decoders.1.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.bias. decoder.decoders.1.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.weight. decoder.decoders.1.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.bias. decoder.decoders.1.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.weight. decoder.decoders.1.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.bias. decoder.decoders.1.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.weight. decoder.decoders.1.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.bias. decoder.decoders.1.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.weight. decoder.decoders.1.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.bias. decoder.decoders.1.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.weight. decoder.decoders.2.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.bias. decoder.decoders.2.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.weight. decoder.decoders.2.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.bias. decoder.decoders.2.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.weight. decoder.decoders.2.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.bias. decoder.decoders.2.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.weight. decoder.decoders.2.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.bias. decoder.decoders.2.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.weight. decoder.decoders.2.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.bias. decoder.decoders.2.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.weight. decoder.decoders.2.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.bias. decoder.decoders.2.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.weight. decoder.decoders.2.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.bias. decoder.decoders.2.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.weight. decoder.decoders.2.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.bias. decoder.decoders.2.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.weight. decoder.decoders.2.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.bias. decoder.decoders.2.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.weight. decoder.decoders.2.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.bias. decoder.decoders.2.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.weight. decoder.decoders.2.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.bias. decoder.decoders.2.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.weight. decoder.decoders.2.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.bias. decoder.decoders.2.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.weight. decoder.decoders.2.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.bias. decoder.decoders.2.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.weight. decoder.decoders.2.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.bias. decoder.decoders.2.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.weight. decoder.decoders.2.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.bias. decoder.decoders.2.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.weight. decoder.decoders.3.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.bias. decoder.decoders.3.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.weight. decoder.decoders.3.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.bias. decoder.decoders.3.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.weight. decoder.decoders.3.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.bias. decoder.decoders.3.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.weight. decoder.decoders.3.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.bias. decoder.decoders.3.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.weight. decoder.decoders.3.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.bias. decoder.decoders.3.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.weight. decoder.decoders.3.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.bias. decoder.decoders.3.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.weight. decoder.decoders.3.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.bias. decoder.decoders.3.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.weight. decoder.decoders.3.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.bias. decoder.decoders.3.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.weight. decoder.decoders.3.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.bias. decoder.decoders.3.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.weight. decoder.decoders.3.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.bias. decoder.decoders.3.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.weight. decoder.decoders.3.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.bias. decoder.decoders.3.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.weight. decoder.decoders.3.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.bias. decoder.decoders.3.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.weight. decoder.decoders.3.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.bias. decoder.decoders.3.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.weight. decoder.decoders.3.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.bias. decoder.decoders.3.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.weight. decoder.decoders.3.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.bias. decoder.decoders.3.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.weight. decoder.decoders.4.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.bias. decoder.decoders.4.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.weight. decoder.decoders.4.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.bias. decoder.decoders.4.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.weight. decoder.decoders.4.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.bias. decoder.decoders.4.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.weight. decoder.decoders.4.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.bias. decoder.decoders.4.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.weight. decoder.decoders.4.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.bias. decoder.decoders.4.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.weight. decoder.decoders.4.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.bias. decoder.decoders.4.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.weight. decoder.decoders.4.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.bias. decoder.decoders.4.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.weight. decoder.decoders.4.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.bias. decoder.decoders.4.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.weight. decoder.decoders.4.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.bias. decoder.decoders.4.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.weight. decoder.decoders.4.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.bias. decoder.decoders.4.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.weight. decoder.decoders.4.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.bias. decoder.decoders.4.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.weight. decoder.decoders.4.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.bias. decoder.decoders.4.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.weight. decoder.decoders.4.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.bias. decoder.decoders.4.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.weight. decoder.decoders.4.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.bias. decoder.decoders.4.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.weight. decoder.decoders.4.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.bias. decoder.decoders.4.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.weight. decoder.decoders.5.self_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.bias. decoder.decoders.5.self_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.weight. decoder.decoders.5.self_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.bias. decoder.decoders.5.self_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.weight. decoder.decoders.5.self_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.bias. decoder.decoders.5.self_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.weight. decoder.decoders.5.self_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.bias. decoder.decoders.5.self_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.weight. decoder.decoders.5.src_attn.linear_q.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.bias. decoder.decoders.5.src_attn.linear_q.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.weight. decoder.decoders.5.src_attn.linear_k.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.bias. decoder.decoders.5.src_attn.linear_k.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.weight. decoder.decoders.5.src_attn.linear_v.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.bias. decoder.decoders.5.src_attn.linear_v.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.weight. decoder.decoders.5.src_attn.linear_out.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.bias. decoder.decoders.5.src_attn.linear_out.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.weight. decoder.decoders.5.feed_forward.w_1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.bias. decoder.decoders.5.feed_forward.w_1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.weight. decoder.decoders.5.feed_forward.w_2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.bias. decoder.decoders.5.feed_forward.w_2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.weight. decoder.decoders.5.norm1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.bias. decoder.decoders.5.norm1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.weight. decoder.decoders.5.norm2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.bias. decoder.decoders.5.norm2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.weight. decoder.decoders.5.norm3.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.bias. decoder.decoders.5.norm3.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.weight. decoder.decoders.5.concat_linear1.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.bias. decoder.decoders.5.concat_linear1.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.weight. decoder.decoders.5.concat_linear2.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.bias. decoder.decoders.5.concat_linear2.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.weight. ctc.ctc_lo.weight is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n",
"/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.bias. ctc.ctc_lo.bias is not found in the provided dict.\n",
" warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n"
]
}
],
"source": [
"model.set_state_dict(paddle_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fc7edf1e",
"metadata": {},
"outputs": [],
"source": [
"e_state = model.encoder.state_dict()\n",
"for key, value in e_state.items():\n",
" if 'concat_linear' in key:\n",
" continue\n",
" if not np.allclose(value.numpy(), paddle_state_dict['encoder.' + key]):\n",
" print(key)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "572097d0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "748250b7",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "91e5deee",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 14,
"id": "fleet-despite",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n",
"encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n"
]
}
],
"source": [
"%ls .notebook/espnet"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "abroad-oracle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(8, 57, 83)\n",
"(8, 1, 57)\n",
"[57 50 48 38 32 31 28 25]\n"
]
}
],
"source": [
"data = np.load('.notebook/espnet/feat.npz', allow_pickle=True)\n",
"xs=data['xs']\n",
"masks=data['masks']\n",
"print(xs.shape)\n",
"print(masks.shape)\n",
"xs_lens = masks.sum(axis=-1).squeeze()\n",
"print(xs_lens)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "false-instrument",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[8, 13, 256]\n",
"[8, 1, 13]\n"
]
}
],
"source": [
"# ecnoder\n",
"xs = paddle.to_tensor(xs, dtype='float32')\n",
"x_lens = paddle.to_tensor(xs_lens, dtype='int32')\n",
"model.eval()\n",
"encoder_out, encoder_mask = model.encoder(xs, x_lens)\n",
"print(encoder_out.shape)\n",
"print(encoder_mask.shape)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "arctic-proxy",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(8, 13, 256)\n",
"(8, 1, 13)\n",
"False\n",
"False\n",
"True\n",
"True\n"
]
}
],
"source": [
"data = np.load('.notebook/espnet/encoder.npz', allow_pickle=True)\n",
"xs = data['xs']\n",
"masks = data['masks']\n",
"print(xs.shape)\n",
"print(masks.shape)\n",
"print(np.allclose(xs, encoder_out.numpy()))\n",
"print(np.allclose(xs, encoder_out.numpy(), atol=1e-6))\n",
"print(np.allclose(xs, encoder_out.numpy(), atol=1e-5))\n",
"print(np.allclose(masks, encoder_mask.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "seasonal-switch",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 2.1380312 1.8675405 -1.1873871 ... -0.30456656 0.56382173\n",
" -0.6526459 ]\n",
" [ 2.1926146 2.1373641 -0.6548196 ... -0.897318 0.6044322\n",
" -0.63332295]\n",
" [ 1.6367635 2.3320658 -0.8848577 ... -0.9640939 1.2420733\n",
" -0.05243584]\n",
" ...\n",
" [ 1.8533031 1.8421621 -0.6728406 ... 0.04810616 0.6459763\n",
" -0.18188554]\n",
" [ 2.0894065 1.7813934 -1.1591585 ... -0.09513803 0.8321831\n",
" -0.72916794]\n",
" [ 1.6488649 2.0984242 -1.3490562 ... 0.42678255 0.5903866\n",
" -0.32597935]]\n",
"Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[ 2.13803196, 1.86753929, -1.18738675, ..., -0.30456796, 0.56382364, -0.65264463],\n",
" [ 2.19261336, 2.13736486, -0.65482187, ..., -0.89731705, 0.60443199, -0.63332343],\n",
" [ 1.63676369, 2.33206534, -0.88485885, ..., -0.96409231, 1.24207270, -0.05243752],\n",
" ...,\n",
" [ 1.85330284, 1.84216177, -0.67284071, ..., 0.04810715, 0.64597648, -0.18188696],\n",
" [ 2.08940673, 1.78139246, -1.15916038, ..., -0.09513779, 0.83218288, -0.72916913],\n",
" [ 1.64886570, 2.09842515, -1.34905660, ..., 0.42678308, 0.59038705, -0.32598034]])\n"
]
}
],
"source": [
"print(xs[0])\n",
"print(encoder_out[0])"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "defined-brooks",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 2.209824 1.5208759 0.1417884 ... -0.73617566 1.6538682\n",
" -0.16355833]\n",
" [ 2.1441019 1.4377339 0.3629197 ... -0.91226125 1.3739952\n",
" 0.11874156]\n",
" [ 1.8725398 1.5417286 0.38919652 ... -0.89621615 1.1841662\n",
" 0.27621832]\n",
" ...\n",
" [ 2.4591084 0.7238764 -1.1456345 ... -0.24188249 0.8232168\n",
" -0.9794884 ]\n",
" [ 2.5156236 1.1919155 -0.97032744 ... -0.7360675 1.0647209\n",
" -1.3076135 ]\n",
" [ 2.160009 0.98425585 -1.2231126 ... -0.03393313 1.9141548\n",
" -1.0099151 ]]\n",
"Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[ 2.20982409, 1.52087593, 0.14178854, ..., -0.73617446, 1.65386844, -0.16355731],\n",
" [ 2.14410043, 1.43773460, 0.36291891, ..., -0.91226172, 1.37399518, 0.11874183],\n",
" [ 1.87254059, 1.54172909, 0.38919681, ..., -0.89621687, 1.18416822, 0.27621880],\n",
" ...,\n",
" [ 2.45910931, 0.72387671, -1.14563596, ..., -0.24188218, 0.82321703, -0.97948682],\n",
" [ 2.51562238, 1.19191694, -0.97032893, ..., -0.73606837, 1.06472087, -1.30761123],\n",
" [ 2.16000915, 0.98425680, -1.22311163, ..., -0.03393326, 1.91415381, -1.00991392]])\n"
]
}
],
"source": [
"print(xs[1])\n",
"print(encoder_out[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0504e3f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
因为 它太大了无法显示 source diff 。你可以改为 查看blob
[中文版](README_cn.md)
# PaddlePaddle ASR toolkit
# PaddlePaddle Speech to Any toolkit
![License](https://img.shields.io/badge/license-Apache%202-red.svg)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux-yellow.svg)
*PaddleASR* is an open-source implementation of end-to-end Automatic Speech Recognition (ASR) engine, with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient, samller and scalable implementation, including training, inference & testing module, and deployment.
*DeepSpeech* is an open-source implementation of end-to-end Automatic Speech Recognition engine, with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient, samller and scalable implementation, including training, inference & testing module, and deployment.
## Features
See [feature list](doc/src/feature_list.md) for more information.
See [feature list](docs/src/feature_list.md) for more information.
## Setup
All tested under:
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.1.0
* paddlepaddle>=2.1.2
Please see [install](doc/src/install.md).
Please see [install](docs/src/install.md).
## Getting Started
Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md).
Please see [Getting Started](docs/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md).
## More Information
* [Data Prepration](doc/src/data_preparation.md)
* [Data Augmentation](doc/src/augmentation.md)
* [Ngram LM](doc/src/ngram_lm.md)
* [Server Demo](doc/src/server.md)
* [Benchmark](doc/src/benchmark.md)
* [Relased Model](doc/src/released_model.md)
* [FAQ](doc/src/faq.md)
* [Data Prepration](docs/src/data_preparation.md)
* [Data Augmentation](docs/src/augmentation.md)
* [Ngram LM](docs/src/ngram_lm.md)
* [Benchmark](docs/src/benchmark.md)
* [Relased Model](docs/src/released_model.md)
## Questions and Help
......@@ -43,8 +41,8 @@ You are welcome to submit questions in [Github Discussions](https://github.com/P
## License
DeepASR is provided under the [Apache-2.0 License](./LICENSE).
DeepSpeech is provided under the [Apache-2.0 License](./LICENSE).
## Acknowledgement
We depends on many open source repos. See [References](doc/src/reference.md) for more information.
We depends on many open source repos. See [References](docs/src/reference.md) for more information.
[English](README.md)
# PaddlePaddle ASR toolkit
![License](https://img.shields.io/badge/license-Apache%202-red.svg)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux-yellow.svg)
*PaddleASR*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别(ASR)引擎的开源项目,
我们的愿景是为语音识别在工业应用和学术研究上,提供易于使用、高效、小型化和可扩展的工具,包括训练,推理,以及 部署。
## 特性
参看 [特性列表](doc/src/feature_list.md)
## 安装
* python>=3.7
* paddlepaddle>=2.1.0
参看 [安装](doc/src/install.md)
## 开始
请查看 [开始](doc/src/getting_started.md)[tiny egs](examples/tiny/s0/README.md)
## 更多信息
* [数据处理](doc/src/data_preparation.md)
* [数据增强](doc/src/augmentation.md)
* [语言模型](doc/src/ngram_lm.md)
* [服务部署](doc/src/server.md)
* [Benchmark](doc/src/benchmark.md)
* [Relased Model](doc/src/released_model.md)
* [FAQ](doc/src/faq.md)
## 问题和帮助
欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。
## License
DeepASR 遵循[Apache-2.0开源协议](./LICENSE)
## 感谢
开发中参考一些优秀的仓库,详情参见 [References](doc/src/reference.md)
# Benchmarks
## Acceleration with Multi-GPUs
We compare the training time with 1, 2, 4, 8 Tesla V100 GPUs (with a subset of LibriSpeech samples whose audio durations are between 6.0 and 7.0 seconds). And it shows that a **near-linear** acceleration with multiple GPUs has been achieved. In the following figure, the time (in seconds) cost for training is printed on the blue bars.
<img src="../images/multi_gpu_speedup.png" width=450>
| # of GPU | Acceleration Rate |
| -------- | --------------: |
| 1 | 1.00 X |
| 2 | 1.98 X |
| 4 | 3.73 X |
| 8 | 6.95 X |
`utils/profile.sh` provides such a demo profiling tool, you can change it as need.
# FAQ
1. 音频变速快慢到达什么晨读会影响识别率?
变速会提升识别效果,一般用0.9, 1.0, 1.1 的变速。
2. 音量大小到什么程度会影响识别率?
一般训练会固定音量到一个范围内,波动过大会影响训练,估计在10dB ~ 20dB吧。
3. 语音模型训练数据的最小时长要求时多少?
Aishell-1大约178h的数据,数据越多越好。
4. 那些噪声或背景生会影响识别率?
主要是人生干扰和低信噪比会影响识别率。
5. 单条语音数据的长度限制是多少?
一般训练的语音长度会限制在1s~6s之间,和训练配置有关。
6. 背景声在识别前是否需要分离出来,或做降噪处理?
需要分离的,需要结合具体场景考虑。
7. 模型是否带有VAD人生激活识别能力?
VAD是单独的模型或模块,模型不包含此能力。
8. 是否支持长语音识别?
一般过VAD后识别。
9. Mandarin LM Large语言模型需要的硬件配置时怎样的?
内存能放得下LM即可。
# Reference
* [wenet](https://github.com/mobvoi/wenet)
# Released Models
## Language Model Released
Language Model | Training Data | Token-based | Size | Descriptions
:-------------:| :------------:| :-----: | -----: | :-----------------
[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1; <br/> About 1.85 billion n-grams; <br/> 'trie' binary with '-a 22 -q 8 -b 8'
[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4; <br/> About 0.13 billion n-grams; <br/> 'probing' binary with default settings
[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning; <br/> About 3.7 billion n-grams; <br/> 'probing' binary with default settings
# Trying Live Demo with Your Own Voice
Until now, an ASR model is trained and tested qualitatively (`infer`) and quantitatively (`test`) with existing audio files. But it is not yet tested with your own speech. We build up a real-time demo ASR engine with the trained model, enabling you to test and play around with the demo, with your own voice.
First, change your directory to `examples/aishell` and `source path.sh`.
To start the demo's server, please run this in one console:
```bash
CUDA_VISIBLE_DEVICES=0 bash local/server.sh
```
For the machine (might not be the same machine) to run the demo's client, please do the following installation before moving on.
For example, on MAC OS X:
```bash
brew install portaudio
pip install pyaudio
pip install keyboard
```
Then to start the client, please run this in another console:
```bash
CUDA_VISIBLE_DEVICES=0 bash local/client.sh
```
Now, in the client console, press the `whitespace` key, hold, and start speaking. Until finishing your utterance, release the key to let the speech-to-text results shown in the console. To quit the client, just press `ESC` key.
Notice that `deepspeech/exps/deepspeech2/deploy/client.py` must be run on a machine with a microphone device, while `deepspeech/exps/deepspeech2/deploy/server.py` could be run on one without any audio recording hardware, e.g. any remote server machine. Just be careful to set the `host_ip` and `host_port` argument with the actual accessible IP address and port, if the server and client are running with two separate machines. Nothing should be done if they are running on one single machine.
Please also refer to `examples/aishell/local/server.sh`, which will first download a pre-trained Chinese model (trained with AISHELL1) and then start the demo server with the model. With running `examples/aishell/local/client.sh`, you can speak Chinese to test it. If you would like to try some other models, just update `--checkpoint_path` argument in the script.  
......@@ -21,7 +21,7 @@ To perform z-score normalization (zero-mean, unit stddev) upon audio features, w
```bash
python3 utils/compute_mean_std.py \
--num_samples 2000 \
--specgram_type linear \
--spectrum_type linear \
--manifest_path examples/librispeech/data/manifest.train \
--output_path examples/librispeech/data/mean_std.npz
```
......
# Deepspeech2
## Streaming
The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes.
The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers.
To illustrate the model implementation clearly, 3 parts are described in detail.
- Data Preparation
- Encoder
- Decoder
In addition, the training process and the testing process are also introduced.
The arcitecture of the model is shown in Fig.1.
<p align="center">
<img src="../images/ds2onlineModel.png" width=800>
<br/>Fig.1 The Arcitecture of deepspeech2 online model
</p>
### Data Preparation
#### Vocabulary
For English data, the vocabulary dictionary is composed of 26 English characters with " ' ", space, \<blank\> and \<eos\>. The \<blank\> represents the blank label in CTC, the \<unk\> represents the unknown character and the \<eos\> represents the start and the end characters. For mandarin, the vocabulary dictionary is composed of chinese characters statisticed from the training set and three additional characters are added. The added characters are \<blank\>, \<unk\> and \<eos\>. For both English and mandarin data, we set the default indexs that \<blank\>=0, \<unk\>=1 and \<eos\>= last index.
```
# The code to build vocabulary
cd examples/aishell/s0
python3 ../../../utils/build_vocab.py \
--unit_type="char" \
--count_threshold=0 \
--vocab_path="data/vocab.txt" \
--manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw"
# vocabulary for aishell dataset (Mandarin)
vi examples/aishell/s0/data/vocab.txt
# vocabulary for librispeech dataset (English)
vi examples/librispeech/s0/data/vocab.txt
```
#### CMVN
For CMVN, a subset or the full of traininig set is chosed and be used to compute the feature mean and std.
```
# The code to compute the feature mean and std
cd examples/aishell/s0
python3 ../../../utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--spectrum_type="linear" \
--delta_delta=false \
--stride_ms=10.0 \
--window_ms=20.0 \
--sample_rate=16000 \
--use_dB_normalization=True \
--num_samples=2000 \
--num_workers=10 \
--output_path="data/mean_std.json"
```
#### Feature Extraction
For feature extraction, three methods are implemented, which are linear (FFT without using filter bank), fbank and mfcc.
Currently, the released deepspeech2 online model use the linear feature extraction method.
```
The code for feature extraction
vi deepspeech/frontend/featurizer/audio_featurizer.py
```
### Encoder
The encoder is composed of two 2D convolution subsampling layers and a number of stacked single direction rnn layers. The 2D convolution subsampling layers extract feature representation from the raw audio feature and reduce the length of audio feature at the same time. After passing through the convolution subsampling layers, then the feature representation are input into the stacked rnn layers. For the stacked rnn layers, LSTM cell and GRU cell are provided to use. Adding one fully connected (fc) layer after the stacked rnn layers is optional. If the number of stacked rnn layers is less than 5, adding one fc layer after stacked rnn layers is recommand.
The code of Encoder is in:
```
vi deepspeech/models/ds2_online/deepspeech2.py
```
### Decoder
To got the character possibilities of each frame, the feature representation of each frame output from the encoder are input into a projection layer which is implemented as a dense layer to do feature projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to transform the frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results.
The code of the decoder is in:
```
# The code of constructing the decoder in model
vi deepspeech/models/ds2_online/deepspeech2.py
# The code of CTC Decoder
vi deepspeech/modules/ctc.py
```
## Training Process
Using the command below, you can train the deepspeech2 online model.
```
cd examples/aishell/s0
bash run.sh --stage 0 --stop_stage 2 --model_type online --conf_path conf/deepspeech2_online.yaml
```
The detail commands are:
```
# The code for training in run.sh
set -e
source path.sh
gpus=2,3,5,7
stage=0
stop_stage=5
conf_path=conf/deepspeech2_online.yaml # conf/deepspeech2.yaml | conf/deepspeech2_online.yaml
avg_num=1
model_type=online # online | offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
```
By using the command above, the training process can be started. There are 5 stages in "run.sh", and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss.
## Testing Process
Using the command below, you can test the deepspeech2 online model.
```
bash run.sh --stage 3 --stop_stage 5 --model_type online --conf_path conf/deepspeech2_online.yaml
```
The detail commands are:
```
conf_path=conf/deepspeech2_online.yaml
avg_num=1
model_type=online
avg_ckpt=avg_${avg_num}
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=2 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=5 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# test export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1
fi
```
After the training process, we use stage 3,4,5 for testing process. The stage 3 is for testing the model generated in the stage 2 and provided the CER index of the test set. The stage 4 is for transforming the model from dynamic graph to static graph by using "paddle.jit" library. The stage 5 is for testing the model in static graph.
## Non-Streaming
The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the stacked bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used. For the stacked bi-directional rnn layers in the offline model, the rnn cell and gru cell are provided to use.
The arcitecture of the model is shown in Fig.2.
<p align="center">
<img src="../images/ds2offlineModel.png" width=800>
<br/>Fig.2 The Arcitecture of deepspeech2 offline model
</p>
For data preparation and decoder, the deepspeech2 offline model is same with the deepspeech2 online model.
The code of encoder and decoder for deepspeech2 offline model is in:
```
vi deepspeech/models/ds2/deepspeech2.py
```
The training process and testing process of deepspeech2 offline model is very similary to deepspeech2 online model.
Only some changes should be noticed.
For training and testing, the "model_type" and the "conf_path" must be set.
```
# Training offline
cd examples/aishell/s0
bash run.sh --stage 0 --stop_stage 2 --model_type offline --conf_path conf/deepspeech2.yaml
```
```
# Testing offline
cd examples/aishell/s0
bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml
```
# Featrues
# Features
### Dataset
* Aishell
* Librispeech
* THCHS30
* TIMIT
### Speech Recognition
* Offline
* Non-Streaming
* [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf)
* [Transformer](https://arxiv.org/abs/1706.03762)
* [Conformer](https://arxiv.org/abs/2005.08100)
* Online
* Streaming
* [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf)
* [U2](https://arxiv.org/pdf/2012.05481.pdf)
### Language Model
......@@ -22,6 +29,15 @@
* beam search
* attention rescore
### Deployment
* Paddle Inference
### Aligment
* MFA
* CTC Aligment
### Speech Frontend
* Audio
......
......@@ -4,15 +4,16 @@ To avoid the trouble of environment setup, [running in Docker container](#runnin
## Prerequisites
- Python >= 3.7
- PaddlePaddle 2.0.0 or later (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html))
- PaddlePaddle latest version (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html))
## Setup
## Setup (Important)
- Make sure these libraries or tools installed: `pkg-config`, `flac`, `ogg`, `vorbis`, `boost`, `sox, and `swig`, e.g. installing them via `apt-get`:
```bash
sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
```
The version of `swig` should >= 3.0
or, installing them via `yum`:
......
......@@ -35,52 +35,3 @@ Different from the English language model, Mandarin language model is character-
* A whitespace character between two tokens is inserted.
Please notice that the released language models only contain Chinese simplified characters. After preprocessing done we can begin to train the language model. The key training arguments for small LM is '-o 5 --prune 0 1 2 4 4' and '-o 5' for large LM. Please refer above section for the meaning of each argument. We also convert the arpa file to binary file using default settings.
## [KenLM](http://kheafield.com/code/kenlm/)
统计语言模型工具有比较多的选择,目前使用比较好的有srilm及kenlm,其中kenlm比srilm晚出来,训练速度也更快,而且支持单机大数据的训练。现在介绍一下kenlm的使用方法。
1. 工具包的下载地址:http://kheafield.com/code/kenlm.tar.gz
2. 使用。该工具在linux环境下使用方便。 先确保linux环境已经按照1.36.0的Boost和zlib
```
boost:
yum install boost
yum install boost-devel
zlib:
yum install zlib
yum install zlib-devel
```
然后gcc版本需要是4.8.2及以上。
```
wget -O - https://kheafield.com/code/kenlm.tar.gz |tar xz
mkdir kenlm/build
cd kenlm/build
cmake ..
make -j2
```
3. 训练。使用如下命令进行训练:
```
build/bin/lmplz -o 3 --verbose_header --text people2014corpus_words.txt --arpa result/people2014corpus_words.arps
```
其中,
1)people2014corpus_words.txt文件必须是分词以后的文件。
训练语料<人民日报2014版熟语料>,包括: 1)标准人工切词及词性数据people2014.tar.gz, 2)未切词文本数据people2014_words.txt, 3)kenlm训练字粒度语言模型文件及其二进制文件people2014corpus_chars.arps/klm, 4)kenlm词粒度语言模型文件及其二进制文件people2014corpus_words.arps/klm。
2)-o后面的5表示的是5-gram,一般取到3即可,但可以结合自己实际情况判断。
4. 压缩。压缩模型为二进制,方便模型快速加载:
```
build/bin/build_binary ./result/people2014corpus_words.arps ./result/people2014corpus_words.klm
```
# Reference
We refer these repos to build `model` and `engine`:
* [delta](https://github.com/Delta-ML/delta.git)
* [espnet](https://github.com/espnet/espnet.git)
* [kaldi](https://github.com/kaldi-asr/kaldi.git)
* [wenet](https://github.com/mobvoi/wenet)
# Released Models
## Acoustic Model Released in paddle 2.X
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech
:-------------:| :------------:| :-----: | -----: | :----------------- |:--------- | :---------- | :---------
[Ds2 Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds_online.5rnn.debug.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.0824 |-| 151 h
[Ds2 Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds2.offline.cer6p65.release.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.065 |-| 151 h
[Conformer Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz) | Aishell Dataset | Char-based | 283 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention + CTC | 0.0594 |-| 151 h
[Conformer Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.release.tar.gz) | Aishell Dataset | Char-based | 284 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0547 |-| 151 h
[Conformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/conformer.release.tar.gz) | Librispeech Dataset | Word-based | 287 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention |-| 0.0325 | 960 h
[Transformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/transformer.release.tar.gz) | Librispeech Dataset | Word-based | 195 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention |-| 0.0544 | 960 h
## Acoustic Model Transformed from paddle 1.8
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech
:-------------:| :------------:| :-----: | -----: | :----------------- | :---------- | :---------- | :---------
[Ds2 Offline Aishell model](https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz)|Aishell Dataset| Char-based| 234 MB| 2 Conv + 3 bidirectional GRU layers| 0.0804 |-| 151 h|
[Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz)|Librispeech Dataset| Word-based| 307 MB| 2 Conv + 3 bidirectional sharing weight RNN layers |-| 0.0685| 960 h|
[Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz)|Baidu Internal English Dataset| Word-based| 273 MB| 2 Conv + 3 bidirectional GRU layers |-| 0.0541 | 8628 h|
## Language Model Released
Language Model | Training Data | Token-based | Size | Descriptions
:-------------:| :------------:| :-----: | -----: | :-----------------
[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1; <br/> About 1.85 billion n-grams; <br/> 'trie' binary with '-a 22 -q 8 -b 8'
[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4; <br/> About 0.13 billion n-grams; <br/> 'probing' binary with default settings
[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning; <br/> About 3.7 billion n-grams; <br/> 'probing' binary with default settings
export MAIN_ROOT=${PWD}
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:/usr/local/bin:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
# 1xt2x
Convert Deepspeech 1.8 released model to 2.x.
## Model
* Deepspeech2x
## Exp
* baidu_en8k
* aishell
* librispeech
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
spectrum_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
use_gru: True
share_rnn_weights: False
blank_id: 4333
training:
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 2.6
beta: 5.0
beam_size: 300
cutoff_prob: 0.99
cutoff_top_n: 40
num_proc_bsearch: 8
#!/bin/bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh ${ckpt_dir}
if [ $? -ne 0 ]; then
exit 1
fi
cd ${ckpt_dir}
tar xzvf aishell_model_v1.8_to_v2.x.tar.gz
cd -
mv ${ckpt_dir}/mean_std.npz data/
mv ${ckpt_dir}/vocab.txt data/
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/aishell/aishell.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/aishell"
if [ $? -ne 0 ]; then
echo "Prepare Aishell failed. Terminated."
exit 1
fi
for dataset in train dev test; do
mv data/manifest.${dataset} data/manifest.${dataset}.raw
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
} &
done
wait
fi
echo "Aishell data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
MD5="29e02312deb2e59b3c8686c7966d4fe3"
TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
#! /usr/bin/env bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz'
MD5=87e7577d4bea737dbf3e8daab37aa808
TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz
echo "Download Aishell model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download Aishell model!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
echo "BIN_DIR "${BIN_DIR}
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
gpus=2
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=aishell_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: .inf # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
spectrum_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
use_gru: True
share_rnn_weights: False
blank_id: 28
training:
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 1.4
beta: 0.35
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#!/bin/bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
stage=-1
stop_stage=100
unit_type=char
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh ${ckpt_dir}
if [ $? -ne 0 ]; then
exit 1
fi
cd ${ckpt_dir}
tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz
cd -
mv ${ckpt_dir}/mean_std.npz data/
mv ${ckpt_dir}/vocab.txt data/
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="True"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mv data/manifest.${set} data/manifest.${set}.raw
done
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
for set in train-clean-100 train-clean-360 train-other-500; do
cat data/manifest.${set}.raw >> data/manifest.train.raw
done
for set in dev-clean dev-other; do
cat data/manifest.${set}.raw >> data/manifest.dev.raw
done
for set in test-clean test-other; do
cat data/manifest.${set}.raw >> data/manifest.test.raw
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest.${set} failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "LibriSpeech Data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
#! /usr/bin/env bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz'
MD5=c1676be8505cee436e6f312823e9008c
TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz
echo "Download BaiduEn8k model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download BaiduEn8k model!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
echo "BIN_DIR "${BIN_DIR}
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
gpus=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=baidu_en8k_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: 1000.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
spectrum_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 2048
use_gru: False
share_rnn_weights: True
blank_id: 28
training:
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#!/bin/bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
stage=-1
stop_stage=100
unit_type=char
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh ${ckpt_dir}
if [ $? -ne 0 ]; then
exit 1
fi
cd ${ckpt_dir}
tar xzvf librispeech_v1.8_to_v2.x.tar.gz
cd -
mv ${ckpt_dir}/mean_std.npz data/
mv ${ckpt_dir}/vocab.txt data/
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="True"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mv data/manifest.${set} data/manifest.${set}.raw
done
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
for set in train-clean-100 train-clean-360 train-other-500; do
cat data/manifest.${set}.raw >> data/manifest.train.raw
done
for set in dev-clean dev-other; do
cat data/manifest.${set}.raw >> data/manifest.dev.raw
done
for set in test-clean test-other; do
cat data/manifest.${set}.raw >> data/manifest.test.raw
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest.${set} failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "LibriSpeech Data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
#! /usr/bin/env bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz'
MD5=a06d9aadb560ea113984dc98d67232c8
TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz
echo "Download LibriSpeech model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download LibriSpeech model!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
gpus=1
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=librispeech_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import List
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddle.fluid import core
from paddle.nn import functional as F
from deepspeech.utils.log import Log
#TODO(Hui Zhang): remove fluid import
logger = Log(__name__).getlog()
########### hcak logging #############
logger.warn = logger.warning
########### hcak paddle #############
paddle.half = 'float16'
paddle.float = 'float32'
paddle.double = 'float64'
paddle.short = 'int16'
paddle.int = 'int32'
paddle.long = 'int64'
paddle.uint16 = 'uint16'
paddle.cdouble = 'complex128'
def convert_dtype_to_string(tensor_dtype):
"""
Convert the data type in numpy to the data type in Paddle
Args:
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
Returns:
core.VarDesc.VarType: the data type in Paddle.
"""
dtype = tensor_dtype
if dtype == core.VarDesc.VarType.FP32:
return paddle.float32
elif dtype == core.VarDesc.VarType.FP64:
return paddle.float64
elif dtype == core.VarDesc.VarType.FP16:
return paddle.float16
elif dtype == core.VarDesc.VarType.INT32:
return paddle.int32
elif dtype == core.VarDesc.VarType.INT16:
return paddle.int16
elif dtype == core.VarDesc.VarType.INT64:
return paddle.int64
elif dtype == core.VarDesc.VarType.BOOL:
return paddle.bool
elif dtype == core.VarDesc.VarType.BF16:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return paddle.uint16
elif dtype == core.VarDesc.VarType.UINT8:
return paddle.uint8
elif dtype == core.VarDesc.VarType.INT8:
return paddle.int8
elif dtype == core.VarDesc.VarType.COMPLEX64:
return paddle.complex64
elif dtype == core.VarDesc.VarType.COMPLEX128:
return paddle.complex128
else:
raise ValueError("Not supported tensor dtype %s" % dtype)
if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'log_softmax'):
logger.warn("register user log_softmax to paddle, remove this when fixed!")
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
if not hasattr(paddle, 'sigmoid'):
logger.warn("register user sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle, 'log_sigmoid'):
logger.warn("register user log_sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
if not hasattr(paddle, 'relu'):
logger.warn("register user relu to paddle, remove this when fixed!")
setattr(paddle, 'relu', paddle.nn.functional.relu)
def cat(xs, dim=0):
return paddle.concat(xs, axis=dim)
if not hasattr(paddle, 'cat'):
logger.warn(
"override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat
########### hcak paddle.Tensor #############
def item(x: paddle.Tensor):
return x.numpy().item()
if not hasattr(paddle.Tensor, 'item'):
logger.warn(
"override item of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.item = item
def func_long(x: paddle.Tensor):
return paddle.cast(x, paddle.long)
if not hasattr(paddle.Tensor, 'long'):
logger.warn(
"override long of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.long = func_long
if not hasattr(paddle.Tensor, 'numel'):
logger.warn(
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.numel = paddle.numel
def new_full(x: paddle.Tensor,
size: Union[List[int], Tuple[int], paddle.Tensor],
fill_value: Union[float, int, bool, paddle.Tensor],
dtype=None):
return paddle.full(size, fill_value, dtype=x.dtype)
if not hasattr(paddle.Tensor, 'new_full'):
logger.warn(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.new_full = new_full
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if convert_dtype_to_string(xs.dtype) == paddle.bool:
xs = xs.astype(paddle.int)
return xs.equal(
paddle.to_tensor(
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
if not hasattr(paddle.Tensor, 'eq'):
logger.warn(
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.eq = eq
if not hasattr(paddle, 'eq'):
logger.warn(
"override eq of paddle if exists or register, remove this when fixed!")
paddle.eq = eq
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs
if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.contiguous = contiguous
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
nargs = len(args)
assert (nargs <= 1)
s = paddle.shape(xs)
if nargs == 1:
return s[args[0]]
else:
return s
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn(
"override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
paddle.Tensor.size = size
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return xs.reshape(args)
if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
return xs.reshape(ys.size())
if not hasattr(paddle.Tensor, 'view_as'):
logger.warn(
"register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn(
"register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill
def masked_fill_(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]) -> paddle.Tensor:
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
ret = paddle.where(mask, trues, xs)
paddle.assign(ret.detach(), output=xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_
def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
val = paddle.full_like(xs, value)
paddle.assign(val.detach(), output=xs)
return xs
if not hasattr(paddle.Tensor, 'fill_'):
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
return paddle.tile(xs, size)
if not hasattr(paddle.Tensor, 'repeat'):
logger.warn(
"register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'):
logger.warn(
"register user softmax to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle.Tensor, 'sigmoid'):
logger.warn(
"register user sigmoid to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
return x.astype(other.dtype)
if not hasattr(paddle.Tensor, 'type_as'):
logger.warn(
"register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as)
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
assert len(args) == 1
if isinstance(args[0], str): # dtype
return x.astype(args[0])
elif isinstance(args[0], paddle.Tensor): #Tensor
return x.astype(args[0].dtype)
else: # Device
return x
if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to)
def func_float(x: paddle.Tensor) -> paddle.Tensor:
return x.astype(paddle.float)
if not hasattr(paddle.Tensor, 'float'):
logger.warn("register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float)
def func_int(x: paddle.Tensor) -> paddle.Tensor:
return x.astype(paddle.int)
if not hasattr(paddle.Tensor, 'int'):
logger.warn("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int)
def tolist(x: paddle.Tensor) -> List[Any]:
return x.numpy().tolist()
if not hasattr(paddle.Tensor, 'tolist'):
logger.warn(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn #############
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from src_deepspeech2x.test_model import DeepSpeech2Tester as Tester
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModel
from .deepspeech2 import DeepSpeech2Model
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
from typing import Optional
import paddle
from paddle import nn
from src_deepspeech2x.models.ds2.rnn import RNNStack
from yacs.config import CfgNode
from deepspeech.models.ds2.conv import ConvStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack
self.rnn = RNNStack(
i_size=i_size,
h_size=rnn_size,
num_stacks=num_rnn_layers,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
@property
def output_size(self):
return self.rnn_size * 2
def forward(self, audio, audio_len):
"""Compute Encoder outputs
Args:
audio (Tensor): [B, Tmax, D]
text (Tensor): [B, Umax]
audio_len (Tensor): [B]
text_len (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
"""
# [B, T, D] -> [B, D, T]
audio = audio.transpose([0, 2, 1])
# [B, D, T] -> [B, C=1, D, T]
x = audio.unsqueeze(1)
x_lens = audio_len
# convolution group
x, x_lens = self.conv(x, x_lens)
x_val = x.numpy()
# convert data from convolution feature map to sequence of vectors
#B, C, D, T = paddle.shape(x) # not work under jit
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part
x, x_lens = self.rnn(x, x_lens) #[B, T, D]
return x, x_lens
class DeepSpeech2Model(nn.Layer):
"""The DeepSpeech2 network structure.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward direction RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
assert (self.encoder.output_size == rnn_size * 2)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=blank_id, # first token is <blank>
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len = self.encoder(audio, audio_len)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
print("probs.shape", probs.shape)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
def decode_probs_split(self, probs_split, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
return self.decoder.decode_probs_split(
probs_split, vocab_list, decoding_method, lang_model_path,
beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes)
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=len(dataloader.collate_fn.vocab_list),
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id)
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def forward(self, audio, audio_len):
"""export model function
Args:
audio (Tensor): [B, T, D]
audio_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs, eouts_len
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, self.encoder.feat_size],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
return static_model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size: int,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size: int,
hidden_size: int,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.relu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states # shape [batch_size, hidden_size]
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
bias_u, bias_r, bias_c = paddle.split(
self.bias_hh, num_or_sections=3, axis=0)
weight_hh = paddle.transpose(
self.weight_hh,
perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size]
w_u_r_c = paddle.flatten(weight_hh)
size_u_r = self.hidden_size * 2 * self.hidden_size
w_u_r = paddle.reshape(w_u_r_c[:size_u_r],
(self.hidden_size, self.hidden_size * 2))
w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1)
w_c = paddle.reshape(w_u_r_c[size_u_r:],
(self.hidden_size, self.hidden_size))
h_u = paddle.matmul(
pre_hidden, w_u,
transpose_y=False) + bias_u #shape [batch_size, hidden_size]
h_r = paddle.matmul(
pre_hidden, w_r,
transpose_y=False) + bias_r #shape [batch_size, hidden_size]
x_u, x_r, x_c = paddle.split(
x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size]
u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size]
r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size]
c = self._activation(
x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) +
bias_c) # [batch_size, hidden_size]
h = (1 - u) * pre_hidden + u * c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int, share_weights: bool):
super().__init__()
self.share_weights = share_weights
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self,
i_size: int,
h_size: int,
num_stacks: int,
use_gru: bool,
share_rnn_weights: bool):
super().__init__()
rnn_stacks = []
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
else:
rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.LayerList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel
from src_deepspeech2x.models.ds2 import DeepSpeech2Model
from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class DeepSpeech2Trainer(Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# training config
default = CfgNode(
dict(
lr=5e-4, # learning rate
lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
loss = self.model(audio, audio_len, text, text_len)
losses_np = {
'train_loss': float(loss),
}
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.iteration += 1
iteration_time = time.time() - start
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
# `step -1` since we update `step` after optimizer.step().
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration - 1)
@paddle.no_grad()
def valid(self):
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
self.model.eval()
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
loss = self.model(audio, audio_len, text, text_len)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
if (i + 1) % self.config.training.log_interval == 0:
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
valid_dump['val_history_loss'] = total_loss / num_seen_utts
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
logger.info('Rank {} Val info val_loss {}'.format(
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def setup_model(self):
config = self.config.clone()
config.defrost()
config.model.feat_size = self.train_loader.collate_fn.feature_size
#config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.model.dict_size = len(self.train_loader.collate_fn.vocab_list)
config.freeze()
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config.model)
else:
raise Exception("wrong model type")
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
grad_clip = ClipGradByGlobalNormWithLog(
config.training.global_grad_clip)
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=config.training.lr,
gamma=config.training.lr_decay,
verbose=True)
optimizer = paddle.optimizer.Adam(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test)
if "<eos>" in self.test_loader.collate_fn.vocab_list:
self.test_loader.collate_fn.vocab_list.remove("<eos>")
if "<eos>" in self.valid_loader.collate_fn.vocab_list:
self.valid_loader.collate_fn.vocab_list.remove("<eos>")
if "<eos>" in self.train_loader.collate_fn.vocab_list:
self.train_loader.collate_fn.vocab_list.remove("<eos>")
logger.info("Setup train/valid/test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# testing config
default = CfgNode(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None)
super().__init__(config, args)
def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
result_transcripts = [
self._text_featurizer.detokenize(item)
for item in result_transcripts
]
return result_transcripts
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval()
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def export(self):
if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
elif self.args.model_type == 'online':
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
raise Exception("wrong model type")
infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size
static_model = infer_model.export()
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
}
]
......@@ -52,16 +52,18 @@
{
"type": "specaug",
"params": {
"W": 80,
"warp_mode": "PIL",
"F": 10,
"T": 50,
"n_freq_masks": 2,
"T": 50,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": false
},
"prob": 0.0
"prob": 1.0
}
]
# [CC-CEDICT](https://cc-cedict.org/wiki/)
What is CC-CEDICT?
CC-CEDICT is a continuation of the CEDICT project.
The objective of the CEDICT project was to create an online, downloadable (as opposed to searchable-only) public-domain Chinese-English dictionary.
CEDICT was started by Paul Andrew Denisowski in October 1997.
For the most part, the project is modeled on Jim Breen's highly successful EDICT (Japanese-English dictionary) project and is intended to be a collaborative effort,
with users providing entries and corrections to the main file.
## Parse CC-CEDICT to Json format
1. Parse to Json
```
run.sh
```
2. Result
```
exp/
|-- cedict
`-- cedict.json
0 directories, 2 files
```
```
4c4bffc84e24467fe1b2ea9ba37ed6b6 exp/cedict
3adf504dacd13886f88cc9fe3b37c75d exp/cedict.json
```
```
==> exp/cedict <==
# CC-CEDICT
# Community maintained free Chinese-English dictionary.
#
# Published by MDBG
#
# License:
# Creative Commons Attribution-ShareAlike 4.0 International License
# https://creativecommons.org/licenses/by-sa/4.0/
#
# Referenced works:
==> exp/cedict.json <==
{"traditional": "2019\u51a0\u72c0\u75c5\u6bd2\u75c5", "simplified": "2019\u51a0\u72b6\u75c5\u6bd2\u75c5", "pinyin": "er4 ling2 yi1 jiu3 guan1 zhuang4 bing4 du2 bing4", "english": "COVID-19, the coronavirus disease identified in 2019"}
{"traditional": "21\u4e09\u9ad4\u7d9c\u5408\u75c7", "simplified": "21\u4e09\u4f53\u7efc\u5408\u75c7", "pinyin": "er4 shi2 yi1 san1 ti3 zong1 he2 zheng4", "english": "trisomy"}
{"traditional": "3C", "simplified": "3C", "pinyin": "san1 C", "english": "abbr. for computers, communications, and consumer electronics"}
{"traditional": "3P", "simplified": "3P", "pinyin": "san1 P", "english": "(slang) threesome"}
{"traditional": "3Q", "simplified": "3Q", "pinyin": "san1 Q", "english": "(Internet slang) thank you (loanword)"}
{"traditional": "421", "simplified": "421", "pinyin": "si4 er4 yi1", "english": "four grandparents, two parents and an only child"}
{"traditional": "502\u81a0", "simplified": "502\u80f6", "pinyin": "wu3 ling2 er4 jiao1", "english": "cyanoacrylate glue"}
{"traditional": "88", "simplified": "88", "pinyin": "ba1 ba1", "english": "(Internet slang) bye-bye (alternative for \u62dc\u62dc[bai2 bai2])"}
{"traditional": "996", "simplified": "996", "pinyin": "jiu3 jiu3 liu4", "english": "9am-9pm, six days a week (work schedule)"}
{"traditional": "A", "simplified": "A", "pinyin": "A", "english": "(slang) (Tw) to steal"}
```
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
# Download Baker dataset
Baker dataset has to be downloaded mannually and moved to 'data/', because you will have to pass the CATTCHA from a browswe to download the dataset.
Download URL https://test.data-baker.com/#/data/index/source.
*.tgz
manifest.*
*.meta
aidatatang_200zh/
\ No newline at end of file
# [Aidatatang_200zh](http://www.openslr.org/62/)
Aidatatang_200zh is a free Chinese Mandarin speech corpus provided by Beijing DataTang Technology Co., Ltd under Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License.
The contents and the corresponding descriptions of the corpus include:
* The corpus contains 200 hours of acoustic data, which is mostly mobile recorded data.
* 600 speakers from different accent areas in China are invited to participate in the recording.
* The transcription accuracy for each sentence is larger than 98%.
* Recordings are conducted in a quiet indoor environment.
* The database is divided into training set, validation set, and testing set in a ratio of 7: 1: 2.
* Detail information such as speech data coding and speaker information is preserved in the metadata file.
* Segmented transcripts are also provided.
The corpus aims to support researchers in speech recognition, machine translation, voiceprint recognition, and other speech-related fields. Therefore, the corpus is totally free for academic use.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare aidatatang_200zh mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import soundfile
from utils.utility import download
from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/62'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/62'
DATA_URL = URL_ROOT + '/aidatatang_200zh.tgz'
MD5_DATA = '6e0f4f39cd5f667a7ee53c397c8d0949'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/aidatatang_200zh",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
transcript_path = os.path.join(data_dir, 'transcript',
'aidatatang_200_zh_transcript.txt')
transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip()
if line == '':
continue
audio_id, text = line.split(' ', 1)
# remove withespace, charactor text
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, 'corpus/', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
if not fname.endswith('.wav'):
continue
audio_path = os.path.abspath(os.path.join(subfolder, fname))
audio_id = os.path.basename(fname)[:-4]
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(
json.dumps(
{
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': text,
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, subset)
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'corpus')
for subfolder, dirlist, filelist in sorted(os.walk(audio_dir)):
for sub in dirlist:
print(f"unpack dir {sub}...")
for folder, _, filelist in sorted(
os.walk(os.path.join(subfolder, sub))):
for ftar in filelist:
unpack(os.path.join(folder, ftar), folder, True)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
manifest_path=args.manifest_prefix,
subset='aidatatang_200zh')
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()
data_aishell*
*.meta
manifest.*
*.tgz
resource_aishell
# [Aishell1](http://www.openslr.org/33/)
This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. )
......@@ -31,9 +31,11 @@ from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/33'
URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
DATA_URL = URL_ROOT + '/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
RESOURCE_URL = URL_ROOT + '/resource_aishell.tgz'
MD5_RESOURCE = '957d480a0fcac85fc18e550756f624e5'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
......@@ -60,18 +62,22 @@ def create_manifest(data_dir, manifest_path_prefix):
if line == '':
continue
audio_id, text = line.split(' ', 1)
# remove withespace
# remove withespace, charactor text
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, 'wav', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
audio_path = os.path.abspath(os.path.join(subfolder, fname))
audio_id = os.path.basename(fname)[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
......@@ -81,22 +87,34 @@ def create_manifest(data_dir, manifest_path_prefix):
json_lines.append(
json.dumps(
{
'utt':
os.path.splitext(os.path.basename(audio_path))[0],
'feat':
audio_path,
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text':
text
'text': text
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
......@@ -110,7 +128,9 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path):
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
if manifest_path:
create_manifest(data_dir, manifest_path)
def main():
......@@ -123,6 +143,14 @@ def main():
target_dir=args.target_dir,
manifest_path=args.manifest_prefix)
prepare_dataset(
url=RESOURCE_URL,
md5sum=MD5_RESOURCE,
target_dir=args.target_dir,
manifest_path=None)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()
# [Aishell3](http://www.openslr.org/93/)
AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus which could be used to train multi-speaker Text-to-Speech (TTS) systems. The corpus contains roughly **85 hours** of emotion-neutral recordings spoken by 218 native Chinese mandarin speakers and total 88035 utterances. Their auxiliary attributes such as gender, age group and native accents are explicitly marked and provided in the corpus. Accordingly, transcripts in Chinese character-level and pinyin-level are provided along with the recordings. The word & tone transcription accuracy rate is above 98%, through professional speech annotation and strict quality inspection for tone and prosody. ( This database is free for academic research, not in the commerce, if without permission. )
# [GigaSpeech](https://github.com/SpeechColab/GigaSpeech)
```
git clone https://github.com/SpeechColab/GigaSpeech.git
cd GigaSpeech
utils/gigaspeech_download.sh /disk1/audio_data/gigaspeech
toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL /disk1/audio_data/gigaspeech ../data
cd ..
```
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
set -e
curdir=$PWD
test -d GigaSpeech || git clone https://github.com/SpeechColab/GigaSpeech.git
pushd GigaSpeech
source env_vars.sh
./utils/download_gigaspeech.sh ${curdir}/
#toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL /disk1/audio_data/gigaspeech ../data
popd
dev-clean/
dev-other/
test-clean/
test-other/
train-clean-100/
train-clean-360/
train-other-500/
dev-clean
dev-other
test-clean
test-other
train-clean-100
train-clean-360
train-other-500
*.meta
manifest.*
......@@ -77,6 +77,10 @@ def create_manifest(data_dir, manifest_path):
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
total_sec = 0.0
total_text = 0.0
total_num = 0
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
......@@ -86,7 +90,9 @@ def create_manifest(data_dir, manifest_path):
for line in io.open(text_filepath, encoding="utf8"):
segments = line.strip().split()
text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.join(subfolder, segments[0] + '.flac')
audio_filepath = os.path.abspath(
os.path.join(subfolder, segments[0] + '.flac'))
audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate
json_lines.append(
......@@ -99,10 +105,27 @@ def create_manifest(data_dir, manifest_path):
'text':
text
}))
total_sec += duration
total_text += len(text)
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
subset = os.path.splitext(manifest_path)[1][1:]
manifest_dir = os.path.dirname(manifest_path)
data_dir_name = os.path.split(data_dir)[-1]
meta_path = os.path.join(manifest_dir, data_dir_name) + '.meta'
with open(meta_path, 'w') as f:
print(f"{subset}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
......
# [MagicData](http://www.openslr.org/68/)
MAGICDATA Mandarin Chinese Read Speech Corpus was developed by MAGIC DATA Technology Co., Ltd. and freely published for non-commercial use.
The contents and the corresponding descriptions of the corpus include:
* The corpus contains 755 hours of speech data, which is mostly mobile recorded data.
* 1080 speakers from different accent areas in China are invited to participate in the recording.
* The sentence transcription accuracy is higher than 98%.
* Recordings are conducted in a quiet indoor environment.
* The database is divided into training set, validation set, and testing set in a ratio of 51: 1: 2.
* Detail information such as speech data coding and speaker information is preserved in the metadata file.
* The domain of recording texts is diversified, including interactive Q&A, music search, SNS messages, home command and control, etc.
* Segmented transcripts are also provided.
The corpus aims to support researchers in speech recognition, machine translation, speaker recognition, and other speech-related fields. Therefore, the corpus is totally free for academic use.
# multi-cn
This is a Chinese speech recognition recipe that trains on all Chinese corpora on OpenSLR, including:
* Aidatatang (140 hours)
* Aishell (151 hours)
* MagicData (712 hours)
* Primewords (99 hours)
* ST-CMDS (110 hours)
* THCHS-30 (26 hours)
* optional AISHELL2 (~1000 hours) if available
# [Primewords](http://www.openslr.org/47/)
This free Chinese Mandarin speech corpus set is released by Shanghai Primewords Information Technology Co., Ltd.
The corpus is recorded by smart mobile phones from 296 native Chinese speakers. The transcription accuracy is larger than 98%, at the confidence level of 95%. It is free for academic use.
The mapping between the transcript and utterance is given in JSON format.
# [FreeST](http://www.openslr.org/38/)
*.tar.gz.*
manifest.*
*.md
EN-ZH/
train-split/
test-segment/
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare Ted-En-Zh speech translation dataset
Create manifest files from splited datased.
dev set: tst2010, test set: tst2015
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import soundfile
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--src_dir",
default="",
type=str,
help="Directory to kaldi splited data. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_types_infos = [
('train', 'train-split/train-segment', 'En-Zh/train.en-zh'),
('dev', 'test-segment/tst2010', 'En-Zh/tst2010.en-zh'),
('test', 'test-segment/tst2015', 'En-Zh/tst2015.en-zh')
]
for data_info in data_types_infos:
dtype, audio_relative_dir, text_relative_path = data_info
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
text_path = os.path.join(data_dir, text_relative_path)
audio_dir = os.path.join(data_dir, audio_relative_dir)
for line in codecs.open(text_path, 'r', 'utf-8', errors='ignore'):
line = line.strip()
if len(line) < 1:
continue
audio_id, trancription, translation = line.split('\t')
utt = audio_id.split('.')[0]
audio_path = os.path.join(audio_dir, audio_id)
if os.path.exists(audio_path):
if os.path.getsize(audio_path) < 30000:
continue
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
json_lines.append(
json.dumps(
{
'utt': utt,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': " ".join(translation.split()),
'text1': " ".join(trancription.split())
},
ensure_ascii=False))
total_sec += duration
total_text += len(translation.split())
total_num += 1
if not total_num % 1000:
print(dtype, 'Processed:', total_num)
manifest_path = manifest_path_prefix + '.' + dtype + '.raw'
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
def prepare_dataset(src_dir, manifest_path=None):
"""create manifest file."""
if os.path.isdir(manifest_path):
manifest_path = os.path.join(manifest_path, 'manifest')
if manifest_path:
create_manifest(src_dir, manifest_path)
def main():
if args.src_dir.startswith('~'):
args.src_dir = os.path.expanduser(args.src_dir)
prepare_dataset(src_dir=args.src_dir, manifest_path=args.manifest_prefix)
print("manifest prepare done!")
if __name__ == '__main__':
main()
*.tgz
manifest.*
data_thchs30
resource
test-noise
*.meta
# [THCHS30](http://www.openslr.org/18/)
This is the *data part* of the `THCHS30 2015` acoustic data
& scripts dataset.
The dataset is described in more detail in the paper ``THCHS-30 : A Free
Chinese Speech Corpus`` by Dong Wang, Xuewei Zhang.
A paper (if it can be called a paper) 13 years ago regarding the database:
Dong Wang, Dalei Wu, Xiaoyan Zhu, ``TCMSD: A new Chinese Continuous Speech Database``,
International Conference on Chinese Computing (ICCC'01), 2001, Singapore.
The layout of this data pack is the following:
``data``
``*.wav``
audio data
``*.wav.trn``
transcriptions
``{train,dev,test}``
contain symlinks into the ``data`` directory for both audio and
transcription files. Contents of these directories define the
train/dev/test split of the data.
``{lm_word}``
``word.3gram.lm``
trigram LM based on word
``lexicon.txt``
lexicon based on word
``{lm_phone}``
``phone.3gram.lm``
trigram LM based on phone
``lexicon.txt``
lexicon based on phone
``README.TXT``
this file
Data statistics
===============
Statistics for the data are as follows:
=========== ========== ========== ===========
**dataset** **audio** **#sents** **#words**
=========== ========== ========== ===========
train 25 10,000 198,252
dev 2:14 893 17,743
test 6:15 2,495 49,085
=========== ========== ========== ===========
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare THCHS-30 mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
from multiprocessing.pool import Pool
from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/18'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/18'
DATA_URL = URL_ROOT + '/data_thchs30.tgz'
TEST_NOISE_URL = URL_ROOT + '/test-noise.tgz'
RESOURCE_URL = URL_ROOT + '/resource.tgz'
MD5_DATA = '2d2252bde5c8429929e1841d4cb95e90'
MD5_TEST_NOISE = '7e8a985fb965b84141b68c68556c2030'
MD5_RESOURCE = 'c0b2a565b4970a0c4fe89fefbf2d97e1'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/THCHS30",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def read_trn(filepath):
"""read trn file.
word text in first line.
syllable text in second line.
phoneme text in third line.
Args:
filepath (str): trn path.
Returns:
list(str): (word, syllable, phone)
"""
texts = []
with open(filepath, 'r') as f:
lines = f.read().strip().split('\n')
assert len(lines) == 3, lines
# charactor text, remove withespace
texts.append(''.join(lines[0].split()))
texts.extend(lines[1:])
return texts
def resolve_symlink(filepath):
"""resolve symlink which content is norm file.
Args:
filepath (str): norm file symlink.
"""
sym_path = Path(filepath)
relative_link = sym_path.read_text().strip()
relative = Path(relative_link)
relpath = sym_path.parent / relative
return relpath.resolve()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
file_path = os.path.join(subfolder, fname)
if file_path.endswith('.wav'):
audio_path = os.path.abspath(file_path)
text_path = resolve_symlink(audio_path + '.trn')
else:
continue
assert os.path.exists(audio_path) and os.path.exists(text_path)
audio_id = os.path.basename(audio_path)[:-4]
word_text, syllable_text, phone_text = read_trn(text_path)
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
# not dump alignment infos
json_lines.append(
json.dumps(
{
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': word_text, # charactor
'syllable': syllable_text,
'phone': phone_text,
},
ensure_ascii=False))
total_sec += duration
total_text += len(word_text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):
"""Download, unpack and create manifest file."""
datadir = os.path.join(target_dir, subset)
if not os.path.exists(datadir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
if subset == 'data_thchs30':
create_manifest(datadir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
tasks = [
(DATA_URL, MD5_DATA, args.target_dir, args.manifest_prefix,
"data_thchs30"),
(TEST_NOISE_URL, MD5_TEST_NOISE, args.target_dir, args.manifest_prefix,
"test-noise"),
(RESOURCE_URL, MD5_RESOURCE, args.target_dir, args.manifest_prefix,
"resource"),
]
with Pool(7) as pool:
pool.starmap(prepare_dataset, tasks)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()
TIMIT.*
TIMIT
manifest.*
*.meta
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare Librispeech ASR datasets.
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import re
import string
from pathlib import Path
import soundfile
from utils.utility import unzip
URL_ROOT = ""
MD5_DATA = "45c68037c7fdfe063a43c851f181fb2d"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default='~/.cache/paddle/dataset/speech/timit',
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
#: A string containing Chinese punctuation marks (non-stops).
non_stops = (
# Fullwidth ASCII variants
'\uFF02\uFF03\uFF04\uFF05\uFF06\uFF07\uFF08\uFF09\uFF0A\uFF0B\uFF0C\uFF0D'
'\uFF0F\uFF1A\uFF1B\uFF1C\uFF1D\uFF1E\uFF20\uFF3B\uFF3C\uFF3D\uFF3E\uFF3F'
'\uFF40\uFF5B\uFF5C\uFF5D\uFF5E\uFF5F\uFF60'
# Halfwidth CJK punctuation
'\uFF62\uFF63\uFF64'
# CJK symbols and punctuation
'\u3000\u3001\u3003'
# CJK angle and corner brackets
'\u3008\u3009\u300A\u300B\u300C\u300D\u300E\u300F\u3010\u3011'
# CJK brackets and symbols/punctuation
'\u3014\u3015\u3016\u3017\u3018\u3019\u301A\u301B\u301C\u301D\u301E\u301F'
# Other CJK symbols
'\u3030'
# Special CJK indicators
'\u303E\u303F'
# Dashes
'\u2013\u2014'
# Quotation marks and apostrophe
'\u2018\u2019\u201B\u201C\u201D\u201E\u201F'
# General punctuation
'\u2026\u2027'
# Overscores and underscores
'\uFE4F'
# Small form variants
'\uFE51\uFE54'
# Latin punctuation
'\u00B7')
#: A string of Chinese stops.
stops = (
'\uFF01' # Fullwidth exclamation mark
'\uFF1F' # Fullwidth question mark
'\uFF61' # Halfwidth ideographic full stop
'\u3002' # Ideographic full stop
)
#: A string containing all Chinese punctuation.
punctuation = non_stops + stops
def tn(text):
# lower text
text = text.lower()
# remove punc
text = re.sub(f'[{punctuation}{string.punctuation}]', "", text)
return text
def read_txt(filepath: str) -> str:
with open(filepath, 'r') as f:
line = f.read().strip().split(maxsplit=2)[2]
return tn(line)
def read_algin(filepath: str) -> str:
"""read word or phone alignment file.
<start-sample> <end-sample> <token><newline>
Args:
filepath (str): [description]
Returns:
str: token sepearte by <space>
"""
aligns = [] # (start, end, token)
with open(filepath, 'r') as f:
for line in f:
items = line.strip().split()
# for phone: (Note: beginning and ending silence regions are marked with h#)
if items[2].strip() == 'h#':
continue
aligns.append(items)
return ' '.join([item[2] for item in aligns])
def create_manifest(data_dir, manifest_path_prefix):
"""Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
utts = set()
data_types = ['TRAIN', 'TEST']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = Path(os.path.join(data_dir, dtype))
for fname in sorted(audio_dir.rglob('*.WAV')):
audio_path = fname.resolve() # .WAV
audio_id = audio_path.stem
# if uttid exits, then skipped
if audio_id in utts:
continue
utts.add(audio_id)
text_path = audio_path.with_suffix('.TXT')
phone_path = audio_path.with_suffix('.PHN')
word_path = audio_path.with_suffix('.WRD')
audio_data, samplerate = soundfile.read(
str(audio_path), dtype='int16')
duration = float(len(audio_data) / samplerate)
word_text = read_txt(text_path)
phone_text = read_algin(phone_path)
gender_spk = str(audio_path.parent.stem)
spk = gender_spk[1:]
gender = gender_spk[0]
utt_id = '_'.join([spk, gender, audio_id])
# not dump alignment infos
json_lines.append(
json.dumps(
{
'utt': utt_id,
'feat': str(audio_path),
'feat_shape': (duration, ), # second
'text': word_text, # word
'phone': phone_text,
'spk': spk,
'gender': gender,
},
ensure_ascii=False))
total_sec += duration
total_text += len(word_text.split())
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype.lower()
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype.lower()) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
"""
filepath = os.path.join(target_dir, "TIMIT.zip")
if not os.path.exists(filepath):
print(f"Please download TIMIT.zip into {target_dir}.")
raise FileNotFoundError
if not os.path.exists(os.path.join(target_dir, "TIMIT")):
# check md5sum
assert check_md5sum(filepath, md5sum)
# unpack
unzip(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
# create manifest json file
create_manifest(os.path.join(target_dir, "TIMIT"), manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(URL_ROOT, MD5_DATA, args.target_dir, args.manifest_prefix)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare TIMIT dataset (Standard split from Kaldi)
Create manifest files from splited datased.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import soundfile
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--src_dir",
default="",
type=str,
help="Directory to kaldi splited data. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
phn_path = os.path.join(data_dir, dtype + '.text')
phn_dict = {}
for line in codecs.open(phn_path, 'r', 'utf-8'):
line = line.strip()
if line == '':
continue
audio_id, text = line.split(' ', 1)
phn_dict[audio_id] = text
audio_dir = os.path.join(data_dir, dtype + '_sph.scp')
for line in codecs.open(audio_dir, 'r', 'utf-8'):
audio_id, audio_path = line.strip().split()
# if no transcription for audio then raise error
assert audio_id in phn_dict
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = phn_dict[audio_id]
json_lines.append(
json.dumps(
{
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': text
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype + '.raw'
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
def prepare_dataset(src_dir, manifest_path=None):
"""create manifest file."""
if os.path.isdir(manifest_path):
manifest_path = os.path.join(manifest_path, 'manifest')
if manifest_path:
create_manifest(src_dir, manifest_path)
def main():
if args.src_dir.startswith('~'):
args.src_dir = os.path.expanduser(args.src_dir)
prepare_dataset(src_dir=args.src_dir, manifest_path=args.manifest_prefix)
print("manifest prepare done!")
if __name__ == '__main__':
main()
# G2P
* zh - Chinese G2P
# G2P
* WS
jieba
* G2P
pypinyin
* Tone sandhi
simple
We recommend using [Paraket](https://github.com/PaddlePaddle/Parakeet] [TextFrontEnd](https://github.com/PaddlePaddle/Parakeet/blob/develop/parakeet/frontend/__init__.py) to do G2P.
The phoneme set should be changed, you can reference `examples/thchs30/a0/data/dict/syllable.lexicon`.
## Download Baker dataset
[Baker](https://test.data-baker.com/#/data/index/source) dataset has to be downloaded mannually and moved to './data',
because you will have to pass the `CATTCHA` from a browswe to download the dataset.
## RUN
```
. path.sh
./run.sh
```
## Result
```
exp/
|-- 000001-010000.txt
|-- ref.pinyin
|-- trans.jieba.pinyin
`-- trans.pinyin
0 directories, 4 files
```
```
4f5a368441eb16aaf43dc1972f8b63dd exp/000001-010000.txt
01707896391c2de9b6fc4a39654be942 exp/ref.pinyin
43380ef160f65a23a3a0544700aa49b8 exp/trans.jieba.pinyin
8e6ff1fc22d8e8584082e804e8bcdeb7 exp/trans.pinyin
```
```
==> exp/000001-010000.txt <==
000001 卡尔普#2陪外孙#1玩滑梯#4。
ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 假语村言#2别再#1拥抱我#4。
jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 宝马#1配挂#1跛骡鞍#3,貂蝉#1怨枕#2董翁榻#4。
bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 邓小平#2与#1撒切尔#2会晤#4。
deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4
000005 老虎#1幼崽#2与#1宠物犬#1玩耍#4。
lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3
==> exp/ref.pinyin <==
000001 ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4
000005 lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu2 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan2 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi2 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
==> exp/trans.jieba.pinyin <==
000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4
000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
==> exp/trans.pinyin <==
000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4
000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
```
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
......@@ -6,16 +6,20 @@ stage=-1
stop_stage=100
exp_dir=exp
data_dir=data
data=data
source ${MAIN_ROOT}/utils/parse_options.sh || exit -1
mkdir -p ${exp_dir}
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ];then
mkdir -p ${data}
test -e ${data}/BZNSYP.rar || { echo "Please download BZNSYP.rar and put it in "${data}; exit -1; }
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then
echo "stage 0: Extracting Prosody Labeling"
bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data_dir}
bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data}
fi
# convert transcription in chinese into pinyin with pypinyin or jieba+pypinyin
......
# Ngram LM
* s0 - kenlm ngram lm
# Ngram LM
Train chinese chararctor ngram lm by [kenlm](https://github.com/kpu/kenlm).
```
bash run.sh
```
# Ngram LM
Train chinese chararctor ngram lm by [kenlm](https://github.com/kpu/kenlm).
## Run
```
. path.sh
bash run.sh
```
## Results
```
exp/
|-- text
|-- text.char.tn
|-- text.word.tn
|-- text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa
|-- text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa.klm.bin
|-- text_zh_word_o3_p0_0_0_a22_q8_b8.arpa
`-- text_zh_word_o3_p0_0_0_a22_q8_b8.arpa.klm.bin
0 directories, 7 files
```
```
3ae083627b9b6cef1a82d574d8483f97 exp/text
d97da252d2a63a662af22f98af30cb8c exp/text.char.tn
c18b03005bd094dbfd9b46442be361fd exp/text.word.tn
73dbf50097896eda33985e11e1ba9a3a exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa
01334e2044c474b99c4f2ffbed790626 exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa.klm.bin
36a42de548045b54662411ae7982c77f exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa
332422803ffd73dd7ffd16cd2b0abcd5 exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa.klm.bin
```
```
==> exp/text <==
少先队员因该为老人让坐
祛痘印可以吗?有效果吗?
不知这款牛奶口感怎样? 小孩子喝行吗!
是转基因油?
我家宝宝13斤用多大码的
会起坨吗?
请问给送上楼吗?
亲是送赁上门吗
送货时候有外包装没有还是直接发货过来
会不会有坏的?
==> exp/text.char.tn <==
少 先 队 员 因 该 为 老 人 让 坐
祛 痘 印 可 以 吗 有 效 果 吗
不 知 这 款 牛 奶 口 感 怎 样 小 孩 子 喝 行 吗
是 转 基 因 油
我 家 宝 宝 十 三 斤 用 多 大 码 的
会 起 坨 吗
请 问 给 送 上 楼 吗
亲 是 送 赁 上 门 吗
送 货 时 候 有 外 包 装 没 有 还 是 直 接 发 货 过 来
会 不 会 有 坏 的
==> exp/text.word.tn <==
少先队员 因该 为 老人 让 坐
祛痘 印 可以 吗 有 效果 吗
不知 这 款 牛奶 口感 怎样 小孩子 喝行 吗
是 转基因 油
我家 宝宝 十三斤 用多大码 的
会起 坨 吗
请问 给 送 上楼 吗
亲是 送赁 上门 吗
送货 时候 有 外包装 没有 还是 直接 发货 过来
会 不会 有坏 的
==> exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa <==
\data\
ngram 1=587
ngram 2=395
ngram 3=100
ngram 4=2
ngram 5=0
\1-grams:
-3.272324 <unk> 0
0 <s> -0.36706257
==> exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa <==
\data\
ngram 1=689
ngram 2=1398
ngram 3=1506
\1-grams:
-3.1755018 <unk> 0
0 <s> -0.23069073
-1.2318869 </s> 0
-3.067262 少先队员 -0.051341705
```
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......@@ -7,4 +7,4 @@ export LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH}
\ No newline at end of file
export LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH}
# Punctation Restoration
Please using [PaddleSpeechTask](https://github.com/745165806/PaddleSpeechTask] to do this task.
# [SentencePiece Model](https://github.com/google/sentencepiece)
## Run
Train a `spm` model for English tokenizer.
```
. path.sh
bash run.sh
```
## Results
```
data/
└── lang_char
├── input.bpe
├── input.decode
├── input.txt
├── train_unigram100.model
├── train_unigram100_units.txt
└── train_unigram100.vocab
1 directory, 6 files
```
```
b5a230c26c61db5c36f34e503102f936 data/lang_char/input.bpe
ec5a9b24acc35469229e41256ceaf77d data/lang_char/input.decode
ec5a9b24acc35469229e41256ceaf77d data/lang_char/input.txt
124bf3fe7ce3b73b1994234c15268577 data/lang_char/train_unigram100.model
0df2488cc8eaace95eb12713facb5cf0 data/lang_char/train_unigram100_units.txt
46360cac35c751310e8e8ffd3a034cb5 data/lang_char/train_unigram100.vocab
```
```
==> data/lang_char/input.bpe <==
▁mi ster ▁quilter ▁ is ▁the ▁a p ost le ▁o f ▁the ▁mi d d le ▁c las s es ▁ and ▁we ▁ar e ▁g l a d ▁ to ▁we l c om e ▁h is ▁g o s pe l
▁ n or ▁ is ▁mi ster ▁quilter ' s ▁ma nne r ▁ l ess ▁in ter es t ing ▁tha n ▁h is ▁ma t ter
▁h e ▁ t e ll s ▁us ▁tha t ▁ at ▁ t h is ▁f es t ive ▁ s e ason ▁o f ▁the ▁ y e ar ▁w ith ▁ ch r is t m a s ▁ and ▁ro a s t ▁be e f ▁ l o om ing ▁be fore ▁us ▁ s i mile s ▁d r a w n ▁f r om ▁ e at ing ▁ and ▁it s ▁re s u l t s ▁o c c ur ▁m ost ▁re a di l y ▁ to ▁the ▁ mind
▁h e ▁ ha s ▁g r a v e ▁d o u b t s ▁w h e t h er ▁ s i r ▁f r e d er ic k ▁ l eig h to n ' s ▁w or k ▁ is ▁re all y ▁gre e k ▁a f ter ▁ all ▁ and ▁c a n ▁di s c o v er ▁in ▁it ▁b u t ▁li t t le ▁o f ▁ro ck y ▁it ha c a
▁li nne ll ' s ▁ p ic tur es ▁ar e ▁a ▁ s or t ▁o f ▁ u p ▁g u ar d s ▁ and ▁ at ▁ em ▁painting s ▁ and ▁m ason ' s ▁ e x q u is i t e ▁ i d y ll s ▁ar e ▁a s ▁ n at ion a l ▁a s ▁a ▁ j ing o ▁ p o em ▁mi ster ▁b i r k e t ▁f o ster ' s ▁ l and s c a pe s ▁ s mile ▁ at ▁on e ▁m u ch ▁in ▁the ▁ s a m e ▁w a y ▁tha t ▁mi ster ▁c ar k er ▁us e d ▁ to ▁f las h ▁h is ▁ t e e t h ▁ and ▁mi ster ▁ j o h n ▁c o ll i er ▁g ive s ▁h is ▁ s i t ter ▁a ▁ ch e er f u l ▁ s l a p ▁on ▁the ▁b a ck ▁be fore ▁h
e ▁ s a y s ▁li k e ▁a ▁ s ha m p o o er ▁in ▁a ▁ tur k is h ▁b at h ▁ n e x t ▁ma n
▁it ▁ is ▁o b v i o u s l y ▁ u nne c ess ar y ▁for ▁us ▁ to ▁ p o i n t ▁o u t ▁h o w ▁ l u m i n o u s ▁the s e ▁c rit ic is m s ▁ar e ▁h o w ▁d e l ic at e ▁in ▁ e x p r ess ion
▁on ▁the ▁g e n er a l ▁ p r i n c i p l es ▁o f ▁ar t ▁mi ster ▁quilter ▁w rit es ▁w ith ▁ e qual ▁ l u c i di t y
▁painting ▁h e ▁ t e ll s ▁us ▁ is ▁o f ▁a ▁di f f er e n t ▁ qual i t y ▁ to ▁ma t h em at ic s ▁ and ▁f i nish ▁in ▁ar t ▁ is ▁a d d ing ▁m or e ▁f a c t
▁a s ▁for ▁ e t ch ing s ▁the y ▁ar e ▁o f ▁ t w o ▁ k i n d s ▁b rit is h ▁ and ▁for eig n
▁h e ▁ l a ment s ▁m ost ▁b i t ter l y ▁the ▁di v or c e ▁tha t ▁ ha s ▁be e n ▁ma d e ▁be t w e e n ▁d e c or at ive ▁ar t ▁ and ▁w ha t ▁we ▁us u all y ▁c all ▁ p ic tur es ▁ma k es ▁the ▁c u s t om ar y ▁a p pe a l ▁ to ▁the ▁ las t ▁ j u d g ment ▁ and ▁re mind s ▁us ▁tha t ▁in ▁the ▁gre at ▁d a y s ▁o f ▁ar t ▁mi c ha e l ▁a n g e l o ▁w a s ▁the ▁f ur nish ing ▁ u p h o l ster er
==> data/lang_char/input.decode <==
mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
nor is mister quilter's manner less interesting than his matter
he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind
he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca
linnell's pictures are a sort of up guards and at em paintings and mason's exquisite idylls are as national as a jingo poem mister birket foster's landscapes smile at one much in the same way that mister carker used to flash his teeth and mister john collier gives his sitter a cheerful slap on the back before he says like a shampooer in a turkish bath next man
it is obviously unnecessary for us to point out how luminous these criticisms are how delicate in expression
on the general principles of art mister quilter writes with equal lucidity
painting he tells us is of a different quality to mathematics and finish in art is adding more fact
as for etchings they are of two kinds british and foreign
he laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes the customary appeal to the last judgment and reminds us that in the great days of art michael angelo was the furnishing upholsterer
==> data/lang_char/input.txt <==
mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
nor is mister quilter's manner less interesting than his matter
he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind
he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca
linnell's pictures are a sort of up guards and at em paintings and mason's exquisite idylls are as national as a jingo poem mister birket foster's landscapes smile at one much in the same way that mister carker used to flash his teeth and mister john collier gives his sitter a cheerful slap on the back before he says like a shampooer in a turkish bath next man
it is obviously unnecessary for us to point out how luminous these criticisms are how delicate in expression
on the general principles of art mister quilter writes with equal lucidity
painting he tells us is of a different quality to mathematics and finish in art is adding more fact
as for etchings they are of two kinds british and foreign
he laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes the customary appeal to the last judgment and reminds us that in the great days of art michael angelo was the furnishing upholsterer
==> data/lang_char/train_unigram100_units.txt <==
<blank> 0
<unk> 1
' 2
a 3
all 4
and 5
ar 6
ason 7
at 8
b 9
==> data/lang_char/train_unigram100.vocab <==
<unk> 0
<s> 0
</s> 0
▁ -2.01742
e -2.7203
s -2.82989
t -2.99689
l -3.53267
n -3.84935
o -3.88229
```
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
# thchs30
* a0 for mfa alignment
# THCHS-30 数据集强制对齐实验
-----
本实验对 THCHS-30 中文数据集用 [Montreal-Forced-Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/index.html) 进行强制对齐。
THCHS-30 的文本标注数据分为:
1. 汉字级别(word),该数据集用空格对词进行了划分,我们在使用时按照将不同字之间按空格划分
2. 音节级别(syllable),即汉语中的一个拼音
3. 音素级别(phone),一个拼音有多个音素组成,汉语的声母韵母可以理解为音素,不同的数据集有各自的音素标准,THCHS-30 数据集与标贝 BZNSYP 数据集的音素标准略有不同
数据 A11_0 文本示例如下:
```
绿 是 阳春 烟 景 大块 文章 的 底色 四月 的 林 峦 更是 绿 得 鲜活 秀媚 诗意 盎然↩
lv4 shi4 yang2 chun1 yan1 jing3 da4 kuai4 wen2 zhang1 de5 di3 se4 si4 yue4 de5 lin2 luan2 geng4 shi4 lv4 de5 xian1 huo2 xiu4 mei4 shi1 yi4 ang4 ran2↩
l v4 sh ix4 ii iang2 ch un1 ii ian1 j ing3 d a4 k uai4 uu un2 zh ang1 d e5 d i3 s e4 s iy4 vv ve4 d e5 l in2 l uan2 g eng4 sh ix4 l v4 d e5 x ian1 h uo2 x iu4 m ei4 sh ix1 ii i4 aa ang4 r an2
```
## 开始实验
---
在本项目的 根目录/tools 执行
```
make
```
下载 MFA 的可执行包(也会同时下载本项目所需的其他工具)
执行如下命令:
```
cd a0
./run.sh
```
应用程序会自动下载 THCHS-30数据集,处理成 MFA 所需的文件格式并开始训练,您可以修改 `run.sh` 中的参数 `LEXICON_NAME` 来决定您需要强制对齐的级别(word、syllable 和 phone)
## MFA 所使用的字典
---
MFA 字典的格式请参考: [MFA 官方文档 Dictionary format ](https://montreal-forced-aligner.readthedocs.io/en/latest/dictionary.html)
phone.lexicon 直接使用的是 `THCHS-30/data_thchs30/lm_phone/lexicon.txt`
word.lexicon 考虑到了中文的多音字,使用**带概率的字典**, 生成规则请参考 `local/gen_word2phone.py`
`syllable.lexicon` 获取自 [DNSun/thchs30-pinyin2tone](https://github.com/DNSun/thchs30-pinyin2tone)
## 对齐结果
---
我们提供了三种级别 MFA 训练好的对齐结果、模型和字典(`syllable.lexicon``data/dict` 中,`phone.lexicon`` word.lexicon` 运行数据预处理代码后会自动从原始数据集复制或生成)
**phone 级别:** [phone.lexicon](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/phone/phone.lexicon)[对齐结果](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/phone/thchs30_alignment.tar.gz)[模型](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/phone/thchs30_model.zip)
**syllabel 级别:** [syllable.lexicon](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/syllable/syllable.lexicon)[对齐结果](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/syllable/thchs30_alignment.tar.gz)[模型](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/syllable/thchs30_model.zip)
**word 级别:** [word.lexicon](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/word/word.lexicon)[对齐结果](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/word/thchs30_alignment.tar.gz)[模型](https://paddlespeech.bj.bcebos.com/MFA/THCHS30/word/thchs30_model.zip)
随后,您可以参考 [MFA 官方文档 Align using pretrained models](https://montreal-forced-aligner.readthedocs.io/en/stable/aligning.html#align-using-pretrained-models) 使用我们给您提供好的模型直接对自己的数据集进行强制对齐,注意,您需要使用和模型对应的 lexicon 文件,当文本是汉字时,您需要用空格把不同的**汉字**(而不是词语)分开
A0 aa a0
A1 aa a1
A2 aa a2
A3 aa a3
A4 aa a4
AI0 aa ai0
AI1 aa ai1
AI2 aa ai2
AI3 aa ai3
AI4 aa ai4
AN0 aa an0
AN1 aa an1
AN2 aa an2
AN3 aa an3
AN4 aa an4
ANG0 aa ang0
ANG1 aa ang1
ANG2 aa ang2
ANG3 aa ang3
ANG4 aa ang4
AO0 aa ao0
AO1 aa ao1
AO2 aa ao2
AO3 aa ao3
AO4 aa ao4
BA0 b a0
BA1 b a1
BA2 b a2
BA3 b a3
BA4 b a4
BAI0 b ai0
BAI1 b ai1
BAI2 b ai2
BAI3 b ai3
BAI4 b ai4
BAN0 b an0
BAN1 b an1
BAN2 b an2
BAN3 b an3
BAN4 b an4
BANG0 b ang0
BANG1 b ang1
BANG2 b ang2
BANG3 b ang3
BANG4 b ang4
BAO0 b ao0
BAO1 b ao1
BAO2 b ao2
BAO3 b ao3
BAO4 b ao4
BEI0 b ei0
BEI1 b ei1
BEI2 b ei2
BEI3 b ei3
BEI4 b ei4
BEN0 b en0
BEN1 b en1
BEN2 b en2
BEN3 b en3
BEN4 b en4
BENG0 b eng0
BENG1 b eng1
BENG2 b eng2
BENG3 b eng3
BENG4 b eng4
BI0 b i0
BI1 b i1
BI2 b i2
BI3 b i3
BI4 b i4
BIAN0 b ian0
BIAN1 b ian1
BIAN2 b ian2
BIAN3 b ian3
BIAN4 b ian4
BIAO0 b iao0
BIAO1 b iao1
BIAO2 b iao2
BIAO3 b iao3
BIAO4 b iao4
BIE0 b ie0
BIE1 b ie1
BIE2 b ie2
BIE3 b ie3
BIE4 b ie4
BIN0 b in0
BIN1 b in1
BIN2 b in2
BIN3 b in3
BIN4 b in4
BING0 b ing0
BING1 b ing1
BING2 b ing2
BING3 b ing3
BING4 b ing4
BO0 b o0
BO1 b o1
BO2 b o2
BO3 b o3
BO4 b o4
BU0 b u0
BU1 b u1
BU2 b u2
BU3 b u3
BU4 b u4
CA0 c a0
CA1 c a1
CA2 c a2
CA3 c a3
CA4 c a4
CAI0 c ai0
CAI1 c ai1
CAI2 c ai2
CAI3 c ai3
CAI4 c ai4
CAN0 c an0
CAN1 c an1
CAN2 c an2
CAN3 c an3
CAN4 c an4
CANG0 c ang0
CANG1 c ang1
CANG2 c ang2
CANG3 c ang3
CANG4 c ang4
CAO0 c ao0
CAO1 c ao1
CAO2 c ao2
CAO3 c ao3
CAO4 c ao4
CE0 c e0
CE1 c e1
CE2 c e2
CE3 c e3
CE4 c e4
CEN0 c en0
CEN1 c en1
CEN2 c en2
CEN3 c en3
CEN4 c en4
CENG0 c eng0
CENG1 c eng1
CENG2 c eng2
CENG3 c eng3
CENG4 c eng4
CHA0 ch a0
CHA1 ch a1
CHA2 ch a2
CHA3 ch a3
CHA4 ch a4
CHAI0 ch ai0
CHAI1 ch ai1
CHAI2 ch ai2
CHAI3 ch ai3
CHAI4 ch ai4
CHAN0 ch an0
CHAN1 ch an1
CHAN2 ch an2
CHAN3 ch an3
CHAN4 ch an4
CHANG0 ch ang0
CHANG1 ch ang1
CHANG2 ch ang2
CHANG3 ch ang3
CHANG4 ch ang4
CHAO0 ch ao0
CHAO1 ch ao1
CHAO2 ch ao2
CHAO3 ch ao3
CHAO4 ch ao4
CHE0 ch e0
CHE1 ch e1
CHE2 ch e2
CHE3 ch e3
CHE4 ch e4
CHEN0 ch en0
CHEN1 ch en1
CHEN2 ch en2
CHEN3 ch en3
CHEN4 ch en4
CHENG0 ch eng0
CHENG1 ch eng1
CHENG2 ch eng2
CHENG3 ch eng3
CHENG4 ch eng4
CHI0 ch ix0
CHI1 ch ix1
CHI2 ch ix2
CHI3 ch ix3
CHI4 ch ix4
CHONG0 ch ong0
CHONG1 ch ong1
CHONG2 ch ong2
CHONG3 ch ong3
CHONG4 ch ong4
CHOU0 ch ou0
CHOU1 ch ou1
CHOU2 ch ou2
CHOU3 ch ou3
CHOU4 ch ou4
CHU0 ch u0
CHU1 ch u1
CHU2 ch u2
CHU3 ch u3
CHU4 ch u4
CHUAI0 ch uai0
CHUAI1 ch uai1
CHUAI2 ch uai2
CHUAI3 ch uai3
CHUAI4 ch uai4
CHUAN0 ch uan0
CHUAN1 ch uan1
CHUAN2 ch uan2
CHUAN3 ch uan3
CHUAN4 ch uan4
CHUANG0 ch uang0
CHUANG1 ch uang1
CHUANG2 ch uang2
CHUANG3 ch uang3
CHUANG4 ch uang4
CHUI0 ch ui0
CHUI1 ch ui1
CHUI2 ch ui2
CHUI3 ch ui3
CHUI4 ch ui4
CHUN0 ch un0
CHUN1 ch un1
CHUN2 ch un2
CHUN3 ch un3
CHUN4 ch un4
CHUO0 ch uo0
CHUO1 ch uo1
CHUO2 ch uo2
CHUO3 ch uo3
CHUO4 ch uo4
CI0 c iy0
CI1 c iy1
CI2 c iy2
CI3 c iy3
CI4 c iy4
CONG0 c ong0
CONG1 c ong1
CONG2 c ong2
CONG3 c ong3
CONG4 c ong4
COU0 c ou0
COU1 c ou1
COU2 c ou2
COU3 c ou3
COU4 c ou4
CU0 c u0
CU1 c u1
CU2 c u2
CU3 c u3
CU4 c u4
CUAN0 c uan0
CUAN1 c uan1
CUAN2 c uan2
CUAN3 c uan3
CUAN4 c uan4
CUI0 c ui0
CUI1 c ui1
CUI2 c ui2
CUI3 c ui3
CUI4 c ui4
CUN0 c un0
CUN1 c un1
CUN2 c un2
CUN3 c un3
CUN4 c un4
CUO0 c uo0
CUO1 c uo1
CUO2 c uo2
CUO3 c uo3
CUO4 c uo4
DA0 d a0
DA1 d a1
DA2 d a2
DA3 d a3
DA4 d a4
DAI0 d ai0
DAI1 d ai1
DAI2 d ai2
DAI3 d ai3
DAI4 d ai4
DAN0 d an0
DAN1 d an1
DAN2 d an2
DAN3 d an3
DAN4 d an4
DANG0 d ang0
DANG1 d ang1
DANG2 d ang2
DANG3 d ang3
DANG4 d ang4
DAO0 d ao0
DAO1 d ao1
DAO2 d ao2
DAO3 d ao3
DAO4 d ao4
DE0 d e0
DE1 d e1
DE2 d e2
DE3 d e3
DE4 d e4
DEI0 d ei0
DEI1 d ei1
DEI2 d ei2
DEI3 d ei3
DEI4 d ei4
DEN0 d en0
DEN1 d en1
DEN2 d en2
DEN3 d en3
DEN4 d en4
DENG0 d eng0
DENG1 d eng1
DENG2 d eng2
DENG3 d eng3
DENG4 d eng4
DI0 d i0
DI1 d i1
DI2 d i2
DI3 d i3
DI4 d i4
DIA0 d ia0
DIA1 d ia1
DIA2 d ia2
DIA3 d ia3
DIA4 d ia4
DIAN0 d ian0
DIAN1 d ian1
DIAN2 d ian2
DIAN3 d ian3
DIAN4 d ian4
DIAO0 d iao0
DIAO1 d iao1
DIAO2 d iao2
DIAO3 d iao3
DIAO4 d iao4
DIE0 d ie0
DIE1 d ie1
DIE2 d ie2
DIE3 d ie3
DIE4 d ie4
DING0 d ing0
DING1 d ing1
DING2 d ing2
DING3 d ing3
DING4 d ing4
DIU0 d iu0
DIU1 d iu1
DIU2 d iu2
DIU3 d iu3
DIU4 d iu4
DONG0 d ong0
DONG1 d ong1
DONG2 d ong2
DONG3 d ong3
DONG4 d ong4
DOU0 d ou0
DOU1 d ou1
DOU2 d ou2
DOU3 d ou3
DOU4 d ou4
DU0 d u0
DU1 d u1
DU2 d u2
DU3 d u3
DU4 d u4
DUAN0 d uan0
DUAN1 d uan1
DUAN2 d uan2
DUAN3 d uan3
DUAN4 d uan4
DUI0 d ui0
DUI1 d ui1
DUI2 d ui2
DUI3 d ui3
DUI4 d ui4
DUN0 d un0
DUN1 d un1
DUN2 d un2
DUN3 d un3
DUN4 d un4
DUO0 d uo0
DUO1 d uo1
DUO2 d uo2
DUO3 d uo3
DUO4 d uo4
E0 ee e0
E1 ee e1
E2 ee e2
E3 ee e3
E4 ee e4
EN0 ee en0
EN1 ee en1
EN2 ee en2
EN3 ee en3
EN4 ee en4
ER0 ee er0
ER1 ee er1
ER2 ee er2
ER3 ee er3
ER4 ee er4
FA0 f a0
FA1 f a1
FA2 f a2
FA3 f a3
FA4 f a4
FAN0 f an0
FAN1 f an1
FAN2 f an2
FAN3 f an3
FAN4 f an4
FANG0 f ang0
FANG1 f ang1
FANG2 f ang2
FANG3 f ang3
FANG4 f ang4
FEI0 f ei0
FEI1 f ei1
FEI2 f ei2
FEI3 f ei3
FEI4 f ei4
FEN0 f en0
FEN1 f en1
FEN2 f en2
FEN3 f en3
FEN4 f en4
FENG0 f eng0
FENG1 f eng1
FENG2 f eng2
FENG3 f eng3
FENG4 f eng4
FO0 f o0
FO1 f o1
FO2 f o2
FO3 f o3
FO4 f o4
FOU0 f ou0
FOU1 f ou1
FOU2 f ou2
FOU3 f ou3
FOU4 f ou4
FU0 f u0
FU1 f u1
FU2 f u2
FU3 f u3
FU4 f u4
GA0 g a0
GA1 g a1
GA2 g a2
GA3 g a3
GA4 g a4
GAI0 g ai0
GAI1 g ai1
GAI2 g ai2
GAI3 g ai3
GAI4 g ai4
GAN0 g an0
GAN1 g an1
GAN2 g an2
GAN3 g an3
GAN4 g an4
GANG0 g ang0
GANG1 g ang1
GANG2 g ang2
GANG3 g ang3
GANG4 g ang4
GAO0 g ao0
GAO1 g ao1
GAO2 g ao2
GAO3 g ao3
GAO4 g ao4
GE0 g e0
GE1 g e1
GE2 g e2
GE3 g e3
GE4 g e4
GEI0 g ei0
GEI1 g ei1
GEI2 g ei2
GEI3 g ei3
GEI4 g ei4
GEN0 g en0
GEN1 g en1
GEN2 g en2
GEN3 g en3
GEN4 g en4
GENG0 g eng0
GENG1 g eng1
GENG2 g eng2
GENG3 g eng3
GENG4 g eng4
GONG0 g ong0
GONG1 g ong1
GONG2 g ong2
GONG3 g ong3
GONG4 g ong4
GOU0 g ou0
GOU1 g ou1
GOU2 g ou2
GOU3 g ou3
GOU4 g ou4
GU0 g u0
GU1 g u1
GU2 g u2
GU3 g u3
GU4 g u4
GUA0 g ua0
GUA1 g ua1
GUA2 g ua2
GUA3 g ua3
GUA4 g ua4
GUAI0 g uai0
GUAI1 g uai1
GUAI2 g uai2
GUAI3 g uai3
GUAI4 g uai4
GUAN0 g uan0
GUAN1 g uan1
GUAN2 g uan2
GUAN3 g uan3
GUAN4 g uan4
GUANG0 g uang0
GUANG1 g uang1
GUANG2 g uang2
GUANG3 g uang3
GUANG4 g uang4
GUI0 g ui0
GUI1 g ui1
GUI2 g ui2
GUI3 g ui3
GUI4 g ui4
GUN0 g un0
GUN1 g un1
GUN2 g un2
GUN3 g un3
GUN4 g un4
GUO0 g uo0
GUO1 g uo1
GUO2 g uo2
GUO3 g uo3
GUO4 g uo4
HA0 h a0
HA1 h a1
HA2 h a2
HA3 h a3
HA4 h a4
HAI0 h ai0
HAI1 h ai1
HAI2 h ai2
HAI3 h ai3
HAI4 h ai4
HAN0 h an0
HAN1 h an1
HAN2 h an2
HAN3 h an3
HAN4 h an4
HANG0 h ang0
HANG1 h ang1
HANG2 h ang2
HANG3 h ang3
HANG4 h ang4
HAO0 h ao0
HAO1 h ao1
HAO2 h ao2
HAO3 h ao3
HAO4 h ao4
HE0 h e0
HE1 h e1
HE2 h e2
HE3 h e3
HE4 h e4
HEI0 h ei0
HEI1 h ei1
HEI2 h ei2
HEI3 h ei3
HEI4 h ei4
HEN0 h en0
HEN1 h en1
HEN2 h en2
HEN3 h en3
HEN4 h en4
HENG0 h eng0
HENG1 h eng1
HENG2 h eng2
HENG3 h eng3
HENG4 h eng4
HONG0 h ong0
HONG1 h ong1
HONG2 h ong2
HONG3 h ong3
HONG4 h ong4
HOU0 h ou0
HOU1 h ou1
HOU2 h ou2
HOU3 h ou3
HOU4 h ou4
HU0 h u0
HU1 h u1
HU2 h u2
HU3 h u3
HU4 h u4
HUA0 h ua0
HUA1 h ua1
HUA2 h ua2
HUA3 h ua3
HUA4 h ua4
HUAI0 h uai0
HUAI1 h uai1
HUAI2 h uai2
HUAI3 h uai3
HUAI4 h uai4
HUAN0 h uan0
HUAN1 h uan1
HUAN2 h uan2
HUAN3 h uan3
HUAN4 h uan4
HUANG0 h uang0
HUANG1 h uang1
HUANG2 h uang2
HUANG3 h uang3
HUANG4 h uang4
HUI0 h ui0
HUI1 h ui1
HUI2 h ui2
HUI3 h ui3
HUI4 h ui4
HUN0 h un0
HUN1 h un1
HUN2 h un2
HUN3 h un3
HUN4 h un4
HUO0 h uo0
HUO1 h uo1
HUO2 h uo2
HUO3 h uo3
HUO4 h uo4
JI0 j i0
JI1 j i1
JI2 j i2
JI3 j i3
JI4 j i4
JIA0 j ia0
JIA1 j ia1
JIA2 j ia2
JIA3 j ia3
JIA4 j ia4
JIAN0 j ian0
JIAN1 j ian1
JIAN2 j ian2
JIAN3 j ian3
JIAN4 j ian4
JIANG0 j iang0
JIANG1 j iang1
JIANG2 j iang2
JIANG3 j iang3
JIANG4 j iang4
JIAO0 j iao0
JIAO1 j iao1
JIAO2 j iao2
JIAO3 j iao3
JIAO4 j iao4
JIE0 j ie0
JIE1 j ie1
JIE2 j ie2
JIE3 j ie3
JIE4 j ie4
JIN0 j in0
JIN1 j in1
JIN2 j in2
JIN3 j in3
JIN4 j in4
JING0 j ing0
JING1 j ing1
JING2 j ing2
JING3 j ing3
JING4 j ing4
JIONG0 j iong0
JIONG1 j iong1
JIONG2 j iong2
JIONG3 j iong3
JIONG4 j iong4
JIU0 j iu0
JIU1 j iu1
JIU2 j iu2
JIU3 j iu3
JIU4 j iu4
JU0 j v0
JU1 j v1
JU2 j v2
JU3 j v3
JU4 j v4
JUAN0 j van0
JUAN1 j van1
JUAN2 j van2
JUAN3 j van3
JUAN4 j van4
JUE0 j ve0
JUE1 j ve1
JUE2 j ve2
JUE3 j ve3
JUE4 j ve4
JUN0 j vn0
JUN1 j vn1
JUN2 j vn2
JUN3 j vn3
JUN4 j vn4
KA0 k a0
KA1 k a1
KA2 k a2
KA3 k a3
KA4 k a4
KAI0 k ai0
KAI1 k ai1
KAI2 k ai2
KAI3 k ai3
KAI4 k ai4
KAN0 k an0
KAN1 k an1
KAN2 k an2
KAN3 k an3
KAN4 k an4
KANG0 k ang0
KANG1 k ang1
KANG2 k ang2
KANG3 k ang3
KANG4 k ang4
KAO0 k ao0
KAO1 k ao1
KAO2 k ao2
KAO3 k ao3
KAO4 k ao4
KE0 k e0
KE1 k e1
KE2 k e2
KE3 k e3
KE4 k e4
KEI0 k ei0
KEI1 k ei1
KEI2 k ei2
KEI3 k ei3
KEI4 k ei4
KEN0 k en0
KEN1 k en1
KEN2 k en2
KEN3 k en3
KEN4 k en4
KENG0 k eng0
KENG1 k eng1
KENG2 k eng2
KENG3 k eng3
KENG4 k eng4
KONG0 k ong0
KONG1 k ong1
KONG2 k ong2
KONG3 k ong3
KONG4 k ong4
KOU0 k ou0
KOU1 k ou1
KOU2 k ou2
KOU3 k ou3
KOU4 k ou4
KU0 k u0
KU1 k u1
KU2 k u2
KU3 k u3
KU4 k u4
KUA0 k ua0
KUA1 k ua1
KUA2 k ua2
KUA3 k ua3
KUA4 k ua4
KUAI0 k uai0
KUAI1 k uai1
KUAI2 k uai2
KUAI3 k uai3
KUAI4 k uai4
KUAN0 k uan0
KUAN1 k uan1
KUAN2 k uan2
KUAN3 k uan3
KUAN4 k uan4
KUANG0 k uang0
KUANG1 k uang1
KUANG2 k uang2
KUANG3 k uang3
KUANG4 k uang4
KUI0 k ui0
KUI1 k ui1
KUI2 k ui2
KUI3 k ui3
KUI4 k ui4
KUN0 k un0
KUN1 k un1
KUN2 k un2
KUN3 k un3
KUN4 k un4
KUO0 k uo0
KUO1 k uo1
KUO2 k uo2
KUO3 k uo3
KUO4 k uo4
LA0 l a0
LA1 l a1
LA2 l a2
LA3 l a3
LA4 l a4
LAI0 l ai0
LAI1 l ai1
LAI2 l ai2
LAI3 l ai3
LAI4 l ai4
LAN0 l an0
LAN1 l an1
LAN2 l an2
LAN3 l an3
LAN4 l an4
LANG0 l ang0
LANG1 l ang1
LANG2 l ang2
LANG3 l ang3
LANG4 l ang4
LAO0 l ao0
LAO1 l ao1
LAO2 l ao2
LAO3 l ao3
LAO4 l ao4
LE0 l e0
LE1 l e1
LE2 l e2
LE3 l e3
LE4 l e4
LEI0 l ei0
LEI1 l ei1
LEI2 l ei2
LEI3 l ei3
LEI4 l ei4
LENG0 l eng0
LENG1 l eng1
LENG2 l eng2
LENG3 l eng3
LENG4 l eng4
LI0 l i0
LI1 l i1
LI2 l i2
LI3 l i3
LI4 l i4
LIA0 l ia0
LIA1 l ia1
LIA2 l ia2
LIA3 l ia3
LIA4 l ia4
LIAN0 l ian0
LIAN1 l ian1
LIAN2 l ian2
LIAN3 l ian3
LIAN4 l ian4
LIANG0 l iang0
LIANG1 l iang1
LIANG2 l iang2
LIANG3 l iang3
LIANG4 l iang4
LIAO0 l iao0
LIAO1 l iao1
LIAO2 l iao2
LIAO3 l iao3
LIAO4 l iao4
LIE0 l ie0
LIE1 l ie1
LIE2 l ie2
LIE3 l ie3
LIE4 l ie4
LIN0 l in0
LIN1 l in1
LIN2 l in2
LIN3 l in3
LIN4 l in4
LING0 l ing0
LING1 l ing1
LING2 l ing2
LING3 l ing3
LING4 l ing4
LIU0 l iu0
LIU1 l iu1
LIU2 l iu2
LIU3 l iu3
LIU4 l iu4
LONG0 l ong0
LONG1 l ong1
LONG2 l ong2
LONG3 l ong3
LONG4 l ong4
LOU0 l ou0
LOU1 l ou1
LOU2 l ou2
LOU3 l ou3
LOU4 l ou4
LU0 l u0
LU1 l u1
LU2 l u2
LU3 l u3
LU4 l u4
LUAN0 l uan0
LUAN1 l uan1
LUAN2 l uan2
LUAN3 l uan3
LUAN4 l uan4
LUE0 l ve0
LUE1 l ve1
LUE2 l ve2
LUE3 l ve3
LUE4 l ve4
LVE0 l ve0
LVE1 l ve1
LVE2 l ve2
LVE3 l ve3
LVE4 l ve4
LUN0 l un0
LUN1 l un1
LUN2 l un2
LUN3 l un3
LUN4 l un4
LUO0 l uo0
LUO1 l uo1
LUO2 l uo2
LUO3 l uo3
LUO4 l uo4
LV0 l v0
LV1 l v1
LV2 l v2
LV3 l v3
LV4 l v4
MA0 m a0
MA1 m a1
MA2 m a2
MA3 m a3
MA4 m a4
MAI0 m ai0
MAI1 m ai1
MAI2 m ai2
MAI3 m ai3
MAI4 m ai4
MAN0 m an0
MAN1 m an1
MAN2 m an2
MAN3 m an3
MAN4 m an4
MANG0 m ang0
MANG1 m ang1
MANG2 m ang2
MANG3 m ang3
MANG4 m ang4
MAO0 m ao0
MAO1 m ao1
MAO2 m ao2
MAO3 m ao3
MAO4 m ao4
ME0 m e0
ME1 m e1
ME2 m e2
ME3 m e3
ME4 m e4
MEI0 m ei0
MEI1 m ei1
MEI2 m ei2
MEI3 m ei3
MEI4 m ei4
MEN0 m en0
MEN1 m en1
MEN2 m en2
MEN3 m en3
MEN4 m en4
MENG0 m eng0
MENG1 m eng1
MENG2 m eng2
MENG3 m eng3
MENG4 m eng4
MI0 m i0
MI1 m i1
MI2 m i2
MI3 m i3
MI4 m i4
MIAN0 m ian0
MIAN1 m ian1
MIAN2 m ian2
MIAN3 m ian3
MIAN4 m ian4
MIAO0 m iao0
MIAO1 m iao1
MIAO2 m iao2
MIAO3 m iao3
MIAO4 m iao4
MIE0 m ie0
MIE1 m ie1
MIE2 m ie2
MIE3 m ie3
MIE4 m ie4
MIN0 m in0
MIN1 m in1
MIN2 m in2
MIN3 m in3
MIN4 m in4
MING0 m ing0
MING1 m ing1
MING2 m ing2
MING3 m ing3
MING4 m ing4
MIU0 m iu0
MIU1 m iu1
MIU2 m iu2
MIU3 m iu3
MIU4 m iu4
MO0 m o0
MO1 m o1
MO2 m o2
MO3 m o3
MO4 m o4
MOU0 m ou0
MOU1 m ou1
MOU2 m ou2
MOU3 m ou3
MOU4 m ou4
MU0 m u0
MU1 m u1
MU2 m u2
MU3 m u3
MU4 m u4
NA0 n a0
NA1 n a1
NA2 n a2
NA3 n a3
NA4 n a4
NAI0 n ai0
NAI1 n ai1
NAI2 n ai2
NAI3 n ai3
NAI4 n ai4
NAN0 n an0
NAN1 n an1
NAN2 n an2
NAN3 n an3
NAN4 n an4
NANG0 n ang0
NANG1 n ang1
NANG2 n ang2
NANG3 n ang3
NANG4 n ang4
NAO0 n ao0
NAO1 n ao1
NAO2 n ao2
NAO3 n ao3
NAO4 n ao4
NE0 n e0
NE1 n e1
NE2 n e2
NE3 n e3
NE4 n e4
NEI0 n ei0
NEI1 n ei1
NEI2 n ei2
NEI3 n ei3
NEI4 n ei4
NEN0 n en0
NEN1 n en1
NEN2 n en2
NEN3 n en3
NEN4 n en4
NENG0 n eng0
NENG1 n eng1
NENG2 n eng2
NENG3 n eng3
NENG4 n eng4
NI0 n i0
NI1 n i1
NI2 n i2
NI3 n i3
NI4 n i4
NIAN0 n ian0
NIAN1 n ian1
NIAN2 n ian2
NIAN3 n ian3
NIAN4 n ian4
NIANG0 n iang0
NIANG1 n iang1
NIANG2 n iang2
NIANG3 n iang3
NIANG4 n iang4
NIAO0 n iao0
NIAO1 n iao1
NIAO2 n iao2
NIAO3 n iao3
NIAO4 n iao4
NIE0 n ie0
NIE1 n ie1
NIE2 n ie2
NIE3 n ie3
NIE4 n ie4
NIN0 n in0
NIN1 n in1
NIN2 n in2
NIN3 n in3
NIN4 n in4
NING0 n ing0
NING1 n ing1
NING2 n ing2
NING3 n ing3
NING4 n ing4
NIU0 n iu0
NIU1 n iu1
NIU2 n iu2
NIU3 n iu3
NIU4 n iu4
NONG0 n ong0
NONG1 n ong1
NONG2 n ong2
NONG3 n ong3
NONG4 n ong4
NU0 n u0
NU1 n u1
NU2 n u2
NU3 n u3
NU4 n u4
NUAN0 n uan0
NUAN1 n uan1
NUAN2 n uan2
NUAN3 n uan3
NUAN4 n uan4
NUE0 n ve0
NUE1 n ve1
NUE2 n ve2
NUE3 n ve3
NUE4 n ve4
NVE0 n ve0
NVE1 n ve1
NVE2 n ve2
NVE3 n ve3
NVE4 n ve4
NUO0 n uo0
NUO1 n uo1
NUO2 n uo2
NUO3 n uo3
NUO4 n uo4
NV0 n v0
NV1 n v1
NV2 n v2
NV3 n v3
NV4 n v4
O0 oo o0
O1 oo o1
O2 oo o2
O3 oo o3
O4 oo o4
OU0 oo ou0
OU1 oo ou1
OU2 oo ou2
OU3 oo ou3
OU4 oo ou4
PA0 p a0
PA1 p a1
PA2 p a2
PA3 p a3
PA4 p a4
PAI0 p ai0
PAI1 p ai1
PAI2 p ai2
PAI3 p ai3
PAI4 p ai4
PAN0 p an0
PAN1 p an1
PAN2 p an2
PAN3 p an3
PAN4 p an4
PANG0 p ang0
PANG1 p ang1
PANG2 p ang2
PANG3 p ang3
PANG4 p ang4
PAO0 p ao0
PAO1 p ao1
PAO2 p ao2
PAO3 p ao3
PAO4 p ao4
PEI0 p ei0
PEI1 p ei1
PEI2 p ei2
PEI3 p ei3
PEI4 p ei4
PEN0 p en0
PEN1 p en1
PEN2 p en2
PEN3 p en3
PEN4 p en4
PENG0 p eng0
PENG1 p eng1
PENG2 p eng2
PENG3 p eng3
PENG4 p eng4
PI0 p i0
PI1 p i1
PI2 p i2
PI3 p i3
PI4 p i4
PIAN0 p ian0
PIAN1 p ian1
PIAN2 p ian2
PIAN3 p ian3
PIAN4 p ian4
PIAO0 p iao0
PIAO1 p iao1
PIAO2 p iao2
PIAO3 p iao3
PIAO4 p iao4
PIE0 p ie0
PIE1 p ie1
PIE2 p ie2
PIE3 p ie3
PIE4 p ie4
PIN0 p in0
PIN1 p in1
PIN2 p in2
PIN3 p in3
PIN4 p in4
PING0 p ing0
PING1 p ing1
PING2 p ing2
PING3 p ing3
PING4 p ing4
PO0 p o0
PO1 p o1
PO2 p o2
PO3 p o3
PO4 p o4
POU0 p ou0
POU1 p ou1
POU2 p ou2
POU3 p ou3
POU4 p ou4
PU0 p u0
PU1 p u1
PU2 p u2
PU3 p u3
PU4 p u4
QI0 q i0
QI1 q i1
QI2 q i2
QI3 q i3
QI4 q i4
QIA0 q ia0
QIA1 q ia1
QIA2 q ia2
QIA3 q ia3
QIA4 q ia4
QIAN0 q ian0
QIAN1 q ian1
QIAN2 q ian2
QIAN3 q ian3
QIAN4 q ian4
QIANG0 q iang0
QIANG1 q iang1
QIANG2 q iang2
QIANG3 q iang3
QIANG4 q iang4
QIAO0 q iao0
QIAO1 q iao1
QIAO2 q iao2
QIAO3 q iao3
QIAO4 q iao4
QIE0 q ie0
QIE1 q ie1
QIE2 q ie2
QIE3 q ie3
QIE4 q ie4
QIN0 q in0
QIN1 q in1
QIN2 q in2
QIN3 q in3
QIN4 q in4
QING0 q ing0
QING1 q ing1
QING2 q ing2
QING3 q ing3
QING4 q ing4
QIONG0 q iong0
QIONG1 q iong1
QIONG2 q iong2
QIONG3 q iong3
QIONG4 q iong4
QIU0 q iu0
QIU1 q iu1
QIU2 q iu2
QIU3 q iu3
QIU4 q iu4
QU0 q v0
QU1 q v1
QU2 q v2
QU3 q v3
QU4 q v4
QUAN0 q van0
QUAN1 q van1
QUAN2 q van2
QUAN3 q van3
QUAN4 q van4
QUE0 q ve0
QUE1 q ve1
QUE2 q ve2
QUE3 q ve3
QUE4 q ve4
QUN0 q vn0
QUN1 q vn1
QUN2 q vn2
QUN3 q vn3
QUN4 q vn4
RAN0 r an0
RAN1 r an1
RAN2 r an2
RAN3 r an3
RAN4 r an4
RANG0 r ang0
RANG1 r ang1
RANG2 r ang2
RANG3 r ang3
RANG4 r ang4
RAO0 r ao0
RAO1 r ao1
RAO2 r ao2
RAO3 r ao3
RAO4 r ao4
RE0 r e0
RE1 r e1
RE2 r e2
RE3 r e3
RE4 r e4
REN0 r en0
REN1 r en1
REN2 r en2
REN3 r en3
REN4 r en4
RENG0 r eng0
RENG1 r eng1
RENG2 r eng2
RENG3 r eng3
RENG4 r eng4
RI0 r iz0
RI1 r iz1
RI2 r iz2
RI3 r iz3
RI4 r iz4
RONG0 r ong0
RONG1 r ong1
RONG2 r ong2
RONG3 r ong3
RONG4 r ong4
ROU0 r ou0
ROU1 r ou1
ROU2 r ou2
ROU3 r ou3
ROU4 r ou4
RU0 r u0
RU1 r u1
RU2 r u2
RU3 r u3
RU4 r u4
RUAN0 r uan0
RUAN1 r uan1
RUAN2 r uan2
RUAN3 r uan3
RUAN4 r uan4
RUI0 r ui0
RUI1 r ui1
RUI2 r ui2
RUI3 r ui3
RUI4 r ui4
RUN0 r un0
RUN1 r un1
RUN2 r un2
RUN3 r un3
RUN4 r un4
RUO0 r uo0
RUO1 r uo1
RUO2 r uo2
RUO3 r uo3
RUO4 r uo4
SA0 s a0
SA1 s a1
SA2 s a2
SA3 s a3
SA4 s a4
SAI0 s ai0
SAI1 s ai1
SAI2 s ai2
SAI3 s ai3
SAI4 s ai4
SAN0 s an0
SAN1 s an1
SAN2 s an2
SAN3 s an3
SAN4 s an4
SANG0 s ang0
SANG1 s ang1
SANG2 s ang2
SANG3 s ang3
SANG4 s ang4
SAO0 s ao0
SAO1 s ao1
SAO2 s ao2
SAO3 s ao3
SAO4 s ao4
SE0 s e0
SE1 s e1
SE2 s e2
SE3 s e3
SE4 s e4
SEN0 s en0
SEN1 s en1
SEN2 s en2
SEN3 s en3
SEN4 s en4
SENG0 s eng0
SENG1 s eng1
SENG2 s eng2
SENG3 s eng3
SENG4 s eng4
SHA0 sh a0
SHA1 sh a1
SHA2 sh a2
SHA3 sh a3
SHA4 sh a4
SHAI0 sh ai0
SHAI1 sh ai1
SHAI2 sh ai2
SHAI3 sh ai3
SHAI4 sh ai4
SHAN0 sh an0
SHAN1 sh an1
SHAN2 sh an2
SHAN3 sh an3
SHAN4 sh an4
SHANG0 sh ang0
SHANG1 sh ang1
SHANG2 sh ang2
SHANG3 sh ang3
SHANG4 sh ang4
SHAO0 sh ao0
SHAO1 sh ao1
SHAO2 sh ao2
SHAO3 sh ao3
SHAO4 sh ao4
SHE0 sh e0
SHE1 sh e1
SHE2 sh e2
SHE3 sh e3
SHE4 sh e4
SHEI0 sh ei0
SHEI1 sh ei1
SHEI2 sh ei2
SHEI3 sh ei3
SHEI4 sh ei4
SHEN0 sh en0
SHEN1 sh en1
SHEN2 sh en2
SHEN3 sh en3
SHEN4 sh en4
SHENG0 sh eng0
SHENG1 sh eng1
SHENG2 sh eng2
SHENG3 sh eng3
SHENG4 sh eng4
SHI0 sh ix0
SHI1 sh ix1
SHI2 sh ix2
SHI3 sh ix3
SHI4 sh ix4
SHOU0 sh ou0
SHOU1 sh ou1
SHOU2 sh ou2
SHOU3 sh ou3
SHOU4 sh ou4
SHU0 sh u0
SHU1 sh u1
SHU2 sh u2
SHU3 sh u3
SHU4 sh u4
SHUA0 sh ua0
SHUA1 sh ua1
SHUA2 sh ua2
SHUA3 sh ua3
SHUA4 sh ua4
SHUAI0 sh uai0
SHUAI1 sh uai1
SHUAI2 sh uai2
SHUAI3 sh uai3
SHUAI4 sh uai4
SHUAN0 sh uan0
SHUAN1 sh uan1
SHUAN2 sh uan2
SHUAN3 sh uan3
SHUAN4 sh uan4
SHUANG0 sh uang0
SHUANG1 sh uang1
SHUANG2 sh uang2
SHUANG3 sh uang3
SHUANG4 sh uang4
SHUI0 sh ui0
SHUI1 sh ui1
SHUI2 sh ui2
SHUI3 sh ui3
SHUI4 sh ui4
SHUN0 sh un0
SHUN1 sh un1
SHUN2 sh un2
SHUN3 sh un3
SHUN4 sh un4
SHUO0 sh uo0
SHUO1 sh uo1
SHUO2 sh uo2
SHUO3 sh uo3
SHUO4 sh uo4
SI0 s iy0
SI1 s iy1
SI2 s iy2
SI3 s iy3
SI4 s iy4
SONG0 s ong0
SONG1 s ong1
SONG2 s ong2
SONG3 s ong3
SONG4 s ong4
SOU0 s ou0
SOU1 s ou1
SOU2 s ou2
SOU3 s ou3
SOU4 s ou4
SU0 s u0
SU1 s u1
SU2 s u2
SU3 s u3
SU4 s u4
SUAN0 s uan0
SUAN1 s uan1
SUAN2 s uan2
SUAN3 s uan3
SUAN4 s uan4
SUI0 s ui0
SUI1 s ui1
SUI2 s ui2
SUI3 s ui3
SUI4 s ui4
SUN0 s un0
SUN1 s un1
SUN2 s un2
SUN3 s un3
SUN4 s un4
SUO0 s uo0
SUO1 s uo1
SUO2 s uo2
SUO3 s uo3
SUO4 s uo4
TA0 t a0
TA1 t a1
TA2 t a2
TA3 t a3
TA4 t a4
TAI0 t ai0
TAI1 t ai1
TAI2 t ai2
TAI3 t ai3
TAI4 t ai4
TAN0 t an0
TAN1 t an1
TAN2 t an2
TAN3 t an3
TAN4 t an4
TANG0 t ang0
TANG1 t ang1
TANG2 t ang2
TANG3 t ang3
TANG4 t ang4
TAO0 t ao0
TAO1 t ao1
TAO2 t ao2
TAO3 t ao3
TAO4 t ao4
TE0 t e0
TE1 t e1
TE2 t e2
TE3 t e3
TE4 t e4
TENG0 t eng0
TENG1 t eng1
TENG2 t eng2
TENG3 t eng3
TENG4 t eng4
TI0 t i0
TI1 t i1
TI2 t i2
TI3 t i3
TI4 t i4
TIAN0 t ian0
TIAN1 t ian1
TIAN2 t ian2
TIAN3 t ian3
TIAN4 t ian4
TIAO0 t iao0
TIAO1 t iao1
TIAO2 t iao2
TIAO3 t iao3
TIAO4 t iao4
TIE0 t ie0
TIE1 t ie1
TIE2 t ie2
TIE3 t ie3
TIE4 t ie4
TING0 t ing0
TING1 t ing1
TING2 t ing2
TING3 t ing3
TING4 t ing4
TONG0 t ong0
TONG1 t ong1
TONG2 t ong2
TONG3 t ong3
TONG4 t ong4
TOU0 t ou0
TOU1 t ou1
TOU2 t ou2
TOU3 t ou3
TOU4 t ou4
TU0 t u0
TU1 t u1
TU2 t u2
TU3 t u3
TU4 t u4
TUAN0 t uan0
TUAN1 t uan1
TUAN2 t uan2
TUAN3 t uan3
TUAN4 t uan4
TUI0 t ui0
TUI1 t ui1
TUI2 t ui2
TUI3 t ui3
TUI4 t ui4
TUN0 t un0
TUN1 t un1
TUN2 t un2
TUN3 t un3
TUN4 t un4
TUO0 t uo0
TUO1 t uo1
TUO2 t uo2
TUO3 t uo3
TUO4 t uo4
WA0 uu ua0
WA1 uu ua1
WA2 uu ua2
WA3 uu ua3
WA4 uu ua4
WAI0 uu uai0
WAI1 uu uai1
WAI2 uu uai2
WAI3 uu uai3
WAI4 uu uai4
WAN0 uu uan0
WAN1 uu uan1
WAN2 uu uan2
WAN3 uu uan3
WAN4 uu uan4
WANG0 uu uang0
WANG1 uu uang1
WANG2 uu uang2
WANG3 uu uang3
WANG4 uu uang4
WEI0 uu ui0
WEI1 uu ui1
WEI2 uu ui2
WEI3 uu ui3
WEI4 uu ui4
WEN0 uu un0
WEN1 uu un1
WEN2 uu un2
WEN3 uu un3
WEN4 uu un4
WENG0 uu ueng0
WENG1 uu ueng1
WENG2 uu ueng2
WENG3 uu ueng3
WENG4 uu ueng4
WO0 uu uo0
WO1 uu uo1
WO2 uu uo2
WO3 uu uo3
WO4 uu uo4
WU0 uu u0
WU1 uu u1
WU2 uu u2
WU3 uu u3
WU4 uu u4
XI0 x i0
XI1 x i1
XI2 x i2
XI3 x i3
XI4 x i4
XIA0 x ia0
XIA1 x ia1
XIA2 x ia2
XIA3 x ia3
XIA4 x ia4
XIAN0 x ian0
XIAN1 x ian1
XIAN2 x ian2
XIAN3 x ian3
XIAN4 x ian4
XIANG0 x iang0
XIANG1 x iang1
XIANG2 x iang2
XIANG3 x iang3
XIANG4 x iang4
XIAO0 x iao0
XIAO1 x iao1
XIAO2 x iao2
XIAO3 x iao3
XIAO4 x iao4
XIE0 x ie0
XIE1 x ie1
XIE2 x ie2
XIE3 x ie3
XIE4 x ie4
XIN0 x in0
XIN1 x in1
XIN2 x in2
XIN3 x in3
XIN4 x in4
XING0 x ing0
XING1 x ing1
XING2 x ing2
XING3 x ing3
XING4 x ing4
XIONG0 x iong0
XIONG1 x iong1
XIONG2 x iong2
XIONG3 x iong3
XIONG4 x iong4
XIU0 x iu0
XIU1 x iu1
XIU2 x iu2
XIU3 x iu3
XIU4 x iu4
XU0 x v0
XU1 x v1
XU2 x v2
XU3 x v3
XU4 x v4
XUAN0 x van0
XUAN1 x van1
XUAN2 x van2
XUAN3 x van3
XUAN4 x van4
XUE0 x ve0
XUE1 x ve1
XUE2 x ve2
XUE3 x ve3
XUE4 x ve4
XUN0 x vn0
XUN1 x vn1
XUN2 x vn2
XUN3 x vn3
XUN4 x vn4
YA0 ii ia0
YA1 ii ia1
YA2 ii ia2
YA3 ii ia3
YA4 ii ia4
YAN0 ii ian0
YAN1 ii ian1
YAN2 ii ian2
YAN3 ii ian3
YAN4 ii ian4
YANG0 ii iang0
YANG1 ii iang1
YANG2 ii iang2
YANG3 ii iang3
YANG4 ii iang4
YAO0 ii iao0
YAO1 ii iao1
YAO2 ii iao2
YAO3 ii iao3
YAO4 ii iao4
YE0 ii ie0
YE1 ii ie1
YE2 ii ie2
YE3 ii ie3
YE4 ii ie4
YI0 ii i0
YI1 ii i1
YI2 ii i2
YI3 ii i3
YI4 ii i4
YIN0 ii in0
YIN1 ii in1
YIN2 ii in2
YIN3 ii in3
YIN4 ii in4
YING0 ii ing0
YING1 ii ing1
YING2 ii ing2
YING3 ii ing3
YING4 ii ing4
YO0 ii ou0
YO1 ii ou1
YO2 ii ou2
YO3 ii ou3
YO4 ii ou4
YONG0 ii iong0
YONG1 ii iong1
YONG2 ii iong2
YONG3 ii iong3
YONG4 ii iong4
YOU0 ii iu0
YOU1 ii iu1
YOU2 ii iu2
YOU3 ii iu3
YOU4 ii iu4
YU0 vv v0
YU1 vv v1
YU2 vv v2
YU3 vv v3
YU4 vv v4
YUAN0 vv van0
YUAN1 vv van1
YUAN2 vv van2
YUAN3 vv van3
YUAN4 vv van4
YUE0 vv ve0
YUE1 vv ve1
YUE2 vv ve2
YUE3 vv ve3
YUE4 vv ve4
YUN0 vv vn0
YUN1 vv vn1
YUN2 vv vn2
YUN3 vv vn3
YUN4 vv vn4
YUO0 ii ou0
YUO1 ii ou1
YUO2 ii ou2
YUO3 ii ou3
YUO4 ii ou4
ZA0 z a0
ZA1 z a1
ZA2 z a2
ZA3 z a3
ZA4 z a4
ZAI0 z ai0
ZAI1 z ai1
ZAI2 z ai2
ZAI3 z ai3
ZAI4 z ai4
ZAN0 z an0
ZAN1 z an1
ZAN2 z an2
ZAN3 z an3
ZAN4 z an4
ZANG0 z ang0
ZANG1 z ang1
ZANG2 z ang2
ZANG3 z ang3
ZANG4 z ang4
ZAO0 z ao0
ZAO1 z ao1
ZAO2 z ao2
ZAO3 z ao3
ZAO4 z ao4
ZE0 z e0
ZE1 z e1
ZE2 z e2
ZE3 z e3
ZE4 z e4
ZEI0 z ei0
ZEI1 z ei1
ZEI2 z ei2
ZEI3 z ei3
ZEI4 z ei4
ZEN0 z en0
ZEN1 z en1
ZEN2 z en2
ZEN3 z en3
ZEN4 z en4
ZENG0 z eng0
ZENG1 z eng1
ZENG2 z eng2
ZENG3 z eng3
ZENG4 z eng4
ZHA0 zh a0
ZHA1 zh a1
ZHA2 zh a2
ZHA3 zh a3
ZHA4 zh a4
ZHAI0 zh ai0
ZHAI1 zh ai1
ZHAI2 zh ai2
ZHAI3 zh ai3
ZHAI4 zh ai4
ZHAN0 zh an0
ZHAN1 zh an1
ZHAN2 zh an2
ZHAN3 zh an3
ZHAN4 zh an4
ZHANG0 zh ang0
ZHANG1 zh ang1
ZHANG2 zh ang2
ZHANG3 zh ang3
ZHANG4 zh ang4
ZHAO0 zh ao0
ZHAO1 zh ao1
ZHAO2 zh ao2
ZHAO3 zh ao3
ZHAO4 zh ao4
ZHE0 zh e0
ZHE1 zh e1
ZHE2 zh e2
ZHE3 zh e3
ZHE4 zh e4
ZHEI0 zh ei0
ZHEI1 zh ei1
ZHEI2 zh ei2
ZHEI3 zh ei3
ZHEI4 zh ei4
ZHEN0 zh en0
ZHEN1 zh en1
ZHEN2 zh en2
ZHEN3 zh en3
ZHEN4 zh en4
ZHENG0 zh eng0
ZHENG1 zh eng1
ZHENG2 zh eng2
ZHENG3 zh eng3
ZHENG4 zh eng4
ZHI0 zh ix0
ZHI1 zh ix1
ZHI2 zh ix2
ZHI3 zh ix3
ZHI4 zh ix4
ZHONG0 zh ong0
ZHONG1 zh ong1
ZHONG2 zh ong2
ZHONG3 zh ong3
ZHONG4 zh ong4
ZHOU0 zh ou0
ZHOU1 zh ou1
ZHOU2 zh ou2
ZHOU3 zh ou3
ZHOU4 zh ou4
ZHU0 zh u0
ZHU1 zh u1
ZHU2 zh u2
ZHU3 zh u3
ZHU4 zh u4
ZHUA0 zh ua0
ZHUA1 zh ua1
ZHUA2 zh ua2
ZHUA3 zh ua3
ZHUA4 zh ua4
ZHUAI0 zh uai0
ZHUAI1 zh uai1
ZHUAI2 zh uai2
ZHUAI3 zh uai3
ZHUAI4 zh uai4
ZHUAN0 zh uan0
ZHUAN1 zh uan1
ZHUAN2 zh uan2
ZHUAN3 zh uan3
ZHUAN4 zh uan4
ZHUANG0 zh uang0
ZHUANG1 zh uang1
ZHUANG2 zh uang2
ZHUANG3 zh uang3
ZHUANG4 zh uang4
ZHUI0 zh ui0
ZHUI1 zh ui1
ZHUI2 zh ui2
ZHUI3 zh ui3
ZHUI4 zh ui4
ZHUN0 zh un0
ZHUN1 zh un1
ZHUN2 zh un2
ZHUN3 zh un3
ZHUN4 zh un4
ZHUO0 zh uo0
ZHUO1 zh uo1
ZHUO2 zh uo2
ZHUO3 zh uo3
ZHUO4 zh uo4
ZI0 z iy0
ZI1 z iy1
ZI2 z iy2
ZI3 z iy3
ZI4 z iy4
ZONG0 z ong0
ZONG1 z ong1
ZONG2 z ong2
ZONG3 z ong3
ZONG4 z ong4
ZOU0 z ou0
ZOU1 z ou1
ZOU2 z ou2
ZOU3 z ou3
ZOU4 z ou4
ZU0 z u0
ZU1 z u1
ZU2 z u2
ZU3 z u3
ZU4 z u4
ZUAN0 z uan0
ZUAN1 z uan1
ZUAN2 z uan2
ZUAN3 z uan3
ZUAN4 z uan4
ZUI0 z ui0
ZUI1 z ui1
ZUI2 z ui2
ZUI3 z ui3
ZUI4 z ui4
ZUN0 z un0
ZUN1 z un1
ZUN2 z un2
ZUN3 z un3
ZUN4 z un4
ZUO0 z uo0
ZUO1 z uo1
ZUO2 z uo2
ZUO3 z uo3
ZUO4 z uo4
EI0 ee ei0
EI1 ee ei1
EI2 ee ei2
EI3 ee ei3
EI4 ee ei4
TEI0 t ei0
TEI1 t ei1
TEI2 t ei2
TEI3 t ei3
TEI4 t ei4
HNG0 ee eng0
HNG1 ee eng1
HNG2 ee eng2
HNG3 ee eng3
HNG4 ee eng4
LO0 l o0
LO1 l o1
LO2 l o2
LO3 l o3
LO4 l o4
N0 ee en0
N1 ee en1
N2 ee en2
N3 ee en3
N4 ee en4
NG0 ee eng0
NG1 ee eng1
NG2 ee eng2
NG3 ee eng3
NG4 ee eng4
NOU0 n ao0
NOU1 n ao1
NOU2 n ao2
NOU3 n ao3
NOU4 n ao4
SEI0 s ei0
SEI1 s ei1
SEI2 s ei2
SEI3 s ei3
SEI4 s ei4
A5 aa a5
AI5 aa ai5
AN5 aa an5
ANG5 aa ang5
AO5 aa ao5
BA5 b a5
BAI5 b ai5
BAN5 b an5
BANG5 b ang5
BAO5 b ao5
BEI5 b ei5
BEN5 b en5
BENG5 b eng5
BI5 b i5
BIAN5 b ian5
BIAO5 b iao5
BIE5 b ie5
BIN5 b in5
BING5 b ing5
BO5 b o5
BU5 b u5
CA5 c a5
CAI5 c ai5
CAN5 c an5
CANG5 c ang5
CAO5 c ao5
CE5 c e5
CEN5 c en5
CENG5 c eng5
CHA5 ch a5
CHAI5 ch ai5
CHAN5 ch an5
CHANG5 ch ang5
CHAO5 ch ao5
CHE5 ch e5
CHEN5 ch en5
CHENG5 ch eng5
CHI5 ch ix5
CHONG5 ch ong5
CHOU5 ch ou5
CHU5 ch u5
CHUAI5 ch uai5
CHUAN5 ch uan5
CHUANG5 ch uang5
CHUI5 ch ui5
CHUN5 ch un5
CHUO5 ch uo5
CI5 c iy5
CONG5 c ong5
COU5 c ou5
CU5 c u5
CUAN5 c uan5
CUI5 c ui5
CUN5 c un5
CUO5 c uo5
DA5 d a5
DAI5 d ai5
DAN5 d an5
DANG5 d ang5
DAO5 d ao5
DE5 d e5
DEI5 d ei5
DEN5 d en5
DENG5 d eng5
DI5 d i5
DIA5 d ia5
DIAN5 d ian5
DIAO5 d iao5
DIE5 d ie5
DING5 d ing5
DIU5 d iu5
DONG5 d ong5
DOU5 d ou5
DU5 d u5
DUAN5 d uan5
DUI5 d ui5
DUN5 d un5
DUO5 d uo5
E5 ee e5
EN5 ee en5
ER5 ee er5
FA5 f a5
FAN5 f an5
FANG5 f ang5
FEI5 f ei5
FEN5 f en5
FENG5 f eng5
FO5 f o5
FOU5 f ou5
FU5 f u5
GA5 g a5
GAI5 g ai5
GAN5 g an5
GANG5 g ang5
GAO5 g ao5
GE5 g e5
GEI5 g ei5
GEN5 g en5
GENG5 g eng5
GONG5 g ong5
GOU5 g ou5
GU5 g u5
GUA5 g ua5
GUAI5 g uai5
GUAN5 g uan5
GUANG5 g uang5
GUI5 g ui5
GUN5 g un5
GUO5 g uo5
HA5 h a5
HAI5 h ai5
HAN5 h an5
HANG5 h ang5
HAO5 h ao5
HE5 h e5
HEI5 h ei5
HEN5 h en5
HENG5 h eng5
HONG5 h ong5
HOU5 h ou5
HU5 h u5
HUA5 h ua5
HUAI5 h uai5
HUAN5 h uan5
HUANG5 h uang5
HUI5 h ui5
HUN5 h un5
HUO5 h uo5
JI5 j i5
JIA5 j ia5
JIAN5 j ian5
JIANG5 j iang5
JIAO5 j iao5
JIE5 j ie5
JIN5 j in5
JING5 j ing5
JIONG5 j iong5
JIU5 j iu5
JU5 j v5
JUAN5 j van5
JUE5 j ve5
JUN5 j vn5
KA5 k a5
KAI5 k ai5
KAN5 k an5
KANG5 k ang5
KAO5 k ao5
KE5 k e5
KEI5 k ei5
KEN5 k en5
KENG5 k eng5
KONG5 k ong5
KOU5 k ou5
KU5 k u5
KUA5 k ua5
KUAI5 k uai5
KUAN5 k uan5
KUANG5 k uang5
KUI5 k ui5
KUN5 k un5
KUO5 k uo5
LA5 l a5
LAI5 l ai5
LAN5 l an5
LANG5 l ang5
LAO5 l ao5
LE5 l e5
LEI5 l ei5
LENG5 l eng5
LI5 l i5
LIA5 l ia5
LIAN5 l ian5
LIANG5 l iang5
LIAO5 l iao5
LIE5 l ie5
LIN5 l in5
LING5 l ing5
LIU5 l iu5
LONG5 l ong5
LOU5 l ou5
LU5 l u5
LUAN5 l uan5
LUE5 l ve5
LVE5 l ve5
LUN5 l un5
LUO5 l uo5
LV5 l v5
MA5 m a5
MAI5 m ai5
MAN5 m an5
MANG5 m ang5
MAO5 m ao5
ME5 m e5
MEI5 m ei5
MEN5 m en5
MENG5 m eng5
MI5 m i5
MIAN5 m ian5
MIAO5 m iao5
MIE5 m ie5
MIN5 m in5
MING5 m ing5
MIU5 m iu5
MO5 m o5
MOU5 m ou5
MU5 m u5
NA5 n a5
NAI5 n ai5
NAN5 n an5
NANG5 n ang5
NAO5 n ao5
NE5 n e5
NEI5 n ei5
NEN5 n en5
NENG5 n eng5
NI5 n i5
NIAN5 n ian5
NIANG5 n iang5
NIAO5 n iao5
NIE5 n ie5
NIN5 n in5
NING5 n ing5
NIU5 n iu5
NONG5 n ong5
NU5 n u5
NUAN5 n uan5
NUE5 n ve5
NVE5 n ve5
NUO5 n uo5
NV5 n v5
O5 oo o5
OU5 oo ou5
PA5 p a5
PAI5 p ai5
PAN5 p an5
PANG5 p ang5
PAO5 p ao5
PEI5 p ei5
PEN5 p en5
PENG5 p eng5
PI5 p i5
PIAN5 p ian5
PIAO5 p iao5
PIE5 p ie5
PIN5 p in5
PING5 p ing5
PO5 p o5
POU5 p ou5
PU5 p u5
QI5 q i5
QIA5 q ia5
QIAN5 q ian5
QIANG5 q iang5
QIAO5 q iao5
QIE5 q ie5
QIN5 q in5
QING5 q ing5
QIONG5 q iong5
QIU5 q iu5
QU5 q v5
QUAN5 q van5
QUE5 q ve5
QUN5 q vn5
RAN5 r an5
RANG5 r ang5
RAO5 r ao5
RE5 r e5
REN5 r en5
RENG5 r eng5
RI5 r iz5
RONG5 r ong5
ROU5 r ou5
RU5 r u5
RUAN5 r uan5
RUI5 r ui5
RUN5 r un5
RUO5 r uo5
SA5 s a5
SAI5 s ai5
SAN5 s an5
SANG5 s ang5
SAO5 s ao5
SE5 s e5
SEN5 s en5
SENG5 s eng5
SHA5 sh a5
SHAI5 sh ai5
SHAN5 sh an5
SHANG5 sh ang5
SHAO5 sh ao5
SHE5 sh e5
SHEI5 sh ei5
SHEN5 sh en5
SHENG5 sh eng5
SHI5 sh ix5
SHOU5 sh ou5
SHU5 sh u5
SHUA5 sh ua5
SHUAI5 sh uai5
SHUAN5 sh uan5
SHUANG5 sh uang5
SHUI5 sh ui5
SHUN5 sh un5
SHUO5 sh uo5
SI5 s iy5
SONG5 s ong5
SOU5 s ou5
SU5 s u5
SUAN5 s uan5
SUI5 s ui5
SUN5 s un5
SUO5 s uo5
TA5 t a5
TAI5 t ai5
TAN5 t an5
TANG5 t ang5
TAO5 t ao5
TE5 t e5
TENG5 t eng5
TI5 t i5
TIAN5 t ian5
TIAO5 t iao5
TIE5 t ie5
TING5 t ing5
TONG5 t ong5
TOU5 t ou5
TU5 t u5
TUAN5 t uan5
TUI5 t ui5
TUN5 t un5
TUO5 t uo5
WA5 uu ua5
WAI5 uu uai5
WAN5 uu uan5
WANG5 uu uang5
WEI5 uu ui5
WEN5 uu un5
WENG5 uu ueng5
WO5 uu uo5
WU5 uu u5
XI5 x i5
XIA5 x ia5
XIAN5 x ian5
XIANG5 x iang5
XIAO5 x iao5
XIE5 x ie5
XIN5 x in5
XING5 x ing5
XIONG5 x iong5
XIU5 x iu5
XU5 x v5
XUAN5 x van5
XUE5 x ve5
XUN5 x vn5
YA5 ii ia5
YAN5 ii ian5
YANG5 ii iang5
YAO5 ii iao5
YE5 ii ie5
YI5 ii i5
YIN5 ii in5
YING5 ii ing5
YO5 ii ou5
YONG5 ii iong5
YOU5 ii iu5
YU5 vv v5
YUAN5 vv van5
YUE5 vv ve5
YUN5 vv vn5
YUO5 ii ou5
ZA5 z a5
ZAI5 z ai5
ZAN5 z an5
ZANG5 z ang5
ZAO5 z ao5
ZE5 z e5
ZEI5 z ei5
ZEN5 z en5
ZENG5 z eng5
ZHA5 zh a5
ZHAI5 zh ai5
ZHAN5 zh an5
ZHANG5 zh ang5
ZHAO5 zh ao5
ZHE5 zh e5
ZHEI5 zh ei5
ZHEN5 zh en5
ZHENG5 zh eng5
ZHI5 zh ix5
ZHONG5 zh ong5
ZHOU5 zh ou5
ZHU5 zh u5
ZHUA5 zh ua5
ZHUAI5 zh uai5
ZHUAN5 zh uan5
ZHUANG5 zh uang5
ZHUI5 zh ui5
ZHUN5 zh un5
ZHUO5 zh uo5
ZI5 z iy5
ZONG5 z ong5
ZOU5 z ou5
ZU5 z u5
ZUAN5 z uan5
ZUI5 z ui5
ZUN5 z un5
ZUO5 z uo5
EI5 ee ei5
TEI5 t ei5
HNG5 ee eng5
LO5 l o5
N5 ee en5
NG5 ee eng5
NOU5 n ao5
SEI5 s ei5
\ No newline at end of file
#! /usr/bin/env bash
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
LEXICON_NAME=$1
# download data, generate manifests
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
python3 ${TARGET_DIR}/thchs30/thchs30.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/thchs30"
if [ $? -ne 0 ]; then
echo "Prepare THCHS-30 failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# dump manifest to data/
python3 ${MAIN_ROOT}/utils/dump_manifest.py --manifest-path=data/manifest.train --output-dir=data
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# copy files to data/dict to gen word.lexicon
cp ${TARGET_DIR}/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
cp ${TARGET_DIR}/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
# copy phone.lexicon to data/dict
cp ${TARGET_DIR}/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# gen word.lexicon
python local/gen_word2phone.py --lexicon-files="data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2" --output-path=data/dict/word.lexicon
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# reorganize dataset for MFA
if [ ! -d $EXP_DIR/thchs30_corpus ]; then
echo "reorganizing thchs30 corpus..."
python local/reorganize_thchs30.py --root-dir=data --output-dir=data/thchs30_corpus --script-type=$LEXICON_NAME
echo "reorganization done."
fi
fi
echo "THCHS-30 data preparation done."
exit 0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gen Chinese characters to THCHS30-30 phone lexicon using THCHS30-30's lexicon
file1: THCHS-30/data_thchs30/lm_word/lexicon.txt
file2: THCHS-30/resource/dict/lexicon.txt
"""
import argparse
from collections import defaultdict
from pathlib import Path
from typing import List
from typing import Union
# key: (cn, ('ee', 'er4')),value: count
cn_phones_counter = defaultdict(int)
# key: cn, value: list of (phones, num)
cn_counter = defaultdict(list)
# key: cn, value: list of (phones, probabilities)
cn_counter_p = defaultdict(list)
def is_Chinese(ch):
if '\u4e00' <= ch <= '\u9fff':
return True
return False
def proc_line(line: str):
line = line.strip()
if is_Chinese(line[0]):
line_list = line.split()
cn_list = line_list[0]
phone_list = line_list[1:]
if len(cn_list) == len(phone_list) / 2:
new_phone_list = [(phone_list[i], phone_list[i + 1])
for i in range(0, len(phone_list), 2)]
assert len(cn_list) == len(new_phone_list)
for idx, cn in enumerate(cn_list):
phones = new_phone_list[idx]
cn_phones_counter[(cn, phones)] += 1
"""
example lines of output
the first column is a Chinese character
the second is the probability of this pronunciation
and the rest are the phones of this pronunciation
一 0.22 ii i1↩
一 0.45 ii i4↩
一 0.32 ii i2↩
一 0.01 ii i5
"""
def gen_lexicon(lexicon_files: List[Union[str, Path]],
output_path: Union[str, Path]):
for file_path in lexicon_files:
with open(file_path, "r") as f1:
for line in f1:
proc_line(line)
for key in cn_phones_counter:
cn = key[0]
cn_counter[cn].append((key[1], cn_phones_counter[key]))
for key in cn_counter:
phone_count_list = cn_counter[key]
count_sum = sum([x[1] for x in phone_count_list])
for item in phone_count_list:
p = item[1] / count_sum
p = round(p, 2)
if p > 0:
cn_counter_p[key].append((item[0], p))
with open(output_path, "w") as wf:
for key in cn_counter_p:
phone_p_list = cn_counter_p[key]
for item in phone_p_list:
phones, p = item
wf.write(key + " " + str(p) + " " + " ".join(phones) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Gen Chinese characters to phone lexicon for THCHS-30 dataset"
)
# A line of word_lexicon:
# 一丁点 ii i4 d ing1 d ian3
# the first is word, and the rest are the phones of the word, and the len of phones is twice of the word's len
parser.add_argument(
"--lexicon-files",
type=str,
default="data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2",
help="lm_word_lexicon files")
parser.add_argument(
"--output-path",
type=str,
default="data/dict/word.lexicon",
help="path to save output word2phone lexicon")
args = parser.parse_args()
lexicon_files = args.lexicon_files.split(" ")
output_path = Path(args.output_path).expanduser()
gen_lexicon(lexicon_files, output_path)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recorganize THCHS-30 for MFA
read manifest.train from root-dir
Link *.wav to output-dir
dump *.lab from manifest.train, such as: text、syllable and phone
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
"""
import argparse
import os
from pathlib import Path
from typing import Union
def link_wav(root_dir: Union[str, Path], output_dir: Union[str, Path]):
wav_scp_path = root_dir / 'wav.scp'
with open(wav_scp_path, 'r') as rf:
for line in rf:
utt, feat = line.strip().split()
wav_path = feat
wav_name = wav_path.split("/")[-1]
new_wav_path = output_dir / wav_name
os.symlink(wav_path, new_wav_path)
def write_lab(root_dir: Union[str, Path],
output_dir: Union[str, Path],
script_type='phone'):
# script_type can in {'word', 'syllable', 'phone'}
json_name = 'text.' + script_type
json_path = root_dir / json_name
with open(json_path, 'r') as rf:
for line in rf:
line = line.strip().split()
utt_id = line[0]
context = ' '.join(line[1:])
transcript_name = utt_id + '.lab'
transcript_path = output_dir / transcript_name
with open(transcript_path, 'wt') as wf:
if script_type == 'word':
# add space between chinese char
context = ''.join([f + ' ' for f in context])[:-1]
wf.write(context + "\n")
def reorganize_thchs30(root_dir: Union[str, Path],
output_dir: Union[str, Path]=None,
script_type='phone'):
output_dir.mkdir(parents=True, exist_ok=True)
link_wav(root_dir, output_dir)
write_lab(root_dir, output_dir, script_type)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Reorganize THCHS-30 dataset for MFA")
parser.add_argument("--root-dir", type=str, help="path to thchs30 dataset.")
parser.add_argument(
"--output-dir",
type=str,
help="path to save outputs (audio and transcriptions)")
parser.add_argument(
"--script-type",
type=str,
default="phone",
help="type of lab ('word'/'syllable'/'phone')")
args = parser.parse_args()
root_dir = Path(args.root_dir).expanduser()
output_dir = Path(args.output_dir).expanduser()
reorganize_thchs30(root_dir, output_dir, args.script_type)
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
# MFA is in tools
export PATH=${MAIN_ROOT}/tools/montreal-forced-aligner/bin:$PATH
\ No newline at end of file
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
EXP_DIR=exp
# LEXICON_NAME in {'phone', 'syllable', 'word'}
LEXICON_NAME='phone'
# set MFA num_jobs as half of machine's cpu core number
NUM_JOBS=$((`nproc`/2))
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# download dataset、unzip and generate manifest
# gen lexicon relink gen dump
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
echo "Start prepare thchs30 data for MFA ..."
bash ./local/data.sh $LEXICON_NAME || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# run MFA
if [ ! -d "$EXP_DIR/thchs30_alignment" ]; then
echo "Start MFA training ..."
mfa_train_and_align data/thchs30_corpus data/dict/$LEXICON_NAME.lexicon $EXP_DIR/thchs30_alignment -o $EXP_DIR/thchs30_model --clean --verbose --temp_directory exp/.mfa_train_and_align --num_jobs $NUM_JOBS
echo "MFA training done! \nresults: $EXP_DIR/thchs30_alignment \nmodel: $EXP_DIR/thchs30_model\n"
fi
fi
# Regular expression based text normalization for Chinese
For simplicity and ease of implementation, text normalization is basically done by rules and dictionaries. Here's an example.
## Run
```
. path.sh
bash run.sh
```
## Results
```
exp/
`-- normalized.txt
0 directories, 1 file
```
```
aff31f8aa08e2a7360228c9ce5886b98 exp/normalized.txt
```
```
今天的最低气温达到零下十度.
只要有四分之三十三的人同意,就可以通过决议。
一九四五年五月二日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。
四月十六日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。
如果剩下的百分之三十点六是过去,那么还有百分之六十九点四.
事情发生在二零二零年三月三十一日的上午八点.
警方正在找一支点二二口径的手枪。
欢迎致电中国联通,北京二零二二年冬奥会官方合作伙伴为您服务
充值缴费请按一,查询话费及余量请按二,跳过本次提醒请按井号键。
快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按九,查询话费、套餐余量、积分及活动返款请按一,手机上网流量开通及取消请按二,查���本机号码及本号所使用套餐请按四,密码修改及重置请按五,紧急开机请按六,挂失请按七,查询充值记录请按八,其它自助服务及工服务请按零
```
今天的最低气温达到-10°C.
只要有33/4的人同意,就可以通过决议。
1945年5月2日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。
4月16日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。
如果剩下的30.6%是过去,那么还有69.4%.
事情发生在2020/03/31的上午8:00.
警方正在找一支.22口径的手枪。
欢迎致电中国联通,北京2022年冬奥会官方合作伙伴为您服务
充值缴费请按1,查询话费及余量请按2,跳过本次提醒请按井号键。
快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按9,查询话费、套餐余量、积分及活动返款请按1,手机上网流量开通及取消请按2,查询本机号码及本号所使用套餐请按4,密码修改及重置请按5,紧急开机请按6,挂失请按7,查询充值记录请按8,其它自助服务及人工服务请按0
智能客服助理快速查话费、查流量请按9,了解北京联通业务请按1,宽带IPTV新装、查询请按2,障碍报修请按3,充值缴费请按4,投诉建议请按5,政企业务请按7,人工服务请按0,for english severice press star key
您的帐户当前可用余额为63.89元,本月消费为2.17元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。
您的帐户当前可用余额为负15.5元,本月消费为59.6元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。
尊敬的客户,您目前的话费余额为负14.60元,已低于10元,为保证您的通信畅通,请及时缴纳费用。
您的流量已用完,为避免您产生额外费用,建议您根据需求开通一个流量包以作补充。
您可以直接说,查询话费及余量、开通流量包、缴费,您也可以说出其它需求,请问有什么可以帮您?
您的账户当前可用余额为负36.00元,本月消费36.00元。
请问你是电话13985608526的机主吗?
如您对处理结果不满意,可拨打中国联通集团投诉电话10015进行投诉,按本地通话费收费,返回自助服务请按井号键
“26314”号VIP客服代表为您服务。
尊敬的5G用户,欢迎您致电中国联通
首先是应用了M1芯片的iPad Pro,新款的iPad Pro支持5G,这也是苹果的第二款5G产品线。
除此之外,摄像头方面再次升级,增加了前摄全新超广角摄像头,支持人物居中功能,搭配超广角可实现视频中始终让人物居中效果。
屏幕方面,iPad Pro 12.9版本支持XDR体验的Mini-LEDS显示屏,支持HDR10、杜比视界,还支持杜比全景声。
iPad Pro的秒控键盘这次也推出白色版本。
售价方面,11英寸版本售价799美元起,12.9英寸售价1099美元起。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from text_processing import normalization
parser = argparse.ArgumentParser(
description="Normalize text in Chinese with some rules.")
parser.add_argument("input", type=str, help="the input sentences")
parser.add_argument("output", type=str, help="path to save the output file.")
args = parser.parse_args()
with open(args.input, 'rt') as fin:
with open(args.output, 'wt') as fout:
for sent in fin:
sent = normalization.normalize_sentence(sent.strip())
fout.write(sent)
fout.write('\n')
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${MAIN_ROOT}/third_party:${PYTHONPATH}#
#!/usr/bin/env bash
source path.sh
stage=-1
stop_stage=100
exp_dir=exp
data_dir=data
filename="sentences.txt"
source ${MAIN_ROOT}/utils/parse_options.sh || exit -1
mkdir -p ${exp_dir}
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "stage 1: Processing "
python3 local/test_normalization.py ${data_dir}/${filename} ${exp_dir}/normalized.txt
if [ -f "${exp_dir}/normalized.txt" ]; then
echo "Normalized text save at ${exp_dir}/normalized.txt"
fi
# TODO(chenfeiyu): compute edit distance against ground-truth
fi
echo "done"
exit 0
coverage
gpustat
jsonlines
kaldiio
llvmlite==0.31.0
loguru
numba==0.47.0
numpy==1.18.5
Pillow
pre-commit
pybind11
python-speech-features
resampy==0.2.2
sacrebleu
scipy==1.2.1
sentencepiece
snakeviz
SoundFile==0.9.0.post1
sox
soxbindings
tensorboardX
textgrid
tqdm
typeguard
visualdl==2.2.0
yacs
#! /usr/bin/env bash
cd .. >> /dev/null
source utils/log.sh
SUDO='sudo'
if [ $(id -u) -eq 0 ]; then
SUDO=''
fi
if [ -e /etc/lsb-release ];then
${SUDO} apt-get update -y
${SUDO} apt-get install -y jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
if [ $? != 0 ]; then
error_msg "Please using Ubuntu or install pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev by user."
exit -1
fi
fi
source tools/venv/bin/activate
cd -
#install python dependencies
if [ -f "requirements.txt" ]; then
pip3 install -r requirements.txt
fi
if [ $? != 0 ]; then
error_msg "Install python dependencies failed !!!"
exit 1
fi
cd .. >> /dev/null
# install package libsndfile
python3 -c "import soundfile"
if [ $? != 0 ]; then
info_msg "Install package libsndfile into default system path."
wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz"
if [ $? != 0 ]; then
error_msg "Download libsndfile-1.0.28.tar.gz failed !!!"
exit 1
fi
tar -zxvf libsndfile-1.0.28.tar.gz
cd libsndfile-1.0.28
./configure > /dev/null && make > /dev/null && make install > /dev/null
cd ..
rm -rf libsndfile-1.0.28
rm libsndfile-1.0.28.tar.gz
fi
# install decoders
python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")"
if [ $? != 0 ]; then
cd deepspeech/decoders/swig > /dev/null
sh setup.sh
cd - > /dev/null
fi
python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")"
if [ $? != 0 ]; then
error_msg "Please check why decoder install error!"
exit -1
fi
info_msg "Install all dependencies successfully."
......@@ -9,14 +9,21 @@ if [ $(id -u) -eq 0 ]; then
fi
if [ -e /etc/lsb-release ];then
#${SUDO} apt-get update
${SUDO} apt-get install -y vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
${SUDO} apt-get update -y
${SUDO} apt-get install -y jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
if [ $? != 0 ]; then
error_msg "Please using Ubuntu or install pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev by user."
exit -1
fi
fi
# tools/make
rm tools/*.done
pushd tools && make && popd
source tools/venv/bin/activate
# install python dependencies
if [ -f "requirements.txt" ]; then
pip3 install -r requirements.txt
......@@ -43,6 +50,22 @@ if [ $? != 0 ]; then
rm libsndfile-1.0.28.tar.gz
fi
#install auto-log
python -c "import auto_log"
if [ $? != 0 ]; then
info_msg "Install auto_log into default system path"
test -d AutoLog || git clone https://github.com/LDOUBLEV/AutoLog
if [ $? != 0 ]; then
error_msg "Download auto_log failed !!!"
exit 1
fi
cd AutoLog
pip install -r requirements.txt
python setup.py install
cd ..
rm -rf AutoLog
fi
# install decoders
python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")"
if [ $? != 0 ]; then
......@@ -66,4 +89,5 @@ if [ $? != 0 ]; then
fi
popd
info_msg "Install all dependencies successfully."
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(speechnn VERSION 0.1)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/src CACHE PATH "Install path prefix." FORCE)
endif(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
# include file
include(cmake/third_party.cmake)
set(CMAKE_VERBOSE_MAKEFILE on)
# set std-14
set(CMAKE_CXX_STANDARD 14)
# # fc_patch dir
# set(FETCHCONTENT_QUIET off)
# get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
# set(FETCHCONTENT_BASE_DIR ${fc_patch})
#
#
# ###############################################################################
# # Option Configurations
# ###############################################################################
# # option configurations
# option(TEST_DEBUG "option for debug" OFF)
#
#
# ###############################################################################
# # Add local library
# ###############################################################################
# # system lib
# find_package()
# # if dir have CmakeLists.txt
# add_subdirectory()
# # if dir do not have CmakeLists.txt
# add_library(lib_name STATIC file.cc)
# target_link_libraries(lib_name item0 item1)
# add_dependencies(lib_name depend-target)
#
#
# ###############################################################################
# # Library installation
# ###############################################################################
# install()
#
#
# ###############################################################################
# # Build binary file
# ###############################################################################
# add_executable()
# target_link_libraries()
#
include(ExternalProject)
# Creat a target named "third_party", which can compile external dependencies on all platform(windows/linux/mac)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
set(THIRD_PARTY_CACHE_PATH "${CMAKE_SOURCE_DIR}" CACHE STRING
"A path cache third party source code to avoid repeated download.")
set(THIRD_PARTY_BUILD_TYPE Release)
set(third_party_deps)
# cache funciton to avoid repeat download code of third_party.
# This function has 4 parameters, URL / REPOSITOR / TAG / DIR:
# 1. URL: specify download url of 3rd party
# 2. REPOSITORY: specify git REPOSITORY of 3rd party
# 3. TAG: specify git tag/branch/commitID of 3rd party
# 4. DIR: overwrite the original SOURCE_DIR when cache directory
#
# The function Return 1 PARENT_SCOPE variables:
# - ${TARGET}_DOWNLOAD_CMD: Simply place "${TARGET}_DOWNLOAD_CMD" in ExternalProject_Add,
# and you no longer need to set any donwnload steps in ExternalProject_Add.
# For example:
# Cache_third_party(${TARGET}
# REPOSITORY ${TARGET_REPOSITORY}
# TAG ${TARGET_TAG}
# DIR ${TARGET_SOURCE_DIR})
FUNCTION(cache_third_party TARGET)
SET(options "")
SET(oneValueArgs URL REPOSITORY TAG DIR)
SET(multiValueArgs "")
cmake_parse_arguments(cache_third_party "${optionps}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
STRING(REPLACE "extern_" "" TARGET_NAME ${TARGET})
STRING(REGEX REPLACE "[0-9]+" "" TARGET_NAME ${TARGET_NAME})
STRING(TOUPPER ${TARGET_NAME} TARGET_NAME)
IF(cache_third_party_REPOSITORY)
SET(${TARGET_NAME}_DOWNLOAD_CMD
GIT_REPOSITORY ${cache_third_party_REPOSITORY})
IF(cache_third_party_TAG)
LIST(APPEND ${TARGET_NAME}_DOWNLOAD_CMD
GIT_TAG ${cache_third_party_TAG})
ENDIF()
ELSEIF(cache_third_party_URL)
SET(${TARGET_NAME}_DOWNLOAD_CMD
URL ${cache_third_party_URL})
ELSE()
MESSAGE(FATAL_ERROR "Download link (Git repo or URL) must be specified for cache!")
ENDIF()
IF(WITH_TP_CACHE)
IF(NOT cache_third_party_DIR)
MESSAGE(FATAL_ERROR "Please input the ${TARGET_NAME}_SOURCE_DIR for overwriting when -DWITH_TP_CACHE=ON")
ENDIF()
# Generate and verify cache dir for third_party source code
SET(cache_third_party_REPOSITORY ${cache_third_party_REPOSITORY} ${cache_third_party_URL})
IF(cache_third_party_REPOSITORY AND cache_third_party_TAG)
STRING(MD5 HASH_REPO ${cache_third_party_REPOSITORY})
STRING(MD5 HASH_GIT ${cache_third_party_TAG})
STRING(SUBSTRING ${HASH_REPO} 0 8 HASH_REPO)
STRING(SUBSTRING ${HASH_GIT} 0 8 HASH_GIT)
STRING(CONCAT HASH ${HASH_REPO} ${HASH_GIT})
# overwrite the original SOURCE_DIR when cache directory
SET(${cache_third_party_DIR} ${THIRD_PARTY_CACHE_PATH}/third_party/${TARGET}_${HASH})
ELSEIF(cache_third_party_REPOSITORY)
STRING(MD5 HASH_REPO ${cache_third_party_REPOSITORY})
STRING(SUBSTRING ${HASH_REPO} 0 16 HASH)
# overwrite the original SOURCE_DIR when cache directory
SET(${cache_third_party_DIR} ${THIRD_PARTY_CACHE_PATH}/third_party/${TARGET}_${HASH})
ENDIF()
IF(EXISTS ${${cache_third_party_DIR}})
# judge whether the cache dir is empty
FILE(GLOB files ${${cache_third_party_DIR}}/*)
LIST(LENGTH files files_len)
IF(files_len GREATER 0)
list(APPEND ${TARGET_NAME}_DOWNLOAD_CMD DOWNLOAD_COMMAND "")
ENDIF()
ENDIF()
SET(${cache_third_party_DIR} ${${cache_third_party_DIR}} PARENT_SCOPE)
ENDIF()
# Pass ${TARGET_NAME}_DOWNLOAD_CMD to parent scope, the double quotation marks can't be removed
SET(${TARGET_NAME}_DOWNLOAD_CMD "${${TARGET_NAME}_DOWNLOAD_CMD}" PARENT_SCOPE)
ENDFUNCTION()
MACRO(UNSET_VAR VAR_NAME)
UNSET(${VAR_NAME} CACHE)
UNSET(${VAR_NAME})
ENDMACRO()
# Funciton to Download the dependencies during compilation
# This function has 2 parameters, URL / DIRNAME:
# 1. URL: The download url of 3rd dependencies
# 2. NAME: The name of file, that determin the dirname
#
FUNCTION(file_download_and_uncompress URL NAME)
set(options "")
set(oneValueArgs MD5)
set(multiValueArgs "")
cmake_parse_arguments(URL "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
MESSAGE(STATUS "Download dependence[${NAME}] from ${URL}, MD5: ${URL_MD5}")
SET(${NAME}_INCLUDE_DIR ${THIRD_PARTY_PATH}/${NAME}/data PARENT_SCOPE)
ExternalProject_Add(
download_${NAME}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${THIRD_PARTY_PATH}/${NAME}
URL ${URL}
URL_MD5 ${URL_MD5}
TIMEOUT 120
DOWNLOAD_DIR ${THIRD_PARTY_PATH}/${NAME}/data/
SOURCE_DIR ${THIRD_PARTY_PATH}/${NAME}/data/
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
UPDATE_COMMAND ""
INSTALL_COMMAND ""
)
set(third_party_deps ${third_party_deps} download_${NAME} PARENT_SCOPE)
ENDFUNCTION()
# Correction of flags on different Platform(WIN/MAC) and Print Warning Message
if (APPLE)
if(WITH_MKL)
MESSAGE(WARNING
"Mac is not supported with MKL in Paddle yet. Force WITH_MKL=OFF.")
set(WITH_MKL OFF CACHE STRING "Disable MKL for building on mac" FORCE)
endif()
endif()
if(WIN32 OR APPLE)
MESSAGE(STATUS "Disable XBYAK in Windows and MacOS")
SET(WITH_XBYAK OFF CACHE STRING "Disable XBYAK in Windows and MacOS" FORCE)
if(WITH_LIBXSMM)
MESSAGE(WARNING
"Windows, Mac are not supported with libxsmm in Paddle yet."
"Force WITH_LIBXSMM=OFF")
SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM in Windows and MacOS" FORCE)
endif()
if(WITH_BOX_PS)
MESSAGE(WARNING
"Windows or Mac is not supported with BOX_PS in Paddle yet."
"Force WITH_BOX_PS=OFF")
SET(WITH_BOX_PS OFF CACHE STRING "Disable BOX_PS package in Windows and MacOS" FORCE)
endif()
if(WITH_PSLIB)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB in Paddle yet."
"Force WITH_PSLIB=OFF")
SET(WITH_PSLIB OFF CACHE STRING "Disable PSLIB package in Windows and MacOS" FORCE)
endif()
if(WITH_LIBMCT)
MESSAGE(WARNING
"Windows or Mac is not supported with LIBMCT in Paddle yet."
"Force WITH_LIBMCT=OFF")
SET(WITH_LIBMCT OFF CACHE STRING "Disable LIBMCT package in Windows and MacOS" FORCE)
endif()
if(WITH_PSLIB_BRPC)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB_BRPC in Paddle yet."
"Force WITH_PSLIB_BRPC=OFF")
SET(WITH_PSLIB_BRPC OFF CACHE STRING "Disable PSLIB_BRPC package in Windows and MacOS" FORCE)
endif()
endif()
set(WITH_MKLML ${WITH_MKL})
if(NOT DEFINED WITH_MKLDNN)
if(WITH_MKL AND AVX2_FOUND)
set(WITH_MKLDNN ON)
else()
message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN")
set(WITH_MKLDNN OFF)
endif()
endif()
if(WIN32 OR APPLE OR NOT WITH_GPU OR ON_INFER)
set(WITH_DGC OFF)
endif()
if(${CMAKE_VERSION} VERSION_GREATER "3.5.2")
set(SHALLOW_CLONE "GIT_SHALLOW TRUE") # adds --depth=1 arg to git clone of External_Projects
endif()
########################### include third_party according to flags ###############################
include(third_party/libsndfile) # download, build, install libsndfile
include(third_party/boost) # download boost
include(third_party/eigen) # download eigen3
include(third_party/threadpool) # download threadpool
cmake_minimum_required(VERSION 3.14)
include(ExternalProject)
include(FetchContent)
FetchContent_Declare(
absl
GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
GIT_TAG "20210324.1"
)
FetchContent_MakeAvailable(absl)
include(ExternalProject)
set(BOOST_PROJECT "extern_boost")
# To release PaddlePaddle as a pip package, we have to follow the
# manylinux1 standard, which features as old Linux kernels and
# compilers as possible and recommends CentOS 5. Indeed, the earliest
# CentOS version that works with NVIDIA CUDA is CentOS 6. And a new
# version of boost, say, 1.66.0, doesn't build on CentOS 6. We
# checked that the devtools package of CentOS 6 installs boost 1.41.0.
# So we use 1.41.0 here.
set(BOOST_VER "1.41.0")
set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE)
set(BOOST_URL "http://paddlepaddledeps.bj.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
MESSAGE(STATUS "BOOST_VERSION: ${BOOST_VER}, BOOST_URL: ${BOOST_URL}")
set(BOOST_PREFIX_DIR ${THIRD_PARTY_PATH}/boost)
set(BOOST_SOURCE_DIR ${THIRD_PARTY_PATH}/boost/src/extern_boost)
cache_third_party(${BOOST_PROJECT}
URL ${BOOST_URL}
DIR BOOST_SOURCE_DIR)
set(BOOST_INCLUDE_DIR "${BOOST_SOURCE_DIR}" CACHE PATH "boost include directory." FORCE)
set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1)
include_directories(${BOOST_INCLUDE_DIR})
if(WIN32 AND MSVC_VERSION GREATER_EQUAL 1600)
add_definitions(-DBOOST_HAS_STATIC_ASSERT)
endif()
ExternalProject_Add(
${BOOST_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
"${BOOST_DOWNLOAD_CMD}"
URL_MD5 f891e8c2c9424f0565f0129ad9ab4aff
PREFIX ${BOOST_PREFIX_DIR}
DOWNLOAD_DIR ${BOOST_SOURCE_DIR}
SOURCE_DIR ${BOOST_SOURCE_DIR}
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
)
add_library(boost INTERFACE)
add_dependencies(boost ${BOOST_PROJECT})
set(Boost_INCLUDE_DIR ${BOOST_INCLUDE_DIR})
include(ExternalProject)
# update eigen to the commit id f612df27 on 03/16/2021
set(EIGEN_PREFIX_DIR ${THIRD_PARTY_PATH}/eigen3)
set(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3/src/extern_eigen3)
set(EIGEN_REPOSITORY https://gitlab.com/libeigen/eigen.git)
set(EIGEN_TAG f612df273689a19d25b45ca4f8269463207c4fee)
cache_third_party(extern_eigen3
REPOSITORY ${EIGEN_REPOSITORY}
TAG ${EIGEN_TAG}
DIR EIGEN_SOURCE_DIR)
if(WIN32)
add_definitions(-DEIGEN_STRONG_INLINE=inline)
elseif(LINUX)
if(WITH_ROCM)
# For HIPCC Eigen::internal::device::numeric_limits is not EIGEN_DEVICE_FUNC
# which will cause compiler error of using __host__ funciont in __host__ __device__
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Meta.h native_src)
file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/util/Meta.h native_dst)
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/TensorReductionGpu.h native_src1)
file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h native_dst1)
set(EIGEN_PATCH_COMMAND cp ${native_src} ${native_dst} && cp ${native_src1} ${native_dst1})
endif()
endif()
set(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR})
INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR})
ExternalProject_Add(
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
"${EIGEN_DOWNLOAD_CMD}"
PREFIX ${EIGEN_PREFIX_DIR}
SOURCE_DIR ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND ${EIGEN_PATCH_COMMAND}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
add_library(eigen3 INTERFACE)
add_dependencies(eigen3 extern_eigen3)
# sw not support thread_local semantic
if(WITH_SW)
add_definitions(-DEIGEN_AVOID_THREAD_LOCAL)
endif()
cmake_minimum_required(VERSION 3.14)
include(ExternalProject)
include(FetchContent)
FetchContent_Declare(
libsndfile
GIT_REPOSITORY https://github.com/libsndfile/libsndfile.git
GIT_TAG v1.0.30 # tag v1.0.30
)
FetchContent_GetProperties(libsndfile)
cmake_minimum_required(VERSION 3.14)
include(ExternalProject)
include(FetchContent)
FetchContent_Declare(
openfst
GIT_REPOSITORY https://github.com/kkm000/openfst
GIT_TAG 338225416178ac36b8002d70387f5556e44c8d05 # tag win/1.7.2.1
)
FetchContent_GetProperties(openfst)
if(NOT openfst_POPULATED)
FetchContent_Populate(openfst)
include_directories(${openfst_SOURCE_DIR}/src/include)
add_subdirectory(${openfst_SOURCE_DIR} ${openfst_BINARY_DIR})
install(DIRECTORY ${openfst_SOURCE_DIR}/src/include/ DESTINATION include/
FILES_MATCHING PATTERN "*.h")
install(TARGETS fst
EXPORT kaldi-targets
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()
if(NOT OPENFST_ROOT_DIR)
message(FATAL_ERROR)
endif()
set(fst_source_dir ${OPENFST_ROOT_DIR}/src/lib)
set(fst_include_dir ${OPENFST_ROOT_DIR}/src/include)
include_directories(${fst_include_dir})
file(GLOB fst_sources "${fst_source_dir}/*.cc")
add_library(fst ${fst_sources})
target_include_directories(fst PUBLIC
$<BUILD_INTERFACE:${fst_include_dir}>
$<INSTALL_INTERFACE:include/openfst>
)
install(TARGETS fst
EXPORT kaldi-targets
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
)
install(DIRECTORY ${fst_include_dir}/fst
DESTINATION include/openfst
PATTERN "test/*.h" EXCLUDE
)
unset(fst_source_dir)
unset(fst_include_dir)
unset(fst_sources)
INCLUDE(ExternalProject)
SET(THREADPOOL_PREFIX_DIR ${THIRD_PARTY_PATH}/threadpool)
SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool/src/extern_threadpool)
if(WITH_ASCEND OR WITH_ASCEND_CL)
SET(THREADPOOL_REPOSITORY https://gitee.com/tianjianhe/ThreadPool.git)
else()
SET(THREADPOOL_REPOSITORY ${GIT_URL}/progschj/ThreadPool.git)
endif()
SET(THREADPOOL_TAG 9a42ec1329f259a5f4881a291db1dcb8f2ad9040)
cache_third_party(extern_threadpool
REPOSITORY ${THREADPOOL_REPOSITORY}
TAG ${THREADPOOL_TAG}
DIR THREADPOOL_SOURCE_DIR)
SET(THREADPOOL_INCLUDE_DIR ${THREADPOOL_SOURCE_DIR})
INCLUDE_DIRECTORIES(${THREADPOOL_INCLUDE_DIR})
ExternalProject_Add(
extern_threadpool
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
"${THREADPOOL_DOWNLOAD_CMD}"
PREFIX ${THREADPOOL_PREFIX_DIR}
SOURCE_DIR ${THREADPOOL_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
add_library(simple_threadpool INTERFACE)
add_dependencies(simple_threadpool extern_threadpool)
function(get_version)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/src/.version version)
string(STRIP ${version} version)
execute_process(COMMAND git log -n1 --format=%H src/.version
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE version_commit
OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND git rev-list --count "${version_commit}..HEAD"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE patch_number)
string(STRIP ${patch_number} patch_number)
set(KALDI_VERSION ${version} PARENT_SCOPE)
set(KALDI_PATCH_NUMBER ${patch_number} PARENT_SCOPE)
endfunction()
# Fast Transformers for Speech
- Conformer
- Transformer
## Reference
* https://github.com/NVIDIA/FasterTransformer.git
* https://github.com/idiap/fast-transformers
aux_source_directory(. DIR_LIB_SRCS)
add_library(decoder STATIC ${DIR_LIB_SRCS})
"""
Module containing all the spectrogram classes
"""
# 0.2.0
import torch
import torch.nn as nn
from torch.nn.functional import conv1d, conv2d, fold
import scipy # used only in CFP
import numpy as np
from time import time
# from nnAudio.librosa_functions import * # For debug purpose
# from nnAudio.utils import *
from .librosa_functions import *
from .utils import *
sz_float = 4 # size of a float
epsilon = 10e-8 # fudge factor for normalization
### --------------------------- Spectrogram Classes ---------------------------###
class STFT(torch.nn.Module):
"""This function is to calculate the short-time Fourier transform (STFT) of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred automatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
n_fft : int
Size of Fourier transform. Default value is 2048.
win_length : int
the size of window frame and STFT filter.
Default: None (treated as equal to n_fft)
freq_bins : int
Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins.
hop_length : int
The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``.
window : str
The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
freq_scale : 'linear', 'log', or 'no'
Determine the spacing between each frequency bin. When `linear` or `log` is used,
the bin spacing can be controlled by ``fmin`` and ``fmax``. If 'no' is used, the bin will
start at 0Hz and end at Nyquist frequency with linear spacing.
center : bool
Putting the STFT keneral at the center of the time-step or not. If ``False``, the time
index is the beginning of the STFT kernel, if ``True``, the time index is the center of
the STFT kernel. Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
iSTFT : bool
To activate the iSTFT module or not. By default, it is False to save GPU memory.
Note: The iSTFT kernel is not trainable. If you want
a trainable iSTFT, use the iSTFT module.
fmin : int
The starting frequency for the lowest frequency bin. If freq_scale is ``no``, this argument
does nothing.
fmax : int
The ending frequency for the highest frequency bin. If freq_scale is ``no``, this argument
does nothing.
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
trainable : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT
kernels will also be caluclated and the STFT kernels will be updated during model training.
Default value is ``False``
output_format : str
Control the spectrogram output type, either ``Magnitude``, ``Complex``, or ``Phase``.
The output_format can also be changed during the ``forward`` method.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms.
``shape = (num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``;
``shape = (num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``;
Examples
--------
>>> spec_layer = Spectrogram.STFT()
>>> specs = spec_layer(x)
"""
def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',
freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,
fmin=50, fmax=6000, sr=22050, trainable=False,
output_format="Complex", verbose=True):
super().__init__()
# Trying to make the default setting same as librosa
if win_length==None: win_length = n_fft
if hop_length==None: hop_length = int(win_length // 4)
self.output_format = output_format
self.trainable = trainable
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
self.freq_bins = freq_bins
self.trainable = trainable
self.pad_amount = self.n_fft // 2
self.window = window
self.win_length = win_length
self.iSTFT = iSTFT
self.trainable = trainable
start = time()
# Create filter windows for stft
kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,
win_length=win_length,
freq_bins=freq_bins,
window=window,
freq_scale=freq_scale,
fmin=fmin,
fmax=fmax,
sr=sr,
verbose=verbose)
kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)
kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)
# In this way, the inverse kernel and the forward kernel do not share the same memory...
kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)
kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)
if iSTFT:
self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))
self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))
# Making all these variables nn.Parameter, so that the model can be used with nn.Parallel
# self.kernel_sin = torch.nn.Parameter(self.kernel_sin, requires_grad=self.trainable)
# self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable)
# Applying window functions to the Fourier kernels
if window:
window_mask = torch.tensor(window_mask)
wsin = kernel_sin * window_mask
wcos = kernel_cos * window_mask
else:
wsin = kernel_sin
wcos = kernel_cos
if self.trainable==False:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)
if self.trainable==True:
wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)
wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
# Prepare the shape of window mask so that it can be used later in inverse
self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))
if verbose==True:
print("STFT kernels created, time used = {:.4f} seconds".format(time()-start))
else:
pass
def forward(self, x, output_format=None):
"""
Convert a batch of waveforms to spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
output_format : str
Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.
Default value is ``Complex``.
"""
output_format = output_format or self.output_format
self.num_samples = x.shape[-1]
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.pad_amount, 0)
elif self.pad_mode == 'reflect':
if self.num_samples < self.pad_amount:
raise AssertionError("Signal length shorter than reflect padding length (n_fft // 2).")
padding = nn.ReflectionPad1d(self.pad_amount)
x = padding(x)
spec_imag = conv1d(x, self.wsin, stride=self.stride)
spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d
# remove redundant parts
spec_real = spec_real[:, :self.freq_bins, :]
spec_imag = spec_imag[:, :self.freq_bins, :]
if output_format=='Magnitude':
spec = spec_real.pow(2) + spec_imag.pow(2)
if self.trainable==True:
return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0
else:
return torch.sqrt(spec)
elif output_format=='Complex':
return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part
elif output_format=='Phase':
return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase
def inverse(self, X, onesided=True, length=None, refresh_win=True):
"""
This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class,
which is to convert spectrograms back to waveforms.
It only works for the complex value spectrograms. If you have the magnitude spectrograms,
please use :func:`~nnAudio.Spectrogram.Griffin_Lim`.
Parameters
----------
onesided : bool
If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,
else use ``onesided=False``
length : int
To make sure the inverse STFT has the same output length of the original waveform, please
set `length` as your intended waveform length. By default, ``length=None``,
which will remove ``n_fft//2`` samples from the start and the end of the output.
refresh_win : bool
Recalculating the window sum square. If you have an input with fixed number of timesteps,
you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``
"""
if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):
raise NameError("Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`")
assert X.dim()==4 , "Inverse iSTFT only works for complex number," \
"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2)."\
"\nIf you have a magnitude spectrogram, please consider using Griffin-Lim."
if onesided:
X = extend_fbins(X) # extend freq
X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]
# broadcast dimensions to support 2D convolution
X_real_bc = X_real.unsqueeze(1)
X_imag_bc = X_imag.unsqueeze(1)
a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))
b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))
# compute real and imag part. signal lies in the real part
real = a1 - b2
real = real.squeeze(-2)*self.window_mask
# Normalize the amplitude with n_fft
real /= (self.n_fft)
# Overlap and Add algorithm to connect all the frames
real = overlap_add(real, self.stride)
# Prepare the window sumsqure for division
# Only need to create this window once to save time
# Unless the input spectrograms have different time steps
if hasattr(self, 'w_sum')==False or refresh_win==True:
self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()
self.nonzero_indices = (self.w_sum>1e-10)
else:
pass
real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])
# Remove padding
if length is None:
if self.center:
real = real[:, self.pad_amount:-self.pad_amount]
else:
if self.center:
real = real[:, self.pad_amount:self.pad_amount + length]
else:
real = real[:, :length]
return real
def extra_repr(self) -> str:
return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(
self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable
)
class MelSpectrogram(torch.nn.Module):
"""This function is to calculate the Melspectrogram of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred automatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio.
It is used to calculate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
n_fft : int
The window size for the STFT. Default value is 2048
win_length : int
the size of window frame and STFT filter.
Default: None (treated as equal to n_fft)
n_mels : int
The number of Mel filter banks. The filter banks maps the n_fft to mel bins.
Default value is 128.
hop_length : int
The hop (or stride) size. Default value is 512.
window : str
The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
center : bool
Putting the STFT keneral at the center of the time-step or not. If ``False``,
the time index is the beginning of the STFT kernel, if ``True``, the time index is the
center of the STFT kernel. Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
htk : bool
When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the
Mel scale is logarithmic. The default value is ``False``.
fmin : int
The starting frequency for the lowest Mel filter bank.
fmax : int
The ending frequency for the highest Mel filter bank.
norm :
if 1, divide the triangular mel weights by the width of the mel band
(area normalization, AKA 'slaney' default in librosa).
Otherwise, leave all the triangles aiming for
a peak value of 1.0
trainable_mel : bool
Determine if the Mel filter banks are trainable or not. If ``True``, the gradients for Mel
filter banks will also be calculated and the Mel filter banks will be updated during model
training. Default value is ``False``.
trainable_STFT : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT
kernels will also be caluclated and the STFT kernels will be updated during model training.
Default value is ``False``.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints.
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``.
Examples
--------
>>> spec_layer = Spectrogram.MelSpectrogram()
>>> specs = spec_layer(x)
"""
def __init__(self, sr=22050, n_fft=2048, win_length=None, n_mels=128, hop_length=512,
window='hann', center=True, pad_mode='reflect', power=2.0, htk=False,
fmin=0.0, fmax=None, norm=1, trainable_mel=False, trainable_STFT=False,
verbose=True, **kwargs):
super().__init__()
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
self.power = power
self.trainable_mel = trainable_mel
self.trainable_STFT = trainable_STFT
# Preparing for the stft layer. No need for center
self.stft = STFT(n_fft=n_fft, win_length=win_length, freq_bins=None,
hop_length=hop_length, window=window, freq_scale='no',
center=center, pad_mode=pad_mode, sr=sr, trainable=trainable_STFT,
output_format="Magnitude", verbose=verbose, **kwargs)
# Create filter windows for stft
start = time()
# Creating kernel for mel spectrogram
start = time()
mel_basis = mel(sr, n_fft, n_mels, fmin, fmax, htk=htk, norm=norm)
mel_basis = torch.tensor(mel_basis)
if verbose==True:
print("STFT filter created, time used = {:.4f} seconds".format(time()-start))
print("Mel filter created, time used = {:.4f} seconds".format(time()-start))
else:
pass
if trainable_mel:
# Making everything nn.Parameter, so that this model can support nn.DataParallel
mel_basis = torch.nn.Parameter(mel_basis, requires_grad=trainable_mel)
self.register_parameter('mel_basis', mel_basis)
else:
self.register_buffer('mel_basis', mel_basis)
# if trainable_mel==True:
# self.mel_basis = torch.nn.Parameter(self.mel_basis)
# if trainable_STFT==True:
# self.wsin = torch.nn.Parameter(self.wsin)
# self.wcos = torch.nn.Parameter(self.wcos)
def forward(self, x):
"""
Convert a batch of waveforms to Mel spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
x = broadcast_dim(x)
spec = self.stft(x, output_format='Magnitude')**self.power
melspec = torch.matmul(self.mel_basis, spec)
return melspec
def extra_repr(self) -> str:
return 'Mel filter banks size = {}, trainable_mel={}'.format(
(*self.mel_basis.shape,), self.trainable_mel, self.trainable_STFT
)
class MFCC(torch.nn.Module):
"""This function is to calculate the Mel-frequency cepstral coefficients (MFCCs) of the input signal.
This algorithm first extracts Mel spectrograms from the audio clips,
then the discrete cosine transform is calcuated to obtain the final MFCCs.
Therefore, the Mel spectrogram part can be made trainable using
``trainable_mel`` and ``trainable_STFT``.
It only support type-II DCT at the moment. Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred autommatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calculate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
n_mfcc : int
The number of Mel-frequency cepstral coefficients
norm : string
The default value is 'ortho'. Normalization for DCT basis
**kwargs
Other arguments for Melspectrogram such as n_fft, n_mels, hop_length, and window
Returns
-------
MFCCs : torch.tensor
It returns a tensor of MFCCs. shape = ``(num_samples, n_mfcc, time_steps)``.
Examples
--------
>>> spec_layer = Spectrogram.MFCC()
>>> mfcc = spec_layer(x)
"""
def __init__(self, sr=22050, n_mfcc=20, norm='ortho', verbose=True, ref=1.0, amin=1e-10, top_db=80.0, **kwargs):
super().__init__()
self.melspec_layer = MelSpectrogram(sr=sr, verbose=verbose, **kwargs)
self.m_mfcc = n_mfcc
# attributes that will be used for _power_to_db
if amin <= 0:
raise ParameterError('amin must be strictly positive')
amin = torch.tensor([amin])
ref = torch.abs(torch.tensor([ref]))
self.register_buffer('amin', amin)
self.register_buffer('ref', ref)
self.top_db = top_db
self.n_mfcc = n_mfcc
def _power_to_db(self, S):
'''
Refer to https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#power_to_db
for the original implmentation.
'''
log_spec = 10.0 * torch.log10(torch.max(S, self.amin))
log_spec -= 10.0 * torch.log10(torch.max(self.amin, self.ref))
if self.top_db is not None:
if self.top_db < 0:
raise ParameterError('top_db must be non-negative')
# make the dim same as log_spec so that it can be broadcasted
batch_wise_max = log_spec.flatten(1).max(1)[0].unsqueeze(1).unsqueeze(1)
log_spec = torch.max(log_spec, batch_wise_max - self.top_db)
return log_spec
def _dct(self, x, norm=None):
'''
Refer to https://github.com/zh217/torch-dct for the original implmentation.
'''
x = x.permute(0,2,1) # make freq the last axis, since dct applies to the frequency axis
x_shape = x.shape
N = x_shape[-1]
v = torch.cat([x[:, :, ::2], x[:, :, 1::2].flip([2])], dim=2)
Vc = torch.rfft(v, 1, onesided=False)
# TODO: Can make the W_r and W_i trainable here
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V = Vc[:, :, :, 0] * W_r - Vc[:, :, :, 1] * W_i
if norm == 'ortho':
V[:, :, 0] /= np.sqrt(N) * 2
V[:, :, 1:] /= np.sqrt(N / 2) * 2
V = 2 * V
return V.permute(0,2,1) # swapping back the time axis and freq axis
def forward(self, x):
"""
Convert a batch of waveforms to MFCC.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
x = self.melspec_layer(x)
x = self._power_to_db(x)
x = self._dct(x, norm='ortho')[:,:self.m_mfcc,:]
return x
def extra_repr(self) -> str:
return 'n_mfcc = {}'.format(
(self.n_mfcc)
)
class Gammatonegram(torch.nn.Module):
"""
This function is to calculate the Gammatonegram of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency.
n_fft : int
The window size for the STFT. Default value is 2048
n_mels : int
The number of Gammatonegram filter banks. The filter banks maps the n_fft to Gammatone bins. Default value is 64
hop_length : int
The hop (or stride) size. Default value is 512.
window : str
The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann'
center : bool
Putting the STFT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the STFT kernel, if ``True``, the time index is the center of the STFT kernel. Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
htk : bool
When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the Mel scale is logarithmic. The default value is ``False``
fmin : int
The starting frequency for the lowest Gammatone filter bank
fmax : int
The ending frequency for the highest Gammatone filter bank
trainable_mel : bool
Determine if the Gammatone filter banks are trainable or not. If ``True``, the gradients for Mel filter banks will also be caluclated and the Mel filter banks will be updated during model training. Default value is ``False``
trainable_STFT : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT kernels will also be caluclated and the STFT kernels will be updated during model training. Default value is ``False``
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``.
Examples
--------
>>> spec_layer = Spectrogram.Gammatonegram()
>>> specs = spec_layer(x)
"""
def __init__(self, sr=44100, n_fft=2048, n_bins=64, hop_length=512, window='hann', center=True, pad_mode='reflect',
power=2.0, htk=False, fmin=20.0, fmax=None, norm=1, trainable_bins=False, trainable_STFT=False,
verbose=True):
super(Gammatonegram, self).__init__()
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
self.power = power
# Create filter windows for stft
start = time()
wsin, wcos, self.bins2freq, _, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no',
sr=sr)
wsin = torch.tensor(wsin, dtype=torch.float)
wcos = torch.tensor(wcos, dtype=torch.float)
if trainable_STFT:
wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT)
wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
else:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)
# Creating kenral for Gammatone spectrogram
start = time()
gammatone_basis = gammatone(sr, n_fft, n_bins, fmin, fmax)
gammatone_basis = torch.tensor(gammatone_basis)
if verbose == True:
print("STFT filter created, time used = {:.4f} seconds".format(time() - start))
print("Gammatone filter created, time used = {:.4f} seconds".format(time() - start))
else:
pass
# Making everything nn.Prarmeter, so that this model can support nn.DataParallel
if trainable_bins:
gammatone_basis = torch.nn.Parameter(gammatone_basis, requires_grad=trainable_bins)
self.register_parameter('gammatone_basis', gammatone_basis)
else:
self.register_buffer('gammatone_basis', gammatone_basis)
# if trainable_mel==True:
# self.mel_basis = torch.nn.Parameter(self.mel_basis)
# if trainable_STFT==True:
# self.wsin = torch.nn.Parameter(self.wsin)
# self.wcos = torch.nn.Parameter(self.wcos)
def forward(self, x):
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.n_fft // 2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.n_fft // 2)
x = padding(x)
spec = torch.sqrt(conv1d(x, self.wsin, stride=self.stride).pow(2) \
+ conv1d(x, self.wcos, stride=self.stride).pow(2)) ** self.power # Doing STFT by using conv1d
gammatonespec = torch.matmul(self.gammatone_basis, spec)
return gammatonespec
class CQT1992(torch.nn.Module):
"""
This alogrithm uses the method proposed in [1], which would run extremely slow if low frequencies (below 220Hz)
are included in the frequency bins.
Please refer to :func:`~nnAudio.Spectrogram.CQT1992v2` for a more
computational and memory efficient version.
[1] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a
constant Q transform.” (1992).
This function is to calculate the CQT of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred autommatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
hop_length : int
The hop (or stride) size. Default value is 512.
fmin : float
The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0.
fmax : float
The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is
inferred from the ``n_bins`` and ``bins_per_octave``.
If ``fmax`` is not ``None``, then the argument ``n_bins`` will be ignored and ``n_bins``
will be calculated automatically. Default is ``None``
n_bins : int
The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``.
bins_per_octave : int
Number of bins per octave. Default is 12.
trainable_STFT : bool
Determine if the time to frequency domain transformation kernel for the input audio is trainable or not.
Default is ``False``
trainable_CQT : bool
Determine if the frequency domain CQT kernel is trainable or not.
Default is ``False``
norm : int
Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization.
Default is ``1``, which is same as the normalization used in librosa.
window : str
The windowing function for CQT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
center : bool
Putting the CQT keneral at the center of the time-step or not. If ``False``, the time index is
the beginning of the CQT kernel, if ``True``, the time index is the center of the CQT kernel.
Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
trainable : bool
Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels
will also be caluclated and the CQT kernels will be updated during model training.
Default value is ``False``.
output_format : str
Determine the return type.
``Magnitude`` will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``;
``Complex`` will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``;
``Phase`` will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``.
The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms.
shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``;
shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``;
Examples
--------
>>> spec_layer = Spectrogram.CQT1992v2()
>>> specs = spec_layer(x)
"""
def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84,
trainable_STFT=False, trainable_CQT=False, bins_per_octave=12, filter_scale=1,
output_format='Magnitude', norm=1, window='hann', center=True, pad_mode='reflect'):
super().__init__()
# norm arg is not functioning
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
self.norm = norm
self.output_format = output_format
# creating kernels for CQT
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
print("Creating CQT kernels ...", end='\r')
start = time()
cqt_kernels, self.kernel_width, lenghts, freqs = create_cqt_kernels(Q,
sr,
fmin,
n_bins,
bins_per_octave,
norm,
window,
fmax)
self.register_buffer('lenghts', lenghts)
self.frequencies = freqs
cqt_kernels = fft(cqt_kernels)[:,:self.kernel_width//2+1]
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
# creating kernels for stft
# self.cqt_kernels_real*=lenghts.unsqueeze(1)/self.kernel_width # Trying to normalize as librosa
# self.cqt_kernels_imag*=lenghts.unsqueeze(1)/self.kernel_width
print("Creating STFT kernels ...", end='\r')
start = time()
kernel_sin, kernel_cos, self.bins2freq, _, window = create_fourier_kernels(self.kernel_width,
window='ones',
freq_scale='no')
# Converting kernels from numpy arrays to torch tensors
wsin = torch.tensor(kernel_sin * window)
wcos = torch.tensor(kernel_cos * window)
cqt_kernels_real = torch.tensor(cqt_kernels.real.astype(np.float32))
cqt_kernels_imag = torch.tensor(cqt_kernels.imag.astype(np.float32))
if trainable_STFT:
wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT)
wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
else:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)
if trainable_CQT:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable_CQT)
cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable_CQT)
self.register_parameter('cqt_kernels_real', cqt_kernels_real)
self.register_parameter('cqt_kernels_imag', cqt_kernels_imag)
else:
self.register_buffer('cqt_kernels_real', cqt_kernels_real)
self.register_buffer('cqt_kernels_imag', cqt_kernels_imag)
print("STFT kernels created, time used = {:.4f} seconds".format(time()-start))
def forward(self, x, output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
output_format = output_format or self.output_format
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.kernel_width//2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.kernel_width//2)
x = padding(x)
# STFT
fourier_real = conv1d(x, self.wcos, stride=self.hop_length)
fourier_imag = conv1d(x, self.wsin, stride=self.hop_length)
# CQT
CQT_real, CQT_imag = complex_mul((self.cqt_kernels_real, self.cqt_kernels_imag),
(fourier_real, fourier_imag))
CQT = torch.stack((CQT_real,-CQT_imag),-1)
if normalization_type == 'librosa':
CQT *= torch.sqrt(self.lenghts.view(-1,1,1))/self.kernel_width
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
CQT *= 2/self.kernel_width
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
# if self.norm:
# CQT = CQT/self.kernel_width*torch.sqrt(self.lenghts.view(-1,1,1))
# else:
# CQT = CQT*torch.sqrt(self.lenghts.view(-1,1,1))
if output_format=='Magnitude':
# Getting CQT Amplitude
return torch.sqrt(CQT.pow(2).sum(-1))
elif output_format=='Complex':
return CQT
elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(CQT_imag,CQT_real))
phase_imag = torch.sin(torch.atan2(CQT_imag,CQT_real))
return torch.stack((phase_real,phase_imag), -1)
def extra_repr(self) -> str:
return 'STFT kernel size = {}, CQT kernel size = {}'.format(
(*self.wcos.shape,), (*self.cqt_kernels_real.shape,)
)
class CQT2010(torch.nn.Module):
"""
This algorithm is using the resampling method proposed in [1].
Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency
spectrum, we make a small CQT kernel covering only the top octave.
Then we keep downsampling the input audio by a factor of 2 to convoluting it with the
small CQT kernel. Everytime the input audio is downsampled, the CQT relative to the downsampled
input is equavalent to the next lower octave.
The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the code
from the 1992 alogrithm [2]
[1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010).
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a
constant Q transform.” (1992).
early downsampling factor is to downsample the input audio to reduce the CQT kernel size.
The result with and without early downsampling are more or less the same except in the very low
frequency region where freq < 40Hz.
"""
def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, bins_per_octave=12,
norm=True, basis_norm=1, window='hann', pad_mode='reflect', trainable_STFT=False, filter_scale=1,
trainable_CQT=False, output_format='Magnitude', earlydownsample=True, verbose=True):
super().__init__()
self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft
# basis_norm is for normalizing basis
self.hop_length = hop_length
self.pad_mode = pad_mode
self.n_bins = n_bins
self.output_format = output_format
self.earlydownsample = earlydownsample # TODO: activate early downsampling later if possible
# This will be used to calculate filter_cutoff and creating CQT kernels
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
# Creating lowpass filter and make it a torch tensor
if verbose==True:
print("Creating low pass filter ...", end='\r')
start = time()
lowpass_filter = torch.tensor(create_lowpass_filter(
band_center = 0.5,
kernelLength=256,
transitionBandwidth=0.001
)
)
# Broadcast the tensor to the shape that fits conv1d
self.register_buffer('lowpass_filter', lowpass_filter[None,None,:])
if verbose==True:
print("Low pass filter created, time used = {:.4f} seconds".format(time()-start))
# Calculate num of filter requires for the kernel
# n_octaves determines how many resampling requires for the CQT
n_filters = min(bins_per_octave, n_bins)
self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
# print("n_octaves = ", self.n_octaves)
# Calculate the lowest frequency bin for the top octave kernel
self.fmin_t = fmin*2**(self.n_octaves-1)
remainder = n_bins % bins_per_octave
# print("remainder = ", remainder)
if remainder==0:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave)
else:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave)
self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins
if fmax_t > sr/2:
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
please reduce the n_bins'.format(fmax_t))
if self.earlydownsample == True: # Do early downsampling if this argument is True
if verbose==True:
print("Creating early downsampling filter ...", end='\r')
start = time()
sr, self.hop_length, self.downsample_factor, early_downsample_filter, \
self.earlydownsample = get_early_downsample_params(sr,
hop_length,
fmax_t,
Q,
self.n_octaves,
verbose)
self.register_buffer('early_downsample_filter', early_downsample_filter)
if verbose==True:
print("Early downsampling filter created, \
time used = {:.4f} seconds".format(time()-start))
else:
self.downsample_factor=1.
# Preparing CQT kernels
if verbose==True:
print("Creating CQT kernels ...", end='\r')
start = time()
# print("Q = {}, fmin_t = {}, n_filters = {}".format(Q, self.fmin_t, n_filters))
basis, self.n_fft, _, _ = create_cqt_kernels(Q,
sr,
self.fmin_t,
n_filters,
bins_per_octave,
norm=basis_norm,
topbin_check=False)
# This is for the normalization in the end
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
self.frequencies = freqs
lenghts = np.ceil(Q * sr / freqs)
lenghts = torch.tensor(lenghts).float()
self.register_buffer('lenghts', lenghts)
self.basis=basis
fft_basis = fft(basis)[:,:self.n_fft//2+1] # Convert CQT kenral from time domain to freq domain
# These cqt_kernel is already in the frequency domain
cqt_kernels_real = torch.tensor(fft_basis.real.astype(np.float32))
cqt_kernels_imag = torch.tensor(fft_basis.imag.astype(np.float32))
if verbose==True:
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
# print("Getting cqt kernel done, n_fft = ",self.n_fft)
# Preparing kernels for Short-Time Fourier Transform (STFT)
# We set the frequency range in the CQT filter instead of here.
if verbose==True:
print("Creating STFT kernels ...", end='\r')
start = time()
kernel_sin, kernel_cos, self.bins2freq, _, window = create_fourier_kernels(self.n_fft, window='ones', freq_scale='no')
wsin = kernel_sin * window
wcos = kernel_cos * window
wsin = torch.tensor(wsin)
wcos = torch.tensor(wcos)
if verbose==True:
print("STFT kernels created, time used = {:.4f} seconds".format(time()-start))
if trainable_STFT:
wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT)
wcos = torch.nn.Parameter(wcos, requires_grad=trainable_STFT)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
else:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)
if trainable_CQT:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable_CQT)
cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable_CQT)
self.register_parameter('cqt_kernels_real', cqt_kernels_real)
self.register_parameter('cqt_kernels_imag', cqt_kernels_imag)
else:
self.register_buffer('cqt_kernels_real', cqt_kernels_real)
self.register_buffer('cqt_kernels_imag', cqt_kernels_imag)
# If center==True, the STFT window will be put in the middle, and paddings at the beginning
# and ending are required.
if self.pad_mode == 'constant':
self.padding = nn.ConstantPad1d(self.n_fft//2, 0)
elif self.pad_mode == 'reflect':
self.padding = nn.ReflectionPad1d(self.n_fft//2)
def forward(self,x, output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
output_format = output_format or self.output_format
x = broadcast_dim(x)
if self.earlydownsample==True:
x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor)
hop = self.hop_length
CQT = get_cqt_complex2(x, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding,
wcos=self.wcos, wsin=self.wsin)
x_down = x # Preparing a new variable for downsampling
for i in range(self.n_octaves-1):
hop = hop//2
x_down = downsampling_by_2(x_down, self.lowpass_filter)
CQT1 = get_cqt_complex2(x_down, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding,
wcos=self.wcos, wsin=self.wsin)
CQT = torch.cat((CQT1, CQT),1)
CQT = CQT[:,-self.n_bins:,:] # Removing unwanted top bins
if normalization_type == 'librosa':
CQT *= torch.sqrt(self.lenghts.view(-1,1,1))/self.n_fft
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
CQT *= 2/self.n_fft
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
if output_format=='Magnitude':
# Getting CQT Amplitude
return torch.sqrt(CQT.pow(2).sum(-1))
elif output_format=='Complex':
return CQT
elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0]))
phase_imag = torch.sin(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0]))
return torch.stack((phase_real,phase_imag), -1)
def extra_repr(self) -> str:
return 'STFT kernel size = {}, CQT kernel size = {}'.format(
(*self.wcos.shape,), (*self.cqt_kernels_real.shape,)
)
class CQT1992v2(torch.nn.Module):
"""This function is to calculate the CQT of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred autommatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
This alogrithm uses the method proposed in [1]. I slightly modify it so that it runs faster
than the original 1992 algorithm, that is why I call it version 2.
[1] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a
constant Q transform.” (1992).
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
hop_length : int
The hop (or stride) size. Default value is 512.
fmin : float
The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0.
fmax : float
The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is
inferred from the ``n_bins`` and ``bins_per_octave``.
If ``fmax`` is not ``None``, then the argument ``n_bins`` will be ignored and ``n_bins``
will be calculated automatically. Default is ``None``
n_bins : int
The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``.
bins_per_octave : int
Number of bins per octave. Default is 12.
filter_scale : float > 0
Filter scale factor. Values of filter_scale smaller than 1 can be used to improve the time resolution at the
cost of degrading the frequency resolution. Important to note is that setting for example filter_scale = 0.5 and
bins_per_octave = 48 leads to exactly the same time-frequency resolution trade-off as setting filter_scale = 1
and bins_per_octave = 24, but the former contains twice more frequency bins per octave. In this sense, values
filter_scale < 1 can be seen to implement oversampling of the frequency axis, analogously to the use of zero
padding when calculating the DFT.
norm : int
Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization.
Default is ``1``, which is same as the normalization used in librosa.
window : string, float, or tuple
The windowing function for CQT. If it is a string, It uses ``scipy.signal.get_window``. If it is a
tuple, only the gaussian window wanrantees constant Q factor. Gaussian window should be given as a
tuple ('gaussian', att) where att is the attenuation in the border given in dB.
Please refer to scipy documentation for possible windowing functions. The default value is 'hann'.
center : bool
Putting the CQT keneral at the center of the time-step or not. If ``False``, the time index is
the beginning of the CQT kernel, if ``True``, the time index is the center of the CQT kernel.
Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
trainable : bool
Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels
will also be caluclated and the CQT kernels will be updated during model training.
Default value is ``False``.
output_format : str
Determine the return type.
``Magnitude`` will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``;
``Complex`` will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``;
``Phase`` will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``.
The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms.
shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``;
shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``;
Examples
--------
>>> spec_layer = Spectrogram.CQT1992v2()
>>> specs = spec_layer(x)
"""
def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84,
bins_per_octave=12, filter_scale=1, norm=1, window='hann', center=True, pad_mode='reflect',
trainable=False, output_format='Magnitude', verbose=True):
super().__init__()
self.trainable = trainable
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
self.output_format = output_format
# creating kernels for CQT
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
if verbose==True:
print("Creating CQT kernels ...", end='\r')
start = time()
cqt_kernels, self.kernel_width, lenghts, freqs = create_cqt_kernels(Q,
sr,
fmin,
n_bins,
bins_per_octave,
norm,
window,
fmax)
self.register_buffer('lenghts', lenghts)
self.frequencies = freqs
cqt_kernels_real = torch.tensor(cqt_kernels.real).unsqueeze(1)
cqt_kernels_imag = torch.tensor(cqt_kernels.imag).unsqueeze(1)
if trainable:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable)
cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable)
self.register_parameter('cqt_kernels_real', cqt_kernels_real)
self.register_parameter('cqt_kernels_imag', cqt_kernels_imag)
else:
self.register_buffer('cqt_kernels_real', cqt_kernels_real)
self.register_buffer('cqt_kernels_imag', cqt_kernels_imag)
if verbose==True:
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
def forward(self,x, output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
normalization_type : str
Type of the normalisation. The possible options are: \n
'librosa' : the output fits the librosa one \n
'convolutional' : the output conserves the convolutional inequalities of the wavelet transform:\n
for all p ϵ [1, inf] \n
- || CQT ||_p <= || f ||_p || g ||_1 \n
- || CQT ||_p <= || f ||_1 || g ||_p \n
- || CQT ||_2 = || f ||_2 || g ||_2 \n
'wrap' : wraps positive and negative frequencies into positive frequencies. This means that the CQT of a
sinus (or a cosinus) with a constant amplitude equal to 1 will have the value 1 in the bin corresponding to
its frequency.
"""
output_format = output_format or self.output_format
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.kernel_width//2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.kernel_width//2)
x = padding(x)
# CQT
CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length)
CQT_imag = -conv1d(x, self.cqt_kernels_imag, stride=self.hop_length)
if normalization_type == 'librosa':
CQT_real *= torch.sqrt(self.lenghts.view(-1, 1))
CQT_imag *= torch.sqrt(self.lenghts.view(-1, 1))
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
CQT_real *= 2
CQT_imag *= 2
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
if output_format=='Magnitude':
if self.trainable==False:
# Getting CQT Amplitude
CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2))
else:
CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)+1e-8)
return CQT
elif output_format=='Complex':
return torch.stack((CQT_real,CQT_imag),-1)
elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(CQT_imag,CQT_real))
phase_imag = torch.sin(torch.atan2(CQT_imag,CQT_real))
return torch.stack((phase_real,phase_imag), -1)
def forward_manual(self,x):
"""
Method for debugging
"""
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.kernel_width//2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.kernel_width//2)
x = padding(x)
# CQT
CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length)
CQT_imag = conv1d(x, self.cqt_kernels_imag, stride=self.hop_length)
# Getting CQT Amplitude
CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2))
return CQT*torch.sqrt(self.lenghts.view(-1,1))
class CQT2010v2(torch.nn.Module):
"""This function is to calculate the CQT of the input signal.
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
The correct shape will be inferred autommatically if the input follows these 3 shapes.
Most of the arguments follow the convention from librosa.
This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``.
This alogrithm uses the resampling method proposed in [1].
Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency
spectrum, we make a small CQT kernel covering only the top octave. Then we keep downsampling the
input audio by a factor of 2 to convoluting it with the small CQT kernel.
Everytime the input audio is downsampled, the CQT relative to the downsampled input is equivalent
to the next lower octave.
The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the
code from the 1992 alogrithm [2]
[1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010).
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a
constant Q transform.” (1992).
Early downsampling factor is to downsample the input audio to reduce the CQT kernel size.
The result with and without early downsampling are more or less the same except in the very low
frequency region where freq < 40Hz.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
hop_length : int
The hop (or stride) size. Default value is 512.
fmin : float
The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0.
fmax : float
The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is
inferred from the ``n_bins`` and ``bins_per_octave``. If ``fmax`` is not ``None``, then the
argument ``n_bins`` will be ignored and ``n_bins`` will be calculated automatically.
Default is ``None``
n_bins : int
The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``.
bins_per_octave : int
Number of bins per octave. Default is 12.
norm : bool
Normalization for the CQT result.
basis_norm : int
Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization.
Default is ``1``, which is same as the normalization used in librosa.
window : str
The windowing function for CQT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'
pad_mode : str
The padding method. Default value is 'reflect'.
trainable : bool
Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels
will also be caluclated and the CQT kernels will be updated during model training.
Default value is ``False``
output_format : str
Determine the return type.
'Magnitude' will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins, time_steps)``;
'Complex' will return the STFT result in complex number, shape = ``(num_samples, freq_bins, time_steps, 2)``;
'Phase' will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``.
The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints.
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms.
shape = ``(num_samples, freq_bins,time_steps)`` if ``output_format='Magnitude'``;
shape = ``(num_samples, freq_bins,time_steps, 2)`` if ``output_format='Complex' or 'Phase'``;
Examples
--------
>>> spec_layer = Spectrogram.CQT2010v2()
>>> specs = spec_layer(x)
"""
# To DO:
# need to deal with the filter and other tensors
def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, filter_scale=1,
bins_per_octave=12, norm=True, basis_norm=1, window='hann', pad_mode='reflect',
earlydownsample=True, trainable=False, output_format='Magnitude', verbose=True):
super().__init__()
self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft
# basis_norm is for normalizing basis
self.hop_length = hop_length
self.pad_mode = pad_mode
self.n_bins = n_bins
self.earlydownsample = earlydownsample # We will activate early downsampling later if possible
self.trainable = trainable
self.output_format = output_format
# It will be used to calculate filter_cutoff and creating CQT kernels
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
# Creating lowpass filter and make it a torch tensor
if verbose==True:
print("Creating low pass filter ...", end='\r')
start = time()
# self.lowpass_filter = torch.tensor(
# create_lowpass_filter(
# band_center = 0.50,
# kernelLength=256,
# transitionBandwidth=0.001))
lowpass_filter = torch.tensor(create_lowpass_filter(
band_center = 0.50,
kernelLength=256,
transitionBandwidth=0.001)
)
# Broadcast the tensor to the shape that fits conv1d
self.register_buffer('lowpass_filter', lowpass_filter[None,None,:])
if verbose==True:
print("Low pass filter created, time used = {:.4f} seconds".format(time()-start))
# Caluate num of filter requires for the kernel
# n_octaves determines how many resampling requires for the CQT
n_filters = min(bins_per_octave, n_bins)
self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
if verbose==True:
print("num_octave = ", self.n_octaves)
# Calculate the lowest frequency bin for the top octave kernel
self.fmin_t = fmin*2**(self.n_octaves-1)
remainder = n_bins % bins_per_octave
# print("remainder = ", remainder)
if remainder==0:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave)
else:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave)
self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins
if fmax_t > sr/2:
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
please reduce the n_bins'.format(fmax_t))
if self.earlydownsample == True: # Do early downsampling if this argument is True
if verbose==True:
print("Creating early downsampling filter ...", end='\r')
start = time()
sr, self.hop_length, self.downsample_factor, early_downsample_filter, \
self.earlydownsample = get_early_downsample_params(sr,
hop_length,
fmax_t,
Q,
self.n_octaves,
verbose)
self.register_buffer('early_downsample_filter', early_downsample_filter)
if verbose==True:
print("Early downsampling filter created, \
time used = {:.4f} seconds".format(time()-start))
else:
self.downsample_factor=1.
# Preparing CQT kernels
if verbose==True:
print("Creating CQT kernels ...", end='\r')
start = time()
basis, self.n_fft, lenghts, _ = create_cqt_kernels(Q,
sr,
self.fmin_t,
n_filters,
bins_per_octave,
norm=basis_norm,
topbin_check=False)
# For normalization in the end
# The freqs returned by create_cqt_kernels cannot be used
# Since that returns only the top octave bins
# We need the information for all freq bin
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
self.frequencies = freqs
lenghts = np.ceil(Q * sr / freqs)
lenghts = torch.tensor(lenghts).float()
self.register_buffer('lenghts', lenghts)
self.basis = basis
# These cqt_kernel is already in the frequency domain
cqt_kernels_real = torch.tensor(basis.real.astype(np.float32)).unsqueeze(1)
cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32)).unsqueeze(1)
if trainable:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable)
cqt_kernels_imag = torch.nn.Parameter(cqt_kernels_imag, requires_grad=trainable)
self.register_parameter('cqt_kernels_real', cqt_kernels_real)
self.register_parameter('cqt_kernels_imag', cqt_kernels_imag)
else:
self.register_buffer('cqt_kernels_real', cqt_kernels_real)
self.register_buffer('cqt_kernels_imag', cqt_kernels_imag)
if verbose==True:
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
# print("Getting cqt kernel done, n_fft = ",self.n_fft)
# If center==True, the STFT window will be put in the middle, and paddings at the beginning
# and ending are required.
if self.pad_mode == 'constant':
self.padding = nn.ConstantPad1d(self.n_fft//2, 0)
elif self.pad_mode == 'reflect':
self.padding = nn.ReflectionPad1d(self.n_fft//2)
def forward(self,x,output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
output_format = output_format or self.output_format
x = broadcast_dim(x)
if self.earlydownsample==True:
x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor)
hop = self.hop_length
CQT = get_cqt_complex(x, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding) # Getting the top octave CQT
x_down = x # Preparing a new variable for downsampling
for i in range(self.n_octaves-1):
hop = hop//2
x_down = downsampling_by_2(x_down, self.lowpass_filter)
CQT1 = get_cqt_complex(x_down, self.cqt_kernels_real, self.cqt_kernels_imag, hop, self.padding)
CQT = torch.cat((CQT1, CQT),1)
CQT = CQT[:,-self.n_bins:,:] # Removing unwanted bottom bins
# print("downsample_factor = ",self.downsample_factor)
# print(CQT.shape)
# print(self.lenghts.view(-1,1).shape)
# Normalizing the output with the downsampling factor, 2**(self.n_octaves-1) is make it
# same mag as 1992
CQT = CQT*self.downsample_factor
# Normalize again to get same result as librosa
if normalization_type == 'librosa':
CQT = CQT*torch.sqrt(self.lenghts.view(-1,1,1))
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
CQT *= 2
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
if output_format=='Magnitude':
if self.trainable==False:
# Getting CQT Amplitude
return torch.sqrt(CQT.pow(2).sum(-1))
else:
return torch.sqrt(CQT.pow(2).sum(-1)+1e-8)
elif output_format=='Complex':
return CQT
elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0]))
phase_imag = torch.sin(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0]))
return torch.stack((phase_real,phase_imag), -1)
class CQT(CQT1992v2):
"""An abbreviation for :func:`~nnAudio.Spectrogram.CQT1992v2`. Please refer to the :func:`~nnAudio.Spectrogram.CQT1992v2` documentation"""
pass
# The section below is for developing purpose
# Please don't use the following classes
#
class DFT(torch.nn.Module):
"""
Experimental feature before `torch.fft` was made avaliable.
The inverse function only works for 1 single frame. i.e. input shape = (batch, n_fft, 1)
"""
def __init__(self, n_fft=2048, freq_bins=None, hop_length=512,
window='hann', freq_scale='no', center=True, pad_mode='reflect',
fmin=50, fmax=6000, sr=22050):
super().__init__()
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
# Create filter windows for stft
wsin, wcos, self.bins2freq = create_fourier_kernels(n_fft=n_fft,
freq_bins=n_fft,
window=window,
freq_scale=freq_scale,
fmin=fmin,
fmax=fmax,
sr=sr)
self.wsin = torch.tensor(wsin, dtype=torch.float)
self.wcos = torch.tensor(wcos, dtype=torch.float)
def forward(self,x):
"""
Convert a batch of waveforms to spectrums.
Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.n_fft//2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.n_fft//2)
x = padding(x)
imag = conv1d(x, self.wsin, stride=self.stride)
real = conv1d(x, self.wcos, stride=self.stride)
return (real, -imag)
def inverse(self,x_real,x_imag):
"""
Convert a batch of waveforms to CQT spectrograms.
Parameters
----------
x_real : torch tensor
Real part of the signal.
x_imag : torch tensor
Imaginary part of the signal.
"""
x_real = broadcast_dim(x_real)
x_imag = broadcast_dim(x_imag)
x_real.transpose_(1,2) # Prepare the right shape to do inverse
x_imag.transpose_(1,2) # Prepare the right shape to do inverse
# if self.center:
# if self.pad_mode == 'constant':
# padding = nn.ConstantPad1d(self.n_fft//2, 0)
# elif self.pad_mode == 'reflect':
# padding = nn.ReflectionPad1d(self.n_fft//2)
# x_real = padding(x_real)
# x_imag = padding(x_imag)
# Watch out for the positive and negative signs
# ifft = e^(+2\pi*j)*X
# ifft(X_real) = (a1, a2)
# ifft(X_imag)*1j = (b1, b2)*1j
# = (-b2, b1)
a1 = conv1d(x_real, self.wcos, stride=self.stride)
a2 = conv1d(x_real, self.wsin, stride=self.stride)
b1 = conv1d(x_imag, self.wcos, stride=self.stride)
b2 = conv1d(x_imag, self.wsin, stride=self.stride)
imag = a2+b1
real = a1-b2
return (real/self.n_fft, imag/self.n_fft)
class iSTFT(torch.nn.Module):
"""This class is to convert spectrograms back to waveforms. It only works for the complex value spectrograms.
If you have the magnitude spectrograms, please use :func:`~nnAudio.Spectrogram.Griffin_Lim`.
The parameters (e.g. n_fft, window) need to be the same as the STFT in order to obtain the correct inverse.
If trainability is not required, it is recommended to use the ``inverse`` method under the ``STFT`` class
to save GPU/RAM memory.
When ``trainable=True`` and ``freq_scale!='no'``, there is no guarantee that the inverse is perfect, please
use with extra care.
Parameters
----------
n_fft : int
The window size. Default value is 2048.
freq_bins : int
Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins
Please make sure the value is the same as the forward STFT.
hop_length : int
The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``.
Please make sure the value is the same as the forward STFT.
window : str
The windowing function for iSTFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
Please make sure the value is the same as the forward STFT.
freq_scale : 'linear', 'log', or 'no'
Determine the spacing between each frequency bin. When `linear` or `log` is used,
the bin spacing can be controlled by ``fmin`` and ``fmax``. If 'no' is used, the bin will
start at 0Hz and end at Nyquist frequency with linear spacing.
Please make sure the value is the same as the forward STFT.
center : bool
Putting the iSTFT keneral at the center of the time-step or not. If ``False``, the time
index is the beginning of the iSTFT kernel, if ``True``, the time index is the center of
the iSTFT kernel. Default value if ``True``.
Please make sure the value is the same as the forward STFT.
fmin : int
The starting frequency for the lowest frequency bin. If freq_scale is ``no``, this argument
does nothing. Please make sure the value is the same as the forward STFT.
fmax : int
The ending frequency for the highest frequency bin. If freq_scale is ``no``, this argument
does nothing. Please make sure the value is the same as the forward STFT.
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
trainable_kernels : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT
kernels will also be caluclated and the STFT kernels will be updated during model training.
Default value is ``False``.
trainable_window : bool
Determine if the window function is trainable or not.
Default value is ``False``.
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints.
Returns
-------
spectrogram : torch.tensor
It returns a batch of waveforms.
Examples
--------
>>> spec_layer = Spectrogram.iSTFT()
>>> specs = spec_layer(x)
"""
def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',
freq_scale='no', center=True, fmin=50, fmax=6000, sr=22050, trainable_kernels=False,
trainable_window=False, verbose=True, refresh_win=True):
super().__init__()
# Trying to make the default setting same as librosa
if win_length==None: win_length = n_fft
if hop_length==None: hop_length = int(win_length // 4)
self.n_fft = n_fft
self.win_length = win_length
self.stride = hop_length
self.center = center
self.pad_amount = self.n_fft // 2
self.refresh_win = refresh_win
start = time()
# Create the window function and prepare the shape for batch-wise-time-wise multiplication
# Create filter windows for inverse
kernel_sin, kernel_cos, _, _, window_mask = create_fourier_kernels(n_fft,
win_length=win_length,
freq_bins=n_fft,
window=window,
freq_scale=freq_scale,
fmin=fmin,
fmax=fmax,
sr=sr,
verbose=False)
window_mask = get_window(window,int(win_length), fftbins=True)
# For inverse, the Fourier kernels do not need to be windowed
window_mask = torch.tensor(window_mask).unsqueeze(0).unsqueeze(-1)
# kernel_sin and kernel_cos have the shape (freq_bins, 1, n_fft, 1) to support 2D Conv
kernel_sin = torch.tensor(kernel_sin, dtype=torch.float).unsqueeze(-1)
kernel_cos = torch.tensor(kernel_cos, dtype=torch.float).unsqueeze(-1)
# Decide if the Fourier kernels are trainable
if trainable_kernels:
# Making all these variables trainable
kernel_sin = torch.nn.Parameter(kernel_sin, requires_grad=trainable_kernels)
kernel_cos = torch.nn.Parameter(kernel_cos, requires_grad=trainable_kernels)
self.register_parameter('kernel_sin', kernel_sin)
self.register_parameter('kernel_cos', kernel_cos)
else:
self.register_buffer('kernel_sin', kernel_sin)
self.register_buffer('kernel_cos', kernel_cos)
# Decide if the window function is trainable
if trainable_window:
window_mask = torch.nn.Parameter(window_mask, requires_grad=trainable_window)
self.register_parameter('window_mask', window_mask)
else:
self.register_buffer('window_mask', window_mask)
if verbose==True:
print("iSTFT kernels created, time used = {:.4f} seconds".format(time()-start))
else:
pass
def forward(self, X, onesided=False, length=None, refresh_win=None):
"""
If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,
else use ``onesided=False``
To make sure the inverse STFT has the same output length of the original waveform, please
set `length` as your intended waveform length. By default, ``length=None``,
which will remove ``n_fft//2`` samples from the start and the end of the output.
If your input spectrograms X are of the same length, please use ``refresh_win=None`` to increase
computational speed.
"""
if refresh_win==None:
refresh_win=self.refresh_win
assert X.dim()==4 , "Inverse iSTFT only works for complex number," \
"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2)"
# If the input spectrogram contains only half of the n_fft
# Use extend_fbins function to get back another half
if onesided:
X = extend_fbins(X) # extend freq
X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]
# broadcast dimensions to support 2D convolution
X_real_bc = X_real.unsqueeze(1)
X_imag_bc = X_imag.unsqueeze(1)
a1 = conv2d(X_real_bc, self.kernel_cos, stride=(1,1))
b2 = conv2d(X_imag_bc, self.kernel_sin, stride=(1,1))
# compute real and imag part. signal lies in the real part
real = a1 - b2
real = real.squeeze(-2)*self.window_mask
# Normalize the amplitude with n_fft
real /= (self.n_fft)
# Overlap and Add algorithm to connect all the frames
real = overlap_add(real, self.stride)
# Prepare the window sumsqure for division
# Only need to create this window once to save time
# Unless the input spectrograms have different time steps
if hasattr(self, 'w_sum')==False or refresh_win==True:
self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()
self.nonzero_indices = (self.w_sum>1e-10)
else:
pass
real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])
# Remove padding
if length is None:
if self.center:
real = real[:, self.pad_amount:-self.pad_amount]
else:
if self.center:
real = real[:, self.pad_amount:self.pad_amount + length]
else:
real = real[:, :length]
return real
class Griffin_Lim(torch.nn.Module):
"""
Converting Magnitude spectrograms back to waveforms based on the "fast Griffin-Lim"[1].
This Griffin Lim is a direct clone from librosa.griffinlim.
[1] Perraudin, N., Balazs, P., & Søndergaard, P. L. “A fast Griffin-Lim algorithm,”
IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4), Oct. 2013.
Parameters
----------
n_fft : int
The window size. Default value is 2048.
n_iter=32 : int
The number of iterations for Griffin-Lim. The default value is ``32``
hop_length : int
The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``.
Please make sure the value is the same as the forward STFT.
window : str
The windowing function for iSTFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
Please make sure the value is the same as the forward STFT.
center : bool
Putting the iSTFT keneral at the center of the time-step or not. If ``False``, the time
index is the beginning of the iSTFT kernel, if ``True``, the time index is the center of
the iSTFT kernel. Default value if ``True``.
Please make sure the value is the same as the forward STFT.
momentum : float
The momentum for the update rule. The default value is ``0.99``.
device : str
Choose which device to initialize this layer. Default value is 'cpu'
"""
def __init__(self,
n_fft,
n_iter=32,
hop_length=None,
win_length=None,
window='hann',
center=True,
pad_mode='reflect',
momentum=0.99,
device='cpu'):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.n_iter = n_iter
self.center = center
self.pad_mode = pad_mode
self.momentum = momentum
self.device = device
if win_length==None:
self.win_length=n_fft
else:
self.win_length=win_length
if hop_length==None:
self.hop_length = n_fft//4
else:
self.hop_length = hop_length
# Creating window function for stft and istft later
self.w = torch.tensor(get_window(window,
int(self.win_length),
fftbins=True),
device=device).float()
def forward(self, S):
"""
Convert a batch of magnitude spectrograms to waveforms.
Parameters
----------
S : torch tensor
Spectrogram of the shape ``(batch, n_fft//2+1, timesteps)``
"""
assert S.dim()==3 , "Please make sure your input is in the shape of (batch, freq_bins, timesteps)"
# Initializing Random Phase
rand_phase = torch.randn(*S.shape, device=self.device)
angles = torch.empty((*S.shape,2), device=self.device)
angles[:, :,:,0] = torch.cos(2 * np.pi * rand_phase)
angles[:,:,:,1] = torch.sin(2 * np.pi * rand_phase)
# Initializing the rebuilt magnitude spectrogram
rebuilt = torch.zeros(*angles.shape, device=self.device)
for _ in range(self.n_iter):
tprev = rebuilt # Saving previous rebuilt magnitude spec
# spec2wav conversion
# print(f'win_length={self.win_length}\tw={self.w.shape}')
inverse = torch.istft(S.unsqueeze(-1) * angles,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
center=self.center)
# wav2spec conversion
rebuilt = torch.stft(inverse,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
pad_mode=self.pad_mode)
# Phase update rule
angles[:,:,:] = rebuilt[:,:,:] - (self.momentum / (1 + self.momentum)) * tprev[:,:,:]
# Phase normalization
angles = angles.div(torch.sqrt(angles.pow(2).sum(-1)).unsqueeze(-1) + 1e-16) # normalizing the phase
# Using the final phase to reconstruct the waveforms
inverse = torch.istft(S.unsqueeze(-1) * angles,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
center=self.center)
return inverse
class Combined_Frequency_Periodicity(nn.Module):
"""
Vectorized version of the code in https://github.com/leo-so/VocalMelodyExtPatchCNN/blob/master/MelodyExt.py.
This feature is described in 'Combining Spectral and Temporal Representations for Multipitch Estimation of Polyphonic Music'
https://ieeexplore.ieee.org/document/7118691
Under development, please report any bugs you found
"""
def __init__(self,fr=2, fs=16000, hop_length=320,
window_size=2049, fc=80, tc=1/1000,
g=[0.24, 0.6, 1], NumPerOct=48):
super().__init__()
self.window_size = window_size
self.hop_length = hop_length
# variables for STFT part
self.N = int(fs/float(fr)) # Will be used to calculate padding
self.f = fs*np.linspace(0, 0.5, np.round(self.N//2), endpoint=True) # it won't be used but will be returned
self.pad_value = ((self.N-window_size))
# Create window function, always blackmanharris?
h = scipy.signal.blackmanharris(window_size).astype(np.float32) # window function for STFT
self.register_buffer('h',torch.tensor(h))
# variables for CFP
self.NumofLayer = np.size(g)
self.g = g
self.tc_idx = round(fs*tc) # index to filter out top tc_idx and bottom tc_idx bins
self.fc_idx = round(fc/fr) # index to filter out top fc_idx and bottom fc_idx bins
self.HighFreqIdx = int(round((1/tc)/fr)+1)
self.HighQuefIdx = int(round(fs/fc)+1)
# attributes to be returned
self.f = self.f[:self.HighFreqIdx]
self.q = np.arange(self.HighQuefIdx)/float(fs)
# filters for the final step
freq2logfreq_matrix, quef2logfreq_matrix = self.create_logfreq_matrix(self.f, self.q, fr, fc, tc, NumPerOct, fs)
self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix.astype(np.float32)))
self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix.astype(np.float32)))
def _CFP(self, spec):
spec = torch.relu(spec).pow(self.g[0])
if self.NumofLayer >= 2:
for gc in range(1, self.NumofLayer):
if np.remainder(gc, 2) == 1:
ceps = torch.rfft(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
ceps = self.nonlinear_func(ceps, self.g[gc], self.tc_idx)
else:
spec = torch.rfft(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
spec = self.nonlinear_func(spec, self.g[gc], self.fc_idx)
return spec, ceps
def forward(self, x):
tfr0 = torch.stft(x, self.N, hop_length=self.hop_length, win_length=self.window_size,
window=self.h, onesided=False, pad_mode='constant')
tfr0 = torch.sqrt(tfr0.pow(2).sum(-1))/torch.norm(self.h) # calcuate magnitude
tfr0 = tfr0.transpose(1,2)[:,1:-1] #transpose F and T axis and discard first and last frames
# The transpose is necessary for rfft later
# (batch, timesteps, n_fft)
tfr, ceps = self._CFP(tfr0)
# return tfr0
# removing duplicate bins
tfr0 = tfr0[:,:,:int(round(self.N/2))]
tfr = tfr[:,:,:int(round(self.N/2))]
ceps = ceps[:,:,:int(round(self.N/2))]
# Crop up to the highest frequency
tfr0 = tfr0[:,:,:self.HighFreqIdx]
tfr = tfr[:,:,:self.HighFreqIdx]
ceps = ceps[:,:,:self.HighQuefIdx]
tfrL0 = torch.matmul(self.freq2logfreq_matrix, tfr0.transpose(1,2))
tfrLF = torch.matmul(self.freq2logfreq_matrix, tfr.transpose(1,2))
tfrLQ = torch.matmul(self.quef2logfreq_matrix, ceps.transpose(1,2))
Z = tfrLF * tfrLQ
# Only need to calculate this once
self.t = np.arange(self.hop_length,
np.ceil(len(x)/float(self.hop_length))*self.hop_length,
self.hop_length) # it won't be used but will be returned
return Z, tfrL0, tfrLF, tfrLQ
def nonlinear_func(self, X, g, cutoff):
cutoff = int(cutoff)
if g!=0:
X = torch.relu(X)
X[:, :, :cutoff] = 0
X[:, :, -cutoff:] = 0
X = X.pow(g)
else: # when g=0, it converges to log
X = torch.log(X)
X[:, :, :cutoff] = 0
X[:, :, -cutoff:] = 0
return X
def create_logfreq_matrix(self, f, q, fr, fc, tc, NumPerOct, fs):
StartFreq = fc
StopFreq = 1/tc
Nest = int(np.ceil(np.log2(StopFreq/StartFreq))*NumPerOct)
central_freq = [] # A list holding the frequencies in log scale
for i in range(0, Nest):
CenFreq = StartFreq*pow(2, float(i)/NumPerOct)
if CenFreq < StopFreq:
central_freq.append(CenFreq)
else:
break
Nest = len(central_freq)
freq_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float)
# Calculating the freq_band_transformation
for i in range(1, Nest-1):
l = int(round(central_freq[i-1]/fr))
r = int(round(central_freq[i+1]/fr)+1)
#rounding1
if l >= r-1:
freq_band_transformation[i, l] = 1
else:
for j in range(l, r):
if f[j] > central_freq[i-1] and f[j] < central_freq[i]:
freq_band_transformation[i, j] = (f[j] - central_freq[i-1]) / (central_freq[i] - central_freq[i-1])
elif f[j] > central_freq[i] and f[j] < central_freq[i+1]:
freq_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i])
# Calculating the quef_band_transformation
f = 1/q # divide by 0, do I need to fix this?
quef_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float)
for i in range(1, Nest-1):
for j in range(int(round(fs/central_freq[i+1])), int(round(fs/central_freq[i-1])+1)):
if f[j] > central_freq[i-1] and f[j] < central_freq[i]:
quef_band_transformation[i, j] = (f[j] - central_freq[i-1])/(central_freq[i] - central_freq[i-1])
elif f[j] > central_freq[i] and f[j] < central_freq[i+1]:
quef_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i])
return freq_band_transformation, quef_band_transformation
class CFP(nn.Module):
"""
This is the modified version so that the number of timesteps fits with other classes
Under development, please report any bugs you found
"""
def __init__(self,fr=2, fs=16000, hop_length=320,
window_size=2049, fc=80, tc=1/1000,
g=[0.24, 0.6, 1], NumPerOct=48):
super().__init__()
self.window_size = window_size
self.hop_length = hop_length
# variables for STFT part
self.N = int(fs/float(fr)) # Will be used to calculate padding
self.f = fs*np.linspace(0, 0.5, np.round(self.N//2), endpoint=True) # it won't be used but will be returned
self.pad_value = ((self.N-window_size))
# Create window function, always blackmanharris?
h = scipy.signal.blackmanharris(window_size).astype(np.float32) # window function for STFT
self.register_buffer('h',torch.tensor(h))
# variables for CFP
self.NumofLayer = np.size(g)
self.g = g
self.tc_idx = round(fs*tc) # index to filter out top tc_idx and bottom tc_idx bins
self.fc_idx = round(fc/fr) # index to filter out top fc_idx and bottom fc_idx bins
self.HighFreqIdx = int(round((1/tc)/fr)+1)
self.HighQuefIdx = int(round(fs/fc)+1)
# attributes to be returned
self.f = self.f[:self.HighFreqIdx]
self.q = np.arange(self.HighQuefIdx)/float(fs)
# filters for the final step
freq2logfreq_matrix, quef2logfreq_matrix = self.create_logfreq_matrix(self.f, self.q, fr, fc, tc, NumPerOct, fs)
self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix.astype(np.float32)))
self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix.astype(np.float32)))
def _CFP(self, spec):
spec = torch.relu(spec).pow(self.g[0])
if self.NumofLayer >= 2:
for gc in range(1, self.NumofLayer):
if np.remainder(gc, 2) == 1:
ceps = torch.rfft(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
ceps = self.nonlinear_func(ceps, self.g[gc], self.tc_idx)
else:
spec = torch.rfft(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
spec = self.nonlinear_func(spec, self.g[gc], self.fc_idx)
return spec, ceps
def forward(self, x):
tfr0 = torch.stft(x, self.N, hop_length=self.hop_length, win_length=self.window_size,
window=self.h, onesided=False, pad_mode='constant')
tfr0 = torch.sqrt(tfr0.pow(2).sum(-1))/torch.norm(self.h) # calcuate magnitude
tfr0 = tfr0.transpose(1,2) #transpose F and T axis and discard first and last frames
# The transpose is necessary for rfft later
# (batch, timesteps, n_fft)
tfr, ceps = self._CFP(tfr0)
# return tfr0
# removing duplicate bins
tfr0 = tfr0[:,:,:int(round(self.N/2))]
tfr = tfr[:,:,:int(round(self.N/2))]
ceps = ceps[:,:,:int(round(self.N/2))]
# Crop up to the highest frequency
tfr0 = tfr0[:,:,:self.HighFreqIdx]
tfr = tfr[:,:,:self.HighFreqIdx]
ceps = ceps[:,:,:self.HighQuefIdx]
tfrL0 = torch.matmul(self.freq2logfreq_matrix, tfr0.transpose(1,2))
tfrLF = torch.matmul(self.freq2logfreq_matrix, tfr.transpose(1,2))
tfrLQ = torch.matmul(self.quef2logfreq_matrix, ceps.transpose(1,2))
Z = tfrLF * tfrLQ
# Only need to calculate this once
self.t = np.arange(self.hop_length,
np.ceil(len(x)/float(self.hop_length))*self.hop_length,
self.hop_length) # it won't be used but will be returned
return Z#, tfrL0, tfrLF, tfrLQ
def nonlinear_func(self, X, g, cutoff):
cutoff = int(cutoff)
if g!=0:
X = torch.relu(X)
X[:, :, :cutoff] = 0
X[:, :, -cutoff:] = 0
X = X.pow(g)
else: # when g=0, it converges to log
X = torch.log(X)
X[:, :, :cutoff] = 0
X[:, :, -cutoff:] = 0
return X
def create_logfreq_matrix(self, f, q, fr, fc, tc, NumPerOct, fs):
StartFreq = fc
StopFreq = 1/tc
Nest = int(np.ceil(np.log2(StopFreq/StartFreq))*NumPerOct)
central_freq = [] # A list holding the frequencies in log scale
for i in range(0, Nest):
CenFreq = StartFreq*pow(2, float(i)/NumPerOct)
if CenFreq < StopFreq:
central_freq.append(CenFreq)
else:
break
Nest = len(central_freq)
freq_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float)
# Calculating the freq_band_transformation
for i in range(1, Nest-1):
l = int(round(central_freq[i-1]/fr))
r = int(round(central_freq[i+1]/fr)+1)
#rounding1
if l >= r-1:
freq_band_transformation[i, l] = 1
else:
for j in range(l, r):
if f[j] > central_freq[i-1] and f[j] < central_freq[i]:
freq_band_transformation[i, j] = (f[j] - central_freq[i-1]) / (central_freq[i] - central_freq[i-1])
elif f[j] > central_freq[i] and f[j] < central_freq[i+1]:
freq_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i])
# Calculating the quef_band_transformation
f = 1/q # divide by 0, do I need to fix this?
quef_band_transformation = np.zeros((Nest-1, len(f)), dtype=np.float)
for i in range(1, Nest-1):
for j in range(int(round(fs/central_freq[i+1])), int(round(fs/central_freq[i-1])+1)):
if f[j] > central_freq[i-1] and f[j] < central_freq[i]:
quef_band_transformation[i, j] = (f[j] - central_freq[i-1])/(central_freq[i] - central_freq[i-1])
elif f[j] > central_freq[i] and f[j] < central_freq[i+1]:
quef_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i])
return freq_band_transformation, quef_band_transformation
__version__ = "0.2.2"
\ No newline at end of file
"""
Module containing functions cloned from librosa
To make sure nnAudio would not become broken when updating librosa
"""
import numpy as np
import warnings
### ----------------Functions for generating kenral for Mel Spectrogram------------ ###
# This code is equalvant to from librosa.filters import mel
# By doing so, we can run nnAudio without installing librosa
def fft2gammatonemx(sr=20000, n_fft=2048, n_bins=64, width=1.0, fmin=0.0,
fmax=11025, maxlen=1024):
"""
# Ellis' description in MATLAB:
# [wts,cfreqa] = fft2gammatonemx(nfft, sr, nfilts, width, minfreq, maxfreq, maxlen)
# Generate a matrix of weights to combine FFT bins into
# Gammatone bins. nfft defines the source FFT size at
# sampling rate sr. Optional nfilts specifies the number of
# output bands required (default 64), and width is the
# constant width of each band in Bark (default 1).
# minfreq, maxfreq specify range covered in Hz (100, sr/2).
# While wts has nfft columns, the second half are all zero.
# Hence, aud spectrum is
# fft2gammatonemx(nfft,sr)*abs(fft(xincols,nfft));
# maxlen truncates the rows to this many bins.
# cfreqs returns the actual center frequencies of each
# gammatone band in Hz.
#
# 2009/02/22 02:29:25 Dan Ellis dpwe@ee.columbia.edu based on rastamat/audspec.m
# Sat May 27 15:37:50 2017 Maddie Cusimano, mcusi@mit.edu 27 May 2017: convert to python
"""
wts = np.zeros([n_bins, n_fft], dtype=np.float32)
# after Slaney's MakeERBFilters
EarQ = 9.26449;
minBW = 24.7;
order = 1;
nFr = np.array(range(n_bins)) + 1
em = EarQ * minBW
cfreqs = (fmax + em) * np.exp(nFr * (-np.log(fmax + em) + np.log(fmin + em)) / n_bins) - em
cfreqs = cfreqs[::-1]
GTord = 4
ucircArray = np.array(range(int(n_fft / 2 + 1)))
ucirc = np.exp(1j * 2 * np.pi * ucircArray / n_fft);
# justpoles = 0 :taking out the 'if' corresponding to this.
ERB = width * np.power(np.power(cfreqs / EarQ, order) + np.power(minBW, order), 1 / order);
B = 1.019 * 2 * np.pi * ERB;
r = np.exp(-B / sr)
theta = 2 * np.pi * cfreqs / sr
pole = r * np.exp(1j * theta)
T = 1 / sr
ebt = np.exp(B * T);
cpt = 2 * cfreqs * np.pi * T;
ccpt = 2 * T * np.cos(cpt);
scpt = 2 * T * np.sin(cpt);
A11 = -np.divide(np.divide(ccpt, ebt) + np.divide(np.sqrt(3 + 2 ** 1.5) * scpt, ebt), 2);
A12 = -np.divide(np.divide(ccpt, ebt) - np.divide(np.sqrt(3 + 2 ** 1.5) * scpt, ebt), 2);
A13 = -np.divide(np.divide(ccpt, ebt) + np.divide(np.sqrt(3 - 2 ** 1.5) * scpt, ebt), 2);
A14 = -np.divide(np.divide(ccpt, ebt) - np.divide(np.sqrt(3 - 2 ** 1.5) * scpt, ebt), 2);
zros = -np.array([A11, A12, A13, A14]) / T;
wIdx = range(int(n_fft / 2 + 1))
gain = np.abs((-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
np.cos(2 * cfreqs * np.pi * T) - np.sqrt(3 - 2 ** (3 / 2)) * np.sin(
2 * cfreqs * np.pi * T))) * (-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (np.cos(2 * cfreqs * np.pi * T) + np.sqrt(
3 - 2 ** (3 / 2)) * np.sin(2 * cfreqs * np.pi * T))) * (
-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
np.cos(2 * cfreqs * np.pi * T) - np.sqrt(3 + 2 ** (3 / 2)) * np.sin(
2 * cfreqs * np.pi * T))) * (
-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
np.cos(2 * cfreqs * np.pi * T) + np.sqrt(3 + 2 ** (3 / 2)) * np.sin(
2 * cfreqs * np.pi * T))) / (
-2 / np.exp(2 * B * T) - 2 * np.exp(4 * 1j * cfreqs * np.pi * T) + 2 * (
1 + np.exp(4 * 1j * cfreqs * np.pi * T)) / np.exp(B * T)) ** 4);
# in MATLAB, there used to be 64 where here it says n_bins:
wts[:, wIdx] = ((T ** 4) / np.reshape(gain, (n_bins, 1))) * np.abs(
ucirc - np.reshape(zros[0], (n_bins, 1))) * np.abs(ucirc - np.reshape(zros[1], (n_bins, 1))) * np.abs(
ucirc - np.reshape(zros[2], (n_bins, 1))) * np.abs(ucirc - np.reshape(zros[3], (n_bins, 1))) * (np.abs(
np.power(np.multiply(np.reshape(pole, (n_bins, 1)) - ucirc, np.conj(np.reshape(pole, (n_bins, 1))) - ucirc),
-GTord)));
wts = wts[:, range(maxlen)];
return wts, cfreqs
def gammatone(sr, n_fft, n_bins=64, fmin=20.0, fmax=None, htk=False,
norm=1, dtype=np.float32):
"""Create a Filterbank matrix to combine FFT bins into Gammatone bins
Parameters
----------
sr : number > 0 [scalar]
sampling rate of the incoming signal
n_fft : int > 0 [scalar]
number of FFT components
n_bins : int > 0 [scalar]
number of Mel bands to generate
fmin : float >= 0 [scalar]
lowest frequency (in Hz)
fmax : float >= 0 [scalar]
highest frequency (in Hz).
If `None`, use `fmax = sr / 2.0`
htk : bool [scalar]
use HTK formula instead of Slaney
norm : {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
dtype : np.dtype
The data type of the output basis.
By default, uses 32-bit (single-precision) floating point.
Returns
-------
G : np.ndarray [shape=(n_bins, 1 + n_fft/2)]
Gammatone transform matrix
"""
if fmax is None:
fmax = float(sr) / 2
n_bins = int(n_bins)
weights,_ = fft2gammatonemx(sr=sr, n_fft=n_fft, n_bins=n_bins, fmin=fmin, fmax=fmax, maxlen=int(n_fft//2+1))
return (1/n_fft)*weights
def mel_to_hz(mels, htk=False):
"""Convert mel bin numbers to frequencies
Examples
--------
>>> librosa.mel_to_hz(3)
200.
>>> librosa.mel_to_hz([1,2,3,4,5])
array([ 66.667, 133.333, 200. , 266.667, 333.333])
Parameters
----------
mels : np.ndarray [shape=(n,)], float
mel bins to convert
htk : bool
use HTK formula instead of Slaney
Returns
-------
frequencies : np.ndarray [shape=(n,)]
input mels in Hz
See Also
--------
hz_to_mel
"""
mels = np.asanyarray(mels)
if htk:
return 700.0 * (10.0**(mels / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if mels.ndim:
# If we have vector data, vectorize
log_t = (mels >= min_log_mel)
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
elif mels >= min_log_mel:
# If we have scalar data, check directly
freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
return freqs
def hz_to_mel(frequencies, htk=False):
"""Convert Hz to Mels
Examples
--------
>>> librosa.hz_to_mel(60)
0.9
>>> librosa.hz_to_mel([110, 220, 440])
array([ 1.65, 3.3 , 6.6 ])
Parameters
----------
frequencies : number or np.ndarray [shape=(n,)] , float
scalar or array of frequencies
htk : bool
use HTK formula instead of Slaney
Returns
-------
mels : number or np.ndarray [shape=(n,)]
input frequencies in Mels
See Also
--------
mel_to_hz
"""
frequencies = np.asanyarray(frequencies)
if htk:
return 2595.0 * np.log10(1.0 + frequencies / 700.0)
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (frequencies - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
if frequencies.ndim:
# If we have array data, vectorize
log_t = (frequencies >= min_log_hz)
mels[log_t] = min_log_mel + np.log(frequencies[log_t]/min_log_hz) / logstep
elif frequencies >= min_log_hz:
# If we have scalar data, heck directly
mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
return mels
def fft_frequencies(sr=22050, n_fft=2048):
'''Alternative implementation of `np.fft.fftfreq`
Parameters
----------
sr : number > 0 [scalar]
Audio sampling rate
n_fft : int > 0 [scalar]
FFT window size
Returns
-------
freqs : np.ndarray [shape=(1 + n_fft/2,)]
Frequencies `(0, sr/n_fft, 2*sr/n_fft, ..., sr/2)`
Examples
--------
>>> librosa.fft_frequencies(sr=22050, n_fft=16)
array([ 0. , 1378.125, 2756.25 , 4134.375,
5512.5 , 6890.625, 8268.75 , 9646.875, 11025. ])
'''
return np.linspace(0,
float(sr) / 2,
int(1 + n_fft//2),
endpoint=True)
def mel_frequencies(n_mels=128, fmin=0.0, fmax=11025.0, htk=False):
"""
This function is cloned from librosa 0.7.
Please refer to the original
`documentation <https://librosa.org/doc/latest/generated/librosa.mel_frequencies.html?highlight=mel_frequencies#librosa.mel_frequencies>`__
for more info.
Parameters
----------
n_mels : int > 0 [scalar]
Number of mel bins.
fmin : float >= 0 [scalar]
Minimum frequency (Hz).
fmax : float >= 0 [scalar]
Maximum frequency (Hz).
htk : bool
If True, use HTK formula to convert Hz to mel.
Otherwise (False), use Slaney's Auditory Toolbox.
Returns
-------
bin_frequencies : ndarray [shape=(n_mels,)]
Vector of n_mels frequencies in Hz which are uniformly spaced on the Mel
axis.
Examples
--------
>>> librosa.mel_frequencies(n_mels=40)
array([ 0. , 85.317, 170.635, 255.952,
341.269, 426.586, 511.904, 597.221,
682.538, 767.855, 853.173, 938.49 ,
1024.856, 1119.114, 1222.042, 1334.436,
1457.167, 1591.187, 1737.532, 1897.337,
2071.84 , 2262.393, 2470.47 , 2697.686,
2945.799, 3216.731, 3512.582, 3835.643,
4188.417, 4573.636, 4994.285, 5453.621,
5955.205, 6502.92 , 7101.009, 7754.107,
8467.272, 9246.028, 10096.408, 11025. ])
"""
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = hz_to_mel(fmin, htk=htk)
max_mel = hz_to_mel(fmax, htk=htk)
mels = np.linspace(min_mel, max_mel, n_mels)
return mel_to_hz(mels, htk=htk)
def mel(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False,
norm=1, dtype=np.float32):
"""
This function is cloned from librosa 0.7.
Please refer to the original
`documentation <https://librosa.org/doc/latest/generated/librosa.filters.mel.html>`__
for more info.
Create a Filterbank matrix to combine FFT bins into Mel-frequency bins
Parameters
----------
sr : number > 0 [scalar]
sampling rate of the incoming signal
n_fft : int > 0 [scalar]
number of FFT components
n_mels : int > 0 [scalar]
number of Mel bands to generate
fmin : float >= 0 [scalar]
lowest frequency (in Hz)
fmax : float >= 0 [scalar]
highest frequency (in Hz).
If `None`, use `fmax = sr / 2.0`
htk : bool [scalar]
use HTK formula instead of Slaney
norm : {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
dtype : np.dtype
The data type of the output basis.
By default, uses 32-bit (single-precision) floating point.
Returns
-------
M : np.ndarray [shape=(n_mels, 1 + n_fft/2)]
Mel transform matrix
Notes
-----
This function caches at level 10.
Examples
--------
>>> melfb = librosa.filters.mel(22050, 2048)
>>> melfb
array([[ 0. , 0.016, ..., 0. , 0. ],
[ 0. , 0. , ..., 0. , 0. ],
...,
[ 0. , 0. , ..., 0. , 0. ],
[ 0. , 0. , ..., 0. , 0. ]])
Clip the maximum frequency to 8KHz
>>> librosa.filters.mel(22050, 2048, fmax=8000)
array([[ 0. , 0.02, ..., 0. , 0. ],
[ 0. , 0. , ..., 0. , 0. ],
...,
[ 0. , 0. , ..., 0. , 0. ],
[ 0. , 0. , ..., 0. , 0. ]])
>>> import matplotlib.pyplot as plt
>>> plt.figure()
>>> librosa.display.specshow(melfb, x_axis='linear')
>>> plt.ylabel('Mel filter')
>>> plt.title('Mel filter bank')
>>> plt.colorbar()
>>> plt.tight_layout()
>>> plt.show()
"""
if fmax is None:
fmax = float(sr) / 2
if norm is not None and norm != 1 and norm != np.inf:
raise ParameterError('Unsupported norm: {}'.format(repr(norm)))
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
# 'Center freqs' of mel bands - uniformly spaced between limits
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i+2] / fdiff[i+1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
if norm == 1:
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2:n_mels+2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
# Only check weights if f_mel[0] is positive
if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
# This means we have an empty channel somewhere
warnings.warn('Empty filters detected in mel frequency basis. '
'Some channels will produce empty responses. '
'Try increasing your sampling rate (and fmax) or '
'reducing n_mels.')
return weights
### ------------------End of Functions for generating kenral for Mel Spectrogram ----------------###
### ------------------Functions for making STFT same as librosa ---------------------------------###
def pad_center(data, size, axis=-1, **kwargs):
'''Wrapper for np.pad to automatically center an array prior to padding.
This is analogous to `str.center()`
Examples
--------
>>> # Generate a vector
>>> data = np.ones(5)
>>> librosa.util.pad_center(data, 10, mode='constant')
array([ 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.])
>>> # Pad a matrix along its first dimension
>>> data = np.ones((3, 5))
>>> librosa.util.pad_center(data, 7, axis=0)
array([[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1.],
[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.]])
>>> # Or its second dimension
>>> librosa.util.pad_center(data, 7, axis=1)
array([[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.]])
Parameters
----------
data : np.ndarray
Vector to be padded and centered
size : int >= len(data) [scalar]
Length to pad `data`
axis : int
Axis along which to pad and center the data
kwargs : additional keyword arguments
arguments passed to `np.pad()`
Returns
-------
data_padded : np.ndarray
`data` centered and padded to length `size` along the
specified axis
Raises
------
ParameterError
If `size < data.shape[axis]`
See Also
--------
numpy.pad
'''
kwargs.setdefault('mode', 'constant')
n = data.shape[axis]
lpad = int((size - n) // 2)
lengths = [(0, 0)] * data.ndim
lengths[axis] = (lpad, int(size - n - lpad))
if lpad < 0:
raise ParameterError(('Target size ({:d}) must be '
'at least input size ({:d})').format(size, n))
return np.pad(data, lengths, **kwargs)
### ------------------End of functions for making STFT same as librosa ---------------------------###
"""
Module containing helper functions such as overlap sum and Fourier kernels generators
"""
import torch
from torch.nn.functional import conv1d, fold
import numpy as np
from time import time
import math
from scipy.signal import get_window
from scipy import signal
from scipy import fft
import warnings
from nnAudio.librosa_functions import *
## --------------------------- Filter Design ---------------------------##
def torch_window_sumsquare(w, n_frames, stride, n_fft, power=2):
w_stacks = w.unsqueeze(-1).repeat((1,n_frames)).unsqueeze(0)
# Window length + stride*(frames-1)
output_len = w_stacks.shape[1] + stride*(w_stacks.shape[2]-1)
return fold(w_stacks**power, (1,output_len), kernel_size=(1,n_fft), stride=stride)
def overlap_add(X, stride):
n_fft = X.shape[1]
output_len = n_fft + stride*(X.shape[2]-1)
return fold(X, (1,output_len), kernel_size=(1,n_fft), stride=stride).flatten(1)
def uniform_distribution(r1,r2, *size, device):
return (r1 - r2) * torch.rand(*size, device=device) + r2
def extend_fbins(X):
"""Extending the number of frequency bins from `n_fft//2+1` back to `n_fft` by
reversing all bins except DC and Nyquist and append it on top of existing spectrogram"""
X_upper = torch.flip(X[:,1:-1],(0,1))
X_upper[:,:,:,1] = -X_upper[:,:,:,1] # For the imaganinry part, it is an odd function
return torch.cat((X[:, :, :], X_upper), 1)
def downsampling_by_n(x, filterKernel, n):
"""A helper function that downsamples the audio by a arbitary factor n.
It is used in CQT2010 and CQT2010v2.
Parameters
----------
x : torch.Tensor
The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
filterKernel : str
Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
n : int
The downsampling factor
Returns
-------
torch.Tensor
The downsampled waveform
Examples
--------
>>> x_down = downsampling_by_n(x, filterKernel)
"""
x = conv1d(x,filterKernel,stride=n, padding=(filterKernel.shape[-1]-1)//2)
return x
def downsampling_by_2(x, filterKernel):
"""A helper function that downsamples the audio by half. It is used in CQT2010 and CQT2010v2
Parameters
----------
x : torch.Tensor
The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
filterKernel : str
Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
Returns
-------
torch.Tensor
The downsampled waveform
Examples
--------
>>> x_down = downsampling_by_2(x, filterKernel)
"""
x = conv1d(x,filterKernel,stride=2, padding=(filterKernel.shape[-1]-1)//2)
return x
## Basic tools for computation ##
def nextpow2(A):
"""A helper function to calculate the next nearest number to the power of 2.
Parameters
----------
A : float
A float number that is going to be rounded up to the nearest power of 2
Returns
-------
int
The nearest power of 2 to the input number ``A``
Examples
--------
>>> nextpow2(6)
3
"""
return int(np.ceil(np.log2(A)))
## Basic tools for computation ##
def prepow2(A):
"""A helper function to calculate the next nearest number to the power of 2.
Parameters
----------
A : float
A float number that is going to be rounded up to the nearest power of 2
Returns
-------
int
The nearest power of 2 to the input number ``A``
Examples
--------
>>> nextpow2(6)
3
"""
return int(np.floor(np.log2(A)))
def complex_mul(cqt_filter, stft):
"""Since PyTorch does not support complex numbers and its operation.
We need to write our own complex multiplication function. This one is specially
designed for CQT usage.
Parameters
----------
cqt_filter : tuple of torch.Tensor
The tuple is in the format of ``(real_torch_tensor, imag_torch_tensor)``
Returns
-------
tuple of torch.Tensor
The output is in the format of ``(real_torch_tensor, imag_torch_tensor)``
"""
cqt_filter_real = cqt_filter[0]
cqt_filter_imag = cqt_filter[1]
fourier_real = stft[0]
fourier_imag = stft[1]
CQT_real = torch.matmul(cqt_filter_real, fourier_real) - torch.matmul(cqt_filter_imag, fourier_imag)
CQT_imag = torch.matmul(cqt_filter_real, fourier_imag) + torch.matmul(cqt_filter_imag, fourier_real)
return CQT_real, CQT_imag
def broadcast_dim(x):
"""
Auto broadcast input so that it can fits into a Conv1d
"""
if x.dim() == 2:
x = x[:, None, :]
elif x.dim() == 1:
# If nn.DataParallel is used, this broadcast doesn't work
x = x[None, None, :]
elif x.dim() == 3:
pass
else:
raise ValueError("Only support input with shape = (batch, len) or shape = (len)")
return x
def broadcast_dim_conv2d(x):
"""
Auto broadcast input so that it can fits into a Conv2d
"""
if x.dim() == 3:
x = x[:, None, :,:]
else:
raise ValueError("Only support input with shape = (batch, len) or shape = (len)")
return x
## Kernal generation functions ##
def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,
freq_scale='linear', window='hann', verbose=True):
""" This function creates the Fourier Kernel for STFT, Melspectrogram and CQT.
Most of the parameters follow librosa conventions. Part of the code comes from
pytorch_musicnet. https://github.com/jthickstun/pytorch_musicnet
Parameters
----------
n_fft : int
The window size
freq_bins : int
Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins
fmin : int
The starting frequency for the lowest frequency bin.
If freq_scale is ``no``, this argument does nothing.
fmax : int
The ending frequency for the highest frequency bin.
If freq_scale is ``no``, this argument does nothing.
sr : int
The sampling rate for the input audio. It is used to calculate the correct ``fmin`` and ``fmax``.
Setting the correct sampling rate is very important for calculating the correct frequency.
freq_scale: 'linear', 'log', or 'no'
Determine the spacing between each frequency bin.
When 'linear' or 'log' is used, the bin spacing can be controlled by ``fmin`` and ``fmax``.
If 'no' is used, the bin will start at 0Hz and end at Nyquist frequency with linear spacing.
Returns
-------
wsin : numpy.array
Imaginary Fourier Kernel with the shape ``(freq_bins, 1, n_fft)``
wcos : numpy.array
Real Fourier Kernel with the shape ``(freq_bins, 1, n_fft)``
bins2freq : list
Mapping each frequency bin to frequency in Hz.
binslist : list
The normalized frequency ``k`` in digital domain.
This ``k`` is in the Discrete Fourier Transform equation $$
"""
if freq_bins==None: freq_bins = n_fft//2+1
if win_length==None: win_length = n_fft
s = np.arange(0, n_fft, 1.)
wsin = np.empty((freq_bins,1,n_fft))
wcos = np.empty((freq_bins,1,n_fft))
start_freq = fmin
end_freq = fmax
bins2freq = []
binslist = []
# num_cycles = start_freq*d/44000.
# scaling_ind = np.log(end_freq/start_freq)/k
# Choosing window shape
window_mask = get_window(window,int(win_length), fftbins=True)
window_mask = pad_center(window_mask, n_fft)
if freq_scale == 'linear':
if verbose==True:
print(f"sampling rate = {sr}. Please make sure the sampling rate is correct in order to"
f"get a valid freq range")
start_bin = start_freq*n_fft/sr
scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins
for k in range(freq_bins): # Only half of the bins contain useful info
# print("linear freq = {}".format((k*scaling_ind+start_bin)*sr/n_fft))
bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)
binslist.append((k*scaling_ind+start_bin))
wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)
wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)
elif freq_scale == 'log':
if verbose==True:
print(f"sampling rate = {sr}. Please make sure the sampling rate is correct in order to"
f"get a valid freq range")
start_bin = start_freq*n_fft/sr
scaling_ind = np.log(end_freq/start_freq)/freq_bins
for k in range(freq_bins): # Only half of the bins contain useful info
# print("log freq = {}".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))
bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)
binslist.append((np.exp(k*scaling_ind)*start_bin))
wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
elif freq_scale == 'no':
for k in range(freq_bins): # Only half of the bins contain useful info
bins2freq.append(k*sr/n_fft)
binslist.append(k)
wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)
wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)
else:
print("Please select the correct frequency scale, 'linear' or 'log'")
return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)
# Tools for CQT
def create_cqt_kernels(Q, fs, fmin, n_bins=84, bins_per_octave=12, norm=1,
window='hann', fmax=None, topbin_check=True):
"""
Automatically create CQT kernels in time domain
"""
fftLen = 2**nextpow2(np.ceil(Q * fs / fmin))
# minWin = 2**nextpow2(np.ceil(Q * fs / fmax))
if (fmax != None) and (n_bins == None):
n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
elif (fmax == None) and (n_bins != None):
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
else:
warnings.warn('If fmax is given, n_bins will be ignored',SyntaxWarning)
n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
if np.max(freqs) > fs/2 and topbin_check==True:
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
please reduce the n_bins'.format(np.max(freqs)))
tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
lengths = np.ceil(Q * fs / freqs)
for k in range(0, int(n_bins)):
freq = freqs[k]
l = np.ceil(Q * fs / freq)
# Centering the kernels
if l%2==1: # pad more zeros on RHS
start = int(np.ceil(fftLen / 2.0 - l / 2.0))-1
else:
start = int(np.ceil(fftLen / 2.0 - l / 2.0))
sig = get_window_dispatch(window,int(l), fftbins=True)*np.exp(np.r_[-l//2:l//2]*1j*2*np.pi*freq/fs)/l
if norm: # Normalizing the filter # Trying to normalize like librosa
tempKernel[k, start:start + int(l)] = sig/np.linalg.norm(sig, norm)
else:
tempKernel[k, start:start + int(l)] = sig
# specKernel[k, :] = fft(tempKernel[k])
# return specKernel[:,:fftLen//2+1], fftLen, torch.tensor(lenghts).float()
return tempKernel, fftLen, torch.tensor(lengths).float(), freqs
def get_window_dispatch(window, N, fftbins=True):
if isinstance(window, str):
return get_window(window, N, fftbins=fftbins)
elif isinstance(window, tuple):
if window[0] == 'gaussian':
assert window[1] >= 0
sigma = np.floor(- N / 2 / np.sqrt(- 2 * np.log(10**(- window[1] / 20))))
return get_window(('gaussian', sigma), N, fftbins=fftbins)
else:
Warning("Tuple windows may have undesired behaviour regarding Q factor")
elif isinstance(window, float):
Warning("You are using Kaiser window with beta factor " + str(window) + ". Correct behaviour not checked.")
else:
raise Exception("The function get_window from scipy only supports strings, tuples and floats.")
def get_cqt_complex(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding):
"""Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
for how to multiple the STFT result with the CQT kernel
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
a constant Q transform.” (1992)."""
# STFT, converting the audio input from time domain to frequency domain
try:
x = padding(x) # When center == True, we need padding at the beginning and ending
except:
warnings.warn(f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
"padding with reflection mode might not be the best choice, try using constant padding",
UserWarning)
x = torch.nn.functional.pad(x, (cqt_kernels_real.shape[-1]//2, cqt_kernels_real.shape[-1]//2))
CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
return torch.stack((CQT_real, CQT_imag),-1)
def get_cqt_complex2(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding, wcos=None, wsin=None):
"""Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
for how to multiple the STFT result with the CQT kernel
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
a constant Q transform.” (1992)."""
# STFT, converting the audio input from time domain to frequency domain
try:
x = padding(x) # When center == True, we need padding at the beginning and ending
except:
warnings.warn(f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
"padding with reflection mode might not be the best choice, try using constant padding",
UserWarning)
x = torch.nn.functional.pad(x, (cqt_kernels_real.shape[-1]//2, cqt_kernels_real.shape[-1]//2))
if wcos==None or wsin==None:
CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
else:
fourier_real = conv1d(x, wcos, stride=hop_length)
fourier_imag = conv1d(x, wsin, stride=hop_length)
# Multiplying input with the CQT kernel in freq domain
CQT_real, CQT_imag = complex_mul((cqt_kernels_real, cqt_kernels_imag),
(fourier_real, fourier_imag))
return torch.stack((CQT_real, CQT_imag),-1)
def create_lowpass_filter(band_center=0.5, kernelLength=256, transitionBandwidth=0.03):
"""
Calculate the highest frequency we need to preserve and the lowest frequency we allow
to pass through.
Note that frequency is on a scale from 0 to 1 where 0 is 0 and 1 is Nyquist frequency of
the signal BEFORE downsampling.
"""
# transitionBandwidth = 0.03
passbandMax = band_center / (1 + transitionBandwidth)
stopbandMin = band_center * (1 + transitionBandwidth)
# Unlike the filter tool we used online yesterday, this tool does
# not allow us to specify how closely the filter matches our
# specifications. Instead, we specify the length of the kernel.
# The longer the kernel is, the more precisely it will match.
# kernelLength = 256
# We specify a list of key frequencies for which we will require
# that the filter match a specific output gain.
# From [0.0 to passbandMax] is the frequency range we want to keep
# untouched and [stopbandMin, 1.0] is the range we want to remove
keyFrequencies = [0.0, passbandMax, stopbandMin, 1.0]
# We specify a list of output gains to correspond to the key
# frequencies listed above.
# The first two gains are 1.0 because they correspond to the first
# two key frequencies. the second two are 0.0 because they
# correspond to the stopband frequencies
gainAtKeyFrequencies = [1.0, 1.0, 0.0, 0.0]
# This command produces the filter kernel coefficients
filterKernel = signal.firwin2(kernelLength, keyFrequencies, gainAtKeyFrequencies)
return filterKernel.astype(np.float32)
def get_early_downsample_params(sr, hop_length, fmax_t, Q, n_octaves, verbose):
"""Used in CQT2010 and CQT2010v2"""
window_bandwidth = 1.5 # for hann window
filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q)
sr, hop_length, downsample_factor = early_downsample(sr,
hop_length,
n_octaves,
sr//2,
filter_cutoff)
if downsample_factor != 1:
if verbose==True:
print("Can do early downsample, factor = ", downsample_factor)
earlydownsample=True
# print("new sr = ", sr)
# print("new hop_length = ", hop_length)
early_downsample_filter = create_lowpass_filter(band_center=1/downsample_factor,
kernelLength=256,
transitionBandwidth=0.03)
early_downsample_filter = torch.tensor(early_downsample_filter)[None, None, :]
else:
if verbose==True:
print("No early downsampling is required, downsample_factor = ", downsample_factor)
early_downsample_filter = None
earlydownsample=False
return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample
def early_downsample(sr, hop_length, n_octaves,
nyquist, filter_cutoff):
'''Return new sampling rate and hop length after early dowansampling'''
downsample_count = early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves)
# print("downsample_count = ", downsample_count)
downsample_factor = 2**(downsample_count)
hop_length //= downsample_factor # Getting new hop_length
new_sr = sr / float(downsample_factor) # Getting new sampling rate
sr = new_sr
return sr, hop_length, downsample_factor
# The following two downsampling count functions are obtained from librosa CQT
# They are used to determine the number of pre resamplings if the starting and ending frequency
# are both in low frequency regions.
def early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves):
'''Compute the number of early downsampling operations'''
downsample_count1 = max(0, int(np.ceil(np.log2(0.85 * nyquist /
filter_cutoff)) - 1) - 1)
# print("downsample_count1 = ", downsample_count1)
num_twos = nextpow2(hop_length)
downsample_count2 = max(0, num_twos - n_octaves + 1)
# print("downsample_count2 = ",downsample_count2)
return min(downsample_count1, downsample_count2)
def early_downsample(sr, hop_length, n_octaves,
nyquist, filter_cutoff):
'''Return new sampling rate and hop length after early dowansampling'''
downsample_count = early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves)
# print("downsample_count = ", downsample_count)
downsample_factor = 2**(downsample_count)
hop_length //= downsample_factor # Getting new hop_length
new_sr = sr / float(downsample_factor) # Getting new sampling rate
sr = new_sr
return sr, hop_length, downsample_factor
\ No newline at end of file
import setuptools
import codecs
import os.path
def read(rel_path):
here = os.path.abspath(os.path.dirname(__file__))
with codecs.open(os.path.join(here, rel_path), 'r') as fp:
return fp.read()
def get_version(rel_path):
for line in read(rel_path).splitlines():
if line.startswith('__version__'):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
else:
raise RuntimeError("Unable to find version string.")
setuptools.setup(
name="nnAudio", # Replace with your own username
version=get_version("nnAudio/__init__.py"),
author="KinWaiCheuk",
author_email="u3500684@connect.hku.hk",
description="A fast GPU audio processing toolbox with 1D convolutional neural network",
long_description='',
long_description_content_type="text/markdown",
url="https://github.com/KinWaiCheuk/nnAudio",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
)
# Creating parameters for STFT test
"""
It is equivalent to
[(1024, 128, 'ones'),
(1024, 128, 'hann'),
(1024, 128, 'hamming'),
(2048, 128, 'ones'),
(2048, 512, 'ones'),
(2048, 128, 'hann'),
(2048, 512, 'hann'),
(2048, 128, 'hamming'),
(2048, 512, 'hamming'),
(None, None, None)]
"""
stft_parameters = []
n_fft = [1024,2048]
hop_length = {128,512,1024}
window = ['ones', 'hann', 'hamming']
for i in n_fft:
for k in window:
for j in hop_length:
if j < (i/2):
stft_parameters.append((i,j,k))
stft_parameters.append((256, None, 'hann'))
stft_with_win_parameters = []
n_fft = [512,1024]
win_length = [400, 900]
hop_length = {128,256}
for i in n_fft:
for j in win_length:
if j < i:
for k in hop_length:
if k < (i/2):
stft_with_win_parameters.append((i,j,k))
mel_win_parameters = [(512,400), (1024, 1000)]
\ No newline at end of file
import pytest
import librosa
import torch
import matplotlib.pyplot as plt
from scipy.signal import chirp, sweep_poly
from nnAudio.Spectrogram import *
from parameters import *
gpu_idx=0
# librosa example audio for testing
example_y, example_sr = librosa.load(librosa.util.example_audio_file())
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_inverse2(n_fft, hop_length, window, device):
x = torch.tensor(example_y,device=device)
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
istft = iSTFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
X = stft(x.unsqueeze(0), output_format="Complex")
x_recon = istft(X, length=x.shape[0], onesided=True).squeeze()
assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-5, atol=1e-3)
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_inverse(n_fft, hop_length, window, device):
x = torch.tensor(example_y, device=device)
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, iSTFT=True).to(device)
X = stft(x.unsqueeze(0), output_format="Complex")
x_recon = stft.inverse(X, length=x.shape[0]).squeeze()
assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-3, atol=1)
# @pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
# def test_inverse_GPU(n_fft, hop_length, window):
# x = torch.tensor(example_y,device=f'cuda:{gpu_idx}')
# stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, device=f'cuda:{gpu_idx}')
# X = stft(x.unsqueeze(0), output_format="Complex")
# x_recon = stft.inverse(X, num_samples=x.shape[0]).squeeze()
# assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-3, atol=1)
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_stft_complex(n_fft, hop_length, window, device):
x = example_y
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Complex")
X_real, X_imag = X[:, :, :, 0].squeeze(), X[:, :, :, 1].squeeze()
X_librosa = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window=window)
real_diff, imag_diff = np.allclose(X_real.cpu(), X_librosa.real, rtol=1e-3, atol=1e-3), \
np.allclose(X_imag.cpu(), X_librosa.imag, rtol=1e-3, atol=1e-3)
assert real_diff and imag_diff
# @pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
# def test_stft_complex_GPU(n_fft, hop_length, window):
# x = example_y
# stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, device=f'cuda:{gpu_idx}')
# X = stft(torch.tensor(x,device=f'cuda:{gpu_idx}').unsqueeze(0), output_format="Complex")
# X_real, X_imag = X[:, :, :, 0].squeeze().detach().cpu(), X[:, :, :, 1].squeeze().detach().cpu()
# X_librosa = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window=window)
# real_diff, imag_diff = np.allclose(X_real, X_librosa.real, rtol=1e-3, atol=1e-3), \
# np.allclose(X_imag, X_librosa.imag, rtol=1e-3, atol=1e-3)
# assert real_diff and imag_diff
@pytest.mark.parametrize("n_fft, win_length, hop_length", stft_with_win_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_stft_complex_winlength(n_fft, win_length, hop_length, device):
x = example_y
stft = STFT(n_fft=n_fft, win_length=win_length, hop_length=hop_length).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Complex")
X_real, X_imag = X[:, :, :, 0].squeeze(), X[:, :, :, 1].squeeze()
X_librosa = librosa.stft(x, n_fft=n_fft, win_length=win_length, hop_length=hop_length)
real_diff, imag_diff = np.allclose(X_real.cpu(), X_librosa.real, rtol=1e-3, atol=1e-3), \
np.allclose(X_imag.cpu(), X_librosa.imag, rtol=1e-3, atol=1e-3)
assert real_diff and imag_diff
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_stft_magnitude(device):
x = example_y
stft = STFT(n_fft=2048, hop_length=512).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Magnitude").squeeze()
X_librosa, _ = librosa.core.magphase(librosa.stft(x, n_fft=2048, hop_length=512))
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_stft_phase(device):
x = example_y
stft = STFT(n_fft=2048, hop_length=512).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Phase")
X_real, X_imag = torch.cos(X).squeeze(), torch.sin(X).squeeze()
_, X_librosa = librosa.core.magphase(librosa.stft(x, n_fft=2048, hop_length=512))
real_diff, imag_diff = np.mean(np.abs(X_real.cpu().numpy() - X_librosa.real)), \
np.mean(np.abs(X_imag.cpu().numpy() - X_librosa.imag))
# I find that np.allclose is too strict for allowing phase to be similar to librosa.
# Hence for phase we use average element-wise distance as the test metric.
assert real_diff < 2e-4 and imag_diff < 2e-4
@pytest.mark.parametrize("n_fft, win_length", mel_win_parameters)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_mel_spectrogram(n_fft, win_length, device):
x = example_y
melspec = MelSpectrogram(n_fft=n_fft, win_length=win_length, hop_length=512).to(device)
X = melspec(torch.tensor(x, device=device).unsqueeze(0)).squeeze()
X_librosa = librosa.feature.melspectrogram(x, n_fft=n_fft, win_length=win_length, hop_length=512)
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_1992(device):
# Log sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='logarithmic')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT1992(sr=fs, fmin=220, output_format="Magnitude",
n_bins=80, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# Complex
stft = CQT1992(sr=fs, fmin=220, output_format="Complex",
n_bins=80, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# Phase
stft = CQT1992(sr=fs, fmin=220, output_format="Phase",
n_bins=160, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
assert True
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_2010(device):
# Log sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='logarithmic')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT2010(sr=fs, fmin=110, output_format="Magnitude",
n_bins=160, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# Complex
stft = CQT2010(sr=fs, fmin=110, output_format="Complex",
n_bins=160, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# Phase
stft = CQT2010(sr=fs, fmin=110, output_format="Phase",
n_bins=160, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
assert True
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_1992_v2_log(device):
# Log sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='logarithmic')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT1992v2(sr=fs, fmin=55, output_format="Magnitude",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-mag-ground-truth.npy")
X = torch.log(X + 1e-5)
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT1992v2(sr=fs, fmin=55, output_format="Complex",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-complex-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Phase
stft = CQT1992v2(sr=fs, fmin=55, output_format="Phase",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-phase-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_1992_v2_linear(device):
# Linear sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='linear')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT1992v2(sr=fs, fmin=55, output_format="Magnitude",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-mag-ground-truth.npy")
X = torch.log(X + 1e-5)
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT1992v2(sr=fs, fmin=55, output_format="Complex",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-complex-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Phase
stft = CQT1992v2(sr=fs, fmin=55, output_format="Phase",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-phase-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_2010_v2_log(device):
# Log sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='logarithmic')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT2010v2(sr=fs, fmin=55, output_format="Magnitude",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
X = torch.log(X + 1e-2)
# np.save("tests/ground-truths/log-sweep-cqt-2010-mag-ground-truth", X.cpu())
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-mag-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT2010v2(sr=fs, fmin=55, output_format="Complex",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# np.save("tests/ground-truths/log-sweep-cqt-2010-complex-ground-truth", X.cpu())
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-complex-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# # Phase
# stft = CQT2010v2(sr=fs, fmin=55, device=device, output_format="Phase",
# n_bins=207, bins_per_octave=24)
# X = stft(torch.tensor(x, device=device).unsqueeze(0))
# # np.save("tests/ground-truths/log-sweep-cqt-2010-phase-ground-truth", X.cpu())
# ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-phase-ground-truth.npy")
# assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_cqt_2010_v2_linear(device):
# Linear sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs*t)
x = chirp(s, f0, 1, f1, method='linear')
x = x.astype(dtype=np.float32)
# Magnitude
stft = CQT2010v2(sr=fs, fmin=55, output_format="Magnitude",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
X = torch.log(X + 1e-2)
# np.save("tests/ground-truths/linear-sweep-cqt-2010-mag-ground-truth", X.cpu())
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-mag-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT2010v2(sr=fs, fmin=55, output_format="Complex",
n_bins=207, bins_per_octave=24).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
# np.save("tests/ground-truths/linear-sweep-cqt-2010-complex-ground-truth", X.cpu())
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-complex-ground-truth.npy")
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
# Phase
# stft = CQT2010v2(sr=fs, fmin=55, device=device, output_format="Phase",
# n_bins=207, bins_per_octave=24)
# X = stft(torch.tensor(x, device=device).unsqueeze(0))
# # np.save("tests/ground-truths/linear-sweep-cqt-2010-phase-ground-truth", X.cpu())
# ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-phase-ground-truth.npy")
# assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
def test_mfcc(device):
x = example_y
mfcc = MFCC(sr=example_sr).to(device)
X = mfcc(torch.tensor(x, device=device).unsqueeze(0)).squeeze()
X_librosa = librosa.feature.mfcc(x, sr=example_sr)
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
x = torch.randn((4,44100)) # Create a batch of input for the following Data.Parallel test
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_STFT_Parallel(device):
spec_layer = STFT(hop_length=512, n_fft=2048, window='hann',
freq_scale='no',
output_format='Complex').to(device)
inverse_spec_layer = iSTFT(hop_length=512, n_fft=2048, window='hann',
freq_scale='no').to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
inverse_spec_layer_parallel = torch.nn.DataParallel(inverse_spec_layer)
spec = spec_layer_parallel(x)
x_recon = inverse_spec_layer_parallel(spec, onesided=True, length=x.shape[-1])
assert np.allclose(x_recon.detach().cpu(), x.detach().cpu(), rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_MelSpectrogram_Parallel(device):
spec_layer = MelSpectrogram(sr=22050, n_fft=2048, n_mels=128, hop_length=512,
window='hann', center=True, pad_mode='reflect',
power=2.0, htk=False, fmin=0.0, fmax=None, norm=1,
verbose=True).to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_MFCC_Parallel(device):
spec_layer = MFCC().to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_CQT1992_Parallel(device):
spec_layer = CQT1992(fmin=110, n_bins=60, bins_per_octave=12).to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_CQT1992v2_Parallel(device):
spec_layer = CQT1992v2().to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_CQT2010_Parallel(device):
spec_layer = CQT2010().to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
def test_CQT2010v2_Parallel(device):
spec_layer = CQT2010v2().to(device)
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
spec = spec_layer_parallel(x)
\ No newline at end of file
import paddle
import numpy as np
from typing import Tuple, Optional, Union
# https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/src/feat/feature-window.cc#L109
def povey_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = (0.5 - 0.5 * np.cos(a * i) )**0.85
return win
def hann_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = 0.5 - 0.5 * np.cos(a * i)
return win
def sine_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = np.sin(0.5 * a * i)
return win
def hamm_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = 0.54 - 0.46 * np.cos(a * i)
return win
def get_window(wintype:Optional[str], winlen:int) -> np.ndarray:
"""get window function
Args:
wintype (Optional[str]): window type.
winlen (int): window length in samples.
Raises:
ValueError: not support window.
Returns:
np.ndarray: window coeffs.
"""
# calculate window
if not wintype or wintype == 'rectangular':
window = np.ones(winlen)
elif wintype == "hann":
window = hann_window(winlen)
elif wintype == "hamm":
window = hamm_window(winlen)
elif wintype == "povey":
window = povey_window(winlen)
else:
msg = f"{wintype} Not supported yet!"
raise ValueError(msg)
return window
def dft_matrix(n_fft:int, winlen:int=None, n_bin:int=None) -> Tuple[np.ndarray, np.ndarray, int]:
# https://en.wikipedia.org/wiki/Discrete_Fourier_transform
# (n_bins, n_fft) complex
if n_bin is None:
n_bin = 1 + n_fft // 2
if winlen is None:
winlen = n_bin
# https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49
kernel_size = min(n_fft, winlen)
n = np.arange(0, n_fft, 1.)
wsin = np.empty((n_bin, kernel_size)) #[Cout, kernel_size]
wcos = np.empty((n_bin, kernel_size)) #[Cout, kernel_size]
for k in range(n_bin): # Only half of the bins contain useful info
wsin[k,:] = -np.sin(2*np.pi*k*n/n_fft)[:kernel_size]
wcos[k,:] = np.cos(2*np.pi*k*n/n_fft)[:kernel_size]
w_real = wcos
w_imag = wsin
return w_real, w_imag, kernel_size
def dft_matrix_fast(n_fft:int, winlen:int=None, n_bin:int=None) -> Tuple[np.ndarray, np.ndarray, int]:
# (n_bins, n_fft) complex
if n_bin is None:
n_bin = 1 + n_fft // 2
if winlen is None:
winlen = n_bin
# https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49
kernel_size = min(n_fft, winlen)
# https://en.wikipedia.org/wiki/DFT_matrix
# https://ccrma.stanford.edu/~jos/st/Matrix_Formulation_DFT.html
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size]
w_real = weight.real
w_imag = weight.imag
return w_real, w_imag, kernel_size
def bin2hz(bin:Union[List[int], np.ndarray], N:int, sr:int)->List[float]:
"""FFT bins to Hz.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Args:
bins (List[int] or np.ndarray): bin index.
N (int): the number of samples, or FFT points.
sr (int): sampling rate.
Returns:
List[float]: Hz's.
"""
hz = bin * float(sr) / N
def hz2mel(hz):
"""Convert a value in Hertz to Mels
:param hz: a value in Hz. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Mels. If an array was passed in, an identical sized array is returned.
"""
return 1127 * np.log(1+hz/700.0)
def mel2hz(mel):
"""Convert a value in Mels to Hertz
:param mel: a value in Mels. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Hertz. If an array was passed in, an identical sized array is returned.
"""
return 700 * (np.exp(mel/1127.0)-1)
def rms_to_db(rms: float):
"""Root Mean Square to dB.
Args:
rms ([float]): root mean square
Returns:
float: dB
"""
return 20.0 * math.log10(max(1e-16, rms))
def rms_to_dbfs(rms: float):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
Args:
rms ([float]): root mean square
Returns:
float: dBFS
"""
return rms_to_db(rms) - 3.0103
def max_dbfs(sample_data: np.ndarray):
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
def mean_dbfs(sample_data):
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
return rms_to_dbfs(
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
def gain_db_to_ratio(gain_db: float):
"""dB to ratio
Args:
gain_db (float): gain in dB
Returns:
float: scale in amp
"""
return math.pow(10.0, gain_db / 20.0)
\ No newline at end of file
from typing import Tuple
import numpy as np
import paddle
from paddle import Tensor
from paddle import nn
from paddle.nn import functional as F
import soundfile as sf
from .common import get_window
from .common import dft_matrix
def read(wavpath:str, sr:int = None, start=0, stop=None, dtype='int16', always_2d=True)->Tuple[int, np.ndarray]:
"""load wav file.
Args:
wavpath (str): wav path.
sr (int, optional): expect sample rate. Defaults to None.
dtype (str, optional): wav data bits. Defaults to 'int16'.
Returns:
Tuple[int, np.ndarray]: sr (int), wav (int16) [T, C].
"""
wav, r_sr = sf.read(wavpath, start=start, stop=stop, dtype=dtype, always_2d=always_2d)
if sr:
assert sr == r_sr
return r_sr, wav
def write(wavpath:str, wav:np.ndarray, sr:int, dtype='PCM_16'):
"""write wav file.
Args:
wavpath (str): file path to save.
wav (np.ndarray): wav data.
sr (int): data samplerate.
dtype (str, optional): wav bit format. Defaults to 'PCM_16'.
"""
sf.write(wavpath, wav, sr, subtype=dtype)
def frames(x: Tensor,
num_samples: Tensor,
sr: int,
win_length: float,
stride_length: float,
clip: bool = False) -> Tuple[Tensor, Tensor]:
"""Extract frames from audio.
Parameters
----------
x : Tensor
Shape (B, T), batched waveform.
num_samples : Tensor
Shape (B, ), number of samples of each waveform.
sr: int
Sampling Rate.
win_length : float
Window length in ms.
stride_length : float
Stride length in ms.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
default True
Returns
-------
frames : Tensor
Shape (B, T', win_length).
num_frames : Tensor
Shape (B, ) number of valid frames
"""
assert stride_length <= win_length
stride_length = int(stride_length * sr)
win_length = int(win_length * sr)
num_frames = (num_samples - win_length) // stride_length
padding = (0, 0)
if not clip:
num_frames += 1
need_samples = num_frames * stride_length + win_length
padding = (0, need_samples - num_samples - 1)
weight = paddle.eye(win_length).unsqueeze(1) #[win_length, 1, win_length]
frames = F.conv1d(x.unsqueeze(-1),
weight,
padding=padding,
stride=(stride_length, ),
data_format='NLC')
return frames, num_frames
def dither(signal:Tensor, dither_value=1.0)->Tensor:
"""dither frames for log compute.
Args:
signal (Tensor): [B, T, D]
dither_value (float, optional): [scalar]. Defaults to 1.0.
Returns:
Tensor: [B, T, D]
"""
D = paddle.shape(signal)[-1]
signal += paddle.normal(shape=[1, 1, D]) * dither_value
return signal
def remove_dc_offset(signal:Tensor)->Tensor:
"""remove dc.
Args:
signal (Tensor): [B, T, D]
Returns:
Tensor: [B, T, D]
"""
signal -= paddle.mean(signal, axis=-1, keepdim=True)
return signal
def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
"""perform preemphasis on the input signal.
Args:
signal (Tensor): [B, T, D], The signal to filter.
coeff (float, optional): [scalar].The preemphasis coefficient. 0 is no filter, Defaults to 0.97.
Returns:
Tensor: [B, T, D]
"""
return paddle.concat([
(1-coeff)*signal[:, :, 0:1],
signal[:, :, 1:] - coeff * signal[:, :, :-1]
], axis=-1)
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Parameters
------------
n_fft : int
Number of samples in a frame.
sr: int
Number of Samplilng rate.
stride_length : float
Number of samples shifted between adjacent frames.
win_length : float
Length of the window.
clip: bool
Whether to clip audio is necesaary.
"""
def __init__(self,
n_fft: int,
sr: int,
win_length: float,
stride_length: float,
dither:float=0.0,
preemph_coeff:float=0.97,
remove_dc_offset:bool=True,
window_type: str = 'povey',
clip: bool = False):
super().__init__()
self.sr = sr
self.win_length = win_length
self.stride_length = stride_length
self.dither = dither
self.preemph_coeff = preemph_coeff
self.remove_dc_offset = remove_dc_offset
self.window_type = window_type
self.clip = clip
self.n_fft = n_fft
self.n_bin = 1 + n_fft // 2
w_real, w_imag, kernel_size = dft_matrix(
self.n_fft, int(self.win_length * self.sr), self.n_bin
)
# calculate window
window = get_window(window_type, kernel_size)
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
# (kernel_size, 2 * n_bins)
w = np.transpose(w)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
num_samples : Tensor [shape=(B,)]
Number of samples of each waveform.
Returns
------------
C : Tensor
Shape(B, T', n_bins, 2) Spectrogram.
num_frames: Tensor
Shape (B,) number of samples of each spectrogram
"""
batch_size = paddle.shape(num_samples)
F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip)
if self.dither:
F = dither(F, self.dither)
if self.remove_dc_offset:
F = remove_dc_offset(F)
if self.preemph_coeff:
F = preemphasis(F)
C = paddle.matmul(F, self.weight) # [B, T, K] [K, 2 * n_bins]
C = paddle.reshape(C, [batch_size, -1, 2, self.n_bin])
C = C.transpose([0, 1, 3, 2])
return C, nframe
def powspec(C:Tensor) -> Tensor:
"""Compute the power spectrum |X_k|^2.
Args:
C (Tensor): [B, T, C, 2]
Returns:
Tensor: [B, T, C]
"""
real, imag = paddle.chunk(C, 2, axis=-1)
return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1))
def magspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute the magnitude spectrum |X_k|.
Args:
C (Tensor): [B, T, C, 2]
eps (float): epsilon.
Returns:
Tensor: [B, T, C]
"""
pspec = powspec(C)
return paddle.sqrt(pspec + eps)
def logspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute log-spectrum 20log10∣X_k∣.
Args:
C (Tensor): [description]
eps ([type], optional): [description]. Defaults to 1e-10.
Returns:
Tensor: [description]
"""
spec = magspec(C)
return 20 * paddle.log10(spec + eps)
from typing import Tuple
import numpy as np
import paddle
import unittest
import decimal
import numpy
import math
import logging
from pathlib import Path
from scipy.fftpack import dct
from third_party.paddle_audio.frontend import kaldi
def round_half_up(number):
return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))
def rolling_window(a, window, step=1):
# http://ellisvalentiner.com/post/2017-03-21-np-strides-trick
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]
def do_dither(signal, dither_value=1.0):
signal += numpy.random.normal(size=signal.shape) * dither_value
return signal
def do_remove_dc_offset(signal):
signal -= numpy.mean(signal)
return signal
def do_preemphasis(signal, coeff=0.97):
"""perform preemphasis on the input signal.
:param signal: The signal to filter.
:param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.
:returns: the filtered signal.
"""
return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])
def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):
"""Frame a signal into overlapping frames.
:param sig: the audio signal to frame.
:param frame_len: length of each frame measured in samples.
:param frame_step: number of samples after the start of the previous frame that the next frame should begin.
:param winfunc: the analysis window to apply to each frame. By default no window is applied.
:param stride_trick: use stride trick to compute the rolling window and window multiplication faster
:returns: an array of frames. Size is NUMFRAMES by frame_len.
"""
slen = len(sig)
frame_len = int(round_half_up(frame_len))
frame_step = int(round_half_up(frame_step))
if slen <= frame_len:
numframes = 1
else:
numframes = 1 + (( slen - frame_len) // frame_step)
# check kaldi/src/feat/feature-window.h
padsignal = sig[:(numframes-1)*frame_step+frame_len]
if wintype is 'povey':
win = numpy.empty(frame_len)
for i in range(frame_len):
win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85
else: # the hamming window
win = numpy.hamming(frame_len)
if stride_trick:
frames = rolling_window(padsignal, window=frame_len, step=frame_step)
else:
indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
indices = numpy.array(indices, dtype=numpy.int32)
frames = padsignal[indices]
win = numpy.tile(win, (numframes, 1))
frames = frames.astype(numpy.float32)
raw_frames = numpy.zeros(frames.shape)
for frm in range(frames.shape[0]):
frames[frm,:] = do_dither(frames[frm,:], dither) # dither
frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset
raw_frames[frm,:] = frames[frm,:]
frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize
return frames * win, raw_frames
def magspec(frames, NFFT):
"""Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.
"""
if numpy.shape(frames)[1] > NFFT:
logging.warn(
'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',
numpy.shape(frames)[1], NFFT)
complex_spec = numpy.fft.rfft(frames, NFFT)
return numpy.absolute(complex_spec)
def powspec(frames, NFFT):
"""Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.
"""
return numpy.square(magspec(frames, NFFT))
def mfcc(signal,samplerate=16000,winlen=0.025,winstep=0.01,numcep=13,
nfilt=23,nfft=512,lowfreq=20,highfreq=None,dither=1.0,remove_dc_offset=True,preemph=0.97,
ceplifter=22,useEnergy=True,wintype='povey'):
"""Compute MFCC features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param numcep: the number of cepstrum to return, default 13
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:param ceplifter: apply a lifter to final cepstral coefficients. 0 is no lifter. Default is 22.
:param appendEnergy: if this is true, the zeroth cepstral coefficient is replaced with the log of the total frame energy.
:param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming
:returns: A numpy array of size (NUMFRAMES by numcep) containing features. Each row holds 1 feature vector.
"""
feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,dither,remove_dc_offset,preemph,wintype)
feat = numpy.log(feat)
feat = dct(feat, type=2, axis=1, norm='ortho')[:,:numcep]
feat = lifter(feat,ceplifter)
if useEnergy: feat[:,0] = numpy.log(energy) # replace first cepstral coefficient with log of frame energy
return feat
def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97,
wintype='hamming'):
"""Compute Mel-filterbank energy features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming
winfunc=lambda x:numpy.ones((x,))
:returns: 2 values. The first is a numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector. The
second return value is the energy in each frame (total energy, unwindowed)
"""
highfreq= highfreq or samplerate/2
frames,raw_frames = sigproc.framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)
pspec = sigproc.powspec(frames,nfft) # nearly the same until this part
energy = numpy.sum(raw_frames**2,1) # this stores the raw energy in each frame
energy = numpy.where(energy == 0,numpy.finfo(float).eps,energy) # if energy is zero, we get problems with log
fb = get_filterbanks(nfilt,nfft,samplerate,lowfreq,highfreq)
feat = numpy.dot(pspec,fb.T) # compute the filterbank energies
feat = numpy.where(feat == 0,numpy.finfo(float).eps,feat) # if feat is zero, we get problems with log
return feat,energy
def logfbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=64,highfreq=None,dither=1.0,remove_dc_offset=True,preemph=0.97,wintype='hamming'):
"""Compute log Mel-filterbank energy features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:returns: A numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector.
"""
feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,dither, remove_dc_offset,preemph,wintype)
return numpy.log(feat)
def hz2mel(hz):
"""Convert a value in Hertz to Mels
:param hz: a value in Hz. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Mels. If an array was passed in, an identical sized array is returned.
"""
return 1127 * numpy.log(1+hz/700.0)
def mel2hz(mel):
"""Convert a value in Mels to Hertz
:param mel: a value in Mels. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Hertz. If an array was passed in, an identical sized array is returned.
"""
return 700 * (numpy.exp(mel/1127.0)-1)
def get_filterbanks(nfilt=26,nfft=512,samplerate=16000,lowfreq=0,highfreq=None):
"""Compute a Mel-filterbank. The filters are stored in the rows, the columns correspond
to fft bins. The filters are returned as an array of size nfilt * (nfft/2 + 1)
:param nfilt: the number of filters in the filterbank, default 20.
:param nfft: the FFT size. Default is 512.
:param samplerate: the samplerate of the signal we are working with. Affects mel spacing.
:param lowfreq: lowest band edge of mel filters, default 0 Hz
:param highfreq: highest band edge of mel filters, default samplerate/2
:returns: A numpy array of size nfilt * (nfft/2 + 1) containing filterbank. Each row holds 1 filter.
"""
highfreq= highfreq or samplerate/2
assert highfreq <= samplerate/2, "highfreq is greater than samplerate/2"
# compute points evenly spaced in mels
lowmel = hz2mel(lowfreq)
highmel = hz2mel(highfreq)
# check kaldi/src/feat/Mel-computations.h
fbank = numpy.zeros([nfilt,nfft//2+1])
mel_freq_delta = (highmel-lowmel)/(nfilt+1)
for j in range(0,nfilt):
leftmel = lowmel+j*mel_freq_delta
centermel = lowmel+(j+1)*mel_freq_delta
rightmel = lowmel+(j+2)*mel_freq_delta
for i in range(0,nfft//2):
mel=hz2mel(i*samplerate/nfft)
if mel>leftmel and mel<rightmel:
if mel<centermel:
fbank[j,i]=(mel-leftmel)/(centermel-leftmel)
else:
fbank[j,i]=(rightmel-mel)/(rightmel-centermel)
return fbank
def lifter(cepstra, L=22):
"""Apply a cepstral lifter the the matrix of cepstra. This has the effect of increasing the
magnitude of the high frequency DCT coeffs.
:param cepstra: the matrix of mel-cepstra, will be numframes * numcep in size.
:param L: the liftering coefficient to use. Default is 22. L <= 0 disables lifter.
"""
if L > 0:
nframes,ncoeff = numpy.shape(cepstra)
n = numpy.arange(ncoeff)
lift = 1 + (L/2.)*numpy.sin(numpy.pi*n/L)
return lift*cepstra
else:
# values of L <= 0, do nothing
return cepstra
def delta(feat, N):
"""Compute delta features from a feature vector sequence.
:param feat: A numpy array of size (NUMFRAMES by number of features) containing features. Each row holds 1 feature vector.
:param N: For each frame, calculate delta features based on preceding and following N frames
:returns: A numpy array of size (NUMFRAMES by number of features) containing delta features. Each row holds 1 delta feature vector.
"""
if N < 1:
raise ValueError('N must be an integer >= 1')
NUMFRAMES = len(feat)
denominator = 2 * sum([i**2 for i in range(1, N+1)])
delta_feat = numpy.empty_like(feat)
padded = numpy.pad(feat, ((N, N), (0, 0)), mode='edge') # padded version of feat
for t in range(NUMFRAMES):
delta_feat[t] = numpy.dot(numpy.arange(-N, N+1), padded[t : t+2*N+1]) / denominator # [t : t+2*N+1] == [(N+t)-N : (N+t)+N+1]
return delta_feat
##### modify for test ######
def framesig_without_dither_dc_preemphasize(sig, frame_len, frame_step, wintype='hamming', stride_trick=True):
"""Frame a signal into overlapping frames.
:param sig: the audio signal to frame.
:param frame_len: length of each frame measured in samples.
:param frame_step: number of samples after the start of the previous frame that the next frame should begin.
:param winfunc: the analysis window to apply to each frame. By default no window is applied.
:param stride_trick: use stride trick to compute the rolling window and window multiplication faster
:returns: an array of frames. Size is NUMFRAMES by frame_len.
"""
slen = len(sig)
frame_len = int(round_half_up(frame_len))
frame_step = int(round_half_up(frame_step))
if slen <= frame_len:
numframes = 1
else:
numframes = 1 + (( slen - frame_len) // frame_step)
# check kaldi/src/feat/feature-window.h
padsignal = sig[:(numframes-1)*frame_step+frame_len]
if wintype is 'povey':
win = numpy.empty(frame_len)
for i in range(frame_len):
win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85
elif wintype == '':
win = numpy.ones(frame_len)
elif wintype == 'hann':
win = numpy.hanning(frame_len)
else: # the hamming window
win = numpy.hamming(frame_len)
if stride_trick:
frames = rolling_window(padsignal, window=frame_len, step=frame_step)
else:
indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
indices = numpy.array(indices, dtype=numpy.int32)
frames = padsignal[indices]
win = numpy.tile(win, (numframes, 1))
frames = frames.astype(numpy.float32)
raw_frames = frames
return frames * win, raw_frames
def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None, wintype='hamming'):
frames_with_win, raw_frames = framesig_without_dither_dc_preemphasize(signal, winlen*samplerate, winstep*samplerate, wintype)
return frames_with_win, raw_frames
def complexspec(frames, NFFT):
"""Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.
"""
if numpy.shape(frames)[1] > NFFT:
logging.warn(
'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',
numpy.shape(frames)[1], NFFT)
complex_spec = numpy.fft.rfft(frames, NFFT)
return complex_spec
def stft_with_window(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97,
wintype='hamming'):
frames_with_win, raw_frames = framesig_without_dither_dc_preemphasize(signal, winlen*samplerate, winstep*samplerate, wintype)
spec = magspec(frames_with_win, nfft) # nearly the same until this part
scomplex = complexspec(frames_with_win, nfft)
rspec = magspec(raw_frames, nfft)
rcomplex = complexspec(raw_frames, nfft)
return spec, scomplex, rspec, rcomplex
class TestKaldiFE(unittest.TestCase):
def setUp(self):
self. this_dir = Path(__file__).parent
self.wavpath = str(self.this_dir / 'english.wav')
self.winlen=0.025 # ms
self.winstep=0.01 # ms
self.nfft=512
self.lowfreq = 0
self.highfreq = None
self.wintype='hamm'
self.nfilt=40
paddle.set_device('cpu')
def test_read(self):
import scipy.io.wavfile as wav
rate, sig = wav.read(self.wavpath)
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
self.assertTrue(np.all(sig == wav))
self.assertEqual(rate, sr)
def test_frames(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
_, fs = frames(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
t_fs, t_nframe = kaldi.frames(t_wav, t_wavlen, sr, self.winlen, self.winstep, clip=False)
t_fs = t_fs.astype(fs.dtype)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertTrue(np.allclose(t_fs.numpy(), fs))
def test_stft(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
_, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_c_win.real.dtype)[0]
t_real = t_stft[:, :, 0]
t_imag = t_stft[:, :, 1]
self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0])
self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1)
self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1))
self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1)
self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1))
def test_magspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1)
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1))
def test_magsepc_winprocess(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
fs, _= framesig(wav, self.winlen*sr, self.winstep*sr,
dither=0.0, preemph=0.97, remove_dc_offset=True, wintype='povey', stride_trick=True)
spec = magspec(fs, self.nfft) # nearly the same until this part
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(
self.nfft, sr, self.winlen, self.winstep,
window_type='povey', dither=0.0, preemph_coeff=0.97, remove_dc_offset=True, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(spec.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(spec), 1)
self.assertTrue(np.allclose(t_spec.numpy(), spec, atol=1e-1))
def test_powspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
stft_win = np.square(stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.powspec(t_stft)[0]
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4)
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))
# from python_speech_features import mfcc
# from python_speech_features import delta
# from python_speech_features import logfbank
# import scipy.io.wavfile as wav
# (rate,sig) = wav.read("english.wav")
# # note that generally nfilt=40 is used for speech recognition
# fbank_feat = logfbank(sig,nfilt=23,lowfreq=20,dither=0,wintype='povey')
# # the computed fbank coefficents of english.wav with dimension [110,23]
# # [ 12.2865 12.6906 13.1765 15.714 16.064 15.7553 16.5746 16.9205 16.6472 16.1302 16.4576 16.7326 16.8864 17.7215 18.88 19.1377 19.1495 18.6683 18.3886 20.3506 20.2772 18.8248 18.1899
# # 11.9198 13.146 14.7215 15.8642 17.4288 16.394 16.8238 16.1095 16.4297 16.6331 16.3163 16.5093 17.4981 18.3429 19.6555 19.6263 19.8435 19.0534 19.001 20.0287 19.7707 19.5852 19.1112
# # ...
# # ...
# # the same with that using kaldi commands: compute-fbank-feats --dither=0.0
# mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')
# # the computed mfcc coefficents of english.wav with dimension [110,13]
# # [ 17.1337 -23.3651 -7.41751 -7.73686 -21.3682 -8.93884 -3.70843 4.68346 -16.0676 12.782 -7.24054 8.25089 10.7292
# # 17.1692 -23.3028 -5.61872 -4.0075 -23.287 -20.6101 -5.51584 -6.15273 -14.4333 8.13052 -0.0345329 2.06274 -0.564298
# # ...
# # ...
# # the same with that using kaldi commands: compute-mfcc-feats --dither=0.0
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
from .sentence_split import split
from .num import RE_NUMBER, RE_FRAC, RE_PERCENTAGE, RE_RANGE, RE_INTEGER, RE_DEFAULT_NUM
from .num import replace_number, replace_frac, replace_percentage, replace_range, replace_default_num
from .chronology import RE_TIME, RE_DATE, RE_DATE2
from .chronology import replace_time, replace_date, replace_date2
from .quantifier import RE_TEMPERATURE
from .quantifier import replace_temperature
from .phone import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone
from .char_convert import tranditional_to_simplified
from .constants import F2H_ASCII_LETTERS, F2H_DIGITS, F2H_SPACE
def normalize_sentence(sentence):
# basic character conversions
sentence = tranditional_to_simplified(sentence)
sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
F2H_DIGITS).translate(F2H_SPACE)
# number related NSW verbalization
sentence = RE_DATE.sub(replace_date, sentence)
sentence = RE_DATE2.sub(replace_date2, sentence)
sentence = RE_TIME.sub(replace_time, sentence)
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
sentence = RE_RANGE.sub(replace_range, sentence)
sentence = RE_FRAC.sub(replace_frac, sentence)
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
sentence = RE_MOBILE_PHONE.sub(replace_phone, sentence)
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
sentence = RE_NUMBER.sub(replace_number, sentence)
return sentence
def normalize(text):
sentences = split(text)
sentences = [normalize_sentence(sent) for sent in sentences]
return sentences
"""Traditional and simplified Chinese conversion with
`opencc <https://github.com/BYVoid/OpenCC>`_.
"""
import opencc
_t2s_converter = opencc.OpenCC("t2s.json")
_s2t_converter = opencc.OpenCC('s2t.json')
def tranditional_to_simplified(text: str) -> str:
return _t2s_converter.convert(text)
def simplified_to_traditional(text: str) -> str:
return _s2t_converter.convert(text)
import re
from .num import verbalize_cardinal, verbalize_digit, num2str, DIGITS
def _time_num2str(num_string: str) -> str:
"""A special case for verbalizing number in time."""
result = num2str(num_string.lstrip('0'))
if num_string.startswith('0'):
result = DIGITS['0'] + result
return result
# 时刻表达式
RE_TIME = re.compile(
r'([0-1]?[0-9]|2[0-3])'
r':([0-5][0-9])'
r'(:([0-5][0-9]))?'
)
def replace_time(match: re.Match) -> str:
hour = match.group(1)
minute = match.group(2)
second = match.group(4)
result = f"{num2str(hour)}点"
if minute.lstrip('0'):
result += f"{_time_num2str(minute)}分"
if second and second.lstrip('0'):
result += f"{_time_num2str(second)}秒"
return result
RE_DATE = re.compile(
r'(\d{4}|\d{2})年'
r'((0?[1-9]|1[0-2])月)?'
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?'
)
def replace_date(match: re.Match) -> str:
year = match.group(1)
month = match.group(3)
day = match.group(5)
result = ""
if year:
result += f"{verbalize_digit(year)}年"
if month:
result += f"{verbalize_cardinal(month)}月"
if day:
result += f"{verbalize_cardinal(day)}{match.group(9)}"
return result
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
RE_DATE2 = re.compile(
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])'
)
def replace_date2(match: re.Match) -> str:
year = match.group(1)
month = match.group(3)
day = match.group(4)
result = ""
if year:
result += f"{verbalize_digit(year)}年"
if month:
result += f"{verbalize_cardinal(month)}月"
if day:
result += f"{verbalize_cardinal(day)}日"
return result
import string
import re
from pypinyin.constants import SUPPORT_UCS4
# 全角半角转换
# 英文字符全角 -> 半角映射表 (num: 52)
F2H_ASCII_LETTERS = {
chr(ord(char) + 65248): char
for char in string.ascii_letters
}
# 英文字符半角 -> 全角映射表
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
# 数字字符全角 -> 半角映射表 (num: 10)
F2H_DIGITS = {
chr(ord(char) + 65248): char
for char in string.digits
}
# 数字字符半角 -> 全角映射表
H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()}
# 标点符号全角 -> 半角映射表 (num: 32)
F2H_PUNCTUATIONS = {
chr(ord(char) + 65248): char
for char in string.punctuation
}
# 标点符号半角 -> 全角映射表
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
# 空格 (num: 1)
F2H_SPACE = {'\u3000': ' '}
H2F_SPACE = {' ': '\u3000'}
# 非"有拼音的汉字"的字符串,可用于NSW提取
if SUPPORT_UCS4:
RE_NSW = re.compile(
r'(?:[^'
r'\u3007' # 〇
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
r'])+'
)
else:
RE_NSW = re.compile( # pragma: no cover
r'(?:[^'
r'\u3007' # 〇
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
r'])+'
)
"""
Rules to verbalize numbers into Chinese characters.
https://zh.wikipedia.org/wiki/中文数字#現代中文
"""
import re
from typing import List
from collections import OrderedDict
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
UNITS = OrderedDict({
1: '十',
2: '百',
3: '千',
4: '万',
8: '亿',
})
# 分数表达式
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
def replace_frac(match: re.Match) -> str:
sign = match.group(1)
nominator = match.group(2)
denominator = match.group(3)
sign: str = "负" if sign else ""
nominator: str = num2str(nominator)
denominator: str = num2str(denominator)
result = f"{sign}{denominator}分之{nominator}"
return result
# 百分数表达式
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
def replace_percentage(match: re.Match) -> str:
sign = match.group(1)
percent = match.group(2)
sign: str = "负" if sign else ""
percent: str = num2str(percent)
result = f"{sign}百分之{percent}"
return result
# 整数表达式
# 带负号或者不带负号的整数 12, -10
RE_INTEGER = re.compile(
r'(-?)'
r'(\d+)'
)
# 编号-无符号整形
# 00078
RE_DEFAULT_NUM = re.compile(r'\d{4}\d*')
def replace_default_num(match: re.Match):
number = match.group(0)
return verbalize_digit(number)
# 数字表达式
# 1. 整数: -10, 10;
# 2. 浮点数: 10.2, -0.3
# 3. 不带符号和整数部分的纯浮点数: .22, .38
RE_NUMBER = re.compile(
r'(-?)((\d+)(\.\d+)?)'
r'|(\.(\d+))'
)
def replace_number(match: re.Match) -> str:
sign = match.group(1)
number = match.group(2)
pure_decimal = match.group(5)
if pure_decimal:
result = num2str(pure_decimal)
else:
sign: str = "负" if sign else ""
number: str = num2str(number)
result = f"{sign}{number}"
return result
# 范围表达式
# 12-23, 12~23
RE_RANGE = re.compile(
r'(\d+)[-~](\d+)'
)
def replace_range(match: re.Match) -> str:
first, second = match.group(1), match.group(2)
first: str = num2str(first)
second: str = num2str(second)
result = f"{first}{second}"
return result
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
stripped = value_string.lstrip('0')
if len(stripped) == 0:
return []
elif len(stripped) == 1:
if use_zero and len(stripped) < len(value_string):
return [DIGITS['0'], DIGITS[stripped]]
else:
return [DIGITS[stripped]]
else:
largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped))
first_part = value_string[:-largest_unit]
second_part = value_string[-largest_unit:]
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
def verbalize_cardinal(value_string: str) -> str:
if not value_string:
return ''
# 000 -> '零' , 0 -> '零'
value_string = value_string.lstrip('0')
if len(value_string) == 0:
return DIGITS['0']
result_symbols = _get_value(value_string)
# verbalized number starting with '一十*' is abbreviated as `十*`
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS['1'] and result_symbols[1] == UNITS[1]:
result_symbols = result_symbols[1:]
return ''.join(result_symbols)
def verbalize_digit(value_string: str, alt_one=False) -> str:
result_symbols = [DIGITS[digit] for digit in value_string]
result = ''.join(result_symbols)
if alt_one:
result.replace("一", "幺")
return result
def num2str(value_string: str) -> str:
integer_decimal = value_string.split('.')
if len(integer_decimal) == 1:
integer = integer_decimal[0]
decimal = ''
elif len(integer_decimal) == 2:
integer, decimal = integer_decimal
else:
raise ValueError(f"The value string: '${value_string}' has more than one point in it.")
result = verbalize_cardinal(integer)
decimal = decimal.rstrip('0')
if decimal:
# '.22' is verbalized as '点二二'
# '3.20' is verbalized as '三点二
result += '点' + verbalize_digit(decimal)
return result
import re
from .num import verbalize_digit
# 规范化固话/手机号码
# 手机
# http://www.jihaoba.com/news/show/13680
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
# 联通:130、131、132、156、155、186、185、176
# 电信:133、153、189、180、181、177
RE_MOBILE_PHONE= re.compile(
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
RE_TELEPHONE = re.compile(
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
def phone2str(phone_string: str, mobile=True) -> str:
if mobile:
sp_parts = phone_string.strip('+').split()
result = ''.join(
[verbalize_digit(part, alt_one=True) for part in sp_parts])
return result
else:
sil_parts = phone_string.split('-')
result = ''.join(
[verbalize_digit(part, alt_one=True) for part in sil_parts])
return result
def replace_phone(match: re.Match) -> str:
return phone2str(match.group(0))
import re
from .num import num2str
# 温度表达式,温度会影响负号的读法
# -3°C 零下三度
RE_TEMPERATURE = re.compile(
r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)'
)
def replace_temperature(match: re.Match) -> str:
sign = match.group(1)
temperature = match.group(2)
unit = match.group(3)
sign: str = "零下" if sign else ""
temperature: str = num2str(temperature)
unit: str = "摄氏度" if unit == "摄氏度" else "度"
result = f"{sign}{temperature}{unit}"
return result
import re
from typing import List
SENTENCE_SPLITOR = re.compile(r'([。!?][”’]?)')
def split(text: str) -> List[str]:
"""Split long text into sentences with sentence-splitting punctuations.
Parameters
----------
text : str
The input text.
Returns
-------
List[str]
Sentences.
"""
text = SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
return sentences
SHELL:= /bin/bash
PYTHON:= python3.7
CXX ?= g++
CC ?= gcc # used for sph2pipe
# CXX = clang++ # Uncomment these lines...
# CC = clang # ...to build with Clang.
WGET ?= wget
.PHONY: all clean
all: virtualenv kenlm.done sox.done soxbindings.done
all: virtualenv kenlm.done sox.done soxbindings.done mfa.done sclite.done
virtualenv:
test -d venv || virtualenv -p $(PYTHON) venv
......@@ -18,8 +27,8 @@ kenlm.done:
apt install -y build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50
test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
cd kenlm && python setup.py install
rm -rf kenlm/build && mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
source venv/bin/activate; cd kenlm && python setup.py install
touch kenlm.done
sox.done:
......@@ -31,5 +40,57 @@ sox.done:
soxbindings.done:
test -d soxbindings || git clone https://github.com/pseeth/soxbindings.git
source venv/bin/activate; cd soxbindings && python3 setup.py install
source venv/bin/activate; cd soxbindings && python setup.py install
touch soxbindings.done
mfa.done:
test -d montreal-forced-aligner || wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz
tar xvf montreal-forced-aligner_linux.tar.gz
touch mfa.done
#== SCTK ===============================================================================
# SCTK official repo does not have version tags. Here's the mapping:
# # 2.4.9 = 659bc36; 2.4.10 = d914e1b; 2.4.11 = 20159b5.
SCTK_GITHASH = 20159b5
SCTK_CXFLAGS = -w -march=native
SCTK_MKENV = CFLAGS="$(CFLAGS) $(SCTK_CXFLAGS)" \
CXXFLAGS="$(CXXFLAGS) -std=c++11 $(SCTK_CXFLAGS)" \
# Keep the existing target 'sclite' to avoid breaking the users who might have
# scripted it in.
.PHONY: sclite.done sctk_cleaned sctk_made
sclite.done sctk_made: sctk/.compiled
touch sclite.done
sctk/.compiled: sctk
rm -f sctk/.compiled
$(SCTK_MKENV) $(MAKE) -C sctk config
$(SCTK_MKENV) $(MAKE) -C sctk all doc
$(MAKE) -C sctk install
touch sctk/.compiled
# The GitHub archive unpacks into SCTK-{40-character-long-hash}/
sctk: sctk-$(SCTK_GITHASH).tar.gz
tar zxvf sctk-$(SCTK_GITHASH).tar.gz
rm -rf sctk-$(SCTK_GITHASH) sctk
mv SCTK-$(SCTK_GITHASH)* sctk-$(SCTK_GITHASH)
ln -s sctk-$(SCTK_GITHASH) sctk
touch sctk-$(SCTK_GITHASH).tar.gz
sctk-$(SCTK_GITHASH).tar.gz:
if [ -d '$(DOWNLOAD_DIR)' ]; then \
cp -p '$(DOWNLOAD_DIR)/sctk-$(SCTK_GITHASH).tar.gz' .; \
else \
$(WGET) -nv -T 10 -t 3 -O sctk-$(SCTK_GITHASH).tar.gz \
https://github.com/usnistgov/SCTK/archive/$(SCTK_GITHASH).tar.gz; \
fi
sctk_cleaned:
-for d in sctk/ sctk-*/; do \
[ ! -f $$d/.compiled ] || $(MAKE) -C $$d clean; \
rm -f $$d/.compiled; \
done
1. kaldi
deps gcc, mkl or openblas
2. OpenFST/ngram/pynini
deps gcc
3. MFA
deps kaldi
#!/bin/bash
set -e
set -x
# gcc
apt update -y
apt install build-essential -y
apt install software-properties-common -y
add-apt-repository ppa:ubuntu-toolchain-r/test
apt install gcc-8 g++-8 -y
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 80
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 80
update-alternatives --config gcc
# gfortran
apt-get install gfortran-8
#!/bin/bash
# Installation script for Kaldi
#
set -e
apt-get install subversion -y
KALDI_GIT="--depth 1 -b master https://github.com/kaldi-asr/kaldi.git"
KALDI_DIR="$PWD/kaldi"
if [ ! -d "$KALDI_DIR" ]; then
git clone $KALDI_GIT $KALDI_DIR
else
echo "$KALDI_DIR already exists!"
fi
cd "$KALDI_DIR/tools"
git pull
# Prevent kaldi from switching default python version
mkdir -p "python"
touch "python/.use_default_python"
./extras/check_dependencies.sh
make -j4
pushd ../src
./configure --shared --use-cuda=no --static-math --mathlib=OPENBLAS --openblas-root=${KALDI_DIR}/../OpenBLAS/install
make clean -j && make depend -j && make -j4
popd
echo "Done installing Kaldi."
#!/bin/bash
apt install -y build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50
test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
rm -rf kenlm/build && mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
#!/usr/bin/env bash
VER=1.10
WGET=${WGET:-wget}
if [ ! -f liblbfgs-$VER.tar.gz ]; then
if [ -d "$DOWNLOAD_DIR" ]; then
cp -p "$DOWNLOAD_DIR/liblbfgs-$VER.tar.gz" . || exit 1
else
$WGET https://github.com/downloads/chokkan/liblbfgs/liblbfgs-$VER.tar.gz || exit 1
fi
fi
tar -xzf liblbfgs-$VER.tar.gz
cd liblbfgs-$VER
./configure --prefix=`pwd`
make
# due to the liblbfgs project directory structure, we have to use -i
# but the erros are completely harmless
make -i install
cd ..
(
[ ! -z "${LIBLBFGS}" ] && \
echo >&2 "LIBLBFGS variable is aleady defined. Undefining..." && \
unset LIBLBFGS
[ -f ./env.sh ] && . ./env.sh
[ ! -z "${LIBLBFGS}" ] && \
echo >&2 "libLBFGS config is already in env.sh" && exit
wd=`pwd`
wd=`readlink -f $wd || pwd`
echo "export LIBLBFGS=$wd/liblbfgs-1.10"
echo export LD_LIBRARY_PATH='${LD_LIBRARY_PATH:-}':'${LIBLBFGS}'/lib/.libs
) >> env.sh
#!/bin/bash
# install openblas, kaldi before
test -d Montreal-Forced-Aligner || git clone https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git
pushd Montreal-Forced-Aligner && python setup.py install && popd
test -d kaldi || { echo "need install kaldi first"; exit 1;}
mfa thirdparty kaldi $PWD/kaldi
mfa thirdparty validate
echo "install mfa pass."
#!/usr/bin/env bash
WGET=${WGET:-wget}
# The script automatically choose default settings of miniconda for installation
# Miniconda will be installed in the HOME directory. ($HOME/miniconda3).
# Also don't make miniconda's python as default.
if [ -d "$DOWNLOAD_DIR" ]; then
cp -p "$DOWNLOAD_DIR/Miniconda3-latest-Linux-x86_64.sh" . || exit 1
else
$WGET https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh || exit 1
fi
bash Miniconda3-latest-Linux-x86_64.sh -b
$HOME/miniconda3/bin/python -m pip install --user tqdm
$HOME/miniconda3/bin/python -m pip install --user scikit-learn
$HOME/miniconda3/bin/python -m pip install --user librosa
$HOME/miniconda3/bin/python -m pip install --user h5py
#!/usr/bin/env bash
# Intel MKL is now freely available even for commercial use. This script
# attempts to install the MKL package automatically from Intel's repository.
#
# For manual repository setup instructions, see:
# https://software.intel.com/articles/installing-intel-free-libs-and-python-yum-repo
# https://software.intel.com/articles/installing-intel-free-libs-and-python-apt-repo
#
# For other package managers, or non-Linux platforms, see:
# https://software.intel.com/mkl/choose-download
set -o pipefail
default_package=intel-mkl-64bit-2020.0-088
yum_repo='https://yum.repos.intel.com/mkl/setup/intel-mkl.repo'
apt_repo='https://apt.repos.intel.com/mkl'
intel_key_url='https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB'
Usage () {
cat >&2 <<EOF
Usage: $0 [-s] [<MKL-package>]
Checks if MKL is present on the system, and/or attempts to install it.
If <MKL-package> is not provided, ${default_package} will be installed.
Intel packages are installed under the /opt/intel directory. You should be root
to install MKL into this directory; run this script using the sudo command.
Options:
-s - Skip check for MKL being already present.
-p <suse|redhat|debian|fedora|arch> -- Force type of package management. Use only
if automatic detection fails, as instructed.
-h - Show this message.
Environment:
CC The C compiler to use for MKL check. If not set, uses 'cc'.
EOF
exit 2
}
Fatal () { echo "$0: $@"; exit 1; }
Have () { type -t "$1" >/dev/null; }
# Option values.
skip_cc=
distro=
while getopts ":hksp:" opt; do
case ${opt} in
h) Usage ;;
s) skip_cc=yes ;;
p) case $OPTARG in
suse|redhat|debian|fedora|arch) distro=$OPTARG ;;
*) Fatal "invalid value -p '${OPTARG}'. " \
"Allowed: 'suse', 'redhat', 'debian', 'fedora', or 'arch'."
esac ;;
\?) echo >&2 "$0: invalid option -${OPTARG}."; Usage ;;
esac
done
shift $((OPTIND-1))
orig_arg_package=${1-''}
package=${1:-$default_package}
# Check that we are actually on Linux, otherwise give a helpful reference.
[[ $(uname) == Linux ]] || Fatal "\
This script can be used on Linux only, and your system is $(uname).
Installer packages for Mac and Windows are available for download from Intel:
https://software.intel.com/mkl/choose-download"
# Test if MKL is already installed on the system.
if [[ ! $skip_cc ]]; then
: ${CC:=cc}
Have "$CC" || Fatal "\
C compiler $CC not found.
You can skip the check for MKL presence by invoking this script with the '-s'
option to this script, but you will need a functional compiler anyway, so we
recommend that you install it first."
mkl_version=$($CC -E -I /opt/intel/mkl/include - <<< \
'#include <mkl_version.h>
__INTEL_MKL__.__INTEL_MKL_MINOR__.__INTEL_MKL_UPDATE__' 2>/dev/null |
tail -n 1 ) || mkl_version=
mkl_version=${mkl_version// /}
[[ $mkl_version ]] && Fatal "\
MKL version $mkl_version is already installed.
You can skip the check for MKL presence by invoking this script with the '-s'
option and proceed with automated installation, but we highly discourage
this. This script will register Intel repositories with your system, and it
seems that they have been already registered, or MKL has been installed some
other way.
You should use your package manager to check which MKL package is already
installed. Note that Intel packages register the latest installed version of
the library as the default. If your installed version is older than
$package, it makes sense to upgrade."
fi
# Try to determine which package manager the distro uses, unless overridden.
if [[ ! $distro ]]; then
dist_vars=$(cat /etc/os-release 2>/dev/null)
eval "$dist_vars"
for rune in $CPE_NAME $ID $ID_LIKE; do
case "$rune" in
cpe:/o:fedoraproject:fedora:2[01]) distro=redhat; break;; # Use yum.
rhel|centos) distro=redhat; break;;
redhat|suse|fedora|debian|arch) distro=$rune; break;;
esac
done
# Certain old distributions do not have /etc/os-release. We are unlikely to
# encounter these in the wild, but just in case.
# NOTE: Do not try to guess Fedora specifically here! Fedora 20 and below
# detect as redhat, and this is good, because they use yum by default.
[[ ! $distro && -f /etc/redhat-release ]] && distro=redhat
[[ ! $distro && -f /etc/SuSE-release ]] && distro=suse
[[ ! $distro && -f /etc/debian_release ]] && distro=debian
[[ ! $distro && -f /etc/arch-release ]] && distro=arch
[[ ! $distro ]] && Fatal "\
Unable to determine package management style.
Invoke this script with the option '-p <style>', where <style> can be:
redhat -- RedHat-like, uses yum and rpm for package management.
fedora -- Fedora 22+, also RedHat-like, but uses dnf instead of yum.
suse -- SUSE-like, uses zypper and rpm.
debian -- Debian-like, uses apt and dpkg.
arch -- Archlinux, uses pacman.
We do not currently support other package management systems. Check the Intel's
documentation at https://software.intel.com/mkl/choose-download for other
install options."
echo >&2 "$0: Your system is using ${distro}-style package management."
fi
# Check for root.
if [[ "$(id -u)" -ne 0 ]]; then
echo >&2 "$0: You must be root to install MKL.
Restart this script using the 'sudo' command, as:
sudo $0 -sp $distro $package
We recommend adding the '-sp $distro' options to skip the MKL and distro
detection, since this has already been done. This minimizes the number of
programs invoked with the root privileges to keep your system safe from
unexpected or erroneous changes. Also, if you are setting the CC environment
variable, sudo might not allow it to propagate to the command that it invokes."
if [ -t 0 ]; then
echo; read -ep "Run the above sudo command now? [Y/n]:"
case $REPLY in
''|[Yy]*) set -x; exec sudo "$0" -sp "$distro" "$package"
esac
fi
exit 0
fi
# The install variants, each in a function to simplify error reporting.
# Each one invokes a subshell with a 'set -x' to to show system-modifying
# commands it runs. The subshells simply limit the scope of this diagnostics
# and avoid creating noise (if we were using 'set +x', it would be printed).
Install_redhat () {
# yum-utils contains yum-config-manager, in case the user does not have it.
( set -x
rpm --import $intel_key_url
yum -y install yum-utils &&
yum-config-manager --add-repo "$yum_repo" &&
yum -y install "$package" )
}
Install_fedora () {
( set -x
rpm --import $intel_key_url
dnf -y install 'dnf-command(config-manager)' &&
dnf config-manager --add-repo "$yum_repo" &&
dnf -y install "$package" )
}
Install_suse () {
# zypper bug until libzypp-17.6.4: '--gpg-auto-import-keys' is ignored.
# See https://github.com/openSUSE/zypper/issues/144#issuecomment-418685933
# We must disable gpg checks with '--no-gpg-checks'. I won't bend backwards
# as far as check the installed .so version...
( set -x
rpm --import $intel_key_url
zypper addrepo "$yum_repo" &&
zypper --gpg-auto-import-keys --no-gpg-checks \
--non-interactive install "$package" )
}
Install_debian () {
local keyring='/usr/share/keyrings/intel-sw-products.gpg' \
sources_d='/etc/apt/sources.list.d' \
trusted_d='/etc/apt/trusted.gpg.d' \
apt_maj= apt_min= apt_ver=
# apt before 1.2 does not understand the signed-by option, and always
# look for the keyring in their trusted.gpg.d directory. This is not
# considered a good security practice any more. If apt is old, add a link
# to the keyring file and remind the user to delete it when apt is upgraded.
IFS=' .' builtin read _ apt_maj apt_min _ < <(apt-get --version)
apt_ver=$(builtin printf '%03d%03d' $apt_maj $apt_min)
# Get alternative location of /etc/apt/sources.list.d, if so configured.
eval $(apt-config shell sources_d Dir::Etc::sourceparts/f \
trusted_d Dir::Etc::trustedparts/f)
# apt is much more involved to configure than other package managers, as fas
# as third-party security keys go.
( set -x;
apt-get update &&
apt-get install -y wget apt-transport-https ca-certificates gnupg &&
wget -qO- $intel_key_url | apt-key --keyring $keyring add - &&
echo "deb [signed-by=${keyring}] $apt_repo all main" \
> "$sources_d/intel-mkl.list" ) || return 1
if [[ $apt_ver < '001002' ]]; then
( set -x; ln -s "$keyring" "${trusted_d}/" ) || return 1
fi
( set +x
apt-get update &&
apt-get install -y "$package" ) || return 1
# Print the message after the large install, so the user may notice. I hope...
if [[ $apt_ver < '001002' ]]; then
echo >&2 "$0: Your apt-get version is earlier than 1.2.
This version does not understand individual repositories signing keys, and
trusts all keys in $trusted_d. We have created a link
$trusted_d/$(basename $keyring) pointing to the file
$keyring. If/when you upgrade your system to
a higher version of apt, removing this link will help make it more secure.
This is not considered a severe security issue, but separating keyrings is the
current recommended security practice."
fi
}
Install_arch () {
( set -x
echo y | pacman -Syu intel-mkl && # In pacman we don't specify the version
pacman -Q --info intel-mkl | grep -v None
)
}
# Register MKL .so libraries with the ld.so.
ConfigLdSo() {
[ -d /etc/ld.so.conf.d ] || return 0
type -t ldconfig >/dev/null || return 0
echo >&2 "$0: Configuring ld runtime bindings"
( set -x;
echo >/etc/ld.so.conf.d/intel-mkl.conf "\
/opt/intel/lib/intel64
/opt/intel/mkl/lib/intel64"
ldconfig )
}
# Invoke installation.
if Install_${distro} && ConfigLdSo; then
echo >&2 "$0: MKL package $package was successfully installed"
else
Fatal "MKL package $package installation FAILED.
Please open an issue with us at https://github.com/kaldi-asr/kaldi/ if you
believe this is a bug."
fi
#!/bin/bash
set -e
set -x
# need support c++17, so need gcc >= 8
# openfst
ngram=ngram-1.3.13
openfst=openfst-1.8.1
shared=true
export CPLUS_INCLUDE_PATH=$PWD/${openfst}/install/include/:$CPLUS_INCLUDE_PATH
export LD_LIBRARY_PATH=$PWD/${openfst}/install/lib/:$LD_LIBRARY_PATH
test -e ${ngram}.tar.gz || wget http://www.openfst.org/twiki/pub/GRM/NGramDownload/${ngram}.tar.gz
test -d ${ngram} || tar -xvf ${ngram}.tar.gz && chown -R root:root ${ngram}
if [ $shared == true ];then
pushd ${ngram} && ./configure --enable-shared && popd
else
pushd ${ngram} && ./configure --enable-static && popd
fi
pushd ${ngram} && make -j && make install && popd
#!/usr/bin/env bash
OPENBLAS_VERSION=0.3.13
WGET=${WGET:-wget}
set -e
if ! command -v gfortran 2>/dev/null; then
echo "$0: gfortran is not installed. Please install it, e.g. by:"
echo " apt-get install gfortran"
echo "(if on Debian or Ubuntu), or:"
echo " yum install gcc-gfortran"
echo "(if on RedHat/CentOS). On a Mac, if brew is installed, it's:"
echo " brew install gfortran"
exit 1
fi
tarball=OpenBLAS-$OPENBLAS_VERSION.tar.gz
rm -rf xianyi-OpenBLAS-* OpenBLAS OpenBLAS-*.tar.gz
if [ -d "$DOWNLOAD_DIR" ]; then
cp -p "$DOWNLOAD_DIR/$tarball" .
else
url=$($WGET -qO- "https://api.github.com/repos/xianyi/OpenBLAS/releases/tags/v${OPENBLAS_VERSION}" | python -c 'import sys,json;print(json.load(sys.stdin)["tarball_url"])')
test -n "$url"
$WGET -t3 -nv -O $tarball "$url"
fi
tar xzf $tarball
mv xianyi-OpenBLAS-* OpenBLAS
make PREFIX=$(pwd)/OpenBLAS/install USE_LOCKING=1 USE_THREAD=0 -C OpenBLAS all install
if [ $? -eq 0 ]; then
echo "OpenBLAS is installed successfully."
rm $tarball
fi
#!/bin/bash
set -e
set -x
# need support c++17, so need gcc >= 8
# openfst
openfst=openfst-1.8.1
shared=true
test -e ${openfst}.tar.gz || wget http://www.openfst.org/twiki/pub/FST/FstDownload/${openfst}.tar.gz
test -d ${openfst} || tar -xvf ${openfst}.tar.gz && chown -R root:root ${openfst}
if [ $shared == true ];then
pushd ${openfst} && ./configure --enable-shared --enable-compact-fsts --enable-compress --enable-const-fsts --enable-far --enable-linear-fsts --enable-lookahead-fsts --enable-mpdt --enable-ngram-fsts --enable-pdt --enable-python --enable-special --enable-bin --enable-grm --prefix ${PWD}/install && popd
else
pushd ${openfst} && ./configure --enable-static --enable-compact-fsts --enable-compress --enable-const-fsts --enable-far --enable-linear-fsts --enable-lookahead-fsts --enable-mpdt --enable-ngram-fsts --enable-pdt --enable-python --enable-special --enable-bin --enable-grm --prefix ${PWD}/install && popd
fi
pushd ${openfst} && make -j && make install && popd
suffix_path=$(python3 -c 'import sysconfig; import os; from pathlib import Path; site = sysconfig.get_paths()["purelib"]; site=Path(site); pwd = os.getcwd(); suffix = site.parts[-2:]; print(os.path.join(*suffix));')
wfst_so_path=${PWD}/${openfst}/install/lib/${suffix_path}
cp ${wfst_so_path}/pywrapfst.* $(python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])')
#!/bin/bash
set -e
set -x
pynini=pynini-2.1.4
openfst=openfst-1.8.1
LIBRARY_PATH=$PWD/${openfst}/install/lib
test -e ${pynini}.tar.gz || wget http://www.openfst.org/twiki/pub/GRM/PyniniDownload/${pynini}.tar.gz
test -d ${pynini} || tar -xvf ${pynini}.tar.gz && chown -R root:root ${pynini}
pushd ${pynini} && LIBRARY_PATH=$LIBRARY_PATH python setup.py install && popd
#!/usr/bin/env bash
current_path=`pwd`
current_dir=`basename "$current_path"`
if [ "tools" != "$current_dir" ]; then
echo "You should run this script in tools/ directory!!"
exit 1
fi
if [ ! -d liblbfgs-1.10 ]; then
echo Installing libLBFGS library to support MaxEnt LMs
bash extras/install_liblbfgs.sh || exit 1
fi
# http://www.speech.sri.com/projects/srilm/download.html
if [ ! -f srilm.tgz ] && [ ! -f srilm.tar.gz ]; then # Changed format type from tgz to tar.gz as the srilm v1.7.3 downloads as tar.gz
echo This script cannot install SRILM in a completely automatic
echo way because you need to put your address in a download form.
echo Please download SRILM from http://www.speech.sri.com/projects/srilm/download.html
echo put it in ./srilm.tar.gz , then run this script.
echo Note: You may have to rename the downloaded file to remove version name from filename eg: mv srilm-1.7.3.tar.gz srilm.tar.gz
exit 1
fi
! which gawk 2>/dev/null && \
echo "GNU awk is not installed so SRILM will probably not work correctly: refusing to install" && exit 1;
mkdir -p srilm
cd srilm
if [ -f ../srilm.tgz ]; then
tar -xvzf ../srilm.tgz # Old SRILM format
elif [ -f ../srilm.tar.gz ]; then
tar -xvzf ../srilm.tar.gz # Changed format type from tgz to tar.gz
fi
major=`awk -F. '{ print $1 }' RELEASE`
minor=`awk -F. '{ print $2 }' RELEASE`
micro=`awk -F. '{ print $3 }' RELEASE`
if [ $major -le 1 ] && [ $minor -le 7 ] && [ $micro -le 1 ]; then
echo "Detected version 1.7.1 or earlier. Applying patch."
patch -p0 < ../extras/srilm.patch
fi
# set the SRILM variable in the top-level Makefile to this directory.
cp Makefile tmpf
cat tmpf | awk -v pwd=`pwd` '/SRILM =/{printf("SRILM = %s\n", pwd); next;} {print;}' \
> Makefile || exit 1
rm tmpf
mtype=`sbin/machine-type`
echo HAVE_LIBLBFGS=1 >> common/Makefile.machine.$mtype
grep ADDITIONAL_INCLUDES common/Makefile.machine.$mtype | \
sed 's|$| -I$(SRILM)/../liblbfgs-1.10/include|' \
>> common/Makefile.machine.$mtype
grep ADDITIONAL_LDFLAGS common/Makefile.machine.$mtype | \
sed 's|$| -L$(SRILM)/../liblbfgs-1.10/lib/ -Wl,-rpath -Wl,$(SRILM)/../liblbfgs-1.10/lib/|' \
>> common/Makefile.machine.$mtype
make || exit
cd ..
(
[ ! -z "${SRILM}" ] && \
echo >&2 "SRILM variable is aleady defined. Undefining..." && \
unset SRILM
[ -f ./env.sh ] && . ./env.sh
[ ! -z "${SRILM}" ] && \
echo >&2 "SRILM config is already in env.sh" && exit
wd=`pwd`
wd=`readlink -f $wd || pwd`
echo "export SRILM=$wd/srilm"
dirs="\${PATH}"
for directory in $(cd srilm && find bin -type d ) ; do
dirs="$dirs:\${SRILM}/$directory"
done
echo "export PATH=$dirs"
) >> env.sh
echo >&2 "Installation of SRILM finished successfully"
echo >&2 "Please source the tools/env.sh in your path.sh to enable it"
--- dstruct/src/Trie.orig 2016-11-08 19:53:40.524000000 +0000
+++ dstruct/src/Trie.cc 2016-11-08 19:53:59.088000000 +0000
@@ -200,11 +200,14 @@
if (removedData == 0) {
Trie<KeyT,DataT> node;
if (sub.remove(keys[0], &node)) {
+#if !defined(__GNUC__) || !(__GNUC__ >= 4 && __GNUC_MINOR__ >= 9 || __GNUC__ > 4)
/*
* XXX: Call subtrie destructor explicitly since we're not
* passing the removed node to the caller.
+ * !!! Triggers bug with gcc >= 4.9 optimization !!!
*/
node.~Trie();
+#endif
return true;
} else {
return false;
# Utils
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#! /usr/bin/env bash
if [ $# != 2 ]; then
echo "usage: ${0} ckpt_dir avg_num"
if [ $# != 3 ]; then
echo "usage: ${0} [best|latest] ckpt_dir avg_num"
exit -1
fi
ckpt_dir=${1}
average_num=${2}
avg_mode=${1} # best,latest
ckpt_dir=${2}
average_num=${3}
decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
avg_model.py \
--dst_model ${decode_checkpoint} \
--ckpt_dir ${ckpt_dir} \
--num ${average_num} \
--val_best
if [ $avg_mode == best ];then
# best
avg_model.py \
--dst_model ${decode_checkpoint} \
--ckpt_dir ${ckpt_dir} \
--num ${average_num} \
--val_best
else
# latest
avg_model.py \
--dst_model ${decode_checkpoint} \
--ckpt_dir ${ckpt_dir} \
--num ${average_num}
fi
if [ $? -ne 0 ]; then
echo "Failed in avg ckpt!"
......
......@@ -27,33 +27,33 @@ def main(args):
val_scores = []
beat_val_scores = []
selected_epochs = []
if args.val_best:
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
for y in jsons:
with open(y, 'r') as f:
dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss))
val_scores = np.array(val_scores)
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
for y in jsons:
with open(y, 'r') as f:
dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss))
val_scores = np.array(val_scores)
if args.val_best:
sort_idx = np.argsort(val_scores[:, 1])
sorted_val_scores = val_scores[sort_idx]
path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
print("best val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
else:
path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams')
path_list = sorted(path_list, key=os.path.getmtime)
path_list = path_list[-args.num:]
sorted_val_scores = val_scores
beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
print(path_list)
avg = None
......@@ -78,6 +78,7 @@ def main(args):
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"ckpt": path_list,
"epoch": selected_epochs.tolist(),
......
#!/usr/bin/env bash
# 2020 author Jiayu DU
# Apache 2.0
# This script reads in an Arpa format language model, and converts it into the
# KenLM format language model.
[ -f path.sh ] && . ./path.sh;
# begin configuration section
kenlm_opts="" # e.g. "-q 8 -b 8" for 8bits quantization
model_type="trie" # "trie" or "probing". trie is smaller, probing is faster.
# end configuration section
. utils/parse_options.sh
if [ $# != 2 ]; then
echo "Usage: "
echo " $0 [options] <arpa-lm-path> <kenlm-path>"
echo "e.g.:"
echo " $0 data/local/lm/4gram.arpa data/lang_test/G.trie"
echo "Options:"
echo " --model-type can be either \"trie\" or \"probing\""
echo " --kenlm-opts directly pass through to kenlm"
echo " e.g. for 8bits quantization, feed \"-q 8 -b 8\""
exit 1;
fi
export LC_ALL=C
arpa_lm=$1
kenlm=$2
if ! which build_binary >& /dev/null ; then
echo "$0: cannot find KenLM's build_binary tool,"
echo "check kenlm installation (tools/extras/install_kenlm_query_only.sh)."
exit 1
fi
mkdir -p $(dirname $kenlm)
build_binary $kenlm_opts $model_type $arpa_lm $kenlm
echo "$0: Successfully built arpa into kenlm format: $kenlm"
exit 0
\ No newline at end of file
......@@ -25,6 +25,7 @@ from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.frontend.utility import BLANK
from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.utility import SOS
from deepspeech.frontend.utility import SPACE
from deepspeech.frontend.utility import UNK
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
......@@ -44,6 +45,11 @@ add_arg('manifest_paths', str,
"You can provide multiple manifest files.",
nargs='+',
required=True)
add_arg('text_keys', str,
'text',
"keys of the text in manifest for building vocabulary. "
"You can provide multiple k.",
nargs='+')
# bpe
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
......@@ -55,13 +61,13 @@ args = parser.parse_args()
def count_manifest(counter, text_feature, manifest_path):
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
line = text_feature.tokenize(line_json['text'])
line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line)
def dump_text_manifest(fileobj, manifest_path):
def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
fileobj.write(line_json['text'] + "\n")
fileobj.write(line_json[key] + "\n")
def main():
print_arguments(args, globals())
......@@ -78,7 +84,9 @@ def main():
fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
for manifest_path in args.manifest_paths:
dump_text_manifest(fp, manifest_path)
text_keys = [args.text_keys] if type(args.text_keys) is not list else args.text_keys
for text_key in text_keys:
dump_text_manifest(fp, manifest_path, key=text_key)
fp.close()
# train
spm.SentencePieceTrainer.Train(
......@@ -102,6 +110,8 @@ def main():
for token, count in count_sorted:
if count < args.count_threshold:
break
# replace space by `<space>`
token = SPACE if token == ' ' else token
tokens.append(token)
tokens = sorted(tokens)
......
......@@ -27,7 +27,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('num_samples', int, 2000, "# of samples to for statistics.")
add_arg('specgram_type', str,
add_arg('spectrum_type', str,
'linear',
"Audio feature type. Options: linear, mfcc, fbank.",
choices=['linear', 'mfcc', 'fbank'])
......@@ -58,7 +58,7 @@ def main():
augmentation_pipeline = AugmentationPipeline('{}')
audio_featurizer = AudioFeaturizer(
specgram_type=args.specgram_type,
spectrum_type=args.spectrum_type,
feat_dim=args.feat_dim,
delta_delta=args.delta_delta,
stride_ms=args.stride_ms,
......
#!/usr/bin/env python3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""format manifest into wav.scp text.word [text.syllable text.phone]"""
import argparse
from pathlib import Path
from typing import Union
from deepspeech.frontend.utility import read_manifest
key_whitelist = set(['feat', 'text', 'syllable', 'phone'])
filename = {
'text': 'text.word',
'syllable': 'text.syllable',
'phone': 'text.phone',
'feat': 'wav.scp',
}
def dump_manifest(manifest_path, output_dir: Union[str, Path]):
output_dir = Path(output_dir).expanduser()
manifest_path = Path(manifest_path).expanduser()
manifest_jsons = read_manifest(manifest_path)
first_line = manifest_jsons[0]
file_map = {}
for k in first_line.keys():
if k not in key_whitelist:
continue
file_map[k] = open(output_dir / filename[k], 'w')
for line_json in manifest_jsons:
for k in line_json.keys():
if k not in key_whitelist:
continue
file_map[k].write(line_json['utt'] + ' ' + line_json[k] + '\n')
for _, file in file_map.items():
file.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="dump manifest to wav.scp text.word ...")
parser.add_argument("--manifest-path", type=str, help="path to manifest")
parser.add_argument(
"--output-dir",
type=str,
help="path to save outputs(audio and transcriptions)")
args = parser.parse_args()
dump_manifest(args.manifest_path, args.output_dir)
#!/bin/bash
if [ $# == 1 ];then
echo "usage: ${0} manifest_file"
exit -1
fi
manifest=$1
jq -S '.feat_shape[0]' ${manifest} | sort -nu
#!/usr/bin/env python3
# Apache 2.0
import argparse
import codecs
import sys
is_python2 = sys.version_info[0] == 2
def get_parser():
parser = argparse.ArgumentParser(
description="filter words in a text file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--exclude",
"-v",
dest="exclude",
action="store_true",
help="exclude filter words", )
parser.add_argument("filt", type=str, help="filter list")
parser.add_argument("infile", type=str, help="input file")
return parser
def main(args):
args = get_parser().parse_args(args)
filter_file(args.infile, args.filt, args.exclude)
def filter_file(infile, filt, exclude):
vocab = set()
with codecs.open(filt, "r", encoding="utf-8") as vocabfile:
for line in vocabfile:
vocab.add(line.strip())
sys.stdout = codecs.getwriter("utf-8")(sys.stdout
if is_python2 else sys.stdout.buffer)
with codecs.open(infile, "r", encoding="utf-8") as textfile:
for line in textfile:
if exclude:
print(" ".join(
map(
lambda word: word if word not in vocab else "",
line.strip().split(), )))
else:
print(" ".join(
map(
lambda word: word if word in vocab else "<UNK>",
line.strip().split(), )))
if __name__ == "__main__":
main(sys.argv[1:])
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
# Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
$shifted=0;
if ($ARGV[0] eq "--exclude") {
$exclude = 1;
shift @ARGV;
$shifted=1;
}
if ($ARGV[0] eq "-f") {
$field = $ARGV[1];
shift @ARGV; shift @ARGV;
$shifted=1
}
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
"only the lines that were *not* in id_list.\n" .
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
"-f option, add 1 to the argument.\n" .
"See also: utils/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
while(<>) {
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
# $1 is what we filter on.
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
print $_;
}
}
} else {
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
@A >= $field || die "Invalid scp file line $_";
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
print $_;
}
}
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
\ No newline at end of file
......@@ -26,7 +26,7 @@ from deepspeech.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), mat(ark), scp")
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.")
......@@ -53,7 +53,8 @@ def main():
fout = open(args.output_path, 'w', encoding='utf-8')
# get feat dim
mean, std = load_cmvn(args.cmvn_path, filetype='json')
filetype = args.cmvn_path.split(".")[-1]
mean, istd = load_cmvn(args.cmvn_path, filetype=filetype)
feat_dim = mean.shape[0] #(D)
print(f"Feature dim: {feat_dim}")
......@@ -75,6 +76,7 @@ def main():
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
feat_shape.append(feat_dim)
line_json['filetype'] = 'sound'
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n')
......
#!/usr/bin/env python3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""format manifest with more metadata."""
import argparse
import functools
import json
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.")
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath of the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
# bpe
add_arg('spm_model_prefix', str, None,
"spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm")
add_arg('output_path', str, None, "filepath of formated manifest.", required=True)
# yapf: disable
args = parser.parse_args()
def main():
print_arguments(args, globals())
fout = open(args.output_path, 'w', encoding='utf-8')
# get feat dim
mean, std = load_cmvn(args.cmvn_path, filetype='json')
feat_dim = mean.shape[0] #(D)
print(f"Feature dim: {feat_dim}")
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
vocab_size = text_feature.vocab_size
print(f"Vocab size: {vocab_size}")
count = 0
for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
# text: translation text, text1: transcript text.
# Currently only support joint-vocab, will add separate vocabs setting.
line = line_json['text']
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
line_json['token'] = tokens
line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size)
line = line_json['text1']
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
line_json['token1'] = tokens
line_json['token_id1'] = tokenids
line_json['token_shape1'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n')
count += 1
print(f"Examples number: {count}")
fout.close()
if __name__ == '__main__':
main()
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# 2013-2016 Johns Hopkins University (author: Daniel Povey)
# 2015 Hainan Xu
# 2015 Guoguo Chen
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Adds disambiguation symbols to a lexicon.
# Outputs still in the normal lexicon format.
# Disambig syms are numbered #1, #2, #3, etc. (#0
# reserved for symbol in grammar).
# Outputs the number of disambig syms to the standard output.
# With the --pron-probs option, expects the second field
# of each lexicon line to be a pron-prob.
# With the --sil-probs option, expects three additional
# fields after the pron-prob, representing various components
# of the silence probability model.
$pron_probs = 0;
$sil_probs = 0;
$first_allowed_disambig = 1;
for ($n = 1; $n <= 3 && @ARGV > 0; $n++) {
if ($ARGV[0] eq "--pron-probs") {
$pron_probs = 1;
shift @ARGV;
}
if ($ARGV[0] eq "--sil-probs") {
$sil_probs = 1;
shift @ARGV;
}
if ($ARGV[0] eq "--first-allowed-disambig") {
$first_allowed_disambig = 0 + $ARGV[1];
if ($first_allowed_disambig < 1) {
die "add_lex_disambig.pl: invalid --first-allowed-disambig option: $first_allowed_disambig\n";
}
shift @ARGV;
shift @ARGV;
}
}
if (@ARGV != 2) {
die "Usage: add_lex_disambig.pl [opts] <lexicon-in> <lexicon-out>\n" .
"This script adds disambiguation symbols to a lexicon in order to\n" .
"make decoding graphs determinizable; it adds pseudo-phone\n" .
"disambiguation symbols #1, #2 and so on at the ends of phones\n" .
"to ensure that all pronunciations are different, and that none\n" .
"is a prefix of another.\n" .
"It prints to the standard output the number of the largest-numbered" .
"disambiguation symbol that was used.\n" .
"\n" .
"Options: --pron-probs Expect pronunciation probabilities in the 2nd field\n" .
" --sil-probs [should be with --pron-probs option]\n" .
" Expect 3 extra fields after the pron-probs, for aspects of\n" .
" the silence probability model\n" .
" --first-allowed-disambig <n> The number of the first disambiguation symbol\n" .
" that this script is allowed to add. By default this is\n" .
" #1, but you can set this to a larger value using this option.\n" .
"e.g.:\n" .
" add_lex_disambig.pl lexicon.txt lexicon_disambig.txt\n" .
" add_lex_disambig.pl --pron-probs lexiconp.txt lexiconp_disambig.txt\n" .
" add_lex_disambig.pl --pron-probs --sil-probs lexiconp_silprob.txt lexiconp_silprob_disambig.txt\n";
}
$lexfn = shift @ARGV;
$lexoutfn = shift @ARGV;
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
# (1) Read in the lexicon.
@L = ( );
while(<L>) {
@A = split(" ", $_);
push @L, join(" ", @A);
}
# (2) Work out the count of each phone-sequence in the
# lexicon.
foreach $l (@L) {
@A = split(" ", $l);
shift @A; # Remove word.
if ($pron_probs) {
$p = shift @A;
if (!($p > 0.0 && $p <= 1.0)) { die "Bad lexicon line $l (expecting pron-prob as second field)"; }
}
if ($sil_probs) {
$silp = shift @A;
if (!($silp > 0.0 && $silp <= 1.0)) { die "Bad lexicon line $l for silprobs"; }
$correction = shift @A;
if ($correction <= 0.0) { die "Bad lexicon line $l for silprobs"; }
$correction = shift @A;
if ($correction <= 0.0) { die "Bad lexicon line $l for silprobs"; }
}
if (!(@A)) {
die "Bad lexicon line $1, no phone in phone list";
}
$count{join(" ",@A)}++;
}
# (3) For each left sub-sequence of each phone-sequence, note down
# that it exists (for identifying prefixes of longer strings).
foreach $l (@L) {
@A = split(" ", $l);
shift @A; # Remove word.
if ($pron_probs) { shift @A; } # remove pron-prob.
if ($sil_probs) {
shift @A; # Remove silprob
shift @A; # Remove silprob
}
while(@A > 0) {
pop @A; # Remove last phone
$issubseq{join(" ",@A)} = 1;
}
}
# (4) For each entry in the lexicon:
# if the phone sequence is unique and is not a
# prefix of another word, no diambig symbol.
# Else output #1, or #2, #3, ... if the same phone-seq
# has already been assigned a disambig symbol.
open(O, ">$lexoutfn") || die "Opening lexicon file $lexoutfn for writing.\n";
# max_disambig will always be the highest-numbered disambiguation symbol that
# has been used so far.
$max_disambig = $first_allowed_disambig - 1;
foreach $l (@L) {
@A = split(" ", $l);
$word = shift @A;
if ($pron_probs) {
$pron_prob = shift @A;
}
if ($sil_probs) {
$sil_word_prob = shift @A;
$word_sil_correction = shift @A;
$prev_nonsil_correction = shift @A
}
$phnseq = join(" ", @A);
if (!defined $issubseq{$phnseq}
&& $count{$phnseq} == 1) {
; # Do nothing.
} else {
if ($phnseq eq "") { # need disambig symbols for the empty string
# that are not use anywhere else.
$max_disambig++;
$reserved_for_the_empty_string{$max_disambig} = 1;
$phnseq = "#$max_disambig";
} else {
$cur_disambig = $last_used_disambig_symbol_of{$phnseq};
if (!defined $cur_disambig) {
$cur_disambig = $first_allowed_disambig;
} else {
$cur_disambig++; # Get a number that has not been used yet for
# this phone sequence.
}
while (defined $reserved_for_the_empty_string{$cur_disambig}) {
$cur_disambig++;
}
if ($cur_disambig > $max_disambig) {
$max_disambig = $cur_disambig;
}
$last_used_disambig_symbol_of{$phnseq} = $cur_disambig;
$phnseq = $phnseq . " #" . $cur_disambig;
}
}
if ($pron_probs) {
if ($sil_probs) {
print O "$word\t$pron_prob\t$sil_word_prob\t$word_sil_correction\t$prev_nonsil_correction\t$phnseq\n";
} else {
print O "$word\t$pron_prob\t$phnseq\n";
}
} else {
print O "$word\t$phnseq\n";
}
}
print $max_disambig . "\n";
\ No newline at end of file
#!/bin/bash
# Copyright 2015 Yajie Miao (Carnegie Mellon University)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script compiles the lexicon and CTC tokens into FSTs. FST compiling slightly differs between the
# phoneme and character-based lexicons.
set -eo pipefail
. utils/parse_options.sh
if [ $# -ne 3 ]; then
echo "usage: utils/fst/compile_lexicon_token_fst.sh <dict-src-dir> <tmp-dir> <lang-dir>"
echo "e.g.: utils/fst/compile_lexicon_token_fst.sh data/local/dict data/local/lang_tmp data/lang"
echo "<dict-src-dir> should contain the following files:"
echo "lexicon.txt lexicon_numbers.txt units.txt"
echo "options: "
exit 1;
fi
srcdir=$1
tmpdir=$2
dir=$3
mkdir -p $dir $tmpdir
[ -f path.sh ] && . ./path.sh
cp $srcdir/units.txt $dir
# Add probabilities to lexicon entries. There is in fact no point of doing this here since all the entries have 1.0.
# But utils/make_lexicon_fst.pl requires a probabilistic version, so we just leave it as it is.
perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $tmpdir/lexiconp.txt || exit 1;
# Add disambiguation symbols to the lexicon. This is necessary for determinizing the composition of L.fst and G.fst.
# Without these symbols, determinization will fail.
# default first disambiguation is #1
ndisambig=`utils/fst/add_lex_disambig.pl $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt`
# add #0 (#0 reserved for symbol in grammar).
ndisambig=$[$ndisambig+1];
( for n in `seq 0 $ndisambig`; do echo '#'$n; done ) > $tmpdir/disambig.list
# Get the full list of CTC tokens used in FST. These tokens include <eps>, the blank <blk>,
# the actual model unit, and the disambiguation symbols.
cat $srcdir/units.txt | awk '{print $1}' > $tmpdir/units.list
(echo '<eps>';) | cat - $tmpdir/units.list $tmpdir/disambig.list | awk '{print $1 " " (NR-1)}' > $dir/tokens.txt
# ctc_token_fst_corrected is too big and too slow for character based chinese modeling,
# so here just use simple ctc_token_fst
utils/fst/ctc_token_fst.py --token_file $dir/tokens.txt | \
fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/tokens.txt --keep_isymbols=false --keep_osymbols=false | \
fstarcsort --sort_type=olabel > $dir/T.fst || exit 1;
# Encode the words with indices. Will be used in lexicon and language model FST compiling.
cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | awk '
BEGIN {
print "<eps> 0";
}
{
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $dir/words.txt || exit 1;
# Now compile the lexicon FST. Depending on the size of your lexicon, it may take some time.
token_disambig_symbol=`grep \#0 $dir/tokens.txt | awk '{print $2}'`
word_disambig_symbol=`grep \#0 $dir/words.txt | awk '{print $2}'`
utils/fst/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp_disambig.txt 0 "sil" '#'$ndisambig | \
fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/words.txt \
--keep_isymbols=false --keep_osymbols=false | \
fstaddselfloops "echo $token_disambig_symbol |" "echo $word_disambig_symbol |" | \
fstarcsort --sort_type=olabel > $dir/L.fst || exit 1;
echo "Lexicon and Token FSTs compiling succeeded"
\ No newline at end of file
#!/usr/bin/env python3
import argparse
def main(args):
"""Token Transducer"""
# <eps> entry
print('0 1 <eps> <eps>')
# skip begining and ending <blank>
print('1 1 <blank> <eps>')
print('2 2 <blank> <eps>')
# <eps> exit
print('2 0 <eps> <eps>')
# linking `token` between node 1 and node 2
with open(args.token_file, 'r') as fin:
node = 3
for entry in fin:
fields = entry.strip().split(' ')
phone = fields[0]
if phone == '<eps>' or phone == '<blank>':
continue
elif '#' in phone:
# disambiguous phone
# `token` maybe ending with disambiguous symbol
print('{} {} {} {}'.format(0, 0, '<eps>', phone))
else:
# eating `token`
print('{} {} {} {}'.format(1, node, phone, phone))
# remove repeating `token`
print('{} {} {} {}'.format(node, node, phone, '<eps>'))
# leaving `token`
print('{} {} {} {}'.format(node, 2, '<eps>', '<eps>'))
node += 1
# Fianl node
print('0')
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='FST: CTC Token FST transducer')
parser.add_argument(
'--token_file',
required=True,
help='e2e model token file. line: token(char/phone/spm/disambigous)')
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
import argparse
def il(n):
"""ilabel"""
return n + 1
def ol(n):
"""olabel"""
return n + 1
def s(n):
"""state"""
return n
def main(args):
with open(args.token_file) as f:
lines = f.readlines()
# token count w/0 <blank> <eps>
phone_count = 0
disambig_count = 0
for line in lines:
sp = line.strip().split()
phone = sp[0]
if phone == '<eps>' or phone == '<blank>':
continue
if phone.startswith('#'):
disambig_count += 1
else:
phone_count += 1
# 1. add start state
# first token is <blank>:0
print('0 0 {} 0'.format(il(0)))
# 2. 0 -> i, i -> i, i -> 0
# non-blank token start from 1
for i in range(1, phone_count + 1):
# eating `token`
print('0 {} {} {}'.format(s(i), il(i), ol(i)))
# remove repeating `token`
print('{} {} {} 0'.format(s(i), s(i), il(i)))
# skip ending <blank> `token`
print('{} 0 {} 0'.format(s(i), il(0)))
# 3. i -> other phone
# non-blank token to other non-blank token
for i in range(1, phone_count + 1):
for j in range(1, phone_count + 1):
if i != j:
print('{} {} {} {}'.format(s(i), s(j), il(j), ol(j)))
# 4. add disambiguous arcs on every final state
# blank and non-blank token maybe ending with disambiguous `token`
for i in range(0, phone_count + 1):
for j in range(phone_count + 2, phone_count + disambig_count + 2):
print('{} {} {} {}'.format(s(i), s(i), 0, j))
# 5. every i is final state
# blank and non-blank `token` are final state
for i in range(0, phone_count + 1):
print(s(i))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='FST: CTC Token unfold FST transducer')
parser.add_argument(
'--token_file',
required=True,
help='e2e model token file. line: token(char/phone/spm/disambigous)')
args = parser.parse_args()
main(args)
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# 2015 Guoguo Chen
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script replaces epsilon with #0 on the input side only, of the G.fst
# acceptor.
while(<>){
if (/\s+#0\s+/) {
print STDERR "$0: ERROR: LM has word #0, " .
"which is reserved as disambiguation symbol\n";
exit 1;
}
s:^(\d+\s+\d+\s+)\<eps\>(\s+):$1#0$2:;
print;
}
\ No newline at end of file
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright 2010-2011 Microsoft Corporation
# 2013 Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# makes lexicon FST, in text form, from lexicon (pronunciation probabilities optional).
$pron_probs = 0;
if ((@ARGV > 0) && ($ARGV[0] eq "--pron-probs")) {
$pron_probs = 1;
shift @ARGV;
}
if (@ARGV != 1 && @ARGV != 3 && @ARGV != 4) {
print STDERR "Usage: make_lexicon_fst.pl [--pron-probs] lexicon.txt [silprob silphone [sil_disambig_sym]] >lexiconfst.txt\n\n";
print STDERR "Creates a lexicon FST that transduces phones to words, and may allow optional silence.\n\n";
print STDERR "Note: ordinarily, each line of lexicon.txt is:\n";
print STDERR " word phone1 phone2 ... phoneN;\n";
print STDERR "if the --pron-probs option is used, each line is:\n";
print STDERR " word pronunciation-probability phone1 phone2 ... phoneN.\n\n";
print STDERR "The probability 'prob' will typically be between zero and one, and note that\n";
print STDERR "it's generally helpful to normalize so the largest one for each word is 1.0, but\n";
print STDERR "this is your responsibility.\n\n";
print STDERR "The silence disambiguation symbol, e.g. something like #5, is used only\n";
print STDERR "when creating a lexicon with disambiguation symbols, e.g. L_disambig.fst,\n";
print STDERR "and was introduced to fix a particular case of non-determinism of decoding graphs.\n\n";
exit(1);
}
$lexfn = shift @ARGV;
if (@ARGV == 0) {
$silprob = 0.0;
} elsif (@ARGV == 2) {
($silprob,$silphone) = @ARGV;
} else {
($silprob,$silphone,$sildisambig) = @ARGV;
}
if ($silprob != 0.0) {
$silprob < 1.0 || die "Sil prob cannot be >= 1.0";
$silcost = -log($silprob);
$nosilcost = -log(1.0 - $silprob);
}
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
if ( $silprob == 0.0 ) { # No optional silences: just have one (loop+final) state which is numbered zero.
$loopstate = 0;
$nextstate = 1; # next unallocated state.
while (<L>) {
@A = split(" ", $_);
@A == 0 && die "Empty lexicon line.";
foreach $a (@A) {
if ($a eq "<eps>") {
die "Bad lexicon line $_ (<eps> is forbidden)";
}
}
$w = shift @A;
if (! $pron_probs) {
$pron_cost = 0.0;
} else {
$pron_prob = shift @A;
if (! defined $pron_prob || !($pron_prob > 0.0 && $pron_prob <= 1.0)) {
die "Bad pronunciation probability in line $_";
}
$pron_cost = -log($pron_prob);
}
if ($pron_cost != 0.0) { $pron_cost_string = "\t$pron_cost"; } else { $pron_cost_string = ""; }
$s = $loopstate;
$word_or_eps = $w;
while (@A > 0) {
$p = shift @A;
if (@A > 0) {
$ns = $nextstate++;
} else {
$ns = $loopstate;
}
print "$s\t$ns\t$p\t$word_or_eps$pron_cost_string\n";
$word_or_eps = "<eps>";
$pron_cost_string = ""; # so we only print it on the first arc of the word.
$s = $ns;
}
}
print "$loopstate\t0\n"; # final-cost.
} else { # have silence probs.
$startstate = 0;
$loopstate = 1;
$silstate = 2; # state from where we go to loopstate after emitting silence.
print "$startstate\t$loopstate\t<eps>\t<eps>\t$nosilcost\n"; # no silence.
if (!defined $sildisambig) {
print "$startstate\t$loopstate\t$silphone\t<eps>\t$silcost\n"; # silence.
print "$silstate\t$loopstate\t$silphone\t<eps>\n"; # no cost.
$nextstate = 3;
} else {
$disambigstate = 3;
$nextstate = 4;
print "$startstate\t$disambigstate\t$silphone\t<eps>\t$silcost\n"; # silence.
print "$silstate\t$disambigstate\t$silphone\t<eps>\n"; # no cost.
print "$disambigstate\t$loopstate\t$sildisambig\t<eps>\n"; # silence disambiguation symbol.
}
while (<L>) {
@A = split(" ", $_);
$w = shift @A;
if (! $pron_probs) {
$pron_cost = 0.0;
} else {
$pron_prob = shift @A;
if (! defined $pron_prob || !($pron_prob > 0.0 && $pron_prob <= 1.0)) {
die "Bad pronunciation probability in line $_";
}
$pron_cost = -log($pron_prob);
}
if ($pron_cost != 0.0) { $pron_cost_string = "\t$pron_cost"; } else { $pron_cost_string = ""; }
$s = $loopstate;
$word_or_eps = $w;
while (@A > 0) {
$p = shift @A;
if (@A > 0) {
$ns = $nextstate++;
print "$s\t$ns\t$p\t$word_or_eps$pron_cost_string\n";
$word_or_eps = "<eps>";
$pron_cost_string = ""; $pron_cost = 0.0; # so we only print it the 1st time.
$s = $ns;
} elsif (!defined($silphone) || $p ne $silphone) {
# This is non-deterministic but relatively compact,
# and avoids epsilons.
$local_nosilcost = $nosilcost + $pron_cost;
$local_silcost = $silcost + $pron_cost;
print "$s\t$loopstate\t$p\t$word_or_eps\t$local_nosilcost\n";
print "$s\t$silstate\t$p\t$word_or_eps\t$local_silcost\n";
} else {
# no point putting opt-sil after silence word.
print "$s\t$loopstate\t$p\t$word_or_eps$pron_cost_string\n";
}
}
}
print "$loopstate\t0\n"; # final-cost.
}
\ No newline at end of file
#!/bin/bash
if [ -f path.sh ]; then . path.sh; fi
lm_dir=$1
src_lang=$2
tgt_lang=$3
arpa_lm=${lm_dir}/lm.arpa
[ ! -f $arpa_lm ] && { echo "No such file $arpa_lm"; exit 1;}
rm -rf $tgt_lang
cp -r $src_lang $tgt_lang
# Compose the language model to FST
# grep -i或--ignore-case 忽略字符大小写的差别。
# grep -v或--revert-match 反转查找。
# arpa2fst: remove the embedded symbols from the FST
# arpa2fst: make sure there are no out-of-vocabulary words in the language model
# arpa2fst: remove "illegal" sequences of the start and end-ofsentence symbols
# eps2disambig.pl: replace epsilons on the input side with the special disambiguation symbol #0.
# s2eps.pl: replaces <s> and </s> with <eps> (on both input and output sides), for the G.fst acceptor.
# G.fst, the disambiguation symbol #0 only appears on the input side
# do eps2disambig.pl and s2eps.pl maybe just for fallowing `fstrmepsilon`.
cat $arpa_lm | \
grep -v '<s> <s>' | \
grep -v '</s> <s>' | \
grep -v '</s> </s>' | \
grep -v -i '<unk>' | \
grep -v -i '<spoken_noise>' | \
arpa2fst --read-symbol-table=$tgt_lang/words.txt --keep-symbols=true - | fstprint | \
utils/fst/eps2disambig.pl | utils/fst/s2eps.pl | fstcompile --isymbols=$tgt_lang/words.txt \
--osymbols=$tgt_lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > $tgt_lang/G.fst
echo "Checking how stochastic G is (the first of these numbers should be small):"
fstisstochastic $tgt_lang/G.fst
# Compose the token, lexicon and language-model FST into the final decoding graph
# minimization: the same as minimization algorithm that applies to weighted acceptors;
# the only change relevant here is that it avoids pushing weights,
# hence preserving stochasticity
fsttablecompose $tgt_lang/L.fst $tgt_lang/G.fst | fstdeterminizestar --use-log=true | \
fstminimizeencoded | fstarcsort --sort_type=ilabel > $tgt_lang/LG.fst || exit 1;
fsttablecompose $tgt_lang/T.fst $tgt_lang/LG.fst > $tgt_lang/TLG.fst || exit 1;
echo "Composing decoding graph TLG.fst succeeded"
#rm -r $tgt_lang/LG.fst # We don't need to keep this intermediate FST
\ No newline at end of file
#!/usr/bin/env python3
import argparse
def main(args):
# load `unit` or `vocab` file
unit_table = set()
with open(args.unit_file, 'r') as fin:
for line in fin:
unit = line.strip()
unit_table.add(unit)
def contain_oov(units):
for unit in units:
if unit not in unit_table:
return True
return False
# load spm model
bpemode = args.bpemodel
if bpemode:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(sys.bpemodel)
# used to filter polyphone
lexicon_table = set()
with open(args.in_lexicon, 'r') as fin, \
open(args.out_lexicon, 'w') as fout:
for line in fin:
word = line.split()[0]
if word == 'SIL' and not bpemode: # `sil` might be a valid piece in bpemodel
continue
elif word == '<SPOKEN_NOISE>':
continue
else:
# each word only has one pronunciation for e2e system
if word in lexicon_table:
continue
if bpemode:
pieces = sp.EncodeAsPieces(word)
if contain_oov(pieces):
print('Ignoring words {}, which contains oov unit'.
format(''.join(word).strip('▁')))
continue
chars = ' '.join(
[p if p in unit_table else '<unk>' for p in pieces])
else:
# ignore words with OOV
if contain_oov(word):
print('Ignoring words {}, which contains oov unit'.
format(word))
continue
# Optional, append ▁ in front of english word
# we assume the model unit of our e2e system is char now.
if word.encode('utf8').isalpha() and '▁' in unit_table:
word = '▁' + word
chars = ' '.join(word) # word is a char list
fout.write('{} {}\n'.format(word, chars))
lexicon_table.add(word)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='FST: preprae e2e(char/spm) dict')
parser.add_argument(
'--unit_file',
required=True,
help='e2e model unit file(lang_char.txt/vocab.txt). line: char/spm_pices'
)
parser.add_argument(
'--in_lexicon',
required=True,
help='raw lexicon file. line: word ph0 ... phn')
parser.add_argument(
'--out_lexicon',
required=True,
help='output lexicon file. line: word char0 ... charn')
parser.add_argument('--bpemodel', default=None, help='bpemodel')
args = parser.parse_args()
print(args)
main(args)
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script removes lines that contain these OOVs on either the
# third or fourth fields of the line. It is intended to remove arcs
# with OOVs on, from FSTs (probably compiled from ARPAs with OOVs in).
if ( @ARGV < 1 && @ARGV > 2) {
die "Usage: remove_oovs.pl unk_list.txt [ printed-fst ]\n";
}
$unklist = shift @ARGV;
open(S, "<$unklist") || die "Failed opening unknown-symbol list $unklist\n";
while(<S>){
@A = split(" ", $_);
@A == 1 || die "Bad line in unknown-symbol list: $_";
$unk{$A[0]} = 1;
}
$num_removed = 0;
while(<>){
@A = split(" ", $_);
if(defined $unk{$A[2]} || defined $unk{$A[3]}) {
$num_removed++;
} else {
print;
}
}
print STDERR "remove_oovs.pl: removed $num_removed lines.\n";
#!/usr/bin/env python3
import argparse
def main(args):
# skip <blank> `token`
print('0 0 <blank> <eps>')
with open(args.token_file, 'r') as fin:
for entry in fin:
fields = entry.strip().split(' ')
phone = fields[0]
if phone == '<eps>' or phone == '<blank>':
continue
elif '#' in phone:
# disambiguous phone
# maybe add disambiguous `token`
print('{} {} {} {}'.format(0, 0, '<eps>', phone))
else:
# eating `token`
print('{} {} {} {}'.format(0, 0, phone, phone))
# final state
print('0')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='FST: RNN-T Token FST transducer')
parser.add_argument(
'--token_file',
required=True,
help='e2e model token file. line: token(char/phone/spm/disambigous)')
args = parser.parse_args()
main(args)
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script replaces <s> and </s> with <eps> (on both input and output sides),
# for the G.fst acceptor.
while(<>){
@A = split(" ", $_);
if ( @A >= 4 ) {
if ($A[2] eq "<s>" || $A[2] eq "</s>") { $A[2] = "<eps>"; }
if ($A[3] eq "<s>" || $A[3] eq "</s>") { $A[3] = "<eps>"; }
}
print join("\t", @A) . "\n";
}
\ No newline at end of file
文件模式从 100644 更改为 100755
#!/usr/bin/env python3
"""Manifest file to key-value files."""
import argparse
import functools
from pathlib import Path
from utils.utility import add_arguments
from utils.utility import print_arguments
from utils.utility import read_manifest
def main(args):
print_arguments(args, globals())
count = 0
outdir = Path(args.output_path)
wav_scp = outdir / 'wav.scp'
dur_scp = outdir / 'duration'
text_scp = outdir / 'text'
manifest_jsons = read_manifest(args.manifest_path)
with wav_scp.open('w') as fwav, dur_scp.open('w') as fdur, text_scp.open(
'w') as ftxt:
for line_json in manifest_jsons:
utt = line_json['utt']
feat = line_json['feat']
file_ext = Path(feat).suffix # .wav
text = line_json['text']
feat_shape = line_json['feat_shape']
dur = feat_shape[0]
feat_dim = feat_shape[1]
if 'token' in line_json:
tokens = line_json['token']
tokenids = line_json['token_id']
token_shape = line_json['token_shape']
token_len = token_shape[0]
vocab_dim = token_shape[1]
if file_ext == '.wav':
fwav.write(f"{utt} {feat}\n")
fdur.write(f"{utt} {dur}\n")
ftxt.write(f"{utt} {text}\n")
count += 1
print(f"Examples number: {count}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('manifest_path', str,
'data/librispeech/manifest.train',
"Filepath of manifest to compute normalizer's mean and stddev.")
add_arg('output_path', str,
'data/train',
"dir path to dump wav.scp/duaration/text files.")
# yapf: disable
args = parser.parse_args()
main(args)
......@@ -22,7 +22,7 @@ lmbin=${2}.klm.bin
# https://kheafield.com/code/kenlm/estimation/
echo "build arpa lm."
lmplz -o ${order} -S ${mem} --prune ${prune} < ${text} >${arpa} || { echo "train kenlm error!"; exit -1; }
lmplz -o ${order} -S ${mem} --prune ${prune} < ${text} > ${arpa} || { echo "train kenlm error!"; exit -1; }
# https://kheafield.com/code/kenlm/
echo "build binary lm."
......
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# In general, doing
# run.pl some.log a b c is like running the command a b c in
# the bash shell, and putting the standard error and output into some.log.
# To run parallel jobs (backgrounded on the host machine), you can do (e.g.)
# run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB
# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier].
# If any of the jobs fails, this script will fail.
# A typical example is:
# run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz
# and run.pl will run something like:
# ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log
#
# Basically it takes the command-line arguments, quotes them
# as necessary to preserve spaces, and evaluates them with bash.
# In addition it puts the command line at the top of the log, and
# the start and end times of the command at the beginning and end.
# The reason why this is useful is so that we can create a different
# version of this program that uses a queueing system instead.
#use Data::Dumper;
@ARGV < 2 && die "usage: run.pl log-file command-line arguments...";
#print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n";
$job_pick = 'all';
$max_jobs_run = -1;
$jobstart = 1;
$jobend = 1;
$ignored_opts = ""; # These will be ignored.
# First parse an option like JOB=1:4, and any
# options that would normally be given to
# queue.pl, which we will just discard.
for (my $x = 1; $x <= 2; $x++) { # This for-loop is to
# allow the JOB=1:n option to be interleaved with the
# options to qsub.
while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) {
# parse any options that would normally go to qsub, but which will be ignored here.
my $switch = shift @ARGV;
if ($switch eq "-V") {
$ignored_opts .= "-V ";
} elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") {
# we do support the option --max-jobs-run n, and its GridEngine form -tc n.
# if the command appears multiple times uses the smallest option.
if ( $max_jobs_run <= 0 ) {
$max_jobs_run = shift @ARGV;
} else {
my $new_constraint = shift @ARGV;
if ( ($new_constraint < $max_jobs_run) ) {
$max_jobs_run = $new_constraint;
}
}
if (! ($max_jobs_run > 0)) {
die "run.pl: invalid option --max-jobs-run $max_jobs_run";
}
} else {
my $argument = shift @ARGV;
if ($argument =~ m/^--/) {
print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n";
}
if ($switch eq "-sync" && $argument =~ m/^[yY]/) {
$ignored_opts .= "-sync "; # Note: in the
# corresponding code in queue.pl it says instead, just "$sync = 1;".
} elsif ($switch eq "-pe") { # e.g. -pe smp 5
my $argument2 = shift @ARGV;
$ignored_opts .= "$switch $argument $argument2 ";
} elsif ($switch eq "--gpu") {
$using_gpu = $argument;
} elsif ($switch eq "--pick") {
if($argument =~ m/^(all|failed|incomplete)$/) {
$job_pick = $argument;
} else {
print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'"
}
} else {
# Ignore option.
$ignored_opts .= "$switch $argument ";
}
}
}
if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20
$jobname = $1;
$jobstart = $2;
$jobend = $3;
if ($jobstart > $jobend) {
die "run.pl: invalid job range $ARGV[0]";
}
if ($jobstart <= 0) {
die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility).";
}
shift;
} elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1.
$jobname = $1;
$jobstart = $2;
$jobend = $2;
shift;
} elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) {
print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n";
}
}
# Users found this message confusing so we are removing it.
# if ($ignored_opts ne "") {
# print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n";
# }
if ($max_jobs_run == -1) { # If --max-jobs-run option not set,
# then work out the number of processors if possible,
# and set it based on that.
$max_jobs_run = 0;
if ($using_gpu) {
if (open(P, "nvidia-smi -L |")) {
$max_jobs_run++ while (<P>);
close(P);
}
if ($max_jobs_run == 0) {
$max_jobs_run = 1;
print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n";
}
} elsif (open(P, "</proc/cpuinfo")) { # Linux
while (<P>) { if (m/^processor/) { $max_jobs_run++; } }
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n";
$max_jobs_run = 10; # reasonable default.
}
close(P);
} elsif (open(P, "sysctl -a |")) { # BSD/Darwin
while (<P>) {
if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4
$max_jobs_run = $1;
last;
}
}
close(P);
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n";
$max_jobs_run = 10; # reasonable default.
}
} else {
# allow at most 32 jobs at once, on non-UNIX systems; change this code
# if you need to change this default.
$max_jobs_run = 32;
}
# The just-computed value of $max_jobs_run is just the number of processors
# (or our best guess); and if it happens that the number of jobs we need to
# run is just slightly above $max_jobs_run, it will make sense to increase
# $max_jobs_run to equal the number of jobs, so we don't have a small number
# of leftover jobs.
$num_jobs = $jobend - $jobstart + 1;
if (!$using_gpu &&
$num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) {
$max_jobs_run = $num_jobs;
}
}
sub pick_or_exit {
# pick_or_exit ( $logfile )
# Invoked before each job is started helps to run jobs selectively.
#
# Given the name of the output logfile decides whether the job must be
# executed (by returning from the subroutine) or not (by terminating the
# process calling exit)
#
# PRE: $job_pick is a global variable set by command line switch --pick
# and indicates which class of jobs must be executed.
#
# 1) If a failed job is not executed the process exit code will indicate
# failure, just as if the task was just executed and failed.
#
# 2) If a task is incomplete it will be executed. Incomplete may be either
# a job whose log file does not contain the accounting notes in the end,
# or a job whose log file does not exist.
#
# 3) If the $job_pick is set to 'all' (default behavior) a task will be
# executed regardless of the result of previous attempts.
#
# This logic could have been implemented in the main execution loop
# but a subroutine to preserve the current level of readability of
# that part of the code.
#
# Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020
#
if($job_pick eq 'all'){
return; # no need to bother with the previous log
}
open my $fh, "<", $_[0] or return; # job not executed yet
my $log_line;
my $cur_line;
while ($cur_line = <$fh>) {
if( $cur_line =~ m/# Ended \(code .*/ ) {
$log_line = $cur_line;
}
}
close $fh;
if (! defined($log_line)){
return; # incomplete
}
if ( $log_line =~ m/# Ended \(code 0\).*/ ) {
exit(0); # complete
} elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){
if ($job_pick !~ m/^(failed|all)$/) {
exit(1); # failed but not going to run
} else {
return; # failed
}
} elsif ( $log_line =~ m/.*\S.*/ ) {
return; # incomplete jobs are always run
}
}
$logfile = shift @ARGV;
if (defined $jobname && $logfile !~ m/$jobname/ &&
$jobend > $jobstart) {
print STDERR "run.pl: you are trying to run a parallel job but "
. "you are putting the output into just one log file ($logfile)\n";
exit(1);
}
$cmd = "";
foreach $x (@ARGV) {
if ($x =~ m/^\S+$/) { $cmd .= $x . " "; }
elsif ($x =~ m:\":) { $cmd .= "'$x' "; }
else { $cmd .= "\"$x\" "; }
}
#$Data::Dumper::Indent=0;
$ret = 0;
$numfail = 0;
%active_pids=();
use POSIX ":sys_wait_h";
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
if (scalar(keys %active_pids) >= $max_jobs_run) {
# Lets wait for a change in any child's status
# Then we have to work out which child finished
$r = waitpid(-1, 0);
$code = $?;
if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen.
if ( defined $active_pids{$r} ) {
$jid=$active_pids{$r};
$fail[$jid]=$code;
if ($code !=0) { $numfail++;}
delete $active_pids{$r};
# print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n";
} else {
die "run.pl: Cannot find the PID of the child process that just finished.";
}
# In theory we could do a non-blocking waitpid over all jobs running just
# to find out if only one or more jobs finished during the previous waitpid()
# However, we just omit this and will reap the next one in the next pass
# through the for(;;) cycle
}
$childpid = fork();
if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; }
if ($childpid == 0) { # We're in the child... this branch
# executes the job and returns (possibly with an error status).
if (defined $jobname) {
$cmd =~ s/$jobname/$jobid/g;
$logfile =~ s/$jobname/$jobid/g;
}
# exit if the job does not need to be executed
pick_or_exit( $logfile );
system("mkdir -p `dirname $logfile` 2>/dev/null");
open(F, ">$logfile") || die "run.pl: Error opening log file $logfile";
print F "# " . $cmd . "\n";
print F "# Started at " . `date`;
$starttime = `date +'%s'`;
print F "#\n";
close(F);
# Pipe into bash.. make sure we're not using any other shell.
open(B, "|bash") || die "run.pl: Error opening shell command";
print B "( " . $cmd . ") 2>>$logfile >> $logfile";
close(B); # If there was an error, exit status is in $?
$ret = $?;
$lowbits = $ret & 127;
$highbits = $ret >> 8;
if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" }
else { $return_str = "code $highbits"; }
$endtime = `date +'%s'`;
open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)";
$enddate = `date`;
chop $enddate;
print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n";
print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n";
close(F);
exit($ret == 0 ? 0 : 1);
} else {
$pid[$jobid] = $childpid;
$active_pids{$childpid} = $jobid;
# print STDERR "Queued: " . Dumper(\%active_pids) . "\n";
}
}
# Now we have submitted all the jobs, lets wait until all the jobs finish
foreach $child (keys %active_pids) {
$jobid=$active_pids{$child};
$r = waitpid($pid[$jobid], 0);
$code = $?;
if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen.
if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully
}
# Some sanity checks:
# The $fail array should not contain undefined codes
# The number of non-zeros in that array should be equal to $numfail
# We cannot do foreach() here, as the JOB ids do not start at zero
$failed_jids=0;
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
$job_return = $fail[$jobid];
if (not defined $job_return ) {
# print Dumper(\@fail);
die "run.pl: Sanity check failed: we have indication that some jobs are running " .
"even after we waited for all jobs to finish" ;
}
if ($job_return != 0 ){ $failed_jids++;}
}
if ($failed_jids != $numfail) {
die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)."
}
if ($numfail > 0) { $ret = 1; }
if ($ret != 0) {
$njobs = $jobend - $jobstart + 1;
if ($njobs == 1) {
if (defined $jobname) {
$logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with
# that job.
}
print STDERR "run.pl: job failed, log is in $logfile\n";
if ($logfile =~ m/JOB/) {
print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script.";
}
}
else {
$logfile =~ s/$jobname/*/g;
print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n";
}
}
exit ($ret);
\ No newline at end of file
文件模式从 100644 更改为 100755
#!/usr/bin/env bash
unset GREP_OPTIONS
set -u # Check for undefined variables
die() {
# Print a message and exit with code 1.
#
# Usage: die <error_message>
# e.g., die "Something bad happened."
echo $@
exit 1
}
echo "Collecting system information..."
OUTPUT_FILE=pd_env.txt
python_bin_path=$(which python || which python3 || die "Cannot find Python binary")
{
echo
echo '== check python ==================================================='
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_python.py
import platform
print("""python version: %s
python branch: %s
python build version: %s
python compiler version: %s
python implementation: %s
""" % (
platform.python_version(),
platform.python_branch(),
platform.python_build(),
platform.python_compiler(),
platform.python_implementation(),
))
EOF
${python_bin_path} /tmp/check_python.py 2>&1 >> ${OUTPUT_FILE}
{
echo
echo '== check os platform ==============================================='
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_os.py
import platform
print("""os: %s
os kernel version: %s
os release version: %s
os platform: %s
linux distribution: %s
linux os distribution: %s
mac version: %s
uname: %s
architecture: %s
machine: %s
""" % (
platform.system(),
platform.version(),
platform.release(),
platform.platform(),
platform.linux_distribution(),
platform.dist(),
platform.mac_ver(),
platform.uname(),
platform.architecture(),
platform.machine(),
))
EOF
${python_bin_path} /tmp/check_os.py 2>&1 >> ${OUTPUT_FILE}
{
echo
echo '== are we in docker ============================================='
num=`cat /proc/1/cgroup | grep docker | wc -l`;
if [ $num -ge 1 ]; then
echo "Yes"
else
echo "No"
fi
echo
echo '== compiler ====================================================='
c++ --version 2>&1
echo
echo '== check pips ==================================================='
pip list 2>&1 | grep "proto\|numpy\|paddlepaddle"
echo
echo '== check for virtualenv ========================================='
${python_bin_path} -c "import sys;print(hasattr(sys, \"real_prefix\"))"
echo
echo '== paddlepaddle import ============================================'
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_pd.py
import paddle as pd;
pd.set_device('cpu')
print("pd.version.full_version = %s" % pd.version.full_version)
print("pd.version.commit = %s" % pd.version.commit)
print("pd.__version__ = %s" % pd.__version__)
print("Sanity check: %r" % pd.zeros([1,2,3])[:1])
EOF
${python_bin_path} /tmp/check_pd.py 2>&1 >> ${OUTPUT_FILE}
LD_DEBUG=libs ${python_bin_path} -c "import paddle" 2>>${OUTPUT_FILE} > /tmp/loadedlibs
{
grep libcudnn.so /tmp/loadedlibs
echo
echo '== env =========================================================='
if [ -z ${LD_LIBRARY_PATH+x} ]; then
echo "LD_LIBRARY_PATH is unset";
else
echo LD_LIBRARY_PATH ${LD_LIBRARY_PATH} ;
fi
if [ -z ${DYLD_LIBRARY_PATH+x} ]; then
echo "DYLD_LIBRARY_PATH is unset";
else
echo DYLD_LIBRARY_PATH ${DYLD_LIBRARY_PATH} ;
fi
echo
echo '== nvidia-smi ==================================================='
nvidia-smi 2>&1
echo
echo '== cuda libs ==================================================='
} >> ${OUTPUT_FILE}
find /usr/local -type f -name 'libcudart*' 2>/dev/null | grep cuda | grep -v "\\.cache" >> ${OUTPUT_FILE}
find /usr/local -type f -name 'libudnn*' 2>/dev/null | grep cuda | grep -v "\\.cache" >> ${OUTPUT_FILE}
{
echo
echo '== paddlepaddle installed from info =================='
pip show paddlepaddle-gpu
echo
echo '== python version =============================================='
echo '(major, minor, micro, releaselevel, serial)'
python -c 'import sys; print(sys.version_info[:])'
echo
echo '== bazel version ==============================================='
bazel version
echo '== cmake version ==============================================='
cmake --version
} >> ${OUTPUT_FILE}
# Remove any words with google.
mv $OUTPUT_FILE old-$OUTPUT_FILE
grep -v -i google old-${OUTPUT_FILE} > $OUTPUT_FILE
echo "Wrote environment to ${OUTPUT_FILE}. You can review the contents of that file."
echo "and use it to populate the fields in the github issue template."
echo
echo "cat ${OUTPUT_FILE}"
echo
文件模式从 100644 更改为 100755
parallel/run.pl
\ No newline at end of file
#!/usr/bin/env bash
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
[ -f ./path.sh ] && . ./path.sh
# non language symbol
nlsyms=""
wer=false
bpe=""
bpemodel=""
remove_blank=true
filter=""
num_spkrs=1
help_message="Usage: $0 <data-dir> <dict>"
. utils/parse_options.sh
if [ $# != 2 ]; then
echo "${help_message}"
exit 1;
fi
dir=$1
dic=$2
cat ${dir}/data.*.json > ${dir}/data.json
if [ $num_spkrs -eq 1 ]; then
json2trn.py ${dir}/data.json ${dic} --num-spkrs ${num_spkrs} --refs ${dir}/ref.trn --hyps ${dir}/hyp.trn
if ${remove_blank}; then
sed -i.bak2 -r 's/<blank> //g' ${dir}/hyp.trn
fi
if [ -n "${nlsyms}" ]; then
cp ${dir}/ref.trn ${dir}/ref.trn.org
cp ${dir}/hyp.trn ${dir}/hyp.trn.org
filt.py -v ${nlsyms} ${dir}/ref.trn.org > ${dir}/ref.trn
filt.py -v ${nlsyms} ${dir}/hyp.trn.org > ${dir}/hyp.trn
fi
if [ -n "${filter}" ]; then
sed -i.bak3 -f ${filter} ${dir}/hyp.trn
sed -i.bak3 -f ${filter} ${dir}/ref.trn
fi
sclite -r ${dir}/ref.trn trn -h ${dir}/hyp.trn trn -i rm -o all stdout > ${dir}/result.txt
echo "write a CER (or TER) result in ${dir}/result.txt"
grep -e Avg -e SPKR -m 2 ${dir}/result.txt
if ${wer}; then
if [ -n "$bpe" ]; then
spm_decode --model=${bpemodel} --input_format=piece < ${dir}/ref.trn | sed -e "s/▁/ /g" > ${dir}/ref.wrd.trn
spm_decode --model=${bpemodel} --input_format=piece < ${dir}/hyp.trn | sed -e "s/▁/ /g" > ${dir}/hyp.wrd.trn
else
sed -e "s/ //g" -e "s/(/ (/" -e "s/<space>/ /g" ${dir}/ref.trn > ${dir}/ref.wrd.trn
sed -e "s/ //g" -e "s/(/ (/" -e "s/<space>/ /g" ${dir}/hyp.trn > ${dir}/hyp.wrd.trn
fi
sclite -r ${dir}/ref.wrd.trn trn -h ${dir}/hyp.wrd.trn trn -i rm -o all stdout > ${dir}/result.wrd.txt
echo "write a WER result in ${dir}/result.wrd.txt"
grep -e Avg -e SPKR -m 2 ${dir}/result.wrd.txt
fi
elif [ ${num_spkrs} -lt 4 ]; then
ref_trns=""
hyp_trns=""
for i in $(seq ${num_spkrs}); do
ref_trns=${ref_trns}"${dir}/ref${i}.trn "
hyp_trns=${hyp_trns}"${dir}/hyp${i}.trn "
done
json2trn.py ${dir}/data.json ${dic} --num-spkrs ${num_spkrs} --refs ${ref_trns} --hyps ${hyp_trns}
for n in $(seq ${num_spkrs}); do
if ${remove_blank}; then
sed -i.bak2 -r 's/<blank> //g' ${dir}/hyp${n}.trn
fi
if [ -n "${nlsyms}" ]; then
cp ${dir}/ref${n}.trn ${dir}/ref${n}.trn.org
cp ${dir}/hyp${n}.trn ${dir}/hyp${n}.trn.org
filt.py -v ${nlsyms} ${dir}/ref${n}.trn.org > ${dir}/ref${n}.trn
filt.py -v ${nlsyms} ${dir}/hyp${n}.trn.org > ${dir}/hyp${n}.trn
fi
if [ -n "${filter}" ]; then
sed -i.bak3 -f ${filter} ${dir}/hyp${n}.trn
sed -i.bak3 -f ${filter} ${dir}/ref${n}.trn
fi
done
results_str=""
for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do
ind_r=$((i / num_spkrs + 1))
ind_h=$((i % num_spkrs + 1))
results_str=${results_str}"${dir}/result_r${ind_r}h${ind_h}.txt "
sclite -r ${dir}/ref${ind_r}.trn trn -h ${dir}/hyp${ind_h}.trn trn -i rm -o all stdout > ${dir}/result_r${ind_r}h${ind_h}.txt
done
echo "write CER (or TER) results in ${dir}/result_r*h*.txt"
eval_perm_free_error.py --num-spkrs ${num_spkrs} \
${results_str} > ${dir}/min_perm_result.json
sed -n '2,4p' ${dir}/min_perm_result.json
if ${wer}; then
for n in $(seq ${num_spkrs}); do
if [ -n "$bpe" ]; then
spm_decode --model=${bpemodel} --input_format=piece < ${dir}/ref${n}.trn | sed -e "s/▁/ /g" > ${dir}/ref${n}.wrd.trn
spm_decode --model=${bpemodel} --input_format=piece < ${dir}/hyp${n}.trn | sed -e "s/▁/ /g" > ${dir}/hyp${n}.wrd.trn
else
sed -e "s/ //g" -e "s/(/ (/" -e "s/<space>/ /g" ${dir}/ref${n}.trn > ${dir}/ref${n}.wrd.trn
sed -e "s/ //g" -e "s/(/ (/" -e "s/<space>/ /g" ${dir}/hyp${n}.trn > ${dir}/hyp${n}.wrd.trn
fi
done
results_str=""
for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do
ind_r=$((i / num_spkrs + 1))
ind_h=$((i % num_spkrs + 1))
results_str=${results_str}"${dir}/result_r${ind_r}h${ind_h}.wrd.txt "
sclite -r ${dir}/ref${ind_r}.wrd.trn trn -h ${dir}/hyp${ind_h}.wrd.trn trn -i rm -o all stdout > ${dir}/result_r${ind_r}h${ind_h}.wrd.txt
done
echo "write WER results in ${dir}/result_r*h*.wrd.txt"
eval_perm_free_error.py --num-spkrs ${num_spkrs} \
${results_str} > ${dir}/min_perm_result.wrd.json
sed -n '2,4p' ${dir}/min_perm_result.wrd.json
fi
fi
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
while(<>){
@A = split(" ", $_);
@A > 1 || die "Invalid line in spk2utt file: $_";
$s = shift @A;
foreach $u ( @A ) {
print "$u $s\n";
}
}
#!/usr/bin/env bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
set -o errexit
if [ $# != 2 ]; then
echo "Usage: split_data.sh data-dir num-to-split"
exit 1
fi
data=$1
numsplit=$2
if [ $numsplit -le 0 ]; then
echo "Invalid num-split argument $numsplit";
exit 1;
fi
n=0;
feats=""
wavs=""
utt2spks=""
texts=""
nu=`cat $data/utt2spk | wc -l`
nf=`cat $data/feats.scp | wc -l`
nt=`cat $data/text | wc -l`
if [ $nu -ne $nf ]; then
echo "split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf);"
echo "this script may produce incorrectly split data."
echo "use utils/fix_data_dir.sh to fix this."
fi
if [ $nt -ne 0 -a $nu -ne $nt ]; then
echo "split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt);"
echo "this script may produce incorrectly split data."
echo "use utils/fix_data_dir.sh to fix this."
fi
# utilsscripts/get_split.pl returns "0 1 2 3" or "00 01 .. 18 19" or whatever.
# for n in `get_splits.pl $numsplit`; do
for n in `seq 1 $numsplit`; do # Changed this to usual number sequence -Arnab
mkdir -p $data/split$numsplit/$n
feats="$feats $data/split$numsplit/$n/feats.scp"
wavs="$wavs $data/split$numsplit/$n/wav.scp"
texts="$texts $data/split$numsplit/$n/text"
utt2spks="$utt2spks $data/split$numsplit/$n/utt2spk"
done
split_scp.pl --utt2spk=$data/utt2spk $data/utt2spk $utt2spks
split_scp.pl --utt2spk=$data/utt2spk $data/feats.scp $feats
[ -f $data/wav.scp ] && \
split_scp.pl --utt2spk=$data/utt2spk $data/wav.scp $wavs
[ -f $data/text ] && \
split_scp.pl --utt2spk=$data/utt2spk $data/text $texts
# for n in `get_splits.pl $numsplit`; do
for n in `seq 1 $numsplit`; do # Changed this to usual number sequence -Arnab
utt2spk_to_spk2utt.pl $data/split$numsplit/$n/utt2spk \
> $data/split$numsplit/$n/spk2utt
# for completeness, also split the spk2gender file
[ -f $data/spk2gender ] && \
filter_scp.pl $data/split$numsplit/$n/spk2utt $data/spk2gender \
> $data/split$numsplit/$n/spk2gender
done
exit 0
\ No newline at end of file
#!/usr/bin/env bash
set -o errexit
if [ $# != 2 ]; then
echo "Usage: split_json.sh manifest num-to-split"
exit 1
fi
data=data
jsonfile=$1
numsplit=$2
if [ $numsplit -le 0 ]; then
echo "Invalid num-split argument $numsplit";
exit 1;
fi
n=0;
jsons=""
# utilsscripts/get_split.pl returns "0 1 2 3" or "00 01 .. 18 19" or whatever.
# for n in `get_splits.pl $numsplit`; do
for n in `seq 1 $numsplit`; do # Changed this to usual number sequence -Arnab
mkdir -p $data/split$numsplit/$n
jsons="$jsons $data/split$numsplit/$n/${jsonfile}"
done
split_scp.pl $data/${jsonfile} $jsons
exit 0
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can. If you use
# the utt2spk option it will make sure these chunks coincide with
# speaker boundaries. In this case, if there are more chunks
# than speakers (and in some other circumstances), some of the
# resulting chunks will be empty and it
# will print a warning.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
for ($x = 1; $x <= 2; $x++) {
if ($ARGV[0] eq "-j") {
shift @ARGV;
$num_jobs = shift @ARGV;
$job_id = shift @ARGV;
if ($num_jobs <= 0 || $job_id < 0 || $job_id >= $num_jobs) {
die "Invalid num-jobs and job-id: $num_jobs and $job_id";
}
}
if ($ARGV[0] =~ "--utt2spk=(.+)") {
$utt2spk_file=$1;
shift;
}
}
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
die "Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ... \n" .
" or: split_scp.pl -j num-jobs job-id [--utt2spk=<utt2spk_file>] in.scp [out.scp]\n" .
" ... where 0 <= job-id < num-jobs.";
}
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
@OUTPUTS = @ARGV;
} else {
for ($j = 0; $j < $num_jobs; $j++) {
if ($j == $job_id) {
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
else { push @OUTPUTS, "-"; }
} else {
push @OUTPUTS, "/dev/null";
}
}
}
if ($utt2spk_file ne "") { # We have the --utt2spk option...
open(U, "<$utt2spk_file") || die "Failed to open utt2spk file $utt2spk_file";
while(<U>) {
@A = split;
@A == 2 || die "Bad line $_ in utt2spk file $utt2spk_file";
($u,$s) = @A;
$utt2spk{$u} = $s;
}
open(I, "<$inscp") || die "Opening input scp file $inscp";
@spkrs = ();
while(<I>) {
@A = split;
if(@A == 0) { die "Empty or space-only line in scp file $inscp"; }
$u = $A[0];
$s = $utt2spk{$u};
if(!defined $s) { die "No such utterance $u in utt2spk file $utt2spk_file"; }
if(!defined $spk_count{$s}) {
push @spkrs, $s;
$spk_count{$s} = 0;
$spk_data{$s} = "";
}
$spk_count{$s}++;
$spk_data{$s} = $spk_data{$s} . $_;
}
# Now split as equally as possible ..
# First allocate spks to files by allocating an approximately
# equal number of speakers.
$numspks = @spkrs; # number of speakers.
$numscps = @OUTPUTS; # number of output files.
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scparray[$scpidx] = []; # [] is array reference.
}
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
$scpidx = int(($spkidx*$numscps) / $numspks);
$spk = $spkrs[$spkidx];
push @{$scparray[$scpidx]}, $spk;
$scpcount[$scpidx] += $spk_count{$spk};
}
# Now will try to reassign beginning + ending speakers
# to different scp's and see if it gets more balanced.
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
# We can show that if considering changing just 2 scp's, we minimize
# this by minimizing the squared difference in sizes. This is
# equivalent to minimizing the absolute difference in sizes. This
# shows this method is bound to converge.
$changed = 1;
while($changed) {
$changed = 0;
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
# First try to reassign ending spk of this scp.
if($scpidx < $numscps-1) {
$sz = @{$scparray[$scpidx]};
if($sz > 0) {
$spk = $scparray[$scpidx]->[$sz-1];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx];
$nutt2 = $scpcount[$scpidx+1];
if( abs( ($nutt2+$count) - ($nutt1-$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx+1] += $count;
$scpcount[$scpidx] -= $count;
pop @{$scparray[$scpidx]};
unshift @{$scparray[$scpidx+1]}, $spk;
$changed = 1;
}
}
}
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
$spk = $scparray[$scpidx]->[0];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx-1];
$nutt2 = $scpcount[$scpidx];
if( abs( ($nutt2-$count) - ($nutt1+$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx-1] += $count;
$scpcount[$scpidx] -= $count;
shift @{$scparray[$scpidx]};
push @{$scparray[$scpidx-1]}, $spk;
$changed = 1;
}
}
}
}
# Now print out the files...
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scpfn = $OUTPUTS[$scpidx];
open(F, ">$scpfn") || die "Could not open scp file $scpfn for writing.";
$count = 0;
if(@{$scparray[$scpidx]} == 0) {
print STDERR "Warning: split_scp.pl producing empty .scp file $scpfn (too many splits and too few speakers?)\n";
} else {
foreach $spk ( @{$scparray[$scpidx]} ) {
print F $spk_data{$spk};
$count += $spk_count{$spk};
}
if($count != $scpcount[$scpidx]) { die "Count mismatch [code error]"; }
}
close(F);
}
} else {
# This block is the "normal" case where there is no --utt2spk
# option and we just break into equal size chunks.
open(I, "<$inscp") || die "Opening input scp file $inscp";
$numscps = @OUTPUTS; # size of array.
@F = ();
while(<I>) {
push @F, $_;
}
$numlines = @F;
if($numlines == 0) {
print STDERR "split_scp.pl: warning: empty input scp file $inscp";
}
$linesperscp = int( ($numlines+($numscps-1)) / $numscps); # the +$(numscps-1) forces rounding up.
# [just doing int() rounds down].
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
open(O, ">$scpfile") || die "Opening output scp file $scpfile";
for($n = $linesperscp * $scpidx; $n < $numlines && $n < $linesperscp*($scpidx+1); $n++) {
print O $F[$n];
}
close(O) || die "Closing scp file $scpfile";
}
}
\ No newline at end of file
#!/bin/bash
if [ $# != 4 ];then
echo "usage: $0 ckpt_prefix model_config mean_std vocab"
if [ $# != 5 ];then
echo "usage: $0 ckpt_prefix model_config mean_std vocab pack_name"
exit -1
fi
......@@ -9,6 +9,7 @@ ckpt_prefix=$1
model_config=$2
mean_std=$3
vocab=$4
pack_name=$5
output=release
......@@ -18,9 +19,15 @@ function clean() {
}
trap clean EXIT
cp ${ckpt_prefix}.* ${output}
# ckpt_prfix dir
if [ -d ${ckpt_prefix} ];then
cp -r ${ckpt_prefix} ${output}
fi
# ckpt_prfix.{json,...}
cp ${ckpt_prefix}.* ${output}
# model config, mean std, vocab
cp ${model_config} ${mean_std} ${vocab} ${output}
tar zcvf release.tar.gz ${output}
tar zcvf ${pack_name}.release.tar.gz ${output}
echo "tarball done!"
echo "tarball: ${pack_name}.release.tar.gz done!"
#!/usr/bin/env bash
# 2020 Author Jiayu DU
# Apache 2.0
# This script uses kenlm to estimate an arpa model from plain text,
# it is a resort when you hit memory limit dealing with large corpus
# kenlm estimates arpa using on-disk structure,
# as long as you have big enough hard disk, memory shouldn't be a problem.
# by default, kenlm use up to 50% of your local memory,
# you can control this through -S option
[ -f path.sh ] && . ./path.sh;
kenlm_opts="" # e.g. "-o 4 -S 50% --prune 0 5 7 7"
if [ $# != 4 ]; then
echo "$0 <text> <kaldi_symbol_table> <working_dir> <arpa_name>"
echo "e.g. $0 train.txt words.txt wdir 4gram"
exit 1
fi
text=$1
symbol_table=$2
dir=$3
arpa_name=$4
if ! which lmplz >& /dev/null ; then
echo "$0: cannot find training tool *lmplz*."
echo "tools/extras/install_kenlm_query_only.sh installs kenlm at tools/kenlm"
echo "it only supports runtime mode, to actually train an arpa using KenLM,"
echo "you need a complete KenLM installation(depends on EIGEN and BOOST),"
echo "follow KenLM's building instructions at (https://github.com/kpu/kenlm)"
exit 1
fi
# the text should be properly pre-processed, e.g:
# cleand, normalized and possibly word-segmented
# get rid off irrelavent symbols
grep -v '<eps>' $symbol_table \
| grep -v '#0' \
| grep -v '<unk>' | grep -v '<UNK>' \
| grep -v '<s>' | grep -v '</s>' \
| awk '{print $1}' \
> $dir/ngram.vocab
# To make sure that kenlm & kaldi have strictly the same vocabulary:
# 1. feed vocabulary into kenlm via --limit_vocab_file
# 2. cat vocabulary to training text, so each word at least appear once
#
# TL;DR reason:
# Unlike SRILM's -limit-vocab, kenlm's --limit_vocab_file option
# spcifies a *valid* set of vocabulary, whereas *valid but unseen*
# words are discarded in final arpa.
# So the trick is,
# we explicitly add kaldi's vocab(one word per line) to training text,
# making each word appear at least once.
# kenlm never prunes unigram,
# so this always generates consistent kenlm vocabuary as kaldi has.
# The effect of this is like add-one smoothing to unigram counts,
# shouldn't have significant impacts in practice.
cat $dir/ngram.vocab $text \
| lmplz $kenlm_opts --limit_vocab_file $dir/ngram.vocab \
> $dir/${arpa_name}.arpa
echo "$0: Done training arpa to: $dir/${arpa_name}.arpa"
\ No newline at end of file
......@@ -11,11 +11,93 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import os
import tarfile
import zipfile
from typing import Text
from paddle.dataset.common import md5file
__all__ = [
"check_md5sum", "getfile_insensitive", "download_multi", "download",
"unpack", "unzip", "md5file", "print_arguments", "add_arguments",
"read_manifest"
]
def read_manifest(manifest_path):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = []
for json_line in open(manifest_path, 'r'):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
return manifest
def print_arguments(args, info=None):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
filename = ""
if info:
filename = info["__file__"]
filename = os.path.basename(filename)
print(f"----------- {filename} Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("-----------------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def getfile_insensitive(path):
......@@ -54,6 +136,19 @@ def download(url, md5sum, target_dir):
return filepath
def check_md5sum(filepath: Text, md5sum: Text) -> bool:
"""check md5sum of file.
Args:
filepath (Text): [description]
md5sum (Text): [description]
Returns:
bool: same or not.
"""
return md5file(filepath) == md5sum
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
......
文件模式从 100644 更改为 100755
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# converts an utt2spk file to a spk2utt file.
# Takes input from the stdin or from a file argument;
# output goes to the standard out.
if ( @ARGV > 1 ) {
die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
}
while(<>){
@A = split(" ", $_);
@A == 2 || die "Invalid line in utt2spk file: $_";
($u,$s) = @A;
if(!$seen_spk{$s}) {
$seen_spk{$s} = 1;
push @spklist, $s;
}
push (@{$spk_hash{$s}}, "$u");
}
foreach $s (@spklist) {
$l = join(' ',@{$spk_hash{$s}});
print "$s $l\n";
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册