未验证 提交 df3be4ac 编写于 作者: H Hui Zhang 提交者: GitHub

[s2t] move s2t data preprocess into paddlespeech.dataset (#3189)

* move s2t data preprocess into paddlespeech.dataset

* avg model, compute wer, format rsl into paddlespeech.dataset

* fix format rsl

* fix avg ckpts
上级 8c7859d3
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi
set -e
stage=0
stop_stage=100
source utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
if [ $# != 3 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
......@@ -92,6 +98,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
echo "using sclite to compute cer..."
# format the reference test file for sclite
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
......
......@@ -28,6 +28,7 @@ import soundfile
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
from paddlespeech.utils.argparse import print_arguments
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
......@@ -139,7 +140,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):
def main():
print(f"args: {args}")
print_arguments(args, globals())
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
......
......@@ -28,6 +28,7 @@ import soundfile
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
from paddlespeech.utils.argparse import print_arguments
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
......@@ -205,7 +206,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path=None, check=False):
def main():
print(f"args: {args}")
print_arguments(args, globals())
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# s2t utils binaries.
from .avg_model import main as avg_ckpts_main
from .build_vocab import main as build_vocab_main
from .compute_mean_std import main as compute_mean_std_main
from .compute_wer import main as compute_wer_main
from .format_data import main as format_data_main
from .format_rsl import main as format_rsl_main
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os
import numpy as np
import paddle
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
return args
def average_checkpoints(dst_model="",
ckpt_dir="",
val_best=True,
num=5,
min_epoch=0,
max_epoch=65536):
paddle.set_device('cpu')
val_scores = []
jsons = glob.glob(f'{ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
for y in jsons:
with open(y, 'r') as f:
dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch']
if epoch >= min_epoch and epoch <= max_epoch:
val_scores.append((epoch, loss))
assert val_scores, f"Not find any valid checkpoints: {val_scores}"
val_scores = np.array(val_scores)
if val_best:
sort_idx = np.argsort(val_scores[:, 1])
sorted_val_scores = val_scores[sort_idx]
else:
sorted_val_scores = val_scores
beat_val_scores = sorted_val_scores[:num, 1]
selected_epochs = sorted_val_scores[:num, 0].astype(np.int64)
avg_val_score = np.mean(beat_val_scores)
print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
print("averaged val score = " + str(avg_val_score))
path_list = [
ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:num, 0]
]
print(path_list)
avg = None
num = args.num
assert num == len(path_list)
for path in path_list:
print(f'Processing {path}')
states = paddle.load(path)
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
avg[k] /= num
paddle.save(avg, args.dst_model)
print(f'Saving to {args.dst_model}')
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"val_loss_mean": avg_val_score,
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(),
})
f.write(data + "\n")
def main():
args = define_argparse()
average_checkpoints(args)
if __name__ == '__main__':
main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build vocabulary from manifest files.
Each item in vocabulary file is a character.
"""
import argparse
import functools
import os
import tempfile
from collections import Counter
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS
from paddlespeech.s2t.frontend.utility import SPACE
from paddlespeech.s2t.frontend.utility import UNK
from paddlespeech.utils.argparse import add_arguments
from paddlespeech.utils.argparse import print_arguments
def count_manifest(counter, text_feature, manifest_path):
manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons:
if isinstance(line_json['text'], str):
tokens = text_feature.tokenize(
line_json['text'], replace_space=False)
counter.update(tokens)
else:
assert isinstance(line_json['text'], list)
for text in line_json['text']:
tokens = text_feature.tokenize(text, replace_space=False)
counter.update(tokens)
def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons:
if isinstance(line_json[key], str):
fileobj.write(line_json[key] + "\n")
else:
assert isinstance(line_json[key], list)
for line in line_json[key]:
fileobj.write(line + "\n")
def build_vocab(manifest_paths="",
vocab_path="examples/librispeech/data/vocab.txt",
unit_type="char",
count_threshold=0,
text_keys='text',
spm_mode="unigram",
spm_vocab_size=0,
spm_model_prefix="",
spm_character_coverage=0.9995):
fout = open(vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1
if unit_type == 'spm':
# tools/spm_train --input=$wave_data/lang_char/input.txt
# --vocab_size=${nbpe} --model_type=${bpemode}
# --model_prefix=${bpemodel} --input_sentence_size=100000000
import sentencepiece as spm
fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
for manifest_path in manifest_paths:
_text_keys = [text_keys] if type(
text_keys) is not list else text_keys
for text_key in _text_keys:
dump_text_manifest(fp, manifest_path, key=text_key)
fp.close()
# train
spm.SentencePieceTrainer.Train(
input=fp.name,
vocab_size=spm_vocab_size,
model_type=spm_mode,
model_prefix=spm_model_prefix,
input_sentence_size=100000000,
character_coverage=spm_character_coverage)
os.unlink(fp.name)
# encode
text_feature = TextFeaturizer(unit_type, "", spm_model_prefix)
counter = Counter()
for manifest_path in manifest_paths:
count_manifest(counter, text_feature, manifest_path)
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
tokens = []
for token, count in count_sorted:
if count < count_threshold:
break
# replace space by `<space>`
token = SPACE if token == ' ' else token
tokens.append(token)
tokens = sorted(tokens)
for token in tokens:
fout.write(token + '\n')
fout.write(SOS + "\n") # <sos/eos>
fout.close()
def define_argparse():
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('count_threshold', int, 0,
"Truncation threshold for char/word counts.Default 0, no truncate.")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
add_arg('text_keys', str,
'text',
"keys of the text in manifest for building vocabulary. "
"You can provide multiple k.",
nargs='+')
# bpe
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols")
# yapf: disable
args = parser.parse_args()
return args
def main():
args = define_argparse()
print_arguments(args, globals())
build_vocab(**vars(args))
if __name__ == '__main__':
main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute mean and std for feature normalizer, and save to file."""
import argparse
import functools
from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline
from paddlespeech.s2t.frontend.featurizer.audio_featurizer import AudioFeaturizer
from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer
from paddlespeech.utils.argparse import add_arguments
from paddlespeech.utils.argparse import print_arguments
def compute_cmvn(manifest_path="data/librispeech/manifest.train",
output_path="data/librispeech/mean_std.npz",
num_samples=2000,
num_workers=0,
spectrum_type="linear",
feat_dim=13,
delta_delta=False,
stride_ms=10,
window_ms=20,
sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
augmentation_pipeline = AugmentationPipeline('{}')
audio_featurizer = AudioFeaturizer(
spectrum_type=spectrum_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=float(stride_ms),
window_ms=float(window_ms),
n_fft=None,
max_freq=None,
target_sample_rate=sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=0.0)
def augment_and_featurize(audio_segment):
augmentation_pipeline.transform_audio(audio_segment)
return audio_featurizer.featurize(audio_segment)
normalizer = FeatureNormalizer(
mean_std_filepath=None,
manifest_path=manifest_path,
featurize_func=augment_and_featurize,
num_samples=num_samples,
num_workers=num_workers)
normalizer.write_to_file(output_path)
def define_argparse():
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('manifest_path', str,
'data/librispeech/manifest.train',
"Filepath of manifest to compute normalizer's mean and stddev.")
add_arg('output_path', str,
'data/librispeech/mean_std.npz',
"Filepath of write mean and stddev to (.npz).")
add_arg('num_samples', int, 2000, "# of samples to for statistics.")
add_arg('num_workers',
default=0,
type=int,
help='num of subprocess workers for processing')
add_arg('spectrum_type', str,
'linear',
"Audio feature type. Options: linear, mfcc, fbank.",
choices=['linear', 'mfcc', 'fbank'])
add_arg('feat_dim', int, 13, "Audio feature dim.")
add_arg('delta_delta', bool, False, "Audio feature with delta delta.")
add_arg('stride_ms', int, 10, "stride length in ms.")
add_arg('window_ms', int, 20, "stride length in ms.")
add_arg('sample_rate', int, 16000, "target sample rate.")
add_arg('use_dB_normalization', bool, True, "do dB normalization.")
add_arg('target_dB', int, -20, "target dB.")
# yapf: disable
args = parser.parse_args()
return args
def main():
args = define_argparse()
print_arguments(args, globals())
compute_cmvn(**vars(args))
if __name__ == '__main__':
main()
此差异已折叠。
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""format manifest with more metadata."""
import argparse
import functools
import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.utils.argparse import add_arguments
from paddlespeech.utils.argparse import print_arguments
def define_argparse():
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
add_arg('output_path', str, None, "filepath of formated manifest.", required=True)
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.")
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath of the vocabulary.")
# bpe
add_arg('spm_model_prefix', str, None,
"spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm")
# yapf: disable
args = parser.parse_args()
return args
def format_data(
manifest_paths="",
output_path="",
cmvn_path="examples/librispeech/data/mean_std.json",
unit_type="char",
vocab_path="examples/librispeech/data/vocab.txt",
spm_model_prefix=""):
fout = open(output_path, 'w', encoding='utf-8')
# get feat dim
filetype = cmvn_path.split(".")[-1]
mean, istd = load_cmvn(cmvn_path, filetype=filetype)
feat_dim = mean.shape[0] #(D)
print(f"Feature dim: {feat_dim}")
text_feature = TextFeaturizer(unit_type, vocab_path, spm_model_prefix)
vocab_size = text_feature.vocab_size
print(f"Vocab size: {vocab_size}")
# josnline like this
# {
# "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}],
# "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}],
# "utt2spk": "111-2222",
# "utt": "111-2222-333"
# }
count = 0
for manifest_path in manifest_paths:
with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
for line_json in manifest_jsons:
output_json = {
"input": [],
"output": [],
'utt': line_json['utt'],
'utt2spk': line_json.get('utt2spk', 'global'),
}
# output
line = line_json['text']
if isinstance(line, str):
# only one target
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
output_json['output'].append({
'name': 'target1',
'shape': (len(tokenids), vocab_size),
'text': line,
'token': ' '.join(tokens),
'tokenid': ' '.join(map(str, tokenids)),
})
else:
# isinstance(line, list), multi target in one vocab
for i, item in enumerate(line, 1):
tokens = text_feature.tokenize(item)
tokenids = text_feature.featurize(item)
output_json['output'].append({
'name': f'target{i}',
'shape': (len(tokenids), vocab_size),
'text': item,
'token': ' '.join(tokens),
'tokenid': ' '.join(map(str, tokenids)),
})
# input
line = line_json['feat']
if isinstance(line, str):
# only one input
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
filetype = feat_type(line)
if filetype == 'sound':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
output_json['input'].append({
"name": "input1",
"shape": feat_shape,
"feat": line,
"filetype": filetype,
})
else:
# isinstance(line, list), multi input
raise NotImplementedError("not support multi input now!")
fout.write(json.dumps(output_json) + '\n')
count += 1
print(f"{manifest_paths} Examples number: {count}")
fout.close()
def main():
args = define_argparse()
print_arguments(args, globals())
format_data(**vars(args))
if __name__ == '__main__':
main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
format ref/hyp file for `utt text` format to compute CER/WER/MER.
norm:
BAC009S0764W0196 明确了发展目标和重点任务
BAC009S0764W0186 实现我国房地产市场的平稳运行
sclite:
加大对结构机械化环境和收集谈控机制力度(BAC009S0906W0240.wav)
河南省新乡市丰秋县刘光镇政府东五零左右(BAC009S0770W0441.wav)
"""
import argparse
import jsonlines
from paddlespeech.utils.argparse import print_arguments
def transform_hyp(origin, trans, trans_sclite):
"""
Args:
origin: The input json file which contains the model output
trans: The output file for caculate CER/WER
trans_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin, "r+", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["hyps"][0]
if trans:
with open(trans, "w+", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
print(f"transform_hyp output: {trans}")
if trans_sclite:
with open(trans_sclite, "w+") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line)
print(f"transform_hyp output: {trans_sclite}")
def transform_ref(origin, trans, trans_sclite):
"""
Args:
origin: The input json file which contains the model output
trans: The output file for caculate CER/WER
trans_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin, "r", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["text"]
if trans:
with open(trans, "w", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
print(f"transform_hyp output: {trans}")
if trans_sclite:
with open(trans_sclite, "w") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line)
print(f"transform_hyp output: {trans_sclite}")
def define_argparse():
parser = argparse.ArgumentParser(
prog='format ref/hyp file for compute CER/WER', add_help=True)
parser.add_argument(
'--origin_hyp', type=str, default="", help='origin hyp file')
parser.add_argument(
'--trans_hyp',
type=str,
default="",
help='hyp file for caculating CER/WER')
parser.add_argument(
'--trans_hyp_sclite',
type=str,
default="",
help='hyp file for caculating CER/WER by sclite')
parser.add_argument(
'--origin_ref', type=str, default="", help='origin ref file')
parser.add_argument(
'--trans_ref',
type=str,
default="",
help='ref file for caculating CER/WER')
parser.add_argument(
'--trans_ref_sclite',
type=str,
default="",
help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args()
return parser_args
def format_result(origin_hyp="",
trans_hyp="",
trans_hyp_sclite="",
origin_ref="",
trans_ref="",
trans_ref_sclite=""):
if origin_hyp:
transform_hyp(
origin=origin_hyp, trans=trans_hyp, trans_sclite=trans_hyp_sclite)
if origin_ref:
transform_ref(
origin=origin_ref, trans=trans_ref, trans_sclite=trans_ref_sclite)
def main():
args = define_argparse()
print_arguments(args, globals())
format_result(**vars(args))
if __name__ == "__main__":
main()
......@@ -28,8 +28,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.socket_server import AsrRequestHandler
from paddlespeech.s2t.utils.socket_server import AsrTCPServer
from paddlespeech.s2t.utils.socket_server import warm_up_test
from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import add_arguments
from paddlespeech.utils.argparse import print_arguments
def init_predictor(args):
......
......@@ -26,8 +26,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.socket_server import AsrRequestHandler
from paddlespeech.s2t.utils.socket_server import AsrTCPServer
from paddlespeech.s2t.utils.socket_server import warm_up_test
from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import add_arguments
from paddlespeech.utils.argparse import print_arguments
def start_server(config, args):
......
......@@ -16,7 +16,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -16,7 +16,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -16,7 +16,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -27,8 +27,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.checkpoint import Checkpoint
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.utils.argparse import print_arguments
logger = Log(__name__).getlog()
......
......@@ -16,7 +16,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -16,7 +16,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -16,7 +16,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -18,7 +18,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -19,7 +19,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
# from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer
......
......@@ -18,7 +18,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
model_test_alias = {
"u2": "paddlespeech.s2t.exps.u2.model:U2Tester",
......
......@@ -19,7 +19,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
model_train_alias = {
"u2": "paddlespeech.s2t.exps.u2.model:U2Trainer",
......
......@@ -16,7 +16,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -18,7 +18,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -19,7 +19,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2_st.model import U2STTrainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -18,7 +18,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -19,7 +19,7 @@ from yacs.config import CfgNode
from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTrainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import print_arguments
def main_sp(config, args):
......
......@@ -48,13 +48,16 @@ class TextFeaturizer():
self.unit_type = unit_type
self.unk = UNK
self.maskctc = maskctc
self.vocab_path_or_list = vocab
if vocab:
if self.vocab_path_or_list:
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
vocab, maskctc)
self.vocab_size = len(self.vocab_list)
else:
logger.warning("TextFeaturizer: not have vocab file or vocab list.")
logger.warning(
"TextFeaturizer: not have vocab file or vocab list. Only Tokenizer can use, can not convert to token idx"
)
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
......@@ -62,6 +65,7 @@ class TextFeaturizer():
self.sp.Load(spm_model)
def tokenize(self, text, replace_space=True):
"""tokenizer split text into text tokens"""
if self.unit_type == 'char':
tokens = self.char_tokenize(text, replace_space)
elif self.unit_type == 'word':
......@@ -71,6 +75,7 @@ class TextFeaturizer():
return tokens
def detokenize(self, tokens):
"""tokenizer convert text tokens back to text"""
if self.unit_type == 'char':
text = self.char_detokenize(tokens)
elif self.unit_type == 'word':
......@@ -88,6 +93,7 @@ class TextFeaturizer():
Returns:
List[int]: List of token indices.
"""
assert self.vocab_path_or_list, "toidx need vocab path or vocab list"
tokens = self.tokenize(text)
ids = []
for token in tokens:
......@@ -107,6 +113,7 @@ class TextFeaturizer():
Returns:
str: Text.
"""
assert self.vocab_path_or_list, "toidx need vocab path or vocab list"
tokens = []
for idx in idxs:
if idx == self.eos_id:
......@@ -127,10 +134,10 @@ class TextFeaturizer():
"""
text = text.strip()
if replace_space:
text_list = [SPACE if item == " " else item for item in list(text)]
tokens = [SPACE if item == " " else item for item in list(text)]
else:
text_list = list(text)
return text_list
tokens = list(text)
return tokens
def char_detokenize(self, tokens):
"""Character detokenizer.
......
......@@ -29,10 +29,7 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"all_version", "UpdateConfig", "seed_all", 'print_arguments',
'add_arguments', "log_add"
]
__all__ = ["all_version", "UpdateConfig", "seed_all", "log_add"]
def all_version():
......@@ -60,51 +57,6 @@ def seed_all(seed: int=20210329):
paddle.seed(seed)
def print_arguments(args, info=None):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
filename = ""
if info:
filename = info["__file__"]
filename = os.path.basename(filename)
print(f"----------- {filename} Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("-----------------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def log_add(args: List[int]) -> float:
"""Stable log add
......
......@@ -16,6 +16,8 @@ import os
import sys
from typing import Text
import distutils
__all__ = ["print_arguments", "add_arguments", "get_commandline_args"]
......
......@@ -12,105 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os
import numpy as np
import paddle
def main(args):
paddle.set_device('cpu')
val_scores = []
beat_val_scores = None
selected_epochs = None
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
for y in jsons:
with open(y, 'r') as f:
dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss))
val_scores = np.array(val_scores)
if args.val_best:
sort_idx = np.argsort(val_scores[:, 1])
sorted_val_scores = val_scores[sort_idx]
else:
sorted_val_scores = val_scores
beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
avg_val_score = np.mean(beat_val_scores)
print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
print("averaged val score = " + str(avg_val_score))
path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
print(path_list)
avg = None
num = args.num
assert num == len(path_list)
for path in path_list:
print(f'Processing {path}')
states = paddle.load(path)
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
avg[k] /= num
paddle.save(avg, args.dst_model)
print(f'Saving to {args.dst_model}')
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"val_loss_mean": avg_val_score,
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(),
})
f.write(data + "\n")
from paddlespeech.dataset.s2t import avg_ckpts_main
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
print(args)
main(args)
avg_ckpts_main()
......@@ -15,134 +15,7 @@
"""Build vocabulary from manifest files.
Each item in vocabulary file is a character.
"""
import argparse
import functools
import os
import tempfile
from collections import Counter
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS
from paddlespeech.s2t.frontend.utility import SPACE
from paddlespeech.s2t.frontend.utility import UNK
from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('count_threshold', int, 0,
"Truncation threshold for char/word counts.Default 0, no truncate.")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
add_arg('text_keys', str,
'text',
"keys of the text in manifest for building vocabulary. "
"You can provide multiple k.",
nargs='+')
# bpe
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols")
# yapf: disable
args = parser.parse_args()
def count_manifest(counter, text_feature, manifest_path):
manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons:
if isinstance(line_json['text'], str):
line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line)
else:
assert isinstance(line_json['text'], list)
for text in line_json['text']:
line = text_feature.tokenize(text, replace_space=False)
counter.update(line)
def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons:
if isinstance(line_json[key], str):
fileobj.write(line_json[key] + "\n")
else:
assert isinstance(line_json[key], list)
for line in line_json[key]:
fileobj.write(line + "\n")
def main():
print_arguments(args, globals())
fout = open(args.vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1
if args.unit_type == 'spm':
# tools/spm_train --input=$wave_data/lang_char/input.txt
# --vocab_size=${nbpe} --model_type=${bpemode}
# --model_prefix=${bpemodel} --input_sentence_size=100000000
import sentencepiece as spm
fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
for manifest_path in args.manifest_paths:
text_keys = [args.text_keys] if type(args.text_keys) is not list else args.text_keys
for text_key in text_keys:
dump_text_manifest(fp, manifest_path, key=text_key)
fp.close()
# train
spm.SentencePieceTrainer.Train(
input=fp.name,
vocab_size=args.spm_vocab_size,
model_type=args.spm_mode,
model_prefix=args.spm_model_prefix,
input_sentence_size=100000000,
character_coverage=args.spm_character_coverage)
os.unlink(fp.name)
# encode
text_feature = TextFeaturizer(args.unit_type, "", args.spm_model_prefix)
counter = Counter()
for manifest_path in args.manifest_paths:
count_manifest(counter, text_feature, manifest_path)
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
tokens = []
for token, count in count_sorted:
if count < args.count_threshold:
break
# replace space by `<space>`
token = SPACE if token == ' ' else token
tokens.append(token)
tokens = sorted(tokens)
for token in tokens:
fout.write(token + '\n')
fout.write(SOS + "\n") # <sos/eos>
fout.close()
from paddlespeech.dataset.s2t import build_vocab_main
if __name__ == '__main__':
main()
build_vocab_main()
此差异已折叠。
......@@ -13,75 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute mean and std for feature normalizer, and save to file."""
import argparse
import functools
from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline
from paddlespeech.s2t.frontend.featurizer.audio_featurizer import AudioFeaturizer
from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer
from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('num_samples', int, 2000, "# of samples to for statistics.")
add_arg('spectrum_type', str,
'linear',
"Audio feature type. Options: linear, mfcc, fbank.",
choices=['linear', 'mfcc', 'fbank'])
add_arg('feat_dim', int, 13, "Audio feature dim.")
add_arg('delta_delta', bool, False, "Audio feature with delta delta.")
add_arg('stride_ms', int, 10, "stride length in ms.")
add_arg('window_ms', int, 20, "stride length in ms.")
add_arg('sample_rate', int, 16000, "target sample rate.")
add_arg('use_dB_normalization', bool, True, "do dB normalization.")
add_arg('target_dB', int, -20, "target dB.")
add_arg('manifest_path', str,
'data/librispeech/manifest.train',
"Filepath of manifest to compute normalizer's mean and stddev.")
add_arg('num_workers',
default=0,
type=int,
help='num of subprocess workers for processing')
add_arg('output_path', str,
'data/librispeech/mean_std.npz',
"Filepath of write mean and stddev to (.npz).")
# yapf: disable
args = parser.parse_args()
def main():
print_arguments(args, globals())
augmentation_pipeline = AugmentationPipeline('{}')
audio_featurizer = AudioFeaturizer(
spectrum_type=args.spectrum_type,
feat_dim=args.feat_dim,
delta_delta=args.delta_delta,
stride_ms=float(args.stride_ms),
window_ms=float(args.window_ms),
n_fft=None,
max_freq=None,
target_sample_rate=args.sample_rate,
use_dB_normalization=args.use_dB_normalization,
target_dB=args.target_dB,
dither=0.0)
def augment_and_featurize(audio_segment):
augmentation_pipeline.transform_audio(audio_segment)
return audio_featurizer.featurize(audio_segment)
normalizer = FeatureNormalizer(
mean_std_filepath=None,
manifest_path=args.manifest_path,
featurize_func=augment_and_featurize,
num_samples=args.num_samples,
num_workers=args.num_workers)
normalizer.write_to_file(args.output_path)
from paddlespeech.dataset.s2t import compute_mean_std_main
if __name__ == '__main__':
main()
compute_mean_std_main()
......@@ -13,130 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""format manifest with more metadata."""
import argparse
import functools
import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.")
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath of the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
# bpe
add_arg('spm_model_prefix', str, None,
"spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm")
add_arg('output_path', str, None, "filepath of formated manifest.", required=True)
# yapf: disable
args = parser.parse_args()
def main():
print_arguments(args, globals())
fout = open(args.output_path, 'w', encoding='utf-8')
# get feat dim
filetype = args.cmvn_path.split(".")[-1]
mean, istd = load_cmvn(args.cmvn_path, filetype=filetype)
feat_dim = mean.shape[0] #(D)
print(f"Feature dim: {feat_dim}")
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
vocab_size = text_feature.vocab_size
print(f"Vocab size: {vocab_size}")
# josnline like this
# {
# "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}],
# "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}],
# "utt2spk": "111-2222",
# "utt": "111-2222-333"
# }
count = 0
for manifest_path in args.manifest_paths:
with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
for line_json in manifest_jsons:
output_json = {
"input": [],
"output": [],
'utt': line_json['utt'],
'utt2spk': line_json.get('utt2spk', 'global'),
}
# output
line = line_json['text']
if isinstance(line, str):
# only one target
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
output_json['output'].append({
'name': 'target1',
'shape': (len(tokenids), vocab_size),
'text': line,
'token': ' '.join(tokens),
'tokenid': ' '.join(map(str, tokenids)),
})
else:
# isinstance(line, list), multi target in one vocab
for i, item in enumerate(line, 1):
tokens = text_feature.tokenize(item)
tokenids = text_feature.featurize(item)
output_json['output'].append({
'name': f'target{i}',
'shape': (len(tokenids), vocab_size),
'text': item,
'token': ' '.join(tokens),
'tokenid': ' '.join(map(str, tokenids)),
})
# input
line = line_json['feat']
if isinstance(line, str):
# only one input
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
filetype = feat_type(line)
if filetype == 'sound':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
output_json['input'].append({
"name": "input1",
"shape": feat_shape,
"feat": line,
"filetype": filetype,
})
else:
# isinstance(line, list), multi input
raise NotImplementedError("not support multi input now!")
fout.write(json.dumps(output_json) + '\n')
count += 1
print(f"{args.manifest_paths} Examples number: {count}")
fout.close()
from paddlespeech.dataset.s2t import format_data_main
if __name__ == '__main__':
main()
format_data_main()
......@@ -11,96 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from paddlespeech.dataset.s2t import format_rsl_main
import jsonlines
def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin_hyp, "r+", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["hyps"][0]
if trans_hyp is not None:
with open(trans_hyp, "w+", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_hyp_sclite is not None:
with open(trans_hyp_sclite, "w+") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line)
def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin_ref, "r", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["text"]
if trans_ref is not None:
with open(trans_ref, "w", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_ref_sclite is not None:
with open(trans_ref_sclite, "w") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='format hyp file for compute CER/WER', add_help=True)
parser.add_argument(
'--origin_hyp', type=str, default=None, help='origin hyp file')
parser.add_argument(
'--trans_hyp',
type=str,
default=None,
help='hyp file for caculating CER/WER')
parser.add_argument(
'--trans_hyp_sclite',
type=str,
default=None,
help='hyp file for caculating CER/WER by sclite')
parser.add_argument(
'--origin_ref', type=str, default=None, help='origin ref file')
parser.add_argument(
'--trans_ref',
type=str,
default=None,
help='ref file for caculating CER/WER')
parser.add_argument(
'--trans_ref_sclite',
type=str,
default=None,
help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args()
if parser_args.origin_hyp is not None:
trans_hyp(
origin_hyp=parser_args.origin_hyp,
trans_hyp=parser_args.trans_hyp,
trans_hyp_sclite=parser_args.trans_hyp_sclite, )
if parser_args.origin_ref is not None:
trans_ref(
origin_ref=parser_args.origin_ref,
trans_ref=parser_args.trans_ref,
trans_ref_sclite=parser_args.trans_ref_sclite, )
if __name__ == '__main__':
format_rsl_main()
......@@ -22,8 +22,8 @@ import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.utils.argparse import add_arguments
from paddlespeech.utils.argparse import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册