Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
862150b5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
862150b5
编写于
10月 08, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
u2 kaldi can train, but ctc loss high
上级
48438066
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
42 addition
and
39 deletion
+42
-39
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+4
-2
deepspeech/models/u2.py
deepspeech/models/u2.py
+9
-2
deepspeech/modules/loss.py
deepspeech/modules/loss.py
+4
-4
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+2
-2
examples/aishell/s0/run.sh
examples/aishell/s0/run.sh
+1
-1
examples/librispeech/s1/conf/transformer.yaml
examples/librispeech/s1/conf/transformer.yaml
+5
-5
examples/librispeech/s2/README.md
examples/librispeech/s2/README.md
+3
-11
examples/librispeech/s2/conf/transformer.yaml
examples/librispeech/s2/conf/transformer.yaml
+3
-3
examples/librispeech/s2/local/test.sh
examples/librispeech/s2/local/test.sh
+3
-1
examples/librispeech/s2/local/train.sh
examples/librispeech/s2/local/train.sh
+3
-3
examples/librispeech/s2/run.sh
examples/librispeech/s2/run.sh
+3
-3
utils/avg_model.py
utils/avg_model.py
+2
-2
未找到文件。
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
862150b5
...
@@ -393,6 +393,7 @@ class U2Tester(U2Trainer):
...
@@ -393,6 +393,7 @@ class U2Tester(U2Trainer):
texts
,
texts
,
texts_len
,
texts_len
,
fout
=
None
):
fout
=
None
):
logger
.
info
(
f
"Input:
{
audio
.
shape
}
{
audio_len
}
,
{
texts
}
{
texts_len
}
"
)
cfg
=
self
.
config
.
decoding
cfg
=
self
.
config
.
decoding
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_func
=
error_rate
.
char_errors
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
word_errors
errors_func
=
error_rate
.
char_errors
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
word_errors
...
@@ -430,8 +431,9 @@ class U2Tester(U2Trainer):
...
@@ -430,8 +431,9 @@ class U2Tester(U2Trainer):
num_ins
+=
1
num_ins
+=
1
if
fout
:
if
fout
:
fout
.
write
(
utt
+
" "
+
result
+
"
\n
"
)
fout
.
write
(
utt
+
" "
+
result
+
"
\n
"
)
logger
.
info
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
logger
.
info
(
f
"Utt:
{
utt
}
"
)
(
target
,
result
))
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
logger
.
info
(
"One example error rate [%s] = %f"
%
logger
.
info
(
"One example error rate [%s] = %f"
%
(
cfg
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
(
cfg
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
...
...
deepspeech/models/u2.py
浏览文件 @
862150b5
...
@@ -297,10 +297,12 @@ class U2BaseModel(nn.Layer):
...
@@ -297,10 +297,12 @@ class U2BaseModel(nn.Layer):
num_decoding_left_chunks
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
simulate_streaming
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
maxlen
=
encoder_out
.
size
(
1
)
# logger.info(f"att:maxlen {maxlen}")
encoder_dim
=
encoder_out
.
size
(
2
)
encoder_dim
=
encoder_out
.
size
(
2
)
running_size
=
batch_size
*
beam_size
running_size
=
batch_size
*
beam_size
encoder_out
=
encoder_out
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
encoder_out
=
encoder_out
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
maxlen
,
encoder_dim
)
# (B*N, maxlen, encoder_dim)
running_size
,
maxlen
,
encoder_dim
)
# (B*N, maxlen, encoder_dim)
# logger.info(f"att: encoder_mask {encoder_mask}")
encoder_mask
=
encoder_mask
.
unsqueeze
(
1
).
repeat
(
encoder_mask
=
encoder_mask
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
1
,
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
1
,
maxlen
)
# (B*N, 1, max_len)
maxlen
)
# (B*N, 1, max_len)
...
@@ -314,6 +316,7 @@ class U2BaseModel(nn.Layer):
...
@@ -314,6 +316,7 @@ class U2BaseModel(nn.Layer):
device
)
# (B*N, 1)
device
)
# (B*N, 1)
end_flag
=
paddle
.
zeros_like
(
scores
,
dtype
=
paddle
.
bool
)
# (B*N, 1)
end_flag
=
paddle
.
zeros_like
(
scores
,
dtype
=
paddle
.
bool
)
# (B*N, 1)
cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
# logger.info(f"att: hyps {hyps} eos: {self.eos}")
# 2. Decoder forward step by step
# 2. Decoder forward step by step
for
i
in
range
(
1
,
maxlen
+
1
):
for
i
in
range
(
1
,
maxlen
+
1
):
# Stop if all batch and all beam produce eos
# Stop if all batch and all beam produce eos
...
@@ -323,6 +326,7 @@ class U2BaseModel(nn.Layer):
...
@@ -323,6 +326,7 @@ class U2BaseModel(nn.Layer):
# 2.1 Forward decoder step
# 2.1 Forward decoder step
hyps_mask
=
subsequent_mask
(
i
).
unsqueeze
(
0
).
repeat
(
hyps_mask
=
subsequent_mask
(
i
).
unsqueeze
(
0
).
repeat
(
running_size
,
1
,
1
).
to
(
device
)
# (B*N, i, i)
running_size
,
1
,
1
).
to
(
device
)
# (B*N, i, i)
# logger.info(f"att: {i} {hyps_mask}")
# logp: (B*N, vocab)
# logp: (B*N, vocab)
logp
,
cache
=
self
.
decoder
.
forward_one_step
(
logp
,
cache
=
self
.
decoder
.
forward_one_step
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_mask
,
cache
)
encoder_out
,
encoder_mask
,
hyps
,
hyps_mask
,
cache
)
...
@@ -332,7 +336,7 @@ class U2BaseModel(nn.Layer):
...
@@ -332,7 +336,7 @@ class U2BaseModel(nn.Layer):
top_k_logp
=
mask_finished_scores
(
top_k_logp
,
end_flag
)
top_k_logp
=
mask_finished_scores
(
top_k_logp
,
end_flag
)
top_k_index
=
mask_finished_preds
(
top_k_index
,
end_flag
,
self
.
eos
)
top_k_index
=
mask_finished_preds
(
top_k_index
,
end_flag
,
self
.
eos
)
# 2.3 Second
e
beam prune: select topk score with history
# 2.3 Second beam prune: select topk score with history
scores
=
scores
+
top_k_logp
# (B*N, N), broadcast add
scores
=
scores
+
top_k_logp
# (B*N, N), broadcast add
scores
=
scores
.
view
(
batch_size
,
beam_size
*
beam_size
)
# (B, N*N)
scores
=
scores
.
view
(
batch_size
,
beam_size
*
beam_size
)
# (B, N*N)
scores
,
offset_k_index
=
scores
.
topk
(
k
=
beam_size
)
# (B, N)
scores
,
offset_k_index
=
scores
.
topk
(
k
=
beam_size
)
# (B, N)
...
@@ -356,9 +360,10 @@ class U2BaseModel(nn.Layer):
...
@@ -356,9 +360,10 @@ class U2BaseModel(nn.Layer):
hyps
=
paddle
.
cat
(
hyps
=
paddle
.
cat
(
(
last_best_k_hyps
,
best_k_pred
.
view
(
-
1
,
1
)),
(
last_best_k_hyps
,
best_k_pred
.
view
(
-
1
,
1
)),
dim
=
1
)
# (B*N, i+1)
dim
=
1
)
# (B*N, i+1)
# logger.info(f"att: hyps {hyps}")
# 2.6 Update end flag
# 2.6 Update end flag
end_flag
=
paddle
.
eq
(
hyps
[:,
-
1
],
self
.
eos
).
view
(
-
1
,
1
)
end_flag
=
paddle
.
eq
(
hyps
[:,
-
1
],
self
.
eos
).
view
(
-
1
,
1
)
# logger.info(f"att: end_flag {end_flag}")
# 3. Select best of best
# 3. Select best of best
scores
=
scores
.
view
(
batch_size
,
beam_size
)
scores
=
scores
.
view
(
batch_size
,
beam_size
)
...
@@ -368,6 +373,7 @@ class U2BaseModel(nn.Layer):
...
@@ -368,6 +373,7 @@ class U2BaseModel(nn.Layer):
batch_size
,
dtype
=
paddle
.
long
)
*
beam_size
batch_size
,
dtype
=
paddle
.
long
)
*
beam_size
best_hyps
=
paddle
.
index_select
(
hyps
,
index
=
best_hyps_index
,
axis
=
0
)
best_hyps
=
paddle
.
index_select
(
hyps
,
index
=
best_hyps_index
,
axis
=
0
)
best_hyps
=
best_hyps
[:,
1
:]
best_hyps
=
best_hyps
[:,
1
:]
# logger.info(f"att: best_hyps {best_hyps}")
return
best_hyps
return
best_hyps
def
ctc_greedy_search
(
def
ctc_greedy_search
(
...
@@ -802,6 +808,7 @@ class U2BaseModel(nn.Layer):
...
@@ -802,6 +808,7 @@ class U2BaseModel(nn.Layer):
else
:
else
:
raise
ValueError
(
f
"Not support decoding method:
{
decoding_method
}
"
)
raise
ValueError
(
f
"Not support decoding method:
{
decoding_method
}
"
)
logger
.
info
(
f
"hyps:
{
hyps
}
"
)
res
=
[
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
res
=
[
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
return
res
return
res
...
...
deepspeech/modules/loss.py
浏览文件 @
862150b5
...
@@ -49,7 +49,7 @@ class CTCLoss(nn.Layer):
...
@@ -49,7 +49,7 @@ class CTCLoss(nn.Layer):
# (TODO:Hui Zhang) ctc loss does not support int64 labels
# (TODO:Hui Zhang) ctc loss does not support int64 labels
ys_pad
=
ys_pad
.
astype
(
paddle
.
int32
)
ys_pad
=
ys_pad
.
astype
(
paddle
.
int32
)
loss
=
self
.
loss
(
loss
=
self
.
loss
(
logits
,
ys_pad
,
hlens
,
ys_lens
,
norm_by_
times
=
self
.
batch_average
)
logits
,
ys_pad
,
hlens
,
ys_lens
,
norm_by_
batchsize
=
self
.
batch_average
)
if
self
.
batch_average
:
if
self
.
batch_average
:
# Batch-size average
# Batch-size average
loss
=
loss
/
B
loss
=
loss
/
B
...
@@ -90,8 +90,8 @@ class LabelSmoothingLoss(nn.Layer):
...
@@ -90,8 +90,8 @@ class LabelSmoothingLoss(nn.Layer):
size (int): the number of class
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
normalize_length (bool):
True, normalize loss by sequence length;
True, normalize loss by sequence length;
False, normalize loss by batch size.
False, normalize loss by batch size.
Defaults to False.
Defaults to False.
"""
"""
...
@@ -108,7 +108,7 @@ class LabelSmoothingLoss(nn.Layer):
...
@@ -108,7 +108,7 @@ class LabelSmoothingLoss(nn.Layer):
The model outputs and data labels tensors are flatten to
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
padding part which should not be calculated for loss.
Args:
Args:
x (paddle.Tensor): prediction (batch, seqlen, class)
x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor):
target (paddle.Tensor):
...
...
deepspeech/training/trainer.py
浏览文件 @
862150b5
...
@@ -163,8 +163,8 @@ class Trainer():
...
@@ -163,8 +163,8 @@ class Trainer():
checkpoint_path
=
self
.
args
.
checkpoint_path
)
checkpoint_path
=
self
.
args
.
checkpoint_path
)
if
infos
:
if
infos
:
# restore from ckpt
# restore from ckpt
self
.
iteration
=
infos
[
"step"
]
self
.
iteration
=
infos
[
"step"
]
+
1
self
.
epoch
=
infos
[
"epoch"
]
self
.
epoch
=
infos
[
"epoch"
]
+
1
scratch
=
False
scratch
=
False
else
:
else
:
self
.
iteration
=
0
self
.
iteration
=
0
...
...
examples/aishell/s0/run.sh
浏览文件 @
862150b5
...
@@ -6,7 +6,7 @@ gpus=0,1,2,3
...
@@ -6,7 +6,7 @@ gpus=0,1,2,3
stage
=
0
stage
=
0
stop_stage
=
100
stop_stage
=
100
conf_path
=
conf/deepspeech2.yaml
conf_path
=
conf/deepspeech2.yaml
avg_num
=
1
avg_num
=
5
model_type
=
offline
model_type
=
offline
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
...
examples/librispeech/s1/conf/transformer.yaml
浏览文件 @
862150b5
...
@@ -16,7 +16,7 @@ collator:
...
@@ -16,7 +16,7 @@ collator:
spm_model_prefix
:
'
data/bpe_unigram_5000'
spm_model_prefix
:
'
data/bpe_unigram_5000'
mean_std_filepath
:
"
"
mean_std_filepath
:
"
"
augmentation_config
:
conf/augmentation.json
augmentation_config
:
conf/augmentation.json
batch_size
:
64
batch_size
:
32
raw_wav
:
True
# use raw_wav or kaldi feature
raw_wav
:
True
# use raw_wav or kaldi feature
specgram_type
:
fbank
#linear, mfcc, fbank
specgram_type
:
fbank
#linear, mfcc, fbank
feat_dim
:
80
feat_dim
:
80
...
@@ -73,13 +73,13 @@ model:
...
@@ -73,13 +73,13 @@ model:
training
:
training
:
n_epoch
:
120
n_epoch
:
120
accum_grad
:
2
accum_grad
:
4
global_grad_clip
:
5.0
global_grad_clip
:
5.0
optim
:
adam
optim
:
noam
optim_conf
:
optim_conf
:
lr
:
0.004
lr
:
10.0
weight_decay
:
1e-06
weight_decay
:
1e-06
scheduler
:
warmuplr
# pytorch v1.1.0+ required
scheduler
:
noam
# pytorch v1.1.0+ required
scheduler_conf
:
scheduler_conf
:
warmup_steps
:
25000
warmup_steps
:
25000
lr_decay
:
1.0
lr_decay
:
1.0
...
...
examples/librispeech/s2/README.md
浏览文件 @
862150b5
...
@@ -14,11 +14,6 @@
...
@@ -14,11 +14,6 @@
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | |
## Chunk Conformer
## Chunk Conformer
...
@@ -33,9 +28,6 @@
...
@@ -33,9 +28,6 @@
## Transformer
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 9.27137279510498, | 0.038421 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 9.27137279510498, | 0.120112 |
### Test w/o length filter
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 9.27137279510498, | 0.116441 |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | |
examples/librispeech/s2/conf/transformer.yaml
浏览文件 @
862150b5
...
@@ -12,7 +12,7 @@ collator:
...
@@ -12,7 +12,7 @@ collator:
stride_ms
:
10.0
stride_ms
:
10.0
window_ms
:
25.0
window_ms
:
25.0
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
3
2
batch_size
:
3
0
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
minibatches
:
0
# for debug
...
@@ -22,7 +22,7 @@ collator:
...
@@ -22,7 +22,7 @@ collator:
batch_frames_out
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
batch_frames_inout
:
0
augmentation_config
:
conf/augmentation.json
augmentation_config
:
conf/augmentation.json
num_workers
:
2
num_workers
:
0
subsampling_factor
:
1
subsampling_factor
:
1
num_encs
:
1
num_encs
:
1
...
@@ -81,7 +81,7 @@ scheduler_conf:
...
@@ -81,7 +81,7 @@ scheduler_conf:
lr_decay
:
1.0
lr_decay
:
1.0
decoding
:
decoding
:
batch_size
:
64
batch_size
:
1
error_rate_type
:
wer
error_rate_type
:
wer
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
...
...
examples/librispeech/s2/local/test.sh
浏览文件 @
862150b5
...
@@ -30,13 +30,15 @@ echo "chunk mode ${chunk_mode}"
...
@@ -30,13 +30,15 @@ echo "chunk mode ${chunk_mode}"
# exit 1
# exit 1
#fi
#fi
#for type in attention ctc_greedy_search; do
for
type
in
attention ctc_greedy_search
;
do
for
type
in
attention ctc_greedy_search
;
do
echo
"decoding
${
type
}
"
echo
"decoding
${
type
}
"
if
[
${
chunk_mode
}
==
true
]
;
then
if
[
${
chunk_mode
}
==
true
]
;
then
# stream decoding only support batchsize=1
# stream decoding only support batchsize=1
batch_size
=
1
batch_size
=
1
else
else
batch_size
=
64
#batch_size=64
batch_size
=
1
fi
fi
python3
-u
${
BIN_DIR
}
/test.py
\
python3
-u
${
BIN_DIR
}
/test.py
\
--model-name
u2_kaldi
\
--model-name
u2_kaldi
\
...
...
examples/librispeech/s2/local/train.sh
浏览文件 @
862150b5
...
@@ -19,8 +19,8 @@ echo "using ${device}..."
...
@@ -19,8 +19,8 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
seed
=
0
if
[
${
seed
}
]
;
then
if
[
${
seed
}
!=
0
]
;
then
export
FLAGS_cudnn_deterministic
=
True
export
FLAGS_cudnn_deterministic
=
True
fi
fi
...
@@ -32,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \
...
@@ -32,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \
--output
exp/
${
ckpt_name
}
\
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
--seed
${
seed
}
if
[
${
seed
}
]
;
then
if
[
${
seed
}
!=
0
]
;
then
unset
FLAGS_cudnn_deterministic
unset
FLAGS_cudnn_deterministic
fi
fi
...
...
examples/librispeech/s2/run.sh
浏览文件 @
862150b5
...
@@ -6,7 +6,7 @@ stage=0
...
@@ -6,7 +6,7 @@ stage=0
stop_stage
=
100
stop_stage
=
100
conf_path
=
conf/transformer.yaml
conf_path
=
conf/transformer.yaml
dict_path
=
data/train_960_unigram5000_units.txt
dict_path
=
data/train_960_unigram5000_units.txt
avg_num
=
5
avg_num
=
10
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
avg_ckpt
=
avg_
${
avg_num
}
avg_ckpt
=
avg_
${
avg_num
}
...
@@ -20,12 +20,12 @@ fi
...
@@ -20,12 +20,12 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
0,1,2,3 ./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
0,1,2,3
,4,5,6,7
./local/train.sh
${
conf_path
}
${
ckpt
}
fi
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# avg n best model
# avg n best model
avg.sh exp/
${
ckpt
}
/checkpoints
${
avg_num
}
avg.sh
latest
exp/
${
ckpt
}
/checkpoints
${
avg_num
}
fi
fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
...
...
utils/avg_model.py
浏览文件 @
862150b5
...
@@ -80,8 +80,8 @@ def main(args):
...
@@ -80,8 +80,8 @@ def main(args):
data
=
json
.
dumps
({
data
=
json
.
dumps
({
"avg_ckpt"
:
args
.
dst_model
,
"avg_ckpt"
:
args
.
dst_model
,
"ckpt"
:
path_list
,
"ckpt"
:
path_list
,
"epoch"
:
selected_epochs
.
tolist
()
,
"epoch"
:
selected_epochs
,
"val_loss"
:
beat_val_scores
.
tolist
()
,
"val_loss"
:
beat_val_scores
,
})
})
f
.
write
(
data
+
"
\n
"
)
f
.
write
(
data
+
"
\n
"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录