未验证 提交 2b55290e 编写于 作者: X Xiaoxu Chen 提交者: GitHub

fix StickBreakingTransform forward error when input rank is over 2 (#41940)

上级 469e3198
......@@ -113,7 +113,7 @@ class Transform(object):
* _forward_shape
* _inverse_shape
"""
_type = Type.INJECTION
......@@ -669,7 +669,7 @@ class IndependentTransform(Transform):
base (Transform): The base transformation.
reinterpreted_batch_rank (int): The num of rightmost batch rank that
will be reinterpreted as event rank.
Examples:
.. code-block:: python
......@@ -743,7 +743,7 @@ class PowerTransform(Transform):
Args:
power (Tensor): The power parameter.
Examples:
.. code-block:: python
......@@ -1017,7 +1017,7 @@ class StackTransform(Transform):
Examples:
.. code-block:: python
import paddle
......@@ -1141,7 +1141,8 @@ class StickBreakingTransform(Transform):
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
z = F.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
return F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1)
return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \
F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1)
def _inverse(self, y):
y_crop = y[..., :-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册