diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 8e8a1824500d035907a61a78894853a373a06384..468bc65216cf6256ca96d1c4eef0663f673edaf7 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -43,7 +43,8 @@ class DeepSpeech2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): start = time.time() - loss = self.model(*batch_data) + utt, audio, audio_len, text, text_len = batch_data + loss = self.model(audio, audio_len, text, text_len) loss.backward() layer_tools.print_grads(self.model, print_func=None) self.optimizer.step() @@ -73,9 +74,10 @@ class DeepSpeech2Trainer(Trainer): num_seen_utts = 1 total_loss = 0.0 for i, batch in enumerate(self.valid_loader): - loss = self.model(*batch) + utt, audio, audio_len, text, text_len = batch + loss = self.model(audio, audio_len, text, text_len) if paddle.isfinite(loss): - num_utts = batch[0].shape[0] + num_utts = batch[1].shape[0] num_seen_utts += num_utts total_loss += float(loss) * num_utts valid_losses['val_loss'].append(float(loss)) @@ -191,7 +193,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, audio, audio_len, texts, texts_len): + def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -213,11 +215,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - for target, result in zip(target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) logger.info("Current error rate [%s] = %f" % @@ -238,15 +242,16 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cfg = self.config error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 - - for i, batch in enumerate(self.test_loader): - metrics = self.compute_metrics(*batch) - errors_sum += metrics['errors_sum'] - len_refs += metrics['len_refs'] - num_ins += metrics['num_ins'] - error_rate_type = metrics['error_rate_type'] - logger.info("Error rate [%s] (%d/?) = %f" % - (error_rate_type, num_ins, errors_sum / len_refs)) + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + utts, audio, audio_len, texts, texts_len = batch + metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) # logging msg = "Test: " diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index f166a071e317286694c438288e23c5161a051ca9..334d6bc8e94a47d3fe4ba644f24965df3ea45579 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -76,8 +76,9 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() + utt, audio, audio_len, text, text_len = batch_data - loss, attention_loss, ctc_loss = self.model(*batch_data) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad loss.backward() @@ -119,9 +120,10 @@ class U2Trainer(Trainer): num_seen_utts = 1 total_loss = 0.0 for i, batch in enumerate(self.valid_loader): - loss, attention_loss, ctc_loss = self.model(*batch) + utt, audio, audio_len, text, text_len = batch + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) if paddle.isfinite(loss): - num_utts = batch[0].shape[0] + num_utts = batch[1].shape[0] num_seen_utts += num_utts total_loss += float(loss) * num_utts valid_losses['val_loss'].append(float(loss)) @@ -366,7 +368,7 @@ class U2Tester(U2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, audio, audio_len, texts, texts_len, fout=None): + def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -393,13 +395,13 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for target, result in zip(target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref num_ins += 1 if fout: - fout.write(result + "\n") + fout.write(utt + " " + result + "\n") logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) logger.info("One example error rate [%s] = %f" % diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 7f019039c93471670011b52c66c98041328e2ea4..3bec9875f43b8242d4e90ce921b5edeb19f10414 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -51,7 +51,10 @@ class SpeechCollator(): audio_lens = [] texts = [] text_lens = [] - for audio, text in batch: + utts = [] + for utt, audio, text in batch: + #utt + utts.append(utt) # audio audios.append(audio.T) # [T, D] audio_lens.append(audio.shape[1]) @@ -75,4 +78,4 @@ class SpeechCollator(): padded_texts = pad_sequence( texts, padding_value=IGNORE_ID).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64) - return padded_audios, audio_lens, padded_texts, text_lens + return utts, padded_audios, audio_lens, padded_texts, text_lens diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index fba5f7c66890aeaa2d9650fcd1da11be99e18f75..1cf3827d344ad57bfa18d0b3ce227cc5b2f6e6f8 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -347,4 +347,6 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] - return self.process_utterance(instance["feat"], instance["text"]) + feat, text =self.process_utterance(instance["feat"], + instance["text"]) + return instance["utt"], feat, text diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 238e2d35c5492097868c2ff8ea1ff941bd27dc9e..bcfddaef0e4f397a901f916b59f1a31c30bf0ac8 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -905,6 +905,7 @@ class U2InferModel(U2Model): def __init__(self, configs: dict): super().__init__(configs) + def forward(self, feats, feats_lengths, diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py index 111f5d3b490c0355cf6c40240f6f75bb71b72b5e..8bf48b2c80de27f6270f5858d03a90098ddb18f1 100644 --- a/deepspeech/modules/conv.py +++ b/deepspeech/modules/conv.py @@ -114,7 +114,8 @@ class ConvBn(nn.Layer): masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] # TODO(Hui Zhang): not support bool multiply - masks = masks.type_as(x) + # masks = masks.type_as(x) + masks = masks.astype(x.dtype) x = x.multiply(masks) return x, x_len diff --git a/deepspeech/modules/rnn.py b/deepspeech/modules/rnn.py index 29bd28839f711088ddb7a67ea1696aa62b64fa3f..01b55c4a2f55e4bb3f61ab22fbd00be0f290fcd3 100644 --- a/deepspeech/modules/rnn.py +++ b/deepspeech/modules/rnn.py @@ -309,6 +309,6 @@ class RNNStack(nn.Layer): masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] # TODO(Hui Zhang): not support bool multiply - masks = masks.type_as(x) + masks = masks.astype(x.dtype) x = x.multiply(masks) return x, x_len diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index d4961adb2a36df578388d1575fa1f6f205fed55d..4073c81b9eba69a8ee7af21d5d4595c9819357df 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -26,7 +26,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index 016502298df6131c458edfdb82bf8993ee3044a8..4cf09553bfece6b425799a4e44ccc78cdbc3fc6a 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/dataset/librispeech/.gitignore b/examples/dataset/librispeech/.gitignore index a8d8eb76d28008139709d46469f4eeb31c400099..dfd5c67b593408b61a6fc6f5cd446483c702ab2f 100644 --- a/examples/dataset/librispeech/.gitignore +++ b/examples/dataset/librispeech/.gitignore @@ -1,7 +1,7 @@ -dev-clean/ -dev-other/ -test-clean/ -test-other/ -train-clean-100/ -train-clean-360/ -train-other-500/ +dev-clean +dev-other +test-clean +test-other +train-clean-100 +train-clean-360 +train-other-500 diff --git a/examples/librispeech/s0/run.sh b/examples/librispeech/s0/run.sh index 3e536bd79ac8e21c3a28e0f206c57996cd0a6dc9..6553e073ded9ba4105fd9b69cac520a9e87fa963 100755 --- a/examples/librispeech/s0/run.sh +++ b/examples/librispeech/s0/run.sh @@ -24,7 +24,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/librispeech/s1/run.sh b/examples/librispeech/s1/run.sh index 472e6ebfbd4d939bddb41771b5938c2b7595d6ff..65194d902e7b2a8553f996020e9485ca58d7402a 100755 --- a/examples/librispeech/s1/run.sh +++ b/examples/librispeech/s1/run.sh @@ -24,7 +24,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/tiny/s0/run.sh b/examples/tiny/s0/run.sh index d4961adb2a36df578388d1575fa1f6f205fed55d..d7e153e8d2346ed57fd16c4ddd3bed176bcde03c 100755 --- a/examples/tiny/s0/run.sh +++ b/examples/tiny/s0/run.sh @@ -11,7 +11,7 @@ avg_num=1 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} -ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ###ckpt = deepspeech2 echo "checkpoint name ${ckpt}" if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then @@ -26,7 +26,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 35c11731cc3acbc17fe4cf4c410f731b3b384e3c..0a7cf3be845b68a87799904f7fdf167813fb1794 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -70,7 +70,7 @@ model: training: - n_epoch: 20 + n_epoch: 2 accum_grad: 1 global_grad_clip: 5.0 optim: adam @@ -85,7 +85,7 @@ training: decoding: - batch_size: 64 + batch_size: 8 #64 error_rate_type: wer decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index f7e41a338e82b88f01ba1c238082ad686f9f83ef..b148869b7d6aaecc7f9181818be315846ee11012 100755 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -20,12 +20,12 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} + ./local/train.sh ${conf_path} ${ckpt} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then @@ -35,5 +35,5 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi