Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
c40b6f40
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c40b6f40
编写于
12月 28, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor the train and test config,test=asr
上级
425b085f
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
229 addition
and
195 deletion
+229
-195
examples/aishell/asr1/conf/conformer.yaml
examples/aishell/asr1/conf/conformer.yaml
+88
-92
examples/aishell/asr1/conf/decode.yaml
examples/aishell/asr1/conf/decode.yaml
+12
-0
examples/aishell/asr1/local/align.sh
examples/aishell/asr1/local/align.sh
+6
-4
examples/aishell/asr1/local/test.sh
examples/aishell/asr1/local/test.sh
+7
-4
examples/aishell/asr1/local/test_wav.sh
examples/aishell/asr1/local/test_wav.sh
+7
-5
examples/aishell/asr1/run.sh
examples/aishell/asr1/run.sh
+4
-3
paddlespeech/s2t/exps/u2/bin/alignment.py
paddlespeech/s2t/exps/u2/bin/alignment.py
+6
-0
paddlespeech/s2t/exps/u2/bin/test.py
paddlespeech/s2t/exps/u2/bin/test.py
+8
-2
paddlespeech/s2t/exps/u2/bin/test_wav.py
paddlespeech/s2t/exps/u2/bin/test_wav.py
+18
-18
paddlespeech/s2t/exps/u2/config.py
paddlespeech/s2t/exps/u2/config.py
+5
-6
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+54
-55
paddlespeech/s2t/training/cli.py
paddlespeech/s2t/training/cli.py
+8
-0
paddlespeech/s2t/training/trainer.py
paddlespeech/s2t/training/trainer.py
+5
-5
paddlespeech/s2t/utils/utility.py
paddlespeech/s2t/utils/utility.py
+1
-1
未找到文件。
examples/aishell/asr1/conf/conformer.yaml
浏览文件 @
c40b6f40
# network architecture
############################################
model
:
# Network Architecture #
cmvn_file
:
############################################
cmvn_file_type
:
"
json"
#model:
# encoder related
cmvn_file
:
encoder
:
conformer
cmvn_file_type
:
"
json"
encoder_conf
:
# encoder related
output_size
:
256
# dimension of attention
encoder
:
conformer
attention_heads
:
4
encoder_conf
:
linear_units
:
2048
# the number of units of position-wise feed forward
output_size
:
256
# dimension of attention
num_blocks
:
12
# the number of encoder blocks
attention_heads
:
4
dropout_rate
:
0.1
linear_units
:
2048
# the number of units of position-wise feed forward
positional_dropout_rate
:
0.1
num_blocks
:
12
# the number of encoder blocks
attention_dropout_rate
:
0.0
dropout_rate
:
0.1
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
positional_dropout_rate
:
0.1
normalize_before
:
True
attention_dropout_rate
:
0.0
cnn_module_kernel
:
15
input_layer
:
conv2d
# encoder input type, you can chose conv2d, conv2d6 and conv2d8
use_cnn_module
:
True
normalize_before
:
True
activation_type
:
'
swish'
cnn_module_kernel
:
15
pos_enc_layer_type
:
'
rel_pos'
use_cnn_module
:
True
selfattention_layer_type
:
'
rel_selfattn'
activation_type
:
'
swish'
pos_enc_layer_type
:
'
rel_pos'
selfattention_layer_type
:
'
rel_selfattn'
# decoder related
# decoder related
decoder
:
transformer
decoder
:
transformer
decoder_conf
:
decoder_conf
:
attention_heads
:
4
attention_heads
:
4
linear_units
:
2048
linear_units
:
2048
num_blocks
:
6
num_blocks
:
6
dropout_rate
:
0.1
dropout_rate
:
0.1
positional_dropout_rate
:
0.1
positional_dropout_rate
:
0.1
self_attention_dropout_rate
:
0.0
self_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
src_attention_dropout_rate
:
0.0
# hybrid CTC/attention
# hybrid CTC/attention
model_conf
:
model_conf
:
ctc_weight
:
0.3
ctc_weight
:
0.3
lsm_weight
:
0.1
# label smoothing option
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
length_normalized_loss
:
false
data
:
###########################################
train_manifest
:
data/manifest.train
# Data #
dev_manifest
:
data/manifest.dev
###########################################
test_manifest
:
data/manifest.test
#data:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test
###########################################
# Dataloader #
###########################################
#collator:
vocab_filepath
:
data/lang_char/vocab.txt
unit_type
:
'
char'
augmentation_config
:
conf/preprocess.yaml
spm_model_prefix
:
'
'
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
64
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
0
subsampling_factor
:
1
num_encs
:
1
collator
:
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
# training #
unit_type
:
'
char'
###########################################
augmentation_config
:
conf/preprocess.yaml
#training:
feat_dim
:
80
n_epoch
:
240
stride_ms
:
10.0
accum_grad
:
2
window_ms
:
25.0
global_grad_clip
:
5.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
optim
:
adam
batch_size
:
64
optim_conf
:
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
lr
:
0.002
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
weight_decay
:
1.0e-6
minibatches
:
0
# for debug
scheduler
:
warmuplr
batch_count
:
auto
scheduler_conf
:
batch_bins
:
0
warmup_steps
:
25000
batch_frames_in
:
0
lr_decay
:
1.0
batch_frames_out
:
0
log_interval
:
100
batch_frames_inout
:
0
checkpoint
:
num_workers
:
0
kbest_n
:
50
subsampling_factor
:
1
latest_n
:
5
num_encs
:
1
training
:
n_epoch
:
240
accum_grad
:
2
global_grad_clip
:
5.0
optim
:
adam
optim_conf
:
lr
:
0.002
weight_decay
:
1e-6
scheduler
:
warmuplr
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
decoding
:
beam_size
:
10
batch_size
:
128
error_rate_type
:
cer
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight
:
0.5
# ctc weight for attention rescoring decode mode.
decoding_chunk_size
:
-1
# decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks
:
-1
# number of left chunks for decoding. Defaults to -1.
simulate_streaming
:
False
# simulate streaming inference. Defaults to False.
examples/aishell/asr1/conf/decode.yaml
0 → 100644
浏览文件 @
c40b6f40
#decoding:
beam_size
:
10
decode_batch_size
:
128
error_rate_type
:
cer
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight
:
0.5
# ctc weight for attention rescoring decode mode.
decoding_chunk_size
:
-1
# decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks
:
-1
# number of left chunks for decoding. Defaults to -1.
simulate_streaming
:
False
# simulate streaming inference. Defaults to False.
examples/aishell/asr1/local/align.sh
浏览文件 @
c40b6f40
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
echo
"usage:
${
0
}
config_path
decode_config_path
ckpt_path_prefix"
exit
-1
exit
-1
fi
fi
...
@@ -9,7 +9,8 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
...
@@ -9,7 +9,8 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo
"using
$ngpu
gpus..."
echo
"using
$ngpu
gpus..."
config_path
=
$1
config_path
=
$1
ckpt_prefix
=
$2
decode_config_path
=
$2
ckpt_prefix
=
$3
batch_size
=
1
batch_size
=
1
output_dir
=
${
ckpt_prefix
}
output_dir
=
${
ckpt_prefix
}
...
@@ -20,9 +21,10 @@ mkdir -p ${output_dir}
...
@@ -20,9 +21,10 @@ mkdir -p ${output_dir}
python3
-u
${
BIN_DIR
}
/alignment.py
\
python3
-u
${
BIN_DIR
}
/alignment.py
\
--ngpu
${
ngpu
}
\
--ngpu
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--decode_config
${
decode_config_path
}
\
--result_file
${
output_dir
}
/
${
type
}
.align
\
--result_file
${
output_dir
}
/
${
type
}
.align
\
--checkpoint_path
${
ckpt_prefix
}
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.batch_size
${
batch_size
}
--opts
decoding.
decode_
batch_size
${
batch_size
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in ctc alignment!"
echo
"Failed in ctc alignment!"
...
...
examples/aishell/asr1/local/test.sh
浏览文件 @
c40b6f40
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
echo
"usage:
${
0
}
config_path
decode_config_path
ckpt_path_prefix"
exit
-1
exit
-1
fi
fi
...
@@ -9,7 +9,8 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
...
@@ -9,7 +9,8 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo
"using
$ngpu
gpus..."
echo
"using
$ngpu
gpus..."
config_path
=
$1
config_path
=
$1
ckpt_prefix
=
$2
decode_config_path
=
$2
ckpt_prefix
=
$3
chunk_mode
=
false
chunk_mode
=
false
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
...
@@ -36,10 +37,11 @@ for type in attention ctc_greedy_search; do
...
@@ -36,10 +37,11 @@ for type in attention ctc_greedy_search; do
python3
-u
${
BIN_DIR
}
/test.py
\
python3
-u
${
BIN_DIR
}
/test.py
\
--ngpu
${
ngpu
}
\
--ngpu
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--decode_config
${
decode_config_path
}
\
--result_file
${
output_dir
}
/
${
type
}
.rsl
\
--result_file
${
output_dir
}
/
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
\
--opts
decoding.decoding_method
${
type
}
\
--opts
decoding.batch_size
${
batch_size
}
--opts
decoding.
decode_
batch_size
${
batch_size
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
echo
"Failed in evaluation!"
...
@@ -55,6 +57,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
...
@@ -55,6 +57,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
python3
-u
${
BIN_DIR
}
/test.py
\
python3
-u
${
BIN_DIR
}
/test.py
\
--ngpu
${
ngpu
}
\
--ngpu
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--decode_config
${
decode_config_path
}
\
--result_file
${
output_dir
}
/
${
type
}
.rsl
\
--result_file
${
output_dir
}
/
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
\
--opts
decoding.decoding_method
${
type
}
\
...
...
examples/aishell/asr1/local/test_wav.sh
浏览文件 @
c40b6f40
#!/bin/bash
#!/bin/bash
if
[
$#
!=
3
]
;
then
if
[
$#
!=
4
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix audio_file"
echo
"usage:
${
0
}
config_path
decode_config_path
ckpt_path_prefix audio_file"
exit
-1
exit
-1
fi
fi
...
@@ -9,8 +9,9 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
...
@@ -9,8 +9,9 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo
"using
$ngpu
gpus..."
echo
"using
$ngpu
gpus..."
config_path
=
$1
config_path
=
$1
ckpt_prefix
=
$2
decode_config_path
=
$2
audio_file
=
$3
ckpt_prefix
=
$3
audio_file
=
$4
mkdir
-p
data
mkdir
-p
data
wget
-nc
https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav
-P
data/
wget
-nc
https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav
-P
data/
...
@@ -42,10 +43,11 @@ for type in attention_rescoring; do
...
@@ -42,10 +43,11 @@ for type in attention_rescoring; do
python3
-u
${
BIN_DIR
}
/test_wav.py
\
python3
-u
${
BIN_DIR
}
/test_wav.py
\
--ngpu
${
ngpu
}
\
--ngpu
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--decode_config
${
decode_config_path
}
\
--result_file
${
output_dir
}
/
${
type
}
.rsl
\
--result_file
${
output_dir
}
/
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
\
--opts
decoding.decoding_method
${
type
}
\
--opts
decoding.batch_size
${
batch_size
}
\
--opts
decoding.
decode_
batch_size
${
batch_size
}
\
--audio_file
${
audio_file
}
--audio_file
${
audio_file
}
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
...
...
examples/aishell/asr1/run.sh
浏览文件 @
c40b6f40
...
@@ -6,6 +6,7 @@ gpus=0,1,2,3
...
@@ -6,6 +6,7 @@ gpus=0,1,2,3
stage
=
0
stage
=
0
stop_stage
=
50
stop_stage
=
50
conf_path
=
conf/conformer.yaml
conf_path
=
conf/conformer.yaml
decode_conf_path
=
conf/decode.yaml
avg_num
=
20
avg_num
=
20
audio_file
=
data/demo_01_03.wav
audio_file
=
data/demo_01_03.wav
...
@@ -32,18 +33,18 @@ fi
...
@@ -32,18 +33,18 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
${
decode_conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# ctc alignment of test data
# ctc alignment of test data
CUDA_VISIBLE_DEVICES
=
0 ./local/align.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/align.sh
${
conf_path
}
${
decode_conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
fi
# Optionally, you can add LM and test it with runtime.
# Optionally, you can add LM and test it with runtime.
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
# test a single .wav file
# test a single .wav file
CUDA_VISIBLE_DEVICES
=
0 ./local/test_wav.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
audio_file
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/test_wav.sh
${
conf_path
}
${
decode_conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
${
audio_file
}
||
exit
-1
fi
fi
# Not supported at now!!!
# Not supported at now!!!
...
...
paddlespeech/s2t/exps/u2/bin/alignment.py
浏览文件 @
c40b6f40
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Alignment for U2 model."""
"""Alignment for U2 model."""
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.exps.u2.config
import
get_cfg_defaults
from
paddlespeech.s2t.exps.u2.config
import
get_cfg_defaults
from
paddlespeech.s2t.exps.u2.model
import
U2Tester
as
Tester
from
paddlespeech.s2t.exps.u2.model
import
U2Tester
as
Tester
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.training.cli
import
default_argument_parser
...
@@ -41,6 +43,10 @@ if __name__ == "__main__":
...
@@ -41,6 +43,10 @@ if __name__ == "__main__":
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
()
if
args
.
config
:
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
merge_from_file
(
args
.
config
)
if
args
.
decode_config
:
decode_confs
=
CfgNode
(
new_allowed
=
True
)
decode_confs
.
merge_from_file
(
args
.
decode_config
)
config
.
decoding
=
decode_confs
if
args
.
opts
:
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
config
.
freeze
()
...
...
paddlespeech/s2t/exps/u2/bin/test.py
浏览文件 @
c40b6f40
...
@@ -14,12 +14,14 @@
...
@@ -14,12 +14,14 @@
"""Evaluation for U2 model."""
"""Evaluation for U2 model."""
import
cProfile
import
cProfile
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.exps.u2.config
import
get_cfg_defaults
from
paddlespeech.s2t.exps.u2.config
import
get_cfg_defaults
from
paddlespeech.s2t.exps.u2.model
import
U2Tester
as
Tester
from
paddlespeech.s2t.exps.u2.model
import
U2Tester
as
Tester
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.utils.utility
import
print_arguments
from
paddlespeech.s2t.utils.utility
import
print_arguments
# TODO(hui zhang): dynamic load
# TODO(hui zhang): dynamic load
def
main_sp
(
config
,
args
):
def
main_sp
(
config
,
args
):
...
@@ -35,7 +37,7 @@ def main(config, args):
...
@@ -35,7 +37,7 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
# save asr result to
# save asr result to
parser
.
add_argument
(
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -45,6 +47,10 @@ if __name__ == "__main__":
...
@@ -45,6 +47,10 @@ if __name__ == "__main__":
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
()
if
args
.
config
:
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
merge_from_file
(
args
.
config
)
if
args
.
decode_config
:
decode_confs
=
CfgNode
(
new_allowed
=
True
)
decode_confs
.
merge_from_file
(
args
.
decode_config
)
config
.
decoding
=
decode_confs
if
args
.
opts
:
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
config
.
freeze
()
...
...
paddlespeech/s2t/exps/u2/bin/test_wav.py
浏览文件 @
c40b6f40
...
@@ -18,6 +18,7 @@ from pathlib import Path
...
@@ -18,6 +18,7 @@ from pathlib import Path
import
paddle
import
paddle
import
soundfile
import
soundfile
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.exps.u2.config
import
get_cfg_defaults
from
paddlespeech.s2t.exps.u2.config
import
get_cfg_defaults
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
...
@@ -36,23 +37,22 @@ class U2Infer():
...
@@ -36,23 +37,22 @@ class U2Infer():
self
.
args
=
args
self
.
args
=
args
self
.
config
=
config
self
.
config
=
config
self
.
audio_file
=
args
.
audio_file
self
.
audio_file
=
args
.
audio_file
self
.
sr
=
config
.
collator
.
target_sample_rate
self
.
preprocess_conf
=
config
.
collator
.
augmentation_config
self
.
preprocess_conf
=
config
.
augmentation_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
text_feature
=
TextFeaturizer
(
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
config
.
collator
.
unit_type
,
unit_type
=
config
.
unit_type
,
vocab
=
config
.
collator
.
vocab_filepath
,
vocab
=
config
.
vocab_filepath
,
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
)
spm_model_prefix
=
config
.
spm_model_prefix
)
paddle
.
set_device
(
'gpu'
if
self
.
args
.
ngpu
>
0
else
'cpu'
)
paddle
.
set_device
(
'gpu'
if
self
.
args
.
ngpu
>
0
else
'cpu'
)
# model
# model
model_conf
=
config
.
model
model_conf
=
config
with
UpdateConfig
(
model_conf
):
with
UpdateConfig
(
model_conf
):
model_conf
.
input_dim
=
config
.
collator
.
feat_dim
model_conf
.
input_dim
=
config
.
feat_dim
model_conf
.
output_dim
=
self
.
text_feature
.
vocab_size
model_conf
.
output_dim
=
self
.
text_feature
.
vocab_size
model
=
U2Model
.
from_config
(
model_conf
)
model
=
U2Model
.
from_config
(
model_conf
)
self
.
model
=
model
self
.
model
=
model
...
@@ -70,10 +70,6 @@ class U2Infer():
...
@@ -70,10 +70,6 @@ class U2Infer():
# read
# read
audio
,
sample_rate
=
soundfile
.
read
(
audio
,
sample_rate
=
soundfile
.
read
(
self
.
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
self
.
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
if
sample_rate
!=
self
.
sr
:
logger
.
error
(
f
"sample rate error:
{
sample_rate
}
, need
{
self
.
sr
}
"
)
sys
.
exit
(
-
1
)
audio
=
audio
[:,
0
]
audio
=
audio
[:,
0
]
logger
.
info
(
f
"audio shape:
{
audio
.
shape
}
"
)
logger
.
info
(
f
"audio shape:
{
audio
.
shape
}
"
)
...
@@ -85,17 +81,17 @@ class U2Infer():
...
@@ -85,17 +81,17 @@ class U2Infer():
ilen
=
paddle
.
to_tensor
(
feat
.
shape
[
0
])
ilen
=
paddle
.
to_tensor
(
feat
.
shape
[
0
])
xs
=
paddle
.
to_tensor
(
feat
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
xs
=
paddle
.
to_tensor
(
feat
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
cf
g
=
self
.
config
.
decoding
decode_confi
g
=
self
.
config
.
decoding
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
=
self
.
model
.
decode
(
xs
,
xs
,
ilen
,
ilen
,
text_feature
=
self
.
text_feature
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cf
g
.
decoding_method
,
decoding_method
=
decode_confi
g
.
decoding_method
,
beam_size
=
cf
g
.
beam_size
,
beam_size
=
decode_confi
g
.
beam_size
,
ctc_weight
=
cf
g
.
ctc_weight
,
ctc_weight
=
decode_confi
g
.
ctc_weight
,
decoding_chunk_size
=
cf
g
.
decoding_chunk_size
,
decoding_chunk_size
=
decode_confi
g
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cf
g
.
num_decoding_left_chunks
,
num_decoding_left_chunks
=
decode_confi
g
.
num_decoding_left_chunks
,
simulate_streaming
=
cf
g
.
simulate_streaming
)
simulate_streaming
=
decode_confi
g
.
simulate_streaming
)
rsl
=
result_transcripts
[
0
][
0
]
rsl
=
result_transcripts
[
0
][
0
]
utt
=
Path
(
self
.
audio_file
).
name
utt
=
Path
(
self
.
audio_file
).
name
logger
.
info
(
f
"hyp:
{
utt
}
{
result_transcripts
[
0
][
0
]
}
"
)
logger
.
info
(
f
"hyp:
{
utt
}
{
result_transcripts
[
0
][
0
]
}
"
)
...
@@ -136,6 +132,10 @@ if __name__ == "__main__":
...
@@ -136,6 +132,10 @@ if __name__ == "__main__":
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
()
if
args
.
config
:
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
merge_from_file
(
args
.
config
)
if
args
.
decode_config
:
decode_confs
=
CfgNode
(
new_allowed
=
True
)
decode_confs
.
merge_from_file
(
args
.
decode_config
)
config
.
decoding
=
decode_confs
if
args
.
opts
:
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
config
.
freeze
()
...
...
paddlespeech/s2t/exps/u2/config.py
浏览文件 @
c40b6f40
...
@@ -19,19 +19,18 @@ from paddlespeech.s2t.io.collator import SpeechCollator
...
@@ -19,19 +19,18 @@ from paddlespeech.s2t.io.collator import SpeechCollator
from
paddlespeech.s2t.io.dataset
import
ManifestDataset
from
paddlespeech.s2t.io.dataset
import
ManifestDataset
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.models.u2
import
U2Model
_C
=
CfgNode
()
_C
=
CfgNode
(
new_allowed
=
True
)
_C
.
data
=
ManifestDataset
.
params
(
)
ManifestDataset
.
params
(
_C
)
_C
.
collator
=
SpeechCollator
.
params
(
)
SpeechCollator
.
params
(
_C
)
_C
.
model
=
U2Model
.
params
(
)
U2Model
.
params
(
_C
)
_C
.
training
=
U2Trainer
.
params
(
)
U2Trainer
.
params
(
_C
)
_C
.
decoding
=
U2Tester
.
params
()
_C
.
decoding
=
U2Tester
.
params
()
def
get_cfg_defaults
():
def
get_cfg_defaults
():
"""Get a yacs CfgNode object with default values for my_project."""
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# Return a clone so that the defaults will not be altered
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
c40b6f40
...
@@ -77,7 +77,7 @@ class U2Trainer(Trainer):
...
@@ -77,7 +77,7 @@ class U2Trainer(Trainer):
super
().
__init__
(
config
,
args
)
super
().
__init__
(
config
,
args
)
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
train_conf
=
self
.
config
start
=
time
.
time
()
start
=
time
.
time
()
# forward
# forward
...
@@ -120,7 +120,7 @@ class U2Trainer(Trainer):
...
@@ -120,7 +120,7 @@ class U2Trainer(Trainer):
for
k
,
v
in
losses_np
.
items
():
for
k
,
v
in
losses_np
.
items
():
report
(
k
,
v
)
report
(
k
,
v
)
report
(
"batch_size"
,
self
.
config
.
collator
.
batch_size
)
report
(
"batch_size"
,
self
.
config
.
batch_size
)
report
(
"accum"
,
train_conf
.
accum_grad
)
report
(
"accum"
,
train_conf
.
accum_grad
)
report
(
"step_cost"
,
iteration_time
)
report
(
"step_cost"
,
iteration_time
)
...
@@ -153,7 +153,7 @@ class U2Trainer(Trainer):
...
@@ -153,7 +153,7 @@ class U2Trainer(Trainer):
if
ctc_loss
:
if
ctc_loss
:
valid_losses
[
'val_ctc_loss'
].
append
(
float
(
ctc_loss
))
valid_losses
[
'val_ctc_loss'
].
append
(
float
(
ctc_loss
))
if
(
i
+
1
)
%
self
.
config
.
training
.
log_interval
==
0
:
if
(
i
+
1
)
%
self
.
config
.
log_interval
==
0
:
valid_dump
=
{
k
:
np
.
mean
(
v
)
for
k
,
v
in
valid_losses
.
items
()}
valid_dump
=
{
k
:
np
.
mean
(
v
)
for
k
,
v
in
valid_losses
.
items
()}
valid_dump
[
'val_history_loss'
]
=
total_loss
/
num_seen_utts
valid_dump
[
'val_history_loss'
]
=
total_loss
/
num_seen_utts
...
@@ -182,7 +182,7 @@ class U2Trainer(Trainer):
...
@@ -182,7 +182,7 @@ class U2Trainer(Trainer):
self
.
before_train
()
self
.
before_train
()
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
try
:
try
:
...
@@ -214,8 +214,7 @@ class U2Trainer(Trainer):
...
@@ -214,8 +214,7 @@ class U2Trainer(Trainer):
k
.
split
(
','
))
==
2
else
""
k
.
split
(
','
))
==
2
else
""
msg
+=
","
msg
+=
","
msg
=
msg
[:
-
1
]
# remove the last ","
msg
=
msg
[:
-
1
]
# remove the last ","
if
(
batch_index
+
1
if
(
batch_index
+
1
)
%
self
.
config
.
log_interval
==
0
:
)
%
self
.
config
.
training
.
log_interval
==
0
:
logger
.
info
(
msg
)
logger
.
info
(
msg
)
data_start_time
=
time
.
time
()
data_start_time
=
time
.
time
()
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -252,29 +251,29 @@ class U2Trainer(Trainer):
...
@@ -252,29 +251,29 @@ class U2Trainer(Trainer):
if
self
.
train
:
if
self
.
train
:
# train/valid dataset, return token ids
# train/valid dataset, return token ids
self
.
train_loader
=
BatchDataLoader
(
self
.
train_loader
=
BatchDataLoader
(
json_file
=
config
.
data
.
train_manifest
,
json_file
=
config
.
train_manifest
,
train_mode
=
True
,
train_mode
=
True
,
sortagrad
=
config
.
collator
.
sortagrad
,
sortagrad
=
config
.
sortagrad
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
collator
.
maxlen_in
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
collator
.
maxlen_out
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
config
.
collator
.
minibatches
,
minibatches
=
config
.
minibatches
,
mini_batch_size
=
self
.
args
.
ngpu
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
config
.
collator
.
batch_count
,
batch_count
=
config
.
batch_count
,
batch_bins
=
config
.
collator
.
batch_bins
,
batch_bins
=
config
.
batch_bins
,
batch_frames_in
=
config
.
collator
.
batch_frames_in
,
batch_frames_in
=
config
.
batch_frames_in
,
batch_frames_out
=
config
.
collator
.
batch_frames_out
,
batch_frames_out
=
config
.
batch_frames_out
,
batch_frames_inout
=
config
.
collator
.
batch_frames_inout
,
batch_frames_inout
=
config
.
batch_frames_inout
,
preprocess_conf
=
config
.
collator
.
augmentation_config
,
preprocess_conf
=
config
.
augmentation_config
,
n_iter_processes
=
config
.
collator
.
num_workers
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
)
self
.
valid_loader
=
BatchDataLoader
(
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
d
ata
.
d
ev_manifest
,
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
train_mode
=
False
,
sortagrad
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
minibatches
=
0
,
...
@@ -284,18 +283,18 @@ class U2Trainer(Trainer):
...
@@ -284,18 +283,18 @@ class U2Trainer(Trainer):
batch_frames_in
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
augmentation_config
,
preprocess_conf
=
config
.
augmentation_config
,
n_iter_processes
=
config
.
collator
.
num_workers
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
# test dataset, return raw text
# test dataset, return raw text
self
.
test_loader
=
BatchDataLoader
(
self
.
test_loader
=
BatchDataLoader
(
json_file
=
config
.
data
.
test_manifest
,
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
train_mode
=
False
,
sortagrad
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
decoding
.
batch_size
,
batch_size
=
config
.
decoding
.
decode_
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
minibatches
=
0
,
...
@@ -305,16 +304,16 @@ class U2Trainer(Trainer):
...
@@ -305,16 +304,16 @@ class U2Trainer(Trainer):
batch_frames_in
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
augmentation_config
,
preprocess_conf
=
config
.
augmentation_config
,
n_iter_processes
=
1
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
)
self
.
align_loader
=
BatchDataLoader
(
self
.
align_loader
=
BatchDataLoader
(
json_file
=
config
.
data
.
test_manifest
,
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
train_mode
=
False
,
sortagrad
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
decoding
.
batch_size
,
batch_size
=
config
.
decoding
.
decode_
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
minibatches
=
0
,
...
@@ -324,7 +323,7 @@ class U2Trainer(Trainer):
...
@@ -324,7 +323,7 @@ class U2Trainer(Trainer):
batch_frames_in
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
augmentation_config
,
preprocess_conf
=
config
.
augmentation_config
,
n_iter_processes
=
1
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
)
...
@@ -332,7 +331,7 @@ class U2Trainer(Trainer):
...
@@ -332,7 +331,7 @@ class U2Trainer(Trainer):
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
model_conf
=
config
.
model
model_conf
=
config
with
UpdateConfig
(
model_conf
):
with
UpdateConfig
(
model_conf
):
if
self
.
train
:
if
self
.
train
:
...
@@ -355,7 +354,7 @@ class U2Trainer(Trainer):
...
@@ -355,7 +354,7 @@ class U2Trainer(Trainer):
if
not
self
.
train
:
if
not
self
.
train
:
return
return
train_config
=
config
.
training
train_config
=
config
optim_type
=
train_config
.
optim
optim_type
=
train_config
.
optim
optim_conf
=
train_config
.
optim_conf
optim_conf
=
train_config
.
optim_conf
scheduler_type
=
train_config
.
scheduler
scheduler_type
=
train_config
.
scheduler
...
@@ -375,7 +374,7 @@ class U2Trainer(Trainer):
...
@@ -375,7 +374,7 @@ class U2Trainer(Trainer):
config
,
config
,
parameters
,
parameters
,
lr_scheduler
=
None
,
):
lr_scheduler
=
None
,
):
train_config
=
config
.
training
train_config
=
config
optim_type
=
train_config
.
optim
optim_type
=
train_config
.
optim
optim_conf
=
train_config
.
optim_conf
optim_conf
=
train_config
.
optim_conf
scheduler_type
=
train_config
.
scheduler
scheduler_type
=
train_config
.
scheduler
...
@@ -415,7 +414,7 @@ class U2Tester(U2Trainer):
...
@@ -415,7 +414,7 @@ class U2Tester(U2Trainer):
error_rate_type
=
'wer'
,
# Error rate type for evaluation. Options `wer`, 'cer'
error_rate_type
=
'wer'
,
# Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch
=
8
,
# # of CPUs for beam search.
num_proc_bsearch
=
8
,
# # of CPUs for beam search.
beam_size
=
10
,
# Beam search width.
beam_size
=
10
,
# Beam search width.
batch_size
=
16
,
# decoding batch size
decode_
batch_size
=
16
,
# decoding batch size
ctc_weight
=
0.0
,
# ctc weight for attention rescoring decode mode.
ctc_weight
=
0.0
,
# ctc weight for attention rescoring decode mode.
decoding_chunk_size
=-
1
,
# decoding chunk size. Defaults to -1.
decoding_chunk_size
=-
1
,
# decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# <0: for decoding, use full chunk.
...
@@ -432,9 +431,9 @@ class U2Tester(U2Trainer):
...
@@ -432,9 +431,9 @@ class U2Tester(U2Trainer):
def
__init__
(
self
,
config
,
args
):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
super
().
__init__
(
config
,
args
)
self
.
text_feature
=
TextFeaturizer
(
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
config
.
collator
.
vocab_filepath
,
vocab
=
self
.
config
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
spm_model_prefix
)
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
...
@@ -453,10 +452,10 @@ class U2Tester(U2Trainer):
...
@@ -453,10 +452,10 @@ class U2Tester(U2Trainer):
texts
,
texts
,
texts_len
,
texts_len
,
fout
=
None
):
fout
=
None
):
cf
g
=
self
.
config
.
decoding
decode_confi
g
=
self
.
config
.
decoding
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_func
=
error_rate
.
char_errors
if
cf
g
.
error_rate_type
==
'cer'
else
error_rate
.
word_errors
errors_func
=
error_rate
.
char_errors
if
decode_confi
g
.
error_rate_type
==
'cer'
else
error_rate
.
word_errors
error_rate_func
=
error_rate
.
cer
if
cf
g
.
error_rate_type
==
'cer'
else
error_rate
.
wer
error_rate_func
=
error_rate
.
cer
if
decode_confi
g
.
error_rate_type
==
'cer'
else
error_rate
.
wer
start_time
=
time
.
time
()
start_time
=
time
.
time
()
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
self
.
text_feature
)
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
self
.
text_feature
)
...
@@ -464,12 +463,12 @@ class U2Tester(U2Trainer):
...
@@ -464,12 +463,12 @@ class U2Tester(U2Trainer):
audio
,
audio
,
audio_len
,
audio_len
,
text_feature
=
self
.
text_feature
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cf
g
.
decoding_method
,
decoding_method
=
decode_confi
g
.
decoding_method
,
beam_size
=
cf
g
.
beam_size
,
beam_size
=
decode_confi
g
.
beam_size
,
ctc_weight
=
cf
g
.
ctc_weight
,
ctc_weight
=
decode_confi
g
.
ctc_weight
,
decoding_chunk_size
=
cf
g
.
decoding_chunk_size
,
decoding_chunk_size
=
decode_confi
g
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cf
g
.
num_decoding_left_chunks
,
num_decoding_left_chunks
=
decode_confi
g
.
num_decoding_left_chunks
,
simulate_streaming
=
cf
g
.
simulate_streaming
)
simulate_streaming
=
decode_confi
g
.
simulate_streaming
)
decode_time
=
time
.
time
()
-
start_time
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
,
rec_tids
in
zip
(
for
utt
,
target
,
result
,
rec_tids
in
zip
(
...
@@ -488,15 +487,15 @@ class U2Tester(U2Trainer):
...
@@ -488,15 +487,15 @@ class U2Tester(U2Trainer):
logger
.
info
(
f
"Utt:
{
utt
}
"
)
logger
.
info
(
f
"Utt:
{
utt
}
"
)
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
logger
.
info
(
"One example error rate [%s] = %f"
%
logger
.
info
(
"One example error rate [%s] = %f"
%
(
(
cf
g
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
decode_confi
g
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
return
dict
(
return
dict
(
errors_sum
=
errors_sum
,
errors_sum
=
errors_sum
,
len_refs
=
len_refs
,
len_refs
=
len_refs
,
num_ins
=
num_ins
,
# num examples
num_ins
=
num_ins
,
# num examples
error_rate
=
errors_sum
/
len_refs
,
error_rate
=
errors_sum
/
len_refs
,
error_rate_type
=
cf
g
.
error_rate_type
,
error_rate_type
=
decode_confi
g
.
error_rate_type
,
num_frames
=
audio_len
.
sum
().
numpy
().
item
(),
num_frames
=
audio_len
.
sum
().
numpy
().
item
(),
decode_time
=
decode_time
)
decode_time
=
decode_time
)
...
@@ -507,7 +506,7 @@ class U2Tester(U2Trainer):
...
@@ -507,7 +506,7 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
config
.
collator
.
stride_ms
stride_ms
=
self
.
config
.
stride_ms
error_rate_type
=
None
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
num_frames
=
0.0
num_frames
=
0.0
...
@@ -558,15 +557,15 @@ class U2Tester(U2Trainer):
...
@@ -558,15 +557,15 @@ class U2Tester(U2Trainer):
"ref_len"
:
"ref_len"
:
len_refs
,
len_refs
,
"decode_method"
:
"decode_method"
:
self
.
config
.
decoding
.
decoding
_method
,
self
.
config
.
decoding_method
,
})
})
f
.
write
(
data
+
'
\n
'
)
f
.
write
(
data
+
'
\n
'
)
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
align
(
self
):
def
align
(
self
):
ctc_utils
.
ctc_align
(
self
.
config
,
self
.
model
,
self
.
align_loader
,
ctc_utils
.
ctc_align
(
self
.
config
,
self
.
model
,
self
.
align_loader
,
self
.
config
.
decoding
.
batch_size
,
self
.
config
.
decoding
.
decode_
batch_size
,
self
.
config
.
collator
.
stride_ms
,
self
.
vocab_list
,
self
.
config
.
stride_ms
,
self
.
vocab_list
,
self
.
args
.
result_file
)
self
.
args
.
result_file
)
def
load_inferspec
(
self
):
def
load_inferspec
(
self
):
...
@@ -577,10 +576,10 @@ class U2Tester(U2Trainer):
...
@@ -577,10 +576,10 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec.
List[paddle.static.InputSpec]: input spec.
"""
"""
from
paddlespeech.s2t.models.u2
import
U2InferModel
from
paddlespeech.s2t.models.u2
import
U2InferModel
infer_model
=
U2InferModel
.
from_pretrained
(
self
.
t
est
_loader
,
infer_model
=
U2InferModel
.
from_pretrained
(
self
.
t
rain
_loader
,
self
.
config
.
model
.
clone
(),
self
.
config
.
clone
(),
self
.
args
.
checkpoint_path
)
self
.
args
.
checkpoint_path
)
feat_dim
=
self
.
t
est
_loader
.
feat_dim
feat_dim
=
self
.
t
rain
_loader
.
feat_dim
input_spec
=
[
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
1
,
None
,
feat_dim
],
paddle
.
static
.
InputSpec
(
shape
=
[
1
,
None
,
feat_dim
],
dtype
=
'float32'
),
# audio, [B,T,D]
dtype
=
'float32'
),
# audio, [B,T,D]
...
...
paddlespeech/s2t/training/cli.py
浏览文件 @
c40b6f40
...
@@ -97,6 +97,14 @@ def default_argument_parser(parser=None):
...
@@ -97,6 +97,14 @@ def default_argument_parser(parser=None):
train_group
.
add_argument
(
train_group
.
add_argument
(
"--dump-config"
,
metavar
=
"FILE"
,
help
=
"dump config to `this` file."
)
"--dump-config"
,
metavar
=
"FILE"
,
help
=
"dump config to `this` file."
)
test_group
=
parser
.
add_argument_group
(
title
=
'Test Options'
,
description
=
None
)
test_group
.
add_argument
(
"--decode_config"
,
metavar
=
"DECODE_CONFIG_FILE"
,
help
=
"decode config file."
)
profile_group
=
parser
.
add_argument_group
(
profile_group
=
parser
.
add_argument_group
(
title
=
'Benchmark Options'
,
description
=
None
)
title
=
'Benchmark Options'
,
description
=
None
)
profile_group
.
add_argument
(
profile_group
.
add_argument
(
...
...
paddlespeech/s2t/training/trainer.py
浏览文件 @
c40b6f40
...
@@ -117,8 +117,8 @@ class Trainer():
...
@@ -117,8 +117,8 @@ class Trainer():
self
.
init_parallel
()
self
.
init_parallel
()
self
.
checkpoint
=
Checkpoint
(
self
.
checkpoint
=
Checkpoint
(
kbest_n
=
self
.
config
.
training
.
checkpoint
.
kbest_n
,
kbest_n
=
self
.
config
.
checkpoint
.
kbest_n
,
latest_n
=
self
.
config
.
training
.
checkpoint
.
latest_n
)
latest_n
=
self
.
config
.
checkpoint
.
latest_n
)
# set random seed if needed
# set random seed if needed
if
args
.
seed
:
if
args
.
seed
:
...
@@ -129,8 +129,8 @@ class Trainer():
...
@@ -129,8 +129,8 @@ class Trainer():
if
hasattr
(
self
.
args
,
if
hasattr
(
self
.
args
,
"benchmark_batch_size"
)
and
self
.
args
.
benchmark_batch_size
:
"benchmark_batch_size"
)
and
self
.
args
.
benchmark_batch_size
:
with
UpdateConfig
(
self
.
config
):
with
UpdateConfig
(
self
.
config
):
self
.
config
.
collator
.
batch_size
=
self
.
args
.
benchmark_batch_size
self
.
config
.
batch_size
=
self
.
args
.
benchmark_batch_size
self
.
config
.
training
.
log_interval
=
1
self
.
config
.
log_interval
=
1
logger
.
info
(
logger
.
info
(
f
"Benchmark reset batch-size:
{
self
.
args
.
benchmark_batch_size
}
"
)
f
"Benchmark reset batch-size:
{
self
.
args
.
benchmark_batch_size
}
"
)
...
@@ -260,7 +260,7 @@ class Trainer():
...
@@ -260,7 +260,7 @@ class Trainer():
self
.
before_train
()
self
.
before_train
()
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
try
:
try
:
...
...
paddlespeech/s2t/utils/utility.py
浏览文件 @
c40b6f40
...
@@ -130,7 +130,7 @@ def get_subsample(config):
...
@@ -130,7 +130,7 @@ def get_subsample(config):
Returns:
Returns:
int: subsample rate.
int: subsample rate.
"""
"""
input_layer
=
config
[
"
model"
][
"
encoder_conf"
][
"input_layer"
]
input_layer
=
config
[
"encoder_conf"
][
"input_layer"
]
assert
input_layer
in
[
"conv2d"
,
"conv2d6"
,
"conv2d8"
]
assert
input_layer
in
[
"conv2d"
,
"conv2d6"
,
"conv2d8"
]
if
input_layer
==
"conv2d"
:
if
input_layer
==
"conv2d"
:
return
4
return
4
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录