test_randperm_op.py 8.5 KB
Newer Older
C
cc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

C
cc 已提交
17 18
import numpy as np
from op_test import OpTest
19

C
cc 已提交
20 21
import paddle
import paddle.fluid.core as core
Z
zyfncg 已提交
22
from paddle.fluid.framework import _test_eager_guard
23
from paddle.static import Program, program_guard
C
cc 已提交
24 25 26


def check_randperm_out(n, data_np):
27 28 29
    assert isinstance(
        data_np, np.ndarray
    ), "The input data_np should be np.ndarray."
C
cc 已提交
30 31 32 33 34 35
    gt_sorted = np.arange(n)
    out_sorted = np.sort(data_np)
    return list(gt_sorted == out_sorted)


def error_msg(data_np):
36 37 38 39 40
    return (
        "The sorted ground truth and sorted out should "
        + "be equal, out = "
        + str(data_np)
    )
C
cc 已提交
41 42 43


def convert_dtype(dtype_str):
44 45
    dtype_str_list = ["int32", "int64", "float32", "float64"]
    dtype_num_list = [
46 47 48 49
        core.VarDesc.VarType.INT32,
        core.VarDesc.VarType.INT64,
        core.VarDesc.VarType.FP32,
        core.VarDesc.VarType.FP64,
50
    ]
51 52 53
    assert dtype_str in dtype_str_list, (
        dtype_str + " should in " + str(dtype_str_list)
    )
C
cc 已提交
54 55 56 57
    return dtype_num_list[dtype_str_list.index(dtype_str)]


class TestRandpermOp(OpTest):
58
    """Test randperm op."""
C
cc 已提交
59 60 61

    def setUp(self):
        self.op_type = "randperm"
Z
zyfncg 已提交
62
        self.python_api = paddle.randperm
C
cc 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        self.n = 200
        self.dtype = "int64"

        self.inputs = {}
        self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
        self.init_attrs()
        self.attrs = {
            "n": self.n,
            "dtype": convert_dtype(self.dtype),
        }

    def init_attrs(self):
        pass

    def test_check_output(self):
        self.check_output_customized(self.verify_output)

    def verify_output(self, outs):
        out_np = np.array(outs[0])
82 83 84
        self.assertTrue(
            check_randperm_out(self.n, out_np), msg=error_msg(out_np)
        )
C
cc 已提交
85

Z
zyfncg 已提交
86 87 88 89
    def test_eager(self):
        with _test_eager_guard():
            self.test_check_output()

C
cc 已提交
90

91
class TestRandpermOpN(TestRandpermOp):
C
cc 已提交
92 93 94 95
    def init_attrs(self):
        self.n = 10000


96
class TestRandpermOpInt32(TestRandpermOp):
C
cc 已提交
97 98 99 100
    def init_attrs(self):
        self.dtype = "int32"


101
class TestRandpermOpFloat32(TestRandpermOp):
C
cc 已提交
102
    def init_attrs(self):
103
        self.dtype = "float32"
C
cc 已提交
104 105


106
class TestRandpermOpFloat64(TestRandpermOp):
C
cc 已提交
107
    def init_attrs(self):
108
        self.dtype = "float64"
C
cc 已提交
109 110 111 112


class TestRandpermOpError(unittest.TestCase):
    def test_errors(self):
113 114 115
        with program_guard(Program(), Program()):
            self.assertRaises(ValueError, paddle.randperm, -3)
            self.assertRaises(TypeError, paddle.randperm, 10, 'int8')
C
cc 已提交
116 117


118 119 120
class TestRandpermAPI(unittest.TestCase):
    def test_out(self):
        n = 10
121 122 123 124 125
        place = (
            paddle.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else paddle.CPUPlace()
        )
126 127 128
        with program_guard(Program(), Program()):
            x1 = paddle.randperm(n)
            x2 = paddle.randperm(n, 'float32')
C
cc 已提交
129

130
            exe = paddle.static.Executor(place)
131
            res = exe.run(fetch_list=[x1, x2])
C
cc 已提交
132

133 134 135 136
            self.assertEqual(res[0].dtype, np.int64)
            self.assertEqual(res[1].dtype, np.float32)
            self.assertTrue(check_randperm_out(n, res[0]))
            self.assertTrue(check_randperm_out(n, res[1]))
C
cc 已提交
137 138


139 140
class TestRandpermImperative(unittest.TestCase):
    def test_out(self):
141 142 143 144 145
        paddle.disable_static()
        n = 10
        for dtype in ['int32', np.int64, 'float32', 'float64']:
            data_p = paddle.randperm(n, dtype)
            data_np = data_p.numpy()
146 147 148
            self.assertTrue(
                check_randperm_out(n, data_np), msg=error_msg(data_np)
            )
149
        paddle.enable_static()
C
cc 已提交
150 151


Z
zyfncg 已提交
152 153 154 155 156 157 158 159
class TestRandpermEager(unittest.TestCase):
    def test_out(self):
        paddle.disable_static()
        n = 10
        with _test_eager_guard():
            for dtype in ['int32', np.int64, 'float32', 'float64']:
                data_p = paddle.randperm(n, dtype)
                data_np = data_p.numpy()
160 161 162
                self.assertTrue(
                    check_randperm_out(n, data_np), msg=error_msg(data_np)
                )
Z
zyfncg 已提交
163 164 165
        paddle.enable_static()


166 167 168 169 170 171 172 173 174 175 176 177 178
class TestRandomValue(unittest.TestCase):
    def test_fixed_random_number(self):
        # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
        if not paddle.is_compiled_with_cuda():
            return

        print("Test Fixed Random number on GPU------>")
        paddle.disable_static()
        paddle.set_device('gpu')
        paddle.seed(2021)

        x = paddle.randperm(30000, dtype='int32').numpy()
        expect = [
179 180 181 182 183 184 185 186 187 188
            24562,
            8409,
            9379,
            10328,
            20503,
            18059,
            9681,
            21883,
            11783,
            27413,
189
        ]
190
        np.testing.assert_array_equal(x[0:10], expect)
191
        expect = [
192 193 194 195 196 197 198 199 200 201
            29477,
            27100,
            9643,
            16637,
            8605,
            16892,
            27767,
            2724,
            1612,
            13096,
202
        ]
203
        np.testing.assert_array_equal(x[10000:10010], expect)
204
        expect = [
205 206 207 208 209 210 211 212 213 214
            298,
            4104,
            16479,
            22714,
            28684,
            7510,
            14667,
            9950,
            15940,
            28343,
215
        ]
216
        np.testing.assert_array_equal(x[20000:20010], expect)
217 218 219

        x = paddle.randperm(30000, dtype='int64').numpy()
        expect = [
220 221 222 223 224 225 226 227 228 229
            6587,
            1909,
            5525,
            23001,
            6488,
            14981,
            14355,
            3083,
            29561,
            8171,
230
        ]
231
        np.testing.assert_array_equal(x[0:10], expect)
232
        expect = [
233 234 235 236 237 238 239 240 241 242
            23460,
            12394,
            22501,
            5427,
            20185,
            9100,
            5127,
            1651,
            25806,
            4818,
243
        ]
244
        np.testing.assert_array_equal(x[10000:10010], expect)
245
        expect = [5829, 4508, 16193, 24836, 8526, 242, 9984, 9243, 1977, 11839]
246
        np.testing.assert_array_equal(x[20000:20010], expect)
247 248 249

        x = paddle.randperm(30000, dtype='float32').numpy()
        expect = [
250 251 252 253 254 255 256 257 258 259
            5154.0,
            10537.0,
            14362.0,
            29843.0,
            27185.0,
            28399.0,
            27561.0,
            4144.0,
            22906.0,
            10705.0,
260
        ]
261
        np.testing.assert_array_equal(x[0:10], expect)
262
        expect = [
263 264 265 266 267 268 269 270 271 272
            1958.0,
            18414.0,
            20090.0,
            21910.0,
            22746.0,
            27346.0,
            22347.0,
            3002.0,
            4564.0,
            26991.0,
273
        ]
274
        np.testing.assert_array_equal(x[10000:10010], expect)
275
        expect = [
276 277 278 279 280 281 282 283 284 285
            25580.0,
            12606.0,
            553.0,
            16387.0,
            29536.0,
            4241.0,
            20946.0,
            16899.0,
            16339.0,
            4662.0,
286
        ]
287
        np.testing.assert_array_equal(x[20000:20010], expect)
288 289 290

        x = paddle.randperm(30000, dtype='float64').numpy()
        expect = [
291 292 293 294 295 296 297 298 299 300
            19051.0,
            2449.0,
            21940.0,
            11121.0,
            282.0,
            7330.0,
            13747.0,
            24321.0,
            21147.0,
            9163.0,
301
        ]
302
        np.testing.assert_array_equal(x[0:10], expect)
303
        expect = [
304 305 306 307 308 309 310 311 312 313
            15483.0,
            1315.0,
            5723.0,
            20954.0,
            13251.0,
            25539.0,
            5074.0,
            1823.0,
            14945.0,
            17624.0,
314
        ]
315
        np.testing.assert_array_equal(x[10000:10010], expect)
316
        expect = [
317 318 319 320 321 322 323 324 325 326
            10516.0,
            2552.0,
            29970.0,
            5941.0,
            986.0,
            8007.0,
            24805.0,
            26753.0,
            12202.0,
            21404.0,
327
        ]
328
        np.testing.assert_array_equal(x[20000:20010], expect)
329 330 331
        paddle.enable_static()


C
cc 已提交
332 333
if __name__ == "__main__":
    unittest.main()