提交 8c5b8e35 编写于 作者: H Hui Zhang

add rtf

上级 9ad706e2
......@@ -362,8 +362,8 @@ class U2Tester(U2Trainer):
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = self.test_loader.dataset.text_feature
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode(
audio,
......@@ -381,7 +381,8 @@ class U2Tester(U2Trainer):
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time()
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
......@@ -397,9 +398,11 @@ class U2Tester(U2Trainer):
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins,
num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
error_rate_type=cfg.error_rate_type,
num_frames=audio_len.sum().numpy().item(),
decode_time=decode_time)
@mp_tools.rank_zero_only
@paddle.no_grad()
......@@ -410,10 +413,13 @@ class U2Tester(U2Trainer):
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
......@@ -421,11 +427,13 @@ class U2Tester(U2Trainer):
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
rtf = num_time / (num_frames * self.test_loader.dataset.stride_ms / 1000.0)
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += ", Final error rate [%s] (%d/%d) = %f" % (
msg += "RTF: {}, ".format(rtf)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
......
......@@ -105,6 +105,10 @@ class AudioFeaturizer(object):
# extract spectrogram
return self._compute_specgram(audio_segment)
@property
def stride_ms(self):
return self._stride_ms
@property
def feature_size(self):
"""audio feature size"""
......
......@@ -63,7 +63,8 @@ class SpeechFeaturizer(object):
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
target_dB=-20,
dither=1.0):
self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type,
feat_dim=feat_dim,
......@@ -74,7 +75,8 @@ class SpeechFeaturizer(object):
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
target_dB=target_dB,
dither=dither)
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
spm_model_prefix)
......@@ -138,6 +140,15 @@ class SpeechFeaturizer(object):
"""
return self._audio_featurizer.feature_size
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
@property
def text_feature(self):
"""Return the text feature object.
......
......@@ -63,6 +63,7 @@ class ManifestDataset(Dataset):
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
dither=1.0, # feature dither
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
......@@ -123,6 +124,7 @@ class ManifestDataset(Dataset):
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delta_delta,
dither=config.data.dither,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
......@@ -150,6 +152,7 @@ class ManifestDataset(Dataset):
specgram_type='linear',
feat_dim=None,
delta_delta=False,
dither=1.0,
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
......@@ -183,13 +186,10 @@ class ManifestDataset(Dataset):
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
"""
super().__init__()
self._max_input_len = max_input_len,
self._min_input_len = min_input_len,
self._max_output_len = max_output_len,
self._min_output_len = min_output_len,
self._max_output_input_ratio = max_output_input_ratio,
self._min_output_input_ratio = min_output_input_ratio,
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._augmentation_pipeline = AugmentationPipeline(
......@@ -207,7 +207,8 @@ class ManifestDataset(Dataset):
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
target_dB=target_dB,
dither=dither)
self._rng = np.random.RandomState(random_seed)
self._keep_transcription_text = keep_transcription_text
......@@ -250,6 +251,10 @@ class ManifestDataset(Dataset):
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
......
......@@ -18,6 +18,7 @@ data:
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
......
......@@ -19,6 +19,7 @@ data:
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册