提交 862150b5 编写于 作者: H Hui Zhang

u2 kaldi can train, but ctc loss high

上级 48438066
...@@ -393,6 +393,7 @@ class U2Tester(U2Trainer): ...@@ -393,6 +393,7 @@ class U2Tester(U2Trainer):
texts, texts,
texts_len, texts_len,
fout=None): fout=None):
logger.info(f"Input: {audio.shape} {audio_len}, {texts} {texts_len}")
cfg = self.config.decoding cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
...@@ -430,8 +431,9 @@ class U2Tester(U2Trainer): ...@@ -430,8 +431,9 @@ class U2Tester(U2Trainer):
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write(utt + " " + result + "\n") fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % logger.info(f"Utt: {utt}")
(target, result)) logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result))) (cfg.error_rate_type, error_rate_func(target, result)))
......
...@@ -297,10 +297,12 @@ class U2BaseModel(nn.Layer): ...@@ -297,10 +297,12 @@ class U2BaseModel(nn.Layer):
num_decoding_left_chunks, num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim) simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.size(1)
# logger.info(f"att:maxlen {maxlen}")
encoder_dim = encoder_out.size(2) encoder_dim = encoder_out.size(2)
running_size = batch_size * beam_size running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
# logger.info(f"att: encoder_mask {encoder_mask}")
encoder_mask = encoder_mask.unsqueeze(1).repeat( encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1, 1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len) maxlen) # (B*N, 1, max_len)
...@@ -314,6 +316,7 @@ class U2BaseModel(nn.Layer): ...@@ -314,6 +316,7 @@ class U2BaseModel(nn.Layer):
device) # (B*N, 1) device) # (B*N, 1)
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1) end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
cache: Optional[List[paddle.Tensor]] = None cache: Optional[List[paddle.Tensor]] = None
# logger.info(f"att: hyps {hyps} eos: {self.eos}")
# 2. Decoder forward step by step # 2. Decoder forward step by step
for i in range(1, maxlen + 1): for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos # Stop if all batch and all beam produce eos
...@@ -323,6 +326,7 @@ class U2BaseModel(nn.Layer): ...@@ -323,6 +326,7 @@ class U2BaseModel(nn.Layer):
# 2.1 Forward decoder step # 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i) running_size, 1, 1).to(device) # (B*N, i, i)
# logger.info(f"att: {i} {hyps_mask}")
# logp: (B*N, vocab) # logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step( logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache) encoder_out, encoder_mask, hyps, hyps_mask, cache)
...@@ -332,7 +336,7 @@ class U2BaseModel(nn.Layer): ...@@ -332,7 +336,7 @@ class U2BaseModel(nn.Layer):
top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
# 2.3 Seconde beam prune: select topk score with history # 2.3 Second beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N) scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
...@@ -356,9 +360,10 @@ class U2BaseModel(nn.Layer): ...@@ -356,9 +360,10 @@ class U2BaseModel(nn.Layer):
hyps = paddle.cat( hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)), (last_best_k_hyps, best_k_pred.view(-1, 1)),
dim=1) # (B*N, i+1) dim=1) # (B*N, i+1)
# logger.info(f"att: hyps {hyps}")
# 2.6 Update end flag # 2.6 Update end flag
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1) end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
# logger.info(f"att: end_flag {end_flag}")
# 3. Select best of best # 3. Select best of best
scores = scores.view(batch_size, beam_size) scores = scores.view(batch_size, beam_size)
...@@ -368,6 +373,7 @@ class U2BaseModel(nn.Layer): ...@@ -368,6 +373,7 @@ class U2BaseModel(nn.Layer):
batch_size, dtype=paddle.long) * beam_size batch_size, dtype=paddle.long) * beam_size
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0) best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
best_hyps = best_hyps[:, 1:] best_hyps = best_hyps[:, 1:]
# logger.info(f"att: best_hyps {best_hyps}")
return best_hyps return best_hyps
def ctc_greedy_search( def ctc_greedy_search(
...@@ -802,6 +808,7 @@ class U2BaseModel(nn.Layer): ...@@ -802,6 +808,7 @@ class U2BaseModel(nn.Layer):
else: else:
raise ValueError(f"Not support decoding method: {decoding_method}") raise ValueError(f"Not support decoding method: {decoding_method}")
logger.info(f"hyps: {hyps}")
res = [text_feature.defeaturize(hyp) for hyp in hyps] res = [text_feature.defeaturize(hyp) for hyp in hyps]
return res return res
......
...@@ -49,7 +49,7 @@ class CTCLoss(nn.Layer): ...@@ -49,7 +49,7 @@ class CTCLoss(nn.Layer):
# (TODO:Hui Zhang) ctc loss does not support int64 labels # (TODO:Hui Zhang) ctc loss does not support int64 labels
ys_pad = ys_pad.astype(paddle.int32) ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss( loss = self.loss(
logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average) logits, ys_pad, hlens, ys_lens, norm_by_batchsize=self.batch_average)
if self.batch_average: if self.batch_average:
# Batch-size average # Batch-size average
loss = loss / B loss = loss / B
......
...@@ -163,8 +163,8 @@ class Trainer(): ...@@ -163,8 +163,8 @@ class Trainer():
checkpoint_path=self.args.checkpoint_path) checkpoint_path=self.args.checkpoint_path)
if infos: if infos:
# restore from ckpt # restore from ckpt
self.iteration = infos["step"] self.iteration = infos["step"] + 1
self.epoch = infos["epoch"] self.epoch = infos["epoch"] + 1
scratch = False scratch = False
else: else:
self.iteration = 0 self.iteration = 0
......
...@@ -6,7 +6,7 @@ gpus=0,1,2,3 ...@@ -6,7 +6,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
avg_num=1 avg_num=5
model_type=offline model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......
...@@ -16,7 +16,7 @@ collator: ...@@ -16,7 +16,7 @@ collator:
spm_model_prefix: 'data/bpe_unigram_5000' spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: "" mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 64 batch_size: 32
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
...@@ -73,13 +73,13 @@ model: ...@@ -73,13 +73,13 @@ model:
training: training:
n_epoch: 120 n_epoch: 120
accum_grad: 2 accum_grad: 4
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: noam
optim_conf: optim_conf:
lr: 0.004 lr: 10.0
weight_decay: 1e-06 weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required scheduler: noam # pytorch v1.1.0+ required
scheduler_conf: scheduler_conf:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
......
...@@ -14,11 +14,6 @@ ...@@ -14,11 +14,6 @@
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | | | conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | | | conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | |
## Chunk Conformer ## Chunk Conformer
...@@ -33,9 +28,6 @@ ...@@ -33,9 +28,6 @@
## Transformer ## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | | Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | | | transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 9.27137279510498, | 0.038421 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 9.27137279510498, | 0.120112 |
### Test w/o length filter | transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 9.27137279510498, | 0.116441 |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | |
...@@ -12,7 +12,7 @@ collator: ...@@ -12,7 +12,7 @@ collator:
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32 batch_size: 30
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug minibatches: 0 # for debug
...@@ -22,7 +22,7 @@ collator: ...@@ -22,7 +22,7 @@ collator:
batch_frames_out: 0 batch_frames_out: 0
batch_frames_inout: 0 batch_frames_inout: 0
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
num_workers: 2 num_workers: 0
subsampling_factor: 1 subsampling_factor: 1
num_encs: 1 num_encs: 1
...@@ -81,7 +81,7 @@ scheduler_conf: ...@@ -81,7 +81,7 @@ scheduler_conf:
lr_decay: 1.0 lr_decay: 1.0
decoding: decoding:
batch_size: 64 batch_size: 1
error_rate_type: wer error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
......
...@@ -30,13 +30,15 @@ echo "chunk mode ${chunk_mode}" ...@@ -30,13 +30,15 @@ echo "chunk mode ${chunk_mode}"
# exit 1 # exit 1
#fi #fi
#for type in attention ctc_greedy_search; do
for type in attention ctc_greedy_search; do for type in attention ctc_greedy_search; do
echo "decoding ${type}" echo "decoding ${type}"
if [ ${chunk_mode} == true ];then if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1 # stream decoding only support batchsize=1
batch_size=1 batch_size=1
else else
batch_size=64 #batch_size=64
batch_size=1
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--model-name u2_kaldi \ --model-name u2_kaldi \
......
...@@ -19,8 +19,8 @@ echo "using ${device}..." ...@@ -19,8 +19,8 @@ echo "using ${device}..."
mkdir -p exp mkdir -p exp
seed=1024 seed=0
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
...@@ -32,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -32,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
if [ ${seed} ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
fi fi
......
...@@ -6,7 +6,7 @@ stage=0 ...@@ -6,7 +6,7 @@ stage=0
stop_stage=100 stop_stage=100
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
dict_path=data/train_960_unigram5000_units.txt dict_path=data/train_960_unigram5000_units.txt
avg_num=5 avg_num=10
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
...@@ -20,12 +20,12 @@ fi ...@@ -20,12 +20,12 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num} avg.sh latest exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
...@@ -80,8 +80,8 @@ def main(args): ...@@ -80,8 +80,8 @@ def main(args):
data = json.dumps({ data = json.dumps({
"avg_ckpt": args.dst_model, "avg_ckpt": args.dst_model,
"ckpt": path_list, "ckpt": path_list,
"epoch": selected_epochs.tolist(), "epoch": selected_epochs,
"val_loss": beat_val_scores.tolist(), "val_loss": beat_val_scores,
}) })
f.write(data + "\n") f.write(data + "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册