collective_allgather_api.py 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
16
import pickle
17
import sys
18 19 20

import test_collective_api_base as test_base

21 22 23 24
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers

P
pangyoki 已提交
25 26
paddle.enable_static()

27

28
class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
29 30 31
    def __init__(self):
        self.global_ring_id = 0

32 33
    def get_model(self, main_prog, startup_program, rank, dtype=None):
        dtype = "float32" if dtype is None else dtype
34 35
        with fluid.program_guard(main_prog, startup_program):
            tensor_list = []
36
            tindata = layers.data(name="tindata", shape=[10, 1000], dtype=dtype)
37 38 39
            paddle.distributed.all_gather(tensor_list, tindata)
            return tensor_list

40 41 42 43 44 45 46 47 48 49 50
    def run_trainer(self, args):
        train_prog = fluid.Program()
        startup_prog = fluid.Program()
        endpoints = args["endpoints"].split(",")
        rank = args["trainerid"]
        current_endpoint = args["currentendpoint"]
        nranks = 2
        paddle.distributed.init_parallel_env()
        if args['backend'] == 'nccl':
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(
51 52
                device_id
            )  # if args.use_gpu else fluid.CPUPlace()
53 54 55 56 57
        elif args['backend'] == 'bkcl':
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
        else:
            place = fluid.CPUPlace()
58 59 60 61 62
        indata = test_base.create_test_data(
            shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
        )
        assert (
            args['static_mode'] == 1
63
        ), "collective_allgather_api only support static graph mode"
64 65 66
        result = self.get_model(
            train_prog, startup_prog, rank, dtype=args["dtype"]
        )
67 68 69 70 71
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        fetch_list = []
        for elem in result:
            fetch_list.append(elem.name)
72 73 74
        out = exe.run(
            train_prog, feed={'tindata': indata}, fetch_list=fetch_list
        )
75 76
        sys.stdout.buffer.write(pickle.dumps(out))

77 78

if __name__ == "__main__":
79
    test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")