Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
30b5b3cb
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
30b5b3cb
编写于
4月 02, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vector csv dataset format, test=doc
上级
5b05300e
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
149 addition
and
131 deletion
+149
-131
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
+2
-2
examples/voxceleb/sv0/local/data.sh
examples/voxceleb/sv0/local/data.sh
+3
-3
examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
...xceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
+42
-31
examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
...ples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
+11
-30
paddlespeech/vector/exps/ecapa_tdnn/train.py
paddlespeech/vector/exps/ecapa_tdnn/train.py
+9
-5
paddlespeech/vector/io/augment.py
paddlespeech/vector/io/augment.py
+8
-8
paddlespeech/vector/io/dataset.py
paddlespeech/vector/io/dataset.py
+42
-52
paddlespeech/vector/utils/utils.py
paddlespeech/vector/utils/utils.py
+32
-0
未找到文件。
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
浏览文件 @
30b5b3cb
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
# we should explicitly specify the wav path of vox2 audio data converted from m4a
# we should explicitly specify the wav path of vox2 audio data converted from m4a
vox2_base_path
:
vox2_base_path
:
augment
:
True
augment
:
True
batch_size
:
16
batch_size
:
32
num_workers
:
2
num_workers
:
2
num_speakers
:
7205
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
num_speakers
:
1211
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle
:
True
shuffle
:
True
skip_prep
:
False
skip_prep
:
False
split_ratio
:
0.9
split_ratio
:
0.9
...
...
examples/voxceleb/sv0/local/data.sh
浏览文件 @
30b5b3cb
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
stage
=
7
stage
=
0
stop_stage
=
100
stop_stage
=
100
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
...
@@ -32,7 +32,7 @@ mkdir -p ${dir}
...
@@ -32,7 +32,7 @@ mkdir -p ${dir}
# Generally the `MAIN_ROOT` refers to the root of PaddleSpeech,
# Generally the `MAIN_ROOT` refers to the root of PaddleSpeech,
# which is defined in the path.sh
# which is defined in the path.sh
# And we will download the
# And we will download the
voxceleb data and rirs noise to ${MAIN_ROOT}/dataset
TARGET_DIR
=
${
MAIN_ROOT
}
/dataset
TARGET_DIR
=
${
MAIN_ROOT
}
/dataset
mkdir
-p
${
TARGET_DIR
}
mkdir
-p
${
TARGET_DIR
}
...
@@ -98,7 +98,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
...
@@ -98,7 +98,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# generate the vox csv file
# generate the vox csv file
# Currently, our training system use csv file for dataset
# Currently, our training system use csv file for dataset
echo
"convert the json format to csv format to be compatible with training process"
echo
"convert the json format to csv format to be compatible with training process"
python3
local
/make_csv_dataset_from_json.py
\
python3
local
/make_
vox_
csv_dataset_from_json.py
\
--train
"
${
dir
}
/vox1/manifest.dev"
\
--train
"
${
dir
}
/vox1/manifest.dev"
\
--test
"
${
dir
}
/vox1/manifest.test"
\
--test
"
${
dir
}
/vox1/manifest.test"
\
--target_dir
"
${
dir
}
/vox/"
\
--target_dir
"
${
dir
}
/vox/"
\
...
...
examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
浏览文件 @
30b5b3cb
...
@@ -20,31 +20,29 @@ import csv
...
@@ -20,31 +20,29 @@ import csv
import
os
import
os
from
typing
import
List
from
typing
import
List
import
paddle
import
tqdm
import
tqdm
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.training.seeding
import
seed_everything
logger
=
Log
(
__name__
).
getlog
()
from
paddleaudio
import
load
as
load_audio
from
paddleaudio
import
load
as
load_audio
from
paddle
audio
import
save
as
save_wav
from
paddle
speech.s2t.utils.log
import
Log
from
paddlespeech.vector.utils.utils
import
get_chunks
def
get_chunks
(
seg_dur
,
audio_id
,
audio_duration
):
logger
=
Log
(
__name__
).
getlog
()
num_chunks
=
int
(
audio_duration
/
seg_dur
)
# all in milliseconds
chunk_lst
=
[
audio_id
+
"_"
+
str
(
i
*
seg_dur
)
+
"_"
+
str
(
i
*
seg_dur
+
seg_dur
)
for
i
in
range
(
num_chunks
)
]
return
chunk_lst
def
get_chunks_list
(
wav_file
:
str
,
split_chunks
:
bool
,
base_path
:
str
,
chunk_duration
:
float
=
3.0
)
->
List
[
List
[
str
]]:
"""Get the single audio file info
def
get_audio_info
(
wav_file
:
str
,
Args:
split_chunks
:
bool
,
wav_file (list): the wav audio file and get this audio segment info list
base_path
:
str
,
split_chunks (bool): audio split flag
chunk_duration
:
float
=
3.0
)
->
List
[
List
[
str
]]:
base_path (str): the audio base path
chunk_duration (float): the chunk duration.
if set the split_chunks, we split the audio into multi-chunks segment.
"""
waveform
,
sr
=
load_audio
(
wav_file
)
waveform
,
sr
=
load_audio
(
wav_file
)
audio_id
=
wav_file
.
split
(
"/rir_noise/"
)[
-
1
].
split
(
"."
)[
0
]
audio_id
=
wav_file
.
split
(
"/rir_noise/"
)[
-
1
].
split
(
"."
)[
0
]
audio_duration
=
waveform
.
shape
[
0
]
/
sr
audio_duration
=
waveform
.
shape
[
0
]
/
sr
...
@@ -57,13 +55,16 @@ def get_audio_info(wav_file: str,
...
@@ -57,13 +55,16 @@ def get_audio_info(wav_file: str,
s
,
e
=
chunk
.
split
(
"_"
)[
-
2
:]
# Timestamps of start and end
s
,
e
=
chunk
.
split
(
"_"
)[
-
2
:]
# Timestamps of start and end
start_sample
=
int
(
float
(
s
)
*
sr
)
start_sample
=
int
(
float
(
s
)
*
sr
)
end_sample
=
int
(
float
(
e
)
*
sr
)
end_sample
=
int
(
float
(
e
)
*
sr
)
new_wav_file
=
os
.
path
.
join
(
base_path
,
audio_id
+
f
'_chunk_
{
idx
+
1
:
02
}
.wav'
)
# currently, all vector csv data format use one representation
save_wav
(
waveform
[
start_sample
:
end_sample
],
sr
,
new_wav_file
)
# id, duration, wav, start, stop, spk_id
# id, duration, new_wav
ret
.
append
([
ret
.
append
([
chunk
,
chunk_duration
,
new_wav_file
])
chunk
,
audio_duration
,
wav_file
,
start_sample
,
end_sample
,
"noise"
])
else
:
# Keep whole audio.
else
:
# Keep whole audio.
ret
.
append
([
audio_id
,
audio_duration
,
wav_file
])
ret
.
append
(
[
audio_id
,
audio_duration
,
wav_file
,
0
,
waveform
.
shape
[
0
],
"noise"
])
return
ret
return
ret
...
@@ -71,12 +72,20 @@ def generate_csv(wav_files,
...
@@ -71,12 +72,20 @@ def generate_csv(wav_files,
output_file
:
str
,
output_file
:
str
,
base_path
:
str
,
base_path
:
str
,
split_chunks
:
bool
=
True
):
split_chunks
:
bool
=
True
):
print
(
f
'Generating csv:
{
output_file
}
'
)
"""Prepare the csv file according the wav files
header
=
[
"id"
,
"duration"
,
"wav"
]
Args:
wav_files (list): all the audio list to prepare the csv file
output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag
"""
logger
.
info
(
f
'Generating csv:
{
output_file
}
'
)
header
=
[
"utt_id"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"lab_id"
]
csv_lines
=
[]
csv_lines
=
[]
for
item
in
tqdm
.
tqdm
(
wav_files
):
for
item
in
tqdm
.
tqdm
(
wav_files
):
csv_lines
.
extend
(
csv_lines
.
extend
(
get_
audio_info
(
get_
chunks_list
(
item
,
base_path
=
base_path
,
split_chunks
=
split_chunks
))
item
,
base_path
=
base_path
,
split_chunks
=
split_chunks
))
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
output_file
)):
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
output_file
)):
...
@@ -91,11 +100,12 @@ def generate_csv(wav_files,
...
@@ -91,11 +100,12 @@ def generate_csv(wav_files,
def
prepare_data
(
args
,
config
):
def
prepare_data
(
args
,
config
):
# stage0: set the cpu device,
"""Convert the jsonline format to csv format
# all data prepare process will be done in cpu mode
paddle
.
device
.
set_device
(
"cpu"
)
Args:
# set the random seed, it is a must for multiprocess training
args (argparse.Namespace): scripts args
seed_everything
(
config
.
seed
)
config (CfgNode): yaml configuration content
"""
# if external config set the skip_prep flat, we will do nothing
# if external config set the skip_prep flat, we will do nothing
if
config
.
skip_prep
:
if
config
.
skip_prep
:
return
return
...
@@ -119,6 +129,7 @@ def prepare_data(args, config):
...
@@ -119,6 +129,7 @@ def prepare_data(args, config):
noise_files
.
append
(
os
.
path
.
join
(
base_path
,
noise_file
))
noise_files
.
append
(
os
.
path
.
join
(
base_path
,
noise_file
))
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
'csv'
)
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
'csv'
)
logger
.
info
(
f
"csv path:
{
csv_path
}
"
)
generate_csv
(
generate_csv
(
rir_files
,
os
.
path
.
join
(
csv_path
,
'rir.csv'
),
base_path
=
base_path
)
rir_files
,
os
.
path
.
join
(
csv_path
,
'rir.csv'
),
base_path
=
base_path
)
generate_csv
(
generate_csv
(
...
...
examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
浏览文件 @
30b5b3cb
...
@@ -21,51 +21,34 @@ import json
...
@@ -21,51 +21,34 @@ import json
import
os
import
os
import
random
import
random
import
paddle
import
tqdm
import
tqdm
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddleaudio
import
load
as
load_audio
from
paddleaudio
import
load
as
load_audio
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.training.seeding
import
seed_everything
from
paddlespeech.vector.utils.utils
import
get_chunks
logger
=
Log
(
__name__
).
getlog
()
def
get_chunks
(
seg_dur
,
audio_id
,
audio_duration
):
"""Get all chunk segments from a utterance
Args:
logger
=
Log
(
__name__
).
getlog
()
seg_dur (float): segment chunk duration
audio_id (str): utterance name
audio_duration (float): utterance duration
Returns:
List: all the chunk segments
"""
num_chunks
=
int
(
audio_duration
/
seg_dur
)
# all in milliseconds
chunk_lst
=
[
audio_id
+
"_"
+
str
(
i
*
seg_dur
)
+
"_"
+
str
(
i
*
seg_dur
+
seg_dur
)
for
i
in
range
(
num_chunks
)
]
return
chunk_lst
def
prepare_csv
(
wav_files
,
output_file
,
config
,
split_chunks
=
True
):
def
prepare_csv
(
wav_files
,
output_file
,
config
,
split_chunks
=
True
):
"""Prepare the csv file according the wav files
"""Prepare the csv file according the wav files
Args:
Args:
dataset_list (list): all the dataset to get the test utterances
wav_files (list): all the audio list to prepare the csv file
verification_file (str): voxceleb1 trial file
output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag
"""
"""
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
output_file
)):
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
output_file
)):
os
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
os
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
csv_lines
=
[]
csv_lines
=
[]
header
=
[
"
id"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"spk
_id"
]
header
=
[
"
utt_id"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"lab
_id"
]
# voxceleb meta info for each training utterance segment
# voxceleb meta info for each training utterance segment
# we extract a segment from a utterance to train
# we extract a segment from a utterance to train
# and the segment' period is between start and stop time point in the original wav file
# and the segment' period is between start and stop time point in the original wav file
# each field in the meta means as follows:
# each field in the meta means as follows:
# id: the utterance segment name
#
utt_
id: the utterance segment name
# duration: utterance segment time
# duration: utterance segment time
# wav: utterance file path
# wav: utterance file path
# start: start point in the original wav file
# start: start point in the original wav file
...
@@ -194,11 +177,9 @@ def prepare_data(args, config):
...
@@ -194,11 +177,9 @@ def prepare_data(args, config):
args (argparse.Namespace): scripts args
args (argparse.Namespace): scripts args
config (CfgNode): yaml configuration content
config (CfgNode): yaml configuration content
"""
"""
# stage0: set the cpu device,
# stage0: set the random seed
# all data prepare process will be done in cpu mode
random
.
seed
(
config
.
seed
)
paddle
.
device
.
set_device
(
"cpu"
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
config
.
seed
)
# if external config set the skip_prep flat, we will do nothing
# if external config set the skip_prep flat, we will do nothing
if
config
.
skip_prep
:
if
config
.
skip_prep
:
return
return
...
...
paddlespeech/vector/exps/ecapa_tdnn/train.py
浏览文件 @
30b5b3cb
...
@@ -29,7 +29,7 @@ from paddlespeech.vector.io.augment import waveform_augment
...
@@ -29,7 +29,7 @@ from paddlespeech.vector.io.augment import waveform_augment
from
paddlespeech.vector.io.batch
import
batch_pad_right
from
paddlespeech.vector.io.batch
import
batch_pad_right
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.io.batch
import
waveform_collate_fn
from
paddlespeech.vector.io.batch
import
waveform_collate_fn
from
paddlespeech.vector.io.dataset
import
VoxCeleb
Dataset
from
paddlespeech.vector.io.dataset
import
CSV
Dataset
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.loss
import
AdditiveAngularMargin
from
paddlespeech.vector.modules.loss
import
AdditiveAngularMargin
from
paddlespeech.vector.modules.loss
import
LogSoftmaxWrapper
from
paddlespeech.vector.modules.loss
import
LogSoftmaxWrapper
...
@@ -55,11 +55,11 @@ def main(args, config):
...
@@ -55,11 +55,11 @@ def main(args, config):
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset
=
VoxCeleb
Dataset
(
train_dataset
=
CSV
Dataset
(
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/train.csv"
),
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/train.csv"
),
spk_id2label_path
=
os
.
path
.
join
(
args
.
data_dir
,
spk_id2label_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/meta/spk_id2label.txt"
))
"vox/meta/spk_id2label.txt"
))
dev_dataset
=
VoxCeleb
Dataset
(
dev_dataset
=
CSV
Dataset
(
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/dev.csv"
),
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/dev.csv"
),
spk_id2label_path
=
os
.
path
.
join
(
args
.
data_dir
,
spk_id2label_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/meta/spk_id2label.txt"
))
"vox/meta/spk_id2label.txt"
))
...
@@ -74,7 +74,7 @@ def main(args, config):
...
@@ -74,7 +74,7 @@ def main(args, config):
# stage4: build the speaker verification train instance with backbone model
# stage4: build the speaker verification train instance with backbone model
model
=
SpeakerIdetification
(
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
VoxCeleb
.
num_speakers
)
backbone
=
ecapa_tdnn
,
num_class
=
config
.
num_speakers
)
# stage5: build the optimizer, we now only construct the AdamW optimizer
# stage5: build the optimizer, we now only construct the AdamW optimizer
# 140000 is single gpu steps
# 140000 is single gpu steps
...
@@ -148,6 +148,7 @@ def main(args, config):
...
@@ -148,6 +148,7 @@ def main(args, config):
train_reader_cost
=
0.0
train_reader_cost
=
0.0
train_feat_cost
=
0.0
train_feat_cost
=
0.0
train_run_cost
=
0.0
train_run_cost
=
0.0
train_misce_cost
=
0.0
reader_start
=
time
.
time
()
reader_start
=
time
.
time
()
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
...
@@ -203,12 +204,14 @@ def main(args, config):
...
@@ -203,12 +204,14 @@ def main(args, config):
train_run_cost
+=
time
.
time
()
-
train_start
train_run_cost
+=
time
.
time
()
-
train_start
# stage 9-8: Calculate average loss per batch
# stage 9-8: Calculate average loss per batch
avg_loss
+=
loss
.
numpy
()[
0
]
train_misce_start
=
time
.
time
()
avg_loss
=
loss
.
item
()
# stage 9-9: Calculate metrics, which is one-best accuracy
# stage 9-9: Calculate metrics, which is one-best accuracy
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
num_samples
+=
feats
.
shape
[
0
]
timer
.
count
()
# step plus one in timer
timer
.
count
()
# step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
# stage 9-10: print the log information only on 0-rank per log-freq batchs
...
@@ -227,6 +230,7 @@ def main(args, config):
...
@@ -227,6 +230,7 @@ def main(args, config):
train_feat_cost
/
config
.
log_interval
)
train_feat_cost
/
config
.
log_interval
)
print_msg
+=
' avg_train_cost: {:.5f} sec,'
.
format
(
print_msg
+=
' avg_train_cost: {:.5f} sec,'
.
format
(
train_run_cost
/
config
.
log_interval
)
train_run_cost
/
config
.
log_interval
)
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
lr
,
timer
.
timing
,
timer
.
eta
)
logger
.
info
(
print_msg
)
logger
.
info
(
print_msg
)
...
...
paddlespeech/vector/io/augment.py
浏览文件 @
30b5b3cb
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# this is modified from SpeechBrain
# this is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import
math
import
math
import
os
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
...
@@ -22,13 +23,12 @@ import paddle.nn as nn
...
@@ -22,13 +23,12 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.dataset
import
RIRSNoise
Dataset
from
paddlespeech.vector.io.dataset
import
CSV
Dataset
from
paddlespeech.vector.io.signal_processing
import
compute_amplitude
from
paddlespeech.vector.io.signal_processing
import
compute_amplitude
from
paddlespeech.vector.io.signal_processing
import
convolve1d
from
paddlespeech.vector.io.signal_processing
import
convolve1d
from
paddlespeech.vector.io.signal_processing
import
dB_to_amplitude
from
paddlespeech.vector.io.signal_processing
import
dB_to_amplitude
from
paddlespeech.vector.io.signal_processing
import
notch_filter
from
paddlespeech.vector.io.signal_processing
import
notch_filter
from
paddlespeech.vector.io.signal_processing
import
reverberate
from
paddlespeech.vector.io.signal_processing
import
reverberate
# from paddleaudio.datasets.rirs_noises import OpenRIRNoise
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -510,7 +510,7 @@ class AddNoise(nn.Layer):
...
@@ -510,7 +510,7 @@ class AddNoise(nn.Layer):
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
0
]
}
'
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
0
]
}
'
return
np
.
pad
(
x
,
[
0
,
w
],
mode
=
mode
,
**
kwargs
)
return
np
.
pad
(
x
,
[
0
,
w
],
mode
=
mode
,
**
kwargs
)
ids
=
[
item
[
'id'
]
for
item
in
batch
]
ids
=
[
item
[
'
utt_
id'
]
for
item
in
batch
]
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
0
]
for
item
in
batch
])
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
0
]
for
item
in
batch
])
waveforms
=
list
(
waveforms
=
list
(
map
(
lambda
x
:
pad
(
x
,
max
(
max_length
,
lengths
.
max
().
item
())),
map
(
lambda
x
:
pad
(
x
,
max
(
max_length
,
lengths
.
max
().
item
())),
...
@@ -590,7 +590,7 @@ class AddReverb(nn.Layer):
...
@@ -590,7 +590,7 @@ class AddReverb(nn.Layer):
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
0
]
}
'
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
0
]
}
'
return
np
.
pad
(
x
,
[
0
,
w
],
mode
=
mode
,
**
kwargs
)
return
np
.
pad
(
x
,
[
0
,
w
],
mode
=
mode
,
**
kwargs
)
ids
=
[
item
[
'id'
]
for
item
in
batch
]
ids
=
[
item
[
'
utt_
id'
]
for
item
in
batch
]
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
0
]
for
item
in
batch
])
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
0
]
for
item
in
batch
])
waveforms
=
list
(
waveforms
=
list
(
map
(
lambda
x
:
pad
(
x
,
lengths
.
max
().
item
()),
map
(
lambda
x
:
pad
(
x
,
lengths
.
max
().
item
()),
...
@@ -840,10 +840,10 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
...
@@ -840,10 +840,10 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
List[paddle.nn.Layer]: all augment process
List[paddle.nn.Layer]: all augment process
"""
"""
logger
.
info
(
"start to build the augment pipeline"
)
logger
.
info
(
"start to build the augment pipeline"
)
noise_dataset
=
RIRSNoiseDataset
(
csv_path
=
os
.
path
.
join
(
noise_dataset
=
CSVDataset
(
csv_path
=
os
.
path
.
join
(
target_dir
,
target_dir
,
"rir_noise/csv/noise.csv"
))
"rir_noise/csv/noise.csv"
))
rir_dataset
=
OpenRIRNoise
(
csv_path
=
os
.
path
.
join
(
target_dir
,
rir_dataset
=
CSVDataset
(
csv_path
=
os
.
path
.
join
(
target_dir
,
"rir_noise/csv/rir.csv"
))
"rir_noise/csv/rir.csv"
))
wavedrop
=
TimeDomainSpecAugment
(
wavedrop
=
TimeDomainSpecAugment
(
sample_rate
=
16000
,
sample_rate
=
16000
,
...
...
paddlespeech/vector/io/dataset.py
浏览文件 @
30b5b3cb
...
@@ -11,18 +11,38 @@
...
@@ -11,18 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
collection
s
from
dataclasses
import
dataclas
s
from
dataclasses
import
fields
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
from
paddleaudio
import
load
as
load_audio
from
paddleaudio
import
load
as
load_audio
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
# the audio meta info in the vector CSVDataset
# utt_id: the utterance segment name
# duration: utterance segment time
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# lab_id: the utterance segment's label id
@
dataclass
class
meta_info
:
utt_id
:
str
duration
:
float
wav
:
str
start
:
int
stop
:
int
lab_id
:
str
class
VoxCelebDataset
(
Dataset
):
meta_info
=
collections
.
namedtuple
(
'META_INFO'
,
(
'id'
,
'duration'
,
'wav'
,
'start'
,
'stop'
,
'spk_id'
))
def
__init__
(
self
,
csv_path
,
spk_id2label_path
,
config
):
class
CSVDataset
(
Dataset
):
# meta_info = collections.namedtuple(
# 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
def
__init__
(
self
,
csv_path
,
spk_id2label_path
=
None
,
config
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
csv_path
=
csv_path
self
.
csv_path
=
csv_path
self
.
spk_id2label_path
=
spk_id2label_path
self
.
spk_id2label_path
=
spk_id2label_path
...
@@ -32,34 +52,41 @@ class VoxCelebDataset(Dataset):
...
@@ -32,34 +52,41 @@ class VoxCelebDataset(Dataset):
def
load_data_csv
(
self
):
def
load_data_csv
(
self
):
data
=
[]
data
=
[]
with
open
(
self
.
csv_path
,
'r'
)
as
rf
:
with
open
(
self
.
csv_path
,
'r'
)
as
rf
:
for
line
in
rf
.
readlines
()[
1
:]:
for
line
in
rf
.
readlines
()[
1
:]:
audio_id
,
duration
,
wav
,
start
,
stop
,
spk_id
=
line
.
strip
(
audio_id
,
duration
,
wav
,
start
,
stop
,
spk_id
=
line
.
strip
(
).
split
(
','
)
).
split
(
','
)
data
.
append
(
data
.
append
(
self
.
meta_info
(
audio_id
,
meta_info
(
audio_id
,
float
(
duration
),
wav
,
float
(
duration
),
wav
,
int
(
start
),
int
(
stop
),
spk_id
))
int
(
start
),
int
(
stop
),
spk_id
))
return
data
return
data
def
load_speaker_to_label
(
self
):
def
load_speaker_to_label
(
self
):
if
not
self
.
spk_id2label_path
:
logger
.
warning
(
"No speaker id to label file"
)
return
spk_id2label
=
{}
with
open
(
self
.
spk_id2label_path
,
'r'
)
as
f
:
with
open
(
self
.
spk_id2label_path
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
for
line
in
f
.
readlines
():
spk_id
,
label
=
line
.
strip
().
split
(
' '
)
spk_id
,
label
=
line
.
strip
().
split
(
' '
)
self
.
spk_id2label
[
spk_id
]
=
int
(
label
)
spk_id2label
[
spk_id
]
=
int
(
label
)
return
spk_id2label
def
convert_to_record
(
self
,
idx
:
int
):
def
convert_to_record
(
self
,
idx
:
int
):
sample
=
self
.
data
[
idx
]
sample
=
self
.
data
[
idx
]
record
=
{}
record
=
{}
# To show all fields in a namedtuple: `type(sample)._fields`
# To show all fields in a namedtuple: `type(sample)._fields`
for
field
in
type
(
sample
).
_fields
:
for
field
in
fields
(
sample
)
:
record
[
field
]
=
getattr
(
sample
,
field
)
record
[
field
.
name
]
=
getattr
(
sample
,
field
.
name
)
waveform
,
sr
=
load_audio
(
record
[
'wav'
])
waveform
,
sr
=
load_audio
(
record
[
'wav'
])
# random select a chunk audio samples from the audio
# random select a chunk audio samples from the audio
if
self
.
config
.
random_chunk
:
if
self
.
config
and
self
.
config
.
random_chunk
:
num_wav_samples
=
waveform
.
shape
[
0
]
num_wav_samples
=
waveform
.
shape
[
0
]
num_chunk_samples
=
int
(
self
.
config
.
chunk_duration
*
sr
)
num_chunk_samples
=
int
(
self
.
config
.
chunk_duration
*
sr
)
start
=
random
.
randint
(
0
,
num_wav_samples
-
num_chunk_samples
-
1
)
start
=
random
.
randint
(
0
,
num_wav_samples
-
num_chunk_samples
-
1
)
...
@@ -71,46 +98,9 @@ class VoxCelebDataset(Dataset):
...
@@ -71,46 +98,9 @@ class VoxCelebDataset(Dataset):
# we only return the waveform as feat
# we only return the waveform as feat
waveform
=
waveform
[
start
:
stop
]
waveform
=
waveform
[
start
:
stop
]
record
.
update
({
'feat'
:
waveform
})
record
.
update
({
'feat'
:
waveform
})
record
.
update
({
'label'
:
self
.
spk_id2label
[
record
[
'spk_id'
]]})
if
self
.
spk_id2label
:
record
.
update
({
'label'
:
self
.
spk_id2label
[
record
[
'lab_id'
]]})
return
record
def
__getitem__
(
self
,
idx
):
return
self
.
convert_to_record
(
idx
)
def
__len__
(
self
):
return
len
(
self
.
data
)
class
RIRSNoiseDataset
(
Dataset
):
meta_info
=
collections
.
namedtuple
(
'META_INFO'
,
(
'id'
,
'duration'
,
'wav'
))
def
__init__
(
self
,
csv_path
):
super
().
__init__
()
self
.
csv_path
=
csv_path
self
.
data
=
self
.
load_csv_data
()
def
load_csv_data
(
self
):
data
=
[]
with
open
(
self
.
csv_path
,
'r'
)
as
rf
:
for
line
in
rf
.
readlines
()[
1
:]:
audio_id
,
duration
,
wav
=
line
.
strip
().
split
(
','
)
data
.
append
(
self
.
meta_info
(
audio_id
,
float
(
duration
),
wav
))
random
.
shuffle
(
data
)
return
data
def
convert_to_record
(
self
,
idx
:
int
):
sample
=
self
.
data
[
idx
]
record
=
{}
# To show all fields in a namedtuple: `type(sample)._fields`
for
field
in
type
(
sample
).
_fields
:
record
[
field
]
=
getattr
(
sample
,
field
)
waveform
,
sr
=
load_audio
(
record
[
'wav'
])
record
.
update
({
'feat'
:
waveform
})
return
record
return
record
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
...
...
paddlespeech/vector/utils/utils.py
0 → 100644
浏览文件 @
30b5b3cb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
get_chunks
(
seg_dur
,
audio_id
,
audio_duration
):
"""Get all chunk segments from a utterance
Args:
seg_dur (float): segment chunk duration
audio_id (str): utterance name
audio_duration (float): utterance duration
Returns:
List: all the chunk segments
"""
num_chunks
=
int
(
audio_duration
/
seg_dur
)
# all in milliseconds
chunk_lst
=
[
audio_id
+
"_"
+
str
(
i
*
seg_dur
)
+
"_"
+
str
(
i
*
seg_dur
+
seg_dur
)
for
i
in
range
(
num_chunks
)
]
return
chunk_lst
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录