提交 f9221b4b 编写于 作者: H Hui Zhang

fix ctc align

上级 56d06f2a
......@@ -554,7 +554,7 @@ class U2Tester(U2Trainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(
ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
......
......@@ -527,7 +527,7 @@ class U2Tester(U2Trainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(
ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
......
......@@ -543,10 +543,10 @@ class U2STTester(U2STTrainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(
ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
self.config.collator.stride_ms,
self.vocab_list, self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import List
from pathlib import Path
import numpy as np
import paddle
......@@ -139,26 +139,27 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
return output_alignment
def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
def ctc_align(config, model, dataloader, batch_size, stride_ms, token_dict,
result_file):
"""ctc alignment.
Args:
config (cfgNode): config
model (nn.Layer): U2 Model.
dataloader (io.DataLoader): dataloader.
batch_size (int): decoding batchsize.
stride_ms (int): audio feature stride in ms unit.
token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>'].
result_file (str): alignment output file, e.g. xxx.align.
result_file (str): alignment output file, e.g. /path/to/xxx.align.
"""
if batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
assert result_file and result_file.endswith('.align')
model.eval()
# conv subsampling rate
subsample = utility.get_subsample(config)
logger.info(f"Align Total Examples: {len(dataloader.dataset)}")
with open(result_file, 'w') as fout:
......@@ -187,13 +188,11 @@ def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(align_segs, subsample,
token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path = Path(result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册