test_gumbel_softmax_op.py 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import unittest
14

15
import numpy as np
16
from eager_op_test import OpTest
17

18 19
import paddle
import paddle.fluid as fluid
20
import paddle.nn.functional as F
21

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
paddle.enable_static()


class TestGumbelSoftmaxOp(OpTest):
    def init_attrs(self):
        self.shape = [20, 10]
        self.attrs = {"hard": True, "axis": -1}
        self.count_expected = 20
        self.dtype = "float64"

    def verify_output(self, outs):
        out_np = np.array(outs[0])
        out_np.shape = self.shape
        self.assertTrue(list(out_np.shape) == self.shape)
        self.assertEqual(out_np.sum(), self.count_expected)

    def setUp(self):
        self.op_type = "gumbel_softmax"
40
        self.python_api = F.gumbel_softmax
41 42 43 44 45 46 47 48 49 50 51 52 53 54
        self.init_attrs()
        np.random.seed(0)
        x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
        out = np.zeros(self.shape).astype(self.dtype)
        self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
        self.outputs = {'Out': out}

    def test_check_output(self):
        self.check_output_customized(self.verify_output)

    def test_check_grad(self):
        self.check_grad(["X"], "Out")


55 56 57
class TestGumbelSoftmax_ZeroDim(OpTest):
    def setUp(self):
        self.op_type = "gumbel_softmax"
58
        self.python_api = F.gumbel_softmax
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        self.dtype = "float64"
        x = np.random.uniform(0.1, 1, []).astype(self.dtype)
        out = np.array(1.0).astype(self.dtype)

        self.inputs = {'X': x}
        self.outputs = {'Out': out}
        self.attrs = {"hard": True, "axis": -1}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(["X"], "Out")


74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
class TestGumbelSoftmaxOp2(TestGumbelSoftmaxOp):
    def init_attrs(self):
        self.shape = [20, 10]
        self.attrs = {"hard": True, "axis": 0}
        self.count_expected = 10
        self.dtype = "float64"


class TestGumbelSoftmaxOp3(TestGumbelSoftmaxOp):
    def init_attrs(self):
        self.shape = [100]
        self.attrs = {"hard": True, "axis": -1}
        self.count_expected = 1
        self.dtype = "float64"


class TestGumbelSoftmaxOp4(TestGumbelSoftmaxOp):
    def init_attrs(self):
        self.shape = [20, 10, 5]
        self.attrs = {"hard": True, "axis": -1}
        self.count_expected = 200
        self.dtype = "float64"


class TestGumbelSoftmaxOp5(TestGumbelSoftmaxOp):
    def init_attrs(self):
        self.shape = [20, 10, 5]
        self.attrs = {"hard": True, "axis": 1}
        self.count_expected = 100
        self.dtype = "float64"


class TestGumbelSoftmaxOpSampleDistribution(OpTest):
    def softmax(self, x):
        x_row_max = x.max(axis=-1)
        x_row_max = x_row_max.reshape(list(x.shape)[:-1] + [1])
        x = x - x_row_max
        x_exp = np.exp(x)
        x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1] + [1])
        softmax = x_exp / x_exp_row_sum
        return softmax

    def init_attrs(self):
        self.shape = [100, 3]
        self.attrs = {"hard": True, "axis": -1}
        self.counts = np.zeros(self.shape).astype(self.dtype)
        self._cpu_only = True

    def accumulate_output(self, outs):
        out_np = np.array(outs)
        out_np = out_np.reshape(self.shape)
        self.counts = np.sum(out_np, axis=0)

    def setUp(self):
        self.op_type = "gumbel_softmax"
129
        self.python_api = F.gumbel_softmax
130 131 132 133 134 135 136 137 138 139 140 141 142
        self.init_attrs()
        single_x = np.array([0.2, 0.3, 0.5])
        batch_x = np.ones(self.shape) * single_x
        out = np.zeros(self.shape).astype(self.dtype)
        self.probs = self.softmax(single_x)
        self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(batch_x)}
        self.outputs = {'Out': out}

    def test_check_output(self):
        self.check_output_customized(self.accumulate_output)
        # Experiment should result in batch num .
        self.assertEqual(self.counts.sum(), self.shape[0])

143
        # Treat the probability from softmax as
144 145
        # the probability of binomial distribution.
        # Samples from gumbel softmax meet this binomial distribution.
146
        # Construct statistics z for samples and
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        # z is approximately N(0,1) for unbiased count
        expected = self.probs * self.shape[0]
        z = (self.counts - expected) / np.sqrt((expected * (1 - self.probs)))
        # A (lazy) approximate 99% two-sided test:
        # occurs with prob alpha~>=0.01 if unbiased
        self.assertLess(np.max(np.abs(z)).item(), 2.58)

    def test_check_grad(self):
        self.check_grad(["X"], "Out")


class TestGumbelSoftmaxOpGrad(unittest.TestCase):
    def init_attrs(self):
        self.shape = [20, 10]
        self.dtype = "float64"

    def setUp(self):
        self.init_attrs()
        np.random.seed(0)
        self.x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)

    def test_dygraph_check(self):
        paddle.disable_static()
        x_hard = paddle.to_tensor(self.x_np, stop_gradient=False)
        x_soft = paddle.to_tensor(self.x_np, stop_gradient=False)
        out_hard = paddle.nn.functional.gumbel_softmax(x_hard, hard=True)
        out_soft = paddle.nn.functional.gumbel_softmax(x_soft, hard=False)

        out_hard.sum().backward()
        out_soft.sum().backward()

178 179 180
        np.testing.assert_allclose(
            x_hard.grad.numpy(), x_soft.grad.numpy(), rtol=1e-5, atol=1e-8
        )
181 182 183 184 185 186
        paddle.enable_static()


class TestGumbelSoftmaxAPI(unittest.TestCase):
    def setUp(self):
        self.x_shape = [2, 3, 4, 5]
187
        self.x = np.random.uniform(-1.0, 1.0, self.x_shape).astype(np.float32)
188
        self.count_expected = 24
189 190 191
        self.place = (
            paddle.CUDAPlace(0)
            if paddle.fluid.core.is_compiled_with_cuda()
192
            else paddle.CPUPlace()
193
        )
194 195 196 197 198 199 200 201 202 203 204 205

    def test_check_api(self):
        # test static api
        with paddle.static.program_guard(paddle.static.Program()):
            x = paddle.fluid.data(name='x', shape=self.x_shape)
            y = paddle.nn.functional.gumbel_softmax(x, hard=True)
            exe = paddle.static.Executor(self.place)
            out = exe.run(feed={'x': self.x}, fetch_list=[y])
            out_np = np.array(out[0])
        self.assertEqual(out_np.sum(), self.count_expected)

        # test dygrapg api
H
hong 已提交
206 207 208 209 210 211
        with paddle.fluid.dygraph.base.guard():
            x = paddle.to_tensor(self.x)
            y = paddle.nn.functional.gumbel_softmax(x, hard=True)
            out_np = np.array(y)
            self.assertEqual(out_np.sum(), self.count_expected)

212 213 214 215 216 217

class TestGumbelSoftmaxOpError(unittest.TestCase):
    def test_errors(self):
        paddle.disable_static()

        def test_Variable():
218 219 220
            x1 = fluid.create_lod_tensor(
                np.zeros((100, 784)), [[10, 10, 10, 70]], fluid.CPUPlace()
            )
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
            paddle.nn.functional.gumbel_softmax(x1)

        self.assertRaises(ValueError, test_Variable)

        def test_Variable2():
            x1 = np.zeros((100, 784))
            paddle.nn.functional.gumbel_softmax(x1)

        self.assertRaises(ValueError, test_Variable2)

        def test_argument1():
            x = paddle.to_tensor([0.2, 0.3, 0.4])
            paddle.nn.functional.gumbel_softmax(x, temperature=-1)

        self.assertRaises(ValueError, test_argument1)

        def test_argument2():
            x = paddle.to_tensor([0.2, 0.3, 0.4])
            paddle.nn.functional.gumbel_softmax(x, axis=1.1)

        self.assertRaises(ValueError, test_argument2)

        paddle.enable_static()

        def test_dtype():
            with paddle.static.program_guard(paddle.static.Program()):
247 248 249
                x_int32 = paddle.fluid.data(
                    name='x_int32', shape=[2, 3], dtype='int32'
                )
250 251 252 253 254 255 256
                paddle.nn.functional.gumbel_softmax(x_int32)

        self.assertRaises(TypeError, test_dtype)


if __name__ == '__main__':
    unittest.main()