diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index dd0e63f048ea5d942b97060b3ec24f6daac68362..31b1dedbc5fb36420b07266cd03fcb8ef9f9ad7a 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -113,7 +113,7 @@ class Transform(object): * _forward_shape * _inverse_shape - + """ _type = Type.INJECTION @@ -669,7 +669,7 @@ class IndependentTransform(Transform): base (Transform): The base transformation. reinterpreted_batch_rank (int): The num of rightmost batch rank that will be reinterpreted as event rank. - + Examples: .. code-block:: python @@ -743,7 +743,7 @@ class PowerTransform(Transform): Args: power (Tensor): The power parameter. - + Examples: .. code-block:: python @@ -1017,7 +1017,7 @@ class StackTransform(Transform): Examples: .. code-block:: python - + import paddle @@ -1141,7 +1141,8 @@ class StickBreakingTransform(Transform): offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) z = F.sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) - return F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1) + return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \ + F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1) def _inverse(self, y): y_crop = y[..., :-1]