collective_allgather_api.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
16
import pickle
17
import sys
18 19 20

import test_collective_api_base as test_base

21
import paddle
TaoTao Li's avatar
TaoTao Li 已提交
22
import paddle.distributed as dist
23 24
from paddle import fluid, framework
from paddle.fluid import data_feeder
25

P
pangyoki 已提交
26 27
paddle.enable_static()

28

TaoTao Li's avatar
TaoTao Li 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
def all_gather_new(tensor_list, tensor, group=None):
    op_type = 'all_gather'
    helper = framework.LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
    for elem in tensor_list:
        data_feeder.check_variable_and_dtype(
            elem,
            'tensor_list',
            [
                'float16',
                'float32',
                'float64',
                'int32',
                'int64',
                'bool',
                'int8',
                'uint8',
46
                'uint16',
TaoTao Li's avatar
TaoTao Li 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
            ],
            op_type,
        )
    data_feeder.check_variable_and_dtype(
        tensor,
        'tensor',
        [
            'float16',
            'float32',
            'float64',
            'int32',
            'int64',
            'bool',
            'int8',
            'uint8',
62
            'uint16',
TaoTao Li's avatar
TaoTao Li 已提交
63 64 65 66 67 68 69 70
        ],
        op_type,
    )

    ring_id = 0 if group is None else group.id
    nranks = dist.get_world_size()
    helper.append_op(
        type=op_type,
71 72
        inputs={'x': [tensor]},
        outputs={'out': [out]},
TaoTao Li's avatar
TaoTao Li 已提交
73 74 75 76 77 78 79 80 81
        attrs={
            'ring_id': ring_id,
            'nranks': nranks,
        },
    )
    tensor_list.clear()
    tensor_list.extend(paddle.split(out, nranks, 0))


82
class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
83 84 85
    def __init__(self):
        self.global_ring_id = 0

86 87
    def get_model(self, main_prog, startup_program, rank, dtype=None):
        dtype = "float32" if dtype is None else dtype
88 89
        with fluid.program_guard(main_prog, startup_program):
            tensor_list = []
G
GGBond8488 已提交
90
            tindata = paddle.static.data(
TaoTao Li's avatar
TaoTao Li 已提交
91
                name="tindata", shape=[10, 1000], dtype=dtype
G
GGBond8488 已提交
92
            )
93 94 95
            paddle.distributed.all_gather(tensor_list, tindata)
            return tensor_list

TaoTao Li's avatar
TaoTao Li 已提交
96 97 98 99 100 101 102 103 104 105 106
    def get_model_new(
        self, main_prog, startup_program, rank, dtype=None, reduce_type=None
    ):
        with fluid.program_guard(main_prog, startup_program):
            tensor_list = []
            tindata = paddle.static.data(
                name="tindata", shape=[10, 1000], dtype=dtype
            )
            all_gather_new(tensor_list, tindata)
            return tensor_list

107 108 109 110 111 112 113
    def run_trainer(self, args):
        train_prog = fluid.Program()
        startup_prog = fluid.Program()
        endpoints = args["endpoints"].split(",")
        rank = args["trainerid"]
        current_endpoint = args["currentendpoint"]
        nranks = 2
TaoTao Li's avatar
TaoTao Li 已提交
114 115 116 117
        if args["use_comm_context"]:
            paddle.distributed.collective._init_parallel_env(args["backend"])
        else:
            paddle.distributed.init_parallel_env()
118 119 120
        if args['backend'] == 'nccl':
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(
121 122
                device_id
            )  # if args.use_gpu else fluid.CPUPlace()
123 124 125 126 127
        elif args['backend'] == 'bkcl':
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
        else:
            place = fluid.CPUPlace()
128 129 130 131 132
        indata = test_base.create_test_data(
            shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
        )
        assert (
            args['static_mode'] == 1
133
        ), "collective_allgather_api only support static graph mode"
TaoTao Li's avatar
TaoTao Li 已提交
134 135 136 137 138 139 140 141
        result = (
            self.get_model_new(
                train_prog, startup_prog, rank, dtype=args["dtype"]
            )
            if args["use_comm_context"]
            else self.get_model(
                train_prog, startup_prog, rank, dtype=args["dtype"]
            )
142
        )
143 144 145 146 147
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        fetch_list = []
        for elem in result:
            fetch_list.append(elem.name)
148 149 150
        out = exe.run(
            train_prog, feed={'tindata': indata}, fetch_list=fetch_list
        )
151 152
        sys.stdout.buffer.write(pickle.dumps(out))

153 154

if __name__ == "__main__":
155
    test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")