test_distribution_independent.py 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import numpy as np
import paddle

import config
import parameterize as param

22 23
np.random.seed(2022)

24 25

@param.place(config.DEVICES)
26 27 28 29
@param.param_cls(
    (param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank'),
    [('base_beta',
      paddle.distribution.Beta(paddle.rand([1, 2]), paddle.rand([1, 2])), 1)])
30
class TestIndependent(unittest.TestCase):
31

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    def setUp(self):
        self._t = paddle.distribution.Independent(self.base,
                                                  self.reinterpreted_batch_rank)

    def test_mean(self):
        np.testing.assert_allclose(
            self.base.mean,
            self._t.mean,
            rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
            atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))

    def test_variance(self):
        np.testing.assert_allclose(
            self.base.variance,
            self._t.variance,
            rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
            atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))

    def test_entropy(self):
        np.testing.assert_allclose(
            self._np_sum_rightmost(self.base.entropy().numpy(),
                                   self.reinterpreted_batch_rank),
            self._t.entropy(),
            rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
            atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))

    def _np_sum_rightmost(self, value, n):
        return np.sum(value, tuple(range(-n, 0))) if n > 0 else value

    def test_log_prob(self):
        value = np.random.rand(1)
        np.testing.assert_allclose(
            self._np_sum_rightmost(
                self.base.log_prob(paddle.to_tensor(value)).numpy(),
                self.reinterpreted_batch_rank),
            self._t.log_prob(paddle.to_tensor(value)).numpy(),
            rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
            atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))

    # TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
    def test_sample(self):
        shape = (5, 10, 8)
        expected_shape = (5, 10, 8, 1, 2)
        data = self._t.sample(shape)
        self.assertEqual(tuple(data.shape), expected_shape)
        self.assertEqual(data.dtype, self.base.alpha.dtype)


@param.place(config.DEVICES)
@param.param_cls(
    (param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank',
     'expected_exception'),
    [('base_not_transform', '', 1, TypeError),
     ('rank_less_than_zero', paddle.distribution.Transform(), -1, ValueError)])
class TestIndependentException(unittest.TestCase):
87

88 89 90 91 92 93 94 95
    def test_init(self):
        with self.assertRaises(self.expected_exception):
            paddle.distribution.IndependentTransform(
                self.base, self.reinterpreted_batch_rank)


if __name__ == '__main__':
    unittest.main()