提交 e66da76d 编写于 作者: H huangyuxin

fix the bug of chooing dataloader, remove the log of downloads lm, change the epoch in tiny

上级 34178893
...@@ -153,8 +153,12 @@ class DeepSpeech2Trainer(Trainer): ...@@ -153,8 +153,12 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config.clone() config = self.config.clone()
with UpdateConfig(config): with UpdateConfig(config):
if self.train:
config.model.feat_size = self.train_loader.collate_fn.feature_size config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size config.model.dict_size = self.train_loader.collate_fn.vocab_size
else:
config.model.feat_size = self.test_loader.collate_fn.feature_size
config.model.dict_size = self.test_loader.collate_fn.vocab_size
if self.args.model_type == 'offline': if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model) model = DeepSpeech2Model.from_config(config.model)
...@@ -189,7 +193,6 @@ class DeepSpeech2Trainer(Trainer): ...@@ -189,7 +193,6 @@ class DeepSpeech2Trainer(Trainer):
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
logger.info("Setup optimizer/lr_scheduler!") logger.info("Setup optimizer/lr_scheduler!")
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost() config.defrost()
......
...@@ -13,7 +13,7 @@ ckpt_prefix=$2 ...@@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_ch.sh bash local/download_lm_ch.sh > dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
...@@ -13,7 +13,7 @@ jit_model_export_path=$2 ...@@ -13,7 +13,7 @@ jit_model_export_path=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_ch.sh bash local/download_lm_ch.sh > dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
...@@ -13,7 +13,7 @@ ckpt_prefix=$2 ...@@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_en.sh bash local/download_lm_en.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
...@@ -13,7 +13,7 @@ ckpt_prefix=$2 ...@@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_ch.sh bash local/download_lm_ch.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
...@@ -13,7 +13,7 @@ ckpt_prefix=$2 ...@@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_en.sh bash local/download_lm_en.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
...@@ -13,7 +13,7 @@ ckpt_prefix=$2 ...@@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_en.sh bash local/download_lm_en.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
...@@ -45,7 +45,7 @@ model: ...@@ -45,7 +45,7 @@ model:
ctc_grad_norm_type: null ctc_grad_norm_type: null
training: training:
n_epoch: 10 n_epoch: 5
accum_grad: 1 accum_grad: 1
lr: 1e-5 lr: 1e-5
lr_decay: 0.8 lr_decay: 0.8
......
...@@ -47,7 +47,7 @@ model: ...@@ -47,7 +47,7 @@ model:
ctc_grad_norm_type: null ctc_grad_norm_type: null
training: training:
n_epoch: 10 n_epoch: 5
accum_grad: 1 accum_grad: 1
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
......
...@@ -13,7 +13,7 @@ ckpt_prefix=$2 ...@@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_en.sh bash local/download_lm_en.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
...@@ -83,7 +83,7 @@ model: ...@@ -83,7 +83,7 @@ model:
training: training:
n_epoch: 20 n_epoch: 5
accum_grad: 1 accum_grad: 1
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
......
...@@ -76,7 +76,7 @@ model: ...@@ -76,7 +76,7 @@ model:
training: training:
n_epoch: 20 n_epoch: 5
accum_grad: 1 accum_grad: 1
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
......
...@@ -79,7 +79,7 @@ model: ...@@ -79,7 +79,7 @@ model:
training: training:
n_epoch: 20 n_epoch: 5
accum_grad: 4 accum_grad: 4
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
......
...@@ -73,7 +73,7 @@ model: ...@@ -73,7 +73,7 @@ model:
training: training:
n_epoch: 21 n_epoch: 5
accum_grad: 1 accum_grad: 1
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册