未验证 提交 2b55290e 编写于 作者: X Xiaoxu Chen 提交者: GitHub

fix StickBreakingTransform forward error when input rank is over 2 (#41940)

上级 469e3198
...@@ -1141,7 +1141,8 @@ class StickBreakingTransform(Transform): ...@@ -1141,7 +1141,8 @@ class StickBreakingTransform(Transform):
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
z = F.sigmoid(x - offset.log()) z = F.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1) z_cumprod = (1 - z).cumprod(-1)
return F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1) return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \
F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1)
def _inverse(self, y): def _inverse(self, y):
y_crop = y[..., :-1] y_crop = y[..., :-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册