Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
30b5b3cb
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
30b5b3cb
编写于
4月 02, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vector csv dataset format, test=doc
上级
5b05300e
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
149 addition
and
131 deletion
+149
-131
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
+2
-2
examples/voxceleb/sv0/local/data.sh
examples/voxceleb/sv0/local/data.sh
+3
-3
examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
...xceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
+42
-31
examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
...ples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
+11
-30
paddlespeech/vector/exps/ecapa_tdnn/train.py
paddlespeech/vector/exps/ecapa_tdnn/train.py
+9
-5
paddlespeech/vector/io/augment.py
paddlespeech/vector/io/augment.py
+8
-8
paddlespeech/vector/io/dataset.py
paddlespeech/vector/io/dataset.py
+42
-52
paddlespeech/vector/utils/utils.py
paddlespeech/vector/utils/utils.py
+32
-0
未找到文件。
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
浏览文件 @
30b5b3cb
...
...
@@ -4,9 +4,9 @@
# we should explicitly specify the wav path of vox2 audio data converted from m4a
vox2_base_path
:
augment
:
True
batch_size
:
16
batch_size
:
32
num_workers
:
2
num_speakers
:
7205
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
num_speakers
:
1211
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle
:
True
skip_prep
:
False
split_ratio
:
0.9
...
...
examples/voxceleb/sv0/local/data.sh
浏览文件 @
30b5b3cb
...
...
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage
=
7
stage
=
0
stop_stage
=
100
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
...
...
@@ -32,7 +32,7 @@ mkdir -p ${dir}
# Generally the `MAIN_ROOT` refers to the root of PaddleSpeech,
# which is defined in the path.sh
# And we will download the
# And we will download the
voxceleb data and rirs noise to ${MAIN_ROOT}/dataset
TARGET_DIR
=
${
MAIN_ROOT
}
/dataset
mkdir
-p
${
TARGET_DIR
}
...
...
@@ -98,7 +98,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# generate the vox csv file
# Currently, our training system use csv file for dataset
echo
"convert the json format to csv format to be compatible with training process"
python3
local
/make_csv_dataset_from_json.py
\
python3
local
/make_
vox_
csv_dataset_from_json.py
\
--train
"
${
dir
}
/vox1/manifest.dev"
\
--test
"
${
dir
}
/vox1/manifest.test"
\
--target_dir
"
${
dir
}
/vox/"
\
...
...
examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
浏览文件 @
30b5b3cb
...
...
@@ -20,31 +20,29 @@ import csv
import
os
from
typing
import
List
import
paddle
import
tqdm
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.training.seeding
import
seed_everything
logger
=
Log
(
__name__
).
getlog
()
from
paddleaudio
import
load
as
load_audio
from
paddleaudio
import
save
as
save_wav
def
get_chunks
(
seg_dur
,
audio_id
,
audio_duration
):
num_chunks
=
int
(
audio_duration
/
seg_dur
)
# all in milliseconds
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.utils.utils
import
get_chunks
chunk_lst
=
[
audio_id
+
"_"
+
str
(
i
*
seg_dur
)
+
"_"
+
str
(
i
*
seg_dur
+
seg_dur
)
for
i
in
range
(
num_chunks
)
]
return
chunk_lst
logger
=
Log
(
__name__
).
getlog
()
def
get_
audio_info
(
wav_file
:
str
,
def
get_
chunks_list
(
wav_file
:
str
,
split_chunks
:
bool
,
base_path
:
str
,
chunk_duration
:
float
=
3.0
)
->
List
[
List
[
str
]]:
"""Get the single audio file info
Args:
wav_file (list): the wav audio file and get this audio segment info list
split_chunks (bool): audio split flag
base_path (str): the audio base path
chunk_duration (float): the chunk duration.
if set the split_chunks, we split the audio into multi-chunks segment.
"""
waveform
,
sr
=
load_audio
(
wav_file
)
audio_id
=
wav_file
.
split
(
"/rir_noise/"
)[
-
1
].
split
(
"."
)[
0
]
audio_duration
=
waveform
.
shape
[
0
]
/
sr
...
...
@@ -57,13 +55,16 @@ def get_audio_info(wav_file: str,
s
,
e
=
chunk
.
split
(
"_"
)[
-
2
:]
# Timestamps of start and end
start_sample
=
int
(
float
(
s
)
*
sr
)
end_sample
=
int
(
float
(
e
)
*
sr
)
new_wav_file
=
os
.
path
.
join
(
base_path
,
audio_id
+
f
'_chunk_
{
idx
+
1
:
02
}
.wav'
)
save_wav
(
waveform
[
start_sample
:
end_sample
],
sr
,
new_wav_file
)
# id, duration, new_wav
ret
.
append
([
chunk
,
chunk_duration
,
new_wav_file
])
# currently, all vector csv data format use one representation
# id, duration, wav, start, stop, spk_id
ret
.
append
([
chunk
,
audio_duration
,
wav_file
,
start_sample
,
end_sample
,
"noise"
])
else
:
# Keep whole audio.
ret
.
append
([
audio_id
,
audio_duration
,
wav_file
])
ret
.
append
(
[
audio_id
,
audio_duration
,
wav_file
,
0
,
waveform
.
shape
[
0
],
"noise"
])
return
ret
...
...
@@ -71,12 +72,20 @@ def generate_csv(wav_files,
output_file
:
str
,
base_path
:
str
,
split_chunks
:
bool
=
True
):
print
(
f
'Generating csv:
{
output_file
}
'
)
header
=
[
"id"
,
"duration"
,
"wav"
]
"""Prepare the csv file according the wav files
Args:
wav_files (list): all the audio list to prepare the csv file
output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag
"""
logger
.
info
(
f
'Generating csv:
{
output_file
}
'
)
header
=
[
"utt_id"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"lab_id"
]
csv_lines
=
[]
for
item
in
tqdm
.
tqdm
(
wav_files
):
csv_lines
.
extend
(
get_
audio_info
(
get_
chunks_list
(
item
,
base_path
=
base_path
,
split_chunks
=
split_chunks
))
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
output_file
)):
...
...
@@ -91,11 +100,12 @@ def generate_csv(wav_files,
def
prepare_data
(
args
,
config
):
# stage0: set the cpu device,
# all data prepare process will be done in cpu mode
paddle
.
device
.
set_device
(
"cpu"
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
config
.
seed
)
"""Convert the jsonline format to csv format
Args:
args (argparse.Namespace): scripts args
config (CfgNode): yaml configuration content
"""
# if external config set the skip_prep flat, we will do nothing
if
config
.
skip_prep
:
return
...
...
@@ -119,6 +129,7 @@ def prepare_data(args, config):
noise_files
.
append
(
os
.
path
.
join
(
base_path
,
noise_file
))
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
'csv'
)
logger
.
info
(
f
"csv path:
{
csv_path
}
"
)
generate_csv
(
rir_files
,
os
.
path
.
join
(
csv_path
,
'rir.csv'
),
base_path
=
base_path
)
generate_csv
(
...
...
examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
浏览文件 @
30b5b3cb
...
...
@@ -21,51 +21,34 @@ import json
import
os
import
random
import
paddle
import
tqdm
from
yacs.config
import
CfgNode
from
paddleaudio
import
load
as
load_audio
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.training.seeding
import
seed_everything
logger
=
Log
(
__name__
).
getlog
()
def
get_chunks
(
seg_dur
,
audio_id
,
audio_duration
):
"""Get all chunk segments from a utterance
from
paddlespeech.vector.utils.utils
import
get_chunks
Args:
seg_dur (float): segment chunk duration
audio_id (str): utterance name
audio_duration (float): utterance duration
Returns:
List: all the chunk segments
"""
num_chunks
=
int
(
audio_duration
/
seg_dur
)
# all in milliseconds
chunk_lst
=
[
audio_id
+
"_"
+
str
(
i
*
seg_dur
)
+
"_"
+
str
(
i
*
seg_dur
+
seg_dur
)
for
i
in
range
(
num_chunks
)
]
return
chunk_lst
logger
=
Log
(
__name__
).
getlog
()
def
prepare_csv
(
wav_files
,
output_file
,
config
,
split_chunks
=
True
):
"""Prepare the csv file according the wav files
Args:
dataset_list (list): all the dataset to get the test utterances
verification_file (str): voxceleb1 trial file
wav_files (list): all the audio list to prepare the csv file
output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag
"""
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
output_file
)):
os
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
csv_lines
=
[]
header
=
[
"
id"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"spk
_id"
]
header
=
[
"
utt_id"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"lab
_id"
]
# voxceleb meta info for each training utterance segment
# we extract a segment from a utterance to train
# and the segment' period is between start and stop time point in the original wav file
# each field in the meta means as follows:
# id: the utterance segment name
#
utt_
id: the utterance segment name
# duration: utterance segment time
# wav: utterance file path
# start: start point in the original wav file
...
...
@@ -194,11 +177,9 @@ def prepare_data(args, config):
args (argparse.Namespace): scripts args
config (CfgNode): yaml configuration content
"""
# stage0: set the cpu device,
# all data prepare process will be done in cpu mode
paddle
.
device
.
set_device
(
"cpu"
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
config
.
seed
)
# stage0: set the random seed
random
.
seed
(
config
.
seed
)
# if external config set the skip_prep flat, we will do nothing
if
config
.
skip_prep
:
return
...
...
paddlespeech/vector/exps/ecapa_tdnn/train.py
浏览文件 @
30b5b3cb
...
...
@@ -29,7 +29,7 @@ from paddlespeech.vector.io.augment import waveform_augment
from
paddlespeech.vector.io.batch
import
batch_pad_right
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.io.batch
import
waveform_collate_fn
from
paddlespeech.vector.io.dataset
import
VoxCeleb
Dataset
from
paddlespeech.vector.io.dataset
import
CSV
Dataset
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.loss
import
AdditiveAngularMargin
from
paddlespeech.vector.modules.loss
import
LogSoftmaxWrapper
...
...
@@ -55,11 +55,11 @@ def main(args, config):
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset
=
VoxCeleb
Dataset
(
train_dataset
=
CSV
Dataset
(
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/train.csv"
),
spk_id2label_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/meta/spk_id2label.txt"
))
dev_dataset
=
VoxCeleb
Dataset
(
dev_dataset
=
CSV
Dataset
(
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/dev.csv"
),
spk_id2label_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/meta/spk_id2label.txt"
))
...
...
@@ -74,7 +74,7 @@ def main(args, config):
# stage4: build the speaker verification train instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
VoxCeleb
.
num_speakers
)
backbone
=
ecapa_tdnn
,
num_class
=
config
.
num_speakers
)
# stage5: build the optimizer, we now only construct the AdamW optimizer
# 140000 is single gpu steps
...
...
@@ -148,6 +148,7 @@ def main(args, config):
train_reader_cost
=
0.0
train_feat_cost
=
0.0
train_run_cost
=
0.0
train_misce_cost
=
0.0
reader_start
=
time
.
time
()
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
...
...
@@ -203,12 +204,14 @@ def main(args, config):
train_run_cost
+=
time
.
time
()
-
train_start
# stage 9-8: Calculate average loss per batch
avg_loss
+=
loss
.
numpy
()[
0
]
train_misce_start
=
time
.
time
()
avg_loss
=
loss
.
item
()
# stage 9-9: Calculate metrics, which is one-best accuracy
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
timer
.
count
()
# step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
...
...
@@ -227,6 +230,7 @@ def main(args, config):
train_feat_cost
/
config
.
log_interval
)
print_msg
+=
' avg_train_cost: {:.5f} sec,'
.
format
(
train_run_cost
/
config
.
log_interval
)
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
logger
.
info
(
print_msg
)
...
...
paddlespeech/vector/io/augment.py
浏览文件 @
30b5b3cb
...
...
@@ -14,6 +14,7 @@
# this is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import
math
import
os
from
typing
import
List
import
numpy
as
np
...
...
@@ -22,13 +23,12 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.dataset
import
RIRSNoise
Dataset
from
paddlespeech.vector.io.dataset
import
CSV
Dataset
from
paddlespeech.vector.io.signal_processing
import
compute_amplitude
from
paddlespeech.vector.io.signal_processing
import
convolve1d
from
paddlespeech.vector.io.signal_processing
import
dB_to_amplitude
from
paddlespeech.vector.io.signal_processing
import
notch_filter
from
paddlespeech.vector.io.signal_processing
import
reverberate
# from paddleaudio.datasets.rirs_noises import OpenRIRNoise
logger
=
Log
(
__name__
).
getlog
()
...
...
@@ -510,7 +510,7 @@ class AddNoise(nn.Layer):
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
0
]
}
'
return
np
.
pad
(
x
,
[
0
,
w
],
mode
=
mode
,
**
kwargs
)
ids
=
[
item
[
'id'
]
for
item
in
batch
]
ids
=
[
item
[
'
utt_
id'
]
for
item
in
batch
]
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
0
]
for
item
in
batch
])
waveforms
=
list
(
map
(
lambda
x
:
pad
(
x
,
max
(
max_length
,
lengths
.
max
().
item
())),
...
...
@@ -590,7 +590,7 @@ class AddReverb(nn.Layer):
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
0
]
}
'
return
np
.
pad
(
x
,
[
0
,
w
],
mode
=
mode
,
**
kwargs
)
ids
=
[
item
[
'id'
]
for
item
in
batch
]
ids
=
[
item
[
'
utt_
id'
]
for
item
in
batch
]
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
0
]
for
item
in
batch
])
waveforms
=
list
(
map
(
lambda
x
:
pad
(
x
,
lengths
.
max
().
item
()),
...
...
@@ -840,9 +840,9 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
List[paddle.nn.Layer]: all augment process
"""
logger
.
info
(
"start to build the augment pipeline"
)
noise_dataset
=
RIRSNoiseDataset
(
csv_path
=
os
.
path
.
join
(
target_dir
,
"rir_noise/csv/noise.csv"
))
rir_dataset
=
OpenRIRNoise
(
csv_path
=
os
.
path
.
join
(
target_dir
,
noise_dataset
=
CSVDataset
(
csv_path
=
os
.
path
.
join
(
target_dir
,
"rir_noise/csv/noise.csv"
))
rir_dataset
=
CSVDataset
(
csv_path
=
os
.
path
.
join
(
target_dir
,
"rir_noise/csv/rir.csv"
))
wavedrop
=
TimeDomainSpecAugment
(
...
...
paddlespeech/vector/io/dataset.py
浏览文件 @
30b5b3cb
...
...
@@ -11,18 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collection
s
from
dataclasses
import
dataclas
s
from
dataclasses
import
fields
from
paddle.io
import
Dataset
from
paddleaudio
import
load
as
load_audio
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
# the audio meta info in the vector CSVDataset
# utt_id: the utterance segment name
# duration: utterance segment time
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# lab_id: the utterance segment's label id
@
dataclass
class
meta_info
:
utt_id
:
str
duration
:
float
wav
:
str
start
:
int
stop
:
int
lab_id
:
str
class
VoxCelebDataset
(
Dataset
):
meta_info
=
collections
.
namedtuple
(
'META_INFO'
,
(
'id'
,
'duration'
,
'wav'
,
'start'
,
'stop'
,
'spk_id'
))
def
__init__
(
self
,
csv_path
,
spk_id2label_path
,
config
):
class
CSVDataset
(
Dataset
):
# meta_info = collections.namedtuple(
# 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
def
__init__
(
self
,
csv_path
,
spk_id2label_path
=
None
,
config
=
None
):
super
().
__init__
()
self
.
csv_path
=
csv_path
self
.
spk_id2label_path
=
spk_id2label_path
...
...
@@ -32,34 +52,41 @@ class VoxCelebDataset(Dataset):
def
load_data_csv
(
self
):
data
=
[]
with
open
(
self
.
csv_path
,
'r'
)
as
rf
:
for
line
in
rf
.
readlines
()[
1
:]:
audio_id
,
duration
,
wav
,
start
,
stop
,
spk_id
=
line
.
strip
(
).
split
(
','
)
data
.
append
(
self
.
meta_info
(
audio_id
,
meta_info
(
audio_id
,
float
(
duration
),
wav
,
int
(
start
),
int
(
stop
),
spk_id
))
return
data
def
load_speaker_to_label
(
self
):
if
not
self
.
spk_id2label_path
:
logger
.
warning
(
"No speaker id to label file"
)
return
spk_id2label
=
{}
with
open
(
self
.
spk_id2label_path
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
spk_id
,
label
=
line
.
strip
().
split
(
' '
)
self
.
spk_id2label
[
spk_id
]
=
int
(
label
)
spk_id2label
[
spk_id
]
=
int
(
label
)
return
spk_id2label
def
convert_to_record
(
self
,
idx
:
int
):
sample
=
self
.
data
[
idx
]
record
=
{}
# To show all fields in a namedtuple: `type(sample)._fields`
for
field
in
type
(
sample
).
_fields
:
record
[
field
]
=
getattr
(
sample
,
field
)
for
field
in
fields
(
sample
)
:
record
[
field
.
name
]
=
getattr
(
sample
,
field
.
name
)
waveform
,
sr
=
load_audio
(
record
[
'wav'
])
# random select a chunk audio samples from the audio
if
self
.
config
.
random_chunk
:
if
self
.
config
and
self
.
config
.
random_chunk
:
num_wav_samples
=
waveform
.
shape
[
0
]
num_chunk_samples
=
int
(
self
.
config
.
chunk_duration
*
sr
)
start
=
random
.
randint
(
0
,
num_wav_samples
-
num_chunk_samples
-
1
)
...
...
@@ -71,46 +98,9 @@ class VoxCelebDataset(Dataset):
# we only return the waveform as feat
waveform
=
waveform
[
start
:
stop
]
record
.
update
({
'feat'
:
waveform
})
record
.
update
({
'label'
:
self
.
spk_id2label
[
record
[
'spk_id'
]]})
return
record
def
__getitem__
(
self
,
idx
):
return
self
.
convert_to_record
(
idx
)
def
__len__
(
self
):
return
len
(
self
.
data
)
class
RIRSNoiseDataset
(
Dataset
):
meta_info
=
collections
.
namedtuple
(
'META_INFO'
,
(
'id'
,
'duration'
,
'wav'
))
def
__init__
(
self
,
csv_path
):
super
().
__init__
()
self
.
csv_path
=
csv_path
self
.
data
=
self
.
load_csv_data
()
if
self
.
spk_id2label
:
record
.
update
({
'label'
:
self
.
spk_id2label
[
record
[
'lab_id'
]]})
def
load_csv_data
(
self
):
data
=
[]
with
open
(
self
.
csv_path
,
'r'
)
as
rf
:
for
line
in
rf
.
readlines
()[
1
:]:
audio_id
,
duration
,
wav
=
line
.
strip
().
split
(
','
)
data
.
append
(
self
.
meta_info
(
audio_id
,
float
(
duration
),
wav
))
random
.
shuffle
(
data
)
return
data
def
convert_to_record
(
self
,
idx
:
int
):
sample
=
self
.
data
[
idx
]
record
=
{}
# To show all fields in a namedtuple: `type(sample)._fields`
for
field
in
type
(
sample
).
_fields
:
record
[
field
]
=
getattr
(
sample
,
field
)
waveform
,
sr
=
load_audio
(
record
[
'wav'
])
record
.
update
({
'feat'
:
waveform
})
return
record
def
__getitem__
(
self
,
idx
):
...
...
paddlespeech/vector/utils/utils.py
0 → 100644
浏览文件 @
30b5b3cb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
get_chunks
(
seg_dur
,
audio_id
,
audio_duration
):
"""Get all chunk segments from a utterance
Args:
seg_dur (float): segment chunk duration
audio_id (str): utterance name
audio_duration (float): utterance duration
Returns:
List: all the chunk segments
"""
num_chunks
=
int
(
audio_duration
/
seg_dur
)
# all in milliseconds
chunk_lst
=
[
audio_id
+
"_"
+
str
(
i
*
seg_dur
)
+
"_"
+
str
(
i
*
seg_dur
+
seg_dur
)
for
i
in
range
(
num_chunks
)
]
return
chunk_lst
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录