From abb15ac6e8671e80cd0cb5c656db850a69856e63 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Mon, 25 Apr 2022 15:45:55 +0800 Subject: [PATCH] Update KWS example. --- examples/hey_snips/kws0/conf/mdtc.yaml | 80 ++++++++++-------- examples/hey_snips/kws0/local/plot.sh | 25 +++++- examples/hey_snips/kws0/local/score.sh | 26 +++++- examples/hey_snips/kws0/local/train.sh | 22 ++++- examples/hey_snips/kws0/run.sh | 10 ++- paddlespeech/kws/exps/mdtc/compute_det.py | 67 +++++++++------ paddlespeech/kws/exps/mdtc/plot_det_curve.py | 18 ++-- paddlespeech/kws/exps/mdtc/score.py | 71 +++++++++------- paddlespeech/kws/exps/mdtc/train.py | 87 +++++++++++--------- 9 files changed, 258 insertions(+), 148 deletions(-) diff --git a/examples/hey_snips/kws0/conf/mdtc.yaml b/examples/hey_snips/kws0/conf/mdtc.yaml index 3ce9f9d0..4bd0708c 100644 --- a/examples/hey_snips/kws0/conf/mdtc.yaml +++ b/examples/hey_snips/kws0/conf/mdtc.yaml @@ -1,39 +1,49 @@ -data: - data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter' - dataset: 'paddleaudio.datasets:HeySnips' +# https://yaml.org/type/float.html +########################################### +# Data # +########################################### +dataset: 'paddleaudio.datasets:HeySnips' +data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter' -model: - num_keywords: 1 - backbone: 'paddlespeech.kws.models:MDTC' - config: - stack_num: 3 - stack_size: 4 - in_channels: 80 - res_channels: 32 - kernel_size: 5 +############################################ +# Network Architecture # +############################################ +backbone: 'paddlespeech.kws.models:MDTC' +num_keywords: 1 +stack_num: 3 +stack_size: 4 +in_channels: 80 +res_channels: 32 +kernel_size: 5 -feature: - feat_type: 'kaldi_fbank' - sample_rate: 16000 - frame_shift: 10 - frame_length: 25 - n_mels: 80 +########################################### +# Feature # +########################################### +feat_type: 'kaldi_fbank' +sample_rate: 16000 +frame_shift: 10 +frame_length: 25 +n_mels: 80 -training: - epochs: 100 - num_workers: 16 - batch_size: 100 - checkpoint_dir: './checkpoint' - save_freq: 10 - log_freq: 10 - learning_rate: 0.001 - weight_decay: 0.00005 - grad_clip: 5.0 +########################################### +# Training # +########################################### +epochs: 100 +num_workers: 16 +batch_size: 100 +checkpoint_dir: './checkpoint' +save_freq: 10 +log_freq: 10 +learning_rate: 0.001 +weight_decay: 0.00005 +grad_clip: 5.0 -scoring: - batch_size: 100 - num_workers: 16 - checkpoint: './checkpoint/epoch_100/model.pdparams' - score_file: './scores.txt' - stats_file: './stats.0.txt' - img_file: './det.png' \ No newline at end of file +########################################### +# Scoring # +########################################### +batch_size: 100 +num_workers: 16 +checkpoint: './checkpoint/epoch_100/model.pdparams' +score_file: './scores.txt' +stats_file: './stats.0.txt' +img_file: './det.png' \ No newline at end of file diff --git a/examples/hey_snips/kws0/local/plot.sh b/examples/hey_snips/kws0/local/plot.sh index 5869e50b..783de98b 100755 --- a/examples/hey_snips/kws0/local/plot.sh +++ b/examples/hey_snips/kws0/local/plot.sh @@ -1,2 +1,25 @@ #!/bin/bash -python3 ${BIN_DIR}/plot_det_curve.py --cfg_path=$1 --keyword HeySnips +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# != 3 ];then + echo "usage: ${0} config_path checkpoint output_file" + exit -1 +fi + +keyword=$1 +stats_file=$2 +img_file=$3 + +python3 ${BIN_DIR}/plot_det_curve.py --keyword_label ${keyword} --stats_file ${stats_file} --img_file ${img_file} diff --git a/examples/hey_snips/kws0/local/score.sh b/examples/hey_snips/kws0/local/score.sh index ed21d08c..916536af 100755 --- a/examples/hey_snips/kws0/local/score.sh +++ b/examples/hey_snips/kws0/local/score.sh @@ -1,5 +1,27 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -python3 ${BIN_DIR}/score.py --cfg_path=$1 +if [ $# != 4 ];then + echo "usage: ${0} checkpoint score_file stats_file" + exit -1 +fi -python3 ${BIN_DIR}/compute_det.py --cfg_path=$1 +cfg_path=$1 +ckpt=$2 +score_file=$3 +stats_file=$4 + +python3 ${BIN_DIR}/score.py --config ${cfg_path} --ckpt ${ckpt} --score_file ${score_file} || exit -1 +python3 ${BIN_DIR}/compute_det.py --config ${cfg_path} --score_file ${score_file} --stats_file ${stats_file} || exit -1 diff --git a/examples/hey_snips/kws0/local/train.sh b/examples/hey_snips/kws0/local/train.sh index 8d0181b8..c403f22a 100755 --- a/examples/hey_snips/kws0/local/train.sh +++ b/examples/hey_snips/kws0/local/train.sh @@ -1,13 +1,31 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# != 2 ];then + echo "usage: ${0} num_gpus config_path" + exit -1 +fi ngpu=$1 cfg_path=$2 if [ ${ngpu} -gt 0 ]; then python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \ - --cfg_path ${cfg_path} + --config ${cfg_path} else echo "set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning." python3 ${BIN_DIR}/train.py \ - --cfg_path ${cfg_path} + --config ${cfg_path} fi diff --git a/examples/hey_snips/kws0/run.sh b/examples/hey_snips/kws0/run.sh index 2cc09a4f..bc25a8e8 100755 --- a/examples/hey_snips/kws0/run.sh +++ b/examples/hey_snips/kws0/run.sh @@ -32,10 +32,16 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ./local/train.sh ${ngpu} ${cfg_path} || exit -1 fi +ckpt=./checkpoint/epoch_100/model.pdparams +score_file=./scores.txt +stats_file=./stats.0.txt +img_file=./det.png + if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - ./local/score.sh ${cfg_path} || exit -1 + ./local/score.sh ${cfg_path} ${ckpt} ${score_file} ${stats_file} || exit -1 fi +keyword=HeySnips if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - ./local/plot.sh ${cfg_path} || exit -1 + ./local/plot.sh ${keyword} ${stats_file} ${img_file} || exit -1 fi \ No newline at end of file diff --git a/paddlespeech/kws/exps/mdtc/compute_det.py b/paddlespeech/kws/exps/mdtc/compute_det.py index 817846b8..e43a953d 100644 --- a/paddlespeech/kws/exps/mdtc/compute_det.py +++ b/paddlespeech/kws/exps/mdtc/compute_det.py @@ -12,24 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from wekws(https://github.com/wenet-e2e/wekws) -import argparse import os import paddle -import yaml from tqdm import tqdm +from yacs.config import CfgNode +from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -# yapf: disable -parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--cfg_path", type=str, required=True) -parser.add_argument('--keyword_index', type=int, default=0, help='keyword index') -parser.add_argument('--step', type=float, default=0.01, help='threshold step of trigger score') -parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered') -args = parser.parse_args() -# yapf: enable - def load_label_and_score(keyword_index: int, ds: paddle.io.Dataset, @@ -61,26 +52,52 @@ def load_label_and_score(keyword_index: int, if __name__ == '__main__': - args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) - with open(args.cfg_path, 'r') as f: - config = yaml.safe_load(f) + parser = default_argument_parser() + parser.add_argument( + '--keyword_index', type=int, default=0, help='keyword index') + parser.add_argument( + '--step', + type=float, + default=0.01, + help='threshold step of trigger score') + parser.add_argument( + '--window_shift', + type=int, + default=50, + help='window_shift is used to skip the frames after triggered') + parser.add_argument( + "--score_file", + type=str, + required=True, + help='output file of trigger scores') + parser.add_argument( + '--stats_file', + type=str, + default='./stats.0.txt', + help='output file of detection error tradeoff') + args = parser.parse_args() - data_conf = config['data'] - feat_conf = config['feature'] - scoring_conf = config['scoring'] + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) # Dataset - ds_class = dynamic_import(data_conf['dataset']) - test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf) - - score_file = os.path.abspath(scoring_conf['score_file']) - stats_file = os.path.abspath(scoring_conf['stats_file']) + ds_class = dynamic_import(config['dataset']) + test_ds = ds_class( + data_dir=config['data_dir'], + mode='test', + feat_type=config['feat_type'], + sample_rate=config['sample_rate'], + frame_shift=config['frame_shift'], + frame_length=config['frame_length'], + n_mels=config['n_mels'], ) keyword_table, filler_table, filler_duration = load_label_and_score( - args.keyword, test_ds, score_file) + args.keyword_index, test_ds, args.score_file) print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) pbar = tqdm(total=int(1.0 / args.step)) - with open(stats_file, 'w', encoding='utf8') as fout: + with open(args.stats_file, 'w', encoding='utf8') as fout: keyword_index = args.keyword_index threshold = 0.0 while threshold <= 1.0: @@ -113,4 +130,4 @@ if __name__ == '__main__': pbar.update(1) pbar.close() - print('DET saved to: {}'.format(stats_file)) + print('DET saved to: {}'.format(args.stats_file)) diff --git a/paddlespeech/kws/exps/mdtc/plot_det_curve.py b/paddlespeech/kws/exps/mdtc/plot_det_curve.py index ac920358..a3ea21ef 100644 --- a/paddlespeech/kws/exps/mdtc/plot_det_curve.py +++ b/paddlespeech/kws/exps/mdtc/plot_det_curve.py @@ -17,12 +17,12 @@ import os import matplotlib.pyplot as plt import numpy as np -import yaml # yapf: disable parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--cfg_path", type=str, required=True) -parser.add_argument("--keyword", type=str, required=True) +parser.add_argument('--keyword_label', type=str, required=True, help='keyword string shown on image') +parser.add_argument('--stats_file', type=str, required=True, help='output file of detection error tradeoff') +parser.add_argument('--img_file', type=str, default='./det.png', help='output det image') args = parser.parse_args() # yapf: enable @@ -61,14 +61,8 @@ def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim, if __name__ == '__main__': - args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) - with open(args.cfg_path, 'r') as f: - config = yaml.safe_load(f) - - scoring_conf = config['scoring'] - img_file = os.path.abspath(scoring_conf['img_file']) - stats_file = os.path.abspath(scoring_conf['stats_file']) - keywords = [args.keyword] - plot_det_curve(keywords, stats_file, img_file, 10, 2, 10, 2) + img_file = os.path.abspath(args.img_file) + stats_file = os.path.abspath(args.stats_file) + plot_det_curve([args.keyword_label], stats_file, img_file, 10, 2, 10, 2) print('DET curve image saved to: {}'.format(img_file)) diff --git a/paddlespeech/kws/exps/mdtc/score.py b/paddlespeech/kws/exps/mdtc/score.py index 7fe88ea3..1b5e1e29 100644 --- a/paddlespeech/kws/exps/mdtc/score.py +++ b/paddlespeech/kws/exps/mdtc/score.py @@ -12,55 +12,67 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from wekws(https://github.com/wenet-e2e/wekws) -import argparse -import os - import paddle -import yaml from tqdm import tqdm +from yacs.config import CfgNode from paddlespeech.kws.exps.mdtc.collate import collate_features from paddlespeech.kws.models.mdtc import KWSModel +from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -# yapf: disable -parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--cfg_path", type=str, required=True) -args = parser.parse_args() -# yapf: enable - if __name__ == '__main__': - args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) - with open(args.cfg_path, 'r') as f: - config = yaml.safe_load(f) + parser = default_argument_parser() + parser.add_argument( + "--ckpt", + type=str, + required=True, + help='model checkpoint for evaluation.') + parser.add_argument( + "--score_file", + type=str, + default='./scores.txt', + help='output file of trigger scores') + args = parser.parse_args() - model_conf = config['model'] - data_conf = config['data'] - feat_conf = config['feature'] - scoring_conf = config['scoring'] + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) # Dataset - ds_class = dynamic_import(data_conf['dataset']) - test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf) + ds_class = dynamic_import(config['dataset']) + test_ds = ds_class( + data_dir=config['data_dir'], + mode='test', + feat_type=config['feat_type'], + sample_rate=config['sample_rate'], + frame_shift=config['frame_shift'], + frame_length=config['frame_length'], + n_mels=config['n_mels'], ) test_sampler = paddle.io.BatchSampler( - test_ds, batch_size=scoring_conf['batch_size'], drop_last=False) + test_ds, batch_size=config['batch_size'], drop_last=False) test_loader = paddle.io.DataLoader( test_ds, batch_sampler=test_sampler, - num_workers=scoring_conf['num_workers'], + num_workers=config['num_workers'], return_list=True, use_buffer_reader=True, collate_fn=collate_features, ) # Model - backbone_class = dynamic_import(model_conf['backbone']) - backbone = backbone_class(**model_conf['config']) - model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords']) - model.set_state_dict(paddle.load(scoring_conf['checkpoint'])) + backbone_class = dynamic_import(config['backbone']) + backbone = backbone_class( + stack_num=config['stack_num'], + stack_size=config['stack_size'], + in_channels=config['in_channels'], + res_channels=config['res_channels'], + kernel_size=config['kernel_size'], ) + model = KWSModel(backbone=backbone, num_keywords=config['num_keywords']) + model.set_state_dict(paddle.load(args.ckpt)) model.eval() - with paddle.no_grad(), open( - scoring_conf['score_file'], 'w', encoding='utf8') as fout: + with paddle.no_grad(), open(args.score_file, 'w', encoding='utf8') as f: for batch_idx, batch in enumerate( tqdm(test_loader, total=len(test_loader))): keys, feats, labels, lengths = batch @@ -73,7 +85,6 @@ if __name__ == '__main__': keyword_scores = score[:, keyword_i] score_frames = ' '.join( ['{:.6f}'.format(x) for x in keyword_scores.tolist()]) - fout.write( - '{} {} {}\n'.format(key, keyword_i, score_frames)) + f.write('{} {} {}\n'.format(key, keyword_i, score_frames)) - print('Result saved to: {}'.format(scoring_conf['score_file'])) + print('Result saved to: {}'.format(args.score_file)) diff --git a/paddlespeech/kws/exps/mdtc/train.py b/paddlespeech/kws/exps/mdtc/train.py index 99e72871..56082bd7 100644 --- a/paddlespeech/kws/exps/mdtc/train.py +++ b/paddlespeech/kws/exps/mdtc/train.py @@ -11,77 +11,88 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import os import paddle -import yaml +from yacs.config import CfgNode from paddleaudio.utils import logger from paddleaudio.utils import Timer from paddlespeech.kws.exps.mdtc.collate import collate_features from paddlespeech.kws.models.loss import max_pooling_loss from paddlespeech.kws.models.mdtc import KWSModel +from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -# yapf: disable -parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--cfg_path", type=str, required=True) -args = parser.parse_args() -# yapf: enable - if __name__ == '__main__': + parser = default_argument_parser() + args = parser.parse_args() + + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + nranks = paddle.distributed.get_world_size() if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() local_rank = paddle.distributed.get_rank() - args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) - with open(args.cfg_path, 'r') as f: - config = yaml.safe_load(f) - - model_conf = config['model'] - data_conf = config['data'] - feat_conf = config['feature'] - training_conf = config['training'] - # Dataset - ds_class = dynamic_import(data_conf['dataset']) + ds_class = dynamic_import(config['dataset']) train_ds = ds_class( - data_dir=data_conf['data_dir'], mode='train', **feat_conf) - dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf) + data_dir=config['data_dir'], + mode='train', + feat_type=config['feat_type'], + sample_rate=config['sample_rate'], + frame_shift=config['frame_shift'], + frame_length=config['frame_length'], + n_mels=config['n_mels'], ) + dev_ds = ds_class( + data_dir=config['data_dir'], + mode='dev', + feat_type=config['feat_type'], + sample_rate=config['sample_rate'], + frame_shift=config['frame_shift'], + frame_length=config['frame_length'], + n_mels=config['n_mels'], ) train_sampler = paddle.io.DistributedBatchSampler( train_ds, - batch_size=training_conf['batch_size'], + batch_size=config['batch_size'], shuffle=True, drop_last=False) train_loader = paddle.io.DataLoader( train_ds, batch_sampler=train_sampler, - num_workers=training_conf['num_workers'], + num_workers=config['num_workers'], return_list=True, use_buffer_reader=True, collate_fn=collate_features, ) # Model - backbone_class = dynamic_import(model_conf['backbone']) - backbone = backbone_class(**model_conf['config']) - model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords']) + backbone_class = dynamic_import(config['backbone']) + backbone = backbone_class( + stack_num=config['stack_num'], + stack_size=config['stack_size'], + in_channels=config['in_channels'], + res_channels=config['res_channels'], + kernel_size=config['kernel_size'], ) + model = KWSModel(backbone=backbone, num_keywords=config['num_keywords']) model = paddle.DataParallel(model) - clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip']) + clip = paddle.nn.ClipGradByGlobalNorm(config['grad_clip']) optimizer = paddle.optimizer.Adam( - learning_rate=training_conf['learning_rate'], - weight_decay=training_conf['weight_decay'], + learning_rate=config['learning_rate'], + weight_decay=config['weight_decay'], parameters=model.parameters(), grad_clip=clip) criterion = max_pooling_loss steps_per_epoch = len(train_sampler) - timer = Timer(steps_per_epoch * training_conf['epochs']) + timer = Timer(steps_per_epoch * config['epochs']) timer.start() - for epoch in range(1, training_conf['epochs'] + 1): + for epoch in range(1, config['epochs'] + 1): model.train() avg_loss = 0 @@ -107,15 +118,13 @@ if __name__ == '__main__': timer.count() - if (batch_idx + 1 - ) % training_conf['log_freq'] == 0 and local_rank == 0: + if (batch_idx + 1) % config['log_freq'] == 0 and local_rank == 0: lr = optimizer.get_lr() - avg_loss /= training_conf['log_freq'] + avg_loss /= config['log_freq'] avg_acc = num_corrects / num_samples print_msg = 'Epoch={}/{}, Step={}/{}'.format( - epoch, training_conf['epochs'], batch_idx + 1, - steps_per_epoch) + epoch, config['epochs'], batch_idx + 1, steps_per_epoch) print_msg += ' loss={:.4f}'.format(avg_loss) print_msg += ' acc={:.4f}'.format(avg_acc) print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format( @@ -126,17 +135,17 @@ if __name__ == '__main__': num_corrects = 0 num_samples = 0 - if epoch % training_conf[ + if epoch % config[ 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0: dev_sampler = paddle.io.BatchSampler( dev_ds, - batch_size=training_conf['batch_size'], + batch_size=config['batch_size'], shuffle=False, drop_last=False) dev_loader = paddle.io.DataLoader( dev_ds, batch_sampler=dev_sampler, - num_workers=training_conf['num_workers'], + num_workers=config['num_workers'], return_list=True, use_buffer_reader=True, collate_fn=collate_features, ) @@ -159,7 +168,7 @@ if __name__ == '__main__': logger.eval(print_msg) # Save model - save_dir = os.path.join(training_conf['checkpoint_dir'], + save_dir = os.path.join(config['checkpoint_dir'], 'epoch_{}'.format(epoch)) logger.info('Saving model checkpoint to {}'.format(save_dir)) paddle.save(model.state_dict(), -- GitLab