提交 837aaa9e 编写于 作者: Y Yibing Liu

Update training & profiling scripts for new config

上级 86235262
The minimum PaddlePaddle version needed for the code sample in this directory is the lastest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
### TODO
## Deep Automatic Speech Recognition
This project is still under active development.
### Introduction
TBD
### Installation
#### Kaldi
The decoder depends on [kaldi](https://github.com/kaldi-asr/kaldi), install it by flowing its instructions. Then
```shell
export KALDI_ROOT=<absolute path to kaldi>
```
#### Decoder
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/fluid/DeepASR/decoder
sh setup.sh
```
### Data reprocessing
TBD
### Training
TBD
### Inference & Decoding
TBD
### Question and Contribution
TBD
export CUDA_VISIBLE_DEVICES=2,3,4,5
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -u ../../tools/profile.py --feature_lst data/train_feature.lst \
--label_lst data/train_label.lst \
--mean_var data/aishell/global_mean_var \
--parallel \
--frame_dim 2640 \
--class_num 101 \
--frame_dim 80 \
--class_num 3040 \
export CUDA_VISIBLE_DEVICES=2,3,4,5
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -u ../../train.py --train_feature_lst data/train_feature.lst \
--train_label_lst data/train_label.lst \
--val_feature_lst data/val_feature.lst \
--val_label_lst data/val_label.lst \
--mean_var data/aishell/global_mean_var \
--checkpoints checkpoints \
--frame_dim 2640 \
--class_num 101 \
--frame_dim 80 \
--class_num 3040 \
--infer_models '' \
--batch_size 128 \
--learning_rate 0.00016 \
--batch_size 64 \
--learning_rate 6.4e-5 \
--parallel
~
......@@ -147,7 +147,7 @@ def profile(args):
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
trans_splice.TransSplice(), trans_delay.TransDelay(5)
trans_splice.TransSplice(5, 5), trans_delay.TransDelay(5)
]
data_reader = reader.AsyncDataReader(args.feature_lst, args.label_lst, -1)
......@@ -170,6 +170,8 @@ def profile(args):
frames_seen = 0
# load_data
(features, labels, lod, _) = batch_data
features = np.reshape(features, (-1, 11, 3, args.frame_dim))
features = np.transpose(features, (0, 2, 1, 3))
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册