未验证 提交 2b55290e 编写于 作者: X Xiaoxu Chen 提交者: GitHub

fix StickBreakingTransform forward error when input rank is over 2 (#41940)

上级 469e3198
...@@ -113,7 +113,7 @@ class Transform(object): ...@@ -113,7 +113,7 @@ class Transform(object):
* _forward_shape * _forward_shape
* _inverse_shape * _inverse_shape
""" """
_type = Type.INJECTION _type = Type.INJECTION
...@@ -669,7 +669,7 @@ class IndependentTransform(Transform): ...@@ -669,7 +669,7 @@ class IndependentTransform(Transform):
base (Transform): The base transformation. base (Transform): The base transformation.
reinterpreted_batch_rank (int): The num of rightmost batch rank that reinterpreted_batch_rank (int): The num of rightmost batch rank that
will be reinterpreted as event rank. will be reinterpreted as event rank.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -743,7 +743,7 @@ class PowerTransform(Transform): ...@@ -743,7 +743,7 @@ class PowerTransform(Transform):
Args: Args:
power (Tensor): The power parameter. power (Tensor): The power parameter.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1017,7 +1017,7 @@ class StackTransform(Transform): ...@@ -1017,7 +1017,7 @@ class StackTransform(Transform):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -1141,7 +1141,8 @@ class StickBreakingTransform(Transform): ...@@ -1141,7 +1141,8 @@ class StickBreakingTransform(Transform):
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
z = F.sigmoid(x - offset.log()) z = F.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1) z_cumprod = (1 - z).cumprod(-1)
return F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1) return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \
F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1)
def _inverse(self, y): def _inverse(self, y):
y_crop = y[..., :-1] y_crop = y[..., :-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册