From c962eec51d3c507999bd638017f711c0cece273a Mon Sep 17 00:00:00 2001 From: qingen Date: Mon, 28 Feb 2022 20:47:19 +0800 Subject: [PATCH] [wip][vec] add clustering of vectors #1304 --- examples/ami/sd0/local/diarization.py | 1087 +++++++++++++++++++++++++ 1 file changed, 1087 insertions(+) create mode 100644 examples/ami/sd0/local/diarization.py diff --git a/examples/ami/sd0/local/diarization.py b/examples/ami/sd0/local/diarization.py new file mode 100644 index 00000000..863cff4a --- /dev/null +++ b/examples/ami/sd0/local/diarization.py @@ -0,0 +1,1087 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script contains basic functions used for speaker diarization. +This script has an optional dependency on open source sklearn library. +A few sklearn functions are modified in this script as per requirement. + +Authors + * qingenz123@126.com (Qingen ZHAO) 2022 + +""" + +import argparse +import warnings +import scipy +import numpy as np +from distutils.util import strtobool + +from scipy import sparse +from scipy.sparse.linalg import eigsh +from scipy.sparse.csgraph import connected_components +from scipy.sparse.csgraph import laplacian as csgraph_laplacian + +import sklearn +from sklearn.neighbors import kneighbors_graph +from sklearn.cluster import SpectralClustering +from sklearn.cluster._kmeans import k_means + + +def _graph_connected_component(graph, node_id): + """ + Find the largest graph connected components that contains one + given node. + + Arguments + --------- + graph : array-like, shape: (n_samples, n_samples) + Adjacency matrix of the graph, non-zero weight means an edge + between the nodes. + node_id : int + The index of the query node of the graph. + + Returns + ------- + connected_components_matrix : array-like + shape - (n_samples,). + An array of bool value indicating the indexes of the nodes belonging + to the largest connected components of the given query node. + """ + + n_node = graph.shape[0] + if sparse.issparse(graph): + # speed up row-wise access to boolean connection mask + graph = graph.tocsr() + connected_nodes = np.zeros(n_node, dtype=bool) + nodes_to_explore = np.zeros(n_node, dtype=bool) + nodes_to_explore[node_id] = True + for _ in range(n_node): + last_num_component = connected_nodes.sum() + np.logical_or(connected_nodes, nodes_to_explore, out=connected_nodes) + if last_num_component >= connected_nodes.sum(): + break + indices = np.where(nodes_to_explore)[0] + nodes_to_explore.fill(False) + for i in indices: + if sparse.issparse(graph): + neighbors = graph[i].toarray().ravel() + else: + neighbors = graph[i] + np.logical_or(nodes_to_explore, neighbors, out=nodes_to_explore) + return connected_nodes + + +def _graph_is_connected(graph): + """ + Return whether the graph is connected (True) or Not (False) + + Arguments + --------- + graph : array-like or sparse matrix, shape: (n_samples, n_samples) + Adjacency matrix of the graph, non-zero weight means an edge between the nodes. + + Returns + ------- + is_connected : bool + True means the graph is fully connected and False means not. + """ + + if sparse.isspmatrix(graph): + # sparse graph, find all the connected components + n_connected_components, _ = connected_components(graph) + return n_connected_components == 1 + else: + # dense graph, find all connected components start from node 0 + return _graph_connected_component(graph, 0).sum() == graph.shape[0] + + +def _set_diag(laplacian, value, norm_laplacian): + """ + Set the diagonal of the laplacian matrix and convert it to a sparse + format well suited for eigenvalue decomposition. + + Arguments + --------- + laplacian : array or sparse matrix + The graph laplacian. + value : float + The value of the diagonal. + norm_laplacian : bool + Whether the value of the diagonal should be changed or not. + + Returns + ------- + laplacian : array or sparse matrix + An array of matrix in a form that is well suited to fast eigenvalue + decomposition, depending on the bandwidth of the matrix. + """ + + n_nodes = laplacian.shape[0] + # We need all entries in the diagonal to values + if not sparse.isspmatrix(laplacian): + if norm_laplacian: + laplacian.flat[::n_nodes + 1] = value + else: + laplacian = laplacian.tocoo() + if norm_laplacian: + diag_idx = laplacian.row == laplacian.col + laplacian.data[diag_idx] = value + # If the matrix has a small number of diagonals (as in the + # case of structured matrices coming from images), the + # dia format might be best suited for matvec products: + n_diags = np.unique(laplacian.row - laplacian.col).size + if n_diags <= 7: + # 3 or less outer diagonals on each side + laplacian = laplacian.todia() + else: + # csr has the fastest matvec and is thus best suited to + # arpack + laplacian = laplacian.tocsr() + return laplacian + + +def _deterministic_vector_sign_flip(u): + """ + Modify the sign of vectors for reproducibility. Flips the sign of + elements of all the vectors (rows of u) such that the absolute + maximum element of each vector is positive. + + Arguments + --------- + u : ndarray + Array with vectors as its rows. + + Returns + ------- + u_flipped : ndarray + Array with the sign flipped vectors as its rows. The same shape as `u`. + """ + + max_abs_rows = np.argmax(np.abs(u), axis=1) + signs = np.sign(u[range(u.shape[0]), max_abs_rows]) + u *= signs[:, np.newaxis] + return u + + +def _check_random_state(seed): + """ + Turn seed into a np.random.RandomState instance. + + Arguments + --------- + seed : None | int | instance of RandomState + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + """ + + if seed is None or seed is np.random: + return np.random.mtrand._rand + if isinstance(seed, numbers.Integral): + return np.random.RandomState(seed) + if isinstance(seed, np.random.RandomState): + return seed + raise ValueError("%r cannot be used to seed a np.random.RandomState" + " instance" % seed) + + +def spectral_embedding( + adjacency, + n_components=8, + norm_laplacian=True, + drop_first=True, ): + """ + Returns spectral embeddings. + + Arguments + --------- + adjacency : array-like or sparse graph + shape - (n_samples, n_samples) + The adjacency matrix of the graph to embed. + n_components : int + The dimension of the projection subspace. + norm_laplacian : bool + If True, then compute normalized Laplacian. + drop_first : bool + Whether to drop the first eigenvector. + + Returns + ------- + embedding : array + Spectral embeddings for each sample. + + Example + ------- + >>> import numpy as np + >>> import diarization as diar + >>> affinity = np.array([[1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5], + ... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + ... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + ... [0.5, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0.5, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) + >>> embs = diar.spectral_embedding(affinity, 3) + >>> # Notice similar embeddings + >>> print(np.around(embs , decimals=3)) + [[ 0.075 0.244 0.285] + [ 0.083 0.356 -0.203] + [ 0.083 0.356 -0.203] + [ 0.26 -0.149 0.154] + [ 0.29 -0.218 -0.11 ] + [ 0.29 -0.218 -0.11 ] + [-0.198 -0.084 -0.122] + [-0.198 -0.084 -0.122] + [-0.198 -0.084 -0.122] + [-0.167 -0.044 0.316]] + """ + + # Whether to drop the first eigenvector + if drop_first: + n_components = n_components + 1 + + if not _graph_is_connected(adjacency): + warnings.warn("Graph is not fully connected, spectral embedding" + " may not work as expected.") + + laplacian, dd = csgraph_laplacian( + adjacency, normed=norm_laplacian, return_diag=True) + + laplacian = _set_diag(laplacian, 1, norm_laplacian) + + laplacian *= -1 + + vals, diffusion_map = eigsh( + laplacian, + k=n_components, + sigma=1.0, + which="LM", ) + + embedding = diffusion_map.T[n_components::-1] + + if norm_laplacian: + embedding = embedding / dd + + embedding = _deterministic_vector_sign_flip(embedding) + if drop_first: + return embedding[1:n_components].T + else: + return embedding[:n_components].T + + +def spectral_clustering( + affinity, + n_clusters=8, + n_components=None, + random_state=None, + n_init=10, ): + """ + Performs spectral clustering. + + Arguments + --------- + affinity : matrix + Affinity matrix. + n_clusters : int + Number of clusters for kmeans. + n_components : int + Number of components to retain while estimating spectral embeddings. + random_state : int + A pseudo random number generator used by kmeans. + n_init : int + Number of time the k-means algorithm will be run with different centroid seeds. + + Returns + ------- + labels : array + Cluster label for each sample. + + Example + ------- + >>> import numpy as np + >>> diarization as diar + >>> affinity = np.array([[1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5], + ... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + ... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + ... [0.5, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0.5, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) + >>> labs = diar.spectral_clustering(affinity, 3) + >>> # print (labs) # [2 2 2 1 1 1 0 0 0 0] + """ + + random_state = _check_random_state(random_state) + n_components = n_clusters if n_components is None else n_components + + maps = spectral_embedding( + affinity, + n_components=n_components, + drop_first=False, ) + + _, labels, _ = k_means( + maps, n_clusters, random_state=random_state, n_init=n_init) + + return labels + + +class EmbeddingMeta: + """ + A utility class to pack deep embeddings and meta-information in one object. + + Arguments + --------- + segset : list + List of session IDs as an array of strings. + stats : tensor + An ndarray of float64. Each line contains embedding + from the corresponding session. + """ + + def __init__( + self, + segset=None, + stats=None, ): + + if segset is None: + self.segset = numpy.empty(0, dtype="|O") + self.stats = numpy.array([], dtype=STAT_TYPE) + else: + self.segset = segset + self.stats = stats + + def norm_stats(self): + """ + Divide all first-order statistics by their Euclidean norm. + """ + + vect_norm = np.clip(np.linalg.norm(self.stats, axis=1), 1e-08, np.inf) + self.stats = (self.stats.transpose() / vect_norm).transpose() + + +class Spec_Clust_unorm: + """ + This class implements the spectral clustering with unnormalized affinity matrix. + Useful when affinity matrix is based on cosine similarities. + + Reference + --------- + Von Luxburg, U. A tutorial on spectral clustering. Stat Comput 17, 395–416 (2007). + https://doi.org/10.1007/s11222-007-9033-z + + Example + ------- + >>> import diarization as diar + >>> clust = diar.Spec_Clust_unorm(min_num_spkrs=2, max_num_spkrs=10) + >>> emb = [[ 2.1, 3.1, 4.1, 4.2, 3.1], + ... [ 2.2, 3.1, 4.2, 4.2, 3.2], + ... [ 2.0, 3.0, 4.0, 4.1, 3.0], + ... [ 8.0, 7.0, 7.0, 8.1, 9.0], + ... [ 8.1, 7.1, 7.2, 8.1, 9.2], + ... [ 8.3, 7.4, 7.0, 8.4, 9.0], + ... [ 0.3, 0.4, 0.4, 0.5, 0.8], + ... [ 0.4, 0.3, 0.6, 0.7, 0.8], + ... [ 0.2, 0.3, 0.2, 0.3, 0.7], + ... [ 0.3, 0.4, 0.4, 0.4, 0.7],] + >>> # Estimating similarity matrix + >>> sim_mat = clust.get_sim_mat(emb) + >>> print (np.around(sim_mat[5:,5:], decimals=3)) + [[1. 0.957 0.961 0.904 0.966] + [0.957 1. 0.977 0.982 0.997] + [0.961 0.977 1. 0.928 0.972] + [0.904 0.982 0.928 1. 0.976] + [0.966 0.997 0.972 0.976 1. ]] + >>> # Prunning + >>> prunned_sim_mat = clust.p_pruning(sim_mat, 0.3) + >>> print (np.around(prunned_sim_mat[5:,5:], decimals=3)) + [[1. 0. 0. 0. 0. ] + [0. 1. 0. 0.982 0.997] + [0. 0.977 1. 0. 0.972] + [0. 0.982 0. 1. 0.976] + [0. 0.997 0. 0.976 1. ]] + >>> # Symmetrization + >>> sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T) + >>> print (np.around(sym_prund_sim_mat[5:,5:], decimals=3)) + [[1. 0. 0. 0. 0. ] + [0. 1. 0.489 0.982 0.997] + [0. 0.489 1. 0. 0.486] + [0. 0.982 0. 1. 0.976] + [0. 0.997 0.486 0.976 1. ]] + >>> # Laplacian + >>> laplacian = clust.get_laplacian(sym_prund_sim_mat) + >>> print (np.around(laplacian[5:,5:], decimals=3)) + [[ 1.999 0. 0. 0. 0. ] + [ 0. 2.468 -0.489 -0.982 -0.997] + [ 0. -0.489 0.975 0. -0.486] + [ 0. -0.982 0. 1.958 -0.976] + [ 0. -0.997 -0.486 -0.976 2.458]] + >>> # Spectral Embeddings + >>> spec_emb, num_of_spk = clust.get_spec_embs(laplacian, 3) + >>> print(num_of_spk) + 3 + >>> # Clustering + >>> clust.cluster_embs(spec_emb, num_of_spk) + >>> # print (clust.labels_) # [0 0 0 2 2 2 1 1 1 1] + >>> # Complete spectral clustering + >>> clust.do_spec_clust(emb, k_oracle=3, p_val=0.3) + >>> # print(clust.labels_) # [0 0 0 2 2 2 1 1 1 1] + """ + + def __init__(self, min_num_spkrs=2, max_num_spkrs=10): + + self.min_num_spkrs = min_num_spkrs + self.max_num_spkrs = max_num_spkrs + + def do_spec_clust(self, X, k_oracle, p_val): + """ + Function for spectral clustering. + + Arguments + --------- + X : array + (n_samples, n_features). + Embeddings extracted from the model. + k_oracle : int + Number of speakers (when oracle number of speakers). + p_val : float + p percent value to prune the affinity matrix. + """ + + # Similarity matrix computation + sim_mat = self.get_sim_mat(X) + + # Refining similarity matrix with p_val + prunned_sim_mat = self.p_pruning(sim_mat, p_val) + + # Symmetrization + sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T) + + # Laplacian calculation + laplacian = self.get_laplacian(sym_prund_sim_mat) + + # Get Spectral Embeddings + emb, num_of_spk = self.get_spec_embs(laplacian, k_oracle) + + # Perform clustering + self.cluster_embs(emb, num_of_spk) + + def get_sim_mat(self, X): + """ + Returns the similarity matrix based on cosine similarities. + + Arguments + --------- + X : array + (n_samples, n_features). + Embeddings extracted from the model. + + Returns + ------- + M : array + (n_samples, n_samples). + Similarity matrix with cosine similarities between each pair of embedding. + """ + + # Cosine similarities + M = sklearn.metrics.pairwise.cosine_similarity(X, X) + return M + + def p_pruning(self, A, pval): + """ + Refine the affinity matrix by zeroing less similar values. + + Arguments + --------- + A : array + (n_samples, n_samples). + Affinity matrix. + pval : float + p-value to be retained in each row of the affinity matrix. + + Returns + ------- + A : array + (n_samples, n_samples). + Prunned affinity matrix based on p_val. + """ + + n_elems = int((1 - pval) * A.shape[0]) + + # For each row in a affinity matrix + for i in range(A.shape[0]): + low_indexes = np.argsort(A[i, :]) + low_indexes = low_indexes[0:n_elems] + + # Replace smaller similarity values by 0s + A[i, low_indexes] = 0 + + return A + + def get_laplacian(self, M): + """ + Returns the un-normalized laplacian for the given affinity matrix. + + Arguments + --------- + M : array + (n_samples, n_samples) + Affinity matrix. + + Returns + ------- + L : array + (n_samples, n_samples) + Laplacian matrix. + """ + + M[np.diag_indices(M.shape[0])] = 0 + D = np.sum(np.abs(M), axis=1) + D = np.diag(D) + L = D - M + return L + + def get_spec_embs(self, L, k_oracle=4): + """ + Returns spectral embeddings and estimates the number of speakers + using maximum Eigen gap. + + Arguments + --------- + L : array (n_samples, n_samples) + Laplacian matrix. + k_oracle : int + Number of speakers when the condition is oracle number of speakers, + else None. + + Returns + ------- + emb : array (n_samples, n_components) + Spectral embedding for each sample with n Eigen components. + num_of_spk : int + Estimated number of speakers. If the condition is set to the oracle + number of speakers then returns k_oracle. + """ + + lambdas, eig_vecs = scipy.linalg.eigh(L) + + # if params["oracle_n_spkrs"] is True: + if k_oracle is not None: + num_of_spk = k_oracle + else: + lambda_gap_list = self.getEigenGaps(lambdas[1:self.max_num_spkrs]) + + num_of_spk = (np.argmax( + lambda_gap_list[:min(self.max_num_spkrs, len(lambda_gap_list))]) + + 2) + + if num_of_spk < self.min_num_spkrs: + num_of_spk = self.min_num_spkrs + + emb = eig_vecs[:, 0:num_of_spk] + + return emb, num_of_spk + + def cluster_embs(self, emb, k): + """ + Clusters the embeddings using kmeans. + + Arguments + --------- + emb : array (n_samples, n_components) + Spectral embedding for each sample with n Eigen components. + k : int + Number of clusters to kmeans. + + Returns + ------- + self.labels_ : self + Labels for each sample embedding. + """ + _, self.labels_, _ = k_means(emb, k) + + def getEigenGaps(self, eig_vals): + """ + Returns the difference (gaps) between the Eigen values. + + Arguments + --------- + eig_vals : list + List of eigen values + + Returns + ------- + eig_vals_gap_list : list + List of differences (gaps) between adjacent Eigen values. + """ + + eig_vals_gap_list = [] + for i in range(len(eig_vals) - 1): + gap = float(eig_vals[i + 1]) - float(eig_vals[i]) + # eig_vals_gap_list.append(float(eig_vals[i + 1]) - float(eig_vals[i])) + eig_vals_gap_list.append(gap) + + return eig_vals_gap_list + + +class Spec_Cluster(SpectralClustering): + def perform_sc(self, X, n_neighbors=10): + """ + Performs spectral clustering using sklearn on embeddings. + + Arguments + --------- + X : array (n_samples, n_features) + Embeddings to be clustered. + n_neighbors : int + Number of neighbors in estimating affinity matrix. + """ + + # Computation of affinity matrix + connectivity = kneighbors_graph( + X, + n_neighbors=n_neighbors, + include_self=True, ) + self.affinity_matrix_ = 0.5 * (connectivity + connectivity.T) + + # Perform spectral clustering on affinity matrix + self.labels_ = spectral_clustering_sb( + self.affinity_matrix_, + n_clusters=self.n_clusters, ) + return self + + +def is_overlapped(end1, start2): + """ + Returns True if segments are overlapping. + + Arguments + --------- + end1 : float + End time of the first segment. + start2 : float + Start time of the second segment. + + Returns + ------- + overlapped : bool + True of segments overlapped else False. + + Example + ------- + >>> import diarization as diar + >>> diar.is_overlapped(5.5, 3.4) + True + >>> diar.is_overlapped(5.5, 6.4) + False + """ + + if start2 > end1: + return False + else: + return True + + +def merge_ssegs_same_speaker(lol): + """ + Merge adjacent sub-segs from the same speaker. + + Arguments + --------- + lol : list of list + Each list contains [rec_id, seg_start, seg_end, spkr_id]. + + Returns + ------- + new_lol : list of list + new_lol contains adjacent segments merged from the same speaker ID. + + Example + ------- + >>> import diarization as diar + >>> lol=[['r1', 5.5, 7.0, 's1'], + ... ['r1', 6.5, 9.0, 's1'], + ... ['r1', 8.0, 11.0, 's1'], + ... ['r1', 11.5, 13.0, 's2'], + ... ['r1', 14.0, 15.0, 's2'], + ... ['r1', 14.5, 15.0, 's1']] + >>> diar.merge_ssegs_same_speaker(lol) + [['r1', 5.5, 11.0, 's1'], ['r1', 11.5, 13.0, 's2'], ['r1', 14.0, 15.0, 's2'], ['r1', 14.5, 15.0, 's1']] + """ + + new_lol = [] + + # Start from the first sub-seg + sseg = lol[0] + flag = False + for i in range(1, len(lol)): + next_sseg = lol[i] + + # IF sub-segments overlap AND has same speaker THEN merge + if is_overlapped(sseg[2], next_sseg[1]) and sseg[3] == next_sseg[3]: + sseg[2] = next_sseg[2] # just update the end time + # This is important. For the last sseg, if it is the same speaker the merge + # Make sure we don't append the last segment once more. Hence, set FLAG=True + if i == len(lol) - 1: + flag = True + new_lol.append(sseg) + else: + new_lol.append(sseg) + sseg = next_sseg + + # Add last segment only when it was skipped earlier. + if flag is False: + new_lol.append(lol[-1]) + + return new_lol + + +def distribute_overlap(lol): + """ + Distributes the overlapped speech equally among the adjacent segments + with different speakers. + + Arguments + --------- + lol : list of list + It has each list structure as [rec_id, seg_start, seg_end, spkr_id]. + + Returns + ------- + new_lol : list of list + It contains the overlapped part equally divided among the adjacent + segments with different speaker IDs. + + Example + ------- + >>> import diarization as diar + >>> lol = [['r1', 5.5, 9.0, 's1'], + ... ['r1', 8.0, 11.0, 's2'], + ... ['r1', 11.5, 13.0, 's2'], + ... ['r1', 12.0, 15.0, 's1']] + >>> diar.distribute_overlap(lol) + [['r1', 5.5, 8.5, 's1'], ['r1', 8.5, 11.0, 's2'], ['r1', 11.5, 12.5, 's2'], ['r1', 12.5, 15.0, 's1']] + """ + + new_lol = [] + sseg = lol[0] + + # Add first sub-segment here to avoid error at: "if new_lol[-1] != sseg:" when new_lol is empty + # new_lol.append(sseg) + + for i in range(1, len(lol)): + next_sseg = lol[i] + # No need to check if they are different speakers. + # Because if segments are overlapped then they always have different speakers. + # This is because similar speaker's adjacent sub-segments are already merged by "merge_ssegs_same_speaker()" + + if is_overlapped(sseg[2], next_sseg[1]): + + # Get overlap duration. + # Now this overlap will be divided equally between adjacent segments. + overlap = sseg[2] - next_sseg[1] + + # Update end time of old seg + sseg[2] = sseg[2] - (overlap / 2.0) + + # Update start time of next seg + next_sseg[1] = next_sseg[1] + (overlap / 2.0) + + if len(new_lol) == 0: + # For first sub-segment entry + new_lol.append(sseg) + else: + # To avoid duplicate entries + if new_lol[-1] != sseg: + new_lol.append(sseg) + + # Current sub-segment is next sub-segment + sseg = next_sseg + + else: + # For the first sseg + if len(new_lol) == 0: + new_lol.append(sseg) + else: + # To avoid duplicate entries + if new_lol[-1] != sseg: + new_lol.append(sseg) + + # Update the current sub-segment + sseg = next_sseg + + # Add the remaining last sub-segment + new_lol.append(next_sseg) + + return new_lol + + +def write_rttm(segs_list, out_rttm_file): + """ + Writes the segment list in RTTM format (A standard NIST format). + + Arguments + --------- + segs_list : list of list + Each list contains [rec_id, seg_start, seg_end, spkr_id]. + out_rttm_file : str + Path of the output RTTM file. + """ + + rttm = [] + rec_id = segs_list[0][0] + + for seg in segs_list: + new_row = [ + "SPEAKER", + rec_id, + "0", + str(round(seg[1], 4)), + str(round(seg[2] - seg[1], 4)), + "", + "", + seg[3], + "", + "", + ] + rttm.append(new_row) + + with open(out_rttm_file, "w") as f: + for row in rttm: + line_str = " ".join(row) + f.write("%s\n" % line_str) + + +def do_AHC(diary_obj, out_rttm_file, rec_id, k_oracle=4, p_val=0.3): + """ + Performs Agglomerative Hierarchical Clustering on embeddings. + + Arguments + --------- + diary_obj : EmbeddingMeta type + Contains embeddings in diary_obj.stats and segment IDs in diary_obj.segset. + out_rttm_file : str + Path of the output RTTM file. + rec_id : str + Recording ID for the recording under processing. + k : int + Number of speaker (None, if it has to be estimated). + pval : float + `pval` for prunning affinity matrix. Used only when number of speakers + are unknown. Note that this is just for experiment. Prefer Spectral clustering + for better clustering results. + """ + + from sklearn.cluster import AgglomerativeClustering + + # p_val is the threshold_val (for AHC) + diary_obj.norm_stats() + + # processing + if k_oracle is not None: + num_of_spk = k_oracle + + clustering = AgglomerativeClustering( + n_clusters=num_of_spk, + affinity="cosine", + linkage="average", ).fit(diary_obj.stats) + labels = clustering.labels_ + + else: + # Estimate num of using max eigen gap with `cos` affinity matrix. + # This is just for experimentation. + clustering = AgglomerativeClustering( + n_clusters=None, + affinity="cosine", + linkage="average", + distance_threshold=p_val, ).fit(diary_obj.stats) + labels = clustering.labels_ + + # Convert labels to speaker boundaries + subseg_ids = diary_obj.segset + lol = [] + + for i in range(labels.shape[0]): + spkr_id = rec_id + "_" + str(labels[i]) + + sub_seg = subseg_ids[i] + + splitted = sub_seg.rsplit("_", 2) + rec_id = str(splitted[0]) + sseg_start = float(splitted[1]) + sseg_end = float(splitted[2]) + + a = [rec_id, sseg_start, sseg_end, spkr_id] + lol.append(a) + + # Sorting based on start time of sub-segment + lol.sort(key=lambda x: float(x[1])) + + # Merge and split in 2 simple steps: (i) Merge sseg of same speakers then (ii) split different speakers + # Step 1: Merge adjacent sub-segments that belong to same speaker (or cluster) + lol = merge_ssegs_same_speaker(lol) + + # Step 2: Distribute duration of adjacent overlapping sub-segments belonging to different speakers (or cluster) + # Taking mid-point as the splitting time location. + lol = distribute_overlap(lol) + + # logger.info("Completed diarizing " + rec_id) + write_rttm(lol, out_rttm_file) + + +def do_spec_clustering(diary_obj, out_rttm_file, rec_id, k, pval, affinity_type, + n_neighbors): + """ + Performs spectral clustering on embeddings. This function calls specific + clustering algorithms as per affinity. + + Arguments + --------- + diary_obj : EmbeddingMeta type + Contains embeddings in diary_obj.stats and segment IDs in diary_obj.segset. + out_rttm_file : str + Path of the output RTTM file. + rec_id : str + Recording ID for the recording under processing. + k : int + Number of speaker (None, if it has to be estimated). + pval : float + `pval` for prunning affinity matrix. + affinity_type : str + Type of similarity to be used to get affinity matrix (cos or nn). + """ + + if affinity_type == "cos": + clust_obj = Spec_Clust_unorm(min_num_spkrs=2, max_num_spkrs=10) + k_oracle = k # use it only when oracle num of speakers + clust_obj.do_spec_clust(diary_obj.stats, k_oracle, pval) + labels = clust_obj.labels_ + else: + clust_obj = Spec_Cluster( + n_clusters=k, + assign_labels="kmeans", + random_state=1234, + affinity="nearest_neighbors", ) + clust_obj.perform_sc(diary_obj.stats, n_neighbors) + labels = clust_obj.labels_ + + # Convert labels to speaker boundaries + subseg_ids = diary_obj.segset + lol = [] + + for i in range(labels.shape[0]): + spkr_id = rec_id + "_" + str(labels[i]) + + sub_seg = subseg_ids[i] + + splitted = sub_seg.rsplit("_", 2) + rec_id = str(splitted[0]) + sseg_start = float(splitted[1]) + sseg_end = float(splitted[2]) + + a = [rec_id, sseg_start, sseg_end, spkr_id] + lol.append(a) + + # Sorting based on start time of sub-segment + lol.sort(key=lambda x: float(x[1])) + + # Merge and split in 2 simple steps: (i) Merge sseg of same speakers then (ii) split different speakers + # Step 1: Merge adjacent sub-segments that belong to same speaker (or cluster) + lol = merge_ssegs_same_speaker(lol) + + # Step 2: Distribute duration of adjacent overlapping sub-segments belonging to different speakers (or cluster) + # Taking mid-point as the splitting time location. + lol = distribute_overlap(lol) + + # logger.info("Completed diarizing " + rec_id) + write_rttm(lol, out_rttm_file) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + prog='python diarization.py --backend AHC', description='diarizing') + parser.add_argument( + '--sys_rttm_dir', + required=False, + help='Directory to store system RTTM files') + parser.add_argument( + '--ref_rttm_dir', + required=False, + help='Directory to store reference RTTM files') + parser.add_argument( + '--backend', default="AHC", help='type of backend, AHC or SC or kmeans') + parser.add_argument( + '--oracle_n_spkrs', + default=True, + type=strtobool, + help='Oracle num of speakers') + parser.add_argument( + '--mic_type', + default="Mix-Headset", + help='Type of microphone to be used') + parser.add_argument( + '--affinity', default="cos", help='affinity matrix, cos or nn') + parser.add_argument( + '--max_subseg_dur', + default=3.0, + type=float, + help='Duration in seconds of a subsegments to be prepared from larger segments' + ) + parser.add_argument( + '--overlap', + default=1.5, + type=float, + help='Overlap duration in seconds between adjacent subsegments') + + args = parser.parse_args() + + pval = 0.3 + rec_id = "utt0001" + n_neighbors = 10 + out_rttm_file = "./out.rttm" + + embeddings = np.empty(shape=[0, 32], dtype=np.float64) + segset = [] + + for i in range(10): + seg = [rec_id + "_" + str(i) + "_" + str(i + 1)] + segset = segset + seg + emb = np.random.rand(1, 32) + embeddings = np.concatenate((embeddings, emb), axis=0) + + segset = np.array(segset, dtype="|O") + stat_obj = EmbeddingMeta(segset, embeddings) + if args.oracle_n_spkrs is True: + num_spkrs = 2 + + if args.backend == "SC": + print("begin SC ") + do_spec_clustering( + stat_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, + args.affinity, + n_neighbors, ) + if args.backend == "AHC": + print("begin AHC ") + do_AHC(stat_obj, out_rttm_file, rec_id, num_spkrs, pval) -- GitLab