未验证 提交 60c212d5 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[cherry-pick]fix StickBreakingTransform forward error when input rank is over 2 (#41940) (#41983)

上级 23cc4636
......@@ -1141,7 +1141,8 @@ class StickBreakingTransform(Transform):
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
z = F.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
return F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1)
return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \
F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1)
def _inverse(self, y):
y_crop = y[..., :-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册