new_group.py 2.8 KB
Newer Older
K
kuizhiqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16

K
kuizhiqing 已提交
17 18 19
import paddle


20
class TestNewGroupAPI:
K
kuizhiqing 已提交
21 22 23 24 25 26 27 28 29
    def __init__(self):
        paddle.distributed.init_parallel_env()
        d1 = np.array([1, 2, 3])
        d2 = np.array([2, 3, 4])
        self.tensor1 = paddle.to_tensor(d1)
        self.tensor2 = paddle.to_tensor(d2)

    def test_all(self):
        gp = paddle.distributed.new_group([0, 1])
30
        print("gp info:", gp)
K
kuizhiqing 已提交
31 32 33 34
        print("test new group api ok")

        tmp = np.array([0, 0, 0])
        result = paddle.to_tensor(tmp)
35 36 37
        paddle.distributed.scatter(
            result, [self.tensor2, self.tensor1], src=0, group=gp, sync_op=True
        )
K
kuizhiqing 已提交
38
        if gp.rank == 0:
39
            np.testing.assert_array_equal(result, self.tensor2)
K
kuizhiqing 已提交
40
        elif gp.rank == 1:
41
            np.testing.assert_array_equal(result, self.tensor1)
K
kuizhiqing 已提交
42 43
        print("test scatter api ok")

44
        paddle.distributed.broadcast(result, src=1, group=gp, sync_op=True)
45
        np.testing.assert_array_equal(result, self.tensor1)
K
kuizhiqing 已提交
46 47
        print("test broadcast api ok")

48
        paddle.distributed.reduce(result, dst=0, group=gp, sync_op=True)
K
kuizhiqing 已提交
49
        if gp.rank == 0:
50
            np.testing.assert_array_equal(
51 52
                result, paddle.add(self.tensor1, self.tensor1)
            )
K
kuizhiqing 已提交
53
        elif gp.rank == 1:
54
            np.testing.assert_array_equal(result, self.tensor1)
K
kuizhiqing 已提交
55 56
        print("test reduce api ok")

57
        paddle.distributed.all_reduce(result, sync_op=True)
58
        np.testing.assert_array_equal(
K
kuizhiqing 已提交
59
            result,
60 61
            paddle.add(paddle.add(self.tensor1, self.tensor1), self.tensor1),
        )
K
kuizhiqing 已提交
62 63 64 65 66 67 68
        print("test all_reduce api ok")

        paddle.distributed.wait(result, gp, use_calc_stream=True)
        paddle.distributed.wait(result, gp, use_calc_stream=False)
        print("test wait api ok")

        result = []
69 70 71
        paddle.distributed.all_gather(
            result, self.tensor1, group=gp, sync_op=True
        )
72 73
        np.testing.assert_array_equal(result[0], self.tensor1)
        np.testing.assert_array_equal(result[1], self.tensor1)
K
kuizhiqing 已提交
74 75 76 77 78 79 80 81 82 83 84
        print("test all_gather api ok")

        paddle.distributed.barrier(group=gp)
        print("test barrier api ok")

        return


if __name__ == "__main__":
    gpt = TestNewGroupAPI()
    gpt.test_all()