diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index ce386971e5fcce9d300e966e0e147bdc51478a0d..160af5e4870af48992afd0eeb785f5c8a69d93f3 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -77,7 +77,7 @@ class TransformedDistribution(distribution.Distribution): max(len(base.event_shape)-chain._domain.event_rank, 0) super(TransformedDistribution, self).__init__( transformed_shape[:len(transformed_shape) - transformed_event_rank], - transformed_shape[:len(transformed_shape) - transformed_event_rank]) + transformed_shape[len(transformed_shape) - transformed_event_rank:]) def sample(self, shape=()): """Sample from ``TransformedDistribution``.