提交 016ed6d6 编写于 作者: X xiongxinlei

repair the code according to the part comment, test=doc

上级 97ec0126
...@@ -11,21 +11,21 @@ ...@@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast import ast
import os import os
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
import paddle.nn.functional as F from tqdm import tqdm
from paddlespeech.vector.training.metrics import compute_eer
from paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from tqdm import tqdm from paddlespeech.vector.training.metrics import compute_eer
def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
...@@ -44,7 +44,7 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): ...@@ -44,7 +44,7 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
return np.pad(x, pad_width, mode=mode, **kwargs) return np.pad(x, pad_width, mode=mode, **kwargs)
def feature_normalize(batch, mean_norm: bool = True, std_norm: bool = True): def feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
ids = [item['id'] for item in batch] ids = [item['id'] for item in batch]
lengths = np.asarray([item['feat'].shape[1] for item in batch]) lengths = np.asarray([item['feat'].shape[1] for item in batch])
feats = list( feats = list(
...@@ -58,8 +58,8 @@ def feature_normalize(batch, mean_norm: bool = True, std_norm: bool = True): ...@@ -58,8 +58,8 @@ def feature_normalize(batch, mean_norm: bool = True, std_norm: bool = True):
mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0 mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0
std = feat.std(axis=-1, keepdims=True) if std_norm else 1 std = feat.std(axis=-1, keepdims=True) if std_norm else 1
feats[i][:, :lengths[i]] = (feat - mean) / std feats[i][:, :lengths[i]] = (feat - mean) / std
assert feats[i][:, lengths[i]:].sum( assert feats[i][:, lengths[
) == 0 # Padding valus should all be 0. i]:].sum() == 0 # Padding valus should all be 0.
# Converts into ratios. # Converts into ratios.
lengths = (lengths / lengths.max()).astype(np.float32) lengths = (lengths / lengths.max()).astype(np.float32)
...@@ -98,16 +98,16 @@ def main(args): ...@@ -98,16 +98,16 @@ def main(args):
print(f'Checkpoint loaded from {args.load_checkpoint}') print(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader # stage4: construct the enroll and test dataloader
enrol_ds = VoxCeleb1(subset='enrol', enrol_ds = VoxCeleb1(
feat_type='melspectrogram', subset='enrol',
random_chunk=False, feat_type='melspectrogram',
n_mels=80, random_chunk=False,
window_size=400, n_mels=80,
hop_length=160) window_size=400,
hop_length=160)
enrol_sampler = BatchSampler( enrol_sampler = BatchSampler(
enrol_ds, enrol_ds, batch_size=args.batch_size,
batch_size=args.batch_size, shuffle=True) # Shuffle to make embedding normalization more robust.
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enrol_ds, enrol_loader = DataLoader(enrol_ds,
batch_sampler=enrol_sampler, batch_sampler=enrol_sampler,
collate_fn=lambda x: feature_normalize( collate_fn=lambda x: feature_normalize(
...@@ -115,16 +115,16 @@ def main(args): ...@@ -115,16 +115,16 @@ def main(args):
num_workers=args.num_workers, num_workers=args.num_workers,
return_list=True,) return_list=True,)
test_ds = VoxCeleb1(subset='test', test_ds = VoxCeleb1(
feat_type='melspectrogram', subset='test',
random_chunk=False, feat_type='melspectrogram',
n_mels=80, random_chunk=False,
window_size=400, n_mels=80,
hop_length=160) window_size=400,
hop_length=160)
test_sampler = BatchSampler(test_ds, test_sampler = BatchSampler(
batch_size=args.batch_size, test_ds, batch_size=args.batch_size, shuffle=True)
shuffle=True)
test_loader = DataLoader(test_ds, test_loader = DataLoader(test_ds,
batch_sampler=test_sampler, batch_sampler=test_sampler,
collate_fn=lambda x: feature_normalize( collate_fn=lambda x: feature_normalize(
...@@ -169,12 +169,13 @@ def main(args): ...@@ -169,12 +169,13 @@ def main(args):
embedding_mean, embedding_std = mean, std embedding_mean, embedding_std = mean, std
else: else:
weight = 1 / batch_count # Weight decay by batches. weight = 1 / batch_count # Weight decay by batches.
embedding_mean = ( embedding_mean = (1 - weight
1 - weight) * embedding_mean + weight * mean ) * embedding_mean + weight * mean
embedding_std = ( embedding_std = (1 - weight
1 - weight) * embedding_std + weight * std ) * embedding_std + weight * std
# Apply global embedding normalization. # Apply global embedding normalization.
embeddings = (embeddings - embedding_mean) / embedding_std embeddings = (
embeddings - embedding_mean) / embedding_std
# Update embedding dict. # Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings))) id2embedding.update(dict(zip(ids, embeddings)))
...@@ -201,38 +202,39 @@ def main(args): ...@@ -201,38 +202,39 @@ def main(args):
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}' f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
) )
if __name__ == "__main__": if __name__ == "__main__":
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device', parser.add_argument('--device',
choices=['cpu', 'gpu'], choices=['cpu', 'gpu'],
default="gpu", default="gpu",
help="Select which device to train model, defaults to gpu.") help="Select which device to train model, defaults to gpu.")
parser.add_argument("--batch-size", parser.add_argument("--batch-size",
type=int, type=int,
default=16, default=16,
help="Total examples' number in batch for training.") help="Total examples' number in batch for training.")
parser.add_argument("--num-workers", parser.add_argument("--num-workers",
type=int, type=int,
default=0, default=0,
help="Number of workers in dataloader.") help="Number of workers in dataloader.")
parser.add_argument("--load-checkpoint", parser.add_argument("--load-checkpoint",
type=str, type=str,
default='', default='',
help="Directory to load model checkpoint to contiune trainning.") help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--global-embedding-norm", parser.add_argument("--global-embedding-norm",
type=bool, type=bool,
default=True, default=True,
help="Apply global normalization on speaker embeddings.") help="Apply global normalization on speaker embeddings.")
parser.add_argument("--embedding-mean-norm", parser.add_argument("--embedding-mean-norm",
type=bool, type=bool,
default=True, default=True,
help="Apply mean normalization on speaker embeddings.") help="Apply mean normalization on speaker embeddings.")
parser.add_argument("--embedding-std-norm", parser.add_argument("--embedding-std-norm",
type=bool, type=bool,
default=False, default=False,
help="Apply std normalization on speaker embeddings.") help="Apply std normalization on speaker embeddings.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
main(args) main(args)
\ No newline at end of file
...@@ -22,22 +22,23 @@ from paddle.io import DistributedBatchSampler ...@@ -22,22 +22,23 @@ from paddle.io import DistributedBatchSampler
from paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddleaudio.features.core import melspectrogram from paddleaudio.features.core import melspectrogram
from paddlespeech.vector.training.time import Timer from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.datasets.batch import feature_normalize from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.datasets.batch import waveform_collate_fn
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
from paddlespeech.vector.layers.lr import CyclicLRScheduler
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
from paddlespeech.vector.modules.lr import CyclicLRScheduler
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.utils.time import Timer
# feat configuration # feat configuration
cpu_feat_conf = { cpu_feat_conf = {
'n_mels': 80, 'n_mels': 80,
'window_size': 400, 'window_size': 400, #ms
'hop_length': 160, 'hop_length': 160, #ms
} }
def main(args): def main(args):
# stage0: set the training device, cpu or gpu # stage0: set the training device, cpu or gpu
paddle.set_device(args.device) paddle.set_device(args.device)
......
...@@ -76,6 +76,9 @@ class VoxCeleb1(Dataset): ...@@ -76,6 +76,9 @@ class VoxCeleb1(Dataset):
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id')) 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
base_path = os.path.join(DATA_HOME, 'vox1') base_path = os.path.join(DATA_HOME, 'vox1')
wav_path = os.path.join(base_path, 'wav') wav_path = os.path.join(base_path, 'wav')
meta_path = os.path.join(base_path, 'meta')
veri_test_file = os.path.join(meta_path, 'veri_test2.txt')
csv_path = os.path.join(base_path, 'csv')
subsets = ['train', 'dev', 'enrol', 'test'] subsets = ['train', 'dev', 'enrol', 'test']
def __init__( def __init__(
......
...@@ -22,30 +22,22 @@ from .log import logger ...@@ -22,30 +22,22 @@ from .log import logger
download.logger = logger download.logger = logger
__all__ = [
'decompress',
'download_and_decompress',
'load_state_dict_from_url',
]
def decompress(file: str, path: str=os.PathLike):
def decompress(file: str):
""" """
Extracts all files from a compressed file to specific path. Extracts all files from a compressed file.
""" """
assert os.path.isfile(file), "File: {} not exists.".format(file) assert os.path.isfile(file), "File: {} not exists.".format(file)
download._decompress(file)
if path is None:
print("decompress the data: {}".format(file))
download._decompress(file)
else:
print("decompress the data: {} to {}".format(file, path))
if not os.path.isdir(path):
os.makedirs(path)
tmp_file = os.path.join(path, os.path.basename(file))
os.rename(file, tmp_file)
download._decompress(tmp_file)
os.rename(tmp_file, file)
def download_and_decompress(archives: List[Dict[str, str]], path: str):
def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
""" """
Download archieves and decompress to specific path. Download archieves and decompress to specific path.
""" """
...@@ -55,8 +47,8 @@ def download_and_decompress(archives: List[Dict[str, str]], ...@@ -55,8 +47,8 @@ def download_and_decompress(archives: List[Dict[str, str]],
for archive in archives: for archive in archives:
assert 'url' in archive and 'md5' in archive, \ assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' 'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress) download.get_path_from_url(archive['url'], path, archive['md5'])
def load_state_dict_from_url(url: str, path: str, md5: str=None): def load_state_dict_from_url(url: str, path: str, md5: str=None):
...@@ -67,4 +59,4 @@ def load_state_dict_from_url(url: str, path: str, md5: str=None): ...@@ -67,4 +59,4 @@ def load_state_dict_from_url(url: str, path: str, md5: str=None):
os.makedirs(path) os.makedirs(path)
download.get_path_from_url(url, path, md5) download.get_path_from_url(url, path, md5)
return load_state_dict(os.path.join(path, os.path.basename(url))) return load_state_dict(os.path.join(path, os.path.basename(url)))
\ No newline at end of file
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import paddle import paddle
...@@ -67,4 +66,4 @@ class LogSoftmaxWrapper(nn.Layer): ...@@ -67,4 +66,4 @@ class LogSoftmaxWrapper(nn.Layer):
predictions = F.log_softmax(predictions, axis=1) predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum() loss = self.criterion(predictions, targets) / targets.sum()
return loss return loss
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from sklearn.metrics import roc_curve
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
'''
Compute EER and return score threshold.
'''
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
fnr = 1 - tpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return eer, eer_threshold
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
from typing import List
from paddle.framework import load as load_state_dict
from paddle.utils import download
__all__ = [
'decompress',
'download_and_decompress',
'load_state_dict_from_url',
]
def decompress(file: str, path: str=os.PathLike):
"""
Extracts all files from a compressed file to specific path.
"""
assert os.path.isfile(file), "File: {} not exists.".format(file)
if path is None:
print("decompress the data: {}".format(file))
download._decompress(file)
else:
print("decompress the data: {} to {}".format(file, path))
if not os.path.isdir(path):
os.makedirs(path)
tmp_file = os.path.join(path, os.path.basename(file))
os.rename(file, tmp_file)
download._decompress(tmp_file)
os.rename(tmp_file, file)
def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
for archive in archives:
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress)
def load_state_dict_from_url(url: str, path: str, md5: str=None):
"""
Download and load a state dict from url
"""
if not os.path.isdir(path):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load_state_dict(os.path.join(path, os.path.basename(url)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册