diff --git a/PaddleSpeech/DeepASR/README.md b/PaddleSpeech/DeepASR/README.md index 6b9913fd30a56ef2328bc62e9b36e496f6763430..6318b05650f96140d1a1757fbf9ddc6149ebeac5 100644 --- a/PaddleSpeech/DeepASR/README.md +++ b/PaddleSpeech/DeepASR/README.md @@ -1,14 +1,35 @@ -The minimum PaddlePaddle version needed for the code sample in this directory is the lastest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html). +[中文](README_cn.md) + +The minimum PaddlePaddle version needed for the code sample in this directory is v1.6.0. If you are on a version of PaddlePaddle earlier than this, [please update your installation](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html). + +--- + +DeepASR (Deep Automatic Speech Recognition) is a speech recognition system based on PaddlePaddle FLuid and Kaldi. It uses the Fluid framework to perform the configuration and training of acoustic models in speech recognition and integrates Kaldi decoder. It is designed to facilitate Kaldi users to implement the rapid and large-scale training of acoustic models, and to use Kaldi to complete complex speech data preprocessing and final decoding processes. + +### Content +- [Introduction](#Introduction) +- [Installation](#Installation) +- [Data reprocessing](#Data reprocessing) +- [Training](#Training) +- [Perf profiling](#Perf profiling) +- [Inference & Decoding](#Inference & Decoding) +- [Scoring error rate](#Scoring error rate) +- [Aishell example](#Aishell example) +- [Question and Contribution](#Question and Contribution) -## Deep Automatic Speech Recognition ### Introduction -TBD +DeepASR is an acoustic model of a single conv layer and multi-layer stacked LSTMP structure. Convolution is used for preliminary feature extraction, and multi-layer LSTMP is used to model the timing relationship, using cross entropy as the loss function. [LSTMP](https://arxiv.org/abs/1402.1128) (LSTM with recurrent projection layer) is an extension of the traditional LSTM. It adds a projection layer to the LSTM, projecting the hidden layer to the lower dimension and enters to the next time step. This structure greatly improves the performance of the LSTM while greatly reducing the parameter size and computational complexity of the LSTM. + +
+
+图1 LSTMP topology
+
+
+图2 The learning curve of the acoustic model on the Aishell dataset
+
@@ -66,7 +68,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py \
--mean_var global_mean_var \
--parallel
```
-其中`train_feature.lst`和`train_label.lst`分别是训练数据集的特征列表文件和标注列表文件,类似的,`val_feature.lst`和`val_label.lst`对应的则是验证集的列表文件。实际训练过程中要正确指定建模单元大小、学习率等重要参数。关于这些参数的说明,请运行
+其中`train_feature.lst`和`train_label.lst`分别是训练数据集的特征列表文件和标注列表文件,类似的,`val_feature.lst`和`val_label.lst`对应的则是验证集的列表文件。实际训练过程中要正确指定 LSTMP 隐藏层的大小、学习率等重要参数。关于这些参数的说明,请运行
```shell
python train.py --help
@@ -75,7 +77,7 @@ python train.py --help
### 训练过程中的时间分析
-利用Fluid提供的性能分析工具profiler,可对训练过程进行性能分析,获取网络中operator级别的执行时间
+利用Fluid提供的性能分析工具profiler,可对训练过程进行性能分析,获取网络中operator级别的执行时间。
```shell
CUDA_VISIBLE_DEVICES=0 python -u tools/profile.py \
@@ -89,7 +91,7 @@ CUDA_VISIBLE_DEVICES=0 python -u tools/profile.py \
### 预测和解码
-在充分训练好声学模型之后,利用训练过程中保存下来的模型checkpoint,可对输入的音频数据进行解码输出,得到声音到文字的识别结果
+在充分训练好声学模型之后,利用训练过程中保存下来的模型checkpoint,可对输入的音频数据进行解码输出,得到声音到文字的识别结果.
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u infer_by_ckpt.py \
@@ -101,7 +103,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -u infer_by_ckpt.py \
--parallel
```
-### 评估错误率
+### 错误率评估
对语音识别系统的评价常用的指标有词错误率(Word Error Rate, WER)和字错误率(Character Error Rate, CER), 在DeepASR中也实现了相关的度量工具,其运行方式为
@@ -121,7 +123,7 @@ key3 text3
### Aishell 实例
-本节以[Aishell数据集](http://www.aishelltech.com/kysjcp)为例,展示如何完成从数据预处理到解码输出。Aishell是由北京希尔贝克公司所开放的中文普通话语音数据集,时长178小时,包含了400名来自不同口音区域录制者的语音,原始数据可由[openslr](http://www.openslr.org/33)获取。为简化流程,这里提供了已完成预处理的数据集供下载:
+本节以[Aishell数据集](http://www.aishelltech.com/kysjcp)为例,展示如何完成从数据预处理到解码输出的过程。Aishell是由北京希尔贝克公司所开放的中文普通话语音数据集,时长178小时,包含了400名来自不同口音区域录制者的语音,原始数据可由[openslr](http://www.openslr.org/33)获取。为简化流程,这里提供了已完成预处理的数据集供下载:
```
cd examples/aishell
@@ -171,7 +173,7 @@ BAC009S0768W0128 土地 交易 可能 随着 供应 淡季 的 到来 而 降温
sh score_cer.sh
```
-其输出类似于如下所示
+其输出样例如下所示
```
Error rate[cer] = 0.101971 (10683/104765),
diff --git a/PaddleSpeech/DeepASR/data_utils/util.py b/PaddleSpeech/DeepASR/data_utils/util.py
index 4a5a8a3f1dad1c46ed773fd48d713e276717d5e5..8f52995acf653708bead9c147d3a9e7124fe08ce 100644
--- a/PaddleSpeech/DeepASR/data_utils/util.py
+++ b/PaddleSpeech/DeepASR/data_utils/util.py
@@ -8,6 +8,10 @@ from tblib import Traceback
import numpy as np
+def lodtensor_to_ndarray(result):
+ return np.array(result), result.lod()
+
+
def to_lodtensor(data, place):
"""convert tensor to lodtensor
"""
diff --git a/PaddleSpeech/DeepASR/examples/aishell/infer_by_ckpt.sh b/PaddleSpeech/DeepASR/examples/aishell/infer_by_ckpt.sh
index 2d31757451849afc1412421376484d2ad41962bc..5a68c09f500e127fc56444072ea38a83e7c9b60d 100644
--- a/PaddleSpeech/DeepASR/examples/aishell/infer_by_ckpt.sh
+++ b/PaddleSpeech/DeepASR/examples/aishell/infer_by_ckpt.sh
@@ -5,6 +5,7 @@ python -u ../../infer_by_ckpt.py --batch_size 96 \
--checkpoint checkpoints/deep_asr.latest.checkpoint \
--infer_feature_lst data/test_feature.lst \
--mean_var data/global_mean_var \
+ --device CPU \
--frame_dim 80 \
--class_num 3040 \
--num_threads 24 \
diff --git a/PaddleSpeech/DeepASR/infer_by_ckpt.py b/PaddleSpeech/DeepASR/infer_by_ckpt.py
index 1e0fb15c6d6f05aa1e054b37333b0fa0cb5cd8d9..b49b99970bdecb15bcf0a0d1605aef5917b05d97 100644
--- a/PaddleSpeech/DeepASR/infer_by_ckpt.py
+++ b/PaddleSpeech/DeepASR/infer_by_ckpt.py
@@ -16,7 +16,7 @@ import data_utils.augmentor.trans_delay as trans_delay
import data_utils.async_data_reader as reader
from data_utils.util import lodtensor_to_ndarray, split_infer_result
from model_utils.model import stacked_lstmp_model
-from decoder.post_latgen_faster_mapped import Decoder
+from post_latgen_faster_mapped import Decoder
from tools.error_rate import char_errors
@@ -188,8 +188,17 @@ def infer_from_ckpt(args):
if not os.path.exists(args.checkpoint):
raise IOError("Invalid checkpoint!")
+ feature = fluid.data(
+ name='feature',
+ shape=[None, 3, 11, args.frame_dim],
+ dtype='float32',
+ lod_level=1)
+ label = fluid.data(
+ name='label', shape=[None, 1], dtype='int64', lod_level=1)
+
prediction, avg_cost, accuracy = stacked_lstmp_model(
- frame_dim=args.frame_dim,
+ feature=feature,
+ label=label,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
stacked_num=args.stacked_num,
diff --git a/PaddleSpeech/DeepASR/tools/__init__.py b/PaddleSpeech/DeepASR/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PaddleSpeech/DeepASR/train.py b/PaddleSpeech/DeepASR/train.py
index 5ed467b242836d53e1f0f25247c3f7b72ff28bae..ac5925a9fa5fd7483b6dce16d1e718050e57a83d 100644
--- a/PaddleSpeech/DeepASR/train.py
+++ b/PaddleSpeech/DeepASR/train.py
@@ -150,15 +150,33 @@ def train(args):
train_program = fluid.Program()
train_startup = fluid.Program()
+ input_fields = {
+ 'names': ['feature', 'label'],
+ 'shapes': [[None, 3, 11, args.frame_dim], [None, 1]],
+ 'dtypes': ['float32', 'int64'],
+ 'lod_levels': [1, 1]
+ }
+
with fluid.program_guard(train_program, train_startup):
with fluid.unique_name.guard():
- py_train_reader = fluid.layers.py_reader(
- capacity=10,
- shapes=([-1, 3, 11, args.frame_dim], [-1, 1]),
- dtypes=['float32', 'int64'],
- lod_levels=[1, 1],
- name='train_reader')
- feature, label = fluid.layers.read_file(py_train_reader)
+
+ inputs = [
+ fluid.data(
+ name=input_fields['names'][i],
+ shape=input_fields['shapes'][i],
+ dtype=input_fields['dtypes'][i],
+ lod_level=input_fields['lod_levels'][i])
+ for i in range(len(input_fields['names']))
+ ]
+
+ train_reader = fluid.io.DataLoader.from_generator(
+ feed_list=inputs,
+ capacity=64,
+ iterable=False,
+ use_double_buffer=True)
+
+ (feature, label) = inputs
+
prediction, avg_cost, accuracy = stacked_lstmp_model(
feature=feature,
label=label,
@@ -179,13 +197,22 @@ def train(args):
test_startup = fluid.Program()
with fluid.program_guard(test_program, test_startup):
with fluid.unique_name.guard():
- py_test_reader = fluid.layers.py_reader(
- capacity=10,
- shapes=([-1, 3, 11, args.frame_dim], [-1, 1]),
- dtypes=['float32', 'int64'],
- lod_levels=[1, 1],
- name='test_reader')
- feature, label = fluid.layers.read_file(py_test_reader)
+ inputs = [
+ fluid.data(
+ name=input_fields['names'][i],
+ shape=input_fields['shapes'][i],
+ dtype=input_fields['dtypes'][i],
+ lod_level=input_fields['lod_levels'][i])
+ for i in range(len(input_fields['names']))
+ ]
+
+ test_reader = fluid.io.DataLoader.from_generator(
+ feed_list=inputs,
+ capacity=64,
+ iterable=False,
+ use_double_buffer=True)
+
+ (feature, label) = inputs
prediction, avg_cost, accuracy = stacked_lstmp_model(
feature=feature,
label=label,
@@ -237,7 +264,7 @@ def train(args):
args.minimum_batch_size):
yield batch_data_to_lod_tensors(args, data, fluid.CPUPlace())
- py_train_reader.decorate_tensor_provider(train_data_provider)
+ train_reader.set_batch_generator(train_data_provider)
if (os.path.exists(args.val_feature_lst) and
os.path.exists(args.val_label_lst)):
@@ -254,7 +281,7 @@ def train(args):
args.batch_size, args.minimum_batch_size):
yield batch_data_to_lod_tensors(args, data, fluid.CPUPlace())
- py_test_reader.decorate_tensor_provider(test_data_provider)
+ test_reader.set_batch_generator(test_data_provider)
# validation
def test(exe):
@@ -267,7 +294,7 @@ def train(args):
test_accs = []
while True:
if batch_id == 0:
- py_test_reader.start()
+ test_reader.start()
try:
if args.parallel:
cost, acc = exe.run(
@@ -283,7 +310,7 @@ def train(args):
test_accs.append(np.array(acc)[0])
batch_id += 1
except fluid.core.EOFException:
- py_test_reader.reset()
+ test_reader.reset()
break
return np.mean(test_costs), np.mean(test_accs)
@@ -293,7 +320,7 @@ def train(args):
batch_id = 0
while True:
if batch_id == 0:
- py_train_reader.start()
+ train_reader.start()
to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
try:
if args.parallel:
@@ -307,7 +334,7 @@ def train(args):
if to_print else [],
return_numpy=False)
except fluid.core.EOFException:
- py_train_reader.reset()
+ train_reader.reset()
break
if to_print: