collective_allgather_api.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
16 17 18

import test_collective_api_base as test_base

19
import paddle
TaoTao Li's avatar
TaoTao Li 已提交
20
import paddle.distributed as dist
21 22
from paddle import fluid, framework
from paddle.fluid import data_feeder
23

P
pangyoki 已提交
24 25
paddle.enable_static()

26

TaoTao Li's avatar
TaoTao Li 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
def all_gather_new(tensor_list, tensor, group=None):
    op_type = 'all_gather'
    helper = framework.LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
    for elem in tensor_list:
        data_feeder.check_variable_and_dtype(
            elem,
            'tensor_list',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
44
                'uint16',
TaoTao Li's avatar
TaoTao Li 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
            ],
            op_type,
        )
    data_feeder.check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'bool',
            'int8',
            'uint8',
60
            'uint16',
TaoTao Li's avatar
TaoTao Li 已提交
61 62 63 64 65 66 67 68
        ],
        op_type,
    )

    ring_id = 0 if group is None else group.id
    nranks = dist.get_world_size()
    helper.append_op(
        type=op_type,
69 70
        inputs={'x': [tensor]},
        outputs={'out': [out]},
TaoTao Li's avatar
TaoTao Li 已提交
71 72 73 74 75 76 77 78 79
        attrs={
            'ring_id': ring_id,
            'nranks': nranks,
        },
    )
    tensor_list.clear()
    tensor_list.extend(paddle.split(out, nranks, 0))


80
class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
81 82 83
    def __init__(self):
        self.global_ring_id = 0

84 85
    def get_model(self, main_prog, startup_program, rank, dtype=None):
        dtype = "float32" if dtype is None else dtype
86 87
        with fluid.program_guard(main_prog, startup_program):
            tensor_list = []
G
GGBond8488 已提交
88
            tindata = paddle.static.data(
TaoTao Li's avatar
TaoTao Li 已提交
89
                name="tindata", shape=[10, 1000], dtype=dtype
G
GGBond8488 已提交
90
            )
91 92 93
            paddle.distributed.all_gather(tensor_list, tindata)
            return tensor_list

TaoTao Li's avatar
TaoTao Li 已提交
94 95 96 97 98 99 100 101 102 103 104
    def get_model_new(
        self, main_prog, startup_program, rank, dtype=None, reduce_type=None
    ):
        with fluid.program_guard(main_prog, startup_program):
            tensor_list = []
            tindata = paddle.static.data(
                name="tindata", shape=[10, 1000], dtype=dtype
            )
            all_gather_new(tensor_list, tindata)
            return tensor_list

105 106 107 108 109 110 111
    def run_trainer(self, args):
        train_prog = fluid.Program()
        startup_prog = fluid.Program()
        endpoints = args["endpoints"].split(",")
        rank = args["trainerid"]
        current_endpoint = args["currentendpoint"]
        nranks = 2
TaoTao Li's avatar
TaoTao Li 已提交
112 113 114 115
        if args["use_comm_context"]:
            paddle.distributed.collective._init_parallel_env(args["backend"])
        else:
            paddle.distributed.init_parallel_env()
116 117 118
        if args['backend'] == 'nccl':
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(
119 120
                device_id
            )  # if args.use_gpu else fluid.CPUPlace()
121 122 123 124 125
        elif args['backend'] == 'bkcl':
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
        else:
            place = fluid.CPUPlace()
126 127 128 129 130
        indata = test_base.create_test_data(
            shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
        )
        assert (
            args['static_mode'] == 1
131
        ), "collective_allgather_api only support static graph mode"
TaoTao Li's avatar
TaoTao Li 已提交
132 133 134 135 136 137 138 139
        result = (
            self.get_model_new(
                train_prog, startup_prog, rank, dtype=args["dtype"]
            )
            if args["use_comm_context"]
            else self.get_model(
                train_prog, startup_prog, rank, dtype=args["dtype"]
            )
140
        )
141 142 143 144 145
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        fetch_list = []
        for elem in result:
            fetch_list.append(elem.name)
146 147 148
        out = exe.run(
            train_prog, feed={'tindata': indata}, fetch_list=fetch_list
        )
S
sneaxiy 已提交
149
        test_base.dump_output(out)
150

151 152

if __name__ == "__main__":
153
    test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")