test_switch_case.py 12.5 KB
Newer Older
L
liym27 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16 17 18
from functools import partial

import numpy as np
L
liym27 已提交
19

20
import paddle
L
liym27 已提交
21 22 23 24 25
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard

26 27
paddle.enable_static()

L
liym27 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

class TestAPISwitchCase(unittest.TestCase):
    def test_return_single_var(self):
        def fn_1():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)

        def fn_2():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)

        def fn_3():
            return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
            index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
            index_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)

            # call fn_1
48
            out_0 = paddle.static.nn.switch_case(
49 50
                branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
            )
L
liym27 已提交
51 52

            # call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
53
            out_1 = paddle.static.nn.switch_case(
54 55
                branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3)
            )
L
liym27 已提交
56 57

            # call default fn_3
58
            out_2 = paddle.static.nn.switch_case(
59 60 61 62
                branch_index=index_5,
                branch_fns=((1, fn_1), (2, fn_2)),
                default=fn_3,
            )
L
liym27 已提交
63 64

            # no default, call fn_2
65
            out_3 = paddle.static.nn.switch_case(
66 67
                branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)]
            )
L
liym27 已提交
68 69

            # no default, call fn_2 but branch_index is 5
70
            out_4 = paddle.static.nn.switch_case(
71 72 73 74 75 76 77 78 79
                branch_index=index_5,
                branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)],
            )

            place = (
                fluid.CUDAPlace(0)
                if core.is_compiled_with_cuda()
                else fluid.CPUPlace()
            )
L
liym27 已提交
80 81
            exe = fluid.Executor(place)

82 83 84
            res = exe.run(
                main_program, fetch_list=[out_0, out_1, out_2, out_3, out_4]
            )
L
liym27 已提交
85

86 87 88 89
            np.testing.assert_allclose(
                res[0],
                1,
                rtol=1e-05,
90 91
                err_msg='result is {} but answer is {}'.format(res[0], 1),
            )
92 93 94 95
            np.testing.assert_allclose(
                res[1],
                2,
                rtol=1e-05,
96 97
                err_msg='result is {} but answer is {}'.format(res[0], 2),
            )
98 99 100 101
            np.testing.assert_allclose(
                res[2],
                3,
                rtol=1e-05,
102 103
                err_msg='result is {} but answer is {}'.format(res[0], 3),
            )
104 105 106 107
            np.testing.assert_allclose(
                res[3],
                2,
                rtol=1e-05,
108 109
                err_msg='result is {} but answer is {}'.format(res[0], 2),
            )
110 111 112 113
            np.testing.assert_allclose(
                res[4],
                2,
                rtol=1e-05,
114 115
                err_msg='result is {} but answer is {}'.format(res[0], 2),
            )
L
liym27 已提交
116 117 118

    def test_return_var_tuple(self):
        def fn_1():
119 120 121
            return layers.fill_constant(
                shape=[1, 2], dtype='int32', value=1
            ), layers.fill_constant(shape=[2, 3], dtype='float32', value=2)
L
liym27 已提交
122 123

        def fn_2():
124 125 126
            return layers.fill_constant(
                shape=[3, 4], dtype='int32', value=3
            ), layers.fill_constant(shape=[4, 5], dtype='float32', value=4)
L
liym27 已提交
127 128

        def fn_3():
129 130 131
            return layers.fill_constant(
                shape=[5], dtype='int32', value=5
            ), layers.fill_constant(shape=[5, 6], dtype='float32', value=6)
L
liym27 已提交
132 133 134 135 136 137

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)

138 139 140
            out = paddle.static.nn.switch_case(
                index_1, ((1, fn_1), (2, fn_2)), fn_3
            )
L
liym27 已提交
141

142 143 144 145 146
            place = (
                fluid.CUDAPlace(0)
                if core.is_compiled_with_cuda()
                else fluid.CPUPlace()
            )
L
liym27 已提交
147 148 149
            exe = fluid.Executor(place)
            ret = exe.run(main_program, fetch_list=out)

150 151 152 153 154 155
            np.testing.assert_allclose(
                np.asarray(ret[0]), np.full((1, 2), 1, np.int32), rtol=1e-05
            )
            np.testing.assert_allclose(
                np.asarray(ret[1]), np.full((2, 3), 2, np.float32), rtol=1e-05
            )
L
liym27 已提交
156 157 158 159 160


class TestAPISwitchCase_Nested(unittest.TestCase):
    def test_nested_switch_case(self):
        def fn_1(x=1):
161
            out = paddle.static.nn.switch_case(
162 163 164 165 166 167 168 169 170 171 172 173
                branch_index=layers.fill_constant(
                    shape=[1], dtype='int32', value=x
                ),
                branch_fns={
                    1: partial(
                        layers.fill_constant, shape=[1], dtype='int32', value=1
                    ),
                    x: partial(
                        layers.fill_constant, shape=[2], dtype='int32', value=x
                    ),
                },
            )
L
liym27 已提交
174 175 176
            return out

        def fn_2(x=2):
177
            out = paddle.static.nn.switch_case(
178 179 180 181 182 183 184 185 186 187 188 189 190
                branch_index=layers.fill_constant(
                    shape=[1], dtype='int32', value=2
                ),
                branch_fns={
                    1: partial(
                        layers.fill_constant,
                        shape=[4, 3],
                        dtype='int32',
                        value=1,
                    ),
                    2: partial(fn_1, x=x),
                },
            )
L
liym27 已提交
191 192 193
            return out

        def fn_3():
194
            out = paddle.static.nn.switch_case(
195 196 197 198 199 200 201 202 203 204 205 206 207
                branch_index=layers.fill_constant(
                    shape=[1], dtype='int32', value=3
                ),
                branch_fns={
                    1: partial(
                        layers.fill_constant,
                        shape=[4, 3],
                        dtype='int32',
                        value=1,
                    ),
                    3: partial(fn_2, x=3),
                },
            )
L
liym27 已提交
208 209 210 211 212 213 214 215 216
            return out

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = fluid.data(name="index_1", shape=[1], dtype='uint8')
            index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
            index_3 = layers.fill_constant(shape=[1], dtype='int64', value=3)

217
            out_1 = paddle.static.nn.switch_case(
218 219
                branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
            )
220
            out_2 = paddle.static.nn.switch_case(
221 222 223
                branch_index=index_2, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
            )

224
            out_3 = paddle.static.nn.switch_case(
225 226 227 228 229 230 231 232
                branch_index=index_3, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
            )

            place = (
                fluid.CUDAPlace(0)
                if core.is_compiled_with_cuda()
                else fluid.CPUPlace()
            )
L
liym27 已提交
233 234
            exe = fluid.Executor(place)

235 236 237 238 239
            res = exe.run(
                main_program,
                feed={"index_1": np.array([1], dtype="uint8")},
                fetch_list=[out_1, out_2, out_3],
            )
L
liym27 已提交
240

241 242 243 244
            np.testing.assert_allclose(
                res[0],
                1,
                rtol=1e-05,
245 246
                err_msg='result is {} but answer is {}'.format(res[0], 1),
            )
247 248 249 250
            np.testing.assert_allclose(
                res[1],
                2,
                rtol=1e-05,
251 252
                err_msg='result is {} but answer is {}'.format(res[1], 2),
            )
253 254 255 256
            np.testing.assert_allclose(
                res[2],
                3,
                rtol=1e-05,
257 258
                err_msg='result is {} but answer is {}'.format(res[2], 3),
            )
L
liym27 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275


# test TypeError and ValueError of api switch_case
class TestAPISwitchCase_Error(unittest.TestCase):
    def test_error(self):
        def fn_1():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)

        def fn_2():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)

        def fn_3():
            return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
276 277 278 279 280 281
            key_float32 = layers.fill_constant(
                shape=[1], dtype='float32', value=0.23
            )
            key_int32 = layers.fill_constant(
                shape=[1], dtype='int32', value=0.23
            )
L
liym27 已提交
282 283 284

            # The type of 'branch_index' in Op(switch_case) must be Variable
            def type_error_branch_index():
285
                paddle.static.nn.switch_case(
286 287
                    branch_index=1, branch_fns=[(1, fn_1)], default=fn_3
                )
L
liym27 已提交
288 289 290 291 292

            self.assertRaises(TypeError, type_error_branch_index)

            # The data type of 'branch_index' in Op(switch_case) must be int32, int64 or uint8
            def dtype_error_branch_index():
293
                paddle.static.nn.switch_case(
294 295 296 297
                    branch_index=key_float32,
                    branch_fns=[(1, fn_1)],
                    default=fn_3,
                )
L
liym27 已提交
298 299 300 301 302

            self.assertRaises(TypeError, dtype_error_branch_index)

            # The type of 'branch_fns' in Op(switch_case) must be list, tuple or dict
            def type_error_branch_fns():
303
                paddle.static.nn.switch_case(
304 305
                    branch_index=key_int32, branch_fns=1, default=fn_3
                )
L
liym27 已提交
306 307 308 309 310

            self.assertRaises(TypeError, type_error_branch_fns)

            # The elements' type of 'branch_fns' in Op(switch_case) must be tuple
            def type_error_index_fn_pair_1():
311
                paddle.static.nn.switch_case(
312 313
                    branch_index=key_int32, branch_fns=[1], default=fn_3
                )
L
liym27 已提交
314 315 316 317 318

            self.assertRaises(TypeError, type_error_index_fn_pair_1)

            # The tuple's size of 'branch_fns' in Op(switch_case) must be 2
            def type_error_index_fn_pair_2():
319
                paddle.static.nn.switch_case(
320 321
                    branch_index=key_int32, branch_fns=[(1, 2, 3)], default=fn_3
                )
L
liym27 已提交
322 323 324 325 326

            self.assertRaises(TypeError, type_error_index_fn_pair_2)

            # The key's type of 'branch_fns' in Op(switch_case) must be int
            def type_error_key():
327
                paddle.static.nn.switch_case(
328 329
                    branch_index=key_int32, branch_fns=[(2.3, 2)], default=fn_3
                )
L
liym27 已提交
330 331 332 333 334

            self.assertRaises(TypeError, type_error_key)

            # The key in 'branch_fns' must be unique
            def value_error_key():
335
                paddle.static.nn.switch_case(
336 337 338 339
                    branch_index=key_int32,
                    branch_fns=[(2, fn_1), (2, fn_2)],
                    default=fn_3,
                )
L
liym27 已提交
340 341 342 343 344

            self.assertRaises(ValueError, value_error_key)

            # The type of function in 'branch_fns' must be callable
            def type_error_fn():
345
                paddle.static.nn.switch_case(
346 347 348 349
                    branch_index=key_int32,
                    branch_fns=[(1, 1), (2, fn_2)],
                    default=fn_3,
                )
L
liym27 已提交
350 351 352 353 354

            self.assertRaises(TypeError, type_error_fn)

            # The default in Op(case) must be callable
            def type_error_default():
355
                paddle.static.nn.switch_case(
356 357 358 359
                    branch_index=key_int32,
                    branch_fns=[(1, fn_1), (2, fn_2)],
                    default=1,
                )
L
liym27 已提交
360 361 362 363 364 365

            self.assertRaises(TypeError, type_error_default)


if __name__ == '__main__':
    unittest.main()