提交 abd2f589 编写于 作者: H Hui Zhang

fix bugs, add coverage, add scripts

上级 22fce191
#!/bin/bash #!/bin/bash
abort(){
echo "Run unittest failed" 1>&2
echo "Please check your code" 1>&2
exit 1
}
unittest(){ unittest(){
cd $1 > /dev/null cd $1 > /dev/null
find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \ find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \
...@@ -8,17 +17,22 @@ unittest(){ ...@@ -8,17 +17,22 @@ unittest(){
cd - > /dev/null cd - > /dev/null
} }
abort(){ coverage(){
echo "Run unittest failed" 1>&2 cd $1 > /dev/null
echo "Please check your code" 1>&2 find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \
exit 1 xargs -0 -I{} -n1 bash -c \
'python3 -m coverage run --branch {}'
python3 -m coverage report -m
python3 -m coverage html
cd - > /dev/null
} }
trap 'abort' 0 trap 'abort' 0
set -e set -e
source tools/venv/bin/activate source tools/venv/bin/activate
pip3 install pytest #pip3 install pytest
unittest . #unittest .
coverage .
trap : 0 trap : 0
...@@ -74,7 +74,7 @@ class U2Trainer(Trainer): ...@@ -74,7 +74,7 @@ class U2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
def train_batch(self, batch_data): def train_batch(self, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training
self.model.train() self.model.train()
...@@ -93,12 +93,9 @@ class U2Trainer(Trainer): ...@@ -93,12 +93,9 @@ class U2Trainer(Trainer):
'train_att_loss': float(attention_loss), 'train_att_loss': float(attention_loss),
'train_ctc_loss': float(ctc_loss), 'train_ctc_loss': float(ctc_loss),
} }
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s, ".format(iteration_time) msg += "time: {:>.3f}s, ".format(iteration_time)
msg += f"batch size: {self.config.data.batch_size}, " msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += f"accum: {train_config.accum_grad}, " msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
if self.iteration % train_conf.log_interval == 0: if self.iteration % train_conf.log_interval == 0:
......
...@@ -16,6 +16,7 @@ import logging ...@@ -16,6 +16,7 @@ import logging
from typing import Optional from typing import Optional
from yacs.config import CfgNode from yacs.config import CfgNode
import paddle
from paddle import nn from paddle import nn
from deepspeech.modules.conv import ConvStack from deepspeech.modules.conv import ConvStack
......
...@@ -761,8 +761,10 @@ class U2Model(U2BaseModel): ...@@ -761,8 +761,10 @@ class U2Model(U2BaseModel):
Returns: Returns:
DeepSpeech2Model: The model built from pretrained result. DeepSpeech2Model: The model built from pretrained result.
""" """
config.input_dim = self.dataset.feature_size config.defrost()
config.output_dim = self.dataset.vocab_size config.input_dim = dataset.feature_size
config.output_dim = dataset.vocab_size
config.freeze()
model = cls.from_config(config) model = cls.from_config(config)
if checkpoint_path: if checkpoint_path:
......
...@@ -181,13 +181,12 @@ class Trainer(): ...@@ -181,13 +181,12 @@ class Trainer():
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "dataloader time: {:>.3f}s, ".format(dataload_time) msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.logger.info(msg)
self.iteration += 1 self.iteration += 1
self.train_batch(batch) self.train_batch(batch, msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: except Exception as e:
self.logger.error(e) self.logger.error(e)
pass raise e
self.valid() self.valid()
self.save() self.save()
......
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: export ckpt_path jit_model_path"
exit -1
fi
python3 -u ${BIN_DIR}/export.py \
--config conf/conformer.yaml \
--checkpoint_path ${1} \
--export_path ${2}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
#! /usr/bin/env bash
# download language model
bash local/download_lm_en.sh
if [ $? -ne 0 ]; then
exit 1
fi
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/test.py \
--device 'gpu' \
--nproc 1 \
--config conf/conformer.yaml \
--output ckpt
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
#! /usr/bin/env bash #! /usr/bin/env bash
export FLAGS_sync_nccl_allreduce=0
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \ --device 'gpu' \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册