test_randperm_op.py 7.2 KB
Newer Older
C
cc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
20
from paddle.static import program_guard, Program
Z
zyfncg 已提交
21
from paddle.fluid.framework import _test_eager_guard
C
cc 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37


def check_randperm_out(n, data_np):
    assert isinstance(data_np, np.ndarray), \
        "The input data_np should be np.ndarray."
    gt_sorted = np.arange(n)
    out_sorted = np.sort(data_np)
    return list(gt_sorted == out_sorted)


def error_msg(data_np):
    return "The sorted ground truth and sorted out should " + \
 "be equal, out = " + str(data_np)


def convert_dtype(dtype_str):
38 39 40 41 42
    dtype_str_list = ["int32", "int64", "float32", "float64"]
    dtype_num_list = [
        core.VarDesc.VarType.INT32, core.VarDesc.VarType.INT64,
        core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64
    ]
C
cc 已提交
43 44 45 46 47 48 49 50 51 52
    assert dtype_str in dtype_str_list, dtype_str + \
        " should in " + str(dtype_str_list)
    return dtype_num_list[dtype_str_list.index(dtype_str)]


class TestRandpermOp(OpTest):
    """ Test randperm op."""

    def setUp(self):
        self.op_type = "randperm"
Z
zyfncg 已提交
53
        self.python_api = paddle.randperm
C
cc 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        self.n = 200
        self.dtype = "int64"

        self.inputs = {}
        self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
        self.init_attrs()
        self.attrs = {
            "n": self.n,
            "dtype": convert_dtype(self.dtype),
        }

    def init_attrs(self):
        pass

    def test_check_output(self):
        self.check_output_customized(self.verify_output)

    def verify_output(self, outs):
        out_np = np.array(outs[0])
73 74
        self.assertTrue(check_randperm_out(self.n, out_np),
                        msg=error_msg(out_np))
C
cc 已提交
75

Z
zyfncg 已提交
76 77 78 79
    def test_eager(self):
        with _test_eager_guard():
            self.test_check_output()

C
cc 已提交
80

81
class TestRandpermOpN(TestRandpermOp):
82

C
cc 已提交
83 84 85 86
    def init_attrs(self):
        self.n = 10000


87
class TestRandpermOpInt32(TestRandpermOp):
88

C
cc 已提交
89 90 91 92
    def init_attrs(self):
        self.dtype = "int32"


93
class TestRandpermOpFloat32(TestRandpermOp):
94

C
cc 已提交
95
    def init_attrs(self):
96
        self.dtype = "float32"
C
cc 已提交
97 98


99
class TestRandpermOpFloat64(TestRandpermOp):
100

C
cc 已提交
101
    def init_attrs(self):
102
        self.dtype = "float64"
C
cc 已提交
103 104 105


class TestRandpermOpError(unittest.TestCase):
106

C
cc 已提交
107
    def test_errors(self):
108 109 110
        with program_guard(Program(), Program()):
            self.assertRaises(ValueError, paddle.randperm, -3)
            self.assertRaises(TypeError, paddle.randperm, 10, 'int8')
C
cc 已提交
111 112


113
class TestRandpermAPI(unittest.TestCase):
114

115 116
    def test_out(self):
        n = 10
117 118
        place = paddle.CUDAPlace(
            0) if core.is_compiled_with_cuda() else paddle.CPUPlace()
119 120 121
        with program_guard(Program(), Program()):
            x1 = paddle.randperm(n)
            x2 = paddle.randperm(n, 'float32')
C
cc 已提交
122

123
            exe = paddle.static.Executor(place)
124
            res = exe.run(fetch_list=[x1, x2])
C
cc 已提交
125

126 127 128 129
            self.assertEqual(res[0].dtype, np.int64)
            self.assertEqual(res[1].dtype, np.float32)
            self.assertTrue(check_randperm_out(n, res[0]))
            self.assertTrue(check_randperm_out(n, res[1]))
C
cc 已提交
130 131


132
class TestRandpermImperative(unittest.TestCase):
133

134
    def test_out(self):
135 136 137 138 139
        paddle.disable_static()
        n = 10
        for dtype in ['int32', np.int64, 'float32', 'float64']:
            data_p = paddle.randperm(n, dtype)
            data_np = data_p.numpy()
140 141
            self.assertTrue(check_randperm_out(n, data_np),
                            msg=error_msg(data_np))
142
        paddle.enable_static()
C
cc 已提交
143 144


Z
zyfncg 已提交
145
class TestRandpermEager(unittest.TestCase):
146

Z
zyfncg 已提交
147 148 149 150 151 152 153
    def test_out(self):
        paddle.disable_static()
        n = 10
        with _test_eager_guard():
            for dtype in ['int32', np.int64, 'float32', 'float64']:
                data_p = paddle.randperm(n, dtype)
                data_np = data_p.numpy()
154 155
                self.assertTrue(check_randperm_out(n, data_np),
                                msg=error_msg(data_np))
Z
zyfncg 已提交
156 157 158
        paddle.enable_static()


159
class TestRandomValue(unittest.TestCase):
160

161 162 163 164 165 166 167 168 169 170 171 172 173 174
    def test_fixed_random_number(self):
        # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
        if not paddle.is_compiled_with_cuda():
            return

        print("Test Fixed Random number on GPU------>")
        paddle.disable_static()
        paddle.set_device('gpu')
        paddle.seed(2021)

        x = paddle.randperm(30000, dtype='int32').numpy()
        expect = [
            24562, 8409, 9379, 10328, 20503, 18059, 9681, 21883, 11783, 27413
        ]
175
        np.testing.assert_array_equal(x[0:10], expect)
176 177 178
        expect = [
            29477, 27100, 9643, 16637, 8605, 16892, 27767, 2724, 1612, 13096
        ]
179
        np.testing.assert_array_equal(x[10000:10010], expect)
180 181 182
        expect = [
            298, 4104, 16479, 22714, 28684, 7510, 14667, 9950, 15940, 28343
        ]
183
        np.testing.assert_array_equal(x[20000:20010], expect)
184 185 186 187 188

        x = paddle.randperm(30000, dtype='int64').numpy()
        expect = [
            6587, 1909, 5525, 23001, 6488, 14981, 14355, 3083, 29561, 8171
        ]
189
        np.testing.assert_array_equal(x[0:10], expect)
190 191 192
        expect = [
            23460, 12394, 22501, 5427, 20185, 9100, 5127, 1651, 25806, 4818
        ]
193
        np.testing.assert_array_equal(x[10000:10010], expect)
194
        expect = [5829, 4508, 16193, 24836, 8526, 242, 9984, 9243, 1977, 11839]
195
        np.testing.assert_array_equal(x[20000:20010], expect)
196 197 198 199 200 201

        x = paddle.randperm(30000, dtype='float32').numpy()
        expect = [
            5154., 10537., 14362., 29843., 27185., 28399., 27561., 4144.,
            22906., 10705.
        ]
202
        np.testing.assert_array_equal(x[0:10], expect)
203 204 205 206
        expect = [
            1958., 18414., 20090., 21910., 22746., 27346., 22347., 3002., 4564.,
            26991.
        ]
207
        np.testing.assert_array_equal(x[10000:10010], expect)
208 209 210 211
        expect = [
            25580., 12606., 553., 16387., 29536., 4241., 20946., 16899., 16339.,
            4662.
        ]
212
        np.testing.assert_array_equal(x[20000:20010], expect)
213 214 215 216 217 218

        x = paddle.randperm(30000, dtype='float64').numpy()
        expect = [
            19051., 2449., 21940., 11121., 282., 7330., 13747., 24321., 21147.,
            9163.
        ]
219
        np.testing.assert_array_equal(x[0:10], expect)
220 221 222 223
        expect = [
            15483., 1315., 5723., 20954., 13251., 25539., 5074., 1823., 14945.,
            17624.
        ]
224
        np.testing.assert_array_equal(x[10000:10010], expect)
225 226 227 228
        expect = [
            10516., 2552., 29970., 5941., 986., 8007., 24805., 26753., 12202.,
            21404.
        ]
229
        np.testing.assert_array_equal(x[20000:20010], expect)
230 231 232
        paddle.enable_static()


C
cc 已提交
233 234
if __name__ == "__main__":
    unittest.main()