process_group_nccl.py 17.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import random
import numpy as np
import os
import shutil

import paddle
from paddle.fluid import core
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
29
import paddle.distributed as dist
30 31 32


def init_process_group(strategy=None):
33 34 35
    nranks = ParallelEnv().nranks
    rank = ParallelEnv().local_rank
    is_master = True if rank == 0 else False
36
    pg_group = dist.init_parallel_env()
37

38
    return pg_group.process_group
39 40 41


class TestProcessGroupFp32(unittest.TestCase):
42

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    def setUp(self):
        paddle.seed(2022)
        random.seed(2022)
        np.random.seed(2022)
        self.config()

    def config(self):
        self.dtype = "float32"
        self.shape = (2, 10, 5)

    def test_create_process_group_nccl(self):
        with _test_eager_guard():
            paddle.set_device('gpu:%d' %
                              paddle.distributed.ParallelEnv().dev_id)

            pg = init_process_group()
            print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name())
            print("test new group api ok")

            # test allreduce sum
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            sum_result = tensor_x + tensor_y
            if pg.rank() == 0:
72
                task = dist.all_reduce(tensor_x)
73 74
                assert np.array_equal(tensor_x, sum_result)
            else:
75
                task = dist.all_reduce(tensor_y)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
                assert np.array_equal(tensor_y, sum_result)

            print("test allreduce sum api ok")

            # test allreduce max
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            max_result = paddle.maximum(tensor_x, tensor_y)

            if pg.rank() == 0:
91 92
                task = dist.all_reduce(tensor_x,
                                       dist.ReduceOp.MAX,
93
                                       sync_op=False)
94 95 96
                task.wait()
                assert np.array_equal(tensor_x, max_result)
            else:
97 98
                task = dist.all_reduce(tensor_y,
                                       dist.ReduceOp.MAX,
99
                                       sync_op=False)
100 101 102 103 104
                task.wait()
                assert np.array_equal(tensor_y, max_result)

            print("test allreduce max api ok")

105 106 107 108 109 110 111 112 113 114 115
            # test allreduce min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            min_result = paddle.minimum(tensor_x, tensor_y)

            if pg.rank() == 0:
116 117
                task = dist.all_reduce(tensor_x,
                                       dist.ReduceOp.MIN,
118
                                       sync_op=False)
119 120 121
                task.wait()
                assert np.array_equal(tensor_x, min_result)
            else:
122 123
                task = dist.all_reduce(tensor_y,
                                       dist.ReduceOp.MIN,
124
                                       sync_op=False)
125 126 127 128 129
                task.wait()
                assert np.array_equal(tensor_y, min_result)

            print("test allreduce min api ok")

130 131 132 133 134 135 136 137 138 139 140
            # test allreduce prod
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            prod_result = np.multiply(x, y)

            if pg.rank() == 0:
141 142
                task = dist.all_reduce(tensor_x,
                                       dist.ReduceOp.PROD,
143
                                       sync_op=False)
144 145 146
                task.wait()
                assert np.array_equal(tensor_x, prod_result)
            else:
147 148
                task = dist.all_reduce(tensor_y,
                                       dist.ReduceOp.PROD,
149
                                       sync_op=False)
150 151 152 153 154
                task.wait()
                assert np.array_equal(tensor_y, prod_result)

            print("test allreduce prod api ok")

155 156 157 158 159 160 161 162 163 164
            # test broadcast
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            broadcast_result = paddle.assign(tensor_x)
            if pg.rank() == 0:
165
                task = dist.broadcast(tensor_x, 0, sync_op=False)
166 167 168 169 170
                task.synchronize()
                paddle.device.cuda.synchronize()
                assert task.is_completed()
                assert np.array_equal(broadcast_result, tensor_x)
            else:
171
                task = dist.broadcast(tensor_y, 0)
172 173 174 175 176
                paddle.device.cuda.synchronize()
                assert np.array_equal(broadcast_result, tensor_y)

            print("test broadcast api ok")

B
Baibaifan 已提交
177 178 179
            # test barrier
            # rank 0
            if pg.rank() == 0:
180
                dist.barrier()
B
Baibaifan 已提交
181 182 183 184 185 186 187
            # rank 1
            else:
                task = pg.barrier()
                task.wait()

            print("test barrier api ok\n")

188
            # test allgather
B
Baibaifan 已提交
189 190
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
191 192 193 194 195 196 197 198 199 200 201 202 203
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            out_shape = list(self.shape)
            out_shape[0] *= 2
            out = np.random.random(out_shape).astype(self.dtype)
            tensor_out = paddle.to_tensor(out)
            if pg.rank() == 0:
                task = pg.all_gather(tensor_x, tensor_out)
                task.wait()
                paddle.device.cuda.synchronize()
            # rank 1
            else:
204
                tensor_out_list = [
205 206
                    paddle.empty_like(tensor_x),
                    paddle.empty_like(tensor_x)
207
                ]
208
                task = dist.all_gather(tensor_out_list, tensor_y, sync_op=False)
209
                paddle.device.cuda.synchronize()
210
                tensor_out = paddle.concat(tensor_out_list)
211 212 213 214 215 216 217
            out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
            out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2],
                                 [out_shape[0]])
            assert np.array_equal(tensor_x, out_1)
            assert np.array_equal(tensor_y, out_2)
            print("test allgather api ok\n")

218 219 220 221 222 223 224
            if pg.rank() == 0:
                task = pg.all_gather(tensor_x, tensor_out)
                task.wait()
                paddle.device.cuda.synchronize()
            # rank 1
            else:
                tensor_out_list = []
225
                task = dist.all_gather(tensor_out_list, tensor_y, sync_op=False)
226 227 228 229 230 231 232 233 234
                paddle.device.cuda.synchronize()
                tensor_out = paddle.concat(tensor_out_list)
            out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
            out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2],
                                 [out_shape[0]])
            assert np.array_equal(tensor_x, out_1)
            assert np.array_equal(tensor_y, out_2)
            print("test allgather api2 ok\n")

235 236 237 238 239 240
            # test alltoall
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            out1 = np.random.random(self.shape).astype(self.dtype)
            out2 = np.random.random(self.shape).astype(self.dtype)
B
Baibaifan 已提交
241
            tensor_x = paddle.to_tensor(x)
242 243 244 245 246 247 248
            tensor_y = paddle.to_tensor(y)
            tensor_out1 = paddle.to_tensor(out1)
            tensor_out2 = paddle.to_tensor(out2)
            raw_tensor_x_2 = paddle.slice(tensor_x, [0], [self.shape[0] // 2],
                                          [self.shape[0]])
            raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0],
                                          [self.shape[0] // 2])
B
Baibaifan 已提交
249
            if pg.rank() == 0:
250
                task = pg.alltoall(tensor_x, tensor_out1)
B
Baibaifan 已提交
251 252 253
                task.wait()
            # rank 1
            else:
254 255 256 257
                in_1, in_2 = paddle.split(tensor_y, 2)
                out_1, out_2 = paddle.split(tensor_out2, 2)
                out_tensor_list = [out_1, out_2]
                task = dist.alltoall([in_1, in_2], out_tensor_list)
B
Baibaifan 已提交
258
                paddle.device.cuda.synchronize()
259
                tensor_out2 = paddle.concat(out_tensor_list)
260 261 262 263 264 265 266 267 268
            out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2],
                                  [self.shape[0]])
            out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
            if pg.rank() == 0:
                assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
            else:
                assert np.array_equal(out2_1, raw_tensor_x_2)
            print("test alltoall api ok\n")

269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
            x = np.random.random(self.shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            out1 = np.random.random(self.shape).astype(self.dtype)
            out2 = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            tensor_out1 = paddle.to_tensor(out1)
            tensor_out2 = paddle.to_tensor(out2)
            raw_tensor_x_2 = paddle.slice(tensor_x, [0], [self.shape[0] // 2],
                                          [self.shape[0]])
            raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0],
                                          [self.shape[0] // 2])
            if pg.rank() == 0:
                task = pg.alltoall(tensor_x, tensor_out1)
                task.wait()
            # rank 1
            else:
                in_1, in_2 = paddle.split(tensor_y, 2)
                out_1, out_2 = paddle.split(tensor_out2, 2)
                out_tensor_list = []
                task = dist.alltoall([in_1, in_2], out_tensor_list)
                paddle.device.cuda.synchronize()
                tensor_out2 = paddle.concat(out_tensor_list)
            out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2],
                                  [self.shape[0]])
            out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
            if pg.rank() == 0:
                assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
            else:
                assert np.array_equal(out2_1, raw_tensor_x_2)
            print("test alltoall api2 ok\n")

301 302 303 304 305 306 307 308
            # test Reduce
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            sum_result = tensor_x + tensor_y
            if pg.rank() == 0:
309
                task = dist.reduce(tensor_x, 0, sync_op=True)
310 311 312
                paddle.device.cuda.synchronize()
            # rank 1
            else:
313
                task = dist.reduce(tensor_y, 0, sync_op=False)
314 315 316 317 318 319
                task.wait()
                paddle.device.cuda.synchronize()
            if pg.rank() == 0:
                assert np.array_equal(tensor_x, sum_result)
            print("test reduce sum api ok\n")

320 321 322 323 324 325 326 327 328 329 330
            # test reduce max
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            max_result = paddle.maximum(tensor_x, tensor_y)

            if pg.rank() == 0:
331 332 333
                task = dist.reduce(tensor_x,
                                   0,
                                   dist.ReduceOp.MAX,
334
                                   sync_op=False)
335 336 337
                task.wait()
                assert np.array_equal(tensor_x, max_result)
            else:
338 339 340
                task = dist.reduce(tensor_y,
                                   0,
                                   dist.ReduceOp.MAX,
341
                                   sync_op=False)
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
                task.wait()

            print("test reduce max api ok")

            # test reduce min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            min_result = paddle.minimum(tensor_x, tensor_y)

            if pg.rank() == 0:
357 358 359
                task = dist.reduce(tensor_x,
                                   0,
                                   dist.ReduceOp.MIN,
360
                                   sync_op=False)
361 362 363
                task.wait()
                assert np.array_equal(tensor_x, min_result)
            else:
364 365 366
                task = dist.reduce(tensor_y,
                                   0,
                                   dist.ReduceOp.MIN,
367
                                   sync_op=False)
368 369 370 371
                task.wait()

            print("test reduce min api ok")

372 373 374 375 376 377 378 379 380 381 382
            # test reduce product
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            prod_result = np.multiply(x, y)

            if pg.rank() == 0:
383 384 385
                task = dist.reduce(tensor_x,
                                   0,
                                   dist.ReduceOp.PROD,
386
                                   sync_op=False)
387 388 389
                task.wait()
                assert np.array_equal(tensor_x, prod_result)
            else:
390 391 392
                task = dist.reduce(tensor_y,
                                   0,
                                   dist.ReduceOp.PROD,
393
                                   sync_op=False)
394 395 396
                task.wait()

            print("test reduce prod api ok")
397 398 399 400 401 402 403 404 405
            # test Scatter
            # rank 0
            in_shape = list(self.shape)
            in_shape[0] *= 2
            x = np.random.random(in_shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            if pg.rank() == 0:
406
                in_1, in_2 = paddle.split(tensor_x, 2)
407
                task = dist.scatter(tensor_y, [in_1, in_2], 0, sync_op=True)
408
                #task.wait()
409 410 411
                paddle.device.cuda.synchronize()
            # rank 1
            else:
412
                task = dist.scatter(tensor_y, [], 0, sync_op=False)
413 414 415 416 417 418 419 420 421 422
                task.wait()
                paddle.device.cuda.synchronize()
            out1 = paddle.slice(tensor_x, [0], [0], [self.shape[0]])
            out2 = paddle.slice(tensor_x, [0], [self.shape[0]],
                                [self.shape[0] * 2])
            if pg.rank() == 0:
                assert np.array_equal(tensor_y, out1)
            else:
                assert np.array_equal(tensor_y, out2)
            print("test scatter api ok\n")
B
Baibaifan 已提交
423

424 425 426 427 428 429 430 431 432
            # test send min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            if pg.rank() == 0:
433
                task = dist.send(tensor_x, 1, sync_op=False)
434 435
                task.wait()
            else:
436
                task = dist.recv(tensor_y, 0, sync_op=False)
437 438 439 440 441 442 443 444 445 446 447 448 449 450
                task.wait()
                assert np.array_equal(tensor_y, tensor_x)

            print("test send api ok")

            # test send min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            if pg.rank() == 0:
451
                task = dist.send(tensor_x, 1, sync_op=True)
452
            else:
453
                task = dist.recv(tensor_y, 0, sync_op=True)
454 455 456 457
                assert np.array_equal(tensor_y, tensor_x)

            print("test send api ok")

458 459

class TestProcessGroupFp16(TestProcessGroupFp32):
460

461 462 463 464 465 466 467 468 469 470 471 472 473
    def setUp(self):
        paddle.seed(2022)
        random.seed(2022)
        np.random.seed(2022)
        self.config()

    def config(self):
        self.dtype = "float16"
        self.shape = (4, 20, 20)


if __name__ == "__main__":
    unittest.main()