collective_allgather_api.py 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
16
import pickle
17
import sys
18 19 20

import test_collective_api_base as test_base

21 22 23
import paddle
import paddle.fluid as fluid

P
pangyoki 已提交
24 25
paddle.enable_static()

26

27
class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
28 29 30
    def __init__(self):
        self.global_ring_id = 0

31 32
    def get_model(self, main_prog, startup_program, rank, dtype=None):
        dtype = "float32" if dtype is None else dtype
33 34
        with fluid.program_guard(main_prog, startup_program):
            tensor_list = []
G
GGBond8488 已提交
35 36 37
            tindata = paddle.static.data(
                name="tindata", shape=[-1, 10, 1000], dtype=dtype
            )
38 39 40
            paddle.distributed.all_gather(tensor_list, tindata)
            return tensor_list

41 42 43 44 45 46 47 48 49 50 51
    def run_trainer(self, args):
        train_prog = fluid.Program()
        startup_prog = fluid.Program()
        endpoints = args["endpoints"].split(",")
        rank = args["trainerid"]
        current_endpoint = args["currentendpoint"]
        nranks = 2
        paddle.distributed.init_parallel_env()
        if args['backend'] == 'nccl':
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(
52 53
                device_id
            )  # if args.use_gpu else fluid.CPUPlace()
54 55 56 57 58
        elif args['backend'] == 'bkcl':
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
        else:
            place = fluid.CPUPlace()
59 60 61 62 63
        indata = test_base.create_test_data(
            shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
        )
        assert (
            args['static_mode'] == 1
64
        ), "collective_allgather_api only support static graph mode"
65 66 67
        result = self.get_model(
            train_prog, startup_prog, rank, dtype=args["dtype"]
        )
68 69 70 71 72
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        fetch_list = []
        for elem in result:
            fetch_list.append(elem.name)
73 74 75
        out = exe.run(
            train_prog, feed={'tindata': indata}, fetch_list=fetch_list
        )
76 77
        sys.stdout.buffer.write(pickle.dumps(out))

78 79

if __name__ == "__main__":
80
    test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")