test_randperm_op.py 8.5 KB
Newer Older
C
cc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
20
from paddle.static import program_guard, Program
Z
zyfncg 已提交
21
from paddle.fluid.framework import _test_eager_guard
C
cc 已提交
22 23 24


def check_randperm_out(n, data_np):
25 26 27
    assert isinstance(
        data_np, np.ndarray
    ), "The input data_np should be np.ndarray."
C
cc 已提交
28 29 30 31 32 33
    gt_sorted = np.arange(n)
    out_sorted = np.sort(data_np)
    return list(gt_sorted == out_sorted)


def error_msg(data_np):
34 35 36 37 38
    return (
        "The sorted ground truth and sorted out should "
        + "be equal, out = "
        + str(data_np)
    )
C
cc 已提交
39 40 41


def convert_dtype(dtype_str):
42 43
    dtype_str_list = ["int32", "int64", "float32", "float64"]
    dtype_num_list = [
44 45 46 47
        core.VarDesc.VarType.INT32,
        core.VarDesc.VarType.INT64,
        core.VarDesc.VarType.FP32,
        core.VarDesc.VarType.FP64,
48
    ]
49 50 51
    assert dtype_str in dtype_str_list, (
        dtype_str + " should in " + str(dtype_str_list)
    )
C
cc 已提交
52 53 54 55
    return dtype_num_list[dtype_str_list.index(dtype_str)]


class TestRandpermOp(OpTest):
56
    """Test randperm op."""
C
cc 已提交
57 58 59

    def setUp(self):
        self.op_type = "randperm"
Z
zyfncg 已提交
60
        self.python_api = paddle.randperm
C
cc 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        self.n = 200
        self.dtype = "int64"

        self.inputs = {}
        self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
        self.init_attrs()
        self.attrs = {
            "n": self.n,
            "dtype": convert_dtype(self.dtype),
        }

    def init_attrs(self):
        pass

    def test_check_output(self):
        self.check_output_customized(self.verify_output)

    def verify_output(self, outs):
        out_np = np.array(outs[0])
80 81 82
        self.assertTrue(
            check_randperm_out(self.n, out_np), msg=error_msg(out_np)
        )
C
cc 已提交
83

Z
zyfncg 已提交
84 85 86 87
    def test_eager(self):
        with _test_eager_guard():
            self.test_check_output()

C
cc 已提交
88

89
class TestRandpermOpN(TestRandpermOp):
C
cc 已提交
90 91 92 93
    def init_attrs(self):
        self.n = 10000


94
class TestRandpermOpInt32(TestRandpermOp):
C
cc 已提交
95 96 97 98
    def init_attrs(self):
        self.dtype = "int32"


99
class TestRandpermOpFloat32(TestRandpermOp):
C
cc 已提交
100
    def init_attrs(self):
101
        self.dtype = "float32"
C
cc 已提交
102 103


104
class TestRandpermOpFloat64(TestRandpermOp):
C
cc 已提交
105
    def init_attrs(self):
106
        self.dtype = "float64"
C
cc 已提交
107 108 109 110


class TestRandpermOpError(unittest.TestCase):
    def test_errors(self):
111 112 113
        with program_guard(Program(), Program()):
            self.assertRaises(ValueError, paddle.randperm, -3)
            self.assertRaises(TypeError, paddle.randperm, 10, 'int8')
C
cc 已提交
114 115


116 117 118
class TestRandpermAPI(unittest.TestCase):
    def test_out(self):
        n = 10
119 120 121 122 123
        place = (
            paddle.CUDAPlace(0)
            if core.is_compiled_with_cuda()
            else paddle.CPUPlace()
        )
124 125 126
        with program_guard(Program(), Program()):
            x1 = paddle.randperm(n)
            x2 = paddle.randperm(n, 'float32')
C
cc 已提交
127

128
            exe = paddle.static.Executor(place)
129
            res = exe.run(fetch_list=[x1, x2])
C
cc 已提交
130

131 132 133 134
            self.assertEqual(res[0].dtype, np.int64)
            self.assertEqual(res[1].dtype, np.float32)
            self.assertTrue(check_randperm_out(n, res[0]))
            self.assertTrue(check_randperm_out(n, res[1]))
C
cc 已提交
135 136


137 138
class TestRandpermImperative(unittest.TestCase):
    def test_out(self):
139 140 141 142 143
        paddle.disable_static()
        n = 10
        for dtype in ['int32', np.int64, 'float32', 'float64']:
            data_p = paddle.randperm(n, dtype)
            data_np = data_p.numpy()
144 145 146
            self.assertTrue(
                check_randperm_out(n, data_np), msg=error_msg(data_np)
            )
147
        paddle.enable_static()
C
cc 已提交
148 149


Z
zyfncg 已提交
150 151 152 153 154 155 156 157
class TestRandpermEager(unittest.TestCase):
    def test_out(self):
        paddle.disable_static()
        n = 10
        with _test_eager_guard():
            for dtype in ['int32', np.int64, 'float32', 'float64']:
                data_p = paddle.randperm(n, dtype)
                data_np = data_p.numpy()
158 159 160
                self.assertTrue(
                    check_randperm_out(n, data_np), msg=error_msg(data_np)
                )
Z
zyfncg 已提交
161 162 163
        paddle.enable_static()


164 165 166 167 168 169 170 171 172 173 174 175 176
class TestRandomValue(unittest.TestCase):
    def test_fixed_random_number(self):
        # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
        if not paddle.is_compiled_with_cuda():
            return

        print("Test Fixed Random number on GPU------>")
        paddle.disable_static()
        paddle.set_device('gpu')
        paddle.seed(2021)

        x = paddle.randperm(30000, dtype='int32').numpy()
        expect = [
177 178 179 180 181 182 183 184 185 186
            24562,
            8409,
            9379,
            10328,
            20503,
            18059,
            9681,
            21883,
            11783,
            27413,
187
        ]
188
        np.testing.assert_array_equal(x[0:10], expect)
189
        expect = [
190 191 192 193 194 195 196 197 198 199
            29477,
            27100,
            9643,
            16637,
            8605,
            16892,
            27767,
            2724,
            1612,
            13096,
200
        ]
201
        np.testing.assert_array_equal(x[10000:10010], expect)
202
        expect = [
203 204 205 206 207 208 209 210 211 212
            298,
            4104,
            16479,
            22714,
            28684,
            7510,
            14667,
            9950,
            15940,
            28343,
213
        ]
214
        np.testing.assert_array_equal(x[20000:20010], expect)
215 216 217

        x = paddle.randperm(30000, dtype='int64').numpy()
        expect = [
218 219 220 221 222 223 224 225 226 227
            6587,
            1909,
            5525,
            23001,
            6488,
            14981,
            14355,
            3083,
            29561,
            8171,
228
        ]
229
        np.testing.assert_array_equal(x[0:10], expect)
230
        expect = [
231 232 233 234 235 236 237 238 239 240
            23460,
            12394,
            22501,
            5427,
            20185,
            9100,
            5127,
            1651,
            25806,
            4818,
241
        ]
242
        np.testing.assert_array_equal(x[10000:10010], expect)
243
        expect = [5829, 4508, 16193, 24836, 8526, 242, 9984, 9243, 1977, 11839]
244
        np.testing.assert_array_equal(x[20000:20010], expect)
245 246 247

        x = paddle.randperm(30000, dtype='float32').numpy()
        expect = [
248 249 250 251 252 253 254 255 256 257
            5154.0,
            10537.0,
            14362.0,
            29843.0,
            27185.0,
            28399.0,
            27561.0,
            4144.0,
            22906.0,
            10705.0,
258
        ]
259
        np.testing.assert_array_equal(x[0:10], expect)
260
        expect = [
261 262 263 264 265 266 267 268 269 270
            1958.0,
            18414.0,
            20090.0,
            21910.0,
            22746.0,
            27346.0,
            22347.0,
            3002.0,
            4564.0,
            26991.0,
271
        ]
272
        np.testing.assert_array_equal(x[10000:10010], expect)
273
        expect = [
274 275 276 277 278 279 280 281 282 283
            25580.0,
            12606.0,
            553.0,
            16387.0,
            29536.0,
            4241.0,
            20946.0,
            16899.0,
            16339.0,
            4662.0,
284
        ]
285
        np.testing.assert_array_equal(x[20000:20010], expect)
286 287 288

        x = paddle.randperm(30000, dtype='float64').numpy()
        expect = [
289 290 291 292 293 294 295 296 297 298
            19051.0,
            2449.0,
            21940.0,
            11121.0,
            282.0,
            7330.0,
            13747.0,
            24321.0,
            21147.0,
            9163.0,
299
        ]
300
        np.testing.assert_array_equal(x[0:10], expect)
301
        expect = [
302 303 304 305 306 307 308 309 310 311
            15483.0,
            1315.0,
            5723.0,
            20954.0,
            13251.0,
            25539.0,
            5074.0,
            1823.0,
            14945.0,
            17624.0,
312
        ]
313
        np.testing.assert_array_equal(x[10000:10010], expect)
314
        expect = [
315 316 317 318 319 320 321 322 323 324
            10516.0,
            2552.0,
            29970.0,
            5941.0,
            986.0,
            8007.0,
            24805.0,
            26753.0,
            12202.0,
            21404.0,
325
        ]
326
        np.testing.assert_array_equal(x[20000:20010], expect)
327 328 329
        paddle.enable_static()


C
cc 已提交
330 331
if __name__ == "__main__":
    unittest.main()