test_select_input_output_op.py 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
18

19
import paddle
20 21
from paddle import fluid
from paddle.fluid import core
22 23 24
from paddle.fluid.backward import append_backward
from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, program_guard
25
from paddle.static.nn.control_flow import select_input, select_output
26

27 28
paddle.enable_static()

29 30

class TestSplitMergeSelectedVarOps(unittest.TestCase):
31 32 33 34
    def test_forward_backward_list_output(self):
        for branch_num in range(2, 10):
            program = Program()
            with program_guard(program):
G
GGBond8488 已提交
35
                x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32')
36
                x.stop_gradient = False  # For test gradient
G
GGBond8488 已提交
37 38 39
                mask = paddle.static.data(
                    name='mask', shape=[-1, 1], dtype='int32'
                )
40 41 42 43

                outputs = []
                for i in range(branch_num):
                    out = program.current_block().create_var(
44 45
                        dtype='float32',
                        shape=[2],
46 47
                        type=core.VarDesc.VarType.LOD_TENSOR,
                    )
48 49 50 51
                    outputs.append(out)

                select_output(x, outputs, mask)
                y = select_input(outputs, mask)
52
                mean = paddle.mean(y)
53 54
                append_backward(mean)

55 56 57 58 59
            place = (
                fluid.CUDAPlace(0)
                if core.is_compiled_with_cuda()
                else fluid.CPUPlace()
            )
60 61 62 63 64
            exe = Executor(place)

            feed_x = np.asarray([1.3, -1.4]).astype(np.float32)
            for i in range(branch_num):
                feed_mask = np.asarray([i]).astype(np.int32)
65 66 67 68 69
                ret = exe.run(
                    program,
                    feed={'x': feed_x, 'mask': feed_mask},
                    fetch_list=[y.name, x.grad_name],
                )
70
                x_grad = np.asarray([0.5, 0.5]).astype(np.float32)
71 72 73 74 75 76
                np.testing.assert_allclose(
                    np.asarray(ret[0]), feed_x, rtol=1e-05
                )
                np.testing.assert_allclose(
                    np.asarray(ret[1]), x_grad, rtol=1e-05
                )
77

78 79 80 81

class TestSelectInputOpError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):
G
GGBond8488 已提交
82 83
            mask = paddle.static.data(name='mask', shape=[-1, 1], dtype='int32')
            in1 = paddle.static.data(name='in1', shape=[-1, 1], dtype='int32')
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

            # 1. The type of inputs in select_input must be list or tuple.
            def test_inputs_type():
                select_input(1, mask)

            self.assertRaises(TypeError, test_inputs_type)

            # 2. The type of mask in select_input must be Variable.
            def test_mask_type():
                select_input([in1], mask=1)

            self.assertRaises(TypeError, test_mask_type)

            # 3. The dtype of mask in select_input must be int32 or int64.
            def test_mask_dtype():
G
GGBond8488 已提交
99 100 101
                mask = paddle.static.data(
                    name='mask2', shape=[-1, 1], dtype='float32'
                )
102 103 104 105 106 107 108 109
                select_input([in1], mask)

            self.assertRaises(TypeError, test_mask_dtype)


class TestSelectOutput_Error(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):
G
GGBond8488 已提交
110 111 112
            in1 = paddle.static.data(name='in1', shape=[-1, 1], dtype='int32')
            mask_int32 = paddle.static.data(
                name='mask_int32', shape=[-1, 1], dtype='int32'
113
            )
G
GGBond8488 已提交
114 115
            mask_float32 = paddle.static.data(
                name='mask_float32', shape=[-1, 1], dtype='float32'
116
            )
G
GGBond8488 已提交
117
            out1 = paddle.static.data(name='out1', shape=[-1, 1], dtype='int32')
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

            # 1. The type of input in select_output must Variable.
            def test_input_type():
                select_output(1, [out1], mask_int32)

            self.assertRaises(TypeError, test_input_type)

            # 2. The type of mask in select_output must be Variable.
            def test_mask_type():
                select_output(in1, [out1], mask=1)

            self.assertRaises(TypeError, test_mask_type)

            # 3. The dtype of mask in select_output must be int32 or int64.
            def test_mask_dtype():
                select_output(in1, [out1], mask=mask_float32)

            self.assertRaises(TypeError, test_mask_dtype)

            # 4. The type of mask in select_output must be list or tuple.
            def test_outputs_type():
                select_output(in1, out1, mask=mask_int32)
140

141
            self.assertRaises(TypeError, test_outputs_type)
142 143 144 145


if __name__ == '__main__':
    unittest.main()