提交 98aaa381 编写于 作者: Q qingen

add DER scripts to calculate Diarization Error Rate

上级 52a8b2f3
"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS),
False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation.
Authors
* Neville Ryant 2018
* Nauman Dawalatabad 2020
Credits
This code is adapted from https://github.com/nryant/dscore
"""
import os
import re
import subprocess
import numpy as np
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")
MISS_SPEAKER_TIME = re.compile(r"(?<=MISSED SPEAKER TIME =)[\d.]+")
FA_SPEAKER_TIME = re.compile(r"(?<=FALARM SPEAKER TIME =)[\d.]+")
ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+")
def rectify(arr):
"""Corrects corner cases and converts scores into percentage.
"""
# Numerator and denominator both 0.
arr[np.isnan(arr)] = 0
# Numerator > 0, but denominator = 0.
arr[np.isinf(arr)] = 1
arr *= 100.0
return arr
def DER(
ref_rttm,
sys_rttm,
ignore_overlap=False,
collar=0.25,
individual_file_scores=False,
):
"""Computes Missed Speaker percentage (MS), False Alarm (FA),
Speaker Error Rate (SER), and Diarization Error Rate (DER).
Arguments
---------
ref_rttm : str
The path of reference/groundtruth RTTM file.
sys_rttm : str
The path of the system generated RTTM file.
individual_file : bool
If True, returns scores for each file in order.
collar : float
Forgiveness collar.
ignore_overlap : bool
If True, ignores overlapping speech during evaluation.
Returns
-------
MS : float array
Missed Speech.
FA : float array
False Alarms.
SER : float array
Speaker Error Rates.
DER : float array
Diarization Error Rates.
Example
-------
>>> import pytest
>>> pytest.skip('Skipping because of Perl dependency')
>>> ref_rttm = "../../samples/rttm_samples/ref_rttm/ES2014c.rttm"
>>> sys_rttm = "../../samples/rttm_samples/sys_rttm/ES2014c.rttm"
>>> ignore_overlap = True
>>> collar = 0.25
>>> individual_file_scores = True
>>> Scores = DER(ref_rttm, sys_rttm, ignore_overlap, collar, individual_file_scores)
>>> print (Scores)
(array([0., 0.]), array([0., 0.]), array([7.16923618, 7.16923618]), array([7.16923618, 7.16923618]))
"""
curr = os.path.abspath(os.path.dirname(__file__))
mdEval = os.path.join(curr, "../../tools/der_eval/md-eval.pl")
cmd = [
mdEval,
"-af",
"-r",
ref_rttm,
"-s",
sys_rttm,
"-c",
str(collar),
]
if ignore_overlap:
cmd.append("-1")
try:
stdout = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as ex:
stdout = ex.output
else:
stdout = stdout.decode("utf-8")
# Get all recording IDs
file_ids = [m.strip() for m in FILE_IDS.findall(stdout)]
file_ids = [
file_id[2:] if file_id.startswith("f=") else file_id
for file_id in file_ids
]
scored_speaker_times = np.array(
[float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)]
)
miss_speaker_times = np.array(
[float(m) for m in MISS_SPEAKER_TIME.findall(stdout)]
)
fa_speaker_times = np.array(
[float(m) for m in FA_SPEAKER_TIME.findall(stdout)]
)
error_speaker_times = np.array(
[float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)]
)
with np.errstate(invalid="ignore", divide="ignore"):
tot_error_times = (
miss_speaker_times + fa_speaker_times + error_speaker_times
)
miss_speaker_frac = miss_speaker_times / scored_speaker_times
fa_speaker_frac = fa_speaker_times / scored_speaker_times
sers_frac = error_speaker_times / scored_speaker_times
ders_frac = tot_error_times / scored_speaker_times
# Values in percentage of scored_speaker_time
miss_speaker = rectify(miss_speaker_frac)
fa_speaker = rectify(fa_speaker_frac)
sers = rectify(sers_frac)
ders = rectify(ders_frac)
if individual_file_scores:
return miss_speaker, fa_speaker, sers, ders
else:
return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册