提交 016ed6d6 编写于 作者: X xiongxinlei

repair the code according to the part comment, test=doc

上级 97ec0126
......@@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.io import BatchSampler
from paddle.io import DataLoader
import paddle.nn.functional as F
from paddlespeech.vector.training.metrics import compute_eer
from tqdm import tqdm
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification
from tqdm import tqdm
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.metrics import compute_eer
def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
......@@ -44,7 +44,7 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
return np.pad(x, pad_width, mode=mode, **kwargs)
def feature_normalize(batch, mean_norm: bool = True, std_norm: bool = True):
def feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
ids = [item['id'] for item in batch]
lengths = np.asarray([item['feat'].shape[1] for item in batch])
feats = list(
......@@ -58,8 +58,8 @@ def feature_normalize(batch, mean_norm: bool = True, std_norm: bool = True):
mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0
std = feat.std(axis=-1, keepdims=True) if std_norm else 1
feats[i][:, :lengths[i]] = (feat - mean) / std
assert feats[i][:, lengths[i]:].sum(
) == 0 # Padding valus should all be 0.
assert feats[i][:, lengths[
i]:].sum() == 0 # Padding valus should all be 0.
# Converts into ratios.
lengths = (lengths / lengths.max()).astype(np.float32)
......@@ -98,16 +98,16 @@ def main(args):
print(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader
enrol_ds = VoxCeleb1(subset='enrol',
feat_type='melspectrogram',
random_chunk=False,
n_mels=80,
window_size=400,
hop_length=160)
enrol_ds = VoxCeleb1(
subset='enrol',
feat_type='melspectrogram',
random_chunk=False,
n_mels=80,
window_size=400,
hop_length=160)
enrol_sampler = BatchSampler(
enrol_ds,
batch_size=args.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_ds, batch_size=args.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enrol_ds,
batch_sampler=enrol_sampler,
collate_fn=lambda x: feature_normalize(
......@@ -115,16 +115,16 @@ def main(args):
num_workers=args.num_workers,
return_list=True,)
test_ds = VoxCeleb1(subset='test',
feat_type='melspectrogram',
random_chunk=False,
n_mels=80,
window_size=400,
hop_length=160)
test_ds = VoxCeleb1(
subset='test',
feat_type='melspectrogram',
random_chunk=False,
n_mels=80,
window_size=400,
hop_length=160)
test_sampler = BatchSampler(test_ds,
batch_size=args.batch_size,
shuffle=True)
test_sampler = BatchSampler(
test_ds, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_ds,
batch_sampler=test_sampler,
collate_fn=lambda x: feature_normalize(
......@@ -169,12 +169,13 @@ def main(args):
embedding_mean, embedding_std = mean, std
else:
weight = 1 / batch_count # Weight decay by batches.
embedding_mean = (
1 - weight) * embedding_mean + weight * mean
embedding_std = (
1 - weight) * embedding_std + weight * std
embedding_mean = (1 - weight
) * embedding_mean + weight * mean
embedding_std = (1 - weight
) * embedding_std + weight * std
# Apply global embedding normalization.
embeddings = (embeddings - embedding_mean) / embedding_std
embeddings = (
embeddings - embedding_mean) / embedding_std
# Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings)))
......@@ -201,38 +202,39 @@ def main(args):
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
)
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="gpu",
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--batch-size",
type=int,
default=16,
parser.add_argument("--batch-size",
type=int,
default=16,
help="Total examples' number in batch for training.")
parser.add_argument("--num-workers",
type=int,
default=0,
parser.add_argument("--num-workers",
type=int,
default=0,
help="Number of workers in dataloader.")
parser.add_argument("--load-checkpoint",
type=str,
default='',
parser.add_argument("--load-checkpoint",
type=str,
default='',
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--global-embedding-norm",
type=bool,
default=True,
parser.add_argument("--global-embedding-norm",
type=bool,
default=True,
help="Apply global normalization on speaker embeddings.")
parser.add_argument("--embedding-mean-norm",
type=bool,
default=True,
parser.add_argument("--embedding-mean-norm",
type=bool,
default=True,
help="Apply mean normalization on speaker embeddings.")
parser.add_argument("--embedding-std-norm",
type=bool,
default=False,
parser.add_argument("--embedding-std-norm",
type=bool,
default=False,
help="Apply std normalization on speaker embeddings.")
args = parser.parse_args()
# yapf: enable
main(args)
\ No newline at end of file
main(args)
......@@ -22,22 +22,23 @@ from paddle.io import DistributedBatchSampler
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddleaudio.features.core import melspectrogram
from paddlespeech.vector.training.time import Timer
from paddlespeech.vector.datasets.batch import feature_normalize
from paddlespeech.vector.datasets.batch import waveform_collate_fn
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
from paddlespeech.vector.layers.lr import CyclicLRScheduler
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification
from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
from paddlespeech.vector.modules.lr import CyclicLRScheduler
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.utils.time import Timer
# feat configuration
cpu_feat_conf = {
'n_mels': 80,
'window_size': 400,
'hop_length': 160,
'window_size': 400, #ms
'hop_length': 160, #ms
}
def main(args):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
......
......@@ -76,6 +76,9 @@ class VoxCeleb1(Dataset):
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
base_path = os.path.join(DATA_HOME, 'vox1')
wav_path = os.path.join(base_path, 'wav')
meta_path = os.path.join(base_path, 'meta')
veri_test_file = os.path.join(meta_path, 'veri_test2.txt')
csv_path = os.path.join(base_path, 'csv')
subsets = ['train', 'dev', 'enrol', 'test']
def __init__(
......
......@@ -22,30 +22,22 @@ from .log import logger
download.logger = logger
__all__ = [
'decompress',
'download_and_decompress',
'load_state_dict_from_url',
]
def decompress(file: str, path: str=os.PathLike):
def decompress(file: str):
"""
Extracts all files from a compressed file to specific path.
Extracts all files from a compressed file.
"""
assert os.path.isfile(file), "File: {} not exists.".format(file)
download._decompress(file)
if path is None:
print("decompress the data: {}".format(file))
download._decompress(file)
else:
print("decompress the data: {} to {}".format(file, path))
if not os.path.isdir(path):
os.makedirs(path)
tmp_file = os.path.join(path, os.path.basename(file))
os.rename(file, tmp_file)
download._decompress(tmp_file)
os.rename(tmp_file, file)
def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
def download_and_decompress(archives: List[Dict[str, str]], path: str):
"""
Download archieves and decompress to specific path.
"""
......@@ -55,8 +47,8 @@ def download_and_decompress(archives: List[Dict[str, str]],
for archive in archives:
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress)
download.get_path_from_url(archive['url'], path, archive['md5'])
def load_state_dict_from_url(url: str, path: str, md5: str=None):
......@@ -67,4 +59,4 @@ def load_state_dict_from_url(url: str, path: str, md5: str=None):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load_state_dict(os.path.join(path, os.path.basename(url)))
return load_state_dict(os.path.join(path, os.path.basename(url)))
\ No newline at end of file
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
......@@ -67,4 +66,4 @@ class LogSoftmaxWrapper(nn.Layer):
predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum()
return loss
\ No newline at end of file
return loss
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from sklearn.metrics import roc_curve
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
'''
Compute EER and return score threshold.
'''
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
fnr = 1 - tpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return eer, eer_threshold
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
from typing import List
from paddle.framework import load as load_state_dict
from paddle.utils import download
__all__ = [
'decompress',
'download_and_decompress',
'load_state_dict_from_url',
]
def decompress(file: str, path: str=os.PathLike):
"""
Extracts all files from a compressed file to specific path.
"""
assert os.path.isfile(file), "File: {} not exists.".format(file)
if path is None:
print("decompress the data: {}".format(file))
download._decompress(file)
else:
print("decompress the data: {} to {}".format(file, path))
if not os.path.isdir(path):
os.makedirs(path)
tmp_file = os.path.join(path, os.path.basename(file))
os.rename(file, tmp_file)
download._decompress(tmp_file)
os.rename(tmp_file, file)
def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
for archive in archives:
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress)
def load_state_dict_from_url(url: str, path: str, md5: str=None):
"""
Download and load a state dict from url
"""
if not os.path.isdir(path):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load_state_dict(os.path.join(path, os.path.basename(url)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册