提交 ffb62624 编写于 作者: M Megvii Engine Team

ci(imperative): fix imperative test_ctc_loss test

GitOrigin-RevId: d48dccc30c994ef29123d83dc151ae5f4c32510f
上级 fa9d719f
......@@ -134,7 +134,6 @@ def _ctc_npy_single_seq(pred, label, blank):
x, y = np.maximum(x, y), np.minimum(x, y)
return x + np.log1p(np.exp(y - x))
assert np.abs(pred.sum(axis=1) - 1).max() <= 1e-3
len_pred, alphabet_size = pred.shape
(len_label,) = label.shape
......@@ -166,6 +165,8 @@ def test_ctc_loss():
def test_func(T, C, N):
input = np.random.randn(T, N, C)
input = F.softmax(Tensor(input), axis=-1).numpy()
# replace nan to 0.2
input = np.nan_to_num(input, copy=True, nan=0.2)
input_lengths = np.ones(N, dtype=np.int32) * T
target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32)
target = np.random.randint(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册