collective_allgather_api.py 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import paddle
import paddle.fluid as fluid
19
import pickle
20
import paddle.fluid.layers as layers
21
import test_collective_api_base as test_base
22

P
pangyoki 已提交
23 24
paddle.enable_static()

25

26
class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
27 28 29
    def __init__(self):
        self.global_ring_id = 0

30 31
    def get_model(self, main_prog, startup_program, rank, dtype=None):
        dtype = "float32" if dtype is None else dtype
32 33
        with fluid.program_guard(main_prog, startup_program):
            tensor_list = []
34
            tindata = layers.data(name="tindata", shape=[10, 1000], dtype=dtype)
35 36 37
            paddle.distributed.all_gather(tensor_list, tindata)
            return tensor_list

38 39 40 41 42 43 44 45 46 47 48
    def run_trainer(self, args):
        train_prog = fluid.Program()
        startup_prog = fluid.Program()
        endpoints = args["endpoints"].split(",")
        rank = args["trainerid"]
        current_endpoint = args["currentendpoint"]
        nranks = 2
        paddle.distributed.init_parallel_env()
        if args['backend'] == 'nccl':
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(
49 50
                device_id
            )  # if args.use_gpu else fluid.CPUPlace()
51 52 53 54 55
        elif args['backend'] == 'bkcl':
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
        else:
            place = fluid.CPUPlace()
56 57 58 59 60 61 62 63 64
        indata = test_base.create_test_data(
            shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
        )
        assert (
            args['static_mode'] == 1
        ), "collective_allgather_api only support static mode"
        result = self.get_model(
            train_prog, startup_prog, rank, dtype=args["dtype"]
        )
65 66 67 68 69
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        fetch_list = []
        for elem in result:
            fetch_list.append(elem.name)
70 71 72
        out = exe.run(
            train_prog, feed={'tindata': indata}, fetch_list=fetch_list
        )
73 74
        sys.stdout.buffer.write(pickle.dumps(out))

75 76

if __name__ == "__main__":
77
    test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")