test_random_routing_op.py 2.6 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import unittest
import paddle
import paddle.fluid.core as core
from paddle.distributed.models.moe import utils
20
from paddle.fluid.framework import _test_eager_guard
R
Roc 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37


def random_routing(topk_idx, topk_value, prob, topk=2):
    if topk == 2:
        new_topk_idx = np.copy(topk_idx)
        for i in range(len(topk_idx)):
            val = topk_value[i][1]
            if val * 2 < prob[i]:
                new_topk_idx[i][1] = -1
        return new_topk_idx
    else:
        raise RuntimeError("only topk=2 is supported now")


@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
class TestNumberCountAPIFp32(unittest.TestCase):
38

R
Roc 已提交
39 40 41 42 43 44
    def setUp(self):
        self.dtype = "float32"
        self.init()

    def init(self):
        self.upper_range = 8
45 46
        self.x = np.random.randint(-1, self.upper_range,
                                   size=(200, 2)).astype('int64')
R
Roc 已提交
47 48 49 50 51 52
        self.prob = np.random.random((self.x.shape[0], )).astype(self.dtype)
        self.topk_value = np.random.random(self.x.shape).astype(self.dtype)
        self.out = random_routing(self.x, self.topk_value,
                                  self.prob).astype(self.dtype)
        self.place = paddle.CUDAPlace(0)

53
    def func_api_dygraph(self):
R
Roc 已提交
54 55 56 57 58 59 60
        paddle.disable_static()
        x = paddle.to_tensor(self.x)
        value = paddle.to_tensor(self.topk_value)
        prob = paddle.to_tensor(self.prob)
        out = utils._random_routing(x, value, prob)
        assert np.allclose(out.numpy(), self.out)

61 62 63 64 65
    def test_api_dygraph(self):
        with _test_eager_guard():
            self.func_api_dygraph()
        self.func_api_dygraph()

R
Roc 已提交
66 67 68 69

@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
class TestNumberCountAPIFp16(TestNumberCountAPIFp32):
70

R
Roc 已提交
71 72 73 74 75 76 77 78
    def setUp(self):
        self.dtype = "float16"
        self.init()


if __name__ == '__main__':
    paddle.enable_static()
    unittest.main()