提交 7bae32f3 编写于 作者: H Haoxin Ma

revise example/ting/s1

上级 b9110af9
......@@ -72,7 +72,7 @@ _C.collator =CN(
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=True
keep_transcription_text=False
))
DeepSpeech2Model.params(_C.model)
......
......@@ -336,13 +336,14 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
config.collator.keep_transcription_text = True
# return text ord id
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(config=config, keep_transcription_text=True))
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup test Dataloader!")
def setup_output_dir(self):
......
......@@ -22,6 +22,13 @@ _C = CfgNode()
_C.data = ManifestDataset.params()
_C.collator =CfgNode(
dict(
augmentation_config="",
unit_type="char",
keep_transcription_text=False
))
_C.model = U2Model.params()
_C.training = U2Trainer.params()
......
......@@ -221,7 +221,7 @@ class U2Trainer(Trainer):
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False)
collate_fn = SpeechCollator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
......@@ -266,12 +266,13 @@ class U2Trainer(Trainer):
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
config.collator.keep_transcription_text = True
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test Dataloader!")
def setup_model(self):
......@@ -375,7 +376,7 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = self.test_loader.dataset.text_feature
text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode(
audio,
......@@ -423,7 +424,7 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.dataset.stride_ms
stride_ms = self.config.collator.stride_ms
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
......
......@@ -82,7 +82,7 @@ def read_manifest(
]
if all(conditions):
manifest.append(json_data)
return manifest, json_data["feat_shape"][-1]
return manifest
def rms_to_db(rms: float):
......
......@@ -56,7 +56,7 @@ class SpeechCollator():
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=True
keep_transcription_text=False
))
if config is not None:
......@@ -75,7 +75,7 @@ class SpeechCollator():
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'mean_std_filepath' in config.data
assert 'vocab_filepath' in config.data
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
......@@ -94,7 +94,7 @@ class SpeechCollator():
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
mean_std_filepath=config.data.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.data.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
......@@ -282,26 +282,11 @@ class SpeechCollator():
text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._text_featurizer
self._speech_featurizer.text_feature
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
......
......@@ -161,7 +161,7 @@ class ManifestDataset(Dataset):
# self._rng = np.random.RandomState(random_seed)
# read manifest
self._manifest, self._feature_size = read_manifest(
self._manifest = read_manifest(
manifest_path=manifest_path,
max_input_len=max_input_len,
min_input_len=min_input_len,
......@@ -213,16 +213,8 @@ class ManifestDataset(Dataset):
Returns:
int: audio feature size.
"""
return self._feature_size
return self._manifest[0]["feat_shape"][-1]
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
def __len__(self):
......
......@@ -6,7 +6,6 @@ data:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.0
max_input_len: 27.0
......@@ -14,18 +13,6 @@ data:
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
specgram_type: linear
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 20.0
delta_delta: False
dither: 1.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
......@@ -33,7 +20,6 @@ data:
collator:
augmentation_config: conf/augmentation.json
random_seed: 0
mean_std_filepath: data/mean_std.json
spm_model_prefix:
specgram_type: linear
feat_dim:
......@@ -46,7 +32,7 @@ collator:
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: True
keep_transcription_text: False
model:
num_conv_layers: 2
......
......@@ -7,7 +7,6 @@ data:
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5 # second
max_input_len: 20.0 # second
......@@ -16,23 +15,26 @@ data:
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0 #2
collator:
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
random_seed: 0
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
......@@ -70,7 +72,7 @@ model:
training:
n_epoch: 2
n_epoch: 3
accum_grad: 1
global_grad_clip: 5.0
optim: adam
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册