diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index bb2e181d7bb98bd819b5285a3f077e593bdbdf52..880bab7d6e3c861404917cb6657610d76b678a03 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -77,7 +77,7 @@ class TransformedDistribution(distribution.Distribution): max(len(base.event_shape)-chain._domain.event_rank, 0) super(TransformedDistribution, self).__init__( transformed_shape[:len(transformed_shape) - transformed_event_rank], - transformed_shape[:len(transformed_shape) - transformed_event_rank]) + transformed_shape[len(transformed_shape) - transformed_event_rank:]) def sample(self, shape=()): """Sample from ``TransformedDistribution``.