test_randperm_op.py 7.3 KB
Newer Older
C
cc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
20
from paddle.static import program_guard, Program
Z
zyfncg 已提交
21
from paddle.fluid.framework import _test_eager_guard
22
import os
C
cc 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38


def check_randperm_out(n, data_np):
    assert isinstance(data_np, np.ndarray), \
        "The input data_np should be np.ndarray."
    gt_sorted = np.arange(n)
    out_sorted = np.sort(data_np)
    return list(gt_sorted == out_sorted)


def error_msg(data_np):
    return "The sorted ground truth and sorted out should " + \
 "be equal, out = " + str(data_np)


def convert_dtype(dtype_str):
39 40 41 42 43
    dtype_str_list = ["int32", "int64", "float32", "float64"]
    dtype_num_list = [
        core.VarDesc.VarType.INT32, core.VarDesc.VarType.INT64,
        core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP64
    ]
C
cc 已提交
44 45 46 47 48 49 50 51 52 53
    assert dtype_str in dtype_str_list, dtype_str + \
        " should in " + str(dtype_str_list)
    return dtype_num_list[dtype_str_list.index(dtype_str)]


class TestRandpermOp(OpTest):
    """ Test randperm op."""

    def setUp(self):
        self.op_type = "randperm"
Z
zyfncg 已提交
54
        self.python_api = paddle.randperm
C
cc 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        self.n = 200
        self.dtype = "int64"

        self.inputs = {}
        self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
        self.init_attrs()
        self.attrs = {
            "n": self.n,
            "dtype": convert_dtype(self.dtype),
        }

    def init_attrs(self):
        pass

    def test_check_output(self):
        self.check_output_customized(self.verify_output)

    def verify_output(self, outs):
        out_np = np.array(outs[0])
74 75
        self.assertTrue(check_randperm_out(self.n, out_np),
                        msg=error_msg(out_np))
C
cc 已提交
76

Z
zyfncg 已提交
77 78 79 80
    def test_eager(self):
        with _test_eager_guard():
            self.test_check_output()

C
cc 已提交
81

82
class TestRandpermOpN(TestRandpermOp):
83

C
cc 已提交
84 85 86 87
    def init_attrs(self):
        self.n = 10000


88
class TestRandpermOpInt32(TestRandpermOp):
89

C
cc 已提交
90 91 92 93
    def init_attrs(self):
        self.dtype = "int32"


94
class TestRandpermOpFloat32(TestRandpermOp):
95

C
cc 已提交
96
    def init_attrs(self):
97
        self.dtype = "float32"
C
cc 已提交
98 99


100
class TestRandpermOpFloat64(TestRandpermOp):
101

C
cc 已提交
102
    def init_attrs(self):
103
        self.dtype = "float64"
C
cc 已提交
104 105 106


class TestRandpermOpError(unittest.TestCase):
107

C
cc 已提交
108
    def test_errors(self):
109 110 111
        with program_guard(Program(), Program()):
            self.assertRaises(ValueError, paddle.randperm, -3)
            self.assertRaises(TypeError, paddle.randperm, 10, 'int8')
C
cc 已提交
112 113


114
class TestRandpermAPI(unittest.TestCase):
115

116 117
    def test_out(self):
        n = 10
118 119
        place = paddle.CUDAPlace(
            0) if core.is_compiled_with_cuda() else paddle.CPUPlace()
120 121 122
        with program_guard(Program(), Program()):
            x1 = paddle.randperm(n)
            x2 = paddle.randperm(n, 'float32')
C
cc 已提交
123

124
            exe = paddle.static.Executor(place)
125
            res = exe.run(fetch_list=[x1, x2])
C
cc 已提交
126

127 128 129 130
            self.assertEqual(res[0].dtype, np.int64)
            self.assertEqual(res[1].dtype, np.float32)
            self.assertTrue(check_randperm_out(n, res[0]))
            self.assertTrue(check_randperm_out(n, res[1]))
C
cc 已提交
131 132


133
class TestRandpermImperative(unittest.TestCase):
134

135
    def test_out(self):
136 137 138 139 140
        paddle.disable_static()
        n = 10
        for dtype in ['int32', np.int64, 'float32', 'float64']:
            data_p = paddle.randperm(n, dtype)
            data_np = data_p.numpy()
141 142
            self.assertTrue(check_randperm_out(n, data_np),
                            msg=error_msg(data_np))
143
        paddle.enable_static()
C
cc 已提交
144 145


Z
zyfncg 已提交
146
class TestRandpermEager(unittest.TestCase):
147

Z
zyfncg 已提交
148 149 150 151 152 153 154
    def test_out(self):
        paddle.disable_static()
        n = 10
        with _test_eager_guard():
            for dtype in ['int32', np.int64, 'float32', 'float64']:
                data_p = paddle.randperm(n, dtype)
                data_np = data_p.numpy()
155 156
                self.assertTrue(check_randperm_out(n, data_np),
                                msg=error_msg(data_np))
Z
zyfncg 已提交
157 158 159
        paddle.enable_static()


160
class TestRandomValue(unittest.TestCase):
161

162 163 164 165 166 167 168 169 170 171 172 173 174 175
    def test_fixed_random_number(self):
        # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
        if not paddle.is_compiled_with_cuda():
            return

        print("Test Fixed Random number on GPU------>")
        paddle.disable_static()
        paddle.set_device('gpu')
        paddle.seed(2021)

        x = paddle.randperm(30000, dtype='int32').numpy()
        expect = [
            24562, 8409, 9379, 10328, 20503, 18059, 9681, 21883, 11783, 27413
        ]
176
        np.testing.assert_array_equal(x[0:10], expect)
177 178 179
        expect = [
            29477, 27100, 9643, 16637, 8605, 16892, 27767, 2724, 1612, 13096
        ]
180
        np.testing.assert_array_equal(x[10000:10010], expect)
181 182 183
        expect = [
            298, 4104, 16479, 22714, 28684, 7510, 14667, 9950, 15940, 28343
        ]
184
        np.testing.assert_array_equal(x[20000:20010], expect)
185 186 187 188 189

        x = paddle.randperm(30000, dtype='int64').numpy()
        expect = [
            6587, 1909, 5525, 23001, 6488, 14981, 14355, 3083, 29561, 8171
        ]
190
        np.testing.assert_array_equal(x[0:10], expect)
191 192 193
        expect = [
            23460, 12394, 22501, 5427, 20185, 9100, 5127, 1651, 25806, 4818
        ]
194
        np.testing.assert_array_equal(x[10000:10010], expect)
195
        expect = [5829, 4508, 16193, 24836, 8526, 242, 9984, 9243, 1977, 11839]
196
        np.testing.assert_array_equal(x[20000:20010], expect)
197 198 199 200 201 202

        x = paddle.randperm(30000, dtype='float32').numpy()
        expect = [
            5154., 10537., 14362., 29843., 27185., 28399., 27561., 4144.,
            22906., 10705.
        ]
203
        np.testing.assert_array_equal(x[0:10], expect)
204 205 206 207
        expect = [
            1958., 18414., 20090., 21910., 22746., 27346., 22347., 3002., 4564.,
            26991.
        ]
208
        np.testing.assert_array_equal(x[10000:10010], expect)
209 210 211 212
        expect = [
            25580., 12606., 553., 16387., 29536., 4241., 20946., 16899., 16339.,
            4662.
        ]
213
        np.testing.assert_array_equal(x[20000:20010], expect)
214 215 216 217 218 219

        x = paddle.randperm(30000, dtype='float64').numpy()
        expect = [
            19051., 2449., 21940., 11121., 282., 7330., 13747., 24321., 21147.,
            9163.
        ]
220
        np.testing.assert_array_equal(x[0:10], expect)
221 222 223 224
        expect = [
            15483., 1315., 5723., 20954., 13251., 25539., 5074., 1823., 14945.,
            17624.
        ]
225
        np.testing.assert_array_equal(x[10000:10010], expect)
226 227 228 229
        expect = [
            10516., 2552., 29970., 5941., 986., 8007., 24805., 26753., 12202.,
            21404.
        ]
230
        np.testing.assert_array_equal(x[20000:20010], expect)
231 232 233
        paddle.enable_static()


C
cc 已提交
234 235
if __name__ == "__main__":
    unittest.main()