test_randperm_op.py 8.3 KB
Newer Older
C
cc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

C
cc 已提交
17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest
19

C
cc 已提交
20 21
import paddle
import paddle.fluid.core as core
22
from paddle.static import Program, program_guard
C
cc 已提交
23 24 25


def check_randperm_out(n, data_np):
26 27 28
    assert isinstance(
        data_np, np.ndarray
    ), "The input data_np should be np.ndarray."
C
cc 已提交
29 30 31 32 33 34
    gt_sorted = np.arange(n)
    out_sorted = np.sort(data_np)
    return list(gt_sorted == out_sorted)


def error_msg(data_np):
35 36 37 38 39
    return (
        "The sorted ground truth and sorted out should "
        + "be equal, out = "
        + str(data_np)
    )
C
cc 已提交
40 41 42


def convert_dtype(dtype_str):
43 44
    dtype_str_list = ["int32", "int64", "float32", "float64"]
    dtype_num_list = [
45 46 47 48
        core.VarDesc.VarType.INT32,
        core.VarDesc.VarType.INT64,
        core.VarDesc.VarType.FP32,
        core.VarDesc.VarType.FP64,
49
    ]
50 51 52
    assert dtype_str in dtype_str_list, (
        dtype_str + " should in " + str(dtype_str_list)
    )
C
cc 已提交
53 54 55 56
    return dtype_num_list[dtype_str_list.index(dtype_str)]


class TestRandpermOp(OpTest):
57
    """Test randperm op."""
C
cc 已提交
58 59 60

    def setUp(self):
        self.op_type = "randperm"
Z
zyfncg 已提交
61
        self.python_api = paddle.randperm
C
cc 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        self.n = 200
        self.dtype = "int64"

        self.inputs = {}
        self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
        self.init_attrs()
        self.attrs = {
            "n": self.n,
            "dtype": convert_dtype(self.dtype),
        }

    def init_attrs(self):
        pass

    def test_check_output(self):
        self.check_output_customized(self.verify_output)

    def verify_output(self, outs):
        out_np = np.array(outs[0])
81 82 83
        self.assertTrue(
            check_randperm_out(self.n, out_np), msg=error_msg(out_np)
        )
C
cc 已提交
84 85


86
class TestRandpermOpN(TestRandpermOp):
C
cc 已提交
87 88 89 90
    def init_attrs(self):
        self.n = 10000


91
class TestRandpermOpInt32(TestRandpermOp):
C
cc 已提交
92 93 94 95
    def init_attrs(self):
        self.dtype = "int32"


96
class TestRandpermOpFloat32(TestRandpermOp):
C
cc 已提交
97
    def init_attrs(self):
98
        self.dtype = "float32"
C
cc 已提交
99 100


101
class TestRandpermOpFloat64(TestRandpermOp):
C
cc 已提交
102
    def init_attrs(self):
103
        self.dtype = "float64"
C
cc 已提交
104 105 106 107


class TestRandpermOpError(unittest.TestCase):
    def test_errors(self):
108 109 110
        with program_guard(Program(), Program()):
            self.assertRaises(ValueError, paddle.randperm, -3)
            self.assertRaises(TypeError, paddle.randperm, 10, 'int8')
C
cc 已提交
111 112


113 114 115
class TestRandpermAPI(unittest.TestCase):
    def test_out(self):
        n = 10
116 117 118 119 120
        place = (
            paddle.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else paddle.CPUPlace()
        )
121 122 123
        with program_guard(Program(), Program()):
            x1 = paddle.randperm(n)
            x2 = paddle.randperm(n, 'float32')
C
cc 已提交
124

125
            exe = paddle.static.Executor(place)
126
            res = exe.run(fetch_list=[x1, x2])
C
cc 已提交
127

128 129 130 131
            self.assertEqual(res[0].dtype, np.int64)
            self.assertEqual(res[1].dtype, np.float32)
            self.assertTrue(check_randperm_out(n, res[0]))
            self.assertTrue(check_randperm_out(n, res[1]))
C
cc 已提交
132 133


134 135
class TestRandpermImperative(unittest.TestCase):
    def test_out(self):
136 137 138 139 140
        paddle.disable_static()
        n = 10
        for dtype in ['int32', np.int64, 'float32', 'float64']:
            data_p = paddle.randperm(n, dtype)
            data_np = data_p.numpy()
141 142 143
            self.assertTrue(
                check_randperm_out(n, data_np), msg=error_msg(data_np)
            )
144
        paddle.enable_static()
C
cc 已提交
145 146


Z
zyfncg 已提交
147 148 149 150
class TestRandpermEager(unittest.TestCase):
    def test_out(self):
        paddle.disable_static()
        n = 10
151 152 153 154 155 156
        for dtype in ['int32', np.int64, 'float32', 'float64']:
            data_p = paddle.randperm(n, dtype)
            data_np = data_p.numpy()
            self.assertTrue(
                check_randperm_out(n, data_np), msg=error_msg(data_np)
            )
Z
zyfncg 已提交
157 158 159
        paddle.enable_static()


160 161 162 163 164 165 166 167 168 169 170 171 172
class TestRandomValue(unittest.TestCase):
    def test_fixed_random_number(self):
        # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
        if not paddle.is_compiled_with_cuda():
            return

        print("Test Fixed Random number on GPU------>")
        paddle.disable_static()
        paddle.set_device('gpu')
        paddle.seed(2021)

        x = paddle.randperm(30000, dtype='int32').numpy()
        expect = [
173 174 175 176 177 178 179 180 181 182
            24562,
            8409,
            9379,
            10328,
            20503,
            18059,
            9681,
            21883,
            11783,
            27413,
183
        ]
184
        np.testing.assert_array_equal(x[0:10], expect)
185
        expect = [
186 187 188 189 190 191 192 193 194 195
            29477,
            27100,
            9643,
            16637,
            8605,
            16892,
            27767,
            2724,
            1612,
            13096,
196
        ]
197
        np.testing.assert_array_equal(x[10000:10010], expect)
198
        expect = [
199 200 201 202 203 204 205 206 207 208
            298,
            4104,
            16479,
            22714,
            28684,
            7510,
            14667,
            9950,
            15940,
            28343,
209
        ]
210
        np.testing.assert_array_equal(x[20000:20010], expect)
211 212 213

        x = paddle.randperm(30000, dtype='int64').numpy()
        expect = [
214 215 216 217 218 219 220 221 222 223
            6587,
            1909,
            5525,
            23001,
            6488,
            14981,
            14355,
            3083,
            29561,
            8171,
224
        ]
225
        np.testing.assert_array_equal(x[0:10], expect)
226
        expect = [
227 228 229 230 231 232 233 234 235 236
            23460,
            12394,
            22501,
            5427,
            20185,
            9100,
            5127,
            1651,
            25806,
            4818,
237
        ]
238
        np.testing.assert_array_equal(x[10000:10010], expect)
239
        expect = [5829, 4508, 16193, 24836, 8526, 242, 9984, 9243, 1977, 11839]
240
        np.testing.assert_array_equal(x[20000:20010], expect)
241 242 243

        x = paddle.randperm(30000, dtype='float32').numpy()
        expect = [
244 245 246 247 248 249 250 251 252 253
            5154.0,
            10537.0,
            14362.0,
            29843.0,
            27185.0,
            28399.0,
            27561.0,
            4144.0,
            22906.0,
            10705.0,
254
        ]
255
        np.testing.assert_array_equal(x[0:10], expect)
256
        expect = [
257 258 259 260 261 262 263 264 265 266
            1958.0,
            18414.0,
            20090.0,
            21910.0,
            22746.0,
            27346.0,
            22347.0,
            3002.0,
            4564.0,
            26991.0,
267
        ]
268
        np.testing.assert_array_equal(x[10000:10010], expect)
269
        expect = [
270 271 272 273 274 275 276 277 278 279
            25580.0,
            12606.0,
            553.0,
            16387.0,
            29536.0,
            4241.0,
            20946.0,
            16899.0,
            16339.0,
            4662.0,
280
        ]
281
        np.testing.assert_array_equal(x[20000:20010], expect)
282 283 284

        x = paddle.randperm(30000, dtype='float64').numpy()
        expect = [
285 286 287 288 289 290 291 292 293 294
            19051.0,
            2449.0,
            21940.0,
            11121.0,
            282.0,
            7330.0,
            13747.0,
            24321.0,
            21147.0,
            9163.0,
295
        ]
296
        np.testing.assert_array_equal(x[0:10], expect)
297
        expect = [
298 299 300 301 302 303 304 305 306 307
            15483.0,
            1315.0,
            5723.0,
            20954.0,
            13251.0,
            25539.0,
            5074.0,
            1823.0,
            14945.0,
            17624.0,
308
        ]
309
        np.testing.assert_array_equal(x[10000:10010], expect)
310
        expect = [
311 312 313 314 315 316 317 318 319 320
            10516.0,
            2552.0,
            29970.0,
            5941.0,
            986.0,
            8007.0,
            24805.0,
            26753.0,
            12202.0,
            21404.0,
321
        ]
322
        np.testing.assert_array_equal(x[20000:20010], expect)
323 324 325
        paddle.enable_static()


C
cc 已提交
326 327
if __name__ == "__main__":
    unittest.main()