process_group_nccl.py 13.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import random
import numpy as np
import os
import shutil

import paddle
from paddle.fluid import core
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
29
import paddle.distributed as dist
30 31 32


def init_process_group(strategy=None):
33 34 35
    nranks = ParallelEnv().nranks
    rank = ParallelEnv().local_rank
    is_master = True if rank == 0 else False
36
    pg_group = dist.init_parallel_env()
37

38
    return pg_group.process_group
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70


class TestProcessGroupFp32(unittest.TestCase):
    def setUp(self):
        paddle.seed(2022)
        random.seed(2022)
        np.random.seed(2022)
        self.config()

    def config(self):
        self.dtype = "float32"
        self.shape = (2, 10, 5)

    def test_create_process_group_nccl(self):
        with _test_eager_guard():
            paddle.set_device('gpu:%d' %
                              paddle.distributed.ParallelEnv().dev_id)

            pg = init_process_group()
            print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name())
            print("test new group api ok")

            # test allreduce sum
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            sum_result = tensor_x + tensor_y
            if pg.rank() == 0:
71
                task = dist.all_reduce(tensor_x)
72 73
                assert np.array_equal(tensor_x, sum_result)
            else:
74
                task = dist.all_reduce(tensor_y)
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
                assert np.array_equal(tensor_y, sum_result)

            print("test allreduce sum api ok")

            # test allreduce max
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            max_result = paddle.maximum(tensor_x, tensor_y)

            if pg.rank() == 0:
90 91
                task = dist.all_reduce(
                    tensor_x, dist.ReduceOp.MAX, use_calc_stream=False)
92 93 94
                task.wait()
                assert np.array_equal(tensor_x, max_result)
            else:
95 96
                task = dist.all_reduce(
                    tensor_y, dist.ReduceOp.MAX, use_calc_stream=False)
97 98 99 100 101
                task.wait()
                assert np.array_equal(tensor_y, max_result)

            print("test allreduce max api ok")

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
            # test allreduce min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            min_result = paddle.minimum(tensor_x, tensor_y)

            if pg.rank() == 0:
                task = dist.all_reduce(
                    tensor_x, dist.ReduceOp.MIN, use_calc_stream=False)
                task.wait()
                assert np.array_equal(tensor_x, min_result)
            else:
                task = dist.all_reduce(
                    tensor_y, dist.ReduceOp.MIN, use_calc_stream=False)
                task.wait()
                assert np.array_equal(tensor_y, min_result)

            print("test allreduce min api ok")

125 126 127 128 129 130 131 132 133 134
            # test broadcast
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            broadcast_result = paddle.assign(tensor_x)
            if pg.rank() == 0:
135
                task = dist.broadcast(tensor_x, 0, use_calc_stream=False)
136 137 138 139 140
                task.synchronize()
                paddle.device.cuda.synchronize()
                assert task.is_completed()
                assert np.array_equal(broadcast_result, tensor_x)
            else:
141
                task = dist.broadcast(tensor_y, 0)
142 143 144 145 146
                paddle.device.cuda.synchronize()
                assert np.array_equal(broadcast_result, tensor_y)

            print("test broadcast api ok")

B
Baibaifan 已提交
147 148 149
            # test barrier
            # rank 0
            if pg.rank() == 0:
150
                dist.barrier()
B
Baibaifan 已提交
151 152 153 154 155 156 157
            # rank 1
            else:
                task = pg.barrier()
                task.wait()

            print("test barrier api ok\n")

158
            # test allgather
B
Baibaifan 已提交
159 160
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
161 162 163 164 165 166 167 168 169 170 171 172 173
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            out_shape = list(self.shape)
            out_shape[0] *= 2
            out = np.random.random(out_shape).astype(self.dtype)
            tensor_out = paddle.to_tensor(out)
            if pg.rank() == 0:
                task = pg.all_gather(tensor_x, tensor_out)
                task.wait()
                paddle.device.cuda.synchronize()
            # rank 1
            else:
174 175 176 177 178
                tensor_out_list = [
                    paddle.empty_like(tensor_x), paddle.empty_like(tensor_x)
                ]
                task = dist.all_gather(
                    tensor_out_list, tensor_y, use_calc_stream=False)
179
                paddle.device.cuda.synchronize()
180
                tensor_out = paddle.concat(tensor_out_list)
181 182 183 184 185 186 187 188 189 190 191 192 193
            out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
            out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2],
                                 [out_shape[0]])
            assert np.array_equal(tensor_x, out_1)
            assert np.array_equal(tensor_y, out_2)
            print("test allgather api ok\n")

            # test alltoall
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            out1 = np.random.random(self.shape).astype(self.dtype)
            out2 = np.random.random(self.shape).astype(self.dtype)
B
Baibaifan 已提交
194
            tensor_x = paddle.to_tensor(x)
195 196 197 198 199 200 201
            tensor_y = paddle.to_tensor(y)
            tensor_out1 = paddle.to_tensor(out1)
            tensor_out2 = paddle.to_tensor(out2)
            raw_tensor_x_2 = paddle.slice(tensor_x, [0], [self.shape[0] // 2],
                                          [self.shape[0]])
            raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0],
                                          [self.shape[0] // 2])
B
Baibaifan 已提交
202
            if pg.rank() == 0:
203
                task = pg.alltoall(tensor_x, tensor_out1)
B
Baibaifan 已提交
204 205 206
                task.wait()
            # rank 1
            else:
207 208 209 210
                in_1, in_2 = paddle.split(tensor_y, 2)
                out_1, out_2 = paddle.split(tensor_out2, 2)
                out_tensor_list = [out_1, out_2]
                task = dist.alltoall([in_1, in_2], out_tensor_list)
B
Baibaifan 已提交
211
                paddle.device.cuda.synchronize()
212
                tensor_out2 = paddle.concat(out_tensor_list)
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
            out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2],
                                  [self.shape[0]])
            out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
            if pg.rank() == 0:
                assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
            else:
                assert np.array_equal(out2_1, raw_tensor_x_2)
            print("test alltoall api ok\n")

            # test Reduce
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            sum_result = tensor_x + tensor_y
            if pg.rank() == 0:
230
                task = dist.reduce(tensor_x, 0, use_calc_stream=True)
231 232 233
                paddle.device.cuda.synchronize()
            # rank 1
            else:
234
                task = dist.reduce(tensor_y, 0, use_calc_stream=False)
235 236 237 238 239 240
                task.wait()
                paddle.device.cuda.synchronize()
            if pg.rank() == 0:
                assert np.array_equal(tensor_x, sum_result)
            print("test reduce sum api ok\n")

241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
            # test reduce max
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            max_result = paddle.maximum(tensor_x, tensor_y)

            if pg.rank() == 0:
                task = dist.reduce(
                    tensor_x, 0, dist.ReduceOp.MAX, use_calc_stream=False)
                task.wait()
                assert np.array_equal(tensor_x, max_result)
            else:
                task = dist.reduce(
                    tensor_y, 0, dist.ReduceOp.MAX, use_calc_stream=False)
                task.wait()

            print("test reduce max api ok")

            # test reduce min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            min_result = paddle.minimum(tensor_x, tensor_y)

            if pg.rank() == 0:
                task = dist.reduce(
                    tensor_x, 0, dist.ReduceOp.MIN, use_calc_stream=False)
                task.wait()
                assert np.array_equal(tensor_x, min_result)
            else:
                task = dist.reduce(
                    tensor_y, 0, dist.ReduceOp.MIN, use_calc_stream=False)
                task.wait()

            print("test reduce min api ok")

285 286 287 288 289 290 291 292 293
            # test Scatter
            # rank 0
            in_shape = list(self.shape)
            in_shape[0] *= 2
            x = np.random.random(in_shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            if pg.rank() == 0:
294 295 296 297
                in_1, in_2 = paddle.split(tensor_x, 2)
                task = dist.scatter(
                    tensor_y, [in_1, in_2], 0, use_calc_stream=True)
                #task.wait()
298 299 300
                paddle.device.cuda.synchronize()
            # rank 1
            else:
301
                task = dist.scatter(tensor_y, [], 0, use_calc_stream=False)
302 303 304 305 306 307 308 309 310 311
                task.wait()
                paddle.device.cuda.synchronize()
            out1 = paddle.slice(tensor_x, [0], [0], [self.shape[0]])
            out2 = paddle.slice(tensor_x, [0], [self.shape[0]],
                                [self.shape[0] * 2])
            if pg.rank() == 0:
                assert np.array_equal(tensor_y, out1)
            else:
                assert np.array_equal(tensor_y, out2)
            print("test scatter api ok\n")
B
Baibaifan 已提交
312

313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
            # test send min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            if pg.rank() == 0:
                task = dist.send(tensor_x, 1, use_calc_stream=False)
                task.wait()
            else:
                task = dist.recv(tensor_y, 0, use_calc_stream=False)
                task.wait()
                assert np.array_equal(tensor_y, tensor_x)

            print("test send api ok")

            # test send min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            if pg.rank() == 0:
                task = dist.send(tensor_x, 1, use_calc_stream=True)
            else:
                task = dist.recv(tensor_y, 0, use_calc_stream=True)
                assert np.array_equal(tensor_y, tensor_x)

            print("test send api ok")

347 348 349 350 351 352 353 354 355 356 357 358 359 360 361

class TestProcessGroupFp16(TestProcessGroupFp32):
    def setUp(self):
        paddle.seed(2022)
        random.seed(2022)
        np.random.seed(2022)
        self.config()

    def config(self):
        self.dtype = "float16"
        self.shape = (4, 20, 20)


if __name__ == "__main__":
    unittest.main()