提交 b4bda290 编写于 作者: H Haoxin Ma

fix bugs

上级 8781ab58
......@@ -368,7 +368,7 @@ class U2Tester(U2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None, fref=None):
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
......@@ -402,8 +402,6 @@ class U2Tester(U2Trainer):
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
if fref:
fref.write(utt + " " + target + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("One example error rate [%s] = %f" %
......@@ -432,7 +430,6 @@ class U2Tester(U2Trainer):
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
# utt, audio, audio_len, text, text_len = batch
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
......
......@@ -284,7 +284,7 @@ class ManifestDataset(Dataset):
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, utt, audio_file, transcript):
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
......@@ -323,7 +323,7 @@ class ManifestDataset(Dataset):
specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return utt, specgram, transcript_part
return specgram, transcript_part
def _instance_reader_creator(self, manifest):
"""
......@@ -336,9 +336,7 @@ class ManifestDataset(Dataset):
def reader():
for instance in manifest:
# inst = self.process_utterance(instance["feat"],
# instance["text"])
inst = self.process_utterance(instance["utt"], instance["feat"],
inst = self.process_utterance(instance["feat"],
instance["text"])
yield inst
......@@ -349,6 +347,6 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx):
instance = self._manifest[idx]
return self.process_utterance(instance["utt"], instance["feat"],
feat, text =self.process_utterance(instance["feat"],
instance["text"])
# return self.process_utterance(instance["feat"], instance["text"])
return instance["utt"], feat, text
......@@ -8,7 +8,7 @@ data:
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 2 #4
batch_size: 4
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
......@@ -31,7 +31,7 @@ data:
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0 #2
num_workers: 2
# network architecture
......
......@@ -30,12 +30,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
# CUDA_VISIBLE_DEVICES=7
./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
# CUDA_VISIBLE_DEVICES=
./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册