提交 862150b5 编写于 作者: H Hui Zhang

u2 kaldi can train, but ctc loss high

上级 48438066
......@@ -393,6 +393,7 @@ class U2Tester(U2Trainer):
texts,
texts_len,
fout=None):
logger.info(f"Input: {audio.shape} {audio_len}, {texts} {texts_len}")
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
......@@ -430,8 +431,9 @@ class U2Tester(U2Trainer):
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
......
......@@ -297,10 +297,12 @@ class U2BaseModel(nn.Layer):
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
# logger.info(f"att:maxlen {maxlen}")
encoder_dim = encoder_out.size(2)
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
# logger.info(f"att: encoder_mask {encoder_mask}")
encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len)
......@@ -314,6 +316,7 @@ class U2BaseModel(nn.Layer):
device) # (B*N, 1)
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
cache: Optional[List[paddle.Tensor]] = None
# logger.info(f"att: hyps {hyps} eos: {self.eos}")
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
......@@ -323,6 +326,7 @@ class U2BaseModel(nn.Layer):
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logger.info(f"att: {i} {hyps_mask}")
# logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
......@@ -332,7 +336,7 @@ class U2BaseModel(nn.Layer):
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
# 2.3 Seconde beam prune: select topk score with history
# 2.3 Second beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
......@@ -356,9 +360,10 @@ class U2BaseModel(nn.Layer):
hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)),
dim=1) # (B*N, i+1)
# logger.info(f"att: hyps {hyps}")
# 2.6 Update end flag
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
# logger.info(f"att: end_flag {end_flag}")
# 3. Select best of best
scores = scores.view(batch_size, beam_size)
......@@ -368,6 +373,7 @@ class U2BaseModel(nn.Layer):
batch_size, dtype=paddle.long) * beam_size
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
best_hyps = best_hyps[:, 1:]
# logger.info(f"att: best_hyps {best_hyps}")
return best_hyps
def ctc_greedy_search(
......@@ -802,6 +808,7 @@ class U2BaseModel(nn.Layer):
else:
raise ValueError(f"Not support decoding method: {decoding_method}")
logger.info(f"hyps: {hyps}")
res = [text_feature.defeaturize(hyp) for hyp in hyps]
return res
......
......@@ -49,7 +49,7 @@ class CTCLoss(nn.Layer):
# (TODO:Hui Zhang) ctc loss does not support int64 labels
ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss(
logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average)
logits, ys_pad, hlens, ys_lens, norm_by_batchsize=self.batch_average)
if self.batch_average:
# Batch-size average
loss = loss / B
......@@ -90,8 +90,8 @@ class LabelSmoothingLoss(nn.Layer):
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
True, normalize loss by sequence length;
normalize_length (bool):
True, normalize loss by sequence length;
False, normalize loss by batch size.
Defaults to False.
"""
......@@ -108,7 +108,7 @@ class LabelSmoothingLoss(nn.Layer):
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor):
......
......@@ -163,8 +163,8 @@ class Trainer():
checkpoint_path=self.args.checkpoint_path)
if infos:
# restore from ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.iteration = infos["step"] + 1
self.epoch = infos["epoch"] + 1
scratch = False
else:
self.iteration = 0
......
......@@ -6,7 +6,7 @@ gpus=0,1,2,3
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
avg_num=5
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......
......@@ -16,7 +16,7 @@ collator:
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
batch_size: 32
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
......@@ -73,13 +73,13 @@ model:
training:
n_epoch: 120
accum_grad: 2
accum_grad: 4
global_grad_clip: 5.0
optim: adam
optim: noam
optim_conf:
lr: 0.004
lr: 10.0
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler: noam # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
......
......@@ -14,11 +14,6 @@
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | |
## Chunk Conformer
......@@ -33,9 +28,6 @@
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 9.27137279510498, | 0.038421 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 9.27137279510498, | 0.120112 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 9.27137279510498, | 0.116441 |
......@@ -12,7 +12,7 @@ collator:
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
batch_size: 30
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
......@@ -22,7 +22,7 @@ collator:
batch_frames_out: 0
batch_frames_inout: 0
augmentation_config: conf/augmentation.json
num_workers: 2
num_workers: 0
subsampling_factor: 1
num_encs: 1
......@@ -81,7 +81,7 @@ scheduler_conf:
lr_decay: 1.0
decoding:
batch_size: 64
batch_size: 1
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
......
......@@ -30,13 +30,15 @@ echo "chunk mode ${chunk_mode}"
# exit 1
#fi
#for type in attention ctc_greedy_search; do
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
#batch_size=64
batch_size=1
fi
python3 -u ${BIN_DIR}/test.py \
--model-name u2_kaldi \
......
......@@ -19,8 +19,8 @@ echo "using ${device}..."
mkdir -p exp
seed=1024
if [ ${seed} ]; then
seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi
......@@ -32,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} ]; then
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi
......
......@@ -6,7 +6,7 @@ stage=0
stop_stage=100
conf_path=conf/transformer.yaml
dict_path=data/train_960_unigram5000_units.txt
avg_num=5
avg_num=10
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
......@@ -20,12 +20,12 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh latest exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
......
......@@ -80,8 +80,8 @@ def main(args):
data = json.dumps({
"avg_ckpt": args.dst_model,
"ckpt": path_list,
"epoch": selected_epochs.tolist(),
"val_loss": beat_val_scores.tolist(),
"epoch": selected_epochs,
"val_loss": beat_val_scores,
})
f.write(data + "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册