Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
14d9e80b
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
14d9e80b
编写于
1月 18, 2022
作者:
Q
qingen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[vector] add AMI data preparation scripts
上级
98788ca2
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
167 addition
and
156 deletion
+167
-156
dataset/ami/ami_prepare.py
dataset/ami/ami_prepare.py
+102
-131
dataset/ami/ami_splits.py
dataset/ami/ami_splits.py
+14
-2
dataset/ami/dataio.py
dataset/ami/dataio.py
+17
-2
paddlespeech/vector/utils/DER.py
paddlespeech/vector/utils/DER.py
+34
-21
paddlespeech/vector/utils/md-eval.pl
paddlespeech/vector/utils/md-eval.pl
+0
-0
未找到文件。
dataset/ami/ami_prepare.py
浏览文件 @
14d9e80b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
Data preparation.
Data preparation.
...
@@ -21,8 +34,7 @@ from distutils.util import strtobool
...
@@ -21,8 +34,7 @@ from distutils.util import strtobool
from
utils.dataio
import
(
from
utils.dataio
import
(
load_pkl
,
load_pkl
,
save_pkl
,
save_pkl
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
SAMPLERATE
=
16000
SAMPLERATE
=
16000
...
@@ -39,8 +51,7 @@ def prepare_ami(
...
@@ -39,8 +51,7 @@ def prepare_ami(
mic_type
=
"Mix-Headset"
,
mic_type
=
"Mix-Headset"
,
vad_type
=
"oracle"
,
vad_type
=
"oracle"
,
max_subseg_dur
=
3.0
,
max_subseg_dur
=
3.0
,
overlap
=
1.5
,
overlap
=
1.5
,
):
):
"""
"""
Prepares reference RTTM and JSON files for the AMI dataset.
Prepares reference RTTM and JSON files for the AMI dataset.
...
@@ -72,12 +83,12 @@ def prepare_ami(
...
@@ -72,12 +83,12 @@ def prepare_ami(
Example
Example
-------
-------
>>> from
recipes.AMI
.ami_prepare import prepare_ami
>>> from
dataset.ami
.ami_prepare import prepare_ami
>>> data_folder = '/
network/datasets
/ami/amicorpus/'
>>> data_folder = '/
home/data
/ami/amicorpus/'
>>> manual_annot_folder = '/home/
mila/d/dawalatn/nauman
/ami_public_manual/'
>>> manual_annot_folder = '/home/
data/ami
/ami_public_manual/'
>>> save_folder = '
results/save/'
>>> save_folder = '
./results/
>>> split_type = 'full_corpus_asr'
>>> split_type = 'full_corpus_asr'
>>> mic_type = '
Lapel
'
>>> mic_type = '
Mix-Headset
'
>>> prepare_ami(data_folder, manual_annot_folder, save_folder, split_type, mic_type)
>>> prepare_ami(data_folder, manual_annot_folder, save_folder, split_type, mic_type)
"""
"""
...
@@ -112,8 +123,7 @@ def prepare_ami(
...
@@ -112,8 +123,7 @@ def prepare_ami(
# Check if this phase is already done (if so, skip it)
# Check if this phase is already done (if so, skip it)
if
skip
(
save_folder
,
conf
,
meta_files
,
opt_file
):
if
skip
(
save_folder
,
conf
,
meta_files
,
opt_file
):
logger
.
info
(
logger
.
info
(
"Skipping data preparation, as it was completed in previous run."
"Skipping data preparation, as it was completed in previous run."
)
)
return
return
msg
=
"
\t
Creating meta-data file for the AMI Dataset.."
msg
=
"
\t
Creating meta-data file for the AMI Dataset.."
...
@@ -138,8 +148,7 @@ def prepare_ami(
...
@@ -138,8 +148,7 @@ def prepare_ami(
data_folder
,
data_folder
,
manual_annot_folder
,
manual_annot_folder
,
i
,
i
,
skip_TNO
,
skip_TNO
,
)
)
if
i
==
"dev"
:
if
i
==
"dev"
:
prepare_segs_for_RTTM
(
prepare_segs_for_RTTM
(
dev_set
,
dev_set
,
...
@@ -147,8 +156,7 @@ def prepare_ami(
...
@@ -147,8 +156,7 @@ def prepare_ami(
data_folder
,
data_folder
,
manual_annot_folder
,
manual_annot_folder
,
i
,
i
,
skip_TNO
,
skip_TNO
,
)
)
if
i
==
"eval"
:
if
i
==
"eval"
:
prepare_segs_for_RTTM
(
prepare_segs_for_RTTM
(
eval_set
,
eval_set
,
...
@@ -156,8 +164,7 @@ def prepare_ami(
...
@@ -156,8 +164,7 @@ def prepare_ami(
data_folder
,
data_folder
,
manual_annot_folder
,
manual_annot_folder
,
i
,
i
,
skip_TNO
,
skip_TNO
,
)
)
# Create meta_files for splits
# Create meta_files for splits
meta_data_dir
=
meta_data_dir
meta_data_dir
=
meta_data_dir
...
@@ -174,8 +181,7 @@ def prepare_ami(
...
@@ -174,8 +181,7 @@ def prepare_ami(
meta_filename_prefix
,
meta_filename_prefix
,
max_subseg_dur
,
max_subseg_dur
,
overlap
,
overlap
,
mic_type
,
mic_type
,
)
)
save_opt_file
=
os
.
path
.
join
(
save_folder
,
opt_file
)
save_opt_file
=
os
.
path
.
join
(
save_folder
,
opt_file
)
save_pkl
(
conf
,
save_opt_file
)
save_pkl
(
conf
,
save_opt_file
)
...
@@ -190,13 +196,8 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id):
...
@@ -190,13 +196,8 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id):
# Prepare header
# Prepare header
for
spkr_id
in
spkrs_list
:
for
spkr_id
in
spkrs_list
:
# e.g. SPKR-INFO ES2008c 0 <NA> <NA> <NA> unknown ES2008c.A_PM <NA> <NA>
# e.g. SPKR-INFO ES2008c 0 <NA> <NA> <NA> unknown ES2008c.A_PM <NA> <NA>
line
=
(
line
=
(
"SPKR-INFO "
+
rec_id
+
" 0 <NA> <NA> <NA> unknown "
+
spkr_id
+
"SPKR-INFO "
" <NA> <NA>"
)
+
rec_id
+
" 0 <NA> <NA> <NA> unknown "
+
spkr_id
+
" <NA> <NA>"
)
rttm
.
append
(
line
)
rttm
.
append
(
line
)
# Append remaining lines
# Append remaining lines
...
@@ -206,57 +207,35 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id):
...
@@ -206,57 +207,35 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id):
if
float
(
row
[
1
])
<
float
(
row
[
0
]):
if
float
(
row
[
1
])
<
float
(
row
[
0
]):
msg1
=
(
msg1
=
(
"Possibly Incorrect Annotation Found!! transcriber_start (%s) > transcriber_end (%s)"
"Possibly Incorrect Annotation Found!! transcriber_start (%s) > transcriber_end (%s)"
%
(
row
[
0
],
row
[
1
])
%
(
row
[
0
],
row
[
1
]))
)
msg2
=
(
msg2
=
(
"Excluding this incorrect row from the RTTM : %s, %s, %s, %s"
"Excluding this incorrect row from the RTTM : %s, %s, %s, %s"
%
%
(
(
rec_id
,
row
[
0
],
str
(
round
(
float
(
row
[
1
])
-
float
(
row
[
0
]),
4
)),
rec_id
,
str
(
row
[
2
]),
))
row
[
0
],
str
(
round
(
float
(
row
[
1
])
-
float
(
row
[
0
]),
4
)),
str
(
row
[
2
]),
)
)
logger
.
info
(
msg1
)
logger
.
info
(
msg1
)
logger
.
info
(
msg2
)
logger
.
info
(
msg2
)
continue
continue
line
=
(
line
=
(
"SPEAKER "
+
rec_id
+
" 0 "
+
str
(
round
(
float
(
row
[
0
]),
4
))
+
" "
"SPEAKER "
+
str
(
round
(
float
(
row
[
1
])
-
float
(
row
[
0
]),
4
))
+
" <NA> <NA> "
+
+
rec_id
str
(
row
[
2
])
+
" <NA> <NA>"
)
+
" 0 "
+
str
(
round
(
float
(
row
[
0
]),
4
))
+
" "
+
str
(
round
(
float
(
row
[
1
])
-
float
(
row
[
0
]),
4
))
+
" <NA> <NA> "
+
str
(
row
[
2
])
+
" <NA> <NA>"
)
rttm
.
append
(
line
)
rttm
.
append
(
line
)
return
rttm
return
rttm
def
prepare_segs_for_RTTM
(
def
prepare_segs_for_RTTM
(
list_ids
,
out_rttm_file
,
audio_dir
,
annot_dir
,
list_ids
,
out_rttm_file
,
audio_dir
,
annot_dir
,
split_type
,
skip_TNO
split_type
,
skip_TNO
):
):
RTTM
=
[]
# Stores all RTTMs clubbed together for a given dataset split
RTTM
=
[]
# Stores all RTTMs clubbed together for a given dataset split
for
main_meet_id
in
list_ids
:
for
main_meet_id
in
list_ids
:
# Skip TNO meetings from dev and eval sets
# Skip TNO meetings from dev and eval sets
if
(
if
(
main_meet_id
.
startswith
(
"TS"
)
and
split_type
!=
"train"
and
main_meet_id
.
startswith
(
"TS"
)
skip_TNO
is
True
):
and
split_type
!=
"train"
msg
=
(
"Skipping TNO meeting in AMI "
+
str
(
split_type
)
+
" set : "
and
skip_TNO
is
True
+
str
(
main_meet_id
))
):
msg
=
(
"Skipping TNO meeting in AMI "
+
str
(
split_type
)
+
" set : "
+
str
(
main_meet_id
)
)
logger
.
info
(
msg
)
logger
.
info
(
msg
)
continue
continue
...
@@ -271,8 +250,7 @@ def prepare_segs_for_RTTM(
...
@@ -271,8 +250,7 @@ def prepare_segs_for_RTTM(
list_spkr_xmls
.
sort
()
# A, B, C, D, E etc (Speakers)
list_spkr_xmls
.
sort
()
# A, B, C, D, E etc (Speakers)
segs
=
[]
segs
=
[]
spkrs_list
=
(
spkrs_list
=
(
[]
[])
# Since non-scenario recordings contains 3-5 speakers
)
# Since non-scenario recordings contains 3-5 speakers
for
spkr_xml_file
in
list_spkr_xmls
:
for
spkr_xml_file
in
list_spkr_xmls
:
...
@@ -286,14 +264,11 @@ def prepare_segs_for_RTTM(
...
@@ -286,14 +264,11 @@ def prepare_segs_for_RTTM(
root
=
tree
.
getroot
()
root
=
tree
.
getroot
()
# Start, end and speaker_ID from xml file
# Start, end and speaker_ID from xml file
segs
=
segs
+
[
segs
=
segs
+
[[
[
elem
.
attrib
[
"transcriber_start"
],
elem
.
attrib
[
"transcriber_start"
],
elem
.
attrib
[
"transcriber_end"
],
elem
.
attrib
[
"transcriber_end"
],
spkr_ID
,
spkr_ID
,
]
]
for
elem
in
root
.
iter
(
"segment"
)]
for
elem
in
root
.
iter
(
"segment"
)
]
# Sort rows as per the start time (per recording)
# Sort rows as per the start time (per recording)
segs
.
sort
(
key
=
lambda
x
:
float
(
x
[
0
]))
segs
.
sort
(
key
=
lambda
x
:
float
(
x
[
0
]))
...
@@ -404,9 +379,8 @@ def get_subsegments(merged_segs, max_subseg_dur=3.0, overlap=1.5):
...
@@ -404,9 +379,8 @@ def get_subsegments(merged_segs, max_subseg_dur=3.0, overlap=1.5):
return
subsegments
return
subsegments
def
prepare_metadata
(
def
prepare_metadata
(
rttm_file
,
save_dir
,
data_dir
,
filename
,
max_subseg_dur
,
rttm_file
,
save_dir
,
data_dir
,
filename
,
max_subseg_dur
,
overlap
,
mic_type
overlap
,
mic_type
):
):
# Read RTTM, get unique meeting_IDs (from RTTM headers)
# Read RTTM, get unique meeting_IDs (from RTTM headers)
# For each MeetingID. select that meetID -> merge -> subsegment -> json -> append
# For each MeetingID. select that meetID -> merge -> subsegment -> json -> append
...
@@ -425,15 +399,13 @@ def prepare_metadata(
...
@@ -425,15 +399,13 @@ def prepare_metadata(
MERGED_SEGMENTS
=
[]
MERGED_SEGMENTS
=
[]
SUBSEGMENTS
=
[]
SUBSEGMENTS
=
[]
for
rec_id
in
rec_ids
:
for
rec_id
in
rec_ids
:
segs_iter
=
filter
(
segs_iter
=
filter
(
lambda
x
:
x
.
startswith
(
"SPEAKER "
+
str
(
rec_id
)),
lambda
x
:
x
.
startswith
(
"SPEAKER "
+
str
(
rec_id
)),
RTTM
RTTM
)
)
gt_rttm_segs
=
[
row
.
split
(
" "
)
for
row
in
segs_iter
]
gt_rttm_segs
=
[
row
.
split
(
" "
)
for
row
in
segs_iter
]
# Merge, subsegment and then convert to json format.
# Merge, subsegment and then convert to json format.
merged_segs
=
merge_rttm_intervals
(
merged_segs
=
merge_rttm_intervals
(
gt_rttm_segs
gt_rttm_segs
)
# We lose speaker_ID after merging
)
# We lose speaker_ID after merging
MERGED_SEGMENTS
=
MERGED_SEGMENTS
+
merged_segs
MERGED_SEGMENTS
=
MERGED_SEGMENTS
+
merged_segs
# Divide segments into smaller sub-segments
# Divide segments into smaller sub-segments
...
@@ -467,16 +439,8 @@ def prepare_metadata(
...
@@ -467,16 +439,8 @@ def prepare_metadata(
# If multi-mic audio is selected
# If multi-mic audio is selected
if
mic_type
==
"Array1"
:
if
mic_type
==
"Array1"
:
wav_file_base_path
=
(
wav_file_base_path
=
(
data_dir
+
"/"
+
rec_id
+
"/audio/"
+
rec_id
+
data_dir
"."
+
mic_type
+
"-"
)
+
"/"
+
rec_id
+
"/audio/"
+
rec_id
+
"."
+
mic_type
+
"-"
)
f
=
[]
# adding all 8 mics
f
=
[]
# adding all 8 mics
for
i
in
range
(
8
):
for
i
in
range
(
8
):
...
@@ -494,16 +458,8 @@ def prepare_metadata(
...
@@ -494,16 +458,8 @@ def prepare_metadata(
}
}
else
:
else
:
# Single mic audio
# Single mic audio
wav_file_path
=
(
wav_file_path
=
(
data_dir
+
"/"
+
rec_id
+
"/audio/"
+
rec_id
+
"."
data_dir
+
mic_type
+
".wav"
)
+
"/"
+
rec_id
+
"/audio/"
+
rec_id
+
"."
+
mic_type
+
".wav"
)
# Note: key "file" without 's' is used for single-mic
# Note: key "file" without 's' is used for single-mic
json_dict
[
subsegment_ID
]
=
{
json_dict
[
subsegment_ID
]
=
{
...
@@ -554,6 +510,7 @@ def skip(save_folder, conf, meta_files, opt_file):
...
@@ -554,6 +510,7 @@ def skip(save_folder, conf, meta_files, opt_file):
return
skip
return
skip
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -563,40 +520,54 @@ if __name__ == '__main__':
...
@@ -563,40 +520,54 @@ if __name__ == '__main__':
--meta_data_dir ./results/metadata'
,
--meta_data_dir ./results/metadata'
,
description
=
'AMI Data preparation'
)
description
=
'AMI Data preparation'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--data_folder'
,
required
=
True
,
help
=
'Path to the folder where the original amicorpus is stored'
)
'--data_folder'
,
required
=
True
,
help
=
'Path to the folder where the original amicorpus is stored'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--manual_annot_folder'
,
required
=
True
,
help
=
'Directory where the manual annotations are stored'
)
'--manual_annot_folder'
,
required
=
True
,
help
=
'Directory where the manual annotations are stored'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--save_folder'
,
required
=
True
,
help
=
'The save directory in results'
)
'--save_folder'
,
required
=
True
,
help
=
'The save directory in results'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--ref_rttm_dir'
,
required
=
True
,
help
=
'Directory to store reference RTTM files'
)
'--ref_rttm_dir'
,
required
=
True
,
help
=
'Directory to store reference RTTM files'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--meta_data_dir'
,
required
=
True
,
help
=
'Directory to store the meta data (json) files'
)
'--meta_data_dir'
,
required
=
True
,
help
=
'Directory to store the meta data (json) files'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--split_type'
,
'--split_type'
,
default
=
"full_corpus_asr"
,
default
=
"full_corpus_asr"
,
help
=
'Standard dataset split. See ami_splits.py for more information'
)
help
=
'Standard dataset split. See ami_splits.py for more information'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--skip_TNO'
,
default
=
True
,
type
=
strtobool
,
help
=
'Skips TNO meeting recordings if True'
)
'--skip_TNO'
,
default
=
True
,
type
=
strtobool
,
help
=
'Skips TNO meeting recordings if True'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--mic_type'
,
default
=
"Mix-Headset"
,
help
=
'Type of microphone to be used'
)
'--mic_type'
,
default
=
"Mix-Headset"
,
help
=
'Type of microphone to be used'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--vad_type'
,
default
=
"oracle"
,
help
=
'Type of VAD. Kept for future when VAD will be added'
)
'--vad_type'
,
default
=
"oracle"
,
help
=
'Type of VAD. Kept for future when VAD will be added'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--max_subseg_dur'
,
'--max_subseg_dur'
,
default
=
3.0
,
default
=
3.0
,
type
=
float
,
type
=
float
,
help
=
'Duration in seconds of a subsegments to be prepared from larger segments'
)
help
=
'Duration in seconds of a subsegments to be prepared from larger segments'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--overlap'
,
default
=
1.5
,
type
=
float
,
help
=
'Overlap duration in seconds between adjacent subsegments'
)
'--overlap'
,
default
=
1.5
,
type
=
float
,
help
=
'Overlap duration in seconds between adjacent subsegments'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
prepare_ami
(
prepare_ami
(
args
.
data_folder
,
args
.
manual_annot_folder
,
args
.
save_folder
,
args
.
data_folder
,
args
.
ref_rttm_dir
,
args
.
meta_data_dir
)
args
.
manual_annot_folder
,
args
.
save_folder
,
args
.
ref_rttm_dir
,
args
.
meta_data_dir
)
\ No newline at end of file
dataset/ami/ami_splits.py
浏览文件 @
14d9e80b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
AMI corpus contained 100 hours of meeting recording.
AMI corpus contained 100 hours of meeting recording.
This script returns the standard train, dev and eval split for AMI corpus.
This script returns the standard train, dev and eval split for AMI corpus.
...
@@ -29,8 +42,7 @@ def get_AMI_split(split_option):
...
@@ -29,8 +42,7 @@ def get_AMI_split(split_option):
if
split_option
not
in
ALLOWED_OPTIONS
:
if
split_option
not
in
ALLOWED_OPTIONS
:
print
(
print
(
f
'Invalid split "
{
split_option
}
" requested!
\n
Valid split_options are: '
,
f
'Invalid split "
{
split_option
}
" requested!
\n
Valid split_options are: '
,
ALLOWED_OPTIONS
,
ALLOWED_OPTIONS
,
)
)
return
return
if
split_option
==
"scenario_only"
:
if
split_option
==
"scenario_only"
:
...
...
utils
/dataio.py
→
dataset/ami
/dataio.py
浏览文件 @
14d9e80b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
Data reading and writing.
Data reading and writing.
...
@@ -5,10 +18,10 @@ Authors
...
@@ -5,10 +18,10 @@ Authors
* qingenz123@126.com (Qingen ZHAO) 2022
* qingenz123@126.com (Qingen ZHAO) 2022
"""
"""
import
os
import
os
import
pickle
import
pickle
def
save_pkl
(
obj
,
file
):
def
save_pkl
(
obj
,
file
):
"""Save an object in pkl format.
"""Save an object in pkl format.
...
@@ -31,6 +44,7 @@ def save_pkl(obj, file):
...
@@ -31,6 +44,7 @@ def save_pkl(obj, file):
with
open
(
file
,
"wb"
)
as
f
:
with
open
(
file
,
"wb"
)
as
f
:
pickle
.
dump
(
obj
,
f
)
pickle
.
dump
(
obj
,
f
)
def
load_pickle
(
pickle_path
):
def
load_pickle
(
pickle_path
):
"""Utility function for loading .pkl pickle files.
"""Utility function for loading .pkl pickle files.
...
@@ -48,6 +62,7 @@ def load_pickle(pickle_path):
...
@@ -48,6 +62,7 @@ def load_pickle(pickle_path):
out
=
pickle
.
load
(
f
)
out
=
pickle
.
load
(
f
)
return
out
return
out
def
load_pkl
(
file
):
def
load_pkl
(
file
):
"""Loads a pkl file.
"""Loads a pkl file.
...
...
utils/DER.py
→
paddlespeech/vector/
utils/DER.py
浏览文件 @
14d9e80b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS),
"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS),
False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation.
False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation.
...
@@ -26,7 +39,6 @@ ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+")
...
@@ -26,7 +39,6 @@ ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+")
def
rectify
(
arr
):
def
rectify
(
arr
):
"""Corrects corner cases and converts scores into percentage.
"""Corrects corner cases and converts scores into percentage.
"""
"""
# Numerator and denominator both 0.
# Numerator and denominator both 0.
arr
[
np
.
isnan
(
arr
)]
=
0
arr
[
np
.
isnan
(
arr
)]
=
0
...
@@ -42,8 +54,7 @@ def DER(
...
@@ -42,8 +54,7 @@ def DER(
sys_rttm
,
sys_rttm
,
ignore_overlap
=
False
,
ignore_overlap
=
False
,
collar
=
0.25
,
collar
=
0.25
,
individual_file_scores
=
False
,
individual_file_scores
=
False
,
):
):
"""Computes Missed Speaker percentage (MS), False Alarm (FA),
"""Computes Missed Speaker percentage (MS), False Alarm (FA),
Speaker Error Rate (SER), and Diarization Error Rate (DER).
Speaker Error Rate (SER), and Diarization Error Rate (DER).
...
@@ -118,25 +129,20 @@ def DER(
...
@@ -118,25 +129,20 @@ def DER(
]
]
scored_speaker_times
=
np
.
array
(
scored_speaker_times
=
np
.
array
(
[
float
(
m
)
for
m
in
SCORED_SPEAKER_TIME
.
findall
(
stdout
)]
[
float
(
m
)
for
m
in
SCORED_SPEAKER_TIME
.
findall
(
stdout
)])
)
miss_speaker_times
=
np
.
array
(
miss_speaker_times
=
np
.
array
(
[
float
(
m
)
for
m
in
MISS_SPEAKER_TIME
.
findall
(
stdout
)]
[
float
(
m
)
for
m
in
MISS_SPEAKER_TIME
.
findall
(
stdout
)])
)
fa_speaker_times
=
np
.
array
(
fa_speaker_times
=
np
.
array
(
[
float
(
m
)
for
m
in
FA_SPEAKER_TIME
.
findall
(
stdout
)]
[
float
(
m
)
for
m
in
FA_SPEAKER_TIME
.
findall
(
stdout
)])
)
error_speaker_times
=
np
.
array
(
error_speaker_times
=
np
.
array
(
[
float
(
m
)
for
m
in
ERROR_SPEAKER_TIME
.
findall
(
stdout
)]
[
float
(
m
)
for
m
in
ERROR_SPEAKER_TIME
.
findall
(
stdout
)])
)
with
np
.
errstate
(
invalid
=
"ignore"
,
divide
=
"ignore"
):
with
np
.
errstate
(
invalid
=
"ignore"
,
divide
=
"ignore"
):
tot_error_times
=
(
tot_error_times
=
(
miss_speaker_times
+
fa_speaker_times
+
error_speaker_times
miss_speaker_times
+
fa_speaker_times
+
error_speaker_times
)
)
miss_speaker_frac
=
miss_speaker_times
/
scored_speaker_times
miss_speaker_frac
=
miss_speaker_times
/
scored_speaker_times
fa_speaker_frac
=
fa_speaker_times
/
scored_speaker_times
fa_speaker_frac
=
fa_speaker_times
/
scored_speaker_times
sers_frac
=
error_speaker_times
/
scored_speaker_times
sers_frac
=
error_speaker_times
/
scored_speaker_times
...
@@ -153,13 +159,19 @@ def DER(
...
@@ -153,13 +159,19 @@ def DER(
else
:
else
:
return
miss_speaker
[
-
1
],
fa_speaker
[
-
1
],
sers
[
-
1
],
ders
[
-
1
]
return
miss_speaker
[
-
1
],
fa_speaker
[
-
1
],
sers
[
-
1
],
ders
[
-
1
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Compute Diarization Error Rate'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Compute Diarization Error Rate'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--ref_rttm'
,
required
=
True
,
help
=
'the path of reference/groundtruth RTTM file'
)
'--ref_rttm'
,
required
=
True
,
help
=
'the path of reference/groundtruth RTTM file'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--sys_rttm'
,
required
=
True
,
help
=
'the path of the system generated RTTM file'
)
'--sys_rttm'
,
required
=
True
,
help
=
'the path of the system generated RTTM file'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--individual_file'
,
'--individual_file'
,
default
=
False
,
default
=
False
,
...
@@ -176,4 +188,5 @@ if __name__ == '__main__':
...
@@ -176,4 +188,5 @@ if __name__ == '__main__':
print
(
args
)
print
(
args
)
der
=
DER
(
args
.
ref_rttm
,
args
.
sys_rttm
)
der
=
DER
(
args
.
ref_rttm
,
args
.
sys_rttm
)
print
(
"miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%"
%
(
der
[
0
],
der
[
1
],
der
[
2
],
der
[
-
1
]))
print
(
"miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%"
%
\ No newline at end of file
(
der
[
0
],
der
[
1
],
der
[
2
],
der
[
-
1
]))
utils/md-eval.pl
→
paddlespeech/vector/
utils/md-eval.pl
浏览文件 @
14d9e80b
文件已移动
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录