提交 7382050e 编写于 作者: H Hui Zhang

fix bug on win

上级 d25871a7
......@@ -237,7 +237,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
max_len = paddle.max(ys_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = ys_lens.unsqueeze(1)
......@@ -279,7 +279,8 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = paddle.where(seq_mask, r_hyps, eos)
_eos = paddle.full([1], eos, dtype=r_hyps.dtype)
r_hyps = paddle.where(seq_mask, r_hyps, _eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
......
......@@ -600,7 +600,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
logger.info(
logger.debug(
f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册