Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
bc93bffb
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bc93bffb
编写于
7月 01, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
replace logger.info with logger.debug in cli, change default log level to INFO
上级
cf846f9e
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
98 addition
and
95 deletion
+98
-95
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+20
-20
paddlespeech/cli/cls/infer.py
paddlespeech/cli/cls/infer.py
+3
-3
paddlespeech/cli/download.py
paddlespeech/cli/download.py
+8
-8
paddlespeech/cli/kws/infer.py
paddlespeech/cli/kws/infer.py
+2
-2
paddlespeech/cli/log.py
paddlespeech/cli/log.py
+1
-1
paddlespeech/cli/st/infer.py
paddlespeech/cli/st/infer.py
+5
-5
paddlespeech/cli/text/infer.py
paddlespeech/cli/text/infer.py
+1
-1
paddlespeech/cli/tts/infer.py
paddlespeech/cli/tts/infer.py
+7
-7
paddlespeech/cli/vector/infer.py
paddlespeech/cli/vector/infer.py
+31
-30
paddlespeech/s2t/frontend/augmentor/spec_augment.py
paddlespeech/s2t/frontend/augmentor/spec_augment.py
+3
-3
paddlespeech/s2t/frontend/featurizer/text_featurizer.py
paddlespeech/s2t/frontend/featurizer/text_featurizer.py
+6
-6
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+2
-2
paddlespeech/s2t/modules/loss.py
paddlespeech/s2t/modules/loss.py
+4
-3
paddlespeech/s2t/transform/spec_augment.py
paddlespeech/s2t/transform/spec_augment.py
+4
-3
paddlespeech/s2t/utils/tensor_utils.py
paddlespeech/s2t/utils/tensor_utils.py
+1
-1
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
bc93bffb
...
@@ -133,11 +133,11 @@ class ASRExecutor(BaseExecutor):
...
@@ -133,11 +133,11 @@ class ASRExecutor(BaseExecutor):
"""
"""
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
logger
.
info
(
"start to init the model"
)
logger
.
debug
(
"start to init the model"
)
# default max_len: unit:second
# default max_len: unit:second
self
.
max_len
=
50
self
.
max_len
=
50
if
hasattr
(
self
,
'model'
):
if
hasattr
(
self
,
'model'
):
logger
.
info
(
'Model had been initialized.'
)
logger
.
debug
(
'Model had been initialized.'
)
return
return
if
cfg_path
is
None
or
ckpt_path
is
None
:
if
cfg_path
is
None
or
ckpt_path
is
None
:
...
@@ -151,15 +151,15 @@ class ASRExecutor(BaseExecutor):
...
@@ -151,15 +151,15 @@ class ASRExecutor(BaseExecutor):
self
.
ckpt_path
=
os
.
path
.
join
(
self
.
ckpt_path
=
os
.
path
.
join
(
self
.
res_path
,
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'ckpt_path'
]
+
".pdparams"
)
self
.
task_resource
.
res_dict
[
'ckpt_path'
]
+
".pdparams"
)
logger
.
info
(
self
.
res_path
)
logger
.
debug
(
self
.
res_path
)
else
:
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
+
".pdparams"
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
+
".pdparams"
)
self
.
res_path
=
os
.
path
.
dirname
(
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
self
.
cfg_path
)
logger
.
debug
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
logger
.
debug
(
self
.
ckpt_path
)
#Init body.
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
@@ -216,7 +216,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -216,7 +216,7 @@ class ASRExecutor(BaseExecutor):
max_len
=
self
.
config
.
encoder_conf
.
max_len
max_len
=
self
.
config
.
encoder_conf
.
max_len
self
.
max_len
=
frame_shift_ms
*
max_len
*
subsample_rate
self
.
max_len
=
frame_shift_ms
*
max_len
*
subsample_rate
logger
.
info
(
logger
.
debug
(
f
"The asr server limit max duration len:
{
self
.
max_len
}
"
)
f
"The asr server limit max duration len:
{
self
.
max_len
}
"
)
def
preprocess
(
self
,
model_type
:
str
,
input
:
Union
[
str
,
os
.
PathLike
]):
def
preprocess
(
self
,
model_type
:
str
,
input
:
Union
[
str
,
os
.
PathLike
]):
...
@@ -227,15 +227,15 @@ class ASRExecutor(BaseExecutor):
...
@@ -227,15 +227,15 @@ class ASRExecutor(BaseExecutor):
audio_file
=
input
audio_file
=
input
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
logger
.
info
(
"Preprocess audio_file:"
+
audio_file
)
logger
.
debug
(
"Preprocess audio_file:"
+
audio_file
)
# Get the object for feature extraction
# Get the object for feature extraction
if
"deepspeech2"
in
model_type
or
"conformer"
in
model_type
or
"transformer"
in
model_type
:
if
"deepspeech2"
in
model_type
or
"conformer"
in
model_type
or
"transformer"
in
model_type
:
logger
.
info
(
"get the preprocess conf"
)
logger
.
debug
(
"get the preprocess conf"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocess_args
=
{
"train"
:
False
}
preprocessing
=
Transformation
(
preprocess_conf
)
preprocessing
=
Transformation
(
preprocess_conf
)
logger
.
info
(
"read the audio file"
)
logger
.
debug
(
"read the audio file"
)
audio
,
audio_sample_rate
=
soundfile
.
read
(
audio
,
audio_sample_rate
=
soundfile
.
read
(
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
if
self
.
change_format
:
if
self
.
change_format
:
...
@@ -255,7 +255,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -255,7 +255,7 @@ class ASRExecutor(BaseExecutor):
else
:
else
:
audio
=
audio
[:,
0
]
audio
=
audio
[:,
0
]
logger
.
info
(
f
"audio shape:
{
audio
.
shape
}
"
)
logger
.
debug
(
f
"audio shape:
{
audio
.
shape
}
"
)
# fbank
# fbank
audio
=
preprocessing
(
audio
,
**
preprocess_args
)
audio
=
preprocessing
(
audio
,
**
preprocess_args
)
...
@@ -264,19 +264,19 @@ class ASRExecutor(BaseExecutor):
...
@@ -264,19 +264,19 @@ class ASRExecutor(BaseExecutor):
self
.
_inputs
[
"audio"
]
=
audio
self
.
_inputs
[
"audio"
]
=
audio
self
.
_inputs
[
"audio_len"
]
=
audio_len
self
.
_inputs
[
"audio_len"
]
=
audio_len
logger
.
info
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
logger
.
debug
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
else
:
else
:
raise
Exception
(
"wrong type"
)
raise
Exception
(
"wrong type"
)
logger
.
info
(
"audio feat process success"
)
logger
.
debug
(
"audio feat process success"
)
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
infer
(
self
,
model_type
:
str
):
def
infer
(
self
,
model_type
:
str
):
"""
"""
Model inference and result stored in self.output.
Model inference and result stored in self.output.
"""
"""
logger
.
info
(
"start to infer the model to get the output"
)
logger
.
debug
(
"start to infer the model to get the output"
)
cfg
=
self
.
config
.
decode
cfg
=
self
.
config
.
decode
audio
=
self
.
_inputs
[
"audio"
]
audio
=
self
.
_inputs
[
"audio"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
...
@@ -293,7 +293,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -293,7 +293,7 @@ class ASRExecutor(BaseExecutor):
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
self
.
_outputs
[
"result"
]
=
result_transcripts
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
logger
.
info
(
logger
.
debug
(
f
"we will use the transformer like model :
{
model_type
}
"
)
f
"we will use the transformer like model :
{
model_type
}
"
)
try
:
try
:
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
=
self
.
model
.
decode
(
...
@@ -352,7 +352,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -352,7 +352,7 @@ class ASRExecutor(BaseExecutor):
logger
.
error
(
"Please input the right audio file path"
)
logger
.
error
(
"Please input the right audio file path"
)
return
False
return
False
logger
.
info
(
"checking the audio file format......"
)
logger
.
debug
(
"checking the audio file format......"
)
try
:
try
:
audio
,
audio_sample_rate
=
soundfile
.
read
(
audio
,
audio_sample_rate
=
soundfile
.
read
(
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
...
@@ -374,7 +374,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -374,7 +374,7 @@ class ASRExecutor(BaseExecutor):
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav
\n
\
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav
\n
\
"
)
"
)
return
False
return
False
logger
.
info
(
"The sample rate is %d"
%
audio_sample_rate
)
logger
.
debug
(
"The sample rate is %d"
%
audio_sample_rate
)
if
audio_sample_rate
!=
self
.
sample_rate
:
if
audio_sample_rate
!=
self
.
sample_rate
:
logger
.
warning
(
"The sample rate of the input file is not {}.
\n
\
logger
.
warning
(
"The sample rate of the input file is not {}.
\n
\
The program will resample the wav file to {}.
\n
\
The program will resample the wav file to {}.
\n
\
...
@@ -383,28 +383,28 @@ class ASRExecutor(BaseExecutor):
...
@@ -383,28 +383,28 @@ class ASRExecutor(BaseExecutor):
"
.
format
(
self
.
sample_rate
,
self
.
sample_rate
))
"
.
format
(
self
.
sample_rate
,
self
.
sample_rate
))
if
force_yes
is
False
:
if
force_yes
is
False
:
while
(
True
):
while
(
True
):
logger
.
info
(
logger
.
debug
(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
)
content
=
input
(
"Input(Y/N):"
)
content
=
input
(
"Input(Y/N):"
)
if
content
.
strip
()
==
"Y"
or
content
.
strip
(
if
content
.
strip
()
==
"Y"
or
content
.
strip
(
)
==
"y"
or
content
.
strip
()
==
"yes"
or
content
.
strip
(
)
==
"y"
or
content
.
strip
()
==
"yes"
or
content
.
strip
(
)
==
"Yes"
:
)
==
"Yes"
:
logger
.
info
(
logger
.
debug
(
"change the sampele rate, channel to 16k and 1 channel"
"change the sampele rate, channel to 16k and 1 channel"
)
)
break
break
elif
content
.
strip
()
==
"N"
or
content
.
strip
(
elif
content
.
strip
()
==
"N"
or
content
.
strip
(
)
==
"n"
or
content
.
strip
()
==
"no"
or
content
.
strip
(
)
==
"n"
or
content
.
strip
()
==
"no"
or
content
.
strip
(
)
==
"No"
:
)
==
"No"
:
logger
.
info
(
"Exit the program"
)
logger
.
debug
(
"Exit the program"
)
return
False
return
False
else
:
else
:
logger
.
warning
(
"Not regular input, please input again"
)
logger
.
warning
(
"Not regular input, please input again"
)
self
.
change_format
=
True
self
.
change_format
=
True
else
:
else
:
logger
.
info
(
"The audio file format is right"
)
logger
.
debug
(
"The audio file format is right"
)
self
.
change_format
=
False
self
.
change_format
=
False
return
True
return
True
...
...
paddlespeech/cli/cls/infer.py
浏览文件 @
bc93bffb
...
@@ -92,7 +92,7 @@ class CLSExecutor(BaseExecutor):
...
@@ -92,7 +92,7 @@ class CLSExecutor(BaseExecutor):
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
if
hasattr
(
self
,
'model'
):
if
hasattr
(
self
,
'model'
):
logger
.
info
(
'Model had been initialized.'
)
logger
.
debug
(
'Model had been initialized.'
)
return
return
if
label_file
is
None
or
ckpt_path
is
None
:
if
label_file
is
None
or
ckpt_path
is
None
:
...
@@ -135,14 +135,14 @@ class CLSExecutor(BaseExecutor):
...
@@ -135,14 +135,14 @@ class CLSExecutor(BaseExecutor):
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
"""
"""
feat_conf
=
self
.
_conf
[
'feature'
]
feat_conf
=
self
.
_conf
[
'feature'
]
logger
.
info
(
feat_conf
)
logger
.
debug
(
feat_conf
)
waveform
,
_
=
load
(
waveform
,
_
=
load
(
file
=
audio_file
,
file
=
audio_file
,
sr
=
feat_conf
[
'sample_rate'
],
sr
=
feat_conf
[
'sample_rate'
],
mono
=
True
,
mono
=
True
,
dtype
=
'float32'
)
dtype
=
'float32'
)
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
logger
.
info
(
"Preprocessing audio_file:"
+
audio_file
)
logger
.
debug
(
"Preprocessing audio_file:"
+
audio_file
)
# Feature extraction
# Feature extraction
feature_extractor
=
LogMelSpectrogram
(
feature_extractor
=
LogMelSpectrogram
(
...
...
paddlespeech/cli/download.py
浏览文件 @
bc93bffb
...
@@ -61,7 +61,7 @@ def _get_unique_endpoints(trainer_endpoints):
...
@@ -61,7 +61,7 @@ def _get_unique_endpoints(trainer_endpoints):
continue
continue
ips
.
add
(
ip
)
ips
.
add
(
ip
)
unique_endpoints
.
add
(
endpoint
)
unique_endpoints
.
add
(
endpoint
)
logger
.
info
(
"unique_endpoints {}"
.
format
(
unique_endpoints
))
logger
.
debug
(
"unique_endpoints {}"
.
format
(
unique_endpoints
))
return
unique_endpoints
return
unique_endpoints
...
@@ -96,7 +96,7 @@ def get_path_from_url(url,
...
@@ -96,7 +96,7 @@ def get_path_from_url(url,
# data, and the same ip will only download data once.
# data, and the same ip will only download data once.
unique_endpoints
=
_get_unique_endpoints
(
ParallelEnv
().
trainer_endpoints
[:])
unique_endpoints
=
_get_unique_endpoints
(
ParallelEnv
().
trainer_endpoints
[:])
if
osp
.
exists
(
fullpath
)
and
check_exist
and
_md5check
(
fullpath
,
md5sum
):
if
osp
.
exists
(
fullpath
)
and
check_exist
and
_md5check
(
fullpath
,
md5sum
):
logger
.
info
(
"Found {}"
.
format
(
fullpath
))
logger
.
debug
(
"Found {}"
.
format
(
fullpath
))
else
:
else
:
if
ParallelEnv
().
current_endpoint
in
unique_endpoints
:
if
ParallelEnv
().
current_endpoint
in
unique_endpoints
:
fullpath
=
_download
(
url
,
root_dir
,
md5sum
,
method
=
method
)
fullpath
=
_download
(
url
,
root_dir
,
md5sum
,
method
=
method
)
...
@@ -118,7 +118,7 @@ def _get_download(url, fullname):
...
@@ -118,7 +118,7 @@ def _get_download(url, fullname):
try
:
try
:
req
=
requests
.
get
(
url
,
stream
=
True
)
req
=
requests
.
get
(
url
,
stream
=
True
)
except
Exception
as
e
:
# requests.exceptions.ConnectionError
except
Exception
as
e
:
# requests.exceptions.ConnectionError
logger
.
info
(
"Downloading {} from {} failed with exception {}"
.
format
(
logger
.
debug
(
"Downloading {} from {} failed with exception {}"
.
format
(
fname
,
url
,
str
(
e
)))
fname
,
url
,
str
(
e
)))
return
False
return
False
...
@@ -190,7 +190,7 @@ def _download(url, path, md5sum=None, method='get'):
...
@@ -190,7 +190,7 @@ def _download(url, path, md5sum=None, method='get'):
fullname
=
osp
.
join
(
path
,
fname
)
fullname
=
osp
.
join
(
path
,
fname
)
retry_cnt
=
0
retry_cnt
=
0
logger
.
info
(
"Downloading {} from {}"
.
format
(
fname
,
url
))
logger
.
debug
(
"Downloading {} from {}"
.
format
(
fname
,
url
))
while
not
(
osp
.
exists
(
fullname
)
and
_md5check
(
fullname
,
md5sum
)):
while
not
(
osp
.
exists
(
fullname
)
and
_md5check
(
fullname
,
md5sum
)):
if
retry_cnt
<
DOWNLOAD_RETRY_LIMIT
:
if
retry_cnt
<
DOWNLOAD_RETRY_LIMIT
:
retry_cnt
+=
1
retry_cnt
+=
1
...
@@ -209,7 +209,7 @@ def _md5check(fullname, md5sum=None):
...
@@ -209,7 +209,7 @@ def _md5check(fullname, md5sum=None):
if
md5sum
is
None
:
if
md5sum
is
None
:
return
True
return
True
logger
.
info
(
"File {} md5 checking..."
.
format
(
fullname
))
logger
.
debug
(
"File {} md5 checking..."
.
format
(
fullname
))
md5
=
hashlib
.
md5
()
md5
=
hashlib
.
md5
()
with
open
(
fullname
,
'rb'
)
as
f
:
with
open
(
fullname
,
'rb'
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
...
@@ -217,8 +217,8 @@ def _md5check(fullname, md5sum=None):
...
@@ -217,8 +217,8 @@ def _md5check(fullname, md5sum=None):
calc_md5sum
=
md5
.
hexdigest
()
calc_md5sum
=
md5
.
hexdigest
()
if
calc_md5sum
!=
md5sum
:
if
calc_md5sum
!=
md5sum
:
logger
.
info
(
"File {} md5 check failed, {}(calc) != "
logger
.
debug
(
"File {} md5 check failed, {}(calc) != "
"{}(base)"
.
format
(
fullname
,
calc_md5sum
,
md5sum
))
"{}(base)"
.
format
(
fullname
,
calc_md5sum
,
md5sum
))
return
False
return
False
return
True
return
True
...
@@ -227,7 +227,7 @@ def _decompress(fname):
...
@@ -227,7 +227,7 @@ def _decompress(fname):
"""
"""
Decompress for zip and tar file
Decompress for zip and tar file
"""
"""
logger
.
info
(
"Decompressing {}..."
.
format
(
fname
))
logger
.
debug
(
"Decompressing {}..."
.
format
(
fname
))
# For protecting decompressing interupted,
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# decompress to fpath_tmp directory firstly, if decompress
...
...
paddlespeech/cli/kws/infer.py
浏览文件 @
bc93bffb
...
@@ -88,7 +88,7 @@ class KWSExecutor(BaseExecutor):
...
@@ -88,7 +88,7 @@ class KWSExecutor(BaseExecutor):
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
if
hasattr
(
self
,
'model'
):
if
hasattr
(
self
,
'model'
):
logger
.
info
(
'Model had been initialized.'
)
logger
.
debug
(
'Model had been initialized.'
)
return
return
if
ckpt_path
is
None
:
if
ckpt_path
is
None
:
...
@@ -141,7 +141,7 @@ class KWSExecutor(BaseExecutor):
...
@@ -141,7 +141,7 @@ class KWSExecutor(BaseExecutor):
assert
os
.
path
.
isfile
(
audio_file
)
assert
os
.
path
.
isfile
(
audio_file
)
waveform
,
_
=
load
(
audio_file
)
waveform
,
_
=
load
(
audio_file
)
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
logger
.
info
(
"Preprocessing audio_file:"
+
audio_file
)
logger
.
debug
(
"Preprocessing audio_file:"
+
audio_file
)
# Feature extraction
# Feature extraction
waveform
=
paddle
.
to_tensor
(
waveform
).
unsqueeze
(
0
)
waveform
=
paddle
.
to_tensor
(
waveform
).
unsqueeze
(
0
)
...
...
paddlespeech/cli/log.py
浏览文件 @
bc93bffb
...
@@ -49,7 +49,7 @@ class Logger(object):
...
@@ -49,7 +49,7 @@ class Logger(object):
self
.
handler
.
setFormatter
(
self
.
format
)
self
.
handler
.
setFormatter
(
self
.
format
)
self
.
logger
.
addHandler
(
self
.
handler
)
self
.
logger
.
addHandler
(
self
.
handler
)
self
.
logger
.
setLevel
(
logging
.
DEBUG
)
self
.
logger
.
setLevel
(
logging
.
INFO
)
self
.
logger
.
propagate
=
False
self
.
logger
.
propagate
=
False
def
__call__
(
self
,
log_level
:
str
,
msg
:
str
):
def
__call__
(
self
,
log_level
:
str
,
msg
:
str
):
...
...
paddlespeech/cli/st/infer.py
浏览文件 @
bc93bffb
...
@@ -110,7 +110,7 @@ class STExecutor(BaseExecutor):
...
@@ -110,7 +110,7 @@ class STExecutor(BaseExecutor):
"""
"""
decompressed_path
=
download_and_decompress
(
self
.
kaldi_bins
,
MODEL_HOME
)
decompressed_path
=
download_and_decompress
(
self
.
kaldi_bins
,
MODEL_HOME
)
decompressed_path
=
os
.
path
.
abspath
(
decompressed_path
)
decompressed_path
=
os
.
path
.
abspath
(
decompressed_path
)
logger
.
info
(
"Kaldi_bins stored in: {}"
.
format
(
decompressed_path
))
logger
.
debug
(
"Kaldi_bins stored in: {}"
.
format
(
decompressed_path
))
if
"LD_LIBRARY_PATH"
in
os
.
environ
:
if
"LD_LIBRARY_PATH"
in
os
.
environ
:
os
.
environ
[
"LD_LIBRARY_PATH"
]
+=
f
":
{
decompressed_path
}
"
os
.
environ
[
"LD_LIBRARY_PATH"
]
+=
f
":
{
decompressed_path
}
"
else
:
else
:
...
@@ -128,7 +128,7 @@ class STExecutor(BaseExecutor):
...
@@ -128,7 +128,7 @@ class STExecutor(BaseExecutor):
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
if
hasattr
(
self
,
'model'
):
if
hasattr
(
self
,
'model'
):
logger
.
info
(
'Model had been initialized.'
)
logger
.
debug
(
'Model had been initialized.'
)
return
return
if
cfg_path
is
None
or
ckpt_path
is
None
:
if
cfg_path
is
None
or
ckpt_path
is
None
:
...
@@ -140,8 +140,8 @@ class STExecutor(BaseExecutor):
...
@@ -140,8 +140,8 @@ class STExecutor(BaseExecutor):
self
.
ckpt_path
=
os
.
path
.
join
(
self
.
ckpt_path
=
os
.
path
.
join
(
self
.
task_resource
.
res_dir
,
self
.
task_resource
.
res_dir
,
self
.
task_resource
.
res_dict
[
'ckpt_path'
])
self
.
task_resource
.
res_dict
[
'ckpt_path'
])
logger
.
info
(
self
.
cfg_path
)
logger
.
debug
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
logger
.
debug
(
self
.
ckpt_path
)
res_path
=
self
.
task_resource
.
res_dir
res_path
=
self
.
task_resource
.
res_dir
else
:
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
...
@@ -192,7 +192,7 @@ class STExecutor(BaseExecutor):
...
@@ -192,7 +192,7 @@ class STExecutor(BaseExecutor):
Input content can be a file(wav).
Input content can be a file(wav).
"""
"""
audio_file
=
os
.
path
.
abspath
(
wav_file
)
audio_file
=
os
.
path
.
abspath
(
wav_file
)
logger
.
info
(
"Preprocess audio_file:"
+
audio_file
)
logger
.
debug
(
"Preprocess audio_file:"
+
audio_file
)
if
"fat_st"
in
model_type
:
if
"fat_st"
in
model_type
:
cmvn
=
self
.
config
.
cmvn_path
cmvn
=
self
.
config
.
cmvn_path
...
...
paddlespeech/cli/text/infer.py
浏览文件 @
bc93bffb
...
@@ -98,7 +98,7 @@ class TextExecutor(BaseExecutor):
...
@@ -98,7 +98,7 @@ class TextExecutor(BaseExecutor):
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
if
hasattr
(
self
,
'model'
):
if
hasattr
(
self
,
'model'
):
logger
.
info
(
'Model had been initialized.'
)
logger
.
debug
(
'Model had been initialized.'
)
return
return
self
.
task
=
task
self
.
task
=
task
...
...
paddlespeech/cli/tts/infer.py
浏览文件 @
bc93bffb
...
@@ -173,7 +173,7 @@ class TTSExecutor(BaseExecutor):
...
@@ -173,7 +173,7 @@ class TTSExecutor(BaseExecutor):
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
if
hasattr
(
self
,
'am_inference'
)
and
hasattr
(
self
,
'voc_inference'
):
if
hasattr
(
self
,
'am_inference'
)
and
hasattr
(
self
,
'voc_inference'
):
logger
.
info
(
'Models had been initialized.'
)
logger
.
debug
(
'Models had been initialized.'
)
return
return
# am
# am
...
@@ -200,9 +200,9 @@ class TTSExecutor(BaseExecutor):
...
@@ -200,9 +200,9 @@ class TTSExecutor(BaseExecutor):
# must have phones_dict in acoustic
# must have phones_dict in acoustic
self
.
phones_dict
=
os
.
path
.
join
(
self
.
phones_dict
=
os
.
path
.
join
(
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
logger
.
info
(
self
.
am_res_path
)
logger
.
debug
(
self
.
am_res_path
)
logger
.
info
(
self
.
am_config
)
logger
.
debug
(
self
.
am_config
)
logger
.
info
(
self
.
am_ckpt
)
logger
.
debug
(
self
.
am_ckpt
)
else
:
else
:
self
.
am_config
=
os
.
path
.
abspath
(
am_config
)
self
.
am_config
=
os
.
path
.
abspath
(
am_config
)
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
)
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
)
...
@@ -248,9 +248,9 @@ class TTSExecutor(BaseExecutor):
...
@@ -248,9 +248,9 @@ class TTSExecutor(BaseExecutor):
self
.
voc_stat
=
os
.
path
.
join
(
self
.
voc_stat
=
os
.
path
.
join
(
self
.
voc_res_path
,
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'speech_stats'
])
self
.
task_resource
.
voc_res_dict
[
'speech_stats'
])
logger
.
info
(
self
.
voc_res_path
)
logger
.
debug
(
self
.
voc_res_path
)
logger
.
info
(
self
.
voc_config
)
logger
.
debug
(
self
.
voc_config
)
logger
.
info
(
self
.
voc_ckpt
)
logger
.
debug
(
self
.
voc_ckpt
)
else
:
else
:
self
.
voc_config
=
os
.
path
.
abspath
(
voc_config
)
self
.
voc_config
=
os
.
path
.
abspath
(
voc_config
)
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
...
...
paddlespeech/cli/vector/infer.py
浏览文件 @
bc93bffb
...
@@ -117,7 +117,7 @@ class VectorExecutor(BaseExecutor):
...
@@ -117,7 +117,7 @@ class VectorExecutor(BaseExecutor):
# stage 2: read the input data and store them as a list
# stage 2: read the input data and store them as a list
task_source
=
self
.
get_input_source
(
parser_args
.
input
)
task_source
=
self
.
get_input_source
(
parser_args
.
input
)
logger
.
info
(
f
"task source:
{
task_source
}
"
)
logger
.
debug
(
f
"task source:
{
task_source
}
"
)
# stage 3: process the audio one by one
# stage 3: process the audio one by one
# we do action according the task type
# we do action according the task type
...
@@ -127,13 +127,13 @@ class VectorExecutor(BaseExecutor):
...
@@ -127,13 +127,13 @@ class VectorExecutor(BaseExecutor):
try
:
try
:
# extract the speaker audio embedding
# extract the speaker audio embedding
if
parser_args
.
task
==
"spk"
:
if
parser_args
.
task
==
"spk"
:
logger
.
info
(
"do vector spk task"
)
logger
.
debug
(
"do vector spk task"
)
res
=
self
(
input_
,
model
,
sample_rate
,
config
,
ckpt_path
,
res
=
self
(
input_
,
model
,
sample_rate
,
config
,
ckpt_path
,
device
)
device
)
task_result
[
id_
]
=
res
task_result
[
id_
]
=
res
elif
parser_args
.
task
==
"score"
:
elif
parser_args
.
task
==
"score"
:
logger
.
info
(
"do vector score task"
)
logger
.
debug
(
"do vector score task"
)
logger
.
info
(
f
"input content
{
input_
}
"
)
logger
.
debug
(
f
"input content
{
input_
}
"
)
if
len
(
input_
.
split
())
!=
2
:
if
len
(
input_
.
split
())
!=
2
:
logger
.
error
(
logger
.
error
(
f
"vector score task input
{
input_
}
wav num is not two,"
f
"vector score task input
{
input_
}
wav num is not two,"
...
@@ -142,7 +142,7 @@ class VectorExecutor(BaseExecutor):
...
@@ -142,7 +142,7 @@ class VectorExecutor(BaseExecutor):
# get the enroll and test embedding
# get the enroll and test embedding
enroll_audio
,
test_audio
=
input_
.
split
()
enroll_audio
,
test_audio
=
input_
.
split
()
logger
.
info
(
logger
.
debug
(
f
"score task, enroll audio:
{
enroll_audio
}
, test audio:
{
test_audio
}
"
f
"score task, enroll audio:
{
enroll_audio
}
, test audio:
{
test_audio
}
"
)
)
enroll_embedding
=
self
(
enroll_audio
,
model
,
sample_rate
,
enroll_embedding
=
self
(
enroll_audio
,
model
,
sample_rate
,
...
@@ -158,8 +158,8 @@ class VectorExecutor(BaseExecutor):
...
@@ -158,8 +158,8 @@ class VectorExecutor(BaseExecutor):
has_exceptions
=
True
has_exceptions
=
True
task_result
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
task_result
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
logger
.
info
(
"task result as follows: "
)
logger
.
debug
(
"task result as follows: "
)
logger
.
info
(
f
"
{
task_result
}
"
)
logger
.
debug
(
f
"
{
task_result
}
"
)
# stage 4: process the all the task results
# stage 4: process the all the task results
self
.
process_task_results
(
parser_args
.
input
,
task_result
,
self
.
process_task_results
(
parser_args
.
input
,
task_result
,
...
@@ -207,7 +207,7 @@ class VectorExecutor(BaseExecutor):
...
@@ -207,7 +207,7 @@ class VectorExecutor(BaseExecutor):
"""
"""
if
not
hasattr
(
self
,
"score_func"
):
if
not
hasattr
(
self
,
"score_func"
):
self
.
score_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=
0
)
self
.
score_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=
0
)
logger
.
info
(
"create the cosine score function "
)
logger
.
debug
(
"create the cosine score function "
)
score
=
self
.
score_func
(
score
=
self
.
score_func
(
paddle
.
to_tensor
(
enroll_embedding
),
paddle
.
to_tensor
(
enroll_embedding
),
...
@@ -244,7 +244,7 @@ class VectorExecutor(BaseExecutor):
...
@@ -244,7 +244,7 @@ class VectorExecutor(BaseExecutor):
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
# stage 1: set the paddle runtime host device
# stage 1: set the paddle runtime host device
logger
.
info
(
f
"device type:
{
device
}
"
)
logger
.
debug
(
f
"device type:
{
device
}
"
)
paddle
.
device
.
set_device
(
device
)
paddle
.
device
.
set_device
(
device
)
# stage 2: read the specific pretrained model
# stage 2: read the specific pretrained model
...
@@ -283,7 +283,7 @@ class VectorExecutor(BaseExecutor):
...
@@ -283,7 +283,7 @@ class VectorExecutor(BaseExecutor):
# stage 0: avoid to init the mode again
# stage 0: avoid to init the mode again
self
.
task
=
task
self
.
task
=
task
if
hasattr
(
self
,
"model"
):
if
hasattr
(
self
,
"model"
):
logger
.
info
(
"Model has been initialized"
)
logger
.
debug
(
"Model has been initialized"
)
return
return
# stage 1: get the model and config path
# stage 1: get the model and config path
...
@@ -294,7 +294,7 @@ class VectorExecutor(BaseExecutor):
...
@@ -294,7 +294,7 @@ class VectorExecutor(BaseExecutor):
sample_rate_str
=
"16k"
if
sample_rate
==
16000
else
"8k"
sample_rate_str
=
"16k"
if
sample_rate
==
16000
else
"8k"
tag
=
model_type
+
"-"
+
sample_rate_str
tag
=
model_type
+
"-"
+
sample_rate_str
self
.
task_resource
.
set_task_model
(
tag
,
version
=
None
)
self
.
task_resource
.
set_task_model
(
tag
,
version
=
None
)
logger
.
info
(
f
"load the pretrained model:
{
tag
}
"
)
logger
.
debug
(
f
"load the pretrained model:
{
tag
}
"
)
# get the model from the pretrained list
# get the model from the pretrained list
# we download the pretrained model and store it in the res_path
# we download the pretrained model and store it in the res_path
self
.
res_path
=
self
.
task_resource
.
res_dir
self
.
res_path
=
self
.
task_resource
.
res_dir
...
@@ -312,19 +312,19 @@ class VectorExecutor(BaseExecutor):
...
@@ -312,19 +312,19 @@ class VectorExecutor(BaseExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
f
"start to read the ckpt from
{
self
.
ckpt_path
}
"
)
logger
.
debug
(
f
"start to read the ckpt from
{
self
.
ckpt_path
}
"
)
logger
.
info
(
f
"read the config from
{
self
.
cfg_path
}
"
)
logger
.
debug
(
f
"read the config from
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
"get the res path
{
self
.
res_path
}
"
)
logger
.
debug
(
f
"get the res path
{
self
.
res_path
}
"
)
# stage 2: read and config and init the model body
# stage 2: read and config and init the model body
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
# stage 3: get the model name to instance the model network with dynamic_import
# stage 3: get the model name to instance the model network with dynamic_import
logger
.
info
(
"start to dynamic import the model class"
)
logger
.
debug
(
"start to dynamic import the model class"
)
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
logger
.
info
(
f
"model name
{
model_name
}
"
)
logger
.
debug
(
f
"model name
{
model_name
}
"
)
model_conf
=
self
.
config
.
model
model_conf
=
self
.
config
.
model
backbone
=
model_class
(
**
model_conf
)
backbone
=
model_class
(
**
model_conf
)
model
=
SpeakerIdetification
(
model
=
SpeakerIdetification
(
...
@@ -333,11 +333,11 @@ class VectorExecutor(BaseExecutor):
...
@@ -333,11 +333,11 @@ class VectorExecutor(BaseExecutor):
self
.
model
.
eval
()
self
.
model
.
eval
()
# stage 4: load the model parameters
# stage 4: load the model parameters
logger
.
info
(
"start to set the model parameters to model"
)
logger
.
debug
(
"start to set the model parameters to model"
)
model_dict
=
paddle
.
load
(
self
.
ckpt_path
)
model_dict
=
paddle
.
load
(
self
.
ckpt_path
)
self
.
model
.
set_state_dict
(
model_dict
)
self
.
model
.
set_state_dict
(
model_dict
)
logger
.
info
(
"create the model instance success"
)
logger
.
debug
(
"create the model instance success"
)
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
infer
(
self
,
model_type
:
str
):
def
infer
(
self
,
model_type
:
str
):
...
@@ -349,14 +349,14 @@ class VectorExecutor(BaseExecutor):
...
@@ -349,14 +349,14 @@ class VectorExecutor(BaseExecutor):
# stage 0: get the feat and length from _inputs
# stage 0: get the feat and length from _inputs
feats
=
self
.
_inputs
[
"feats"
]
feats
=
self
.
_inputs
[
"feats"
]
lengths
=
self
.
_inputs
[
"lengths"
]
lengths
=
self
.
_inputs
[
"lengths"
]
logger
.
info
(
"start to do backbone network model forward"
)
logger
.
debug
(
"start to do backbone network model forward"
)
logger
.
info
(
logger
.
debug
(
f
"feats shape:
{
feats
.
shape
}
, lengths shape:
{
lengths
.
shape
}
"
)
f
"feats shape:
{
feats
.
shape
}
, lengths shape:
{
lengths
.
shape
}
"
)
# stage 1: get the audio embedding
# stage 1: get the audio embedding
# embedding from (1, emb_size, 1) -> (emb_size)
# embedding from (1, emb_size, 1) -> (emb_size)
embedding
=
self
.
model
.
backbone
(
feats
,
lengths
).
squeeze
().
numpy
()
embedding
=
self
.
model
.
backbone
(
feats
,
lengths
).
squeeze
().
numpy
()
logger
.
info
(
f
"embedding size:
{
embedding
.
shape
}
"
)
logger
.
debug
(
f
"embedding size:
{
embedding
.
shape
}
"
)
# stage 2: put the embedding and dim info to _outputs property
# stage 2: put the embedding and dim info to _outputs property
# the embedding type is numpy.array
# the embedding type is numpy.array
...
@@ -380,12 +380,13 @@ class VectorExecutor(BaseExecutor):
...
@@ -380,12 +380,13 @@ class VectorExecutor(BaseExecutor):
"""
"""
audio_file
=
input_file
audio_file
=
input_file
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
logger
.
info
(
f
"Preprocess audio file:
{
audio_file
}
"
)
logger
.
debug
(
f
"Preprocess audio file:
{
audio_file
}
"
)
# stage 1: load the audio sample points
# stage 1: load the audio sample points
# Note: this process must match the training process
# Note: this process must match the training process
waveform
,
sr
=
load_audio
(
audio_file
)
waveform
,
sr
=
load_audio
(
audio_file
)
logger
.
info
(
f
"load the audio sample points, shape is:
{
waveform
.
shape
}
"
)
logger
.
debug
(
f
"load the audio sample points, shape is:
{
waveform
.
shape
}
"
)
# stage 2: get the audio feat
# stage 2: get the audio feat
# Note: Now we only support fbank feature
# Note: Now we only support fbank feature
...
@@ -396,9 +397,9 @@ class VectorExecutor(BaseExecutor):
...
@@ -396,9 +397,9 @@ class VectorExecutor(BaseExecutor):
n_mels
=
self
.
config
.
n_mels
,
n_mels
=
self
.
config
.
n_mels
,
window_size
=
self
.
config
.
window_size
,
window_size
=
self
.
config
.
window_size
,
hop_length
=
self
.
config
.
hop_size
)
hop_length
=
self
.
config
.
hop_size
)
logger
.
info
(
f
"extract the audio feat, shape is:
{
feat
.
shape
}
"
)
logger
.
debug
(
f
"extract the audio feat, shape is:
{
feat
.
shape
}
"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
info
(
f
"feat occurs exception
{
e
}
"
)
logger
.
debug
(
f
"feat occurs exception
{
e
}
"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
feat
=
paddle
.
to_tensor
(
feat
).
unsqueeze
(
0
)
feat
=
paddle
.
to_tensor
(
feat
).
unsqueeze
(
0
)
...
@@ -411,11 +412,11 @@ class VectorExecutor(BaseExecutor):
...
@@ -411,11 +412,11 @@ class VectorExecutor(BaseExecutor):
# stage 4: store the feat and length in the _inputs,
# stage 4: store the feat and length in the _inputs,
# which will be used in other function
# which will be used in other function
logger
.
info
(
f
"feats shape:
{
feat
.
shape
}
"
)
logger
.
debug
(
f
"feats shape:
{
feat
.
shape
}
"
)
self
.
_inputs
[
"feats"
]
=
feat
self
.
_inputs
[
"feats"
]
=
feat
self
.
_inputs
[
"lengths"
]
=
lengths
self
.
_inputs
[
"lengths"
]
=
lengths
logger
.
info
(
"audio extract the feat success"
)
logger
.
debug
(
"audio extract the feat success"
)
def
_check
(
self
,
audio_file
:
str
,
sample_rate
:
int
):
def
_check
(
self
,
audio_file
:
str
,
sample_rate
:
int
):
"""Check if the model sample match the audio sample rate
"""Check if the model sample match the audio sample rate
...
@@ -441,7 +442,7 @@ class VectorExecutor(BaseExecutor):
...
@@ -441,7 +442,7 @@ class VectorExecutor(BaseExecutor):
logger
.
error
(
"Please input the right audio file path"
)
logger
.
error
(
"Please input the right audio file path"
)
return
False
return
False
logger
.
info
(
"checking the aduio file format......"
)
logger
.
debug
(
"checking the aduio file format......"
)
try
:
try
:
audio
,
audio_sample_rate
=
soundfile
.
read
(
audio
,
audio_sample_rate
=
soundfile
.
read
(
audio_file
,
dtype
=
"float32"
,
always_2d
=
True
)
audio_file
,
dtype
=
"float32"
,
always_2d
=
True
)
...
@@ -458,7 +459,7 @@ class VectorExecutor(BaseExecutor):
...
@@ -458,7 +459,7 @@ class VectorExecutor(BaseExecutor):
"
)
"
)
return
False
return
False
logger
.
info
(
f
"The sample rate is
{
audio_sample_rate
}
"
)
logger
.
debug
(
f
"The sample rate is
{
audio_sample_rate
}
"
)
if
audio_sample_rate
!=
self
.
sample_rate
:
if
audio_sample_rate
!=
self
.
sample_rate
:
logger
.
error
(
"The sample rate of the input file is not {}.
\n
\
logger
.
error
(
"The sample rate of the input file is not {}.
\n
\
...
@@ -468,6 +469,6 @@ class VectorExecutor(BaseExecutor):
...
@@ -468,6 +469,6 @@ class VectorExecutor(BaseExecutor):
"
.
format
(
self
.
sample_rate
,
self
.
sample_rate
))
"
.
format
(
self
.
sample_rate
,
self
.
sample_rate
))
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
else
:
else
:
logger
.
info
(
"The audio file format is right"
)
logger
.
debug
(
"The audio file format is right"
)
return
True
return
True
paddlespeech/s2t/frontend/augmentor/spec_augment.py
浏览文件 @
bc93bffb
...
@@ -16,7 +16,7 @@ import random
...
@@ -16,7 +16,7 @@ import random
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
from
PIL.Image
import
BICUBIC
from
PIL.Image
import
Resampling
from
paddlespeech.s2t.frontend.augmentor.base
import
AugmentorBase
from
paddlespeech.s2t.frontend.augmentor.base
import
AugmentorBase
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
...
@@ -164,9 +164,9 @@ class SpecAugmentor(AugmentorBase):
...
@@ -164,9 +164,9 @@ class SpecAugmentor(AugmentorBase):
window
)
+
1
# 1 ... t - 1
window
)
+
1
# 1 ... t - 1
left
=
Image
.
fromarray
(
x
[:
center
]).
resize
((
x
.
shape
[
1
],
warped
),
left
=
Image
.
fromarray
(
x
[:
center
]).
resize
((
x
.
shape
[
1
],
warped
),
BICUBIC
)
Resampling
.
BICUBIC
)
right
=
Image
.
fromarray
(
x
[
center
:]).
resize
((
x
.
shape
[
1
],
t
-
warped
),
right
=
Image
.
fromarray
(
x
[
center
:]).
resize
((
x
.
shape
[
1
],
t
-
warped
),
BICUBIC
)
Resampling
.
BICUBIC
)
if
self
.
inplace
:
if
self
.
inplace
:
x
[:
warped
]
=
left
x
[:
warped
]
=
left
x
[
warped
:]
=
right
x
[
warped
:]
=
right
...
...
paddlespeech/s2t/frontend/featurizer/text_featurizer.py
浏览文件 @
bc93bffb
...
@@ -226,10 +226,10 @@ class TextFeaturizer():
...
@@ -226,10 +226,10 @@ class TextFeaturizer():
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
debug
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
debug
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
info
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
debug
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
debug
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
debug
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
logger
.
debug
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
,
blank_id
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
,
blank_id
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
bc93bffb
...
@@ -827,7 +827,7 @@ class U2Model(U2DecodeModel):
...
@@ -827,7 +827,7 @@ class U2Model(U2DecodeModel):
# encoder
# encoder
encoder_type
=
configs
.
get
(
'encoder'
,
'transformer'
)
encoder_type
=
configs
.
get
(
'encoder'
,
'transformer'
)
logger
.
info
(
f
"U2 Encoder type:
{
encoder_type
}
"
)
logger
.
debug
(
f
"U2 Encoder type:
{
encoder_type
}
"
)
if
encoder_type
==
'transformer'
:
if
encoder_type
==
'transformer'
:
encoder
=
TransformerEncoder
(
encoder
=
TransformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
...
@@ -894,7 +894,7 @@ class U2Model(U2DecodeModel):
...
@@ -894,7 +894,7 @@ class U2Model(U2DecodeModel):
if
checkpoint_path
:
if
checkpoint_path
:
infos
=
checkpoint
.
Checkpoint
().
load_parameters
(
infos
=
checkpoint
.
Checkpoint
().
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
logger
.
debug
(
f
"checkpoint info:
{
infos
}
"
)
layer_tools
.
summary
(
model
)
layer_tools
.
summary
(
model
)
return
model
return
model
...
...
paddlespeech/s2t/modules/loss.py
浏览文件 @
bc93bffb
...
@@ -37,9 +37,9 @@ class CTCLoss(nn.Layer):
...
@@ -37,9 +37,9 @@ class CTCLoss(nn.Layer):
self
.
loss
=
nn
.
CTCLoss
(
blank
=
blank
,
reduction
=
reduction
)
self
.
loss
=
nn
.
CTCLoss
(
blank
=
blank
,
reduction
=
reduction
)
self
.
batch_average
=
batch_average
self
.
batch_average
=
batch_average
logger
.
info
(
logger
.
debug
(
f
"CTCLoss Loss reduction:
{
reduction
}
, div-bs:
{
batch_average
}
"
)
f
"CTCLoss Loss reduction:
{
reduction
}
, div-bs:
{
batch_average
}
"
)
logger
.
info
(
f
"CTCLoss Grad Norm Type:
{
grad_norm_type
}
"
)
logger
.
debug
(
f
"CTCLoss Grad Norm Type:
{
grad_norm_type
}
"
)
assert
grad_norm_type
in
(
'instance'
,
'batch'
,
'frame'
,
None
)
assert
grad_norm_type
in
(
'instance'
,
'batch'
,
'frame'
,
None
)
self
.
norm_by_times
=
False
self
.
norm_by_times
=
False
...
@@ -70,7 +70,8 @@ class CTCLoss(nn.Layer):
...
@@ -70,7 +70,8 @@ class CTCLoss(nn.Layer):
param
=
{}
param
=
{}
self
.
_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
param
}
self
.
_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
param
}
_notin
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
not
in
param
}
_notin
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
not
in
param
}
logger
.
info
(
f
"
{
self
.
loss
}
kwargs:
{
self
.
_kwargs
}
, not support:
{
_notin
}
"
)
logger
.
debug
(
f
"
{
self
.
loss
}
kwargs:
{
self
.
_kwargs
}
, not support:
{
_notin
}
"
)
def
forward
(
self
,
logits
,
ys_pad
,
hlens
,
ys_lens
):
def
forward
(
self
,
logits
,
ys_pad
,
hlens
,
ys_lens
):
"""Compute CTC loss.
"""Compute CTC loss.
...
...
paddlespeech/s2t/transform/spec_augment.py
浏览文件 @
bc93bffb
...
@@ -17,7 +17,7 @@ import random
...
@@ -17,7 +17,7 @@ import random
import
numpy
import
numpy
from
PIL
import
Image
from
PIL
import
Image
from
PIL.Image
import
BICUBIC
from
PIL.Image
import
Resampling
from
paddlespeech.s2t.transform.functional
import
FuncTrans
from
paddlespeech.s2t.transform.functional
import
FuncTrans
...
@@ -46,9 +46,10 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
...
@@ -46,9 +46,10 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
warped
=
random
.
randrange
(
center
-
window
,
center
+
warped
=
random
.
randrange
(
center
-
window
,
center
+
window
)
+
1
# 1 ... t - 1
window
)
+
1
# 1 ... t - 1
left
=
Image
.
fromarray
(
x
[:
center
]).
resize
((
x
.
shape
[
1
],
warped
),
BICUBIC
)
left
=
Image
.
fromarray
(
x
[:
center
]).
resize
((
x
.
shape
[
1
],
warped
),
Resampling
.
BICUBIC
)
right
=
Image
.
fromarray
(
x
[
center
:]).
resize
((
x
.
shape
[
1
],
t
-
warped
),
right
=
Image
.
fromarray
(
x
[
center
:]).
resize
((
x
.
shape
[
1
],
t
-
warped
),
BICUBIC
)
Resampling
.
BICUBIC
)
if
inplace
:
if
inplace
:
x
[:
warped
]
=
left
x
[:
warped
]
=
left
x
[
warped
:]
=
right
x
[
warped
:]
=
right
...
...
paddlespeech/s2t/utils/tensor_utils.py
浏览文件 @
bc93bffb
...
@@ -94,7 +94,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
...
@@ -94,7 +94,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
for
i
,
tensor
in
enumerate
(
sequences
):
for
i
,
tensor
in
enumerate
(
sequences
):
length
=
tensor
.
shape
[
0
]
length
=
tensor
.
shape
[
0
]
# use index notation to prevent duplicate references to the tensor
# use index notation to prevent duplicate references to the tensor
logger
.
info
(
logger
.
debug
(
f
"length
{
length
}
, out_tensor
{
out_tensor
.
shape
}
, tensor
{
tensor
.
shape
}
"
f
"length
{
length
}
, out_tensor
{
out_tensor
.
shape
}
, tensor
{
tensor
.
shape
}
"
)
)
if
batch_first
:
if
batch_first
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录