提交 dc28ebe4 编写于 作者: X xiongxinlei

move the csv vox format to paddleaudio, test=doc

上级 3a943ca9
......@@ -11,319 +11,182 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
"""Prepare VoxCeleb1 dataset
create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
researchers should download the voxceleb1 dataset yourselves
through google form to get the username & password and unpack the data
"""
import argparse
import codecs
import glob
import json
import os
import random
from typing import Dict
from typing import List
from typing import Tuple
from paddle.io import Dataset
from pathos.multiprocessing import Pool
from tqdm import tqdm
from paddleaudio.backends import load as load_audio
from paddleaudio.datasets.dataset import feat_funcs
from paddleaudio.utils import DATA_HOME
from paddleaudio.utils import decompress
from paddleaudio.utils import download_and_decompress
import subprocess
from pathlib import Path
import soundfile
from utils.utility import check_md5sum
from utils.utility import download
from utils.utility import unpack
__all__ = ['VoxCeleb1']
class VoxCeleb1(Dataset):
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
archieves_audio_dev = [
{
'url': source_url + 'vox1_dev_wav_partaa',
'md5': 'e395d020928bc15670b570a21695ed96',
},
{
'url': source_url + 'vox1_dev_wav_partab',
'md5': 'bbfaaccefab65d82b21903e81a8a8020',
},
{
'url': source_url + 'vox1_dev_wav_partac',
'md5': '017d579a2a96a077f40042ec33e51512',
},
{
'url': source_url + 'vox1_dev_wav_partad',
'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19',
},
]
archieves_audio_test = [
{
'url': source_url + 'vox1_test_wav.zip',
'md5': '185fdc63c3c739954633d50379a3d102',
},
]
archieves_meta = [
{
'url':
'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt',
'md5':
'b73110731c9223c1461fe49cb48dddfc',
},
]
num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
sample_rate = 16000
meta_info = collections.namedtuple(
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
base_path = os.path.join(DATA_HOME, 'vox1')
wav_path = os.path.join(base_path, 'wav')
subsets = ['train', 'dev', 'enrol', 'test']
def __init__(
self,
subset: str='train',
feat_type: str='raw',
random_chunk: bool=True,
chunk_duration: float=3.0, # seconds
split_ratio: float=0.9, # train split ratio
seed: int=0,
target_dir: str=None,
**kwargs):
assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
self.subset = subset
self.spk_id2label = {}
self.feat_type = feat_type
self.feat_config = kwargs
self.random_chunk = random_chunk
self.chunk_duration = chunk_duration
self.split_ratio = split_ratio
self.target_dir = target_dir if target_dir else self.base_path
self.csv_path = os.path.join(
target_dir, 'csv') if target_dir else os.path.join(self.base_path,
'csv')
self.meta_path = os.path.join(
target_dir, 'meta') if target_dir else os.path.join(base_path,
'meta')
self.veri_test_file = os.path.join(self.meta_path, 'veri_test2.txt')
# self._data = self._get_data()[:1000] # KP: Small dataset test.
self._data = self._get_data()
super(VoxCeleb1, self).__init__()
# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
def _get_data(self):
# Download audio files.
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
# so, we check the vox1/wav dir status
print("wav base path: {}".format(self.wav_path))
if not os.path.isdir(self.wav_path):
print("start to download the voxceleb1 dataset")
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
self.archieves_audio_dev,
self.base_path,
decompress=False)
download_and_decompress( # download the vox1_test_wav.zip and unzip
self.archieves_audio_test,
self.base_path,
decompress=True)
# Download all parts and concatenate the files into one zip file.
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
print(f'Concatenating all parts to: {dev_zipfile}')
os.system(
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
)
# Extract all audio files of dev and test set.
decompress(dev_zipfile, self.base_path)
# Download meta files.
if not os.path.isdir(self.meta_path):
download_and_decompress(
self.archieves_meta, self.meta_path, decompress=False)
# Data preparation.
if not os.path.isdir(self.csv_path):
os.makedirs(self.csv_path)
self.prepare_data()
data = []
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav, start, stop, spk_id = line.strip(
).split(',')
data.append(
self.meta_info(audio_id,
float(duration), wav,
int(start), int(stop), spk_id))
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f:
for line in f.readlines():
spk_id, label = line.strip().split(' ')
self.spk_id2label[spk_id] = int(label)
return data
def _convert_to_record(self, idx: int):
sample = self._data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields:
record[field] = getattr(sample, field)
waveform, sr = load_audio(record['wav'])
# random select a chunk audio samples from the audio
if self.random_chunk:
num_wav_samples = waveform.shape[0]
num_chunk_samples = int(self.chunk_duration * sr)
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
stop = start + num_chunk_samples
from utils.utility import unzip
# all the data will be download in the current data/voxceleb directory default
DATA_HOME = os.path.expanduser('.')
# if you use the http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/ as the download base url
# you need to get the username & password via the google form
# if you use the https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a as the download base url,
# you need use --no-check-certificate to connect the target download url
BASE_URL = "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a"
# dev data
DEV_LIST = {
"vox1_dev_wav_partaa": "e395d020928bc15670b570a21695ed96",
"vox1_dev_wav_partab": "bbfaaccefab65d82b21903e81a8a8020",
"vox1_dev_wav_partac": "017d579a2a96a077f40042ec33e51512",
"vox1_dev_wav_partad": "7bb1e9f70fddc7a678fa998ea8b3ba19",
}
DEV_TARGET_DATA = "vox1_dev_wav_parta* vox1_dev_wav.zip ae63e55b951748cc486645f532ba230b"
# test data
TEST_LIST = {"vox1_test_wav.zip": "185fdc63c3c739954633d50379a3d102"}
TEST_TARGET_DATA = "vox1_test_wav.zip vox1_test_wav.zip 185fdc63c3c739954633d50379a3d102"
# kaldi trial
# this trial file is organized by kaldi according the official file,
# which is a little different with the official trial veri_test2.txt
KALDI_BASE_URL = "http://www.openslr.org/resources/49/"
TRIAL_LIST = {"voxceleb1_test_v2.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7"}
TRIAL_TARGET_DATA = "voxceleb1_test_v2.txt voxceleb1_test_v2.txt 29fc7cc1c5d59f0816dc15d6e8be60f7"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/voxceleb1/",
type=str,
help="Directory to save the voxceleb1 dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_path = os.path.join(data_dir, "wav", "**", "*.wav")
total_sec = 0.0
total_text = 0.0
total_num = 0
speakers = set()
for audio_path in glob.glob(data_path, recursive=True):
audio_id = "-".join(audio_path.split("/")[-3:])
utt2spk = audio_path.split("/")[-3]
duration = soundfile.info(audio_path).duration
text = ""
json_lines.append(
json.dumps(
{
"utt": audio_id,
"utt2spk": str(utt2spk),
"feat": audio_path,
"feat_shape": (duration, ),
"text": text # compatible with asr data format
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
speakers.add(utt2spk)
# data_dir_name refer to dev or test
# voxceleb1 is given explicit in the path
data_dir_name = Path(data_dir).name
manifest_path_prefix = manifest_path_prefix + "." + data_dir_name
with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f:
for line in json_lines:
f.write(line + "\n")
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, "voxceleb1." +
data_dir_name) + ".meta"
with codecs.open(meta_path, 'w', encoding='utf-8') as f:
print(f"{total_num} utts", file=f)
print(f"{len(speakers)} speakers", file=f)
print(f"{total_sec / (60 * 60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(base_url, data_list, target_dir, manifest_path,
target_data):
if not os.path.exists(target_dir):
os.mkdir(target_dir)
# wav directory already exists, it need do nothing
if not os.path.exists(os.path.join(target_dir, "wav")):
# download all dataset part
for zip_part in data_list.keys():
download_url = " --no-check-certificate " + base_url + "/" + zip_part
download(
url=download_url,
md5sum=data_list[zip_part],
target_dir=target_dir)
# pack the all part to target zip file
all_target_part, target_name, target_md5sum = target_data.split()
target_name = os.path.join(target_dir, target_name)
if not os.path.exists(target_name):
pack_part_cmd = "cat {}/{} > {}".format(target_dir, all_target_part,
target_name)
subprocess.call(pack_part_cmd, shell=True)
# check the target zip file md5sum
if not check_md5sum(target_name, target_md5sum):
raise RuntimeError("{} MD5 checkssum failed".format(target_name))
else:
start = record['start']
stop = record['stop']
waveform = waveform[start:stop]
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
if self.subset in ['train',
'dev']: # Labels are available in train and dev.
record.update({'label': self.spk_id2label[record['spk_id']]})
return record
@staticmethod
def _get_chunks(seg_dur, audio_id, audio_duration):
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def _get_audio_info(self, wav_file: str,
split_chunks: bool) -> List[List[str]]:
waveform, sr = load_audio(wav_file)
spk_id, sess_id, utt_id = wav_file.split("/")[-3:]
audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]])
audio_duration = waveform.shape[0] / sr
ret = []
if split_chunks: # Split into pieces of self.chunk_duration seconds.
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
audio_duration)
for chunk in uniq_chunks_list:
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
# id, duration, wav, start, stop, spk_id
ret.append([
chunk, audio_duration, wav_file, start_sample, end_sample,
spk_id
])
else: # Keep whole audio.
ret.append([
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
])
return ret
def generate_csv(self,
wav_files: List[str],
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
header = ["id", "duration", "wav", "start", "stop", "spk_id"]
with Pool(64) as p:
infos = list(
tqdm(
p.imap(lambda x: self._get_audio_info(x, split_chunks),
wav_files),
total=len(wav_files)))
csv_lines = []
for info in infos:
csv_lines.extend(info)
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def prepare_data(self):
# Audio of speakers in veri_test_file should not be included in training set.
print("start to prepare the data csv file")
enrol_files = set()
test_files = set()
# get the enroll and test audio file path
with open(self.veri_test_file, 'r') as f:
for line in f.readlines():
_, enrol_file, test_file = line.strip().split(' ')
enrol_files.add(os.path.join(self.wav_path, enrol_file))
test_files.add(os.path.join(self.wav_path, test_file))
enrol_files = sorted(enrol_files)
test_files = sorted(test_files)
# get the enroll and test speakers
test_spks = set()
for file in (enrol_files + test_files):
spk = file.split('/wav/')[1].split('/')[0]
test_spks.add(spk)
# get all the train and dev audios file path
audio_files = []
speakers = set()
for path in [self.wav_path]:
for file in glob.glob(
os.path.join(path, "**", "*.wav"), recursive=True):
spk = file.split('/wav/')[1].split('/')[0]
if spk in test_spks:
continue
speakers.add(spk)
audio_files.append(file)
print("start to generate the {}".format(
os.path.join(self.meta_path, 'spk_id2label.txt')))
# encode the train and dev speakers label to spk_id2label.txt
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
for label, spk_id in enumerate(
sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2
f.write(f'{spk_id} {label}\n')
audio_files = sorted(audio_files)
random.shuffle(audio_files)
split_idx = int(self.split_ratio * len(audio_files))
# split_ratio to train
train_files, dev_files = audio_files[:split_idx], audio_files[
split_idx:]
self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv'))
self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv'))
self.generate_csv(
enrol_files,
os.path.join(self.csv_path, 'enrol.csv'),
split_chunks=False)
self.generate_csv(
test_files,
os.path.join(self.csv_path, 'test.csv'),
split_chunks=False)
def __getitem__(self, idx):
return self._convert_to_record(idx)
def __len__(self):
return len(self._data)
print("Check {} md5sum successfully".format(target_name))
# unzip the all zip file
if target_name.endswith(".zip"):
unzip(target_name, target_dir)
# create the manifest file
create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
base_url=BASE_URL,
data_list=DEV_LIST,
target_dir=os.path.join(args.target_dir, "dev"),
manifest_path=args.manifest_prefix,
target_data=DEV_TARGET_DATA)
prepare_dataset(
base_url=BASE_URL,
data_list=TEST_LIST,
target_dir=os.path.join(args.target_dir, "test"),
manifest_path=args.manifest_prefix,
target_data=TEST_TARGET_DATA)
print("Manifest prepare done!")
if __name__ == '__main__':
main()
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
import glob
import os
import random
from typing import Dict
from typing import List
from typing import Tuple
from paddle.io import Dataset
from pathos.multiprocessing import Pool
from tqdm import tqdm
from paddleaudio.backends import load as load_audio
from paddleaudio.datasets.dataset import feat_funcs
from paddleaudio.utils import DATA_HOME
from paddleaudio.utils import decompress
from paddleaudio.utils import download_and_decompress
from utils.utility import download
from utils.utility import unpack
__all__ = ['VoxCeleb1']
class VoxCeleb1(Dataset):
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
archieves_audio_dev = [
{
'url': source_url + 'vox1_dev_wav_partaa',
'md5': 'e395d020928bc15670b570a21695ed96',
},
{
'url': source_url + 'vox1_dev_wav_partab',
'md5': 'bbfaaccefab65d82b21903e81a8a8020',
},
{
'url': source_url + 'vox1_dev_wav_partac',
'md5': '017d579a2a96a077f40042ec33e51512',
},
{
'url': source_url + 'vox1_dev_wav_partad',
'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19',
},
]
archieves_audio_test = [
{
'url': source_url + 'vox1_test_wav.zip',
'md5': '185fdc63c3c739954633d50379a3d102',
},
]
archieves_meta = [
{
'url':
'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt',
'md5':
'b73110731c9223c1461fe49cb48dddfc',
},
]
num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
sample_rate = 16000
meta_info = collections.namedtuple(
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
base_path = os.path.join(DATA_HOME, 'vox1')
wav_path = os.path.join(base_path, 'wav')
subsets = ['train', 'dev', 'enrol', 'test']
def __init__(
self,
subset: str='train',
feat_type: str='raw',
random_chunk: bool=True,
chunk_duration: float=3.0, # seconds
split_ratio: float=0.9, # train split ratio
seed: int=0,
target_dir: str=None,
**kwargs):
assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
self.subset = subset
self.spk_id2label = {}
self.feat_type = feat_type
self.feat_config = kwargs
self.random_chunk = random_chunk
self.chunk_duration = chunk_duration
self.split_ratio = split_ratio
self.target_dir = target_dir if target_dir else self.base_path
self.csv_path = os.path.join(
target_dir, 'csv') if target_dir else os.path.join(self.base_path,
'csv')
self.meta_path = os.path.join(
target_dir, 'meta') if target_dir else os.path.join(self.base_path,
'meta')
self.veri_test_file = os.path.join(self.meta_path, 'veri_test2.txt')
# self._data = self._get_data()[:1000] # KP: Small dataset test.
self._data = self._get_data()
super(VoxCeleb1, self).__init__()
# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
def _get_data(self):
# Download audio files.
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
# so, we check the vox1/wav dir status
print("wav base path: {}".format(self.wav_path))
if not os.path.isdir(self.wav_path):
print("start to download the voxceleb1 dataset")
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
self.archieves_audio_dev,
self.base_path,
decompress=False)
download_and_decompress( # download the vox1_test_wav.zip and unzip
self.archieves_audio_test,
self.base_path,
decompress=True)
# Download all parts and concatenate the files into one zip file.
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
print(f'Concatenating all parts to: {dev_zipfile}')
os.system(
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
)
# Extract all audio files of dev and test set.
decompress(dev_zipfile, self.base_path)
# Download meta files.
if not os.path.isdir(self.meta_path):
download_and_decompress(
self.archieves_meta, self.meta_path, decompress=False)
# Data preparation.
if not os.path.isdir(self.csv_path):
os.makedirs(self.csv_path)
self.prepare_data()
data = []
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav, start, stop, spk_id = line.strip(
).split(',')
data.append(
self.meta_info(audio_id,
float(duration), wav,
int(start), int(stop), spk_id))
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f:
for line in f.readlines():
spk_id, label = line.strip().split(' ')
self.spk_id2label[spk_id] = int(label)
return data
def _convert_to_record(self, idx: int):
sample = self._data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields:
record[field] = getattr(sample, field)
waveform, sr = load_audio(record['wav'])
# random select a chunk audio samples from the audio
if self.random_chunk:
num_wav_samples = waveform.shape[0]
num_chunk_samples = int(self.chunk_duration * sr)
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
stop = start + num_chunk_samples
else:
start = record['start']
stop = record['stop']
waveform = waveform[start:stop]
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
if self.subset in ['train',
'dev']: # Labels are available in train and dev.
record.update({'label': self.spk_id2label[record['spk_id']]})
return record
@staticmethod
def _get_chunks(seg_dur, audio_id, audio_duration):
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def _get_audio_info(self, wav_file: str,
split_chunks: bool) -> List[List[str]]:
waveform, sr = load_audio(wav_file)
spk_id, sess_id, utt_id = wav_file.split("/")[-3:]
audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]])
audio_duration = waveform.shape[0] / sr
ret = []
if split_chunks: # Split into pieces of self.chunk_duration seconds.
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
audio_duration)
for chunk in uniq_chunks_list:
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
# id, duration, wav, start, stop, spk_id
ret.append([
chunk, audio_duration, wav_file, start_sample, end_sample,
spk_id
])
else: # Keep whole audio.
ret.append([
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
])
return ret
def generate_csv(self,
wav_files: List[str],
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
header = ["id", "duration", "wav", "start", "stop", "spk_id"]
with Pool(64) as p:
infos = list(
tqdm(
p.imap(lambda x: self._get_audio_info(x, split_chunks),
wav_files),
total=len(wav_files)))
csv_lines = []
for info in infos:
csv_lines.extend(info)
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def prepare_data(self):
# Audio of speakers in veri_test_file should not be included in training set.
print("start to prepare the data csv file")
enrol_files = set()
test_files = set()
# get the enroll and test audio file path
with open(self.veri_test_file, 'r') as f:
for line in f.readlines():
_, enrol_file, test_file = line.strip().split(' ')
enrol_files.add(os.path.join(self.wav_path, enrol_file))
test_files.add(os.path.join(self.wav_path, test_file))
enrol_files = sorted(enrol_files)
test_files = sorted(test_files)
# get the enroll and test speakers
test_spks = set()
for file in (enrol_files + test_files):
spk = file.split('/wav/')[1].split('/')[0]
test_spks.add(spk)
# get all the train and dev audios file path
audio_files = []
speakers = set()
for path in [self.wav_path]:
for file in glob.glob(
os.path.join(path, "**", "*.wav"), recursive=True):
spk = file.split('/wav/')[1].split('/')[0]
if spk in test_spks:
continue
speakers.add(spk)
audio_files.append(file)
print("start to generate the {}".format(
os.path.join(self.meta_path, 'spk_id2label.txt')))
# encode the train and dev speakers label to spk_id2label.txt
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
for label, spk_id in enumerate(
sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2
f.write(f'{spk_id} {label}\n')
audio_files = sorted(audio_files)
random.shuffle(audio_files)
split_idx = int(self.split_ratio * len(audio_files))
# split_ratio to train
train_files, dev_files = audio_files[:split_idx], audio_files[
split_idx:]
self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv'))
self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv'))
self.generate_csv(
enrol_files,
os.path.join(self.csv_path, 'enrol.csv'),
split_chunks=False)
self.generate_csv(
test_files,
os.path.join(self.csv_path, 'test.csv'),
split_chunks=False)
def __getitem__(self, idx):
return self._convert_to_record(idx)
def __len__(self):
return len(self._data)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册