test_distribution_expfamily_static.py 2.1 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import config
import mock_data as mock
19
import numpy as np
20
import parameterize
21

22 23
import paddle

24
np.random.seed(2022)
25 26 27
paddle.enable_static()


28
@parameterize.place(config.DEVICES)
29 30 31 32 33
class TestExponentialFamily(unittest.TestCase):
    def setUp(self):
        self.program = paddle.static.Program()
        self.executor = paddle.static.Executor()
        with paddle.static.program_guard(self.program):
34
            rate_np = parameterize.xrand((100, 200, 99))
35 36 37 38 39 40 41 42 43 44 45 46
            rate = paddle.static.data('rate', rate_np.shape, rate_np.dtype)
            self.mock_dist = mock.Exponential(rate)
            self.feeds = {'rate': rate_np}

    def test_entropy(self):
        with paddle.static.program_guard(self.program):
            [out1, out2] = self.executor.run(
                self.program,
                feed=self.feeds,
                fetch_list=[
                    self.mock_dist.entropy(),
                    paddle.distribution.ExponentialFamily.entropy(
47 48 49 50
                        self.mock_dist
                    ),
                ],
            )
51 52 53 54 55

            np.testing.assert_allclose(
                out1,
                out2,
                rtol=config.RTOL.get(config.DEFAULT_DTYPE),
56 57
                atol=config.ATOL.get(config.DEFAULT_DTYPE),
            )
58

59
    def test_entropy_exception(self):
60 61 62
        with paddle.static.program_guard(self.program):
            with self.assertRaises(NotImplementedError):
                paddle.distribution.ExponentialFamily.entropy(
63 64
                    mock.DummyExpFamily(0.5, 0.5)
                )