提交 14d9e80b 编写于 作者: Q qingen

[vector] add AMI data preparation scripts

上级 98788ca2
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data preparation.
......@@ -21,26 +34,24 @@ from distutils.util import strtobool
from utils.dataio import (
load_pkl,
save_pkl,
)
save_pkl, )
logger = logging.getLogger(__name__)
SAMPLERATE = 16000
def prepare_ami(
data_folder,
manual_annot_folder,
save_folder,
ref_rttm_dir,
meta_data_dir,
split_type="full_corpus_asr",
skip_TNO=True,
mic_type="Mix-Headset",
vad_type="oracle",
max_subseg_dur=3.0,
overlap=1.5,
):
data_folder,
manual_annot_folder,
save_folder,
ref_rttm_dir,
meta_data_dir,
split_type="full_corpus_asr",
skip_TNO=True,
mic_type="Mix-Headset",
vad_type="oracle",
max_subseg_dur=3.0,
overlap=1.5, ):
"""
Prepares reference RTTM and JSON files for the AMI dataset.
......@@ -72,12 +83,12 @@ def prepare_ami(
Example
-------
>>> from recipes.AMI.ami_prepare import prepare_ami
>>> data_folder = '/network/datasets/ami/amicorpus/'
>>> manual_annot_folder = '/home/mila/d/dawalatn/nauman/ami_public_manual/'
>>> save_folder = 'results/save/'
>>> from dataset.ami.ami_prepare import prepare_ami
>>> data_folder = '/home/data/ami/amicorpus/'
>>> manual_annot_folder = '/home/data/ami/ami_public_manual/'
>>> save_folder = './results/
>>> split_type = 'full_corpus_asr'
>>> mic_type = 'Lapel'
>>> mic_type = 'Mix-Headset'
>>> prepare_ami(data_folder, manual_annot_folder, save_folder, split_type, mic_type)
"""
......@@ -112,8 +123,7 @@ def prepare_ami(
# Check if this phase is already done (if so, skip it)
if skip(save_folder, conf, meta_files, opt_file):
logger.info(
"Skipping data preparation, as it was completed in previous run."
)
"Skipping data preparation, as it was completed in previous run.")
return
msg = "\tCreating meta-data file for the AMI Dataset.."
......@@ -138,8 +148,7 @@ def prepare_ami(
data_folder,
manual_annot_folder,
i,
skip_TNO,
)
skip_TNO, )
if i == "dev":
prepare_segs_for_RTTM(
dev_set,
......@@ -147,8 +156,7 @@ def prepare_ami(
data_folder,
manual_annot_folder,
i,
skip_TNO,
)
skip_TNO, )
if i == "eval":
prepare_segs_for_RTTM(
eval_set,
......@@ -156,8 +164,7 @@ def prepare_ami(
data_folder,
manual_annot_folder,
i,
skip_TNO,
)
skip_TNO, )
# Create meta_files for splits
meta_data_dir = meta_data_dir
......@@ -174,8 +181,7 @@ def prepare_ami(
meta_filename_prefix,
max_subseg_dur,
overlap,
mic_type,
)
mic_type, )
save_opt_file = os.path.join(save_folder, opt_file)
save_pkl(conf, save_opt_file)
......@@ -190,13 +196,8 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id):
# Prepare header
for spkr_id in spkrs_list:
# e.g. SPKR-INFO ES2008c 0 <NA> <NA> <NA> unknown ES2008c.A_PM <NA> <NA>
line = (
"SPKR-INFO "
+ rec_id
+ " 0 <NA> <NA> <NA> unknown "
+ spkr_id
+ " <NA> <NA>"
)
line = ("SPKR-INFO " + rec_id + " 0 <NA> <NA> <NA> unknown " + spkr_id +
" <NA> <NA>")
rttm.append(line)
# Append remaining lines
......@@ -206,57 +207,35 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id):
if float(row[1]) < float(row[0]):
msg1 = (
"Possibly Incorrect Annotation Found!! transcriber_start (%s) > transcriber_end (%s)"
% (row[0], row[1])
)
% (row[0], row[1]))
msg2 = (
"Excluding this incorrect row from the RTTM : %s, %s, %s, %s"
% (
rec_id,
row[0],
str(round(float(row[1]) - float(row[0]), 4)),
str(row[2]),
)
)
"Excluding this incorrect row from the RTTM : %s, %s, %s, %s" %
(rec_id, row[0], str(round(float(row[1]) - float(row[0]), 4)),
str(row[2]), ))
logger.info(msg1)
logger.info(msg2)
continue
line = (
"SPEAKER "
+ rec_id
+ " 0 "
+ str(round(float(row[0]), 4))
+ " "
+ str(round(float(row[1]) - float(row[0]), 4))
+ " <NA> <NA> "
+ str(row[2])
+ " <NA> <NA>"
)
line = ("SPEAKER " + rec_id + " 0 " + str(round(float(row[0]), 4)) + " "
+ str(round(float(row[1]) - float(row[0]), 4)) + " <NA> <NA> " +
str(row[2]) + " <NA> <NA>")
rttm.append(line)
return rttm
def prepare_segs_for_RTTM(
list_ids, out_rttm_file, audio_dir, annot_dir, split_type, skip_TNO
):
def prepare_segs_for_RTTM(list_ids, out_rttm_file, audio_dir, annot_dir,
split_type, skip_TNO):
RTTM = [] # Stores all RTTMs clubbed together for a given dataset split
for main_meet_id in list_ids:
# Skip TNO meetings from dev and eval sets
if (
main_meet_id.startswith("TS")
and split_type != "train"
and skip_TNO is True
):
msg = (
"Skipping TNO meeting in AMI "
+ str(split_type)
+ " set : "
+ str(main_meet_id)
)
if (main_meet_id.startswith("TS") and split_type != "train" and
skip_TNO is True):
msg = ("Skipping TNO meeting in AMI " + str(split_type) + " set : "
+ str(main_meet_id))
logger.info(msg)
continue
......@@ -271,8 +250,7 @@ def prepare_segs_for_RTTM(
list_spkr_xmls.sort() # A, B, C, D, E etc (Speakers)
segs = []
spkrs_list = (
[]
) # Since non-scenario recordings contains 3-5 speakers
[]) # Since non-scenario recordings contains 3-5 speakers
for spkr_xml_file in list_spkr_xmls:
......@@ -286,14 +264,11 @@ def prepare_segs_for_RTTM(
root = tree.getroot()
# Start, end and speaker_ID from xml file
segs = segs + [
[
elem.attrib["transcriber_start"],
elem.attrib["transcriber_end"],
spkr_ID,
]
for elem in root.iter("segment")
]
segs = segs + [[
elem.attrib["transcriber_start"],
elem.attrib["transcriber_end"],
spkr_ID,
] for elem in root.iter("segment")]
# Sort rows as per the start time (per recording)
segs.sort(key=lambda x: float(x[0]))
......@@ -404,9 +379,8 @@ def get_subsegments(merged_segs, max_subseg_dur=3.0, overlap=1.5):
return subsegments
def prepare_metadata(
rttm_file, save_dir, data_dir, filename, max_subseg_dur, overlap, mic_type
):
def prepare_metadata(rttm_file, save_dir, data_dir, filename, max_subseg_dur,
overlap, mic_type):
# Read RTTM, get unique meeting_IDs (from RTTM headers)
# For each MeetingID. select that meetID -> merge -> subsegment -> json -> append
......@@ -425,15 +399,13 @@ def prepare_metadata(
MERGED_SEGMENTS = []
SUBSEGMENTS = []
for rec_id in rec_ids:
segs_iter = filter(
lambda x: x.startswith("SPEAKER " + str(rec_id)), RTTM
)
segs_iter = filter(lambda x: x.startswith("SPEAKER " + str(rec_id)),
RTTM)
gt_rttm_segs = [row.split(" ") for row in segs_iter]
# Merge, subsegment and then convert to json format.
merged_segs = merge_rttm_intervals(
gt_rttm_segs
) # We lose speaker_ID after merging
gt_rttm_segs) # We lose speaker_ID after merging
MERGED_SEGMENTS = MERGED_SEGMENTS + merged_segs
# Divide segments into smaller sub-segments
......@@ -467,16 +439,8 @@ def prepare_metadata(
# If multi-mic audio is selected
if mic_type == "Array1":
wav_file_base_path = (
data_dir
+ "/"
+ rec_id
+ "/audio/"
+ rec_id
+ "."
+ mic_type
+ "-"
)
wav_file_base_path = (data_dir + "/" + rec_id + "/audio/" + rec_id +
"." + mic_type + "-")
f = [] # adding all 8 mics
for i in range(8):
......@@ -494,16 +458,8 @@ def prepare_metadata(
}
else:
# Single mic audio
wav_file_path = (
data_dir
+ "/"
+ rec_id
+ "/audio/"
+ rec_id
+ "."
+ mic_type
+ ".wav"
)
wav_file_path = (data_dir + "/" + rec_id + "/audio/" + rec_id + "."
+ mic_type + ".wav")
# Note: key "file" without 's' is used for single-mic
json_dict[subsegment_ID] = {
......@@ -554,6 +510,7 @@ def skip(save_folder, conf, meta_files, opt_file):
return skip
if __name__ == '__main__':
parser = argparse.ArgumentParser(
......@@ -561,42 +518,56 @@ if __name__ == '__main__':
--manual_annot_folder /home/data/ami/ami_public_manual_1.6.2 \
--save_folder ./results/ --ref_rttm_dir ./results/ref_rttms \
--meta_data_dir ./results/metadata',
description='AMI Data preparation')
description='AMI Data preparation')
parser.add_argument(
'--data_folder', required=True, help='Path to the folder where the original amicorpus is stored')
'--data_folder',
required=True,
help='Path to the folder where the original amicorpus is stored')
parser.add_argument(
'--manual_annot_folder', required=True, help='Directory where the manual annotations are stored')
'--manual_annot_folder',
required=True,
help='Directory where the manual annotations are stored')
parser.add_argument(
'--save_folder', required=True, help='The save directory in results')
parser.add_argument(
'--ref_rttm_dir', required=True, help='Directory to store reference RTTM files')
'--ref_rttm_dir',
required=True,
help='Directory to store reference RTTM files')
parser.add_argument(
'--meta_data_dir', required=True, help='Directory to store the meta data (json) files')
'--meta_data_dir',
required=True,
help='Directory to store the meta data (json) files')
parser.add_argument(
'--split_type',
default="full_corpus_asr",
'--split_type',
default="full_corpus_asr",
help='Standard dataset split. See ami_splits.py for more information')
parser.add_argument(
'--skip_TNO', default=True, type=strtobool, help='Skips TNO meeting recordings if True')
'--skip_TNO',
default=True,
type=strtobool,
help='Skips TNO meeting recordings if True')
parser.add_argument(
'--mic_type', default="Mix-Headset", help='Type of microphone to be used')
'--mic_type',
default="Mix-Headset",
help='Type of microphone to be used')
parser.add_argument(
'--vad_type', default="oracle", help='Type of VAD. Kept for future when VAD will be added')
'--vad_type',
default="oracle",
help='Type of VAD. Kept for future when VAD will be added')
parser.add_argument(
'--max_subseg_dur',
default=3.0,
type=float,
help='Duration in seconds of a subsegments to be prepared from larger segments')
'--max_subseg_dur',
default=3.0,
type=float,
help='Duration in seconds of a subsegments to be prepared from larger segments'
)
parser.add_argument(
'--overlap', default=1.5, type=float, help='Overlap duration in seconds between adjacent subsegments')
'--overlap',
default=1.5,
type=float,
help='Overlap duration in seconds between adjacent subsegments')
args = parser.parse_args()
print(args)
prepare_ami(
args.data_folder,
args.manual_annot_folder,
args.save_folder,
args.ref_rttm_dir,
args.meta_data_dir
)
\ No newline at end of file
prepare_ami(args.data_folder, args.manual_annot_folder, args.save_folder,
args.ref_rttm_dir, args.meta_data_dir)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
AMI corpus contained 100 hours of meeting recording.
This script returns the standard train, dev and eval split for AMI corpus.
......@@ -29,8 +42,7 @@ def get_AMI_split(split_option):
if split_option not in ALLOWED_OPTIONS:
print(
f'Invalid split "{split_option}" requested!\nValid split_options are: ',
ALLOWED_OPTIONS,
)
ALLOWED_OPTIONS, )
return
if split_option == "scenario_only":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data reading and writing.
......@@ -5,10 +18,10 @@ Authors
* qingenz123@126.com (Qingen ZHAO) 2022
"""
import os
import pickle
def save_pkl(obj, file):
"""Save an object in pkl format.
......@@ -31,6 +44,7 @@ def save_pkl(obj, file):
with open(file, "wb") as f:
pickle.dump(obj, f)
def load_pickle(pickle_path):
"""Utility function for loading .pkl pickle files.
......@@ -48,6 +62,7 @@ def load_pickle(pickle_path):
out = pickle.load(f)
return out
def load_pkl(file):
"""Loads a pkl file.
......@@ -79,4 +94,4 @@ def load_pkl(file):
return pickle.load(f)
finally:
if os.path.isfile(file + ".lock"):
os.remove(file + ".lock")
\ No newline at end of file
os.remove(file + ".lock")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS),
False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation.
......@@ -26,7 +39,6 @@ ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+")
def rectify(arr):
"""Corrects corner cases and converts scores into percentage.
"""
# Numerator and denominator both 0.
arr[np.isnan(arr)] = 0
......@@ -38,12 +50,11 @@ def rectify(arr):
def DER(
ref_rttm,
sys_rttm,
ignore_overlap=False,
collar=0.25,
individual_file_scores=False,
):
ref_rttm,
sys_rttm,
ignore_overlap=False,
collar=0.25,
individual_file_scores=False, ):
"""Computes Missed Speaker percentage (MS), False Alarm (FA),
Speaker Error Rate (SER), and Diarization Error Rate (DER).
......@@ -118,25 +129,20 @@ def DER(
]
scored_speaker_times = np.array(
[float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)]
)
[float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)])
miss_speaker_times = np.array(
[float(m) for m in MISS_SPEAKER_TIME.findall(stdout)]
)
[float(m) for m in MISS_SPEAKER_TIME.findall(stdout)])
fa_speaker_times = np.array(
[float(m) for m in FA_SPEAKER_TIME.findall(stdout)]
)
[float(m) for m in FA_SPEAKER_TIME.findall(stdout)])
error_speaker_times = np.array(
[float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)]
)
[float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)])
with np.errstate(invalid="ignore", divide="ignore"):
tot_error_times = (
miss_speaker_times + fa_speaker_times + error_speaker_times
)
miss_speaker_times + fa_speaker_times + error_speaker_times)
miss_speaker_frac = miss_speaker_times / scored_speaker_times
fa_speaker_frac = fa_speaker_times / scored_speaker_times
sers_frac = error_speaker_times / scored_speaker_times
......@@ -153,13 +159,19 @@ def DER(
else:
return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Compute Diarization Error Rate')
parser = argparse.ArgumentParser(
description='Compute Diarization Error Rate')
parser.add_argument(
'--ref_rttm', required=True, help='the path of reference/groundtruth RTTM file')
'--ref_rttm',
required=True,
help='the path of reference/groundtruth RTTM file')
parser.add_argument(
'--sys_rttm', required=True, help='the path of the system generated RTTM file')
'--sys_rttm',
required=True,
help='the path of the system generated RTTM file')
parser.add_argument(
'--individual_file',
default=False,
......@@ -176,4 +188,5 @@ if __name__ == '__main__':
print(args)
der = DER(args.ref_rttm, args.sys_rttm)
print("miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%" % (der[0], der[1], der[2], der[-1]))
\ No newline at end of file
print("miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%" %
(der[0], der[1], der[2], der[-1]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册