提交 6ee67785 编写于 作者: H Hui Zhang

fix ctc alignment

上级 7ec623f7
......@@ -39,6 +39,7 @@ from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
......@@ -280,7 +281,15 @@ class U2Trainer(Trainer):
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test Dataloader!")
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):
config = self.config
......@@ -507,16 +516,17 @@ class U2Tester(U2Trainer):
sys.exit(1)
# xxx.align
assert self.args.result_file
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.test_loader.dataset)}")
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.test_loader.collate_fn.stride_ms
token_dict = self.test_loader.collate_fn.vocab_list
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.test_loader):
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
......@@ -527,36 +537,36 @@ class U2Tester(U2Trainer):
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
# print(ctc_probs.size(1))
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
print(kye[0], alignment)
logger.info("align ids", key[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
print(kye[0], align_segs)
logger.info("align tokens", key[0], align_segs)
# IntervalTier, List["start end token\n"]
subsample = get_subsample(self.config)
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
tier_path = os.path.join(
os.path.dirname(args.result_file), key[0] + ".tier")
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = s.path.join(
os.path.dirname(args.result_file), key[0] + ".TextGrid")
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
lines=tierformat,
intervals=tierformat,
output=textgrid_path)
def run_align(self):
......
......@@ -86,13 +86,15 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1))
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)): # T
for s in range(len(y_insert_blank)): # 2L+1
......@@ -108,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
y_insert_blank[s]]
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16)
# TODO(Hui Zhang): zeros not support paddle.int16
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb
......
......@@ -110,7 +110,7 @@ def generate_textgrid(maxtime: float,
"""
# Download Praat: https://www.fon.hum.uva.nl/praat/
avg_interval = maxtime / (len(intervals) + 1)
print(f"average duration per {name}: {avg_interval}")
print(f"average second/token: {avg_interval}")
margin = 0.0001
tg = textgrid.TextGrid(maxTime=maxtime)
......
......@@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float:
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
def get_subsample(config):
"""Subsample rate from config.
Args:
config (yacs.config.CfgNode): yaml config
Returns:
int: subsample rate.
"""
input_layer = config["model"]["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
......@@ -19,7 +19,7 @@ kenlm.done:
apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50
test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
cd kenlm && python setup.py install
source venv/bin/activate; cd kenlm && python setup.py install
touch kenlm.done
sox.done:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册