From 60c212d5e53db552a8777ac0a6bec2a056630ab1 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Wed, 20 Apr 2022 19:43:52 +0800 Subject: [PATCH] [cherry-pick]fix StickBreakingTransform forward error when input rank is over 2 (#41940) (#41983) --- python/paddle/distribution/transform.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index dd0e63f048e..31b1dedbc5f 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -113,7 +113,7 @@ class Transform(object): * _forward_shape * _inverse_shape - + """ _type = Type.INJECTION @@ -669,7 +669,7 @@ class IndependentTransform(Transform): base (Transform): The base transformation. reinterpreted_batch_rank (int): The num of rightmost batch rank that will be reinterpreted as event rank. - + Examples: .. code-block:: python @@ -743,7 +743,7 @@ class PowerTransform(Transform): Args: power (Tensor): The power parameter. - + Examples: .. code-block:: python @@ -1017,7 +1017,7 @@ class StackTransform(Transform): Examples: .. code-block:: python - + import paddle @@ -1141,7 +1141,8 @@ class StickBreakingTransform(Transform): offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) z = F.sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) - return F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1) + return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \ + F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1) def _inverse(self, y): y_crop = y[..., :-1] -- GitLab