test_switch_case.py 12.2 KB
Newer Older
L
liym27 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import unittest

import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
from functools import partial


class TestAPISwitchCase(unittest.TestCase):
    def test_return_single_var(self):
        def fn_1():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)

        def fn_2():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)

        def fn_3():
            return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
            index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
            index_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)

            # call fn_1
44 45 46
            out_0 = layers.switch_case(
                branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
            )
L
liym27 已提交
47 48

            # call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
49 50 51
            out_1 = layers.switch_case(
                branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3)
            )
L
liym27 已提交
52 53

            # call default fn_3
54 55 56 57 58
            out_2 = layers.switch_case(
                branch_index=index_5,
                branch_fns=((1, fn_1), (2, fn_2)),
                default=fn_3,
            )
L
liym27 已提交
59 60

            # no default, call fn_2
61 62 63
            out_3 = layers.switch_case(
                branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)]
            )
L
liym27 已提交
64 65

            # no default, call fn_2 but branch_index is 5
66 67 68 69 70 71 72 73 74 75
            out_4 = layers.switch_case(
                branch_index=index_5,
                branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)],
            )

            place = (
                fluid.CUDAPlace(0)
                if core.is_compiled_with_cuda()
                else fluid.CPUPlace()
            )
L
liym27 已提交
76 77
            exe = fluid.Executor(place)

78 79 80
            res = exe.run(
                main_program, fetch_list=[out_0, out_1, out_2, out_3, out_4]
            )
L
liym27 已提交
81

82 83 84 85
            np.testing.assert_allclose(
                res[0],
                1,
                rtol=1e-05,
86 87
                err_msg='result is {} but answer is {}'.format(res[0], 1),
            )
88 89 90 91
            np.testing.assert_allclose(
                res[1],
                2,
                rtol=1e-05,
92 93
                err_msg='result is {} but answer is {}'.format(res[0], 2),
            )
94 95 96 97
            np.testing.assert_allclose(
                res[2],
                3,
                rtol=1e-05,
98 99
                err_msg='result is {} but answer is {}'.format(res[0], 3),
            )
100 101 102 103
            np.testing.assert_allclose(
                res[3],
                2,
                rtol=1e-05,
104 105
                err_msg='result is {} but answer is {}'.format(res[0], 2),
            )
106 107 108 109
            np.testing.assert_allclose(
                res[4],
                2,
                rtol=1e-05,
110 111
                err_msg='result is {} but answer is {}'.format(res[0], 2),
            )
L
liym27 已提交
112 113 114

    def test_return_var_tuple(self):
        def fn_1():
115 116 117
            return layers.fill_constant(
                shape=[1, 2], dtype='int32', value=1
            ), layers.fill_constant(shape=[2, 3], dtype='float32', value=2)
L
liym27 已提交
118 119

        def fn_2():
120 121 122
            return layers.fill_constant(
                shape=[3, 4], dtype='int32', value=3
            ), layers.fill_constant(shape=[4, 5], dtype='float32', value=4)
L
liym27 已提交
123 124

        def fn_3():
125 126 127
            return layers.fill_constant(
                shape=[5], dtype='int32', value=5
            ), layers.fill_constant(shape=[5, 6], dtype='float32', value=6)
L
liym27 已提交
128 129 130 131 132 133 134 135

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)

            out = layers.switch_case(index_1, ((1, fn_1), (2, fn_2)), fn_3)

136 137 138 139 140
            place = (
                fluid.CUDAPlace(0)
                if core.is_compiled_with_cuda()
                else fluid.CPUPlace()
            )
L
liym27 已提交
141 142 143
            exe = fluid.Executor(place)
            ret = exe.run(main_program, fetch_list=out)

144 145 146 147 148 149
            np.testing.assert_allclose(
                np.asarray(ret[0]), np.full((1, 2), 1, np.int32), rtol=1e-05
            )
            np.testing.assert_allclose(
                np.asarray(ret[1]), np.full((2, 3), 2, np.float32), rtol=1e-05
            )
L
liym27 已提交
150 151 152 153 154


class TestAPISwitchCase_Nested(unittest.TestCase):
    def test_nested_switch_case(self):
        def fn_1(x=1):
155 156 157 158 159 160 161 162 163 164 165 166 167
            out = layers.switch_case(
                branch_index=layers.fill_constant(
                    shape=[1], dtype='int32', value=x
                ),
                branch_fns={
                    1: partial(
                        layers.fill_constant, shape=[1], dtype='int32', value=1
                    ),
                    x: partial(
                        layers.fill_constant, shape=[2], dtype='int32', value=x
                    ),
                },
            )
L
liym27 已提交
168 169 170
            return out

        def fn_2(x=2):
171 172 173 174 175 176 177 178 179 180 181 182 183 184
            out = layers.switch_case(
                branch_index=layers.fill_constant(
                    shape=[1], dtype='int32', value=2
                ),
                branch_fns={
                    1: partial(
                        layers.fill_constant,
                        shape=[4, 3],
                        dtype='int32',
                        value=1,
                    ),
                    2: partial(fn_1, x=x),
                },
            )
L
liym27 已提交
185 186 187
            return out

        def fn_3():
188 189 190 191 192 193 194 195 196 197 198 199 200 201
            out = layers.switch_case(
                branch_index=layers.fill_constant(
                    shape=[1], dtype='int32', value=3
                ),
                branch_fns={
                    1: partial(
                        layers.fill_constant,
                        shape=[4, 3],
                        dtype='int32',
                        value=1,
                    ),
                    3: partial(fn_2, x=3),
                },
            )
L
liym27 已提交
202 203 204 205 206 207 208 209 210
            return out

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = fluid.data(name="index_1", shape=[1], dtype='uint8')
            index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
            index_3 = layers.fill_constant(shape=[1], dtype='int64', value=3)

211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
            out_1 = layers.switch_case(
                branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
            )
            out_2 = layers.switch_case(
                branch_index=index_2, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
            )

            out_3 = layers.switch_case(
                branch_index=index_3, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
            )

            place = (
                fluid.CUDAPlace(0)
                if core.is_compiled_with_cuda()
                else fluid.CPUPlace()
            )
L
liym27 已提交
227 228
            exe = fluid.Executor(place)

229 230 231 232 233
            res = exe.run(
                main_program,
                feed={"index_1": np.array([1], dtype="uint8")},
                fetch_list=[out_1, out_2, out_3],
            )
L
liym27 已提交
234

235 236 237 238
            np.testing.assert_allclose(
                res[0],
                1,
                rtol=1e-05,
239 240
                err_msg='result is {} but answer is {}'.format(res[0], 1),
            )
241 242 243 244
            np.testing.assert_allclose(
                res[1],
                2,
                rtol=1e-05,
245 246
                err_msg='result is {} but answer is {}'.format(res[1], 2),
            )
247 248 249 250
            np.testing.assert_allclose(
                res[2],
                3,
                rtol=1e-05,
251 252
                err_msg='result is {} but answer is {}'.format(res[2], 3),
            )
L
liym27 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269


# test TypeError and ValueError of api switch_case
class TestAPISwitchCase_Error(unittest.TestCase):
    def test_error(self):
        def fn_1():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)

        def fn_2():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)

        def fn_3():
            return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
270 271 272 273 274 275
            key_float32 = layers.fill_constant(
                shape=[1], dtype='float32', value=0.23
            )
            key_int32 = layers.fill_constant(
                shape=[1], dtype='int32', value=0.23
            )
L
liym27 已提交
276 277 278

            # The type of 'branch_index' in Op(switch_case) must be Variable
            def type_error_branch_index():
279 280 281
                layers.switch_case(
                    branch_index=1, branch_fns=[(1, fn_1)], default=fn_3
                )
L
liym27 已提交
282 283 284 285 286

            self.assertRaises(TypeError, type_error_branch_index)

            # The data type of 'branch_index' in Op(switch_case) must be int32, int64 or uint8
            def dtype_error_branch_index():
287 288 289 290 291
                layers.switch_case(
                    branch_index=key_float32,
                    branch_fns=[(1, fn_1)],
                    default=fn_3,
                )
L
liym27 已提交
292 293 294 295 296

            self.assertRaises(TypeError, dtype_error_branch_index)

            # The type of 'branch_fns' in Op(switch_case) must be list, tuple or dict
            def type_error_branch_fns():
297 298 299
                layers.switch_case(
                    branch_index=key_int32, branch_fns=1, default=fn_3
                )
L
liym27 已提交
300 301 302 303 304

            self.assertRaises(TypeError, type_error_branch_fns)

            # The elements' type of 'branch_fns' in Op(switch_case) must be tuple
            def type_error_index_fn_pair_1():
305 306 307
                layers.switch_case(
                    branch_index=key_int32, branch_fns=[1], default=fn_3
                )
L
liym27 已提交
308 309 310 311 312

            self.assertRaises(TypeError, type_error_index_fn_pair_1)

            # The tuple's size of 'branch_fns' in Op(switch_case) must be 2
            def type_error_index_fn_pair_2():
313 314 315
                layers.switch_case(
                    branch_index=key_int32, branch_fns=[(1, 2, 3)], default=fn_3
                )
L
liym27 已提交
316 317 318 319 320

            self.assertRaises(TypeError, type_error_index_fn_pair_2)

            # The key's type of 'branch_fns' in Op(switch_case) must be int
            def type_error_key():
321 322 323
                layers.switch_case(
                    branch_index=key_int32, branch_fns=[(2.3, 2)], default=fn_3
                )
L
liym27 已提交
324 325 326 327 328

            self.assertRaises(TypeError, type_error_key)

            # The key in 'branch_fns' must be unique
            def value_error_key():
329 330 331 332 333
                layers.switch_case(
                    branch_index=key_int32,
                    branch_fns=[(2, fn_1), (2, fn_2)],
                    default=fn_3,
                )
L
liym27 已提交
334 335 336 337 338

            self.assertRaises(ValueError, value_error_key)

            # The type of function in 'branch_fns' must be callable
            def type_error_fn():
339 340 341 342 343
                layers.switch_case(
                    branch_index=key_int32,
                    branch_fns=[(1, 1), (2, fn_2)],
                    default=fn_3,
                )
L
liym27 已提交
344 345 346 347 348

            self.assertRaises(TypeError, type_error_fn)

            # The default in Op(case) must be callable
            def type_error_default():
349 350 351 352 353
                layers.switch_case(
                    branch_index=key_int32,
                    branch_fns=[(1, fn_1), (2, fn_2)],
                    default=1,
                )
L
liym27 已提交
354 355 356 357 358 359

            self.assertRaises(TypeError, type_error_default)


if __name__ == '__main__':
    unittest.main()