compute_det.py 5.0 KB
Newer Older
H
Hui Zhang 已提交
1 2
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
#               2022 Shaoqing Yu(954793264@qq.com)
K
KP 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
K
KP 已提交
16
# Modified from wekws(https://github.com/wenet-e2e/wekws)
K
KP 已提交
17 18
import os

K
KP 已提交
19
import paddle
K
KP 已提交
20
from tqdm import tqdm
K
KP 已提交
21
from yacs.config import CfgNode
K
KP 已提交
22

K
KP 已提交
23
from paddlespeech.s2t.training.cli import default_argument_parser
K
KP 已提交
24
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
K
KP 已提交
25

K
KP 已提交
26

K
KP 已提交
27 28 29 30
def load_label_and_score(keyword_index: int,
                         ds: paddle.io.Dataset,
                         score_file: os.PathLike):
    score_table = {}  # {utt_id: scores_over_frames}
K
KP 已提交
31 32 33 34 35 36
    with open(score_file, 'r', encoding='utf8') as fin:
        for line in fin:
            arr = line.strip().split()
            key = arr[0]
            current_keyword = arr[1]
            str_list = arr[2:]
K
KP 已提交
37
            if int(current_keyword) == keyword_index:
K
KP 已提交
38 39 40
                scores = list(map(float, str_list))
                if key not in score_table:
                    score_table.update({key: scores})
K
KP 已提交
41 42
    keyword_table = {}  # scores of keyword utt_id
    filler_table = {}  # scores of non-keyword utt_id
K
KP 已提交
43
    filler_duration = 0.0
K
KP 已提交
44 45 46

    for key, index, duration in zip(ds.keys, ds.labels, ds.durations):
        assert key in score_table
K
KP 已提交
47
        if index == keyword_index:
K
KP 已提交
48 49 50 51 52
            keyword_table[key] = score_table[key]
        else:
            filler_table[key] = score_table[key]
            filler_duration += duration

K
KP 已提交
53 54 55
    return keyword_table, filler_table, filler_duration


K
KP 已提交
56
if __name__ == '__main__':
K
KP 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    parser = default_argument_parser()
    parser.add_argument(
        '--keyword_index', type=int, default=0, help='keyword index')
    parser.add_argument(
        '--step',
        type=float,
        default=0.01,
        help='threshold step of trigger score')
    parser.add_argument(
        '--window_shift',
        type=int,
        default=50,
        help='window_shift is used to skip the frames after triggered')
    parser.add_argument(
        "--score_file",
        type=str,
        required=True,
        help='output file of trigger scores')
    parser.add_argument(
        '--stats_file',
        type=str,
        default='./stats.0.txt',
        help='output file of detection error tradeoff')
    args = parser.parse_args()
K
KP 已提交
81

K
KP 已提交
82 83 84 85
    # https://yaml.org/type/float.html
    config = CfgNode(new_allowed=True)
    if args.config:
        config.merge_from_file(args.config)
K
KP 已提交
86

K
KP 已提交
87
    # Dataset
K
KP 已提交
88 89 90 91 92 93 94 95 96
    ds_class = dynamic_import(config['dataset'])
    test_ds = ds_class(
        data_dir=config['data_dir'],
        mode='test',
        feat_type=config['feat_type'],
        sample_rate=config['sample_rate'],
        frame_shift=config['frame_shift'],
        frame_length=config['frame_length'],
        n_mels=config['n_mels'], )
K
KP 已提交
97 98

    keyword_table, filler_table, filler_duration = load_label_and_score(
K
KP 已提交
99
        args.keyword_index, test_ds, args.score_file)
K
KP 已提交
100 101
    print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
    pbar = tqdm(total=int(1.0 / args.step))
K
KP 已提交
102
    with open(args.stats_file, 'w', encoding='utf8') as fout:
K
KP 已提交
103
        keyword_index = args.keyword_index
K
KP 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        threshold = 0.0
        while threshold <= 1.0:
            num_false_reject = 0
            # transverse the all keyword_table
            for key, score_list in keyword_table.items():
                # computer positive test sample, use the max score of list.
                score = max(score_list)
                if float(score) < threshold:
                    num_false_reject += 1
            num_false_alarm = 0
            # transverse the all filler_table
            for key, score_list in filler_table.items():
                i = 0
                while i < len(score_list):
                    if score_list[i] >= threshold:
                        num_false_alarm += 1
K
KP 已提交
120
                        i += args.window_shift
K
KP 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134
                    else:
                        i += 1
            if len(keyword_table) != 0:
                false_reject_rate = num_false_reject / len(keyword_table)
            num_false_alarm = max(num_false_alarm, 1e-6)
            if filler_duration != 0:
                false_alarm_per_hour = num_false_alarm / \
                    (filler_duration / 3600.0)
            fout.write('{:.6f} {:.6f} {:.6f}\n'.format(
                threshold, false_alarm_per_hour, false_reject_rate))
            threshold += args.step
            pbar.update(1)

    pbar.close()
K
KP 已提交
135
    print('DET saved to: {}'.format(args.stats_file))