提交 b4bda290 编写于 作者: H Haoxin Ma

fix bugs

上级 8781ab58
...@@ -368,7 +368,7 @@ class U2Tester(U2Trainer): ...@@ -368,7 +368,7 @@ class U2Tester(U2Trainer):
trans.append(''.join([chr(i) for i in ids])) trans.append(''.join([chr(i) for i in ids]))
return trans return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None, fref=None): def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None):
cfg = self.config.decoding cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
...@@ -402,8 +402,6 @@ class U2Tester(U2Trainer): ...@@ -402,8 +402,6 @@ class U2Tester(U2Trainer):
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write(utt + " " + result + "\n")
if fref:
fref.write(utt + " " + target + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result)) (target, result))
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
...@@ -432,7 +430,6 @@ class U2Tester(U2Trainer): ...@@ -432,7 +430,6 @@ class U2Tester(U2Trainer):
num_time = 0.0 num_time = 0.0
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
# utt, audio, audio_len, text, text_len = batch
metrics = self.compute_metrics(*batch, fout=fout) metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames'] num_frames += metrics['num_frames']
num_time += metrics["decode_time"] num_time += metrics["decode_time"]
......
...@@ -284,7 +284,7 @@ class ManifestDataset(Dataset): ...@@ -284,7 +284,7 @@ class ManifestDataset(Dataset):
return self._local_data.tar2object[tarpath].extractfile( return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename]) self._local_data.tar2info[tarpath][filename])
def process_utterance(self, utt, audio_file, transcript): def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data. """Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file. :param audio_file: Filepath or file object of audio file.
...@@ -323,7 +323,7 @@ class ManifestDataset(Dataset): ...@@ -323,7 +323,7 @@ class ManifestDataset(Dataset):
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}") #logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return utt, specgram, transcript_part return specgram, transcript_part
def _instance_reader_creator(self, manifest): def _instance_reader_creator(self, manifest):
""" """
...@@ -336,9 +336,7 @@ class ManifestDataset(Dataset): ...@@ -336,9 +336,7 @@ class ManifestDataset(Dataset):
def reader(): def reader():
for instance in manifest: for instance in manifest:
# inst = self.process_utterance(instance["feat"], inst = self.process_utterance(instance["feat"],
# instance["text"])
inst = self.process_utterance(instance["utt"], instance["feat"],
instance["text"]) instance["text"])
yield inst yield inst
...@@ -349,6 +347,6 @@ class ManifestDataset(Dataset): ...@@ -349,6 +347,6 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
instance = self._manifest[idx] instance = self._manifest[idx]
return self.process_utterance(instance["utt"], instance["feat"], feat, text =self.process_utterance(instance["feat"],
instance["text"]) instance["text"])
# return self.process_utterance(instance["feat"], instance["text"]) return instance["utt"], feat, text
...@@ -8,7 +8,7 @@ data: ...@@ -8,7 +8,7 @@ data:
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: "" mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 2 #4 batch_size: 4
min_input_len: 0.5 # second min_input_len: 0.5 # second
max_input_len: 20.0 # second max_input_len: 20.0 # second
min_output_len: 0.0 # tokens min_output_len: 0.0 # tokens
...@@ -31,7 +31,7 @@ data: ...@@ -31,7 +31,7 @@ data:
keep_transcription_text: False keep_transcription_text: False
sortagrad: True sortagrad: True
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 0 #2 num_workers: 2
# network architecture # network architecture
......
...@@ -30,12 +30,10 @@ fi ...@@ -30,12 +30,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
# CUDA_VISIBLE_DEVICES=7 CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n # export ckpt avg_n
# CUDA_VISIBLE_DEVICES= CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册