提交 90788b11 编写于 作者: H Hui Zhang

more comment; fix datapipe of align

上级 1e2a5887
......@@ -355,7 +355,7 @@ class U2Tester(U2Trainer):
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
simulate_streaming=False, # simulate streaming inference. Defaults to False.
))
......@@ -512,11 +512,13 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Align Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.dataset.stride_ms
token_dict = self.test_loader.dataset.vocab_list
stride_ms = self.test_loader.collate_fn.stride_ms
token_dict = self.test_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.test_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
......@@ -529,28 +531,31 @@ class U2Tester(U2Trainer):
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
print(alignment)
print(kye[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
print(align_segs)
print(kye[0], align_segs)
# IntervalTier, List["start end token\n"]
subsample = get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
tier_path = os.path.join(
os.path.dirname(args.result_file), key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = s.path.join(
os.path.dirname(args.result_file), key[0] + ".TextGrid")
second_per_frame = 1. / (1000. / stride_ms
) # 25ms window, 10ms stride
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=(len(alignment) + 1) * subsample * second_per_frame,
maxtime=second_per_example,
lines=tierformat,
output=textgrid_path)
......
......@@ -38,8 +38,10 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
# add non-blank into new_hyp
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
# skip repeat label
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
......@@ -52,7 +54,7 @@ def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
"abcdefg" -> "-a-b-c-d-e-f-g-"
Args:
label ([np.ndarray]): label ids, (L).
label ([np.ndarray]): label ids, List[int], (L).
blank_id (int, optional): blank id. Defaults to 0.
Returns:
......@@ -61,8 +63,8 @@ def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
label = np.expand_dims(label, 1) #[L, 1]
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1) #[L, 2]
label = label.reshape(-1) #[2L]
label = np.append(label, label[0]) #[2L + 1]
label = label.reshape(-1) #[2L], -l-l-l
label = np.append(label, label[0]) #[2L + 1], -l-l-l-
return label
......@@ -79,21 +81,21 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
Returns:
List[int]: best alignment result, (T).
"""
y_insert_blank = insert_blank(y, blank_id)
y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1
) # state path
) # state path, Tuple((T, 2L+1))
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb
for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
for t in range(1, ctc_probs.size(0)): # T
for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = paddle.to_tensor(
......
......@@ -22,11 +22,13 @@ def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]:
"""segment ctc alignment ids by continuous blank and repeat label.
Args:
alignment (List[int]): ctc alignment id sequence. e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
alignment (List[int]): ctc alignment id sequence.
e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[List[int]]: segment aligment id sequence. e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
List[List[int]]: token align, segment aligment id sequence.
e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
"""
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
......@@ -61,7 +63,7 @@ def align_to_tierformat(align_segs: List[List[int]],
token_dict (Dict[int, Text]): int -> str map.
Returns:
List[Text]: list of textgrid.Interval.
List[Text]: list of textgrid.Interval text, str(start, end, text).
"""
hop_length = 10 # ms
second_ms = 1000 # ms
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册