test_prune_gate_by_capacity_op.py 4.6 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
import numpy as np
from paddle.distributed.models.moe import utils
from paddle.fluid import core


def count(x, upper_num):
    res = np.zeros((upper_num, )).astype(int)
    for i in x.reshape(-1):
        if i >= 0 and i < len(res):
            res[i] += 1
    return res


def limit_by_capacity(expert_count, _capacity, n_worker):
    capacity = np.copy(_capacity)
    old_shape = expert_count.shape
    expert_count = np.reshape(expert_count, (n_worker, len(capacity)))
    output = np.zeros_like(expert_count)
    for wid in range(len(expert_count)):
        for eid in range(len(expert_count[wid])):
            last_cap = capacity[eid]
            if last_cap >= 0:
                capacity[eid] -= expert_count[wid][eid]
            if last_cap >= expert_count[wid][eid]:
                output[wid][eid] = expert_count[wid][eid]
            elif last_cap >= 0:
                output[wid][eid] = last_cap
    return output.reshape(old_shape)


def prune_gate_by_capacity(gate_idx, expert_count, n_expert, n_worker):
    new_gate_idx = np.copy(gate_idx)
    expert_count = np.copy(expert_count)
    for i in range(len(gate_idx)):
        idx = gate_idx[i]
        last_cap = expert_count[idx]
        if last_cap > 0:
            expert_count[idx] -= 1
        else:
            new_gate_idx[i] = -1
    return new_gate_idx


def assert_allclose(output, expected, n_expert):
    c1 = count(output, n_expert)
    c2 = count(expected, n_expert)
    assert np.allclose(c1, c2)


@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
class TestPruneGateByCapacityAPI1(unittest.TestCase):
    def init_test_case(self):
        self.gate_idx = np.random.randint(
            0, self.n_expert, size=(200, )).astype(self.dtype)
        expert_count = count(self.gate_idx, self.n_expert * self.n_worker)
        capacity = np.random.randint(10, 200, size=(self.n_expert, ))
        self.expert_count = limit_by_capacity(expert_count, capacity,
                                              self.n_worker).astype(self.dtype)
        self.out = prune_gate_by_capacity(self.gate_idx, self.expert_count,
                                          self.n_expert,
                                          self.n_worker).astype(self.dtype)
        self.place = paddle.CUDAPlace(0)

    def setUp(self):
        self.n_expert = 24
        self.n_worker = 2
        self.dtype = "int64"
        self.init_test_case()

    def test_static_api(self):
        paddle.enable_static()
        with paddle.static.program_guard(paddle.static.Program()):
            gate_idx_tensor = paddle.static.data(
                'GateIdx', shape=self.gate_idx.shape, dtype="int64")
            expert_count_tensor = paddle.static.data(
                'ExpertCount', shape=self.expert_count.shape, dtype="int64")
            out = utils._prune_gate_by_capacity(gate_idx_tensor,
                                                expert_count_tensor,
                                                self.n_expert, self.n_worker)
            exe = paddle.static.Executor(self.place)
            res = exe.run(feed={
                'GateIdx': self.gate_idx,
                'ExpertCount': self.expert_count,
            },
                          fetch_list=out)
        assert_allclose(res[0], self.out, self.n_expert)

    def test_dygraph_api(self):
        paddle.disable_static(self.place)
        gate_idx_tensor = paddle.to_tensor(self.gate_idx)
        expert_count_tensor = paddle.to_tensor(self.expert_count)
        out = utils._prune_gate_by_capacity(
            gate_idx_tensor, expert_count_tensor, self.n_expert, self.n_worker)
        assert_allclose(out.numpy(), self.out, self.n_expert)


@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
class TestPruneGateByCapacityAPI2(TestPruneGateByCapacityAPI1):
    def setUp(self):
        self.n_expert = 12
        self.n_worker = 1
        self.dtype = "int64"
        self.init_test_case()


if __name__ == '__main__':
    unittest.main()