提交 6ee67785 编写于 作者: H Hui Zhang

fix ctc alignment

上级 7ec623f7
...@@ -39,6 +39,7 @@ from deepspeech.utils import error_rate ...@@ -39,6 +39,7 @@ from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -280,7 +281,15 @@ class U2Trainer(Trainer): ...@@ -280,7 +281,15 @@ class U2Trainer(Trainer):
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator.from_config(config)) collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test Dataloader!") # return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
...@@ -507,16 +516,17 @@ class U2Tester(U2Trainer): ...@@ -507,16 +516,17 @@ class U2Tester(U2Trainer):
sys.exit(1) sys.exit(1)
# xxx.align # xxx.align
assert self.args.result_file assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval() self.model.eval()
logger.info(f"Align Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.test_loader.collate_fn.stride_ms stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.test_loader.collate_fn.vocab_list token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
# one example in batch # one example in batch
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch key, feat, feats_length, target, target_length = batch
# 1. Encoder # 1. Encoder
...@@ -527,36 +537,36 @@ class U2Tester(U2Trainer): ...@@ -527,36 +537,36 @@ class U2Tester(U2Trainer):
encoder_out) # (1, maxlen, vocab_size) encoder_out) # (1, maxlen, vocab_size)
# 2. alignment # 2. alignment
# print(ctc_probs.size(1))
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0) target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target) alignment = ctc_utils.forced_align(ctc_probs, target)
print(kye[0], alignment) logger.info("align ids", key[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment)) fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat # 3. gen praat
# segment alignment # segment alignment
align_segs = text_grid.segment_alignment(alignment) align_segs = text_grid.segment_alignment(alignment)
print(kye[0], align_segs) logger.info("align tokens", key[0], align_segs)
# IntervalTier, List["start end token\n"] # IntervalTier, List["start end token\n"]
subsample = get_subsample(self.config) subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat( tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict) align_segs, subsample, token_dict)
# write tier # write tier
tier_path = os.path.join( align_output_path = os.path.join(
os.path.dirname(args.result_file), key[0] + ".tier") os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f: with open(tier_path, 'w') as f:
f.writelines(tierformat) f.writelines(tierformat)
# write textgrid # write textgrid
textgrid_path = s.path.join( textgrid_path = os.path.join(align_output_path,
os.path.dirname(args.result_file), key[0] + ".TextGrid") key[0] + ".TextGrid")
second_per_frame = 1. / (1000. / second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride stride_ms) # 25ms window, 10ms stride
second_per_example = ( second_per_example = (
len(alignment) + 1) * subsample * second_per_frame len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid( text_grid.generate_textgrid(
maxtime=second_per_example, maxtime=second_per_example,
lines=tierformat, intervals=tierformat,
output=textgrid_path) output=textgrid_path)
def run_align(self): def run_align(self):
......
...@@ -86,13 +86,15 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, ...@@ -86,13 +86,15 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha = paddle.zeros( log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
state_path = (paddle.zeros( state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1)) ) # state path, Tuple((T, 2L+1))
# init start state # init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)): # T for t in range(1, ctc_probs.size(0)): # T
for s in range(len(y_insert_blank)): # 2L+1 for s in range(len(y_insert_blank)): # 2L+1
...@@ -108,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, ...@@ -108,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2], log_alpha[t - 1, s - 2],
]) ])
prev_state = [s, s - 1, s - 2] prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
y_insert_blank[s]] log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)] state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) # TODO(Hui Zhang): zeros not support paddle.int16
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
candidates = paddle.to_tensor([ candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb log_alpha[-1, len(y_insert_blank) - 1], # Sb
......
...@@ -110,7 +110,7 @@ def generate_textgrid(maxtime: float, ...@@ -110,7 +110,7 @@ def generate_textgrid(maxtime: float,
""" """
# Download Praat: https://www.fon.hum.uva.nl/praat/ # Download Praat: https://www.fon.hum.uva.nl/praat/
avg_interval = maxtime / (len(intervals) + 1) avg_interval = maxtime / (len(intervals) + 1)
print(f"average duration per {name}: {avg_interval}") print(f"average second/token: {avg_interval}")
margin = 0.0001 margin = 0.0001
tg = textgrid.TextGrid(maxTime=maxtime) tg = textgrid.TextGrid(maxTime=maxtime)
......
...@@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float: ...@@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float:
a_max = max(args) a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args)) lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp return a_max + lsp
def get_subsample(config):
"""Subsample rate from config.
Args:
config (yacs.config.CfgNode): yaml config
Returns:
int: subsample rate.
"""
input_layer = config["model"]["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0
...@@ -19,7 +19,7 @@ kenlm.done: ...@@ -19,7 +19,7 @@ kenlm.done:
apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50 apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50
test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
cd kenlm && python setup.py install source venv/bin/activate; cd kenlm && python setup.py install
touch kenlm.done touch kenlm.done
sox.done: sox.done:
...@@ -32,4 +32,4 @@ sox.done: ...@@ -32,4 +32,4 @@ sox.done:
soxbindings.done: soxbindings.done:
test -d soxbindings || git clone https://github.com/pseeth/soxbindings.git test -d soxbindings || git clone https://github.com/pseeth/soxbindings.git
source venv/bin/activate; cd soxbindings && python setup.py install source venv/bin/activate; cd soxbindings && python setup.py install
touch soxbindings.done touch soxbindings.done
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册