From db121226b83b57a8c1e76b913ac05993980d90b0 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 22 Dec 2021 17:19:22 +0800 Subject: [PATCH] clean aishell asr1 conf & compare ctc loss with torch and warpctc_pytorch (#1191) --- .gitignore | 1 + docs/topic/ctc/ctc_loss_compare.ipynb | 520 ++++++++++++++++++ examples/aishell/asr1/RESULTS.md | 8 +- .../aishell/asr1/conf/chunk_conformer.yaml | 89 ++- examples/aishell/asr1/conf/conformer.yaml | 73 +-- examples/aishell/asr1/conf/transformer.yaml | 81 ++- paddlespeech/s2t/exps/u2/model.py | 30 +- 7 files changed, 635 insertions(+), 167 deletions(-) create mode 100644 docs/topic/ctc/ctc_loss_compare.ipynb diff --git a/.gitignore b/.gitignore index 8cbb734d..cc8fff87 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ build docs/build/ +docs/topic/ctc/warp-ctc/ tools/venv tools/kenlm diff --git a/docs/topic/ctc/ctc_loss_compare.ipynb b/docs/topic/ctc/ctc_loss_compare.ipynb new file mode 100644 index 00000000..95b2af50 --- /dev/null +++ b/docs/topic/ctc/ctc_loss_compare.ipynb @@ -0,0 +1,520 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ff6ff1e0", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "33af5f76", + "metadata": {}, + "outputs": [], + "source": [ + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9b566b73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'warp-ctc'...\n", + "remote: Enumerating objects: 829, done.\u001b[K\n", + "remote: Total 829 (delta 0), reused 0 (delta 0), pack-reused 829\u001b[K\n", + "Receiving objects: 100% (829/829), 388.85 KiB | 140.00 KiB/s, done.\n", + "Resolving deltas: 100% (419/419), done.\n", + "Checking connectivity... done.\n" + ] + } + ], + "source": [ + "!git clone https://github.com/SeanNaren/warp-ctc.git" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4a087a09", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n" + ] + } + ], + "source": [ + "%cd warp-ctc" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f55dc29a", + "metadata": {}, + "outputs": [], + "source": [ + "mkdir -p build" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fe79f4cf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n" + ] + } + ], + "source": [ + "cd build" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3d25c718", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-- The C compiler identification is GNU 5.4.0\n", + "-- The CXX compiler identification is GNU 5.4.0\n", + "-- Check for working C compiler: /usr/bin/cc\n", + "-- Check for working C compiler: /usr/bin/cc -- works\n", + "-- Detecting C compiler ABI info\n", + "-- Detecting C compiler ABI info - done\n", + "-- Detecting C compile features\n", + "-- Detecting C compile features - done\n", + "-- Check for working CXX compiler: /usr/bin/c++\n", + "-- Check for working CXX compiler: /usr/bin/c++ -- works\n", + "-- Detecting CXX compiler ABI info\n", + "-- Detecting CXX compiler ABI info - done\n", + "-- Detecting CXX compile features\n", + "-- Detecting CXX compile features - done\n", + "-- Looking for pthread.h\n", + "-- Looking for pthread.h - found\n", + "-- Performing Test CMAKE_HAVE_LIBC_PTHREAD\n", + "-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed\n", + "-- Looking for pthread_create in pthreads\n", + "-- Looking for pthread_create in pthreads - not found\n", + "-- Looking for pthread_create in pthread\n", + "-- Looking for pthread_create in pthread - found\n", + "-- Found Threads: TRUE \n", + "-- Found CUDA: /usr/local/cuda (found suitable version \"10.2\", minimum required is \"6.5\") \n", + "-- cuda found TRUE\n", + "-- Building shared library with GPU support\n", + "-- Configuring done\n", + "-- Generating done\n", + "-- Build files have been written to: /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n" + ] + } + ], + "source": [ + "!cmake .." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7a4238f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 11%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o\u001b[0m\n", + "[ 22%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_ctc_entrypoint.cu.o\u001b[0m\n", + "\u001b[35m\u001b[1mScanning dependencies of target warpctc\u001b[0m\n", + "[ 33%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n", + "[ 33%] Built target warpctc\n", + "[ 44%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/test_gpu.dir/tests/test_gpu_generated_test_gpu.cu.o\u001b[0m\n", + "\u001b[35m\u001b[1mScanning dependencies of target test_cpu\u001b[0m\n", + "[ 55%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/test_cpu.cpp.o\u001b[0m\n", + "[ 66%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/random.cpp.o\u001b[0m\n", + "[ 77%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n", + "[ 77%] Built target test_cpu\n", + "\u001b[35m\u001b[1mScanning dependencies of target test_gpu\u001b[0m\n", + "[ 88%] \u001b[32mBuilding CXX object CMakeFiles/test_gpu.dir/tests/random.cpp.o\u001b[0m\n", + "[100%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n", + "[100%] Built target test_gpu\n" + ] + } + ], + "source": [ + "!make -j" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "31761a31", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n" + ] + } + ], + "source": [ + "cd .." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "f53316f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding\n" + ] + } + ], + "source": [ + "cd pytorch_binding" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "084f1e49", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "running install\n", + "running bdist_egg\n", + "running egg_info\n", + "creating warpctc_pytorch.egg-info\n", + "writing warpctc_pytorch.egg-info/PKG-INFO\n", + "writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n", + "writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n", + "writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n", + "writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n", + "installing library code to build/bdist.linux-x86_64/egg\n", + "running install_lib\n", + "running build_py\n", + "creating build\n", + "creating build/lib.linux-x86_64-3.9\n", + "creating build/lib.linux-x86_64-3.9/warpctc_pytorch\n", + "copying warpctc_pytorch/__init__.py -> build/lib.linux-x86_64-3.9/warpctc_pytorch\n", + "running build_ext\n", + "building 'warpctc_pytorch._warp_ctc' extension\n", + "creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9\n", + "creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src\n", + "Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n", + "Compiling objects...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "[1/1] c++ -MMD -MF /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o.d -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -I/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/TH -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/include/python3.9 -c -c /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/src/binding.cpp -o /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -std=c++14 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1011\"' -DTORCH_EXTENSION_NAME=_warp_ctc -D_GLIBCXX_USE_CXX11_ABI=0\n", + "g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n", + "creating build/bdist.linux-x86_64\n", + "creating build/bdist.linux-x86_64/egg\n", + "creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n", + "copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n", + "copying build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n", + "byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/__init__.py to __init__.cpython-39.pyc\n", + "creating stub loader for warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so\n", + "byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/_warp_ctc.py to _warp_ctc.cpython-39.pyc\n", + "creating build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying warpctc_pytorch.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying warpctc_pytorch.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying warpctc_pytorch.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying warpctc_pytorch.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n", + "zip_safe flag not set; analyzing archive contents...\n", + "warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n", + "creating dist\n", + "creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n", + "removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n", + "Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n", + "removing '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' (and everything under it)\n", + "creating /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n", + "Extracting warpctc_pytorch-0.1-py3.9-linux-x86_64.egg to /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages\n", + "warpctc-pytorch 0.1 is already the active version in easy-install.pth\n", + "\n", + "Installed /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n", + "Processing dependencies for warpctc-pytorch==0.1\n", + "Finished processing dependencies for warpctc-pytorch==0.1\n" + ] + } + ], + "source": [ + "!python setup.py install" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ee4ca9e3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python 3.9.5\r\n" + ] + } + ], + "source": [ + "!python -V" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "59255ed8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n" + ] + } + ], + "source": [ + "cd .." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "1dae09b9", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "grep: warning: GREP_OPTIONS is deprecated; please use an alias or script\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import warpctc_pytorch as wp\n", + "import paddle.nn as pn\n", + "import paddle" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "83d0762e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.10.0+cu102'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "62501e2c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.2.0'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "paddle.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9e8e0f40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 1, 5])\n", + "2.4628584384918213\n", + "[[[ 0.17703122 -0.70812464 0.17703122 0.17703122 0.17703122]]\n", + "\n", + " [[ 0.17703122 0.17703122 -0.70812464 0.17703122 0.17703122]]]\n" + ] + } + ], + "source": [ + "probs = torch.FloatTensor([[\n", + " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n", + " ]]).transpose(0, 1).contiguous()\n", + "print(probs.size())\n", + "labels = torch.IntTensor([1, 2])\n", + "label_sizes = torch.IntTensor([2])\n", + "probs_sizes = torch.IntTensor([2])\n", + "probs.requires_grad_(True)\n", + "bs = probs.size(1)\n", + "\n", + "ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n", + "cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n", + "cost = cost.sum() / bs\n", + "print(cost.item())\n", + "cost.backward()\n", + "print(probs.grad.numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "2cd46569", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.4628584384918213\n", + "[[[ 0.1770312 -0.7081248 0.1770312 0.1770312 0.1770312]]\n", + "\n", + " [[ 0.1770312 0.1770312 -0.7081248 0.1770312 0.1770312]]]\n" + ] + } + ], + "source": [ + "probs = torch.FloatTensor([[\n", + " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n", + " ]]).transpose(0, 1).contiguous()\n", + "labels = torch.IntTensor([1, 2])\n", + "label_sizes = torch.IntTensor([2])\n", + "probs_sizes = torch.IntTensor([2])\n", + "probs.requires_grad_(True)\n", + "bs = probs.size(1)\n", + "\n", + "log_probs = torch.log_softmax(probs, axis=-1)\n", + "\n", + "ctc_loss1 = nn.CTCLoss(reduction='none')\n", + "cost = ctc_loss1(log_probs, labels, probs_sizes, label_sizes)\n", + "cost = cost.sum() / bs\n", + "print(cost.item())\n", + "cost.backward()\n", + "print(probs.grad.numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "85c3461a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 1, 5]\n", + "[1, 2]\n", + "2.4628584384918213\n", + "[[[ 0.17703122 -0.70812464 0.17703122 0.17703122 0.17703122]]\n", + "\n", + " [[ 0.17703122 0.17703122 -0.70812464 0.17703122 0.17703122]]]\n" + ] + } + ], + "source": [ + "paddle.set_device('cpu')\n", + "probs = paddle.to_tensor([[\n", + " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n", + " ]]).transpose([1,0,2])\n", + "print(probs.shape) # (T, B, D)\n", + "labels = paddle.to_tensor([[1, 2]], dtype='int32') #(B,L)\n", + "print(labels.shape)\n", + "label_sizes = paddle.to_tensor([2], dtype='int64')\n", + "probs_sizes = paddle.to_tensor([2], dtype='int64')\n", + "bs = paddle.shape(probs)[1]\n", + "probs.stop_gradient=False\n", + "\n", + "ctc_loss = pn.CTCLoss(reduction='none')\n", + "cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n", + "cost = cost.sum() / bs\n", + "print(cost.item())\n", + "cost.backward()\n", + "print(probs.grad.numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d390cd91", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/aishell/asr1/RESULTS.md b/examples/aishell/asr1/RESULTS.md index 783e179e..b68d6992 100644 --- a/examples/aishell/asr1/RESULTS.md +++ b/examples/aishell/asr1/RESULTS.md @@ -25,7 +25,7 @@ Need set `decoding.decoding_chunk_size=16` when decoding. | Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER | | --- | --- | --- | --- | --- | --- | --- | --- | -| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | attention | 3.858648955821991 | 0.057293 | -| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | ctc_greedy_search | 3.858648955821991 | 0.061837 | -| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | ctc_prefix_beam_search | 3.858648955821991 | 0.061685 | -| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | attention_rescoring | 3.858648955821991 | 0.053844 | +| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | attention | 3.8103787302970886 | 0.056588 | +| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | ctc_greedy_search | 3.8103787302970886 | 0.059932 | +| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | ctc_prefix_beam_search | 3.8103787302970886 | 0.059989 | +| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | attention_rescoring | 3.8103787302970886 | 0.052273 | diff --git a/examples/aishell/asr1/conf/chunk_conformer.yaml b/examples/aishell/asr1/conf/chunk_conformer.yaml index 80b45587..50eaef98 100644 --- a/examples/aishell/asr1/conf/chunk_conformer.yaml +++ b/examples/aishell/asr1/conf/chunk_conformer.yaml @@ -1,41 +1,3 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test - min_input_len: 0.5 - max_input_len: 20.0 # second - min_output_len: 0.0 - max_output_len: 400.0 - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 - - -collator: - vocab_filepath: data/lang_char/vocab.txt - unit_type: 'char' - spm_model_prefix: '' - augmentation_config: conf/preprocess.yaml - batch_size: 32 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - - # network architecture model: cmvn_file: @@ -52,8 +14,8 @@ model: attention_dropout_rate: 0.0 input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 normalize_before: True - use_cnn_module: True cnn_module_kernel: 15 + use_cnn_module: True activation_type: 'swish' pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' @@ -76,21 +38,47 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + + +collator: + vocab_filepath: data/lang_char/vocab.txt + unit_type: 'char' + augmentation_config: conf/preprocess.yaml + feat_dim: 80 + stride_ms: 10.0 + window_ms: 25.0 + sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs + batch_size: 64 + maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced + maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced + minibatches: 0 # for debug + batch_count: auto + batch_bins: 0 + batch_frames_in: 0 + batch_frames_out: 0 + batch_frames_inout: 0 + num_workers: 0 + subsampling_factor: 1 + num_encs: 1 + + training: - n_epoch: 240 - accum_grad: 4 + n_epoch: 240 + accum_grad: 2 global_grad_clip: 5.0 optim: adam optim_conf: - lr: 0.001 + lr: 0.002 weight_decay: 1e-6 - scheduler: warmuplr + scheduler: warmuplr scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 @@ -101,22 +89,15 @@ training: decoding: + beam_size: 10 batch_size: 128 error_rate_type: cer decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: true # simulate streaming inference. Defaults to False. - + simulate_streaming: False # simulate streaming inference. Defaults to False. diff --git a/examples/aishell/asr1/conf/conformer.yaml b/examples/aishell/asr1/conf/conformer.yaml index 67a96e69..907e3a94 100644 --- a/examples/aishell/asr1/conf/conformer.yaml +++ b/examples/aishell/asr1/conf/conformer.yaml @@ -1,33 +1,3 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test - -collator: - vocab_filepath: data/lang_char/vocab.txt - unit_type: 'char' - spm_model_prefix: '' - augmentation_config: conf/preprocess.yaml - batch_size: 64 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - # network architecture model: cmvn_file: @@ -44,8 +14,8 @@ model: attention_dropout_rate: 0.0 input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 normalize_before: True - use_cnn_module: True cnn_module_kernel: 15 + use_cnn_module: True activation_type: 'swish' pos_enc_layer_type: 'rel_pos' selfattention_layer_type: 'rel_selfattn' @@ -64,11 +34,36 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + + +collator: + vocab_filepath: data/lang_char/vocab.txt + unit_type: 'char' + augmentation_config: conf/preprocess.yaml + feat_dim: 80 + stride_ms: 10.0 + window_ms: 25.0 + sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs + batch_size: 64 + maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced + maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced + minibatches: 0 # for debug + batch_count: auto + batch_bins: 0 + batch_frames_in: 0 + batch_frames_out: 0 + batch_frames_inout: 0 + num_workers: 0 + subsampling_factor: 1 + num_encs: 1 + training: n_epoch: 240 @@ -78,7 +73,7 @@ training: optim_conf: lr: 0.002 weight_decay: 1e-6 - scheduler: warmuplr # pytorch v1.1.0+ required + scheduler: warmuplr scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 @@ -89,16 +84,10 @@ training: decoding: + beam_size: 10 batch_size: 128 error_rate_type: cer decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. @@ -106,5 +95,3 @@ decoding: # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. simulate_streaming: False # simulate streaming inference. Defaults to False. - - diff --git a/examples/aishell/asr1/conf/transformer.yaml b/examples/aishell/asr1/conf/transformer.yaml index e1006309..7c5fa624 100644 --- a/examples/aishell/asr1/conf/transformer.yaml +++ b/examples/aishell/asr1/conf/transformer.yaml @@ -1,40 +1,3 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test - min_input_len: 0.5 - max_input_len: 20.0 # second - min_output_len: 0.0 - max_output_len: 400.0 - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 - - -collator: - vocab_filepath: data/lang_char/vocab.txt - unit_type: 'char' - spm_model_prefix: '' - augmentation_config: conf/preprocess.yaml - batch_size: 64 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - # network architecture model: cmvn_file: @@ -66,12 +29,40 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + + +collator: + unit_type: 'char' + vocab_filepath: data/lang_char/vocab.txt + feat_dim: 80 + stride_ms: 10.0 + window_ms: 25.0 + sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs + batch_size: 64 + maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced + maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced + minibatches: 0 # for debug + batch_count: auto + batch_bins: 0 + batch_frames_in: 0 + batch_frames_out: 0 + batch_frames_inout: 0 + augmentation_config: conf/preprocess.yaml + num_workers: 0 + subsampling_factor: 1 + num_encs: 1 + + + training: n_epoch: 240 accum_grad: 2 @@ -91,22 +82,14 @@ training: decoding: + beam_size: 10 batch_size: 128 error_rate_type: cer decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: False # simulate streaming inference. Defaults to False. - - + simulate_streaming: False # simulate streaming inference. Defaults to False. \ No newline at end of file diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 404058ed..9fb7067f 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -254,19 +254,18 @@ class U2Trainer(Trainer): self.train_loader = BatchDataLoader( json_file=config.data.train_manifest, train_mode=True, - sortagrad=False, + sortagrad=config.collator.sortagrad, batch_size=config.collator.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, + maxlen_in=config.collator.maxlen_in, + maxlen_out=config.collator.maxlen_out, + minibatches=config.collator.minibatches, mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.collator. - augmentation_config, # aug will be off when train_mode=False + batch_count=config.collator.batch_count, + batch_bins=config.collator.batch_bins, + batch_frames_in=config.collator.batch_frames_in, + batch_frames_out=config.collator.batch_frames_out, + batch_frames_inout=config.collator.batch_frames_inout, + preprocess_conf=config.collator.augmentation_config, n_iter_processes=config.collator.num_workers, subsampling_factor=1, num_encs=1) @@ -285,8 +284,7 @@ class U2Trainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.collator. - augmentation_config, # aug will be off when train_mode=False + preprocess_conf=config.collator.augmentation_config, n_iter_processes=config.collator.num_workers, subsampling_factor=1, num_encs=1) @@ -307,8 +305,7 @@ class U2Trainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.collator. - augmentation_config, # aug will be off when train_mode=False + preprocess_conf=config.collator.augmentation_config, n_iter_processes=1, subsampling_factor=1, num_encs=1) @@ -327,8 +324,7 @@ class U2Trainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.collator. - augmentation_config, # aug will be off when train_mode=False + preprocess_conf=config.collator.augmentation_config, n_iter_processes=1, subsampling_factor=1, num_encs=1) -- GitLab