提交 4f3851f6 编写于 作者: L Li Fuchen 提交者: Yibing Liu

update DeepASR API and README (#3764) (#3821)

* update DeepASR API and README

* update readme

* update README
上级 b9cbe195
The minimum PaddlePaddle version needed for the code sample in this directory is the lastest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
[中文](README_cn.md)
The minimum PaddlePaddle version needed for the code sample in this directory is v1.6.0. If you are on a version of PaddlePaddle earlier than this, [please update your installation](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html).
---
DeepASR (Deep Automatic Speech Recognition) is a speech recognition system based on PaddlePaddle FLuid and Kaldi. It uses the Fluid framework to perform the configuration and training of acoustic models in speech recognition and integrates Kaldi decoder. It is designed to facilitate Kaldi users to implement the rapid and large-scale training of acoustic models, and to use Kaldi to complete complex speech data preprocessing and final decoding processes.
### Content
- [Introduction](#Introduction)
- [Installation](#Installation)
- [Data reprocessing](#Data reprocessing)
- [Training](#Training)
- [Perf profiling](#Perf profiling)
- [Inference & Decoding](#Inference & Decoding)
- [Scoring error rate](#Scoring error rate)
- [Aishell example](#Aishell example)
- [Question and Contribution](#Question and Contribution)
## Deep Automatic Speech Recognition
### Introduction
TBD
DeepASR is an acoustic model of a single conv layer and multi-layer stacked LSTMP structure. Convolution is used for preliminary feature extraction, and multi-layer LSTMP is used to model the timing relationship, using cross entropy as the loss function. [LSTMP](https://arxiv.org/abs/1402.1128) (LSTM with recurrent projection layer) is an extension of the traditional LSTM. It adds a projection layer to the LSTM, projecting the hidden layer to the lower dimension and enters to the next time step. This structure greatly improves the performance of the LSTM while greatly reducing the parameter size and computational complexity of the LSTM.
<p align="center">
<img src="images/lstmp.png" height=240 width=480 hspace='10'/> <br />
图1 LSTMP topology
</p>
### Installation
#### Kaldi
The decoder depends on [kaldi](https://github.com/kaldi-asr/kaldi), install it by flowing its instructions. Then
The decoder depends on [kaldi](https://github.com/kaldi-asr/kaldi), If there is no Kaldi in the environment, please `git clone` its source code and install it by flowing its instructions. Then set the environment variable `KALDI_ROOT`:
```shell
export KALDI_ROOT=<absolute path to kaldi>
......@@ -16,21 +37,140 @@ export KALDI_ROOT=<absolute path to kaldi>
#### Decoder
Enter the directory where the decoder source is located.
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/fluid/DeepASR/decoder
```
Run the installation script.
```shell
sh setup.sh
```
The decoding process was successfully completed after compilation.
### Data reprocessing
TBD
Refer to [Kaldi's data preparation process](http://kaldi-asr.org/doc/data_prep.html) to complete feature extraction and label alignment of audio data.
### Training
TBD
You can choose to train models in CPU or GPU, such as training in GPU:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py \
--train_feature_lst train_feature.lst \
--train_label_lst train_label.lst \
--val_feature_lst val_feature.lst \
--val_label_lst val_label.lst \
--mean_var global_mean_var \
--parallel
```
where `train_feature.lst` and `train_label.lst` are the feature list file and the label list file of the training data. Similarly, `val_feature.lst` and `val_label.lst` correspond to the list file of the validation data. In the actual training process, important arguments such as LSTMP's hidden unit size and learning rate should be correctly specified. For instructions on these parameters, please run:
```shell
python train.py --help
```
to get more information.
### Perf profiling
Using the performance analysis tool profiler provided by Fluid, you can perform performance analysis on the training process and obtain the execution time of the operator level in the network.
```shell
CUDA_VISIBLE_DEVICES=0 python -u tools/profile.py \
--train_feature_lst train_feature.lst \
--train_label_lst train_label.lst \
--val_feature_lst val_feature.lst \
--val_label_lst val_label.lst \
--mean_var global_mean_var
```
### Inference & Decoding
TBD
After fully training the acoustic model, using the model checkpoint saved in the training process, the input audio data can be decoded and output, and the sound to text recognition result can be obtained.
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u infer_by_ckpt.py \
--batch_size 96 \
--checkpoint deep_asr.pass_1.checkpoint \
--infer_feature_lst test_feature.lst \
--infer_label_lst test_label.lst \
--mean_var global_mean_var \
--parallel
```
### Scoring error rate
Word Error Rate (WER) and Character Error Rate (CER) are commonly used to evaluate speech recognition systems. Related measurement tools are also implemented in DeepASR.
```
python score_error_rate.py --error_rate_type cer --ref ref.txt --hyp decoding.txt
```
The parameter `error_rate_type` indicates the type of measurement error rate, ie WER or CER; `ref.txt` and `decoding.txt` represent the reference text and the actually decoded text in the same format:
```
key1 text1
key2 text2
key3 text3
...
```
### Aishell example
This section uses the [Aishell dataset](http://www.aishelltech.com/kysjcp) as an example to show how to perform data preprocessing to decoding output. Aishell is an open Chinese Mandarin speech database published by Beijing Shell Shell Technology Co.,Ltd. It is 178 hours long and contains 400 voices from different accent area recorders. The original data can be obtained from [openslr](http://www.openslr.org/33). To simplify the process, here is a data set that has been preprocessed for download:
```
cd examples/aishell
sh prepare_data.sh
```
It includes the training data of the acoustic model and the auxiliary files used in the decoding process. After download data, the training process can be analyzed before starting the training.
```
sh profile.sh
```
Training:
```
sh train.sh
```
The default is to use 4 GPUs for training. In the actual process, the arguments such as batch_size and learning rate can be dynamically adjusted according to the number of available GPUs and the size of the memory. The typical curves for cost and accuracy during training are shown in Figure 2.
<p align="center">
<img src="images/learning_curve.png" height=480 width=640 hspace='10'/> <br />
图2 The learning curve of the acoustic model on the Aishell dataset
</p>
After training, you can perform infer to identify the text in the test data:
```
sh infer_by_ckpt.sh
```
It includes two important processes: the prediction of the acoustic model and decoding output. The following is an example of the decoded output:
```
...
BAC009S0764W0239 十一 五 期间 我 国 累计 境外 投资 七千亿 美元
BAC009S0765W0140 在 了解 送 方 的 资产 情况 与 需求 之后
BAC009S0915W0291 这 对 苹果 来说 不 是 件 容易 的 事 儿
BAC009S0769W0159 今年 土地 收入 预计 近 四万亿 元
BAC009S0907W0451 由 浦东 商店 作为 掩护
BAC009S0768W0128 土地 交易 可能 随着 供应 淡季 的 到来 而 降温
...
```
Each line corresponds to an output, starting with the keyword of the audio sample, followed by the decoded Chinese text separated by words. Run the script to evaluate the character error rate (CER) after decoding:
```
sh score_cer.sh
```
Its output sample is shown below:
```
Error rate[cer] = 0.101971 (10683/104765),
total 7176 sentences in hyp, 0 not presented in ref.
```
Using an acoustic model trained for about 20 epoch, you can get about 10% of CER recognition results on Aishell test data.
### Question and Contribution
TBD
DeepASR currently only has Aishell instances open, and we welcome users to test the complete training process on more data sets and contribute to this project.
运行本目录下的程序示例需要使用 PaddlePaddle v0.14及以上版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新 PaddlePaddle 安装版本。
[English](README.md)
运行本目录下的程序示例需要使用 PaddlePaddle v1.6.0及以上版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)中的说明更新 PaddlePaddle 安装版本。
---
DeepASR (Deep Automatic Speech Recognition) 是一个基于PaddlePaddle FLuid与[Kaldi](http://www.kaldi-asr.org)的语音识别系统。其利用Fluid框架完成语音识别中声学模型的配置和训练,并集成 Kaldi 的解码器。旨在方便已对 Kaldi 的较为熟悉的用户实现中声学模型的快速、大规模训练,并利用kaldi完成复杂的语音数据预处理和最终的解码过程。
### 目录
- [模型概览](#model-overview)
- [安装](#installation)
- [数据预处理](#data-reprocessing)
- [模型训练](#training)
- [训练过程中的时间分析](#perf-profiling)
- [预测和解码](#infer-decoding)
- [评估错误率](#scoring-error-rate)
- [Aishell 实例](#aishell-example)
- [欢迎贡献更多的实例](#how-to-contrib)
- [模型概览](#模型概览)
- [安装](#安装)
- [数据预处理](#数据预处理)
- [模型训练](#声学模型的训练)
- [训练过程中的时间分析](#训练过程中的时间分析)
- [预测和解码](#预测和解码)
- [错误率评估](#错误率评估)
- [Aishell 实例](#Aishell 实例)
- [欢迎贡献更多的实例](#欢迎贡献更多的实例)
### 模型概览
DeepASR的声学模型是一个单卷积层加多层层叠LSTMP 的结构,利用卷积来进行初步的特征提取,并用多层的LSTMP来对时序关系进行建模,所用到的损失函数是交叉熵[LSTMP](https://arxiv.org/abs/1402.1128)(LSTM with recurrent projection layer)是传统 LSTM 的拓展,在 LSTM 的基础上增加了一个映射层,将隐含层映射到较低的维度并输入下一个时间步,这种结构在大为减小 LSTM 的参数规模和计算复杂度的同时还提升了 LSTM 的性能表现。
DeepASR是一个单卷积层加多层层叠LSTMP 结构的声学模型,利用卷积来进行初步的特征提取,并用多层的LSTMP来对时序关系进行建模,使用交叉熵作为损失函数[LSTMP](https://arxiv.org/abs/1402.1128)(LSTM with recurrent projection layer)是传统 LSTM 的拓展,在 LSTM 的基础上增加了一个映射层,将隐含层映射到较低的维度并输入下一个时间步,这种结构在大为减小 LSTM 的参数规模和计算复杂度的同时还提升了 LSTM 的性能表现。
<p align="center">
<img src="images/lstmp.png" height=240 width=480 hspace='10'/> <br />
......@@ -66,7 +68,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py \
--mean_var global_mean_var \
--parallel
```
其中`train_feature.lst``train_label.lst`分别是训练数据集的特征列表文件和标注列表文件,类似的,`val_feature.lst``val_label.lst`对应的则是验证集的列表文件。实际训练过程中要正确指定建模单元大小、学习率等重要参数。关于这些参数的说明,请运行
其中`train_feature.lst``train_label.lst`分别是训练数据集的特征列表文件和标注列表文件,类似的,`val_feature.lst``val_label.lst`对应的则是验证集的列表文件。实际训练过程中要正确指定 LSTMP 隐藏层的大小、学习率等重要参数。关于这些参数的说明,请运行
```shell
python train.py --help
......@@ -75,7 +77,7 @@ python train.py --help
### 训练过程中的时间分析
利用Fluid提供的性能分析工具profiler,可对训练过程进行性能分析,获取网络中operator级别的执行时间
利用Fluid提供的性能分析工具profiler,可对训练过程进行性能分析,获取网络中operator级别的执行时间
```shell
CUDA_VISIBLE_DEVICES=0 python -u tools/profile.py \
......@@ -89,7 +91,7 @@ CUDA_VISIBLE_DEVICES=0 python -u tools/profile.py \
### 预测和解码
在充分训练好声学模型之后,利用训练过程中保存下来的模型checkpoint,可对输入的音频数据进行解码输出,得到声音到文字的识别结果
在充分训练好声学模型之后,利用训练过程中保存下来的模型checkpoint,可对输入的音频数据进行解码输出,得到声音到文字的识别结果.
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u infer_by_ckpt.py \
......@@ -101,7 +103,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -u infer_by_ckpt.py \
--parallel
```
### 评估错误率
### 错误率评估
对语音识别系统的评价常用的指标有词错误率(Word Error Rate, WER)和字错误率(Character Error Rate, CER), 在DeepASR中也实现了相关的度量工具,其运行方式为
......@@ -121,7 +123,7 @@ key3 text3
### Aishell 实例
本节以[Aishell数据集](http://www.aishelltech.com/kysjcp)为例,展示如何完成从数据预处理到解码输出。Aishell是由北京希尔贝克公司所开放的中文普通话语音数据集,时长178小时,包含了400名来自不同口音区域录制者的语音,原始数据可由[openslr](http://www.openslr.org/33)获取。为简化流程,这里提供了已完成预处理的数据集供下载:
本节以[Aishell数据集](http://www.aishelltech.com/kysjcp)为例,展示如何完成从数据预处理到解码输出的过程。Aishell是由北京希尔贝克公司所开放的中文普通话语音数据集,时长178小时,包含了400名来自不同口音区域录制者的语音,原始数据可由[openslr](http://www.openslr.org/33)获取。为简化流程,这里提供了已完成预处理的数据集供下载:
```
cd examples/aishell
......@@ -171,7 +173,7 @@ BAC009S0768W0128 土地 交易 可能 随着 供应 淡季 的 到来 而 降温
sh score_cer.sh
```
其输出类似于如下所示
其输出样例如下所示
```
Error rate[cer] = 0.101971 (10683/104765),
......
......@@ -8,6 +8,10 @@ from tblib import Traceback
import numpy as np
def lodtensor_to_ndarray(result):
return np.array(result), result.lod()
def to_lodtensor(data, place):
"""convert tensor to lodtensor
"""
......
......@@ -5,6 +5,7 @@ python -u ../../infer_by_ckpt.py --batch_size 96 \
--checkpoint checkpoints/deep_asr.latest.checkpoint \
--infer_feature_lst data/test_feature.lst \
--mean_var data/global_mean_var \
--device CPU \
--frame_dim 80 \
--class_num 3040 \
--num_threads 24 \
......
......@@ -16,7 +16,7 @@ import data_utils.augmentor.trans_delay as trans_delay
import data_utils.async_data_reader as reader
from data_utils.util import lodtensor_to_ndarray, split_infer_result
from model_utils.model import stacked_lstmp_model
from decoder.post_latgen_faster_mapped import Decoder
from post_latgen_faster_mapped import Decoder
from tools.error_rate import char_errors
......@@ -188,8 +188,17 @@ def infer_from_ckpt(args):
if not os.path.exists(args.checkpoint):
raise IOError("Invalid checkpoint!")
feature = fluid.data(
name='feature',
shape=[None, 3, 11, args.frame_dim],
dtype='float32',
lod_level=1)
label = fluid.data(
name='label', shape=[None, 1], dtype='int64', lod_level=1)
prediction, avg_cost, accuracy = stacked_lstmp_model(
frame_dim=args.frame_dim,
feature=feature,
label=label,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
stacked_num=args.stacked_num,
......
......@@ -150,15 +150,33 @@ def train(args):
train_program = fluid.Program()
train_startup = fluid.Program()
input_fields = {
'names': ['feature', 'label'],
'shapes': [[None, 3, 11, args.frame_dim], [None, 1]],
'dtypes': ['float32', 'int64'],
'lod_levels': [1, 1]
}
with fluid.program_guard(train_program, train_startup):
with fluid.unique_name.guard():
py_train_reader = fluid.layers.py_reader(
capacity=10,
shapes=([-1, 3, 11, args.frame_dim], [-1, 1]),
dtypes=['float32', 'int64'],
lod_levels=[1, 1],
name='train_reader')
feature, label = fluid.layers.read_file(py_train_reader)
inputs = [
fluid.data(
name=input_fields['names'][i],
shape=input_fields['shapes'][i],
dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i])
for i in range(len(input_fields['names']))
]
train_reader = fluid.io.DataLoader.from_generator(
feed_list=inputs,
capacity=64,
iterable=False,
use_double_buffer=True)
(feature, label) = inputs
prediction, avg_cost, accuracy = stacked_lstmp_model(
feature=feature,
label=label,
......@@ -179,13 +197,22 @@ def train(args):
test_startup = fluid.Program()
with fluid.program_guard(test_program, test_startup):
with fluid.unique_name.guard():
py_test_reader = fluid.layers.py_reader(
capacity=10,
shapes=([-1, 3, 11, args.frame_dim], [-1, 1]),
dtypes=['float32', 'int64'],
lod_levels=[1, 1],
name='test_reader')
feature, label = fluid.layers.read_file(py_test_reader)
inputs = [
fluid.data(
name=input_fields['names'][i],
shape=input_fields['shapes'][i],
dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i])
for i in range(len(input_fields['names']))
]
test_reader = fluid.io.DataLoader.from_generator(
feed_list=inputs,
capacity=64,
iterable=False,
use_double_buffer=True)
(feature, label) = inputs
prediction, avg_cost, accuracy = stacked_lstmp_model(
feature=feature,
label=label,
......@@ -237,7 +264,7 @@ def train(args):
args.minimum_batch_size):
yield batch_data_to_lod_tensors(args, data, fluid.CPUPlace())
py_train_reader.decorate_tensor_provider(train_data_provider)
train_reader.set_batch_generator(train_data_provider)
if (os.path.exists(args.val_feature_lst) and
os.path.exists(args.val_label_lst)):
......@@ -254,7 +281,7 @@ def train(args):
args.batch_size, args.minimum_batch_size):
yield batch_data_to_lod_tensors(args, data, fluid.CPUPlace())
py_test_reader.decorate_tensor_provider(test_data_provider)
test_reader.set_batch_generator(test_data_provider)
# validation
def test(exe):
......@@ -267,7 +294,7 @@ def train(args):
test_accs = []
while True:
if batch_id == 0:
py_test_reader.start()
test_reader.start()
try:
if args.parallel:
cost, acc = exe.run(
......@@ -283,7 +310,7 @@ def train(args):
test_accs.append(np.array(acc)[0])
batch_id += 1
except fluid.core.EOFException:
py_test_reader.reset()
test_reader.reset()
break
return np.mean(test_costs), np.mean(test_accs)
......@@ -293,7 +320,7 @@ def train(args):
batch_id = 0
while True:
if batch_id == 0:
py_train_reader.start()
train_reader.start()
to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
try:
if args.parallel:
......@@ -307,7 +334,7 @@ def train(args):
if to_print else [],
return_numpy=False)
except fluid.core.EOFException:
py_train_reader.reset()
train_reader.reset()
break
if to_print:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册