test_distribution_dirichlet_static.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import scipy.stats
19
from config import ATOL, DEVICES, RTOL
20
from parameterize import TEST_CASE_NAME, parameterize_cls, place
21

22 23
import paddle

24
np.random.seed(2022)
25 26 27 28
paddle.enable_static()


@place(DEVICES)
29 30 31 32
@parameterize_cls(
    (TEST_CASE_NAME, 'concentration'),
    [('test-one-dim', np.random.rand(89) + 5.0)],
)
33 34 35 36 37
class TestDirichlet(unittest.TestCase):
    def setUp(self):
        self.program = paddle.static.Program()
        self.executor = paddle.static.Executor()
        with paddle.static.program_guard(self.program):
38 39 40
            conc = paddle.static.data(
                'conc', self.concentration.shape, self.concentration.dtype
            )
41 42 43 44 45
            self._paddle_diric = paddle.distribution.Dirichlet(conc)
            self.feeds = {'conc': self.concentration}

    def test_mean(self):
        with paddle.static.program_guard(self.program):
46 47 48 49 50
            [out] = self.executor.run(
                self.program,
                feed=self.feeds,
                fetch_list=[self._paddle_diric.mean],
            )
51 52 53 54
            np.testing.assert_allclose(
                out,
                scipy.stats.dirichlet.mean(self.concentration),
                rtol=RTOL.get(str(self.concentration.dtype)),
55 56
                atol=ATOL.get(str(self.concentration.dtype)),
            )
57 58 59

    def test_variance(self):
        with paddle.static.program_guard(self.program):
60 61 62 63 64
            [out] = self.executor.run(
                self.program,
                feed=self.feeds,
                fetch_list=[self._paddle_diric.variance],
            )
65 66 67 68
            np.testing.assert_allclose(
                out,
                scipy.stats.dirichlet.var(self.concentration),
                rtol=RTOL.get(str(self.concentration.dtype)),
69 70
                atol=ATOL.get(str(self.concentration.dtype)),
            )
71 72 73 74 75 76

    def test_prob(self):
        with paddle.static.program_guard(self.program):
            random_number = np.random.rand(*self.concentration.shape)
            random_number = random_number / random_number.sum()
            feeds = dict(self.feeds, value=random_number)
77 78 79
            value = paddle.static.data(
                'value', random_number.shape, random_number.dtype
            )
80
            out = self._paddle_diric.prob(value)
81 82 83
            [out] = self.executor.run(
                self.program, feed=feeds, fetch_list=[out]
            )
84 85 86 87
            np.testing.assert_allclose(
                out,
                scipy.stats.dirichlet.pdf(random_number, self.concentration),
                rtol=RTOL.get(str(self.concentration.dtype)),
88 89
                atol=ATOL.get(str(self.concentration.dtype)),
            )
90 91 92 93 94 95

    def test_log_prob(self):
        with paddle.static.program_guard(self.program):
            random_number = np.random.rand(*self.concentration.shape)
            random_number = random_number / random_number.sum()
            feeds = dict(self.feeds, value=random_number)
96 97 98
            value = paddle.static.data(
                'value', random_number.shape, random_number.dtype
            )
99
            out = self._paddle_diric.log_prob(value)
100 101 102
            [out] = self.executor.run(
                self.program, feed=feeds, fetch_list=[out]
            )
103 104 105 106
            np.testing.assert_allclose(
                out,
                scipy.stats.dirichlet.logpdf(random_number, self.concentration),
                rtol=RTOL.get(str(self.concentration.dtype)),
107 108
                atol=ATOL.get(str(self.concentration.dtype)),
            )
109 110 111

    def test_entropy(self):
        with paddle.static.program_guard(self.program):
112 113 114 115 116
            [out] = self.executor.run(
                self.program,
                feed=self.feeds,
                fetch_list=[self._paddle_diric.entropy()],
            )
117 118 119 120
            np.testing.assert_allclose(
                out,
                scipy.stats.dirichlet.entropy(self.concentration),
                rtol=RTOL.get(str(self.concentration.dtype)),
121 122
                atol=ATOL.get(str(self.concentration.dtype)),
            )