Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
db121226
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“68aa5004512cd12e8d81b08d2fe40ddcdfb59f2f”上不存在“paddle/fluid/lite/kernels/arm/calib_compute.h”
未验证
提交
db121226
编写于
12月 22, 2021
作者:
H
Hui Zhang
提交者:
GitHub
12月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean aishell asr1 conf & compare ctc loss with torch and warpctc_pytorch (#1191)
上级
54119650
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
635 addition
and
167 deletion
+635
-167
.gitignore
.gitignore
+1
-0
docs/topic/ctc/ctc_loss_compare.ipynb
docs/topic/ctc/ctc_loss_compare.ipynb
+520
-0
examples/aishell/asr1/RESULTS.md
examples/aishell/asr1/RESULTS.md
+4
-4
examples/aishell/asr1/conf/chunk_conformer.yaml
examples/aishell/asr1/conf/chunk_conformer.yaml
+35
-54
examples/aishell/asr1/conf/conformer.yaml
examples/aishell/asr1/conf/conformer.yaml
+30
-43
examples/aishell/asr1/conf/transformer.yaml
examples/aishell/asr1/conf/transformer.yaml
+32
-49
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+13
-17
未找到文件。
.gitignore
浏览文件 @
db121226
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
build
build
docs/build/
docs/build/
docs/topic/ctc/warp-ctc/
tools/venv
tools/venv
tools/kenlm
tools/kenlm
...
...
docs/topic/ctc/ctc_loss_compare.ipynb
0 → 100644
浏览文件 @
db121226
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ff6ff1e0",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "33af5f76",
"metadata": {},
"outputs": [],
"source": [
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9b566b73",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cloning into 'warp-ctc'...\n",
"remote: Enumerating objects: 829, done.\u001b[K\n",
"remote: Total 829 (delta 0), reused 0 (delta 0), pack-reused 829\u001b[K\n",
"Receiving objects: 100% (829/829), 388.85 KiB | 140.00 KiB/s, done.\n",
"Resolving deltas: 100% (419/419), done.\n",
"Checking connectivity... done.\n"
]
}
],
"source": [
"!git clone https://github.com/SeanNaren/warp-ctc.git"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4a087a09",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
]
}
],
"source": [
"%cd warp-ctc"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f55dc29a",
"metadata": {},
"outputs": [],
"source": [
"mkdir -p build"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fe79f4cf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n"
]
}
],
"source": [
"cd build"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3d25c718",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-- The C compiler identification is GNU 5.4.0\n",
"-- The CXX compiler identification is GNU 5.4.0\n",
"-- Check for working C compiler: /usr/bin/cc\n",
"-- Check for working C compiler: /usr/bin/cc -- works\n",
"-- Detecting C compiler ABI info\n",
"-- Detecting C compiler ABI info - done\n",
"-- Detecting C compile features\n",
"-- Detecting C compile features - done\n",
"-- Check for working CXX compiler: /usr/bin/c++\n",
"-- Check for working CXX compiler: /usr/bin/c++ -- works\n",
"-- Detecting CXX compiler ABI info\n",
"-- Detecting CXX compiler ABI info - done\n",
"-- Detecting CXX compile features\n",
"-- Detecting CXX compile features - done\n",
"-- Looking for pthread.h\n",
"-- Looking for pthread.h - found\n",
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD\n",
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed\n",
"-- Looking for pthread_create in pthreads\n",
"-- Looking for pthread_create in pthreads - not found\n",
"-- Looking for pthread_create in pthread\n",
"-- Looking for pthread_create in pthread - found\n",
"-- Found Threads: TRUE \n",
"-- Found CUDA: /usr/local/cuda (found suitable version \"10.2\", minimum required is \"6.5\") \n",
"-- cuda found TRUE\n",
"-- Building shared library with GPU support\n",
"-- Configuring done\n",
"-- Generating done\n",
"-- Build files have been written to: /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n"
]
}
],
"source": [
"!cmake .."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7a4238f1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 11%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o\u001b[0m\n",
"[ 22%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_ctc_entrypoint.cu.o\u001b[0m\n",
"\u001b[35m\u001b[1mScanning dependencies of target warpctc\u001b[0m\n",
"[ 33%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
"[ 33%] Built target warpctc\n",
"[ 44%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/test_gpu.dir/tests/test_gpu_generated_test_gpu.cu.o\u001b[0m\n",
"\u001b[35m\u001b[1mScanning dependencies of target test_cpu\u001b[0m\n",
"[ 55%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/test_cpu.cpp.o\u001b[0m\n",
"[ 66%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/random.cpp.o\u001b[0m\n",
"[ 77%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
"[ 77%] Built target test_cpu\n",
"\u001b[35m\u001b[1mScanning dependencies of target test_gpu\u001b[0m\n",
"[ 88%] \u001b[32mBuilding CXX object CMakeFiles/test_gpu.dir/tests/random.cpp.o\u001b[0m\n",
"[100%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
"[100%] Built target test_gpu\n"
]
}
],
"source": [
"!make -j"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "31761a31",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
]
}
],
"source": [
"cd .."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f53316f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding\n"
]
}
],
"source": [
"cd pytorch_binding"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "084f1e49",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"running install\n",
"running bdist_egg\n",
"running egg_info\n",
"creating warpctc_pytorch.egg-info\n",
"writing warpctc_pytorch.egg-info/PKG-INFO\n",
"writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n",
"writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"installing library code to build/bdist.linux-x86_64/egg\n",
"running install_lib\n",
"running build_py\n",
"creating build\n",
"creating build/lib.linux-x86_64-3.9\n",
"creating build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
"copying warpctc_pytorch/__init__.py -> build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
"running build_ext\n",
"building 'warpctc_pytorch._warp_ctc' extension\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src\n",
"Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n",
"Compiling objects...\n",
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
"[1/1] c++ -MMD -MF /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o.d -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -I/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/TH -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/include/python3.9 -c -c /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/src/binding.cpp -o /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -std=c++14 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1011\"' -DTORCH_EXTENSION_NAME=_warp_ctc -D_GLIBCXX_USE_CXX11_ABI=0\n",
"g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n",
"creating build/bdist.linux-x86_64\n",
"creating build/bdist.linux-x86_64/egg\n",
"creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/__init__.py to __init__.cpython-39.pyc\n",
"creating stub loader for warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so\n",
"byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/_warp_ctc.py to _warp_ctc.cpython-39.pyc\n",
"creating build/bdist.linux-x86_64/egg/EGG-INFO\n",
"copying warpctc_pytorch.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
"copying warpctc_pytorch.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
"copying warpctc_pytorch.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
"copying warpctc_pytorch.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
"writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n",
"zip_safe flag not set; analyzing archive contents...\n",
"warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n",
"creating dist\n",
"creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n",
"removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n",
"Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
"removing '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' (and everything under it)\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
"Extracting warpctc_pytorch-0.1-py3.9-linux-x86_64.egg to /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages\n",
"warpctc-pytorch 0.1 is already the active version in easy-install.pth\n",
"\n",
"Installed /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
"Processing dependencies for warpctc-pytorch==0.1\n",
"Finished processing dependencies for warpctc-pytorch==0.1\n"
]
}
],
"source": [
"!python setup.py install"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ee4ca9e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python 3.9.5\r\n"
]
}
],
"source": [
"!python -V"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "59255ed8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
]
}
],
"source": [
"cd .."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1dae09b9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"grep: warning: GREP_OPTIONS is deprecated; please use an alias or script\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import warpctc_pytorch as wp\n",
"import paddle.nn as pn\n",
"import paddle"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "83d0762e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.10.0+cu102'"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "62501e2c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.2.0'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.__version__"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "9e8e0f40",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 1, 5])\n",
"2.4628584384918213\n",
"[[[ 0.17703122 -0.70812464 0.17703122 0.17703122 0.17703122]]\n",
"\n",
" [[ 0.17703122 0.17703122 -0.70812464 0.17703122 0.17703122]]]\n"
]
}
],
"source": [
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
"print(probs.size())\n",
"labels = torch.IntTensor([1, 2])\n",
"label_sizes = torch.IntTensor([2])\n",
"probs_sizes = torch.IntTensor([2])\n",
"probs.requires_grad_(True)\n",
"bs = probs.size(1)\n",
"\n",
"ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
"cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "2cd46569",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.4628584384918213\n",
"[[[ 0.1770312 -0.7081248 0.1770312 0.1770312 0.1770312]]\n",
"\n",
" [[ 0.1770312 0.1770312 -0.7081248 0.1770312 0.1770312]]]\n"
]
}
],
"source": [
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
"labels = torch.IntTensor([1, 2])\n",
"label_sizes = torch.IntTensor([2])\n",
"probs_sizes = torch.IntTensor([2])\n",
"probs.requires_grad_(True)\n",
"bs = probs.size(1)\n",
"\n",
"log_probs = torch.log_softmax(probs, axis=-1)\n",
"\n",
"ctc_loss1 = nn.CTCLoss(reduction='none')\n",
"cost = ctc_loss1(log_probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "85c3461a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2, 1, 5]\n",
"[1, 2]\n",
"2.4628584384918213\n",
"[[[ 0.17703122 -0.70812464 0.17703122 0.17703122 0.17703122]]\n",
"\n",
" [[ 0.17703122 0.17703122 -0.70812464 0.17703122 0.17703122]]]\n"
]
}
],
"source": [
"paddle.set_device('cpu')\n",
"probs = paddle.to_tensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n",
" ]]).transpose([1,0,2])\n",
"print(probs.shape) # (T, B, D)\n",
"labels = paddle.to_tensor([[1, 2]], dtype='int32') #(B,L)\n",
"print(labels.shape)\n",
"label_sizes = paddle.to_tensor([2], dtype='int64')\n",
"probs_sizes = paddle.to_tensor([2], dtype='int64')\n",
"bs = paddle.shape(probs)[1]\n",
"probs.stop_gradient=False\n",
"\n",
"ctc_loss = pn.CTCLoss(reduction='none')\n",
"cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d390cd91",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
examples/aishell/asr1/RESULTS.md
浏览文件 @
db121226
...
@@ -25,7 +25,7 @@ Need set `decoding.decoding_chunk_size=16` when decoding.
...
@@ -25,7 +25,7 @@ Need set `decoding.decoding_chunk_size=16` when decoding.
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | attention | 3.8
58648955821991 | 0.057293
|
| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | attention | 3.8
103787302970886 | 0.056588
|
| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | ctc_greedy_search | 3.8
58648955821991 | 0.061837
|
| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | ctc_greedy_search | 3.8
103787302970886 | 0.059932
|
| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | ctc_prefix_beam_search | 3.8
58648955821991 | 0.061685
|
| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | ctc_prefix_beam_search | 3.8
103787302970886 | 0.059989
|
| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | attention_rescoring | 3.8
58648955821991 | 0.053844
|
| transformer | 31.95M | conf/transformer.yaml | spec_aug | test | attention_rescoring | 3.8
103787302970886 | 0.052273
|
examples/aishell/asr1/conf/chunk_conformer.yaml
浏览文件 @
db121226
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
min_input_len
:
0.5
max_input_len
:
20.0
# second
min_output_len
:
0.0
max_output_len
:
400.0
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
collator
:
vocab_filepath
:
data/lang_char/vocab.txt
unit_type
:
'
char'
spm_model_prefix
:
'
'
augmentation_config
:
conf/preprocess.yaml
batch_size
:
32
raw_wav
:
True
# use raw_wav or kaldi feature
spectrum_type
:
fbank
#linear, mfcc, fbank
feat_dim
:
80
delta_delta
:
False
dither
:
1.0
target_sample_rate
:
16000
max_freq
:
None
n_fft
:
None
stride_ms
:
10.0
window_ms
:
25.0
use_dB_normalization
:
True
target_dB
:
-20
random_seed
:
0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
2
# network architecture
# network architecture
model
:
model
:
cmvn_file
:
cmvn_file
:
...
@@ -52,8 +14,8 @@ model:
...
@@ -52,8 +14,8 @@ model:
attention_dropout_rate
:
0.0
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
normalize_before
:
True
use_cnn_module
:
True
cnn_module_kernel
:
15
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rel_pos'
pos_enc_layer_type
:
'
rel_pos'
selfattention_layer_type
:
'
rel_selfattn'
selfattention_layer_type
:
'
rel_selfattn'
...
@@ -76,21 +38,47 @@ model:
...
@@ -76,21 +38,47 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
null
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
collator
:
vocab_filepath
:
data/lang_char/vocab.txt
unit_type
:
'
char'
augmentation_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
64
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
0
subsampling_factor
:
1
num_encs
:
1
training
:
training
:
n_epoch
:
240
n_epoch
:
240
accum_grad
:
4
accum_grad
:
2
global_grad_clip
:
5.0
global_grad_clip
:
5.0
optim
:
adam
optim
:
adam
optim_conf
:
optim_conf
:
lr
:
0.00
1
lr
:
0.00
2
weight_decay
:
1e-6
weight_decay
:
1e-6
scheduler
:
warmuplr
scheduler
:
warmuplr
scheduler_conf
:
scheduler_conf
:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
...
@@ -101,22 +89,15 @@ training:
...
@@ -101,22 +89,15 @@ training:
decoding
:
decoding
:
beam_size
:
10
batch_size
:
128
batch_size
:
128
error_rate_type
:
cer
error_rate_type
:
cer
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
alpha
:
2.5
beta
:
0.3
beam_size
:
10
cutoff_prob
:
1.0
cutoff_top_n
:
0
num_proc_bsearch
:
8
ctc_weight
:
0.5
# ctc weight for attention rescoring decode mode.
ctc_weight
:
0.5
# ctc weight for attention rescoring decode mode.
decoding_chunk_size
:
-1
# decoding chunk size. Defaults to -1.
decoding_chunk_size
:
-1
# decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks
:
-1
# number of left chunks for decoding. Defaults to -1.
num_decoding_left_chunks
:
-1
# number of left chunks for decoding. Defaults to -1.
simulate_streaming
:
true
# simulate streaming inference. Defaults to False.
simulate_streaming
:
False
# simulate streaming inference. Defaults to False.
examples/aishell/asr1/conf/conformer.yaml
浏览文件 @
db121226
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
collator
:
vocab_filepath
:
data/lang_char/vocab.txt
unit_type
:
'
char'
spm_model_prefix
:
'
'
augmentation_config
:
conf/preprocess.yaml
batch_size
:
64
raw_wav
:
True
# use raw_wav or kaldi feature
spectrum_type
:
fbank
#linear, mfcc, fbank
feat_dim
:
80
delta_delta
:
False
dither
:
1.0
target_sample_rate
:
16000
max_freq
:
None
n_fft
:
None
stride_ms
:
10.0
window_ms
:
25.0
use_dB_normalization
:
True
target_dB
:
-20
random_seed
:
0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
2
# network architecture
# network architecture
model
:
model
:
cmvn_file
:
cmvn_file
:
...
@@ -44,8 +14,8 @@ model:
...
@@ -44,8 +14,8 @@ model:
attention_dropout_rate
:
0.0
attention_dropout_rate
:
0.0
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before
:
True
normalize_before
:
True
use_cnn_module
:
True
cnn_module_kernel
:
15
cnn_module_kernel
:
15
use_cnn_module
:
True
activation_type
:
'
swish'
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rel_pos'
pos_enc_layer_type
:
'
rel_pos'
selfattention_layer_type
:
'
rel_selfattn'
selfattention_layer_type
:
'
rel_selfattn'
...
@@ -64,11 +34,36 @@ model:
...
@@ -64,11 +34,36 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
null
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
collator
:
vocab_filepath
:
data/lang_char/vocab.txt
unit_type
:
'
char'
augmentation_config
:
conf/preprocess.yaml
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
64
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
0
subsampling_factor
:
1
num_encs
:
1
training
:
training
:
n_epoch
:
240
n_epoch
:
240
...
@@ -78,7 +73,7 @@ training:
...
@@ -78,7 +73,7 @@ training:
optim_conf
:
optim_conf
:
lr
:
0.002
lr
:
0.002
weight_decay
:
1e-6
weight_decay
:
1e-6
scheduler
:
warmuplr
# pytorch v1.1.0+ required
scheduler
:
warmuplr
scheduler_conf
:
scheduler_conf
:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
...
@@ -89,16 +84,10 @@ training:
...
@@ -89,16 +84,10 @@ training:
decoding
:
decoding
:
beam_size
:
10
batch_size
:
128
batch_size
:
128
error_rate_type
:
cer
error_rate_type
:
cer
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
alpha
:
2.5
beta
:
0.3
beam_size
:
10
cutoff_prob
:
1.0
cutoff_top_n
:
0
num_proc_bsearch
:
8
ctc_weight
:
0.5
# ctc weight for attention rescoring decode mode.
ctc_weight
:
0.5
# ctc weight for attention rescoring decode mode.
decoding_chunk_size
:
-1
# decoding chunk size. Defaults to -1.
decoding_chunk_size
:
-1
# decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# <0: for decoding, use full chunk.
...
@@ -106,5 +95,3 @@ decoding:
...
@@ -106,5 +95,3 @@ decoding:
# 0: used for training, it's prohibited here.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks
:
-1
# number of left chunks for decoding. Defaults to -1.
num_decoding_left_chunks
:
-1
# number of left chunks for decoding. Defaults to -1.
simulate_streaming
:
False
# simulate streaming inference. Defaults to False.
simulate_streaming
:
False
# simulate streaming inference. Defaults to False.
examples/aishell/asr1/conf/transformer.yaml
浏览文件 @
db121226
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
min_input_len
:
0.5
max_input_len
:
20.0
# second
min_output_len
:
0.0
max_output_len
:
400.0
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
collator
:
vocab_filepath
:
data/lang_char/vocab.txt
unit_type
:
'
char'
spm_model_prefix
:
'
'
augmentation_config
:
conf/preprocess.yaml
batch_size
:
64
raw_wav
:
True
# use raw_wav or kaldi feature
spectrum_type
:
fbank
#linear, mfcc, fbank
feat_dim
:
80
delta_delta
:
False
dither
:
1.0
target_sample_rate
:
16000
max_freq
:
None
n_fft
:
None
stride_ms
:
10.0
window_ms
:
25.0
use_dB_normalization
:
True
target_dB
:
-20
random_seed
:
0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
2
# network architecture
# network architecture
model
:
model
:
cmvn_file
:
cmvn_file
:
...
@@ -66,12 +29,40 @@ model:
...
@@ -66,12 +29,40 @@ model:
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
ctc_dropoutrate
:
0.0
ctc_grad_norm_type
:
null
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
collator
:
unit_type
:
'
char'
vocab_filepath
:
data/lang_char/vocab.txt
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
64
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
augmentation_config
:
conf/preprocess.yaml
num_workers
:
0
subsampling_factor
:
1
num_encs
:
1
training
:
training
:
n_epoch
:
240
n_epoch
:
240
accum_grad
:
2
accum_grad
:
2
...
@@ -91,22 +82,14 @@ training:
...
@@ -91,22 +82,14 @@ training:
decoding
:
decoding
:
beam_size
:
10
batch_size
:
128
batch_size
:
128
error_rate_type
:
cer
error_rate_type
:
cer
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
alpha
:
2.5
beta
:
0.3
beam_size
:
10
cutoff_prob
:
1.0
cutoff_top_n
:
0
num_proc_bsearch
:
8
ctc_weight
:
0.5
# ctc weight for attention rescoring decode mode.
ctc_weight
:
0.5
# ctc weight for attention rescoring decode mode.
decoding_chunk_size
:
-1
# decoding chunk size. Defaults to -1.
decoding_chunk_size
:
-1
# decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks
:
-1
# number of left chunks for decoding. Defaults to -1.
num_decoding_left_chunks
:
-1
# number of left chunks for decoding. Defaults to -1.
simulate_streaming
:
False
# simulate streaming inference. Defaults to False.
simulate_streaming
:
False
# simulate streaming inference. Defaults to False.
\ No newline at end of file
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
db121226
...
@@ -254,19 +254,18 @@ class U2Trainer(Trainer):
...
@@ -254,19 +254,18 @@ class U2Trainer(Trainer):
self
.
train_loader
=
BatchDataLoader
(
self
.
train_loader
=
BatchDataLoader
(
json_file
=
config
.
data
.
train_manifest
,
json_file
=
config
.
data
.
train_manifest
,
train_mode
=
True
,
train_mode
=
True
,
sortagrad
=
False
,
sortagrad
=
config
.
collator
.
sortagrad
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
collator
.
batch_size
,
maxlen_in
=
float
(
'inf'
)
,
maxlen_in
=
config
.
collator
.
maxlen_in
,
maxlen_out
=
float
(
'inf'
)
,
maxlen_out
=
config
.
collator
.
maxlen_out
,
minibatches
=
0
,
minibatches
=
config
.
collator
.
minibatches
,
mini_batch_size
=
self
.
args
.
ngpu
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
'auto'
,
batch_count
=
config
.
collator
.
batch_count
,
batch_bins
=
0
,
batch_bins
=
config
.
collator
.
batch_bins
,
batch_frames_in
=
0
,
batch_frames_in
=
config
.
collator
.
batch_frames_in
,
batch_frames_out
=
0
,
batch_frames_out
=
config
.
collator
.
batch_frames_out
,
batch_frames_inout
=
0
,
batch_frames_inout
=
config
.
collator
.
batch_frames_inout
,
preprocess_conf
=
config
.
collator
.
preprocess_conf
=
config
.
collator
.
augmentation_config
,
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
collator
.
num_workers
,
n_iter_processes
=
config
.
collator
.
num_workers
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
)
...
@@ -285,8 +284,7 @@ class U2Trainer(Trainer):
...
@@ -285,8 +284,7 @@ class U2Trainer(Trainer):
batch_frames_in
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
preprocess_conf
=
config
.
collator
.
augmentation_config
,
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
collator
.
num_workers
,
n_iter_processes
=
config
.
collator
.
num_workers
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
)
...
@@ -307,8 +305,7 @@ class U2Trainer(Trainer):
...
@@ -307,8 +305,7 @@ class U2Trainer(Trainer):
batch_frames_in
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
preprocess_conf
=
config
.
collator
.
augmentation_config
,
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
1
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
)
...
@@ -327,8 +324,7 @@ class U2Trainer(Trainer):
...
@@ -327,8 +324,7 @@ class U2Trainer(Trainer):
batch_frames_in
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
preprocess_conf
=
config
.
collator
.
augmentation_config
,
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
1
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录