Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
311fa87a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
311fa87a
编写于
3月 13, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add some comments to the code
上级
8ed5c287
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
99 addition
and
19 deletion
+99
-19
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
+13
-7
examples/voxceleb/sv0/run.sh
examples/voxceleb/sv0/run.sh
+2
-1
paddleaudio/paddleaudio/metric/__init__.py
paddleaudio/paddleaudio/metric/__init__.py
+2
-1
paddleaudio/paddleaudio/metric/eer.py
paddleaudio/paddleaudio/metric/eer.py
+66
-0
paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
+0
-0
paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py
...ech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py
+8
-6
paddlespeech/vector/exps/ecapa_tdnn/train.py
paddlespeech/vector/exps/ecapa_tdnn/train.py
+3
-3
paddlespeech/vector/io/augment.py
paddlespeech/vector/io/augment.py
+2
-1
paddlespeech/vector/io/batch.py
paddlespeech/vector/io/batch.py
+3
-0
未找到文件。
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
浏览文件 @
311fa87a
###########################################
# Data #
###########################################
batch_size
:
32
num_workers
:
2
num_speakers
:
7205
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle
:
True
random_chunk
:
True
###########################################################
###########################################################
# FEATURE EXTRACTION SETTING #
# FEATURE EXTRACTION SETTING #
###########################################################
###########################################################
...
@@ -7,7 +16,6 @@ feature:
...
@@ -7,7 +16,6 @@ feature:
window_size
:
400
#25ms, sample rate 16000, 25 * 16000 / 1000 = 400
window_size
:
400
#25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_length
:
160
#10ms, sample rate 16000, 10 * 16000 / 1000 = 160
hop_length
:
160
#10ms, sample rate 16000, 10 * 16000 / 1000 = 160
###########################################################
###########################################################
# MODEL SETTING #
# MODEL SETTING #
###########################################################
###########################################################
...
@@ -15,9 +23,8 @@ feature:
...
@@ -15,9 +23,8 @@ feature:
# if we want use another model, please choose another configuration yaml file
# if we want use another model, please choose another configuration yaml file
model
:
model
:
input_size
:
80
input_size
:
80
##"channels": [1024, 1024, 1024, 1024, 3072],
# "channels": [512, 512, 512, 512, 1536],
# "channels": [512, 512, 512, 512, 1536],
channels
:
[
512
,
512
,
512
,
512
,
1536
]
channels
:
[
1024
,
1024
,
1024
,
1024
,
3072
]
kernel_sizes
:
[
5
,
3
,
3
,
3
,
1
]
kernel_sizes
:
[
5
,
3
,
3
,
3
,
1
]
dilations
:
[
1
,
2
,
3
,
4
,
1
]
dilations
:
[
1
,
2
,
3
,
4
,
1
]
attention_channels
:
128
attention_channels
:
128
...
@@ -26,10 +33,9 @@ model:
...
@@ -26,10 +33,9 @@ model:
###########################################
###########################################
# Training #
# Training #
###########################################
###########################################
seed
:
0
seed
:
1986
# according from speechbrain configuration
epochs
:
10
epochs
:
10
batch_size
:
32
num_workers
:
2
save_freq
:
10
save_freq
:
10
log_
freq
:
10
log_
interval
:
10
learning_rate
:
1e-8
learning_rate
:
1e-8
examples/voxceleb/sv0/run.sh
浏览文件 @
311fa87a
...
@@ -47,7 +47,8 @@ mkdir -p ${exp_dir}
...
@@ -47,7 +47,8 @@ mkdir -p ${exp_dir}
if
[
$stage
-le
0
]
;
then
if
[
$stage
-le
0
]
;
then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
python3
local
/data_prepare.py
\
python3
local
/data_prepare.py
\
--data-dir
${
dir
}
--augment
--vox2-base-path
${
vox2_base_path
}
--data-dir
${
dir
}
--augment
--vox2-base-path
${
vox2_base_path
}
\
--config
conf/ecapa_tdnn.yaml
fi
fi
if
[
$stage
-le
1
]
;
then
if
[
$stage
-le
1
]
;
then
...
...
paddleaudio/paddleaudio/metric/__init__.py
浏览文件 @
311fa87a
...
@@ -12,5 +12,6 @@
...
@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.dtw
import
dtw_distance
from
.dtw
import
dtw_distance
from
.mcd
import
mcd_distance
from
.eer
import
compute_eer
from
.eer
import
compute_eer
from
.eer
import
compute_minDCF
from
.mcd
import
mcd_distance
paddleaudio/paddleaudio/metric/eer.py
浏览文件 @
311fa87a
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
import
paddle
from
sklearn.metrics
import
roc_curve
from
sklearn.metrics
import
roc_curve
...
@@ -26,3 +27,68 @@ def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
...
@@ -26,3 +27,68 @@ def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
eer_threshold
=
threshold
[
np
.
nanargmin
(
np
.
absolute
((
fnr
-
fpr
)))]
eer_threshold
=
threshold
[
np
.
nanargmin
(
np
.
absolute
((
fnr
-
fpr
)))]
eer
=
fpr
[
np
.
nanargmin
(
np
.
absolute
((
fnr
-
fpr
)))]
eer
=
fpr
[
np
.
nanargmin
(
np
.
absolute
((
fnr
-
fpr
)))]
return
eer
,
eer_threshold
return
eer
,
eer_threshold
def
compute_minDCF
(
positive_scores
,
negative_scores
,
c_miss
=
1.0
,
c_fa
=
1.0
,
p_target
=
0.01
):
"""
This is modified from SpeechBrain
https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/utils/metric_stats.py#L509
Computes the minDCF metric normally used to evaluate speaker verification
systems. The min_DCF is the minimum of the following C_det function computed
within the defined threshold range:
C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)
where p_miss is the missing probability and p_fa is the probability of having
a false alarm.
Args:
positive_scores (Paddle.Tensor): The scores from entries of the same class.
negative_scores (Paddle.Tensor): The scores from entries of different classes.
c_miss (float, optional): Cost assigned to a missing error (default 1.0).
c_fa (float, optional): Cost assigned to a false alarm (default 1.0).
p_target (float, optional): Prior probability of having a target (default 0.01).
Returns:
_type_: min dcf
"""
# Computing candidate thresholds
if
len
(
positive_scores
.
shape
)
>
1
:
positive_scores
=
positive_scores
.
squeeze
()
if
len
(
negative_scores
.
shape
)
>
1
:
negative_scores
=
negative_scores
.
squeeze
()
thresholds
=
paddle
.
sort
(
paddle
.
concat
([
positive_scores
,
negative_scores
]))
thresholds
=
paddle
.
unique
(
thresholds
)
# Adding intermediate thresholds
interm_thresholds
=
(
thresholds
[
0
:
-
1
]
+
thresholds
[
1
:])
/
2
thresholds
=
paddle
.
sort
(
paddle
.
concat
([
thresholds
,
interm_thresholds
]))
# Computing False Rejection Rate (miss detection)
positive_scores
=
paddle
.
concat
(
len
(
thresholds
)
*
[
positive_scores
.
unsqueeze
(
0
)])
pos_scores_threshold
=
positive_scores
.
transpose
(
perm
=
[
1
,
0
])
<=
thresholds
p_miss
=
(
pos_scores_threshold
.
sum
(
0
)
).
astype
(
"float32"
)
/
positive_scores
.
shape
[
1
]
del
positive_scores
del
pos_scores_threshold
# Computing False Acceptance Rate (false alarm)
negative_scores
=
paddle
.
concat
(
len
(
thresholds
)
*
[
negative_scores
.
unsqueeze
(
0
)])
neg_scores_threshold
=
negative_scores
.
transpose
(
perm
=
[
1
,
0
])
>
thresholds
p_fa
=
(
neg_scores_threshold
.
sum
(
0
)
).
astype
(
"float32"
)
/
negative_scores
.
shape
[
1
]
del
negative_scores
del
neg_scores_threshold
c_det
=
c_miss
*
p_miss
*
p_target
+
c_fa
*
p_fa
*
(
1
-
p_target
)
c_min
=
paddle
.
min
(
c_det
,
axis
=
0
)
min_index
=
paddle
.
argmin
(
c_det
,
axis
=
0
)
return
float
(
c_min
),
float
(
thresholds
[
min_index
])
paddlespeech/vector/exps/ecapa_tdnn/extract_
speaker_embedding
.py
→
paddlespeech/vector/exps/ecapa_tdnn/extract_
emb
.py
浏览文件 @
311fa87a
文件已移动
paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py
浏览文件 @
311fa87a
...
@@ -45,7 +45,7 @@ def main(args, config):
...
@@ -45,7 +45,7 @@ def main(args, config):
# stage2: build the speaker verification eval instance with backbone model
# stage2: build the speaker verification eval instance with backbone model
model
=
SpeakerIdetification
(
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
VoxCeleb
.
num_speakers
)
backbone
=
ecapa_tdnn
,
num_class
=
config
.
num_speakers
)
# stage3: load the pre-trained model
# stage3: load the pre-trained model
args
.
load_checkpoint
=
os
.
path
.
abspath
(
args
.
load_checkpoint
=
os
.
path
.
abspath
(
...
@@ -93,6 +93,7 @@ def main(args, config):
...
@@ -93,6 +93,7 @@ def main(args, config):
model
.
eval
()
model
.
eval
()
# stage7: global embedding norm to imporve the performance
# stage7: global embedding norm to imporve the performance
print
(
"global embedding norm: {}"
.
format
(
args
.
global_embedding_norm
))
if
args
.
global_embedding_norm
:
if
args
.
global_embedding_norm
:
global_embedding_mean
=
None
global_embedding_mean
=
None
global_embedding_std
=
None
global_embedding_std
=
None
...
@@ -118,6 +119,8 @@ def main(args, config):
...
@@ -118,6 +119,8 @@ def main(args, config):
-
1
).
numpy
()
# (N, emb_size, 1) -> (N, emb_size)
-
1
).
numpy
()
# (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
# Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if
args
.
global_embedding_norm
:
if
args
.
global_embedding_norm
:
batch_count
+=
1
batch_count
+=
1
current_mean
=
embeddings
.
mean
(
current_mean
=
embeddings
.
mean
(
...
@@ -150,8 +153,8 @@ def main(args, config):
...
@@ -150,8 +153,8 @@ def main(args, config):
for
line
in
f
.
readlines
():
for
line
in
f
.
readlines
():
label
,
enrol_id
,
test_id
=
line
.
strip
().
split
(
' '
)
label
,
enrol_id
,
test_id
=
line
.
strip
().
split
(
' '
)
labels
.
append
(
int
(
label
))
labels
.
append
(
int
(
label
))
enrol_ids
.
append
(
enrol_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
enrol_ids
.
append
(
enrol_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-
-
'
))
test_ids
.
append
(
test_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
test_ids
.
append
(
test_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-
-
'
))
cos_sim_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=
1
)
cos_sim_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=
1
)
enrol_embeddings
,
test_embeddings
=
map
(
lambda
ids
:
paddle
.
to_tensor
(
enrol_embeddings
,
test_embeddings
=
map
(
lambda
ids
:
paddle
.
to_tensor
(
...
@@ -185,11 +188,10 @@ if __name__ == "__main__":
...
@@ -185,11 +188,10 @@ if __name__ == "__main__":
default
=
''
,
default
=
''
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--global-embedding-norm"
,
parser
.
add_argument
(
"--global-embedding-norm"
,
type
=
bool
,
default
=
False
,
default
=
True
,
action
=
"store_true"
,
help
=
"Apply global normalization on speaker embeddings."
)
help
=
"Apply global normalization on speaker embeddings."
)
parser
.
add_argument
(
"--embedding-mean-norm"
,
parser
.
add_argument
(
"--embedding-mean-norm"
,
type
=
bool
,
default
=
True
,
default
=
True
,
help
=
"Apply mean normalization on speaker embeddings."
)
help
=
"Apply mean normalization on speaker embeddings."
)
parser
.
add_argument
(
"--embedding-std-norm"
,
parser
.
add_argument
(
"--embedding-std-norm"
,
...
...
paddlespeech/vector/exps/ecapa_tdnn/train.py
浏览文件 @
311fa87a
...
@@ -178,9 +178,9 @@ def main(args, config):
...
@@ -178,9 +178,9 @@ def main(args, config):
timer
.
count
()
# step plus one in timer
timer
.
count
()
# step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if
(
batch_idx
+
1
)
%
config
.
log_
freq
==
0
and
local_rank
==
0
:
if
(
batch_idx
+
1
)
%
config
.
log_
interval
==
0
and
local_rank
==
0
:
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
avg_loss
/=
config
.
log_
freq
avg_loss
/=
config
.
log_
interval
avg_acc
=
num_corrects
/
num_samples
avg_acc
=
num_corrects
/
num_samples
print_msg
=
'Train Epoch={}/{}, Step={}/{}'
.
format
(
print_msg
=
'Train Epoch={}/{}, Step={}/{}'
.
format
(
...
@@ -196,7 +196,7 @@ def main(args, config):
...
@@ -196,7 +196,7 @@ def main(args, config):
num_samples
=
0
num_samples
=
0
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if
epoch
%
config
.
save_
freq
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
if
epoch
%
config
.
save_
interval
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
if
local_rank
!=
0
:
if
local_rank
!=
0
:
paddle
.
distributed
.
barrier
(
paddle
.
distributed
.
barrier
(
)
# Wait for valid step in main process
)
# Wait for valid step in main process
...
...
paddlespeech/vector/io/augment.py
浏览文件 @
311fa87a
...
@@ -11,7 +11,8 @@
...
@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# this is modified from https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
# this is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import
math
import
math
import
os
import
os
from
typing
import
List
from
typing
import
List
...
...
paddlespeech/vector/io/batch.py
浏览文件 @
311fa87a
...
@@ -75,6 +75,9 @@ def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
...
@@ -75,6 +75,9 @@ def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
i
]:].
sum
()
==
0
# Padding valus should all be 0.
i
]:].
sum
()
==
0
# Padding valus should all be 0.
# Converts into ratios.
# Converts into ratios.
# the utterance of the max length doesn't need to padding
# the remaining utterances need to padding and all of them will be padded to max length
# we convert the original length of each utterance to the ratio of the max length
lengths
=
(
lengths
/
lengths
.
max
()).
astype
(
np
.
float32
)
lengths
=
(
lengths
/
lengths
.
max
()).
astype
(
np
.
float32
)
return
{
'ids'
:
ids
,
'feats'
:
feats
,
'lengths'
:
lengths
}
return
{
'ids'
:
ids
,
'feats'
:
feats
,
'lengths'
:
lengths
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录