test_distribution_transformed_distribution_static.py 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import config
17
import numpy as np
18 19
import parameterize as param

20 21
import paddle

22 23 24 25
paddle.enable_static()


@param.place(config.DEVICES)
26 27 28 29 30 31 32 33 34 35
@param.param_cls(
    (param.TEST_CASE_NAME, 'base', 'transforms'),
    [
        (
            'base_normal',
            paddle.distribution.Normal,
            [paddle.distribution.ExpTransform()],
        )
    ],
)
36 37 38
class TestIndependent(unittest.TestCase):
    def setUp(self):
        value = np.array([0.5])
39 40
        loc = np.array([0.0])
        scale = np.array([1.0])
41 42 43 44 45 46 47 48 49 50 51
        shape = [5, 10, 8]
        self.dtype = value.dtype
        exe = paddle.static.Executor()
        sp = paddle.static.Program()
        mp = paddle.static.Program()
        with paddle.static.program_guard(mp, sp):
            static_value = paddle.static.data('value', value.shape, value.dtype)
            static_loc = paddle.static.data('loc', loc.shape, loc.dtype)
            static_scale = paddle.static.data('scale', scale.shape, scale.dtype)
            self.base = self.base(static_loc, static_scale)
            self._t = paddle.distribution.TransformedDistribution(
52 53
                self.base, self.transforms
            )
54 55
            actual_log_prob = self._t.log_prob(static_value)
            expected_log_prob = self.transformed_log_prob(
56 57
                static_value, self.base, self.transforms
            )
58 59 60
            sample_data = self._t.sample(shape)

        exe.run(sp)
61 62 63 64 65 66 67 68 69
        [
            self.actual_log_prob,
            self.expected_log_prob,
            self.sample_data,
        ] = exe.run(
            mp,
            feed={'value': value, 'loc': loc, 'scale': scale},
            fetch_list=[actual_log_prob, expected_log_prob, sample_data],
        )
70 71

    def test_log_prob(self):
72 73 74 75 76 77
        np.testing.assert_allclose(
            self.actual_log_prob,
            self.expected_log_prob,
            rtol=config.RTOL.get(str(self.dtype)),
            atol=config.ATOL.get(str(self.dtype)),
        )
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

    def transformed_log_prob(self, value, base, transforms):
        log_prob = 0.0
        y = value
        for t in reversed(transforms):
            x = t.inverse(y)
            log_prob = log_prob - t.forward_log_det_jacobian(x)
            y = x
        log_prob += base.log_prob(y)
        return log_prob

    # TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
    def test_sample(self):
        expected_shape = (5, 10, 8, 1)
        self.assertEqual(tuple(self.sample_data.shape), expected_shape)
        self.assertEqual(self.sample_data.dtype, self.dtype)


if __name__ == '__main__':
    unittest.main()