Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b5315657
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b5315657
编写于
3月 22, 2022
作者:
H
Hui Zhang
提交者:
GitHub
3月 22, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1509 from qingen/cluster
[vec] add clustering of vectors
上级
e5e8b8a1
0f7ede11
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
1083 addition
and
12 deletion
+1083
-12
examples/ami/sd0/local/ami_prepare.py
examples/ami/sd0/local/ami_prepare.py
+1
-4
examples/ami/sd0/local/ami_splits.py
examples/ami/sd0/local/ami_splits.py
+0
-4
examples/ami/sd0/local/dataio.py
examples/ami/sd0/local/dataio.py
+0
-4
paddlespeech/vector/cluster/diarization.py
paddlespeech/vector/cluster/diarization.py
+1082
-0
未找到文件。
examples/ami/sd0/local/ami_prepare.py
浏览文件 @
b5315657
...
@@ -17,11 +17,8 @@ Data preparation.
...
@@ -17,11 +17,8 @@ Data preparation.
Download: http://groups.inf.ed.ac.uk/ami/download/
Download: http://groups.inf.ed.ac.uk/ami/download/
Prepares metadata files (JSON) from manual annotations "segments/" using RTTM format (Oracle VAD).
Prepares metadata files (JSON) from manual annotations "segments/" using RTTM format (Oracle VAD).
Authors
* qingenz123@126.com (Qingen ZHAO) 2022
"""
"""
import
argparse
import
argparse
import
glob
import
glob
import
json
import
json
...
...
examples/ami/sd0/local/ami_splits.py
浏览文件 @
b5315657
...
@@ -15,10 +15,6 @@
...
@@ -15,10 +15,6 @@
AMI corpus contained 100 hours of meeting recording.
AMI corpus contained 100 hours of meeting recording.
This script returns the standard train, dev and eval split for AMI corpus.
This script returns the standard train, dev and eval split for AMI corpus.
For more information on dataset please refer to http://groups.inf.ed.ac.uk/ami/corpus/datasets.shtml
For more information on dataset please refer to http://groups.inf.ed.ac.uk/ami/corpus/datasets.shtml
Authors
* qingenz123@126.com (Qingen ZHAO) 2022
"""
"""
ALLOWED_OPTIONS
=
[
"scenario_only"
,
"full_corpus"
,
"full_corpus_asr"
]
ALLOWED_OPTIONS
=
[
"scenario_only"
,
"full_corpus"
,
"full_corpus_asr"
]
...
...
examples/ami/sd0/local/dataio.py
浏览文件 @
b5315657
...
@@ -13,10 +13,6 @@
...
@@ -13,10 +13,6 @@
# limitations under the License.
# limitations under the License.
"""
"""
Data reading and writing.
Data reading and writing.
Authors
* qingenz123@126.com (Qingen ZHAO) 2022
"""
"""
import
os
import
os
import
pickle
import
pickle
...
...
paddlespeech/vector/cluster/diarization.py
0 → 100644
浏览文件 @
b5315657
# Copyright (c) 2022 SpeechBrain Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script contains basic functions used for speaker diarization.
This script has an optional dependency on open source sklearn library.
A few sklearn functions are modified in this script as per requirement.
"""
import
argparse
import
warnings
import
scipy
import
numpy
as
np
from
distutils.util
import
strtobool
from
scipy
import
sparse
from
scipy.sparse.linalg
import
eigsh
from
scipy.sparse.csgraph
import
connected_components
from
scipy.sparse.csgraph
import
laplacian
as
csgraph_laplacian
import
sklearn
from
sklearn.neighbors
import
kneighbors_graph
from
sklearn.cluster
import
SpectralClustering
from
sklearn.cluster._kmeans
import
k_means
def
_graph_connected_component
(
graph
,
node_id
):
"""
Find the largest graph connected components that contains one
given node.
Arguments
---------
graph : array-like, shape: (n_samples, n_samples)
Adjacency matrix of the graph, non-zero weight means an edge
between the nodes.
node_id : int
The index of the query node of the graph.
Returns
-------
connected_components_matrix : array-like
shape - (n_samples,).
An array of bool value indicating the indexes of the nodes belonging
to the largest connected components of the given query node.
"""
n_node
=
graph
.
shape
[
0
]
if
sparse
.
issparse
(
graph
):
# speed up row-wise access to boolean connection mask
graph
=
graph
.
tocsr
()
connected_nodes
=
np
.
zeros
(
n_node
,
dtype
=
bool
)
nodes_to_explore
=
np
.
zeros
(
n_node
,
dtype
=
bool
)
nodes_to_explore
[
node_id
]
=
True
for
_
in
range
(
n_node
):
last_num_component
=
connected_nodes
.
sum
()
np
.
logical_or
(
connected_nodes
,
nodes_to_explore
,
out
=
connected_nodes
)
if
last_num_component
>=
connected_nodes
.
sum
():
break
indices
=
np
.
where
(
nodes_to_explore
)[
0
]
nodes_to_explore
.
fill
(
False
)
for
i
in
indices
:
if
sparse
.
issparse
(
graph
):
neighbors
=
graph
[
i
].
toarray
().
ravel
()
else
:
neighbors
=
graph
[
i
]
np
.
logical_or
(
nodes_to_explore
,
neighbors
,
out
=
nodes_to_explore
)
return
connected_nodes
def
_graph_is_connected
(
graph
):
"""
Return whether the graph is connected (True) or Not (False)
Arguments
---------
graph : array-like or sparse matrix, shape: (n_samples, n_samples)
Adjacency matrix of the graph, non-zero weight means an edge between the nodes.
Returns
-------
is_connected : bool
True means the graph is fully connected and False means not.
"""
if
sparse
.
isspmatrix
(
graph
):
# sparse graph, find all the connected components
n_connected_components
,
_
=
connected_components
(
graph
)
return
n_connected_components
==
1
else
:
# dense graph, find all connected components start from node 0
return
_graph_connected_component
(
graph
,
0
).
sum
()
==
graph
.
shape
[
0
]
def
_set_diag
(
laplacian
,
value
,
norm_laplacian
):
"""
Set the diagonal of the laplacian matrix and convert it to a sparse
format well suited for eigenvalue decomposition.
Arguments
---------
laplacian : array or sparse matrix
The graph laplacian.
value : float
The value of the diagonal.
norm_laplacian : bool
Whether the value of the diagonal should be changed or not.
Returns
-------
laplacian : array or sparse matrix
An array of matrix in a form that is well suited to fast eigenvalue
decomposition, depending on the bandwidth of the matrix.
"""
n_nodes
=
laplacian
.
shape
[
0
]
# We need all entries in the diagonal to values
if
not
sparse
.
isspmatrix
(
laplacian
):
if
norm_laplacian
:
laplacian
.
flat
[::
n_nodes
+
1
]
=
value
else
:
laplacian
=
laplacian
.
tocoo
()
if
norm_laplacian
:
diag_idx
=
laplacian
.
row
==
laplacian
.
col
laplacian
.
data
[
diag_idx
]
=
value
# If the matrix has a small number of diagonals (as in the
# case of structured matrices coming from images), the
# dia format might be best suited for matvec products:
n_diags
=
np
.
unique
(
laplacian
.
row
-
laplacian
.
col
).
size
if
n_diags
<=
7
:
# 3 or less outer diagonals on each side
laplacian
=
laplacian
.
todia
()
else
:
# csr has the fastest matvec and is thus best suited to
# arpack
laplacian
=
laplacian
.
tocsr
()
return
laplacian
def
_deterministic_vector_sign_flip
(
u
):
"""
Modify the sign of vectors for reproducibility. Flips the sign of
elements of all the vectors (rows of u) such that the absolute
maximum element of each vector is positive.
Arguments
---------
u : ndarray
Array with vectors as its rows.
Returns
-------
u_flipped : ndarray
Array with the sign flipped vectors as its rows. The same shape as `u`.
"""
max_abs_rows
=
np
.
argmax
(
np
.
abs
(
u
),
axis
=
1
)
signs
=
np
.
sign
(
u
[
range
(
u
.
shape
[
0
]),
max_abs_rows
])
u
*=
signs
[:,
np
.
newaxis
]
return
u
def
_check_random_state
(
seed
):
"""
Turn seed into a np.random.RandomState instance.
Arguments
---------
seed : None | int | instance of RandomState
If seed is None, return the RandomState singleton used by np.random.
If seed is an int, return a new RandomState instance seeded with seed.
If seed is already a RandomState instance, return it.
Otherwise raise ValueError.
"""
if
seed
is
None
or
seed
is
np
.
random
:
return
np
.
random
.
mtrand
.
_rand
if
isinstance
(
seed
,
numbers
.
Integral
):
return
np
.
random
.
RandomState
(
seed
)
if
isinstance
(
seed
,
np
.
random
.
RandomState
):
return
seed
raise
ValueError
(
"%r cannot be used to seed a np.random.RandomState"
" instance"
%
seed
)
def
spectral_embedding
(
adjacency
,
n_components
=
8
,
norm_laplacian
=
True
,
drop_first
=
True
,
):
"""
Returns spectral embeddings.
Arguments
---------
adjacency : array-like or sparse graph
shape - (n_samples, n_samples)
The adjacency matrix of the graph to embed.
n_components : int
The dimension of the projection subspace.
norm_laplacian : bool
If True, then compute normalized Laplacian.
drop_first : bool
Whether to drop the first eigenvector.
Returns
-------
embedding : array
Spectral embeddings for each sample.
Example
-------
>>> import numpy as np
>>> import diarization as diar
>>> affinity = np.array([[1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5],
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
... [0.5, 0, 0, 1, 1, 1, 0, 0, 0, 0],
... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
... [0.5, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
>>> embs = diar.spectral_embedding(affinity, 3)
>>> # Notice similar embeddings
>>> print(np.around(embs , decimals=3))
[[ 0.075 0.244 0.285]
[ 0.083 0.356 -0.203]
[ 0.083 0.356 -0.203]
[ 0.26 -0.149 0.154]
[ 0.29 -0.218 -0.11 ]
[ 0.29 -0.218 -0.11 ]
[-0.198 -0.084 -0.122]
[-0.198 -0.084 -0.122]
[-0.198 -0.084 -0.122]
[-0.167 -0.044 0.316]]
"""
# Whether to drop the first eigenvector
if
drop_first
:
n_components
=
n_components
+
1
if
not
_graph_is_connected
(
adjacency
):
warnings
.
warn
(
"Graph is not fully connected, spectral embedding"
" may not work as expected."
)
laplacian
,
dd
=
csgraph_laplacian
(
adjacency
,
normed
=
norm_laplacian
,
return_diag
=
True
)
laplacian
=
_set_diag
(
laplacian
,
1
,
norm_laplacian
)
laplacian
*=
-
1
vals
,
diffusion_map
=
eigsh
(
laplacian
,
k
=
n_components
,
sigma
=
1.0
,
which
=
"LM"
,
)
embedding
=
diffusion_map
.
T
[
n_components
::
-
1
]
if
norm_laplacian
:
embedding
=
embedding
/
dd
embedding
=
_deterministic_vector_sign_flip
(
embedding
)
if
drop_first
:
return
embedding
[
1
:
n_components
].
T
else
:
return
embedding
[:
n_components
].
T
def
spectral_clustering
(
affinity
,
n_clusters
=
8
,
n_components
=
None
,
random_state
=
None
,
n_init
=
10
,
):
"""
Performs spectral clustering.
Arguments
---------
affinity : matrix
Affinity matrix.
n_clusters : int
Number of clusters for kmeans.
n_components : int
Number of components to retain while estimating spectral embeddings.
random_state : int
A pseudo random number generator used by kmeans.
n_init : int
Number of time the k-means algorithm will be run with different centroid seeds.
Returns
-------
labels : array
Cluster label for each sample.
Example
-------
>>> import numpy as np
>>> diarization as diar
>>> affinity = np.array([[1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5],
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
... [0.5, 0, 0, 1, 1, 1, 0, 0, 0, 0],
... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
... [0.5, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
>>> labs = diar.spectral_clustering(affinity, 3)
>>> # print (labs) # [2 2 2 1 1 1 0 0 0 0]
"""
random_state
=
_check_random_state
(
random_state
)
n_components
=
n_clusters
if
n_components
is
None
else
n_components
maps
=
spectral_embedding
(
affinity
,
n_components
=
n_components
,
drop_first
=
False
,
)
_
,
labels
,
_
=
k_means
(
maps
,
n_clusters
,
random_state
=
random_state
,
n_init
=
n_init
)
return
labels
class
EmbeddingMeta
:
"""
A utility class to pack deep embeddings and meta-information in one object.
Arguments
---------
segset : list
List of session IDs as an array of strings.
stats : tensor
An ndarray of float64. Each line contains embedding
from the corresponding session.
"""
def
__init__
(
self
,
segset
=
None
,
stats
=
None
,
):
if
segset
is
None
:
self
.
segset
=
numpy
.
empty
(
0
,
dtype
=
"|O"
)
self
.
stats
=
numpy
.
array
([],
dtype
=
np
.
float64
)
else
:
self
.
segset
=
segset
self
.
stats
=
stats
def
norm_stats
(
self
):
"""
Divide all first-order statistics by their Euclidean norm.
"""
vect_norm
=
np
.
clip
(
np
.
linalg
.
norm
(
self
.
stats
,
axis
=
1
),
1e-08
,
np
.
inf
)
self
.
stats
=
(
self
.
stats
.
transpose
()
/
vect_norm
).
transpose
()
class
SpecClustUnorm
:
"""
This class implements the spectral clustering with unnormalized affinity matrix.
Useful when affinity matrix is based on cosine similarities.
Reference
---------
Von Luxburg, U. A tutorial on spectral clustering. Stat Comput 17, 395–416 (2007).
https://doi.org/10.1007/s11222-007-9033-z
Example
-------
>>> import diarization as diar
>>> clust = diar.SpecClustUnorm(min_num_spkrs=2, max_num_spkrs=10)
>>> emb = [[ 2.1, 3.1, 4.1, 4.2, 3.1],
... [ 2.2, 3.1, 4.2, 4.2, 3.2],
... [ 2.0, 3.0, 4.0, 4.1, 3.0],
... [ 8.0, 7.0, 7.0, 8.1, 9.0],
... [ 8.1, 7.1, 7.2, 8.1, 9.2],
... [ 8.3, 7.4, 7.0, 8.4, 9.0],
... [ 0.3, 0.4, 0.4, 0.5, 0.8],
... [ 0.4, 0.3, 0.6, 0.7, 0.8],
... [ 0.2, 0.3, 0.2, 0.3, 0.7],
... [ 0.3, 0.4, 0.4, 0.4, 0.7],]
>>> # Estimating similarity matrix
>>> sim_mat = clust.get_sim_mat(emb)
>>> print (np.around(sim_mat[5:,5:], decimals=3))
[[1. 0.957 0.961 0.904 0.966]
[0.957 1. 0.977 0.982 0.997]
[0.961 0.977 1. 0.928 0.972]
[0.904 0.982 0.928 1. 0.976]
[0.966 0.997 0.972 0.976 1. ]]
>>> # Prunning
>>> prunned_sim_mat = clust.p_pruning(sim_mat, 0.3)
>>> print (np.around(prunned_sim_mat[5:,5:], decimals=3))
[[1. 0. 0. 0. 0. ]
[0. 1. 0. 0.982 0.997]
[0. 0.977 1. 0. 0.972]
[0. 0.982 0. 1. 0.976]
[0. 0.997 0. 0.976 1. ]]
>>> # Symmetrization
>>> sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
>>> print (np.around(sym_prund_sim_mat[5:,5:], decimals=3))
[[1. 0. 0. 0. 0. ]
[0. 1. 0.489 0.982 0.997]
[0. 0.489 1. 0. 0.486]
[0. 0.982 0. 1. 0.976]
[0. 0.997 0.486 0.976 1. ]]
>>> # Laplacian
>>> laplacian = clust.get_laplacian(sym_prund_sim_mat)
>>> print (np.around(laplacian[5:,5:], decimals=3))
[[ 1.999 0. 0. 0. 0. ]
[ 0. 2.468 -0.489 -0.982 -0.997]
[ 0. -0.489 0.975 0. -0.486]
[ 0. -0.982 0. 1.958 -0.976]
[ 0. -0.997 -0.486 -0.976 2.458]]
>>> # Spectral Embeddings
>>> spec_emb, num_of_spk = clust.get_spec_embs(laplacian, 3)
>>> print(num_of_spk)
3
>>> # Clustering
>>> clust.cluster_embs(spec_emb, num_of_spk)
>>> # print (clust.labels_) # [0 0 0 2 2 2 1 1 1 1]
>>> # Complete spectral clustering
>>> clust.do_spec_clust(emb, k_oracle=3, p_val=0.3)
>>> # print(clust.labels_) # [0 0 0 2 2 2 1 1 1 1]
"""
def
__init__
(
self
,
min_num_spkrs
=
2
,
max_num_spkrs
=
10
):
self
.
min_num_spkrs
=
min_num_spkrs
self
.
max_num_spkrs
=
max_num_spkrs
def
do_spec_clust
(
self
,
X
,
k_oracle
,
p_val
):
"""
Function for spectral clustering.
Arguments
---------
X : array
(n_samples, n_features).
Embeddings extracted from the model.
k_oracle : int
Number of speakers (when oracle number of speakers).
p_val : float
p percent value to prune the affinity matrix.
"""
# Similarity matrix computation
sim_mat
=
self
.
get_sim_mat
(
X
)
# Refining similarity matrix with p_val
prunned_sim_mat
=
self
.
p_pruning
(
sim_mat
,
p_val
)
# Symmetrization
sym_prund_sim_mat
=
0.5
*
(
prunned_sim_mat
+
prunned_sim_mat
.
T
)
# Laplacian calculation
laplacian
=
self
.
get_laplacian
(
sym_prund_sim_mat
)
# Get Spectral Embeddings
emb
,
num_of_spk
=
self
.
get_spec_embs
(
laplacian
,
k_oracle
)
# Perform clustering
self
.
cluster_embs
(
emb
,
num_of_spk
)
def
get_sim_mat
(
self
,
X
):
"""
Returns the similarity matrix based on cosine similarities.
Arguments
---------
X : array
(n_samples, n_features).
Embeddings extracted from the model.
Returns
-------
M : array
(n_samples, n_samples).
Similarity matrix with cosine similarities between each pair of embedding.
"""
# Cosine similarities
M
=
sklearn
.
metrics
.
pairwise
.
cosine_similarity
(
X
,
X
)
return
M
def
p_pruning
(
self
,
A
,
pval
):
"""
Refine the affinity matrix by zeroing less similar values.
Arguments
---------
A : array
(n_samples, n_samples).
Affinity matrix.
pval : float
p-value to be retained in each row of the affinity matrix.
Returns
-------
A : array
(n_samples, n_samples).
Prunned affinity matrix based on p_val.
"""
n_elems
=
int
((
1
-
pval
)
*
A
.
shape
[
0
])
# For each row in a affinity matrix
for
i
in
range
(
A
.
shape
[
0
]):
low_indexes
=
np
.
argsort
(
A
[
i
,
:])
low_indexes
=
low_indexes
[
0
:
n_elems
]
# Replace smaller similarity values by 0s
A
[
i
,
low_indexes
]
=
0
return
A
def
get_laplacian
(
self
,
M
):
"""
Returns the un-normalized laplacian for the given affinity matrix.
Arguments
---------
M : array
(n_samples, n_samples)
Affinity matrix.
Returns
-------
L : array
(n_samples, n_samples)
Laplacian matrix.
"""
M
[
np
.
diag_indices
(
M
.
shape
[
0
])]
=
0
D
=
np
.
sum
(
np
.
abs
(
M
),
axis
=
1
)
D
=
np
.
diag
(
D
)
L
=
D
-
M
return
L
def
get_spec_embs
(
self
,
L
,
k_oracle
=
4
):
"""
Returns spectral embeddings and estimates the number of speakers
using maximum Eigen gap.
Arguments
---------
L : array (n_samples, n_samples)
Laplacian matrix.
k_oracle : int
Number of speakers when the condition is oracle number of speakers,
else None.
Returns
-------
emb : array (n_samples, n_components)
Spectral embedding for each sample with n Eigen components.
num_of_spk : int
Estimated number of speakers. If the condition is set to the oracle
number of speakers then returns k_oracle.
"""
lambdas
,
eig_vecs
=
scipy
.
linalg
.
eigh
(
L
)
# if params["oracle_n_spkrs"] is True:
if
k_oracle
is
not
None
:
num_of_spk
=
k_oracle
else
:
lambda_gap_list
=
self
.
get_eigen_gaps
(
lambdas
[
1
:
self
.
max_num_spkrs
])
num_of_spk
=
(
np
.
argmax
(
lambda_gap_list
[:
min
(
self
.
max_num_spkrs
,
len
(
lambda_gap_list
))])
+
2
)
if
num_of_spk
<
self
.
min_num_spkrs
:
num_of_spk
=
self
.
min_num_spkrs
emb
=
eig_vecs
[:,
0
:
num_of_spk
]
return
emb
,
num_of_spk
def
cluster_embs
(
self
,
emb
,
k
):
"""
Clusters the embeddings using kmeans.
Arguments
---------
emb : array (n_samples, n_components)
Spectral embedding for each sample with n Eigen components.
k : int
Number of clusters to kmeans.
Returns
-------
self.labels_ : self
Labels for each sample embedding.
"""
_
,
self
.
labels_
,
_
=
k_means
(
emb
,
k
)
def
get_eigen_gaps
(
self
,
eig_vals
):
"""
Returns the difference (gaps) between the Eigen values.
Arguments
---------
eig_vals : list
List of eigen values
Returns
-------
eig_vals_gap_list : list
List of differences (gaps) between adjacent Eigen values.
"""
eig_vals_gap_list
=
[]
for
i
in
range
(
len
(
eig_vals
)
-
1
):
gap
=
float
(
eig_vals
[
i
+
1
])
-
float
(
eig_vals
[
i
])
eig_vals_gap_list
.
append
(
gap
)
return
eig_vals_gap_list
class
SpecCluster
(
SpectralClustering
):
def
perform_sc
(
self
,
X
,
n_neighbors
=
10
):
"""
Performs spectral clustering using sklearn on embeddings.
Arguments
---------
X : array (n_samples, n_features)
Embeddings to be clustered.
n_neighbors : int
Number of neighbors in estimating affinity matrix.
"""
# Computation of affinity matrix
connectivity
=
kneighbors_graph
(
X
,
n_neighbors
=
n_neighbors
,
include_self
=
True
,
)
self
.
affinity_matrix_
=
0.5
*
(
connectivity
+
connectivity
.
T
)
# Perform spectral clustering on affinity matrix
self
.
labels_
=
spectral_clustering
(
self
.
affinity_matrix_
,
n_clusters
=
self
.
n_clusters
,
)
return
self
def
is_overlapped
(
end1
,
start2
):
"""
Returns True if segments are overlapping.
Arguments
---------
end1 : float
End time of the first segment.
start2 : float
Start time of the second segment.
Returns
-------
overlapped : bool
True of segments overlapped else False.
Example
-------
>>> import diarization as diar
>>> diar.is_overlapped(5.5, 3.4)
True
>>> diar.is_overlapped(5.5, 6.4)
False
"""
if
start2
>
end1
:
return
False
else
:
return
True
def
merge_ssegs_same_speaker
(
lol
):
"""
Merge adjacent sub-segs from the same speaker.
Arguments
---------
lol : list of list
Each list contains [rec_id, seg_start, seg_end, spkr_id].
Returns
-------
new_lol : list of list
new_lol contains adjacent segments merged from the same speaker ID.
Example
-------
>>> import diarization as diar
>>> lol=[['r1', 5.5, 7.0, 's1'],
... ['r1', 6.5, 9.0, 's1'],
... ['r1', 8.0, 11.0, 's1'],
... ['r1', 11.5, 13.0, 's2'],
... ['r1', 14.0, 15.0, 's2'],
... ['r1', 14.5, 15.0, 's1']]
>>> diar.merge_ssegs_same_speaker(lol)
[['r1', 5.5, 11.0, 's1'], ['r1', 11.5, 13.0, 's2'], ['r1', 14.0, 15.0, 's2'], ['r1', 14.5, 15.0, 's1']]
"""
new_lol
=
[]
# Start from the first sub-seg
sseg
=
lol
[
0
]
flag
=
False
for
i
in
range
(
1
,
len
(
lol
)):
next_sseg
=
lol
[
i
]
# IF sub-segments overlap AND has same speaker THEN merge
if
is_overlapped
(
sseg
[
2
],
next_sseg
[
1
])
and
sseg
[
3
]
==
next_sseg
[
3
]:
sseg
[
2
]
=
next_sseg
[
2
]
# just update the end time
# This is important. For the last sseg, if it is the same speaker the merge
# Make sure we don't append the last segment once more. Hence, set FLAG=True
if
i
==
len
(
lol
)
-
1
:
flag
=
True
new_lol
.
append
(
sseg
)
else
:
new_lol
.
append
(
sseg
)
sseg
=
next_sseg
# Add last segment only when it was skipped earlier.
if
flag
is
False
:
new_lol
.
append
(
lol
[
-
1
])
return
new_lol
def
distribute_overlap
(
lol
):
"""
Distributes the overlapped speech equally among the adjacent segments
with different speakers.
Arguments
---------
lol : list of list
It has each list structure as [rec_id, seg_start, seg_end, spkr_id].
Returns
-------
new_lol : list of list
It contains the overlapped part equally divided among the adjacent
segments with different speaker IDs.
Example
-------
>>> import diarization as diar
>>> lol = [['r1', 5.5, 9.0, 's1'],
... ['r1', 8.0, 11.0, 's2'],
... ['r1', 11.5, 13.0, 's2'],
... ['r1', 12.0, 15.0, 's1']]
>>> diar.distribute_overlap(lol)
[['r1', 5.5, 8.5, 's1'], ['r1', 8.5, 11.0, 's2'], ['r1', 11.5, 12.5, 's2'], ['r1', 12.5, 15.0, 's1']]
"""
new_lol
=
[]
sseg
=
lol
[
0
]
# Add first sub-segment here to avoid error at: "if new_lol[-1] != sseg:" when new_lol is empty
# new_lol.append(sseg)
for
i
in
range
(
1
,
len
(
lol
)):
next_sseg
=
lol
[
i
]
# No need to check if they are different speakers.
# Because if segments are overlapped then they always have different speakers.
# This is because similar speaker's adjacent sub-segments are already merged by "merge_ssegs_same_speaker()"
if
is_overlapped
(
sseg
[
2
],
next_sseg
[
1
]):
# Get overlap duration.
# Now this overlap will be divided equally between adjacent segments.
overlap
=
sseg
[
2
]
-
next_sseg
[
1
]
# Update end time of old seg
sseg
[
2
]
=
sseg
[
2
]
-
(
overlap
/
2.0
)
# Update start time of next seg
next_sseg
[
1
]
=
next_sseg
[
1
]
+
(
overlap
/
2.0
)
if
len
(
new_lol
)
==
0
:
# For first sub-segment entry
new_lol
.
append
(
sseg
)
else
:
# To avoid duplicate entries
if
new_lol
[
-
1
]
!=
sseg
:
new_lol
.
append
(
sseg
)
# Current sub-segment is next sub-segment
sseg
=
next_sseg
else
:
# For the first sseg
if
len
(
new_lol
)
==
0
:
new_lol
.
append
(
sseg
)
else
:
# To avoid duplicate entries
if
new_lol
[
-
1
]
!=
sseg
:
new_lol
.
append
(
sseg
)
# Update the current sub-segment
sseg
=
next_sseg
# Add the remaining last sub-segment
new_lol
.
append
(
next_sseg
)
return
new_lol
def
write_rttm
(
segs_list
,
out_rttm_file
):
"""
Writes the segment list in RTTM format (A standard NIST format).
Arguments
---------
segs_list : list of list
Each list contains [rec_id, seg_start, seg_end, spkr_id].
out_rttm_file : str
Path of the output RTTM file.
"""
rttm
=
[]
rec_id
=
segs_list
[
0
][
0
]
for
seg
in
segs_list
:
new_row
=
[
"SPEAKER"
,
rec_id
,
"0"
,
str
(
round
(
seg
[
1
],
4
)),
str
(
round
(
seg
[
2
]
-
seg
[
1
],
4
)),
"<NA>"
,
"<NA>"
,
seg
[
3
],
"<NA>"
,
"<NA>"
,
]
rttm
.
append
(
new_row
)
with
open
(
out_rttm_file
,
"w"
)
as
f
:
for
row
in
rttm
:
line_str
=
" "
.
join
(
row
)
f
.
write
(
"%s
\n
"
%
line_str
)
def
do_AHC
(
diary_obj
,
out_rttm_file
,
rec_id
,
k_oracle
=
4
,
p_val
=
0.3
):
"""
Performs Agglomerative Hierarchical Clustering on embeddings.
Arguments
---------
diary_obj : EmbeddingMeta type
Contains embeddings in diary_obj.stats and segment IDs in diary_obj.segset.
out_rttm_file : str
Path of the output RTTM file.
rec_id : str
Recording ID for the recording under processing.
k : int
Number of speaker (None, if it has to be estimated).
pval : float
`pval` for prunning affinity matrix. Used only when number of speakers
are unknown. Note that this is just for experiment. Prefer Spectral clustering
for better clustering results.
"""
from
sklearn.cluster
import
AgglomerativeClustering
# p_val is the threshold_val (for AHC)
diary_obj
.
norm_stats
()
# processing
if
k_oracle
is
not
None
:
num_of_spk
=
k_oracle
clustering
=
AgglomerativeClustering
(
n_clusters
=
num_of_spk
,
affinity
=
"cosine"
,
linkage
=
"average"
,
).
fit
(
diary_obj
.
stats
)
labels
=
clustering
.
labels_
else
:
# Estimate num of using max eigen gap with `cos` affinity matrix.
# This is just for experimentation.
clustering
=
AgglomerativeClustering
(
n_clusters
=
None
,
affinity
=
"cosine"
,
linkage
=
"average"
,
distance_threshold
=
p_val
,
).
fit
(
diary_obj
.
stats
)
labels
=
clustering
.
labels_
# Convert labels to speaker boundaries
subseg_ids
=
diary_obj
.
segset
lol
=
[]
for
i
in
range
(
labels
.
shape
[
0
]):
spkr_id
=
rec_id
+
"_"
+
str
(
labels
[
i
])
sub_seg
=
subseg_ids
[
i
]
splitted
=
sub_seg
.
rsplit
(
"_"
,
2
)
rec_id
=
str
(
splitted
[
0
])
sseg_start
=
float
(
splitted
[
1
])
sseg_end
=
float
(
splitted
[
2
])
a
=
[
rec_id
,
sseg_start
,
sseg_end
,
spkr_id
]
lol
.
append
(
a
)
# Sorting based on start time of sub-segment
lol
.
sort
(
key
=
lambda
x
:
float
(
x
[
1
]))
# Merge and split in 2 simple steps: (i) Merge sseg of same speakers then (ii) split different speakers
# Step 1: Merge adjacent sub-segments that belong to same speaker (or cluster)
lol
=
merge_ssegs_same_speaker
(
lol
)
# Step 2: Distribute duration of adjacent overlapping sub-segments belonging to different speakers (or cluster)
# Taking mid-point as the splitting time location.
lol
=
distribute_overlap
(
lol
)
# logger.info("Completed diarizing " + rec_id)
write_rttm
(
lol
,
out_rttm_file
)
def
do_spec_clustering
(
diary_obj
,
out_rttm_file
,
rec_id
,
k
,
pval
,
affinity_type
,
n_neighbors
):
"""
Performs spectral clustering on embeddings. This function calls specific
clustering algorithms as per affinity.
Arguments
---------
diary_obj : EmbeddingMeta type
Contains embeddings in diary_obj.stats and segment IDs in diary_obj.segset.
out_rttm_file : str
Path of the output RTTM file.
rec_id : str
Recording ID for the recording under processing.
k : int
Number of speaker (None, if it has to be estimated).
pval : float
`pval` for prunning affinity matrix.
affinity_type : str
Type of similarity to be used to get affinity matrix (cos or nn).
"""
if
affinity_type
==
"cos"
:
clust_obj
=
SpecClustUnorm
(
min_num_spkrs
=
2
,
max_num_spkrs
=
10
)
k_oracle
=
k
# use it only when oracle num of speakers
clust_obj
.
do_spec_clust
(
diary_obj
.
stats
,
k_oracle
,
pval
)
labels
=
clust_obj
.
labels_
else
:
clust_obj
=
SpecCluster
(
n_clusters
=
k
,
assign_labels
=
"kmeans"
,
random_state
=
1234
,
affinity
=
"nearest_neighbors"
,
)
clust_obj
.
perform_sc
(
diary_obj
.
stats
,
n_neighbors
)
labels
=
clust_obj
.
labels_
# Convert labels to speaker boundaries
subseg_ids
=
diary_obj
.
segset
lol
=
[]
for
i
in
range
(
labels
.
shape
[
0
]):
spkr_id
=
rec_id
+
"_"
+
str
(
labels
[
i
])
sub_seg
=
subseg_ids
[
i
]
splitted
=
sub_seg
.
rsplit
(
"_"
,
2
)
rec_id
=
str
(
splitted
[
0
])
sseg_start
=
float
(
splitted
[
1
])
sseg_end
=
float
(
splitted
[
2
])
a
=
[
rec_id
,
sseg_start
,
sseg_end
,
spkr_id
]
lol
.
append
(
a
)
# Sorting based on start time of sub-segment
lol
.
sort
(
key
=
lambda
x
:
float
(
x
[
1
]))
# Merge and split in 2 simple steps: (i) Merge sseg of same speakers then (ii) split different speakers
# Step 1: Merge adjacent sub-segments that belong to same speaker (or cluster)
lol
=
merge_ssegs_same_speaker
(
lol
)
# Step 2: Distribute duration of adjacent overlapping sub-segments belonging to different speakers (or cluster)
# Taking mid-point as the splitting time location.
lol
=
distribute_overlap
(
lol
)
# logger.info("Completed diarizing " + rec_id)
write_rttm
(
lol
,
out_rttm_file
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
prog
=
'python diarization.py --backend AHC'
,
description
=
'diarizing'
)
parser
.
add_argument
(
'--sys_rttm_dir'
,
required
=
False
,
help
=
'Directory to store system RTTM files'
)
parser
.
add_argument
(
'--ref_rttm_dir'
,
required
=
False
,
help
=
'Directory to store reference RTTM files'
)
parser
.
add_argument
(
'--backend'
,
default
=
"AHC"
,
help
=
'type of backend, AHC or SC or kmeans'
)
parser
.
add_argument
(
'--oracle_n_spkrs'
,
default
=
True
,
type
=
strtobool
,
help
=
'Oracle num of speakers'
)
parser
.
add_argument
(
'--mic_type'
,
default
=
"Mix-Headset"
,
help
=
'Type of microphone to be used'
)
parser
.
add_argument
(
'--affinity'
,
default
=
"cos"
,
help
=
'affinity matrix, cos or nn'
)
parser
.
add_argument
(
'--max_subseg_dur'
,
default
=
3.0
,
type
=
float
,
help
=
'Duration in seconds of a subsegments to be prepared from larger segments'
)
parser
.
add_argument
(
'--overlap'
,
default
=
1.5
,
type
=
float
,
help
=
'Overlap duration in seconds between adjacent subsegments'
)
args
=
parser
.
parse_args
()
pval
=
0.3
rec_id
=
"utt0001"
n_neighbors
=
10
out_rttm_file
=
"./out.rttm"
embeddings
=
np
.
empty
(
shape
=
[
0
,
32
],
dtype
=
np
.
float64
)
segset
=
[]
for
i
in
range
(
10
):
seg
=
[
rec_id
+
"_"
+
str
(
i
)
+
"_"
+
str
(
i
+
1
)]
segset
=
segset
+
seg
emb
=
np
.
random
.
rand
(
1
,
32
)
embeddings
=
np
.
concatenate
((
embeddings
,
emb
),
axis
=
0
)
segset
=
np
.
array
(
segset
,
dtype
=
"|O"
)
stat_obj
=
EmbeddingMeta
(
segset
,
embeddings
)
if
args
.
oracle_n_spkrs
is
True
:
num_spkrs
=
2
if
args
.
backend
==
"SC"
:
print
(
"begin SC "
)
do_spec_clustering
(
stat_obj
,
out_rttm_file
,
rec_id
,
num_spkrs
,
pval
,
args
.
affinity
,
n_neighbors
,
)
if
args
.
backend
==
"AHC"
:
print
(
"begin AHC "
)
do_AHC
(
stat_obj
,
out_rttm_file
,
rec_id
,
num_spkrs
,
pval
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录