test_switch_case.py 13.5 KB
Newer Older
L
liym27 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest

import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
from functools import partial


class TestAPISwitchCase(unittest.TestCase):
28

L
liym27 已提交
29
    def test_return_single_var(self):
30

L
liym27 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
        def fn_1():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)

        def fn_2():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)

        def fn_3():
            return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
            index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
            index_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)

            # call fn_1
48 49 50 51 52 53
            out_0 = layers.switch_case(branch_index=index_1,
                                       branch_fns={
                                           1: fn_1,
                                           2: fn_2,
                                           3: fn_3
                                       })
L
liym27 已提交
54 55

            # call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
56 57
            out_1 = layers.switch_case(branch_index=index_1,
                                       branch_fns=(fn_1, fn_2, fn_3))
L
liym27 已提交
58 59

            # call default fn_3
60 61 62
            out_2 = layers.switch_case(branch_index=index_5,
                                       branch_fns=((1, fn_1), (2, fn_2)),
                                       default=fn_3)
L
liym27 已提交
63 64

            # no default, call fn_2
65 66
            out_3 = layers.switch_case(branch_index=index_2,
                                       branch_fns=[(1, fn_1), (2, fn_2)])
L
liym27 已提交
67 68

            # no default, call fn_2 but branch_index is 5
69 70 71
            out_4 = layers.switch_case(branch_index=index_5,
                                       branch_fns=[(1, fn_1), (3, fn_2),
                                                   (2, fn_3)])
L
liym27 已提交
72

73 74
            place = fluid.CUDAPlace(
                0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
L
liym27 已提交
75 76 77 78 79
            exe = fluid.Executor(place)

            res = exe.run(main_program,
                          fetch_list=[out_0, out_1, out_2, out_3, out_4])

80 81 82 83 84 85 86 87 88 89
            self.assertTrue(np.allclose(res[0], 1),
                            "result is {} but answer is {}".format(res[0], 1))
            self.assertTrue(np.allclose(res[1], 2),
                            "result is {} but answer is {}".format(res[0], 2))
            self.assertTrue(np.allclose(res[2], 3),
                            "result is {} but answer is {}".format(res[0], 3))
            self.assertTrue(np.allclose(res[3], 2),
                            "result is {} but answer is {}".format(res[0], 2))
            self.assertTrue(np.allclose(res[4], 2),
                            "result is {} but answer is {}".format(res[0], 2))
L
liym27 已提交
90 91

    def test_return_var_tuple(self):
92

L
liym27 已提交
93
        def fn_1():
94 95 96 97 98
            return layers.fill_constant(shape=[1, 2], dtype='int32',
                                        value=1), layers.fill_constant(
                                            shape=[2, 3],
                                            dtype='float32',
                                            value=2)
L
liym27 已提交
99 100

        def fn_2():
101 102 103 104 105
            return layers.fill_constant(shape=[3, 4], dtype='int32',
                                        value=3), layers.fill_constant(
                                            shape=[4, 5],
                                            dtype='float32',
                                            value=4)
L
liym27 已提交
106 107

        def fn_3():
108 109 110 111 112
            return layers.fill_constant(shape=[5], dtype='int32',
                                        value=5), layers.fill_constant(
                                            shape=[5, 6],
                                            dtype='float32',
                                            value=6)
L
liym27 已提交
113 114 115 116 117 118 119 120

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)

            out = layers.switch_case(index_1, ((1, fn_1), (2, fn_2)), fn_3)

121 122
            place = fluid.CUDAPlace(
                0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
L
liym27 已提交
123 124 125 126 127 128
            exe = fluid.Executor(place)
            ret = exe.run(main_program, fetch_list=out)

            self.assertTrue(
                np.allclose(np.asarray(ret[0]), np.full((1, 2), 1, np.int32)))
            self.assertTrue(
129
                np.allclose(np.asarray(ret[1]), np.full((2, 3), 2, np.float32)))
L
liym27 已提交
130 131 132


class TestAPISwitchCase_Nested(unittest.TestCase):
133

L
liym27 已提交
134
    def test_nested_switch_case(self):
135

L
liym27 已提交
136
        def fn_1(x=1):
137 138 139 140 141 142 143 144 145 146 147 148 149 150
            out = layers.switch_case(branch_index=layers.fill_constant(
                shape=[1], dtype='int32', value=x),
                                     branch_fns={
                                         1:
                                         partial(layers.fill_constant,
                                                 shape=[1],
                                                 dtype='int32',
                                                 value=1),
                                         x:
                                         partial(layers.fill_constant,
                                                 shape=[2],
                                                 dtype='int32',
                                                 value=x)
                                     })
L
liym27 已提交
151 152 153
            return out

        def fn_2(x=2):
154 155 156 157 158 159 160 161 162 163 164
            out = layers.switch_case(branch_index=layers.fill_constant(
                shape=[1], dtype='int32', value=2),
                                     branch_fns={
                                         1:
                                         partial(layers.fill_constant,
                                                 shape=[4, 3],
                                                 dtype='int32',
                                                 value=1),
                                         2:
                                         partial(fn_1, x=x)
                                     })
L
liym27 已提交
165 166 167
            return out

        def fn_3():
168 169 170 171 172 173 174 175 176 177 178
            out = layers.switch_case(branch_index=layers.fill_constant(
                shape=[1], dtype='int32', value=3),
                                     branch_fns={
                                         1:
                                         partial(layers.fill_constant,
                                                 shape=[4, 3],
                                                 dtype='int32',
                                                 value=1),
                                         3:
                                         partial(fn_2, x=3)
                                     })
L
liym27 已提交
179 180 181 182 183 184 185 186 187
            return out

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
            index_1 = fluid.data(name="index_1", shape=[1], dtype='uint8')
            index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
            index_3 = layers.fill_constant(shape=[1], dtype='int64', value=3)

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
            out_1 = layers.switch_case(branch_index=index_1,
                                       branch_fns={
                                           1: fn_1,
                                           2: fn_2,
                                           3: fn_3
                                       })
            out_2 = layers.switch_case(branch_index=index_2,
                                       branch_fns={
                                           1: fn_1,
                                           2: fn_2,
                                           3: fn_3
                                       })

            out_3 = layers.switch_case(branch_index=index_3,
                                       branch_fns={
                                           1: fn_1,
                                           2: fn_2,
                                           3: fn_3
                                       })

            place = fluid.CUDAPlace(
                0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
L
liym27 已提交
210 211 212
            exe = fluid.Executor(place)

            res = exe.run(main_program,
213
                          feed={"index_1": np.array([1], dtype="uint8")},
L
liym27 已提交
214 215
                          fetch_list=[out_1, out_2, out_3])

216 217 218 219 220 221
            self.assertTrue(np.allclose(res[0], 1),
                            "result is {} but answer is {}".format(res[0], 1))
            self.assertTrue(np.allclose(res[1], 2),
                            "result is {} but answer is {}".format(res[1], 2))
            self.assertTrue(np.allclose(res[2], 3),
                            "result is {} but answer is {}".format(res[2], 3))
L
liym27 已提交
222 223 224 225


# test TypeError and ValueError of api switch_case
class TestAPISwitchCase_Error(unittest.TestCase):
226

L
liym27 已提交
227
    def test_error(self):
228

L
liym27 已提交
229 230 231 232 233 234 235 236 237 238 239 240
        def fn_1():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)

        def fn_2():
            return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)

        def fn_3():
            return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)

        main_program = Program()
        startup_program = Program()
        with program_guard(main_program, startup_program):
241 242 243 244 245 246
            key_float32 = layers.fill_constant(shape=[1],
                                               dtype='float32',
                                               value=0.23)
            key_int32 = layers.fill_constant(shape=[1],
                                             dtype='int32',
                                             value=0.23)
L
liym27 已提交
247 248 249

            # The type of 'branch_index' in Op(switch_case) must be Variable
            def type_error_branch_index():
250 251 252
                layers.switch_case(branch_index=1,
                                   branch_fns=[(1, fn_1)],
                                   default=fn_3)
L
liym27 已提交
253 254 255 256 257

            self.assertRaises(TypeError, type_error_branch_index)

            # The data type of 'branch_index' in Op(switch_case) must be int32, int64 or uint8
            def dtype_error_branch_index():
258 259 260
                layers.switch_case(branch_index=key_float32,
                                   branch_fns=[(1, fn_1)],
                                   default=fn_3)
L
liym27 已提交
261 262 263 264 265

            self.assertRaises(TypeError, dtype_error_branch_index)

            # The type of 'branch_fns' in Op(switch_case) must be list, tuple or dict
            def type_error_branch_fns():
266 267 268
                layers.switch_case(branch_index=key_int32,
                                   branch_fns=1,
                                   default=fn_3)
L
liym27 已提交
269 270 271 272 273

            self.assertRaises(TypeError, type_error_branch_fns)

            # The elements' type of 'branch_fns' in Op(switch_case) must be tuple
            def type_error_index_fn_pair_1():
274 275 276
                layers.switch_case(branch_index=key_int32,
                                   branch_fns=[1],
                                   default=fn_3)
L
liym27 已提交
277 278 279 280 281

            self.assertRaises(TypeError, type_error_index_fn_pair_1)

            # The tuple's size of 'branch_fns' in Op(switch_case) must be 2
            def type_error_index_fn_pair_2():
282 283 284
                layers.switch_case(branch_index=key_int32,
                                   branch_fns=[(1, 2, 3)],
                                   default=fn_3)
L
liym27 已提交
285 286 287 288 289

            self.assertRaises(TypeError, type_error_index_fn_pair_2)

            # The key's type of 'branch_fns' in Op(switch_case) must be int
            def type_error_key():
290 291 292
                layers.switch_case(branch_index=key_int32,
                                   branch_fns=[(2.3, 2)],
                                   default=fn_3)
L
liym27 已提交
293 294 295 296 297

            self.assertRaises(TypeError, type_error_key)

            # The key in 'branch_fns' must be unique
            def value_error_key():
298 299 300
                layers.switch_case(branch_index=key_int32,
                                   branch_fns=[(2, fn_1), (2, fn_2)],
                                   default=fn_3)
L
liym27 已提交
301 302 303 304 305

            self.assertRaises(ValueError, value_error_key)

            # The type of function in 'branch_fns' must be callable
            def type_error_fn():
306 307 308
                layers.switch_case(branch_index=key_int32,
                                   branch_fns=[(1, 1), (2, fn_2)],
                                   default=fn_3)
L
liym27 已提交
309 310 311 312 313

            self.assertRaises(TypeError, type_error_fn)

            # The default in Op(case) must be callable
            def type_error_default():
314 315 316
                layers.switch_case(branch_index=key_int32,
                                   branch_fns=[(1, fn_1), (2, fn_2)],
                                   default=1)
L
liym27 已提交
317 318 319 320 321 322

            self.assertRaises(TypeError, type_error_default)


if __name__ == '__main__':
    unittest.main()