test_unsqueeze_op.py 6.6 KB
Newer Older
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

from op_test import OpTest


# Correct: General.
22
class TestUnsqueezeOp(OpTest):
23 24 25 26 27 28 29
    def setUp(self):
        ori_shape = (3, 5)
        axes = (0, 2)
        new_shape = (1, 3, 1, 5)

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
30
        self.attrs = {"axes": axes, "inplace": False}
31 32 33 34 35 36 37 38 39
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(["X"], "Out")


40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
# Correct: Single input index.
class TestUnsqueezeOp1(OpTest):
    def setUp(self):
        ori_shape = (3, 5)
        axes = (-1, )
        new_shape = (3, 5, 1)

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
        self.attrs = {"axes": axes, "inplace": False}
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(["X"], "Out")


# Correct: Mixed input axis.
60
class TestUnsqueezeOp2(OpTest):
61 62
    def setUp(self):
        ori_shape = (3, 5)
63 64
        axes = (0, -1)
        new_shape = (1, 3, 5, 1)
65 66 67

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
68
        self.attrs = {"axes": axes, "inplace": False}
69 70
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

71 72
    def test_check_output(self):
        self.check_output()
73

74 75
    def test_check_grad(self):
        self.check_grad(["X"], "Out")
76 77


78 79 80 81 82 83 84 85 86
# Correct: There is duplicated axis.
class TestUnsqueezeOp3(OpTest):
    def setUp(self):
        ori_shape = (3, 2, 5)
        axes = (0, 3, 3)
        new_shape = (1, 3, 2, 1, 1, 5)

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
87
        self.attrs = {"axes": axes, "inplace": False}
88 89
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

90 91
    def test_check_output(self):
        self.check_output()
92

93 94
    def test_check_grad(self):
        self.check_grad(["X"], "Out")
95 96


97 98
# Correct: Inplace.
class TestUnsqueezeOpInplace1(OpTest):
99
    def setUp(self):
100 101 102
        ori_shape = (3, 5)
        axes = (0, 2)
        new_shape = (1, 3, 1, 5)
103 104 105

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
106
        self.attrs = {"axes": axes, "inplace": True}
107 108
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

109 110
    def test_check_output(self):
        self.check_output()
111

112 113
    def test_check_grad(self):
        self.check_grad(["X"], "Out")
114 115


116 117
# Correct: Inplace. There is mins index.
class TestUnsqueezeOpInplace2(OpTest):
118
    def setUp(self):
119 120
        ori_shape = (3, 5)
        axes = (0, -2)
121 122 123 124
        new_shape = (1, 3, 1, 5)

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
125
        self.attrs = {"axes": axes, "inplace": True}
126 127
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

128 129
    def test_check_output(self):
        self.check_output()
130

131 132
    def test_check_grad(self):
        self.check_grad(["X"], "Out")
133 134


135 136
# Correct: Inplace. There is duplicated axis.
class TestUnsqueezeOpInplace3(OpTest):
137 138
    def setUp(self):
        ori_shape = (3, 2, 5)
139 140
        axes = (0, 3, 3)
        new_shape = (1, 3, 2, 1, 1, 5)
141 142 143

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
144
        self.attrs = {"axes": axes, "inplace": True}
145 146
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

147 148
    def test_check_output(self):
        self.check_output()
149

150 151
    def test_check_grad(self):
        self.check_grad(["X"], "Out")
152 153


154 155 156
'''
# Error: Output dimension is error.
class TestUnsqueezeOp4(OpTest):
157 158
    def setUp(self):
        ori_shape = (3, 5)
159 160
        axes = (0, 3)
        new_shape = (1, 3, 1, 1, 5)
161 162 163

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
164
        self.attrs = {"axes": axes, "inplace": False}
165 166 167 168 169 170 171 172
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(["X"], "Out")

173 174
# Error: Input axis is large than output range.
class TestUnsqueezeOp5(OpTest):
175
    def setUp(self):
176 177 178
        ori_shape = (3, 5)
        axes = (0, 4)
        new_shape = (1, 3, 5, 1)
179 180 181

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
182
        self.attrs = {"axes": axes, "inplace": False}
183 184
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

185 186
    def test_check_output(self):
        self.check_output()
187

188 189
        def test_check_grad(self):
            self.check_grad(["X"], "Out")
190

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
# Error: Input axes is large than Eigen limit.
class TestUnsqueezeOp6(OpTest):
    def setUp(self):
        ori_shape = (3, 5)
        axes = (0, 2, 10)
        new_shape = (1, 3, 1, 5, 1)

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
        self.attrs = {"axes": axes, "inplace": False}
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(["X"], "Out")

# Error: Input axes size is large than Eigen limit.
class TestUnsqueezeOp7(OpTest):
    def setUp(self):
        ori_shape = (3, 5)
        axes = (0, 2, 2, 2, 2, 2)
        new_shape = (1, 3, 1, 1, 5, 1)

        self.op_type = "unsqueeze"
        self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
        self.attrs = {"axes": axes, "inplace": False}
        self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(["X"], "Out")
'''
227 228 229

if __name__ == "__main__":
    unittest.main()