Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d64f6e9e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d64f6e9e
编写于
11月 01, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the feature: caculating the perplexity of transformerLM
上级
fc8a7a15
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
357 addition
and
5 deletion
+357
-5
deepspeech/exps/lm/transformer/__init__.py
deepspeech/exps/lm/transformer/__init__.py
+13
-0
deepspeech/exps/lm/transformer/bin/cacu_perplexity.py
deepspeech/exps/lm/transformer/bin/cacu_perplexity.py
+82
-0
deepspeech/exps/lm/transformer/lm_cacu_perplexity.py
deepspeech/exps/lm/transformer/lm_cacu_perplexity.py
+132
-0
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+2
-2
deepspeech/io/collator.py
deepspeech/io/collator.py
+39
-1
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+19
-0
deepspeech/models/lm/transformer.py
deepspeech/models/lm/transformer.py
+5
-2
examples/librispeech/s2/conf/lm/transformer.yaml
examples/librispeech/s2/conf/lm/transformer.yaml
+8
-0
examples/librispeech/s2/local/cacu_perplexity.sh
examples/librispeech/s2/local/cacu_perplexity.sh
+53
-0
examples/librispeech/s2/run.sh
examples/librispeech/s2/run.sh
+4
-0
未找到文件。
deepspeech/exps/lm/transformer/__init__.py
0 → 100644
浏览文件 @
d64f6e9e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech/exps/lm/transformer/bin/cacu_perplexity.py
0 → 100644
浏览文件 @
d64f6e9e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
configargparse
def
get_parser
():
"""Get default arguments."""
parser
=
configargparse
.
ArgumentParser
(
description
=
"The parser for caculating the perplexity of transformer language model "
,
config_file_parser_class
=
configargparse
.
YAMLConfigFileParser
,
formatter_class
=
configargparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
"--rnnlm"
,
type
=
str
,
default
=
None
,
help
=
"RNNLM model file to read"
)
parser
.
add_argument
(
"--rnnlm-conf"
,
type
=
str
,
default
=
None
,
help
=
"RNNLM model config file to read"
)
parser
.
add_argument
(
"--vocab_path"
,
type
=
str
,
default
=
None
,
help
=
"vocab path to for token2id"
)
parser
.
add_argument
(
"--bpeprefix"
,
type
=
str
,
default
=
None
,
help
=
"The path of bpeprefix for loading"
)
parser
.
add_argument
(
"--text_path"
,
type
=
str
,
default
=
None
,
help
=
"The path of text file for testing "
)
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
0
,
help
=
"The number of gpu to use, 0 for using cpu instead"
)
parser
.
add_argument
(
"--dtype"
,
choices
=
(
"float16"
,
"float32"
,
"float64"
),
default
=
"float32"
,
help
=
"Float precision (only available in --api v2)"
,
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"."
,
help
=
"The output directory to store the sentence PPL"
)
return
parser
def
main
(
args
):
parser
=
get_parser
()
args
=
parser
.
parse_args
(
args
)
from
deepspeech.exps.lm.transformer.lm_cacu_perplexity
import
run_get_perplexity
run_get_perplexity
(
args
)
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
deepspeech/exps/lm/transformer/lm_cacu_perplexity.py
0 → 100644
浏览文件 @
d64f6e9e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Caculating the PPL of LM model
import
os
import
numpy
as
np
import
paddle
from
paddle.io
import
DataLoader
from
yacs.config
import
CfgNode
from
deepspeech.io.collator
import
TextCollatorSpm
from
deepspeech.io.dataset
import
TextDataset
from
deepspeech.models.lm_interface
import
dynamic_import_lm
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
def
get_config
(
config_path
):
confs
=
CfgNode
(
new_allowed
=
True
)
confs
.
merge_from_file
(
config_path
)
return
confs
def
load_trained_lm
(
args
):
lm_config
=
get_config
(
args
.
rnnlm_conf
)
lm_model_module
=
lm_config
.
model_module
lm_class
=
dynamic_import_lm
(
lm_model_module
)
lm
=
lm_class
(
**
lm_config
.
model
)
model_dict
=
paddle
.
load
(
args
.
rnnlm
)
lm
.
set_state_dict
(
model_dict
)
return
lm
,
lm_config
def
write_dict_into_file
(
ppl_dict
,
name
):
with
open
(
name
,
"w"
)
as
f
:
for
key
in
ppl_dict
.
keys
():
f
.
write
(
key
+
" "
+
ppl_dict
[
key
]
+
"
\n
"
)
return
def
cacu_perplexity
(
lm_model
,
lm_config
,
args
,
log_base
=
None
,
):
unit_type
=
lm_config
.
data
.
unit_type
batch_size
=
lm_config
.
decoding
.
batch_size
num_workers
=
lm_config
.
decoding
.
num_workers
text_file_path
=
args
.
text_path
total_nll
=
0.0
total_ntokens
=
0
ppl_dict
=
{}
len_dict
=
{}
text_dataset
=
TextDataset
.
from_file
(
text_file_path
)
collate_fn_text
=
TextCollatorSpm
(
unit_type
=
unit_type
,
vocab_filepath
=
args
.
vocab_path
,
spm_model_prefix
=
args
.
bpeprefix
)
train_loader
=
DataLoader
(
text_dataset
,
batch_size
=
batch_size
,
collate_fn
=
collate_fn_text
,
num_workers
=
num_workers
)
logger
.
info
(
"start caculating PPL......"
)
for
i
,
(
keys
,
ys_input_pad
,
ys_output_pad
,
y_lens
)
in
enumerate
(
train_loader
()):
ys_input_pad
=
paddle
.
to_tensor
(
ys_input_pad
)
ys_output_pad
=
paddle
.
to_tensor
(
ys_output_pad
)
_
,
unused_logp
,
unused_count
,
nll
,
nll_count
=
lm_model
.
forward
(
ys_input_pad
,
ys_output_pad
)
nll
=
nll
.
numpy
()
nll_count
=
nll_count
.
numpy
()
for
key
,
_nll
,
ntoken
in
zip
(
keys
,
nll
,
nll_count
):
if
log_base
is
None
:
utt_ppl
=
np
.
exp
(
_nll
/
ntoken
)
else
:
utt_ppl
=
log_base
**
(
_nll
/
ntoken
/
np
.
log
(
log_base
))
# Write PPL of each utts for debugging or analysis
ppl_dict
[
key
]
=
str
(
utt_ppl
)
len_dict
[
key
]
=
str
(
ntoken
)
total_nll
+=
nll
.
sum
()
total_ntokens
+=
nll_count
.
sum
()
logger
.
info
(
"Current total nll: "
+
str
(
total_nll
))
logger
.
info
(
"Current total tokens: "
+
str
(
total_ntokens
))
write_dict_into_file
(
ppl_dict
,
os
.
path
.
join
(
args
.
output_dir
,
"uttPPL"
))
write_dict_into_file
(
len_dict
,
os
.
path
.
join
(
args
.
output_dir
,
"uttLEN"
))
if
log_base
is
None
:
ppl
=
np
.
exp
(
total_nll
/
total_ntokens
)
else
:
ppl
=
log_base
**
(
total_nll
/
total_ntokens
/
np
.
log
(
log_base
))
if
log_base
is
None
:
log_base
=
np
.
e
else
:
log_base
=
log_base
return
ppl
,
log_base
def
run_get_perplexity
(
args
):
if
args
.
ngpu
>
1
:
raise
NotImplementedError
(
"only single GPU decoding is supported"
)
if
args
.
ngpu
==
1
:
device
=
"gpu:0"
else
:
device
=
"cpu"
paddle
.
set_device
(
device
)
dtype
=
getattr
(
paddle
,
args
.
dtype
)
logger
.
info
(
f
"Decoding device=
{
device
}
, dtype=
{
dtype
}
"
)
lm_model
,
lm_config
=
load_trained_lm
(
args
)
lm_model
.
to
(
device
=
device
,
dtype
=
dtype
)
lm_model
.
eval
()
PPL
,
log_base
=
cacu_perplexity
(
lm_model
,
lm_config
,
args
,
None
)
logger
.
info
(
"Final PPL: "
+
str
(
PPL
))
logger
.
info
(
"The log base is:"
+
str
(
"%.2f"
%
log_base
))
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
d64f6e9e
...
@@ -53,7 +53,7 @@ class TextFeaturizer():
...
@@ -53,7 +53,7 @@ class TextFeaturizer():
self
.
maskctc
=
maskctc
self
.
maskctc
=
maskctc
if
vocab_filepath
:
if
vocab_filepath
:
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
self
.
unk_id
,
self
.
eos_id
=
self
.
_load_vocabulary_from_file
(
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
self
.
unk_id
,
self
.
eos_id
,
self
.
blank_id
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
,
maskctc
)
vocab_filepath
,
maskctc
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
...
@@ -227,4 +227,4 @@ class TextFeaturizer():
...
@@ -227,4 +227,4 @@ class TextFeaturizer():
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
,
blank_id
deepspeech/io/collator.py
浏览文件 @
d64f6e9e
...
@@ -19,6 +19,7 @@ from yacs.config import CfgNode
...
@@ -19,6 +19,7 @@ from yacs.config import CfgNode
from
deepspeech.frontend.augmentor.augmentation
import
AugmentationPipeline
from
deepspeech.frontend.augmentor.augmentation
import
AugmentationPipeline
from
deepspeech.frontend.featurizer.speech_featurizer
import
SpeechFeaturizer
from
deepspeech.frontend.featurizer.speech_featurizer
import
SpeechFeaturizer
from
deepspeech.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
deepspeech.frontend.normalizer
import
FeatureNormalizer
from
deepspeech.frontend.normalizer
import
FeatureNormalizer
from
deepspeech.frontend.speech
import
SpeechSegment
from
deepspeech.frontend.speech
import
SpeechSegment
from
deepspeech.frontend.utility
import
IGNORE_ID
from
deepspeech.frontend.utility
import
IGNORE_ID
...
@@ -33,7 +34,7 @@ logger = Log(__name__).getlog()
...
@@ -33,7 +34,7 @@ logger = Log(__name__).getlog()
def
_tokenids
(
text
,
keep_transcription_text
):
def
_tokenids
(
text
,
keep_transcription_text
):
# for training text is token ids
# for training text is token ids
tokens
=
text
# token ids
tokens
=
text
# token ids
if
keep_transcription_text
:
if
keep_transcription_text
:
...
@@ -45,6 +46,43 @@ def _tokenids(text, keep_transcription_text):
...
@@ -45,6 +46,43 @@ def _tokenids(text, keep_transcription_text):
return
tokens
return
tokens
class
TextCollatorSpm
():
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
):
assert
(
vocab_filepath
is
not
None
)
self
.
text_featurizer
=
TextFeaturizer
(
unit_type
=
unit_type
,
vocab_filepath
=
vocab_filepath
,
spm_model_prefix
=
spm_model_prefix
)
self
.
eos_id
=
self
.
text_featurizer
.
eos_id
self
.
blank_id
=
self
.
text_featurizer
.
blank_id
def
__call__
(
self
,
batch
):
"""
return type [List, np.array [B, T], np.array [B, T], np.array[B]]
"""
keys
=
[]
texts
=
[]
texts_input
=
[]
texts_output
=
[]
text_lens
=
[]
for
idx
,
item
in
enumerate
(
batch
):
key
=
item
.
split
(
" "
)[
0
].
strip
()
text
=
" "
.
join
(
item
.
split
(
" "
)[
1
:])
keys
.
append
(
key
)
token_ids
=
self
.
text_featurizer
.
featurize
(
text
)
texts_input
.
append
(
np
.
array
([
self
.
eos_id
]
+
token_ids
).
astype
(
np
.
int64
))
texts_output
.
append
(
np
.
array
(
token_ids
+
[
self
.
eos_id
]).
astype
(
np
.
int64
))
text_lens
.
append
(
len
(
token_ids
)
+
1
)
ys_input_pad
=
pad_list
(
texts_input
,
self
.
blank_id
).
astype
(
np
.
int64
)
ys_output_pad
=
pad_list
(
texts_output
,
self
.
blank_id
).
astype
(
np
.
int64
)
y_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
keys
,
ys_input_pad
,
ys_output_pad
,
y_lens
class
SpeechCollatorBase
():
class
SpeechCollatorBase
():
def
__init__
(
def
__init__
(
self
,
self
,
...
...
deepspeech/io/dataset.py
浏览文件 @
d64f6e9e
...
@@ -24,6 +24,25 @@ __all__ = ["ManifestDataset", "TransformDataset"]
...
@@ -24,6 +24,25 @@ __all__ = ["ManifestDataset", "TransformDataset"]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
class
TextDataset
(
Dataset
):
@
classmethod
def
from_file
(
cls
,
file_path
):
dataset
=
cls
(
file_path
)
return
dataset
def
__init__
(
self
,
file_path
):
self
.
_manifest
=
[]
with
open
(
file_path
)
as
f
:
for
line
in
f
:
self
.
_manifest
.
append
(
line
.
strip
())
def
__len__
(
self
):
return
len
(
self
.
_manifest
)
def
__getitem__
(
self
,
idx
):
return
self
.
_manifest
[
idx
]
class
ManifestDataset
(
Dataset
):
class
ManifestDataset
(
Dataset
):
@
classmethod
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
...
...
deepspeech/models/lm/transformer.py
浏览文件 @
d64f6e9e
...
@@ -111,6 +111,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
...
@@ -111,6 +111,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
"""
batch_size
=
x
.
size
(
0
)
xm
=
x
!=
0
xm
=
x
!=
0
xlen
=
xm
.
sum
(
axis
=
1
)
xlen
=
xm
.
sum
(
axis
=
1
)
if
self
.
embed_drop
is
not
None
:
if
self
.
embed_drop
is
not
None
:
...
@@ -121,11 +122,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
...
@@ -121,11 +122,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
y
=
self
.
decoder
(
h
)
y
=
self
.
decoder
(
h
)
loss
=
F
.
cross_entropy
(
loss
=
F
.
cross_entropy
(
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
mask
=
xm
.
to
(
dtype
=
loss
.
dtype
)
mask
=
xm
.
to
(
loss
.
dtype
)
logp
=
loss
*
mask
.
view
(
-
1
)
logp
=
loss
*
mask
.
view
(
-
1
)
nll
=
logp
.
view
(
batch_size
,
-
1
).
sum
(
-
1
)
nll_count
=
mask
.
sum
(
-
1
)
logp
=
logp
.
sum
()
logp
=
logp
.
sum
()
count
=
mask
.
sum
()
count
=
mask
.
sum
()
return
logp
/
count
,
logp
,
count
return
logp
/
count
,
logp
,
count
,
nll
,
nll_count
# beam search API (see ScorerInterface)
# beam search API (see ScorerInterface)
def
score
(
self
,
y
:
paddle
.
Tensor
,
state
:
Any
,
def
score
(
self
,
y
:
paddle
.
Tensor
,
state
:
Any
,
...
...
examples/librispeech/s2/conf/lm/transformer.yaml
浏览文件 @
d64f6e9e
model_module
:
transformer
model_module
:
transformer
data
:
unit_type
:
spm
model
:
model
:
n_vocab
:
5002
n_vocab
:
5002
pos_enc
:
null
pos_enc
:
null
...
@@ -11,3 +15,7 @@ model:
...
@@ -11,3 +15,7 @@ model:
emb_dropout_rate
:
0.0
emb_dropout_rate
:
0.0
att_dropout_rate
:
0.0
att_dropout_rate
:
0.0
tie_weights
:
False
tie_weights
:
False
decoding
:
batch_size
:
30
num_workers
:
2
examples/librispeech/s2/local/cacu_perplexity.sh
0 → 100755
浏览文件 @
d64f6e9e
#!/bin/bash
set
-e
stage
=
-1
stop_stage
=
100
expdir
=
exp
datadir
=
data
ngpu
=
0
# lm params
rnnlm_config_path
=
conf/lm/transformer.yaml
lmexpdir
=
exp/lm/transformer
lang_model
=
transformerLM.pdparams
#data path
test_set
=
${
datadir
}
/test_clean/text
test_set_lower
=
${
datadir
}
/test_clean/text_lower
train_set
=
train_960
# bpemode (unigram or bpe)
nbpe
=
5000
bpemode
=
unigram
bpeprefix
=
${
datadir
}
/lang_char/
${
train_set
}
_
${
bpemode
}${
nbpe
}
bpemodel
=
${
bpeprefix
}
.model
vocabfile
=
${
bpeprefix
}
_units.txt
vocabfile_lower
=
${
bpeprefix
}
_units_lower.txt
output_dir
=
${
expdir
}
/lm/transformer/perplexity
mkdir
-p
${
output_dir
}
# Transform the data upper case to lower
if
[
-f
${
vocabfile
}
]
;
then
tr
A-Z a-z <
${
vocabfile
}
>
${
vocabfile_lower
}
fi
if
[
-f
${
test_set
}
]
;
then
tr
A-Z a-z <
${
test_set
}
>
${
test_set_lower
}
fi
python
${
LM_BIN_DIR
}
/cacu_perplexity.py
\
--rnnlm
${
lmexpdir
}
/
${
lang_model
}
\
--rnnlm-conf
${
rnnlm_config_path
}
\
--vocab_path
${
vocabfile_lower
}
\
--bpeprefix
${
bpeprefix
}
\
--text_path
${
test_set_lower
}
\
--output_dir
${
output_dir
}
\
--ngpu
${
ngpu
}
examples/librispeech/s2/run.sh
浏览文件 @
d64f6e9e
...
@@ -51,3 +51,7 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
...
@@ -51,3 +51,7 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# export ckpt avg_n
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
CUDA_VISIBLE_DEVICES
=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
fi
fi
if
[
${
stage
}
-le
7
]
&&
[
${
stop_stage
}
-ge
7
]
;
then
CUDA_VISIBLE_DEVICES
=
./local/cacu_perplexity.sh
||
exit
-1
fi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录