提交 caa8eb4d 编写于 作者: K KP

Add KWS example.

上级 43659b98
......@@ -7,6 +7,7 @@ if [ ${ngpu} -gt 0 ]; then
python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \
--cfg_path ${cfg_path}
else
echo "set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning."
python3 ${BIN_DIR}/train.py \
--cfg_path ${cfg_path}
fi
......@@ -18,6 +18,11 @@ source path.sh
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
if [ $# != 1 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path"
exit -1
fi
stage=1
stop_stage=3
......
......@@ -15,6 +15,7 @@
import argparse
import os
import paddle
import yaml
from tqdm import tqdm
......@@ -23,32 +24,34 @@ from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
parser.add_argument('--keyword', type=int, default=0, help='keyword label')
parser.add_argument('--step', type=float, default=0.01, help='threshold step')
parser.add_argument('--keyword_index', type=int, default=0, help='keyword index')
parser.add_argument('--step', type=float, default=0.01, help='threshold step of trigger score')
parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered')
args = parser.parse_args()
# yapf: enable
def load_label_and_score(keyword, ds, score_file):
score_table = {}
def load_label_and_score(keyword_index: int,
ds: paddle.io.Dataset,
score_file: os.PathLike):
score_table = {} # {utt_id: scores_over_frames}
with open(score_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
key = arr[0]
current_keyword = arr[1]
str_list = arr[2:]
if int(current_keyword) == keyword:
if int(current_keyword) == keyword_index:
scores = list(map(float, str_list))
if key not in score_table:
score_table.update({key: scores})
keyword_table = {}
filler_table = {}
keyword_table = {} # scores of keyword utt_id
filler_table = {} # scores of non-keyword utt_id
filler_duration = 0.0
for key, index, duration in zip(ds.keys, ds.labels, ds.durations):
assert key in score_table
if index == keyword:
if index == keyword_index:
keyword_table[key] = score_table[key]
else:
filler_table[key] = score_table[key]
......@@ -78,7 +81,7 @@ if __name__ == '__main__':
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
pbar = tqdm(total=int(1.0 / args.step))
with open(stats_file, 'w', encoding='utf8') as fout:
keyword_index = args.keyword
keyword_index = args.keyword_index
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
......
......@@ -15,7 +15,16 @@
import paddle
def fill_mask_elements(condition, value, x):
def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor:
batch_size = lengths.shape[0]
max_len = int(lengths.max().item())
seq = paddle.arange(max_len, dtype=paddle.int64)
seq = seq.expand((batch_size, max_len))
return seq >= lengths.unsqueeze(1)
def fill_mask_elements(condition: paddle.Tensor, value: float,
x: paddle.Tensor) -> paddle.Tensor:
assert condition.shape == x.shape
values = paddle.ones_like(x, dtype=x.dtype) * value
return paddle.where(condition, values, x)
......@@ -70,11 +79,3 @@ def max_pooling_loss(logits: paddle.Tensor,
acc = num_correct / num_utts
# acc = 0.0
return loss, num_correct, acc
def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor:
batch_size = lengths.shape[0]
max_len = int(lengths.max().item())
seq = paddle.arange(max_len, dtype=paddle.int64)
seq = seq.expand((batch_size, max_len))
return seq >= lengths.unsqueeze(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册