From 16fa4245ec3f29ba2bc558e1d21c685e171c14de Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 17 Mar 2021 12:48:39 +0000 Subject: [PATCH] add test; attention --- .notebook/hack_api_test.ipynb | 290 +++++++++++++++ .notebook/jit_infer.ipynb | 616 ++++++++++++++++++++++++++++++++ deepspeech/modules/__init__.py | 51 ++- deepspeech/modules/attention.py | 227 ++++++++++++ requirements.txt | 1 + 5 files changed, 1180 insertions(+), 5 deletions(-) create mode 100644 .notebook/hack_api_test.ipynb create mode 100644 .notebook/jit_infer.ipynb create mode 100644 deepspeech/modules/attention.py diff --git a/.notebook/hack_api_test.ipynb b/.notebook/hack_api_test.ipynb new file mode 100644 index 00000000..f653084e --- /dev/null +++ b/.notebook/hack_api_test.ipynb @@ -0,0 +1,290 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "breeding-haven", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/ssd5/zhanghui/DeepSpeech2.x\n" + ] + }, + { + "data": { + "text/plain": [ + "'/home/ssd5/zhanghui/DeepSpeech2.x'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%cd ..\n", + "%pwd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "appropriate-theta", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LICENSE deepspeech examples\t\t requirements.txt tools\r\n", + "README.md docs\t libsndfile-1.0.28\t setup.sh\t utils\r\n", + "README_cn.md env.sh\t libsndfile-1.0.28.tar.gz tests\r\n" + ] + } + ], + "source": [ + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "entire-bloom", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " def convert_to_list(value, n, name, dtype=np.int):\n", + "WARNING:root:override cat of paddle.Tensor if exists or register, remove this when fixed!\n", + "WARNING:root:register user masked_fill to paddle.Tensor, remove this when fixed!\n", + "WARNING:root:register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", + "WARNING:root:register user repeat to paddle.Tensor, remove this when fixed!\n", + "WARNING:root:register user glu to paddle.nn.functional, remove this when fixed!\n", + "WARNING:root:register user GLU to paddle.nn, remove this when fixed!\n", + "WARNING:root:register user ConstantPad2d to paddle.nn, remove this when fixed!\n", + "WARNING:root:override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n" + ] + } + ], + "source": [ + "from deepspeech.modules import loss" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "governmental-aircraft", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] + } + ], + "source": [ + "import paddle" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "proprietary-disaster", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + " paddle.VarBase>" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "paddle.Tensor.repeat" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "first-diagram", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "paddle.Tensor.size" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "intelligent-david", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "paddle.Tensor.cat" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bronze-tenant", + "metadata": {}, + "outputs": [], + "source": [ + "a = paddle.to_tensor([12,32, 10, 12, 123,32 ,4])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "balanced-bearing", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a.size" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "extreme-republic", + "metadata": {}, + "outputs": [], + "source": [ + "def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:\n", + " nargs = len(args)\n", + " assert (nargs <= 1)\n", + " s = paddle.shape(xs)\n", + " if nargs == 1:\n", + " return s[args[0]]\n", + " else:\n", + " return s\n", + "\n", + "# logger.warn(\n", + "# \"override size of paddle.Tensor if exists or register, remove this when fixed!\"\n", + "# )\n", + "paddle.Tensor.size = size" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "gross-addiction", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", + " [7])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a.size(0)\n", + "a.size()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "adverse-dining", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", + " [7])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a.size()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "popular-potato", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb new file mode 100644 index 00000000..49e395b3 --- /dev/null +++ b/.notebook/jit_infer.ipynb @@ -0,0 +1,616 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/ssd5/zhanghui/DeepSpeech2.x\n" + ] + }, + { + "data": { + "text/plain": [ + "'/home/ssd5/zhanghui/DeepSpeech2.x'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%cd ..\n", + "%pwd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-03-17 11:09:34,972 - WARNING - override cat of paddle.Tensor if exists or register, remove this when fixed!\n", + "2021-03-17 11:09:34,973 - WARNING - override size of paddle.Tensor if exists or register (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", + "2021-03-17 11:09:34,974 - WARNING - register user masked_fill to paddle.Tensor, remove this when fixed!\n", + "2021-03-17 11:09:34,975 - WARNING - register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", + "2021-03-17 11:09:34,975 - WARNING - register user repeat to paddle.Tensor, remove this when fixed!\n", + "2021-03-17 11:09:34,976 - WARNING - register user glu to paddle.nn.functional, remove this when fixed!\n", + "2021-03-17 11:09:34,976 - WARNING - register user GLU to paddle.nn, remove this when fixed!\n", + "2021-03-17 11:09:34,977 - WARNING - register user ConstantPad2d to paddle.nn, remove this when fixed!\n", + "2021-03-17 11:09:34,977 - WARNING - override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", + "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", + " from numpy.dual import register_func\n", + "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "import argparse\n", + "import functools\n", + "import paddle\n", + "import numpy as np\n", + "\n", + "from deepspeech.utils.socket_server import warm_up_test\n", + "from deepspeech.utils.socket_server import AsrTCPServer\n", + "from deepspeech.utils.socket_server import AsrRequestHandler\n", + "\n", + "from deepspeech.training.cli import default_argument_parser\n", + "from deepspeech.exps.deepspeech2.config import get_cfg_defaults\n", + "\n", + "from deepspeech.frontend.utility import read_manifest\n", + "from deepspeech.utils.utility import add_arguments, print_arguments\n", + "\n", + "from deepspeech.models.deepspeech2 import DeepSpeech2Model\n", + "from deepspeech.models.deepspeech2 import DeepSpeech2InferModel\n", + "from deepspeech.io.dataset import ManifestDataset\n", + "\n", + "\n", + "from paddle.inference import Config\n", + "from paddle.inference import create_predictor\n", + "\n", + "from deepspeech.frontend.utility import read_manifest" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0.0\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] + } + ], + "source": [ + "print(paddle.__version__)\n", + "print(paddle.version)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data:\n", + " augmentation_config: conf/augmentation.config\n", + " batch_size: 64\n", + " dev_manifest: data/manifest.dev\n", + " keep_transcription_text: False\n", + " max_duration: 27.0\n", + " max_freq: None\n", + " mean_std_filepath: examples/aishell/data/mean_std.npz\n", + " min_duration: 0.0\n", + " n_fft: None\n", + " num_workers: 0\n", + " random_seed: 0\n", + " shuffle_method: batch_shuffle\n", + " sortagrad: True\n", + " specgram_type: linear\n", + " stride_ms: 10.0\n", + " target_dB: -20\n", + " target_sample_rate: 16000\n", + " test_manifest: examples/aishell/data/manifest.test\n", + " train_manifest: data/manifest.train\n", + " use_dB_normalization: True\n", + " vocab_filepath: examples/aishell/data/vocab.txt\n", + " window_ms: 20.0\n", + "decoding:\n", + " alpha: 2.6\n", + " batch_size: 128\n", + " beam_size: 300\n", + " beta: 5.0\n", + " cutoff_prob: 0.99\n", + " cutoff_top_n: 40\n", + " decoding_method: ctc_beam_search\n", + " error_rate_type: cer\n", + " lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm\n", + " num_proc_bsearch: 10\n", + "model:\n", + " num_conv_layers: 2\n", + " num_rnn_layers: 3\n", + " rnn_layer_size: 1024\n", + " share_rnn_weights: False\n", + " use_gru: True\n", + "training:\n", + " global_grad_clip: 5.0\n", + " lr: 0.0005\n", + " lr_decay: 0.83\n", + " n_epoch: 30\n", + " weight_decay: 1e-06\n", + "----------- Configuration Arguments -----------\n", + "checkpoint_path: examples/aishell/ckpt/checkpoints/step-1876\n", + "config: examples/aishell/conf/deepspeech2.yaml\n", + "device: gpu\n", + "dump_config: None\n", + "export_path: None\n", + "host_ip: localhost\n", + "host_port: 8086\n", + "model_dir: None\n", + "model_file: examples/aishell/jit.model.pdmodel\n", + "nprocs: 1\n", + "opts: ['data.test_manifest', 'examples/aishell/data/manifest.test', 'data.mean_std_filepath', 'examples/aishell/data/mean_std.npz', 'data.vocab_filepath', 'examples/aishell/data/vocab.txt']\n", + "output: None\n", + "params_file: examples/aishell/jit.model.pdiparams\n", + "speech_save_dir: demo_cache\n", + "use_gpu: True\n", + "warmup_manifest: examples/aishell/data/manifest.test\n", + "------------------------------------------------\n" + ] + } + ], + "source": [ + "parser = default_argument_parser()\n", + "add_arg = functools.partial(add_arguments, argparser=parser)\n", + "add_arg('host_ip', str,\n", + " 'localhost',\n", + " \"Server's IP address.\")\n", + "add_arg('host_port', int, 8086, \"Server's IP port.\")\n", + "add_arg('speech_save_dir', str,\n", + " 'demo_cache',\n", + " \"Directory to save demo audios.\")\n", + "add_arg('warmup_manifest', str, \"examples/aishell/data/manifest.test\", \"Filepath of manifest to warm up.\")\n", + "add_arg(\n", + " \"--model_file\",\n", + " type=str,\n", + " default=\"examples/aishell/jit.model.pdmodel\",\n", + " help=\"Model filename, Specify this when your model is a combined model.\"\n", + ")\n", + "add_arg(\n", + " \"--params_file\",\n", + " type=str,\n", + " default=\"examples/aishell/jit.model.pdiparams\",\n", + " help=\n", + " \"Parameter filename, Specify this when your model is a combined model.\"\n", + ")\n", + "add_arg(\n", + " \"--model_dir\",\n", + " type=str,\n", + " default=None,\n", + " help=\n", + " \"Model dir, If you load a non-combined model, specify the directory of the model.\"\n", + ")\n", + "add_arg(\"--use_gpu\",type=bool,default=True, help=\"Whether use gpu.\")\n", + "args = parser.parse_args(\"--checkpoint_path examples/aishell/ckpt/checkpoints/step-1876 --config examples/aishell/conf/deepspeech2.yaml --opts data.test_manifest examples/aishell/data/manifest.test data.mean_std_filepath examples/aishell/data/mean_std.npz data.vocab_filepath examples/aishell/data/vocab.txt\".split())\n", + "\n", + "\n", + "config = get_cfg_defaults()\n", + "if args.config:\n", + " config.merge_from_file(args.config)\n", + "if args.opts:\n", + " config.merge_from_list(args.opts)\n", + "config.freeze()\n", + "print(config)\n", + "\n", + "args.warmup_manifest = config.data.test_manifest\n", + "print_arguments(args)\n", + "\n", + "if args.dump_config:\n", + " with open(args.dump_config, 'w') as f:\n", + " print(config, file=f)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = ManifestDataset(\n", + " config.data.test_manifest,\n", + " config.data.vocab_filepath,\n", + " config.data.mean_std_filepath,\n", + " augmentation_config=\"{}\",\n", + " max_duration=config.data.max_duration,\n", + " min_duration=config.data.min_duration,\n", + " stride_ms=config.data.stride_ms,\n", + " window_ms=config.data.window_ms,\n", + " n_fft=config.data.n_fft,\n", + " max_freq=config.data.max_freq,\n", + " target_sample_rate=config.data.target_sample_rate,\n", + " specgram_type=config.data.specgram_type,\n", + " use_dB_normalization=config.data.use_dB_normalization,\n", + " target_dB=config.data.target_dB,\n", + " random_seed=config.data.random_seed,\n", + " keep_transcription_text=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/paddlepaddle_gpu-0.0.0-py3.7-linux-x86_64.egg/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.ctc_lo.weight. decoder.ctc_lo.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/paddlepaddle_gpu-0.0.0-py3.7-linux-x86_64.egg/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.ctc_lo.bias. decoder.ctc_lo.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "2021-03-17 11:10:00,017 - INFO - [checkpoint] Rank 0: loaded model from examples/aishell/ckpt/checkpoints/step-1876.pdparams\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "layer summary:\n", + "encoder.conv.conv_in.conv.weight|[32, 1, 41, 11]|14432\n", + "encoder.conv.conv_in.bn.weight|[32]|32\n", + "encoder.conv.conv_in.bn.bias|[32]|32\n", + "encoder.conv.conv_in.bn._mean|[32]|32\n", + "encoder.conv.conv_in.bn._variance|[32]|32\n", + "encoder.conv.conv_stack.0.conv.weight|[32, 32, 21, 11]|236544\n", + "encoder.conv.conv_stack.0.bn.weight|[32]|32\n", + "encoder.conv.conv_stack.0.bn.bias|[32]|32\n", + "encoder.conv.conv_stack.0.bn._mean|[32]|32\n", + "encoder.conv.conv_stack.0.bn._variance|[32]|32\n", + "encoder.rnn.rnn_stacks.0.fw_fc.weight|[1312, 3072]|4030464\n", + "encoder.rnn.rnn_stacks.0.fw_bn.weight|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.fw_bn.bias|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.fw_bn._mean|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.fw_bn._variance|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.bw_fc.weight|[1312, 3072]|4030464\n", + "encoder.rnn.rnn_stacks.0.bw_bn.weight|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.bw_bn.bias|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.bw_bn._mean|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.bw_bn._variance|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.fw_cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.0.fw_cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.bw_cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.0.bw_cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.0.fw_rnn.cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.0.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.0.bw_rnn.cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.fw_fc.weight|[2048, 3072]|6291456\n", + "encoder.rnn.rnn_stacks.1.fw_bn.weight|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.fw_bn.bias|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.fw_bn._mean|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.fw_bn._variance|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.bw_fc.weight|[2048, 3072]|6291456\n", + "encoder.rnn.rnn_stacks.1.bw_bn.weight|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.bw_bn.bias|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.bw_bn._mean|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.bw_bn._variance|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.fw_cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.1.fw_cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.bw_cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.1.bw_cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.1.fw_rnn.cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.1.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.1.bw_rnn.cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.fw_fc.weight|[2048, 3072]|6291456\n", + "encoder.rnn.rnn_stacks.2.fw_bn.weight|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.fw_bn.bias|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.fw_bn._mean|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.fw_bn._variance|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.bw_fc.weight|[2048, 3072]|6291456\n", + "encoder.rnn.rnn_stacks.2.bw_bn.weight|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.bw_bn.bias|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.bw_bn._mean|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.bw_bn._variance|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.fw_cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.2.fw_cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.bw_cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.2.bw_cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.2.fw_rnn.cell.bias_hh|[3072]|3072\n", + "encoder.rnn.rnn_stacks.2.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", + "encoder.rnn.rnn_stacks.2.bw_rnn.cell.bias_hh|[3072]|3072\n", + "decoder.ctc_lo.weight|[2048, 4300]|8806400\n", + "decoder.ctc_lo.bias|[4300]|4300\n", + "layer has 66 parameters, 80148012 elements.\n" + ] + } + ], + "source": [ + "model = DeepSpeech2InferModel.from_pretrained(dataset, config,\n", + " args.checkpoint_path)\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def init_predictor(args):\n", + " if args.model_dir is not None:\n", + " config = Config(args.model_dir)\n", + " else:\n", + " config = Config(args.model_file, args.params_file)\n", + "\n", + " config.enable_memory_optim()\n", + " if args.use_gpu:\n", + " config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)\n", + " else:\n", + " # If not specific mkldnn, you can set the blas thread.\n", + " # The thread num should not be greater than the number of cores in the CPU.\n", + " config.set_cpu_math_library_num_threads(4)\n", + " #config.enable_mkldnn()\n", + " \n", + " config.switch_ir_optim(False)\n", + "\n", + " predictor = create_predictor(config)\n", + " return predictor\n", + "\n", + "def run(predictor, audio, audio_len):\n", + " # copy img data to input tensor\n", + " input_names = predictor.get_input_names()\n", + " for i, name in enumerate(input_names):\n", + " print(\"input:\", i, name)\n", + " \n", + " audio_tensor = predictor.get_input_handle('audio')\n", + " audio_tensor.reshape(audio.shape)\n", + " audio_tensor.copy_from_cpu(audio.copy())\n", + " \n", + " audiolen_tensor = predictor.get_input_handle('audio_len')\n", + " audiolen_tensor.reshape(audio_len.shape)\n", + " audiolen_tensor.copy_from_cpu(audio_len.copy())\n", + "\n", + " output_names = predictor.get_output_names()\n", + " for i, name in enumerate(output_names):\n", + " print(\"output:\", i, name)\n", + "\n", + " # do the inference\n", + " predictor.run()\n", + "\n", + " results = []\n", + " # get out data from output tensor\n", + " output_names = predictor.get_output_names()\n", + " for i, name in enumerate(output_names):\n", + " output_tensor = predictor.get_output_handle(name)\n", + " output_data = output_tensor.copy_to_cpu()\n", + " results.append(output_data)\n", + "\n", + " return results\n", + "\n", + "\n", + "predictor = init_predictor(args)\n", + "\n", + "def file_to_transcript(filename):\n", + " print(filename)\n", + " feature = dataset.process_utterance(filename, \"\")\n", + " audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n", + " audio_len = feature[0].shape[1]\n", + " audio_len = np.array([audio_len]).astype('int64') # [1]\n", + " \n", + " \n", + " i_probs = run(predictor, audio, audio_len)\n", + " print('jit:', i_probs[0], type(i_probs[0]))\n", + " \n", + " audio = paddle.to_tensor(audio)\n", + " audio_len = paddle.to_tensor(audio_len)\n", + " print(audio.shape)\n", + " print(audio_len.shape)\n", + " \n", + " #eouts, eouts_len = model.encoder(audio, audio_len)\n", + " #probs = model.decoder.probs(eouts)\n", + " probs = model.forward(audio, audio_len)\n", + " print('paddle:', probs.numpy())\n", + " \n", + " flag = np.allclose(i_probs[0], probs.numpy())\n", + " print(flag)\n", + " \n", + " return probs\n", + "\n", + "# result_transcript = model.decode(\n", + "# audio,\n", + "# audio_len,\n", + "# vocab_list=dataset.vocab_list,\n", + "# decoding_method=config.decoding.decoding_method,\n", + "# lang_model_path=config.decoding.lang_model_path,\n", + "# beam_alpha=config.decoding.alpha,\n", + "# beam_beta=config.decoding.beta,\n", + "# beam_size=config.decoding.beam_size,\n", + "# cutoff_prob=config.decoding.cutoff_prob,\n", + "# cutoff_top_n=config.decoding.cutoff_top_n,\n", + "# num_processes=config.decoding.num_proc_bsearch)\n", + "# return result_transcript[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warm-up Test Case %d: %s 0 /home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../..//examples/dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0121.wav\n", + "/home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../..//examples/dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0121.wav\n", + "input: 0 audio\n", + "input: 1 audio_len\n", + "output: 0 tmp_75\n", + "jit: [[[1.40282078e-04 3.31296207e-04 5.57157793e-04 ... 1.07916087e-04\n", + " 8.73636964e-05 1.96113906e-04]\n", + " [1.38032061e-04 2.70526099e-04 4.53807996e-04 ... 1.02293277e-04\n", + " 8.40202629e-05 1.90612729e-04]\n", + " [1.38912103e-04 2.45687814e-04 3.99624696e-04 ... 9.70420660e-05\n", + " 7.88255784e-05 1.80084753e-04]\n", + " ...\n", + " [3.28999187e-04 2.59723864e-04 3.03535169e-04 ... 2.82066030e-04\n", + " 1.11002744e-04 1.27009131e-04]\n", + " [2.91427423e-04 2.20203598e-04 2.85082555e-04 ... 3.27318383e-04\n", + " 1.09202861e-04 1.17112293e-04]\n", + " [3.63971514e-04 1.47859042e-04 2.24457763e-04 ... 3.63016297e-04\n", + " 1.34765272e-04 1.61947115e-04]]] \n", + "[1, 161, 419]\n", + "[1]\n", + "paddle: [[[3.4913886e-04 2.5836096e-04 4.2449642e-04 ... 7.2210147e-05\n", + " 7.1211573e-05 2.0057644e-04]\n", + " [3.8406707e-04 2.4088801e-04 5.0910388e-04 ... 6.1701416e-05\n", + " 6.7852285e-05 2.3967208e-04]\n", + " [4.1069370e-04 2.5478008e-04 6.7985675e-04 ... 5.8369777e-05\n", + " 6.2065104e-05 2.5938542e-04]\n", + " ...\n", + " [6.6656910e-04 3.1835871e-04 7.5929717e-04 ... 1.1990797e-04\n", + " 3.7087579e-05 3.4520373e-04]\n", + " [4.7881933e-04 2.7979453e-04 6.7949941e-04 ... 1.2511105e-04\n", + " 4.5631223e-05 3.7984925e-04]\n", + " [2.8661705e-04 2.9201157e-04 4.5970027e-04 ... 1.4581002e-04\n", + " 7.8281126e-05 3.8263199e-04]]]\n", + "False\n" + ] + } + ], + "source": [ + "manifest = read_manifest(args.warmup_manifest)\n", + "\n", + "for idx, sample in enumerate(manifest[:1]):\n", + " print(\"Warm-up Test Case %d: %s\", idx, sample['audio_filepath'])\n", + " start_time = time.time()\n", + " transcript = file_to_transcript(sample['audio_filepath'])\n", + " finish_time = time.time()\n", + "# print(\"Response Time: %f, Transcript: %s\" %\n", + "# (finish_time - start_time, transcript))\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 161, 419) (1,)\n", + "input: 0 audio\n", + "input: 1 audio_len\n", + "output: 0 tmp_75\n", + "jit: [[[1.40282078e-04 3.31296207e-04 5.57157793e-04 ... 1.07916087e-04\n", + " 8.73636964e-05 1.96113906e-04]\n", + " [1.38032061e-04 2.70526099e-04 4.53807996e-04 ... 1.02293277e-04\n", + " 8.40202629e-05 1.90612729e-04]\n", + " [1.38912103e-04 2.45687814e-04 3.99624696e-04 ... 9.70420660e-05\n", + " 7.88255784e-05 1.80084753e-04]\n", + " ...\n", + " [3.28999187e-04 2.59723864e-04 3.03535169e-04 ... 2.82066030e-04\n", + " 1.11002744e-04 1.27009131e-04]\n", + " [2.91427423e-04 2.20203598e-04 2.85082555e-04 ... 3.27318383e-04\n", + " 1.09202861e-04 1.17112293e-04]\n", + " [3.63971514e-04 1.47859042e-04 2.24457763e-04 ... 3.63016297e-04\n", + " 1.34765272e-04 1.61947115e-04]]]\n" + ] + } + ], + "source": [ + "def test(filename):\n", + " feature = dataset.process_utterance(filename, \"\")\n", + " audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n", + " audio_len = feature[0].shape[1]\n", + " audio_len = np.array([audio_len]).astype('int64') # [1]\n", + " \n", + " print(audio.shape, audio_len.shape)\n", + "\n", + " i_probs = run(predictor, audio, audio_len)\n", + " print('jit:', i_probs[0])\n", + " return i_probs\n", + " \n", + "probs = test(sample['audio_filepath'])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/paddlepaddle_gpu-0.0.0-py3.7-linux-x86_64.egg/paddle/tensor/creation.py:143: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n" + ] + } + ], + "source": [ + "a = paddle.to_tensor([1,3,4])\n", + "a.numpy?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/deepspeech/modules/__init__.py b/deepspeech/modules/__init__.py index 925848e8..973bc062 100644 --- a/deepspeech/modules/__init__.py +++ b/deepspeech/modules/__init__.py @@ -43,20 +43,53 @@ if not hasattr(paddle.Tensor, 'cat'): paddle.Tensor.cat = paddle.Tensor.concat +def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: + return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place)) + + +if not hasattr(paddle.Tensor, 'eq'): + logger.warn( + "override eq of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.eq = eq + + +def contiguous(xs: paddle.Tensor) -> paddle.Tensor: + return xs + + +if not hasattr(paddle.Tensor, 'contiguous'): + logger.warn( + "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.contiguous = contiguous + + def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: nargs = len(args) assert (nargs <= 1) s = paddle.shape(xs) if nargs == 1: - return s[args] + return s[args[0]] else: return s -# logger.warn( -# "override size of paddle.Tensor if exists or register, remove this when fixed!" -# ) -# paddle.Tensor.size = size +#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. +logger.warn( + "override size of paddle.Tensor " + "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" +) +paddle.Tensor.size = size + + +def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + return xs.reshape(args) + + +if not hasattr(paddle.Tensor, 'view'): + logger.warn("register user view to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view = view def masked_fill(xs: paddle.Tensor, @@ -185,6 +218,14 @@ if not hasattr(paddle.nn, 'ConstantPad2d'): "register user ConstantPad2d to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d) +if not hasattr(paddle, 'softmax'): + logger.warn("register user softmax to paddle, remove this when fixed!") + setattr(paddle, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle, 'sigmoid'): + logger.warn("register user softmax to paddle, remove this when fixed!") + setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) + # hack loss def ctc_loss(logits, diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py new file mode 100644 index 00000000..d75a7f84 --- /dev/null +++ b/deepspeech/modules/attention.py @@ -0,0 +1,227 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-Head Attention layer definition.""" +import math +import logging +from typing import Optional, Tuple + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] + + +class MultiHeadedAttention(nn.Layer): + """Multi-Head Attention layer.""" + + def __init__(self, n_head: int, n_feat: int, dropout_rate: float): + """Construct an MultiHeadedAttention object. + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Transform query, key and value. + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + Returns: + paddle.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + paddle.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + paddle.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) + k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) + v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, + value: paddle.Tensor, + scores: paddle.Tensor, + mask: Optional[paddle.Tensor]) -> paddle.Tensor: + """Compute attention context vector. + Args: + value (paddle.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (paddle.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (paddle.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + paddle.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = paddle.softmax( + scores, axis=-1).masked_fill(mask, + 0.0) # (batch, head, time1, time2) + else: + attn = paddle.softmax( + scores, axis=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) + x = x.transpose([0, 2, 1, 3]).contiguous().view( + n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: Optional[paddle.Tensor]) -> paddle.Tensor: + """Compute scaled dot product attention. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + scores = paddle.matmul(q, k.transpose( + [0, 1, 3, 2])) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding.""" + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x, zero_triu: bool=False): + """Compute relative positinal encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, size). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + torch.Tensor: Output tensor. + """ + + zero_pad = torch.zeros( + (x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_emb: torch.Tensor, + mask: Optional[torch.Tensor]): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + # matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) diff --git a/requirements.txt b/requirements.txt index 14d7c032..f0d92f7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ python_speech_features tensorboardX yacs typeguard +pre-commit -- GitLab