提交 3845804c 编写于 作者: H huangyuxin

Merge branch 'develop' of https://github.com/PaddlePaddle/DeepSpeech into Setup

...@@ -463,7 +463,6 @@ Normally, [Speech SoTA](https://paperswithcode.com/area/speech), [Audio SoTA](ht ...@@ -463,7 +463,6 @@ Normally, [Speech SoTA](https://paperswithcode.com/area/speech), [Audio SoTA](ht
- [Automatic Speech Recognition](./docs/source/asr/quick_start.md) - [Automatic Speech Recognition](./docs/source/asr/quick_start.md)
- [Introduction](./docs/source/asr/models_introduction.md) - [Introduction](./docs/source/asr/models_introduction.md)
- [Data Preparation](./docs/source/asr/data_preparation.md) - [Data Preparation](./docs/source/asr/data_preparation.md)
- [Data Augmentation](./docs/source/asr/augmentation.md)
- [Ngram LM](./docs/source/asr/ngram_lm.md) - [Ngram LM](./docs/source/asr/ngram_lm.md)
- [Text-to-Speech](./docs/source/tts/quick_start.md) - [Text-to-Speech](./docs/source/tts/quick_start.md)
- [Introduction](./docs/source/tts/models_introduction.md) - [Introduction](./docs/source/tts/models_introduction.md)
......
...@@ -468,7 +468,6 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 ...@@ -468,7 +468,6 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
- [语音识别自定义训练](./docs/source/asr/quick_start.md) - [语音识别自定义训练](./docs/source/asr/quick_start.md)
- [简介](./docs/source/asr/models_introduction.md) - [简介](./docs/source/asr/models_introduction.md)
- [数据准备](./docs/source/asr/data_preparation.md) - [数据准备](./docs/source/asr/data_preparation.md)
- [数据增强](./docs/source/asr/augmentation.md)
- [Ngram 语言模型](./docs/source/asr/ngram_lm.md) - [Ngram 语言模型](./docs/source/asr/ngram_lm.md)
- [语音合成自定义训练](./docs/source/tts/quick_start.md) - [语音合成自定义训练](./docs/source/tts/quick_start.md)
- [简介](./docs/source/tts/models_introduction.md) - [简介](./docs/source/tts/models_introduction.md)
......
# Data Augmentation Pipeline
Data augmentation has often been a highly effective technique to boost deep learning performance. We augment our speech data by synthesizing new audios with small random perturbation (label-invariant transformation) added upon raw audios. You don't have to do the syntheses on your own, as it is already embedded into the data provider and is done on the fly, randomly for each epoch during training.
Six optional augmentation components are provided to be selected, configured, and inserted into the processing pipeline.
* Audio
- Volume Perturbation
- Speed Perturbation
- Shifting Perturbation
- Online Bayesian normalization
- Noise Perturbation (need background noise audio files)
- Impulse Response (need impulse audio files)
* Feature
- SpecAugment
- Adaptive SpecAugment
To inform the trainer of what augmentation components are needed and what their processing orders are, it is required to prepare in advance an *augmentation configuration file* in [JSON](http://www.json.org/) format. For example:
```
[{
"type": "speed",
"params": {"min_speed_rate": 0.95,
"max_speed_rate": 1.05},
"prob": 0.6
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 0.8
}]
```
When the `augment_conf_file` argument is set to the path of the above example configuration file, every audio clip in every epoch will be processed: with 60% of chance, it will first be speed perturbed with a uniformly random sampled speed-rate between 0.95 and 1.05, and then with 80% of chance it will be shifted in time with a randomly sampled offset between -5 ms and 5 ms. Finally, this newly synthesized audio clip will be fed into the feature extractor for further training.
For other configuration examples, please refer to `examples/conf/augmentation.example.json`.
Be careful when utilizing the data augmentation technique, as improper augmentation will harm the training, due to the enlarged train-test gap.
...@@ -27,7 +27,6 @@ Contents ...@@ -27,7 +27,6 @@ Contents
asr/models_introduction asr/models_introduction
asr/data_preparation asr/data_preparation
asr/augmentation
asr/feature_list asr/feature_list
asr/ngram_lm asr/ngram_lm
......
...@@ -257,6 +257,7 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ ...@@ -257,6 +257,7 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=exp/default/test_e2e \ --output_dir=exp/default/test_e2e \
--phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \ --phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \
--speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \ --speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \
--spk_id=0 --spk_id=0 \
--inference_dir=exp/default/inference
``` ```
#!/bin/bash
train_output_path=$1
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=fastspeech2_aishell3 \
--voc=pwgan_aishell3 \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--spk_id=0
fi
...@@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ ...@@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \ --speaker_dict=dump/speaker_id_map.txt \
--spk_id=0 --spk_id=0 \
--inference_dir=${train_output_path}/inference
...@@ -240,13 +240,14 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ ...@@ -240,13 +240,14 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--am_ckpt=fastspeech2_nosil_vctk_ckpt_0.5/snapshot_iter_66200.pdz \ --am_ckpt=fastspeech2_nosil_vctk_ckpt_0.5/snapshot_iter_66200.pdz \
--am_stat=fastspeech2_nosil_vctk_ckpt_0.5/speech_stats.npy \ --am_stat=fastspeech2_nosil_vctk_ckpt_0.5/speech_stats.npy \
--voc=pwgan_vctk \ --voc=pwgan_vctk \
--voc_config=pwg_vctk_ckpt_0.5/pwg_default.yaml \ --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \
--voc_ckpt=pwg_vctk_ckpt_0.5/pwg_snapshot_iter_1000000.pdz \ --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \
--voc_stat=pwg_vctk_ckpt_0.5/pwg_stats.npy \ --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \
--lang=en \ --lang=en \
--text=${BIN_DIR}/../sentences_en.txt \ --text=${BIN_DIR}/../sentences_en.txt \
--output_dir=exp/default/test_e2e \ --output_dir=exp/default/test_e2e \
--phones_dict=fastspeech2_nosil_vctk_ckpt_0.5/phone_id_map.txt \ --phones_dict=fastspeech2_nosil_vctk_ckpt_0.5/phone_id_map.txt \
--speaker_dict=fastspeech2_nosil_vctk_ckpt_0.5/speaker_id_map.txt \ --speaker_dict=fastspeech2_nosil_vctk_ckpt_0.5/speaker_id_map.txt \
--spk_id=0 --spk_id=0 \
--inference_dir=exp/default/inference
``` ```
#!/bin/bash
train_output_path=$1
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=fastspeech2_vctk \
--voc=pwgan_vctk \
--text=${BIN_DIR}/../sentences_en.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--spk_id=0 \
--lang=en
fi
...@@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ ...@@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \ --speaker_dict=dump/speaker_id_map.txt \
--spk_id=0 --spk_id=0 \
--inference_dir=${train_output_path}/inference
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
import numpy
import soundfile as sf import soundfile as sf
from paddle import inference from paddle import inference
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
...@@ -29,20 +31,38 @@ def main(): ...@@ -29,20 +31,38 @@ def main():
'--am', '--am',
type=str, type=str,
default='fastspeech2_csmsc', default='fastspeech2_csmsc',
choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'], choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_aishell3',
'fastspeech2_vctk'
],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.") "--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument( parser.add_argument(
"--tones_dict", type=str, default=None, help="tone vocabulary file.") "--tones_dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
'--spk_id',
type=int,
default=0,
help='spk id for multi speaker acoustic model')
# voc # voc
parser.add_argument( parser.add_argument(
'--voc', '--voc',
type=str, type=str,
default='pwgan_csmsc', default='pwgan_csmsc',
choices=['pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc'], choices=[
'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3',
'pwgan_vctk'
],
help='Choose vocoder type of tts task.') help='Choose vocoder type of tts task.')
# other # other
parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
parser.add_argument( parser.add_argument(
"--text", "--text",
type=str, type=str,
...@@ -53,8 +73,12 @@ def main(): ...@@ -53,8 +73,12 @@ def main():
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
frontend = Frontend( # frontend
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) if args.lang == 'zh':
frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
elif args.lang == 'en':
frontend = English(phone_vocab_path=args.phones_dict)
print("frontend done!") print("frontend done!")
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
...@@ -83,30 +107,53 @@ def main(): ...@@ -83,30 +107,53 @@ def main():
print("in new inference") print("in new inference")
# construct dataset for evaluation
sentences = []
with open(args.text, 'rt') as f: with open(args.text, 'rt') as f:
for line in f: for line in f:
items = line.strip().split() items = line.strip().split()
utt_id = items[0] utt_id = items[0]
sentence = "".join(items[1:]) if args.lang == 'zh':
sentence = "".join(items[1:])
elif args.lang == 'en':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence)) sentences.append((utt_id, sentence))
get_tone_ids = False get_tone_ids = False
get_spk_id = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
get_spk_id = True
spk_id = numpy.array([args.spk_id])
am_input_names = am_predictor.get_input_names() am_input_names = am_predictor.get_input_names()
print("am_input_names:", am_input_names)
merge_sentences = True
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
input_ids = frontend.get_input_ids( if args.lang == 'zh':
sentence, merge_sentences=True, get_tone_ids=get_tone_ids) input_ids = frontend.get_input_ids(
phone_ids = input_ids["phone_ids"] sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
tones = tone_ids[0].numpy() tones = tone_ids[0].numpy()
tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle = am_predictor.get_input_handle(am_input_names[1])
tones_handle.reshape(tones.shape) tones_handle.reshape(tones.shape)
tones_handle.copy_from_cpu(tones) tones_handle.copy_from_cpu(tones)
if get_spk_id:
spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
spk_id_handle.reshape(spk_id.shape)
spk_id_handle.copy_from_cpu(spk_id)
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle = am_predictor.get_input_handle(am_input_names[0])
phones_handle.reshape(phones.shape) phones_handle.reshape(phones.shape)
......
...@@ -159,9 +159,16 @@ def evaluate(args): ...@@ -159,9 +159,16 @@ def evaluate(args):
# acoustic model # acoustic model
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
print( am_inference = jit.to_static(
"Haven't test dygraph to static for multi speaker fastspeech2 now!" am_inference,
) input_spec=[
InputSpec([-1], dtype=paddle.int64),
InputSpec([1], dtype=paddle.int64)
])
paddle.jit.save(am_inference,
os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load(
os.path.join(args.inference_dir, args.am))
else: else:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
......
...@@ -781,7 +781,7 @@ class FastSpeech2(nn.Layer): ...@@ -781,7 +781,7 @@ class FastSpeech2(nn.Layer):
elif self.spk_embed_integration_type == "concat": elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection # concat hidden states with spk embeds and then apply projection
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, hs.shape[1], -1]) shape=[-1, paddle.shape(hs)[1], -1])
hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1)) hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))
else: else:
raise NotImplementedError("support only add or concat.") raise NotImplementedError("support only add or concat.")
......
...@@ -86,11 +86,13 @@ requirements = { ...@@ -86,11 +86,13 @@ requirements = {
def write_version_py(filename='paddlespeech/__init__.py'): def write_version_py(filename='paddlespeech/__init__.py'):
import paddlespeech import paddlespeech
if hasattr(paddlespeech, "__version__") and paddlespeech.__version__ == VERSION: if hasattr(paddlespeech,
"__version__") and paddlespeech.__version__ == VERSION:
return return
with open(filename, "a") as f: with open(filename, "a") as f:
f.write(f"\n__version__ = '{VERSION}'\n") f.write(f"\n__version__ = '{VERSION}'\n")
def remove_version_py(filename='paddlespeech/__init__.py'): def remove_version_py(filename='paddlespeech/__init__.py'):
with open(filename, "r") as f: with open(filename, "r") as f:
lines = f.readlines() lines = f.readlines()
...@@ -256,8 +258,4 @@ setup_info = dict( ...@@ -256,8 +258,4 @@ setup_info = dict(
setup(**setup_info) setup(**setup_info)
remove_version_py() remove_version_py()
...@@ -19,11 +19,13 @@ VERSION = '0.1.0' ...@@ -19,11 +19,13 @@ VERSION = '0.1.0'
def write_version_py(filename='paddleaudio/__init__.py'): def write_version_py(filename='paddleaudio/__init__.py'):
import paddleaudio import paddleaudio
if hasattr(paddleaudio, "__version__") and paddleaudio.__version__ == VERSION: if hasattr(paddleaudio,
"__version__") and paddleaudio.__version__ == VERSION:
return return
with open(filename, "a") as f: with open(filename, "a") as f:
f.write(f"\n__version__ = '{VERSION}'\n") f.write(f"\n__version__ = '{VERSION}'\n")
def remove_version_py(filename='paddleaudio/__init__.py'): def remove_version_py(filename='paddleaudio/__init__.py'):
with open(filename, "r") as f: with open(filename, "r") as f:
lines = f.readlines() lines = f.readlines()
...@@ -32,6 +34,7 @@ def remove_version_py(filename='paddleaudio/__init__.py'): ...@@ -32,6 +34,7 @@ def remove_version_py(filename='paddleaudio/__init__.py'):
if "__version__" not in line: if "__version__" not in line:
f.write(line) f.write(line)
write_version_py() write_version_py()
setuptools.setup( setuptools.setup(
...@@ -58,4 +61,4 @@ setuptools.setup( ...@@ -58,4 +61,4 @@ setuptools.setup(
'colorlog', 'colorlog',
], ) ], )
remove_version_py() remove_version_py()
\ No newline at end of file
#!/usr/bin/env python3 #!/usr/bin/env python3
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
''' '''
Merge training configs into a single inference config. Merge training configs into a single inference config.
The single inference config is for CLI, which only takes a single config to do inferencing. The single inference config is for CLI, which only takes a single config to do inferencing.
The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file. The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file.
''' '''
import yaml
import json
import os
import argparse import argparse
import json
import math import math
import os
from contextlib import redirect_stdout
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.frontend.utility import load_dict from paddlespeech.s2t.frontend.utility import load_dict
from contextlib import redirect_stdout
def save(save_path, config): def save(save_path, config):
...@@ -29,18 +27,21 @@ def load(save_path): ...@@ -29,18 +27,21 @@ def load(save_path):
config.merge_from_file(save_path) config.merge_from_file(save_path)
return config return config
def load_json(json_path): def load_json(json_path):
with open(json_path) as f: with open(json_path) as f:
json_content = json.load(f) json_content = json.load(f)
return json_content return json_content
def remove_config_part(config, key_list): def remove_config_part(config, key_list):
if len(key_list) == 0: if len(key_list) == 0:
return return
for i in range(len(key_list) -1): for i in range(len(key_list) - 1):
config = config[key_list[i]] config = config[key_list[i]]
config.pop(key_list[-1]) config.pop(key_list[-1])
def load_cmvn_from_json(cmvn_stats): def load_cmvn_from_json(cmvn_stats):
means = cmvn_stats['mean_stat'] means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat'] variance = cmvn_stats['var_stat']
...@@ -51,17 +52,17 @@ def load_cmvn_from_json(cmvn_stats): ...@@ -51,17 +52,17 @@ def load_cmvn_from_json(cmvn_stats):
if variance[i] < 1.0e-20: if variance[i] < 1.0e-20:
variance[i] = 1.0e-20 variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i]) variance[i] = 1.0 / math.sqrt(variance[i])
cmvn_stats = {"mean":means, "istd":variance} cmvn_stats = {"mean": means, "istd": variance}
return cmvn_stats return cmvn_stats
def merge_configs( def merge_configs(
conf_path = "conf/conformer.yaml", conf_path="conf/conformer.yaml",
preprocess_path = "conf/preprocess.yaml", preprocess_path="conf/preprocess.yaml",
decode_path = "conf/tuning/decode.yaml", decode_path="conf/tuning/decode.yaml",
vocab_path = "data/vocab.txt", vocab_path="data/vocab.txt",
cmvn_path = "data/mean_std.json", cmvn_path="data/mean_std.json",
save_path = "conf/conformer_infer.yaml", save_path="conf/conformer_infer.yaml", ):
):
# Load the configs # Load the configs
config = load(conf_path) config = load(conf_path)
...@@ -72,17 +73,16 @@ def merge_configs( ...@@ -72,17 +73,16 @@ def merge_configs(
if cmvn_path.split(".")[-1] == 'json': if cmvn_path.split(".")[-1] == 'json':
cmvn_stats = load_json(cmvn_path) cmvn_stats = load_json(cmvn_path)
if os.path.exists(preprocess_path): if os.path.exists(preprocess_path):
preprocess_config = load(preprocess_path) preprocess_config = load(preprocess_path)
for idx, process in enumerate(preprocess_config["process"]): for idx, process in enumerate(preprocess_config["process"]):
if process['type'] == "cmvn_json": if process['type'] == "cmvn_json":
preprocess_config["process"][idx][ preprocess_config["process"][idx]["cmvn_path"] = cmvn_stats
"cmvn_path"] = cmvn_stats
break break
config.preprocess_config = preprocess_config config.preprocess_config = preprocess_config
else: else:
cmvn_stats = load_cmvn_from_json(cmvn_stats) cmvn_stats = load_cmvn_from_json(cmvn_stats)
config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] config.mean_std_filepath = [{"cmvn_stats": cmvn_stats}]
config.augmentation_config = '' config.augmentation_config = ''
# the cmvn file is end with .ark # the cmvn file is end with .ark
else: else:
...@@ -95,7 +95,8 @@ def merge_configs( ...@@ -95,7 +95,8 @@ def merge_configs(
# Remove some parts of the config # Remove some parts of the config
if os.path.exists(preprocess_path): if os.path.exists(preprocess_path):
remove_train_list = ["train_manifest", remove_train_list = [
"train_manifest",
"dev_manifest", "dev_manifest",
"test_manifest", "test_manifest",
"n_epoch", "n_epoch",
...@@ -124,9 +125,10 @@ def merge_configs( ...@@ -124,9 +125,10 @@ def merge_configs(
"batch_size", "batch_size",
"maxlen_in", "maxlen_in",
"maxlen_out", "maxlen_out",
] ]
else: else:
remove_train_list = ["train_manifest", remove_train_list = [
"train_manifest",
"dev_manifest", "dev_manifest",
"test_manifest", "test_manifest",
"n_epoch", "n_epoch",
...@@ -141,43 +143,41 @@ def merge_configs( ...@@ -141,43 +143,41 @@ def merge_configs(
"weight_decay", "weight_decay",
"sortagrad", "sortagrad",
"num_workers", "num_workers",
] ]
for item in remove_train_list: for item in remove_train_list:
try: try:
remove_config_part(config, [item]) remove_config_part(config, [item])
except: except:
print ( item + " " +"can not be removed") print(item + " " + "can not be removed")
# Save the config # Save the config
save(save_path, config) save(save_path, config)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(prog='Config merge', add_help=True)
prog='Config merge', add_help=True)
parser.add_argument( parser.add_argument(
'--cfg_pth', type=str, default = 'conf/transformer.yaml', help='origin config file') '--cfg_pth',
type=str,
default='conf/transformer.yaml',
help='origin config file')
parser.add_argument( parser.add_argument(
'--pre_pth', type=str, default= "conf/preprocess.yaml", help='') '--pre_pth', type=str, default="conf/preprocess.yaml", help='')
parser.add_argument( parser.add_argument(
'--dcd_pth', type=str, default= "conf/tuninig/decode.yaml", help='') '--dcd_pth', type=str, default="conf/tuninig/decode.yaml", help='')
parser.add_argument( parser.add_argument(
'--vb_pth', type=str, default= "data/lang_char/vocab.txt", help='') '--vb_pth', type=str, default="data/lang_char/vocab.txt", help='')
parser.add_argument( parser.add_argument(
'--cmvn_pth', type=str, default= "data/mean_std.json", help='') '--cmvn_pth', type=str, default="data/mean_std.json", help='')
parser.add_argument( parser.add_argument(
'--save_pth', type=str, default= "conf/transformer_infer.yaml", help='') '--save_pth', type=str, default="conf/transformer_infer.yaml", help='')
parser_args = parser.parse_args() parser_args = parser.parse_args()
merge_configs( merge_configs(
conf_path = parser_args.cfg_pth, conf_path=parser_args.cfg_pth,
decode_path = parser_args.dcd_pth, decode_path=parser_args.dcd_pth,
preprocess_path = parser_args.pre_pth, preprocess_path=parser_args.pre_pth,
vocab_path = parser_args.vb_pth, vocab_path=parser_args.vb_pth,
cmvn_path = parser_args.cmvn_pth, cmvn_path=parser_args.cmvn_pth,
save_path = parser_args.save_pth, save_path=parser_args.save_pth, )
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册