test_distribution_transformed_distribution.py 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import config
17
import numpy as np
18 19
import parameterize as param

20 21
import paddle

22 23

@param.place(config.DEVICES)
24 25 26 27 28 29 30 31 32 33
@param.param_cls(
    (param.TEST_CASE_NAME, 'base', 'transforms'),
    [
        (
            'base_normal',
            paddle.distribution.Normal(0.0, 1.0),
            [paddle.distribution.ExpTransform()],
        )
    ],
)
34 35
class TestIndependent(unittest.TestCase):
    def setUp(self):
36
        self._t = paddle.distribution.TransformedDistribution(
37 38
            self.base, self.transforms
        )
39 40 41 42 43 44 45 46 47 48

    def _np_sum_rightmost(self, value, n):
        return np.sum(value, tuple(range(-n, 0))) if n > 0 else value

    def test_log_prob(self):
        value = paddle.to_tensor(0.5)
        np.testing.assert_allclose(
            self.simple_log_prob(value, self.base, self.transforms),
            self._t.log_prob(value),
            rtol=config.RTOL.get(str(value.numpy().dtype)),
49 50
            atol=config.ATOL.get(str(value.numpy().dtype)),
        )
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

    def simple_log_prob(self, value, base, transforms):
        log_prob = 0.0
        y = value
        for t in reversed(transforms):
            x = t.inverse(y)
            log_prob = log_prob - t.forward_log_det_jacobian(x)
            y = x
        log_prob += base.log_prob(y)
        return log_prob

    # TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
    def test_sample(self):
        shape = [5, 10, 8]
        expected_shape = (5, 10, 8)
        data = self._t.sample(shape)
        self.assertEqual(tuple(data.shape), expected_shape)
        self.assertEqual(data.dtype, self.base.loc.dtype)

70 71
    def test_rsample(self):
        shape = [5, 10, 8]
72
        expected_shape = (5, 10, 8)
73 74 75 76
        data = self._t.rsample(shape)
        self.assertEqual(tuple(data.shape), expected_shape)
        self.assertEqual(data.dtype, self.base.loc.dtype)

77 78 79

if __name__ == '__main__':
    unittest.main()