提交 789471bf 编写于 作者: H Hui Zhang

test wav for u2

上级 f598df0c
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix audio_file"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_prefix=$2
audio_file=$3
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in attention_rescoring; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test_wav.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size} \
--audio_file ${audio_file}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0
...@@ -12,125 +12,107 @@ ...@@ -12,125 +12,107 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import cProfile
import os import os
import sys import sys
from pathlib import Path
import paddle import paddle
import soundfile import soundfile
from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.exps.u2.config import get_cfg_defaults
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.trainer import Trainer from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
# TODO(hui zhang): dynamic load # TODO(hui zhang): dynamic load
class U2Tester_Hub(Trainer): class U2Infer():
def __init__(self, config, args): def __init__(self, config, args):
# super().__init__(config, args)
self.args = args self.args = args
self.config = config self.config = config
self.audio_file = args.audio_file self.audio_file = args.audio_file
self.collate_fn_test = SpeechCollator.from_config(config) self.sr = config.collator.target_sample_rate
self._text_featurizer = TextFeaturizer(
self.preprocess_conf = config.collator.augmentation_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.text_feature = TextFeaturizer(
unit_type=config.collator.unit_type, unit_type=config.collator.unit_type,
vocab_filepath=None, vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix) spm_model_prefix=config.collator.spm_model_prefix)
def setup_model(self): paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
config = self.config
model_conf = config.model
# model
model_conf = config.model
with UpdateConfig(model_conf): with UpdateConfig(model_conf):
model_conf.input_dim = self.collate_fn_test.feature_size model_conf.input_dim = config.collator.feat_dim
model_conf.output_dim = self.collate_fn_test.vocab_size model_conf.output_dim = self.text_feature.vocab_size
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
self.model = model self.model = model
logger.info("Setup model")
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
self.model.eval() self.model.eval()
cfg = self.config.decoding
audio_file = self.audio_file
collate_fn_test = self.collate_fn_test
audio, _ = collate_fn_test.process_utterance(
audio_file=audio_file, transcript="Hello")
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
vocab_list = collate_fn_test.vocab_list
text_feature = self.collate_fn_test.text_feature
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
logger.info("The result_transcripts: " + result_transcripts[0][0])
def run_test(self):
self.resume()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
#self.setup_output_dir() # load model
#self.setup_checkpointer()
#self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def resume(self):
"""Resume from the checkpoint at checkpoints in the output
directory or load a specified checkpoint.
"""
params_path = self.args.checkpoint_path + ".pdparams" params_path = self.args.checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path) model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def run(self):
check(args.audio_file)
with paddle.no_grad():
# read
audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
if sample_rate != self.sr:
logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
# fbank
feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}")
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
cfg = self.config.decoding
result_transcripts = self.model.decode(
xs,
ilen,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}")
return rsl
def check(audio_file): def check(audio_file):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......") logger.info("checking the audio file format......")
try: try:
sig, sample_rate = soundfile.read(audio_file) sig, sample_rate = soundfile.read(audio_file)
...@@ -144,15 +126,8 @@ def check(audio_file): ...@@ -144,15 +126,8 @@ def check(audio_file):
logger.info("The audio file format is right") logger.info("The audio file format is right")
def main_sp(config, args):
exp = U2Tester_Hub(config, args)
with exp.eval():
exp.setup()
exp.run_test()
def main(config, args): def main(config, args):
main_sp(config, args) U2Infer(config, args).run()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -163,25 +138,11 @@ if __name__ == "__main__": ...@@ -163,25 +138,11 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--audio_file", type=str, help="path of the input audio file") "--audio_file", type=str, help="path of the input audio file")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals())
if not os.path.isfile(args.audio_file):
print("Please input the right audio file path")
sys.exit(-1)
check(args.audio_file)
# https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults()
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.opts: if args.opts:
config.merge_from_list(args.opts) config.merge_from_list(args.opts)
config.freeze() config.freeze()
print(config) main(config, args)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats('test.profile')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册