diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py index f8b0a027c6499ded95c901ebbc1b652652176f0b..a8043c227bbc070fffa254c79f6f468d9b4174f0 100644 --- a/paddlespeech/vector/cluster/diarization.py +++ b/paddlespeech/vector/cluster/diarization.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 SpeechBrain Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle and SpeechBrain Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,12 +18,14 @@ This script has an optional dependency on open source sklearn library. A few sklearn functions are modified in this script as per requirement. """ import argparse +import copy import warnings from distutils.util import strtobool import numpy as np import scipy import sklearn +from scipy import linalg from scipy import sparse from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import laplacian as csgraph_laplacian @@ -346,6 +348,8 @@ class EmbeddingMeta: --------- segset : list List of session IDs as an array of strings. + modelset : list + List of model IDs as an array of strings. stats : tensor An ndarray of float64. Each line contains embedding from the corresponding session. @@ -354,15 +358,20 @@ class EmbeddingMeta: def __init__( self, segset=None, + modelset=None, stats=None, ): if segset is None: - self.segset = numpy.empty(0, dtype="|O") - self.stats = numpy.array([], dtype=np.float64) + self.segset = np.empty(0, dtype="|O") + self.modelset = np.empty(0, dtype="|O") + self.stats = np.array([], dtype=np.float64) else: self.segset = segset + self.modelset = modelset self.stats = stats + self.stat0 = np.array([[1.0]] * self.stats.shape[0]) + def norm_stats(self): """ Divide all first-order statistics by their Euclidean norm. @@ -371,6 +380,188 @@ class EmbeddingMeta: vect_norm = np.clip(np.linalg.norm(self.stats, axis=1), 1e-08, np.inf) self.stats = (self.stats.transpose() / vect_norm).transpose() + def get_mean_stats(self): + """ + Return the mean of first order statistics. + """ + mu = np.mean(self.stats, axis=0) + return mu + + def get_total_covariance_stats(self): + """ + Compute and return the total covariance matrix of the first-order statistics. + """ + C = self.stats - self.stats.mean(axis=0) + return np.dot(C.transpose(), C) / self.stats.shape[0] + + def get_model_stat0(self, mod_id): + """Return zero-order statistics of a given model + + Arguments + --------- + mod_id : str + ID of the model which stat0 will be returned. + """ + S = self.stat0[self.modelset == mod_id, :] + return S + + def get_model_stats(self, mod_id): + """Return first-order statistics of a given model. + + Arguments + --------- + mod_id : str + ID of the model which stat1 will be returned. + """ + return self.stats[self.modelset == mod_id, :] + + def sum_stat_per_model(self): + """ + Sum the zero- and first-order statistics per model and store them + in a new EmbeddingMeta. + Returns a EmbeddingMeta object with the statistics summed per model + and a numpy array with session_per_model. + """ + + sts_per_model = EmbeddingMeta() + sts_per_model.modelset = np.unique( + self.modelset) # nd: get uniq spkr ids + sts_per_model.segset = copy.deepcopy(sts_per_model.modelset) + sts_per_model.stat0 = np.zeros( + (sts_per_model.modelset.shape[0], self.stat0.shape[1]), + dtype=np.float64, ) + sts_per_model.stats = np.zeros( + (sts_per_model.modelset.shape[0], self.stats.shape[1]), + dtype=np.float64, ) + + session_per_model = np.zeros(np.unique(self.modelset).shape[0]) + + # For each model sum the stats + for idx, model in enumerate(sts_per_model.modelset): + sts_per_model.stat0[idx, :] = self.get_model_stat0(model).sum( + axis=0) + sts_per_model.stats[idx, :] = self.get_model_stats(model).sum( + axis=0) + session_per_model[idx] += self.get_model_stats(model).shape[0] + return sts_per_model, session_per_model + + def center_stats(self, mu): + """ + Center first order statistics. + + Arguments + --------- + mu : array + Array to center on. + """ + + dim = self.stats.shape[1] / self.stat0.shape[1] + index_map = np.repeat(np.arange(self.stat0.shape[1]), dim) + self.stats = self.stats - (self.stat0[:, index_map] * + mu.astype(np.float64)) + + def rotate_stats(self, R): + """ + Rotate first-order statistics by a right-product. + + Arguments + --------- + R : ndarray + Matrix to use for right product on the first order statistics. + """ + self.stats = np.dot(self.stats, R) + + def whiten_stats(self, mu, sigma, isSqrInvSigma=False): + """ + Whiten first-order statistics + If sigma.ndim == 1, case of a diagonal covariance. + If sigma.ndim == 2, case of a single Gaussian with full covariance. + If sigma.ndim == 3, case of a full covariance UBM. + + Arguments + --------- + mu : array + Mean vector to be subtracted from the statistics. + sigma : narray + Co-variance matrix or covariance super-vector. + isSqrInvSigma : bool + True if the input Sigma matrix is the inverse of the square root of a covariance matrix. + """ + + if sigma.ndim == 1: + self.center_stats(mu) + self.stats = self.stats / np.sqrt(sigma.astype(np.float64)) + + elif sigma.ndim == 2: + # Compute the inverse square root of the co-variance matrix Sigma + sqr_inv_sigma = sigma + + if not isSqrInvSigma: + # eigen_values, eigen_vectors = scipy.linalg.eigh(sigma) + eigen_values, eigen_vectors = linalg.eigh(sigma) + ind = eigen_values.real.argsort()[::-1] + eigen_values = eigen_values.real[ind] + eigen_vectors = eigen_vectors.real[:, ind] + + sqr_inv_eval_sigma = 1 / np.sqrt(eigen_values.real) + sqr_inv_sigma = np.dot(eigen_vectors, + np.diag(sqr_inv_eval_sigma)) + else: + pass + + # Whitening of the first-order statistics + self.center_stats(mu) # CENTERING + self.rotate_stats(sqr_inv_sigma) + + elif sigma.ndim == 3: + # we assume that sigma is a 3D ndarray of size D x n x n + # where D is the number of distributions and n is the dimension of a single distribution + n = self.stats.shape[1] // self.stat0.shape[1] + sess_nb = self.stat0.shape[0] + self.center_stats(mu) + self.stats = (np.einsum("ikj,ikl->ilj", + self.stats.T.reshape(-1, n, sess_nb), sigma) + .reshape(-1, sess_nb).T) + + else: + raise Exception("Wrong dimension of Sigma, must be 1 or 2") + + def align_models(self, model_list): + """ + Align models of the current EmbeddingMeta to match a list of models + provided as input parameter. The size of the StatServer might be + reduced to match the input list of models. + + Arguments + --------- + model_list : ndarray of strings + List of models to match. + """ + indx = np.array( + [np.argwhere(self.modelset == v)[0][0] for v in model_list]) + self.segset = self.segset[indx] + self.modelset = self.modelset[indx] + self.stat0 = self.stat0[indx, :] + self.stats = self.stats[indx, :] + + def align_segments(self, segment_list): + """ + Align segments of the current EmbeddingMeta to match a list of segment + provided as input parameter. The size of the StatServer might be + reduced to match the input list of segments. + + Arguments + --------- + segment_list: ndarray of strings + list of segments to match + """ + indx = np.array( + [np.argwhere(self.segset == v)[0][0] for v in segment_list]) + self.segset = self.segset[indx] + self.modelset = self.modelset[indx] + self.stat0 = self.stat0[indx, :] + self.stats = self.stats[indx, :] + class SpecClustUnorm: """ diff --git a/paddlespeech/vector/cluster/plda.py b/paddlespeech/vector/cluster/plda.py new file mode 100644 index 0000000000000000000000000000000000000000..81def435692fc8ad77cc87ac79a814235b74987e --- /dev/null +++ b/paddlespeech/vector/cluster/plda.py @@ -0,0 +1,575 @@ +# Copyright (c) 2022 PaddlePaddle and SpeechBrain Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A popular speaker recognition/diarization model (LDA and PLDA). + +Relevant Papers + - This implementation of PLDA is based on the following papers. + + - PLDA model Training + * Ye Jiang et. al, "PLDA Modeling in I-Vector and Supervector Space for Speaker Verification," in Interspeech, 2012. + * Patrick Kenny et. al, "PLDA for speaker verification with utterances of arbitrary duration," in ICASSP, 2013. + + - PLDA scoring (fast scoring) + * Daniel Garcia-Romero et. al, “Analysis of i-vector length normalization in speaker recognition systems,” in Interspeech, 2011. + * Weiwei-LIN et. al, "Fast Scoring for PLDA with Uncertainty Propagation," in Odyssey, 2016. + * Kong Aik Lee et. al, "Multi-session PLDA Scoring of I-vector for Partially Open-Set Speaker Detection," in Interspeech 2013. + +Credits + This code is adapted from: https://git-lium.univ-lemans.fr/Larcher/sidekit +""" +import copy +import pickle + +import numpy +from scipy import linalg + +from paddlespeech.vector.cluster.diarization import EmbeddingMeta + + +def ismember(list1, list2): + c = [item in list2 for item in list1] + return c + + +class Ndx: + """ + A class that encodes trial index information. It has a list of + model names and a list of test segment names and a matrix + indicating which combinations of model and test segment are + trials of interest. + + Arguments + --------- + modelset : list + List of unique models in a ndarray. + segset : list + List of unique test segments in a ndarray. + trialmask : 2D ndarray of bool. + Rows correspond to the models and columns to the test segments. True, if the trial is of interest. + """ + + def __init__(self, + ndx_file_name="", + models=numpy.array([]), + testsegs=numpy.array([])): + """ + Initialize a Ndx object by loading information from a file. + + Arguments + --------- + ndx_file_name : str + Name of the file to load. + """ + self.modelset = numpy.empty(0, dtype="|O") + self.segset = numpy.empty(0, dtype="|O") + self.trialmask = numpy.array([], dtype="bool") + + if ndx_file_name == "": + # This is needed to make sizes same + d = models.shape[0] - testsegs.shape[0] + if d != 0: + if d > 0: + last = str(testsegs[-1]) + pad = numpy.array([last] * d) + testsegs = numpy.hstack((testsegs, pad)) + # pad = testsegs[-d:] + # testsegs = numpy.concatenate((testsegs, pad), axis=1) + else: + d = abs(d) + last = str(models[-1]) + pad = numpy.array([last] * d) + models = numpy.hstack((models, pad)) + # pad = models[-d:] + # models = numpy.concatenate((models, pad), axis=1) + + modelset = numpy.unique(models) + segset = numpy.unique(testsegs) + + trialmask = numpy.zeros( + (modelset.shape[0], segset.shape[0]), dtype="bool") + for m in range(modelset.shape[0]): + segs = testsegs[numpy.array(ismember(models, modelset[m]))] + trialmask[m, ] = ismember(segset, segs) # noqa E231 + + self.modelset = modelset + self.segset = segset + self.trialmask = trialmask + assert self.validate(), "Wrong Ndx format" + + else: + ndx = Ndx.read(ndx_file_name) + self.modelset = ndx.modelset + self.segset = ndx.segset + self.trialmask = ndx.trialmask + + def save_ndx_object(self, output_file_name): + with open(output_file_name, "wb") as output: + pickle.dump(self, output, pickle.HIGHEST_PROTOCOL) + + def filter(self, modlist, seglist, keep): + """ + Removes some of the information in an Ndx. Useful for creating a + gender specific Ndx from a pooled gender Ndx. Depending on the + value of \'keep\', the two input lists indicate the strings to + retain or the strings to discard. + + Arguments + --------- + modlist : array + A cell array of strings which will be compared with the modelset of 'inndx'. + seglist : array + A cell array of strings which will be compared with the segset of 'inndx'. + keep : bool + Indicating whether modlist and seglist are the models to keep or discard. + """ + if keep: + keepmods = modlist + keepsegs = seglist + else: + keepmods = diff(self.modelset, modlist) + keepsegs = diff(self.segset, seglist) + + keepmodidx = numpy.array(ismember(self.modelset, keepmods)) + keepsegidx = numpy.array(ismember(self.segset, keepsegs)) + + outndx = Ndx() + outndx.modelset = self.modelset[keepmodidx] + outndx.segset = self.segset[keepsegidx] + tmp = self.trialmask[numpy.array(keepmodidx), :] + outndx.trialmask = tmp[:, numpy.array(keepsegidx)] + + assert outndx.validate, "Wrong Ndx format" + + if self.modelset.shape[0] > outndx.modelset.shape[0]: + print( + "Number of models reduced from %d to %d" % + self.modelset.shape[0], + outndx.modelset.shape[0], ) + if self.segset.shape[0] > outndx.segset.shape[0]: + print( + "Number of test segments reduced from %d to %d", + self.segset.shape[0], + outndx.segset.shape[0], ) + return outndx + + def validate(self): + """ + Checks that an object of type Ndx obeys certain rules that + must always be true. Returns a boolean value indicating whether the object is valid + """ + ok = isinstance(self.modelset, numpy.ndarray) + ok &= isinstance(self.segset, numpy.ndarray) + ok &= isinstance(self.trialmask, numpy.ndarray) + + ok &= self.modelset.ndim == 1 + ok &= self.segset.ndim == 1 + ok &= self.trialmask.ndim == 2 + + ok &= self.trialmask.shape == (self.modelset.shape[0], + self.segset.shape[0], ) + return ok + + +class Scores: + """ + A class for storing scores for trials. The modelset and segset + fields are lists of model and test segment names respectively. + The element i,j of scoremat and scoremask corresponds to the + trial involving model i and test segment j. + + Arguments + --------- + modelset : list + List of unique models in a ndarray. + segset : list + List of unique test segments in a ndarray. + scoremask : 2D ndarray of bool + Indicates the trials of interest, i.e., + the entry i,j in scoremat should be ignored if scoremask[i,j] is False. + scoremat : 2D ndarray + Scores matrix. + """ + + def __init__(self, scores_file_name=""): + """ + Initialize a Scores object by loading information from a file HDF5 format. + + Arguments + --------- + scores_file_name : str + Name of the file to load. + """ + self.modelset = numpy.empty(0, dtype="|O") + self.segset = numpy.empty(0, dtype="|O") + self.scoremask = numpy.array([], dtype="bool") + self.scoremat = numpy.array([]) + + if scores_file_name == "": + pass + else: + tmp = Scores.read(scores_file_name) + self.modelset = tmp.modelset + self.segset = tmp.segset + self.scoremask = tmp.scoremask + self.scoremat = tmp.scoremat + + def __repr__(self): + ch = "modelset:\n" + ch += self.modelset + "\n" + ch += "segset:\n" + ch += self.segset + "\n" + ch += "scoremask:\n" + ch += self.scoremask.__repr__() + "\n" + ch += "scoremat:\n" + ch += self.scoremat.__repr__() + "\n" + + +def fa_model_loop( + batch_start, + mini_batch_indices, + factor_analyser, + stat0, + stats, + e_h, + e_hh, ): + """ + A function for PLDA estimation. + + Arguments + --------- + batch_start : int + Index to start at in the list. + mini_batch_indices : list + Indices of the elements in the list (should start at zero). + factor_analyser : instance of PLDA class + PLDA class object. + stat0 : tensor + Matrix of zero-order statistics. + stats: tensor + Matrix of first-order statistics. + e_h : tensor + An accumulator matrix. + e_hh: tensor + An accumulator matrix. + """ + rank = factor_analyser.F.shape[1] + if factor_analyser.Sigma.ndim == 2: + A = factor_analyser.F.T.dot(factor_analyser.F) + inv_lambda_unique = dict() + for sess in numpy.unique(stat0[:, 0]): + inv_lambda_unique[sess] = linalg.inv(sess * A + numpy.eye(A.shape[ + 0])) + + tmp = numpy.zeros( + (factor_analyser.F.shape[1], factor_analyser.F.shape[1]), + dtype=numpy.float64, ) + + for idx in mini_batch_indices: + if factor_analyser.Sigma.ndim == 1: + inv_lambda = linalg.inv( + numpy.eye(rank) + (factor_analyser.F.T * stat0[ + idx + batch_start, :]).dot(factor_analyser.F)) + else: + inv_lambda = inv_lambda_unique[stat0[idx + batch_start, 0]] + + aux = factor_analyser.F.T.dot(stats[idx + batch_start, :]) + numpy.dot(aux, inv_lambda, out=e_h[idx]) + e_hh[idx] = inv_lambda + numpy.outer(e_h[idx], e_h[idx], tmp) + + +def _check_missing_model(enroll, test, ndx): + # Remove missing models and test segments + clean_ndx = ndx.filter(enroll.modelset, test.segset, True) + + # Align EmbeddingMeta to match the clean_ndx + enroll.align_models(clean_ndx.modelset) + test.align_segments(clean_ndx.segset) + + return clean_ndx + + +class PLDA: + """ + A class to train PLDA model from embeddings. + + The input is in paddlespeech.vector.cluster.diarization.EmbeddingMeta format. + Trains a simplified PLDA model no within-class covariance matrix but full residual covariance matrix. + + Arguments + --------- + mean : tensor + Mean of the vectors. + F : tensor + Eigenvoice matrix. + Sigma : tensor + Residual matrix. + """ + + def __init__( + self, + mean=None, + F=None, + Sigma=None, + rank_f=100, + nb_iter=10, + scaling_factor=1.0, ): + self.mean = None + self.F = None + self.Sigma = None + self.rank_f = rank_f + self.nb_iter = nb_iter + self.scaling_factor = scaling_factor + + if mean is not None: + self.mean = mean + if F is not None: + self.F = F + if Sigma is not None: + self.Sigma = Sigma + + def plda( + self, + emb_meta=None, + output_file_name=None, ): + """ + Trains PLDA model with no within class covariance matrix but full residual covariance matrix. + + Arguments + --------- + emb_meta : paddlespeech.vector.cluster.diarization.EmbeddingMeta + Contains vectors and meta-information to perform PLDA + rank_f : int + Rank of the between-class covariance matrix. + nb_iter : int + Number of iterations to run. + scaling_factor : float + Scaling factor to downscale statistics (value between 0 and 1). + output_file_name : str + Name of the output file where to store PLDA model. + """ + + # Dimension of the vector (x-vectors stored in stats) + vect_size = emb_meta.stats.shape[1] + + # Initialize mean and residual covariance from the training data + self.mean = emb_meta.get_mean_stats() + self.Sigma = emb_meta.get_total_covariance_stats() + + # Sum stat0 and stat1 for each speaker model + model_shifted_stat, session_per_model = emb_meta.sum_stat_per_model() + + # Number of speakers (classes) in training set + class_nb = model_shifted_stat.modelset.shape[0] + + # Multiply statistics by scaling_factor + model_shifted_stat.stat0 *= self.scaling_factor + model_shifted_stat.stats *= self.scaling_factor + session_per_model *= self.scaling_factor + + # Covariance for stats + sigma_obs = emb_meta.get_total_covariance_stats() + evals, evecs = linalg.eigh(sigma_obs) + + # Initial F (eigen voice matrix) from rank + idx = numpy.argsort(evals)[::-1] + evecs = evecs.real[:, idx[:self.rank_f]] + self.F = evecs[:, :self.rank_f] + + # Estimate PLDA model by iterating the EM algorithm + for it in range(self.nb_iter): + + # E-step + + # Copy stats as they will be whitened with a different Sigma for each iteration + local_stat = copy.deepcopy(model_shifted_stat) + + # Whiten statistics (with the new mean and Sigma) + local_stat.whiten_stats(self.mean, self.Sigma) + + # Whiten the EigenVoice matrix + eigen_values, eigen_vectors = linalg.eigh(self.Sigma) + ind = eigen_values.real.argsort()[::-1] + eigen_values = eigen_values.real[ind] + eigen_vectors = eigen_vectors.real[:, ind] + sqr_inv_eval_sigma = 1 / numpy.sqrt(eigen_values.real) + sqr_inv_sigma = numpy.dot(eigen_vectors, + numpy.diag(sqr_inv_eval_sigma)) + self.F = sqr_inv_sigma.T.dot(self.F) + + # Replicate self.stat0 + index_map = numpy.zeros(vect_size, dtype=int) + _stat0 = local_stat.stat0[:, index_map] + + e_h = numpy.zeros((class_nb, self.rank_f)) + e_hh = numpy.zeros((class_nb, self.rank_f, self.rank_f)) + + # loop on model id's + fa_model_loop( + batch_start=0, + mini_batch_indices=numpy.arange(class_nb), + factor_analyser=self, + stat0=_stat0, + stats=local_stat.stats, + e_h=e_h, + e_hh=e_hh, ) + + # Accumulate for minimum divergence step + _R = numpy.sum(e_hh, axis=0) / session_per_model.shape[0] + + _C = e_h.T.dot(local_stat.stats).dot(linalg.inv(sqr_inv_sigma)) + _A = numpy.einsum("ijk,i->jk", e_hh, local_stat.stat0.squeeze()) + + # M-step + self.F = linalg.solve(_A, _C).T + + # Update the residual covariance + self.Sigma = sigma_obs - self.F.dot(_C) / session_per_model.sum() + + # Minimum Divergence step + self.F = self.F.dot(linalg.cholesky(_R)) + + def scoring( + self, + enroll, + test, + ndx, + test_uncertainty=None, + Vtrans=None, + p_known=0.0, + scaling_factor=1.0, + check_missing=True, ): + """ + Compute the PLDA scores between to sets of vectors. The list of + trials to perform is given in an Ndx object. PLDA matrices have to be + pre-computed. i-vectors/x-vectors are supposed to be whitened before. + + Arguments + --------- + enroll : paddlespeech.vector.cluster.diarization.EmbeddingMeta + A EmbeddingMeta in which stats are xvectors. + test : paddlespeech.vector.cluster.diarization.EmbeddingMeta + A EmbeddingMeta in which stats are xvectors. + ndx : paddlespeech.vector.cluster.plda.Ndx + An Ndx object defining the list of trials to perform. + p_known : float + Probability of having a known speaker for open-set + identification case (=1 for the verification task and =0 for the + closed-set case). + check_missing : bool + If True, check that all models and segments exist. + """ + + enroll_ctr = copy.deepcopy(enroll) + test_ctr = copy.deepcopy(test) + + # Remove missing models and test segments + if check_missing: + clean_ndx = _check_missing_model(enroll_ctr, test_ctr, ndx) + else: + clean_ndx = ndx + + # Center the i-vectors around the PLDA mean + enroll_ctr.center_stats(self.mean) + test_ctr.center_stats(self.mean) + + # Compute constant component of the PLDA distribution + invSigma = linalg.inv(self.Sigma) + I_spk = numpy.eye(self.F.shape[1], dtype="float") + + K = self.F.T.dot(invSigma * scaling_factor).dot(self.F) + K1 = linalg.inv(K + I_spk) + K2 = linalg.inv(2 * K + I_spk) + + # Compute the Gaussian distribution constant + alpha1 = numpy.linalg.slogdet(K1)[1] + alpha2 = numpy.linalg.slogdet(K2)[1] + plda_cst = alpha2 / 2.0 - alpha1 + + # Compute intermediate matrices + Sigma_ac = numpy.dot(self.F, self.F.T) + Sigma_tot = Sigma_ac + self.Sigma + Sigma_tot_inv = linalg.inv(Sigma_tot) + + Tmp = linalg.inv(Sigma_tot - Sigma_ac.dot(Sigma_tot_inv).dot(Sigma_ac)) + Phi = Sigma_tot_inv - Tmp + Psi = Sigma_tot_inv.dot(Sigma_ac).dot(Tmp) + + # Compute the different parts of PLDA score + model_part = 0.5 * numpy.einsum("ij, ji->i", + enroll_ctr.stats.dot(Phi), + enroll_ctr.stats.T) + seg_part = 0.5 * numpy.einsum("ij, ji->i", + test_ctr.stats.dot(Phi), test_ctr.stats.T) + + # Compute verification scores + score = Scores() # noqa F821 + score.modelset = clean_ndx.modelset + score.segset = clean_ndx.segset + score.scoremask = clean_ndx.trialmask + + score.scoremat = model_part[:, numpy.newaxis] + seg_part + plda_cst + score.scoremat += enroll_ctr.stats.dot(Psi).dot(test_ctr.stats.T) + score.scoremat *= scaling_factor + + # Case of open-set identification, we compute the log-likelihood + # by taking into account the probability of having a known impostor + # or an out-of set class + if p_known != 0: + N = score.scoremat.shape[0] + open_set_scores = numpy.empty(score.scoremat.shape) + tmp = numpy.exp(score.scoremat) + for ii in range(N): + # open-set term + open_set_scores[ii, :] = score.scoremat[ii, :] - numpy.log( + p_known * tmp[~(numpy.arange(N) == ii)].sum(axis=0) / ( + N - 1) + (1 - p_known)) + score.scoremat = open_set_scores + + return score + + +if __name__ == '__main__': + import random + + dim, N, n_spkrs = 10, 100, 10 + train_xv = numpy.random.rand(N, dim) + md = ['md' + str(random.randrange(1, n_spkrs, 1)) for i in range(N)] # spk + modelset = numpy.array(md, dtype="|O") + sg = ['sg' + str(i) for i in range(N)] # utt + segset = numpy.array(sg, dtype="|O") + stat0 = numpy.array([[1.0]] * N) + xvectors_stat = EmbeddingMeta( + modelset=modelset, segset=segset, stats=train_xv) + # Training PLDA model: M ~ (mean, F, Sigma) + plda = PLDA(rank_f=5) + plda.plda(xvectors_stat) + print(plda.mean.shape) #(10,) + print(plda.F.shape) #(10, 5) + print(plda.Sigma.shape) #(10, 10) + # Enrollment (20 utts), + en_N = 20 + en_xv = numpy.random.rand(en_N, dim) + en_sgs = ['en' + str(i) for i in range(en_N)] + en_sets = numpy.array(en_sgs, dtype="|O") + en_stat = EmbeddingMeta(modelset=en_sets, segset=en_sets, stats=en_xv) + # Test (30 utts) + te_N = 30 + te_xv = numpy.random.rand(te_N, dim) + te_sgs = ['te' + str(i) for i in range(te_N)] + te_sets = numpy.array(te_sgs, dtype="|O") + te_stat = EmbeddingMeta(modelset=te_sets, segset=te_sets, stats=te_xv) + ndx = Ndx(models=en_sets, testsegs=te_sets) # trials + # PLDA Scoring + scores_plda = plda.scoring(en_stat, te_stat, ndx) + print(scores_plda.scoremat.shape) #(20, 30) diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py index 5c5ce13e39a828cd501d1ece6bcf40d2e3f97fb8..a4d8c4524a504fc7349ca908fe416df7b52b6e35 100644 --- a/paddlespeech/vector/io/dataset_from_json.py +++ b/paddlespeech/vector/io/dataset_from_json.py @@ -26,14 +26,14 @@ from paddleaudio.compliance.librosa import mfcc class meta_info: """the audio meta info in the vector JSONDataset Args: - id (str): the segment name + utt_id (str): the segment name duration (float): segment time wav (str): wav file path start (int): start point in the original wav file stop (int): stop point in the original wav file lab_id (str): the record id """ - id: str + utt_id: str duration: float wav: str start: int