diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 8dad5074896b7cdc78061d6bd090b670af55af7d..7eed9391088b3d8ed0d7b13644da5225bbb5e058 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -575,10 +575,10 @@ class U2Tester(U2Trainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align(self.config, - self.model, self.align_loader, self.config.decoding.batch_size, - self.config.collator.stride_ms, - self.vocab_list, self.args.result_file) + ctc_utils.ctc_align(self.config, self.model, self.align_loader, + self.config.decoding.batch_size, + self.config.collator.stride_ms, self.vocab_list, + self.args.result_file) def load_inferspec(self): """infer model and input spec. diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index 6c4365b860af3d3ddd4d585f1da5a3de06c570c5..d82034c8234df7bb621f2756dd047f40e0475e5c 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -528,10 +528,10 @@ class U2Tester(U2Trainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align(self.config, - self.model, self.align_loader, self.config.decoding.batch_size, - self.config.collator.stride_ms, - self.vocab_list, self.args.result_file) + ctc_utils.ctc_align(self.config, self.model, self.align_loader, + self.config.decoding.batch_size, + self.config.collator.stride_ms, self.vocab_list, + self.args.result_file) def load_inferspec(self): """infer model and input spec. diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 9141b3613deb2ece3680ad9fa45f0ef613986c66..91390afe5c6dbcff599b9d8cdd4ddc42e622f760 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -543,10 +543,10 @@ class U2STTester(U2STTrainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align(self.config, - self.model, self.align_loader, self.config.decoding.batch_size, - self.config.collator.stride_ms, - self.vocab_list, self.args.result_file) + ctc_utils.ctc_align(self.config, self.model, self.align_loader, + self.config.decoding.batch_size, + self.config.collator.stride_ms, self.vocab_list, + self.args.result_file) def load_inferspec(self): """infer model and input spec. diff --git a/paddlespeech/s2t/transform/cmvn.py b/paddlespeech/s2t/transform/cmvn.py index dc9ea87e5cf00776fb368ef80f8fbf61564b4548..aa1e6b4450f41103c1f0c9f2e723bfcdbe2cde9d 100644 --- a/paddlespeech/s2t/transform/cmvn.py +++ b/paddlespeech/s2t/transform/cmvn.py @@ -14,10 +14,12 @@ # Modified from espnet(https://github.com/espnet/espnet) import io import json + import h5py import kaldiio import numpy as np + class CMVN(): "Apply Global/Spk CMVN/iverserCMVN." @@ -158,11 +160,14 @@ class UtteranceCMVN(): return x - class GlobalCMVN(): "Apply Global CMVN" - def __init__(self, cmvn_path, norm_means=True, norm_vars=True, std_floor=1.0e-20): + def __init__(self, + cmvn_path, + norm_means=True, + norm_vars=True, + std_floor=1.0e-20): self.cmvn_path = cmvn_path self.norm_means = norm_means self.norm_vars = norm_vars @@ -189,4 +194,4 @@ class GlobalCMVN(): if self.norm_vars: x = np.divide(x, self.std) - return x \ No newline at end of file + return x diff --git a/paddlespeech/s2t/transform/perturb.py b/paddlespeech/s2t/transform/perturb.py index ee4c7ce02ea2be062d7ea59345b937e393c569be..873adb0b8ab2e5d67bb434fb6c1ab907114bc35d 100644 --- a/paddlespeech/s2t/transform/perturb.py +++ b/paddlespeech/s2t/transform/perturb.py @@ -17,6 +17,7 @@ import numpy import scipy import soundfile import soxbindings as sox + from paddlespeech.s2t.io.reader import SoundHDF5File @@ -171,6 +172,7 @@ class SpeedPerturbationSox(): upper={self.upper}, keep_length={self.keep_length}, sample_rate={self.sr})""" + else: return f"""{self.__class__.__name__}( utt2ratio={self.utt2ratio_file}, diff --git a/paddlespeech/s2t/transform/transformation.py b/paddlespeech/s2t/transform/transformation.py index bfe6c53d0c8776b643ec0baf763a2dff2f40befa..381b0cdc9d92c9d583bf357935dcf8ac9759c9aa 100644 --- a/paddlespeech/s2t/transform/transformation.py +++ b/paddlespeech/s2t/transform/transformation.py @@ -46,8 +46,7 @@ import_alias = dict( wpe="paddlespeech.s2t.transform.wpe:WPE", channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi", - cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN" -) + cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN") class Transformation(): diff --git a/paddlespeech/s2t/utils/ctc_utils.py b/paddlespeech/s2t/utils/ctc_utils.py index f5822e5dd7454d381af35bdfb82fb5f741eb7024..886b72033605e9080ebc7ae06e0a32054325be71 100644 --- a/paddlespeech/s2t/utils/ctc_utils.py +++ b/paddlespeech/s2t/utils/ctc_utils.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from wenet(https://github.com/wenet-e2e/wenet) -from typing import List from pathlib import Path +from typing import List + import numpy as np import paddle diff --git a/utils/remove_longshortdata.py b/utils/remove_longshortdata.py index dcc05b234e0ea875cde9830c05eadff4cd49d3cc..131b4a5828bee7dc3e2520ed1694e80230c57f56 100755 --- a/utils/remove_longshortdata.py +++ b/utils/remove_longshortdata.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 """remove longshort data from manifest""" -import logging import argparse +import logging + import jsonlines from paddlespeech.s2t.utils.cli_utils import get_commandline_args @@ -23,17 +24,19 @@ def get_parser(): parser.add_argument( "--verbose", "-V", default=0, type=int, help="Verbose option") parser.add_argument( - "--iaxis", default=0, type=int, help="multi inputs index, 0 is the first") - parser.add_argument( - "--oaxis", default=0, type=int, help="multi outputs index, 0 is the first") - parser.add_argument( - "--maxframes", default=2000, type=int, help="maxframes") - parser.add_argument( - "--minframes", default=10, type=int, help="minframes") + "--iaxis", + default=0, + type=int, + help="multi inputs index, 0 is the first") parser.add_argument( - "--maxchars", default=200, type=int, help="max tokens") - parser.add_argument( - "--minchars", default=0, type=int, help="min tokens") + "--oaxis", + default=0, + type=int, + help="multi outputs index, 0 is the first") + parser.add_argument("--maxframes", default=2000, type=int, help="maxframes") + parser.add_argument("--minframes", default=10, type=int, help="minframes") + parser.add_argument("--maxchars", default=200, type=int, help="max tokens") + parser.add_argument("--minchars", default=0, type=int, help="min tokens") parser.add_argument( "--stride_ms", default=10, type=int, help="stride in ms unit.") parser.add_argument( @@ -54,7 +57,7 @@ def filter_input(args, line): nframe = tmp['shape'][0] * 1000 / args.stride_ms else: nframe = tmp['shape'][0] - + if nframe < args.minframes or nframe > args.maxframes: return True else: @@ -67,7 +70,7 @@ def filter_output(args, line): return True else: return False - + def main(): args = get_parser().parse_args() @@ -78,15 +81,15 @@ def main(): else: logging.basicConfig(level=logging.WARN, format=logfmt) logging.info(get_commandline_args()) - + with jsonlines.open(args.rspecifier, 'r') as reader: lines = list(reader) logging.info(f"Example: {len(lines)}") feat = lines[0]['input'][args.iaxis]['feat'] - args.soud = False + args.soud = False if feat.split('.')[-1] not in 'ark, scp': args.sound = True - + count = 0 filter = 0 with jsonlines.open(args.wspecifier_or_wxfilename, 'w') as writer: @@ -98,5 +101,6 @@ def main(): count += 1 logging.info(f"Example after filter: {count}\{filter}") + if __name__ == '__main__': - main() \ No newline at end of file + main()