未验证 提交 962a2789 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1558 from KPatr1ck/kws

[KWS]Add kws example on HeySnips dataset.
......@@ -13,6 +13,7 @@
# limitations under the License.
from .esc50 import ESC50
from .gtzan import GTZAN
from .hey_snips import HeySnips
from .rirs_noises import OpenRIRNoise
from .tess import TESS
from .urban_sound import UrbanSound8K
......
......@@ -17,6 +17,8 @@ import numpy as np
import paddle
from ..backends import load as load_audio
from ..compliance.kaldi import fbank as kaldi_fbank
from ..compliance.kaldi import mfcc as kaldi_mfcc
from ..compliance.librosa import melspectrogram
from ..compliance.librosa import mfcc
......@@ -24,6 +26,8 @@ feat_funcs = {
'raw': None,
'melspectrogram': melspectrogram,
'mfcc': mfcc,
'kaldi_fbank': kaldi_fbank,
'kaldi_mfcc': kaldi_mfcc,
}
......@@ -73,6 +77,11 @@ class AudioClassificationDataset(paddle.io.Dataset):
feat_func = feat_funcs[self.feat_type]
record = {}
if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']:
waveform = paddle.to_tensor(waveform).unsqueeze(0) # (C, T)
record['feat'] = feat_func(
waveform=waveform, sr=self.sample_rate, **self.feat_config)
else:
record['feat'] = feat_func(
waveform, sample_rate,
**self.feat_config) if feat_func else waveform
......@@ -81,6 +90,9 @@ class AudioClassificationDataset(paddle.io.Dataset):
def __getitem__(self, idx):
record = self._convert_to_record(idx)
if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']:
return self.keys[idx], record['feat'], record['label']
else:
return np.array(record['feat']).transpose(), np.array(
record['label'], dtype=np.int64)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import os
from typing import List
from typing import Tuple
from .dataset import AudioClassificationDataset
__all__ = ['HeySnips']
class HeySnips(AudioClassificationDataset):
meta_info = collections.namedtuple('META_INFO',
('key', 'label', 'duration', 'wav'))
def __init__(self,
data_dir: os.PathLike,
mode: str='train',
feat_type: str='kaldi_fbank',
sample_rate: int=16000,
**kwargs):
self.data_dir = data_dir
files, labels = self._get_data(mode)
super(HeySnips, self).__init__(
files=files,
labels=labels,
feat_type=feat_type,
sample_rate=sample_rate,
**kwargs)
def _get_meta_info(self, mode) -> List[collections.namedtuple]:
ret = []
with open(os.path.join(self.data_dir, '{}.json'.format(mode)),
'r') as f:
data = json.load(f)
for item in data:
sample = collections.OrderedDict()
if item['duration'] > 0:
sample['key'] = item['id']
sample['label'] = 0 if item['is_hotword'] == 1 else -1
sample['duration'] = item['duration']
sample['wav'] = os.path.join(self.data_dir,
item['audio_file_path'])
ret.append(self.meta_info(*sample.values()))
return ret
def _get_data(self, mode: str) -> Tuple[List[str], List[int]]:
meta_info = self._get_meta_info(mode)
files = []
labels = []
self.keys = []
self.durations = []
for sample in meta_info:
key, target, duration, wav = sample
files.append(wav)
labels.append(int(target))
self.keys.append(key)
self.durations.append(float(duration))
return files, labels
## Metrics
We mesure FRRs with fixing false alarms in one hour:
|Model|False Alarm| False Reject Rate|
|--|--|--|
|MDTC| 1| 0.003559 |
# MDTC Keyword Spotting with HeySnips Dataset
## Dataset
Before running scripts, you **MUST** follow this instruction to download the dataset: https://github.com/sonos/keyword-spotting-research-datasets
After you download and decompress the dataset archive, you should **REPLACE** the value of `data_dir` in `conf/*.yaml` to complete dataset config.
## Get Started
In this section, we will train the [MDTC](https://arxiv.org/pdf/2102.13552.pdf) model and evaluate on "Hey Snips" dataset.
```sh
CUDA_VISIBLE_DEVICES=0,1 ./run.sh conf/mdtc.yaml
```
This script contains training and scoring steps. You can just set the `CUDA_VISIBLE_DEVICES` environment var to run on single gpu or multi-gpus.
The vars `stage` and `stop_stage` in `./run.sh` controls the running steps:
- stage 1: Training from scratch.
- stage 2: Evaluating model on test dataset and computing detection error tradeoff(DET) of all trigger thresholds.
- stage 3: Plotting the DET cruve for visualizaiton.
data:
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
dataset: 'paddleaudio.datasets:HeySnips'
model:
num_keywords: 1
backbone: 'paddlespeech.kws.models:MDTC'
config:
stack_num: 3
stack_size: 4
in_channels: 80
res_channels: 32
kernel_size: 5
feature:
feat_type: 'kaldi_fbank'
sample_rate: 16000
frame_shift: 10
frame_length: 25
n_mels: 80
training:
epochs: 100
num_workers: 16
batch_size: 100
checkpoint_dir: './checkpoint'
save_freq: 10
log_freq: 10
learning_rate: 0.001
weight_decay: 0.00005
grad_clip: 5.0
scoring:
batch_size: 100
num_workers: 16
checkpoint: './checkpoint/epoch_100/model.pdparams'
score_file: './scores.txt'
stats_file: './stats.0.txt'
img_file: './det.png'
\ No newline at end of file
#!/bin/bash
python3 ${BIN_DIR}/plot_det_curve.py --cfg_path=$1 --keyword HeySnips
#!/bin/bash
python3 ${BIN_DIR}/score.py --cfg_path=$1
python3 ${BIN_DIR}/compute_det.py --cfg_path=$1
#!/bin/bash
ngpu=$1
cfg_path=$2
if [ ${ngpu} -gt 0 ]; then
python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \
--cfg_path ${cfg_path}
else
echo "set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning."
python3 ${BIN_DIR}/train.py \
--cfg_path ${cfg_path}
fi
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=mdtc
export BIN_DIR=${MAIN_ROOT}/paddlespeech/kws/exps/${MODEL}
\ No newline at end of file
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
source path.sh
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
if [ $# != 1 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path"
exit -1
fi
stage=1
stop_stage=3
cfg_path=$1
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
./local/train.sh ${ngpu} ${cfg_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/score.sh ${cfg_path} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./local/plot.sh ${cfg_path} || exit -1
fi
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .models.mdtc import MDTC
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import paddle
def collate_features(batch):
# (key, feat, label)
collate_start = time.time()
keys = []
feats = []
labels = []
lengths = []
for sample in batch:
keys.append(sample[0])
feats.append(sample[1])
labels.append(sample[2])
lengths.append(sample[1].shape[0])
max_length = max(lengths)
for i in range(len(feats)):
feats[i] = paddle.nn.functional.pad(
feats[i], [0, max_length - feats[i].shape[0], 0, 0],
data_format='NLC')
return keys, paddle.stack(feats), paddle.to_tensor(
labels), paddle.to_tensor(lengths)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import paddle
import yaml
from tqdm import tqdm
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
parser.add_argument('--keyword_index', type=int, default=0, help='keyword index')
parser.add_argument('--step', type=float, default=0.01, help='threshold step of trigger score')
parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered')
args = parser.parse_args()
# yapf: enable
def load_label_and_score(keyword_index: int,
ds: paddle.io.Dataset,
score_file: os.PathLike):
score_table = {} # {utt_id: scores_over_frames}
with open(score_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
key = arr[0]
current_keyword = arr[1]
str_list = arr[2:]
if int(current_keyword) == keyword_index:
scores = list(map(float, str_list))
if key not in score_table:
score_table.update({key: scores})
keyword_table = {} # scores of keyword utt_id
filler_table = {} # scores of non-keyword utt_id
filler_duration = 0.0
for key, index, duration in zip(ds.keys, ds.labels, ds.durations):
assert key in score_table
if index == keyword_index:
keyword_table[key] = score_table[key]
else:
filler_table[key] = score_table[key]
filler_duration += duration
return keyword_table, filler_table, filler_duration
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
data_conf = config['data']
feat_conf = config['feature']
scoring_conf = config['scoring']
# Dataset
ds_class = dynamic_import(data_conf['dataset'])
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf)
score_file = os.path.abspath(scoring_conf['score_file'])
stats_file = os.path.abspath(scoring_conf['stats_file'])
keyword_table, filler_table, filler_duration = load_label_and_score(
args.keyword, test_ds, score_file)
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
pbar = tqdm(total=int(1.0 / args.step))
with open(stats_file, 'w', encoding='utf8') as fout:
keyword_index = args.keyword_index
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
# transverse the all keyword_table
for key, score_list in keyword_table.items():
# computer positive test sample, use the max score of list.
score = max(score_list)
if float(score) < threshold:
num_false_reject += 1
num_false_alarm = 0
# transverse the all filler_table
for key, score_list in filler_table.items():
i = 0
while i < len(score_list):
if score_list[i] >= threshold:
num_false_alarm += 1
i += args.window_shift
else:
i += 1
if len(keyword_table) != 0:
false_reject_rate = num_false_reject / len(keyword_table)
num_false_alarm = max(num_false_alarm, 1e-6)
if filler_duration != 0:
false_alarm_per_hour = num_false_alarm / \
(filler_duration / 3600.0)
fout.write('{:.6f} {:.6f} {:.6f}\n'.format(
threshold, false_alarm_per_hour, false_reject_rate))
threshold += args.step
pbar.update(1)
pbar.close()
print('DET saved to: {}'.format(stats_file))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
import yaml
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
parser.add_argument("--keyword", type=str, required=True)
args = parser.parse_args()
# yapf: enable
def load_stats_file(stats_file):
values = []
with open(stats_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
threshold, fa_per_hour, frr = arr
values.append([float(fa_per_hour), float(frr) * 100])
values.reverse()
return np.array(values)
def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim,
y_step):
plt.figure(dpi=200)
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['font.size'] = 12
for index, keyword in enumerate(keywords):
values = load_stats_file(stats_file)
plt.plot(values[:, 0], values[:, 1], label=keyword)
plt.xlim([0, xlim])
plt.ylim([0, ylim])
plt.xticks(range(0, xlim + x_step, x_step))
plt.yticks(range(0, ylim + y_step, y_step))
plt.xlabel('False Alarm Per Hour')
plt.ylabel('False Rejection Rate (\\%)')
plt.grid(linestyle='--')
plt.legend(loc='best', fontsize=16)
plt.savefig(figure_file)
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
scoring_conf = config['scoring']
img_file = os.path.abspath(scoring_conf['img_file'])
stats_file = os.path.abspath(scoring_conf['stats_file'])
keywords = [args.keyword]
plot_det_curve(keywords, stats_file, img_file, 10, 2, 10, 2)
print('DET curve image saved to: {}'.format(img_file))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import paddle
import yaml
from tqdm import tqdm
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
model_conf = config['model']
data_conf = config['data']
feat_conf = config['feature']
scoring_conf = config['scoring']
# Dataset
ds_class = dynamic_import(data_conf['dataset'])
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf)
test_sampler = paddle.io.BatchSampler(
test_ds, batch_size=scoring_conf['batch_size'], drop_last=False)
test_loader = paddle.io.DataLoader(
test_ds,
batch_sampler=test_sampler,
num_workers=scoring_conf['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
# Model
backbone_class = dynamic_import(model_conf['backbone'])
backbone = backbone_class(**model_conf['config'])
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
model.set_state_dict(paddle.load(scoring_conf['checkpoint']))
model.eval()
with paddle.no_grad(), open(
scoring_conf['score_file'], 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(
tqdm(test_loader, total=len(test_loader))):
keys, feats, labels, lengths = batch
logits = model(feats)
num_keywords = logits.shape[2]
for i in range(len(keys)):
key = keys[i]
score = logits[i][:lengths[i]]
for keyword_i in range(num_keywords):
keyword_scores = score[:, keyword_i]
score_frames = ' '.join(
['{:.6f}'.format(x) for x in keyword_scores.tolist()])
fout.write(
'{} {} {}\n'.format(key, keyword_i, score_frames))
print('Result saved to: {}'.format(scoring_conf['score_file']))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import paddle
import yaml
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
if __name__ == '__main__':
nranks = paddle.distributed.get_world_size()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
local_rank = paddle.distributed.get_rank()
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
model_conf = config['model']
data_conf = config['data']
feat_conf = config['feature']
training_conf = config['training']
# Dataset
ds_class = dynamic_import(data_conf['dataset'])
train_ds = ds_class(
data_dir=data_conf['data_dir'], mode='train', **feat_conf)
dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf)
train_sampler = paddle.io.DistributedBatchSampler(
train_ds,
batch_size=training_conf['batch_size'],
shuffle=True,
drop_last=False)
train_loader = paddle.io.DataLoader(
train_ds,
batch_sampler=train_sampler,
num_workers=training_conf['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
# Model
backbone_class = dynamic_import(model_conf['backbone'])
backbone = backbone_class(**model_conf['config'])
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords'])
model = paddle.DataParallel(model)
clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip'])
optimizer = paddle.optimizer.Adam(
learning_rate=training_conf['learning_rate'],
weight_decay=training_conf['weight_decay'],
parameters=model.parameters(),
grad_clip=clip)
criterion = max_pooling_loss
steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * training_conf['epochs'])
timer.start()
for epoch in range(1, training_conf['epochs'] + 1):
model.train()
avg_loss = 0
num_corrects = 0
num_samples = 0
for batch_idx, batch in enumerate(train_loader):
keys, feats, labels, lengths = batch
logits = model(feats)
loss, corrects, acc = criterion(logits, labels, lengths)
loss.backward()
optimizer.step()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()
# Calculate loss
avg_loss += loss.numpy()[0]
# Calculate metrics
num_corrects += corrects
num_samples += feats.shape[0]
timer.count()
if (batch_idx + 1
) % training_conf['log_freq'] == 0 and local_rank == 0:
lr = optimizer.get_lr()
avg_loss /= training_conf['log_freq']
avg_acc = num_corrects / num_samples
print_msg = 'Epoch={}/{}, Step={}/{}'.format(
epoch, training_conf['epochs'], batch_idx + 1,
steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
logger.train(print_msg)
avg_loss = 0
num_corrects = 0
num_samples = 0
if epoch % training_conf[
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
dev_sampler = paddle.io.BatchSampler(
dev_ds,
batch_size=training_conf['batch_size'],
shuffle=False,
drop_last=False)
dev_loader = paddle.io.DataLoader(
dev_ds,
batch_sampler=dev_sampler,
num_workers=training_conf['num_workers'],
return_list=True,
use_buffer_reader=True,
collate_fn=collate_features, )
model.eval()
num_corrects = 0
num_samples = 0
with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(dev_loader):
keys, feats, labels, lengths = batch
logits = model(feats)
loss, corrects, acc = criterion(logits, labels, lengths)
num_corrects += corrects
num_samples += feats.shape[0]
eval_acc = num_corrects / num_samples
print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(eval_acc)
logger.eval(print_msg)
# Save model
save_dir = os.path.join(training_conf['checkpoint_dir'],
'epoch_{}'.format(epoch))
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(save_dir, 'model.pdopt'))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mdtc import KWSModel
from .mdtc import MDTC
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import paddle
def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor:
batch_size = lengths.shape[0]
max_len = int(lengths.max().item())
seq = paddle.arange(max_len, dtype=paddle.int64)
seq = seq.expand((batch_size, max_len))
return seq >= lengths.unsqueeze(1)
def fill_mask_elements(condition: paddle.Tensor, value: float,
x: paddle.Tensor) -> paddle.Tensor:
assert condition.shape == x.shape
values = paddle.ones_like(x, dtype=x.dtype) * value
return paddle.where(condition, values, x)
def max_pooling_loss(logits: paddle.Tensor,
target: paddle.Tensor,
lengths: paddle.Tensor,
min_duration: int=0):
mask = padding_mask(lengths)
num_utts = logits.shape[0]
num_keywords = logits.shape[2]
loss = 0.0
for i in range(num_utts):
for j in range(num_keywords):
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
if target[i] == j:
# For the keyword, do max-polling
prob = logits[i, :, j]
m = mask[i]
if min_duration > 0:
m[:min_duration] = True
prob = fill_mask_elements(m, 0.0, prob)
prob = paddle.clip(prob, 1e-8, 1.0)
max_prob = prob.max()
loss += -paddle.log(max_prob)
else:
# For other keywords or filler, do min-polling
prob = 1 - logits[i, :, j]
prob = fill_mask_elements(mask[i], 1.0, prob)
prob = paddle.clip(prob, 1e-8, 1.0)
min_prob = prob.min()
loss += -paddle.log(min_prob)
loss = loss / num_utts
# Compute accuracy of current batch
mask = mask.unsqueeze(-1)
logits = fill_mask_elements(mask, 0.0, logits)
max_logits = logits.max(1)
num_correct = 0
for i in range(num_utts):
max_p = max_logits[i].max(0).item()
idx = max_logits[i].argmax(0).item()
# Predict correct as the i'th keyword
if max_p > 0.5 and idx == target[i].item():
num_correct += 1
# Predict correct as the filler, filler id < 0
if max_p < 0.5 and target[i].item() < 0:
num_correct += 1
acc = num_correct / num_utts
# acc = 0.0
return loss, num_correct, acc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class DSDilatedConv1d(nn.Layer):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
dilation: int=1,
stride: int=1,
bias: bool=True, ):
super(DSDilatedConv1d, self).__init__()
self.receptive_fields = dilation * (kernel_size - 1)
self.conv = nn.Conv1D(
in_channels,
in_channels,
kernel_size,
padding=0,
dilation=dilation,
stride=stride,
groups=in_channels,
bias_attr=bias, )
self.bn = nn.BatchNorm1D(in_channels)
self.pointwise = nn.Conv1D(
in_channels,
out_channels,
kernel_size=1,
padding=0,
dilation=1,
bias_attr=bias)
def forward(self, inputs: paddle.Tensor):
outputs = self.conv(inputs)
outputs = self.bn(outputs)
outputs = self.pointwise(outputs)
return outputs
class TCNBlock(nn.Layer):
def __init__(
self,
in_channels: int,
res_channels: int,
kernel_size: int,
dilation: int,
causal: bool, ):
super(TCNBlock, self).__init__()
self.in_channels = in_channels
self.res_channels = res_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.causal = causal
self.receptive_fields = dilation * (kernel_size - 1)
self.half_receptive_fields = self.receptive_fields // 2
self.conv1 = DSDilatedConv1d(
in_channels=in_channels,
out_channels=res_channels,
kernel_size=kernel_size,
dilation=dilation, )
self.bn1 = nn.BatchNorm1D(res_channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv1D(
in_channels=res_channels, out_channels=res_channels, kernel_size=1)
self.bn2 = nn.BatchNorm1D(res_channels)
self.relu2 = nn.ReLU()
def forward(self, inputs: paddle.Tensor):
outputs = self.relu1(self.bn1(self.conv1(inputs)))
outputs = self.bn2(self.conv2(outputs))
if self.causal:
inputs = inputs[:, :, self.receptive_fields:]
else:
inputs = inputs[:, :, self.half_receptive_fields:
-self.half_receptive_fields]
if self.in_channels == self.res_channels:
res_out = self.relu2(outputs + inputs)
else:
res_out = self.relu2(outputs)
return res_out
class TCNStack(nn.Layer):
def __init__(
self,
in_channels: int,
stack_num: int,
stack_size: int,
res_channels: int,
kernel_size: int,
causal: bool, ):
super(TCNStack, self).__init__()
self.in_channels = in_channels
self.stack_num = stack_num
self.stack_size = stack_size
self.res_channels = res_channels
self.kernel_size = kernel_size
self.causal = causal
self.res_blocks = self.stack_tcn_blocks()
self.receptive_fields = self.calculate_receptive_fields()
self.res_blocks = nn.Sequential(*self.res_blocks)
def calculate_receptive_fields(self):
receptive_fields = 0
for block in self.res_blocks:
receptive_fields += block.receptive_fields
return receptive_fields
def build_dilations(self):
dilations = []
for s in range(0, self.stack_size):
for l in range(0, self.stack_num):
dilations.append(2**l)
return dilations
def stack_tcn_blocks(self):
dilations = self.build_dilations()
res_blocks = nn.LayerList()
res_blocks.append(
TCNBlock(
self.in_channels,
self.res_channels,
self.kernel_size,
dilations[0],
self.causal, ))
for dilation in dilations[1:]:
res_blocks.append(
TCNBlock(
self.res_channels,
self.res_channels,
self.kernel_size,
dilation,
self.causal, ))
return res_blocks
def forward(self, inputs: paddle.Tensor):
outputs = self.res_blocks(inputs)
return outputs
class MDTC(nn.Layer):
def __init__(
self,
stack_num: int,
stack_size: int,
in_channels: int,
res_channels: int,
kernel_size: int,
causal: bool=True, ):
super(MDTC, self).__init__()
assert kernel_size % 2 == 1
self.kernel_size = kernel_size
self.causal = causal
self.preprocessor = TCNBlock(
in_channels, res_channels, kernel_size, dilation=1, causal=causal)
self.relu = nn.ReLU()
self.blocks = nn.LayerList()
self.receptive_fields = self.preprocessor.receptive_fields
for i in range(stack_num):
self.blocks.append(
TCNStack(res_channels, stack_size, 1, res_channels, kernel_size,
causal))
self.receptive_fields += self.blocks[-1].receptive_fields
self.half_receptive_fields = self.receptive_fields // 2
self.hidden_dim = res_channels
def forward(self, x: paddle.Tensor):
if self.causal:
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
'constant')
else:
outputs = F.pad(
x,
(0, 0, self.half_receptive_fields, self.half_receptive_fields,
0, 0),
'constant', )
outputs = outputs.transpose([0, 2, 1])
outputs_list = []
outputs = self.relu(self.preprocessor(outputs))
for block in self.blocks:
outputs = block(outputs)
outputs_list.append(outputs)
normalized_outputs = []
output_size = outputs_list[-1].shape[-1]
for x in outputs_list:
remove_length = x.shape[-1] - output_size
if self.causal and remove_length > 0:
normalized_outputs.append(x[:, :, remove_length:])
elif not self.causal and remove_length > 1:
half_remove_length = remove_length // 2
normalized_outputs.append(
x[:, :, half_remove_length:-half_remove_length])
else:
normalized_outputs.append(x)
outputs = paddle.zeros_like(
outputs_list[-1], dtype=outputs_list[-1].dtype)
for x in normalized_outputs:
outputs += x
outputs = outputs.transpose([0, 2, 1])
return outputs, None
class KWSModel(nn.Layer):
def __init__(self, backbone, num_keywords):
super(KWSModel, self).__init__()
self.backbone = backbone
self.linear = nn.Linear(self.backbone.hidden_dim, num_keywords)
self.activation = nn.Sigmoid()
def forward(self, x):
outputs = self.backbone(x)
outputs = self.linear(outputs)
return self.activation(outputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册