process_group_nccl.py 20.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
16 17
import unittest

18 19 20
import numpy as np

import paddle
21
import paddle.distributed as dist
22
from paddle.fluid.dygraph.parallel import ParallelEnv
23 24 25


def init_process_group(strategy=None):
26 27 28
    nranks = ParallelEnv().nranks
    rank = ParallelEnv().local_rank
    is_master = True if rank == 0 else False
29
    pg_group = dist.init_parallel_env()
30

31
    return pg_group.process_group
32 33 34 35 36 37 38 39 40 41 42 43 44 45


class TestProcessGroupFp32(unittest.TestCase):
    def setUp(self):
        paddle.seed(2022)
        random.seed(2022)
        np.random.seed(2022)
        self.config()

    def config(self):
        self.dtype = "float32"
        self.shape = (2, 10, 5)

    def test_create_process_group_nccl(self):
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        device_id = paddle.distributed.ParallelEnv().dev_id
        paddle.set_device('gpu:%d' % device_id)

        pg = init_process_group()
        print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name())
        print("test new group api ok")

        # test allreduce sum
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        sum_result = tensor_x + tensor_y
        if pg.rank() == 0:
            task = dist.all_reduce(tensor_x)
            assert np.array_equal(tensor_x, sum_result)
        else:
            task = dist.all_reduce(tensor_y)
            assert np.array_equal(tensor_y, sum_result)

        print("test allreduce sum api ok")

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        # test allreduce sum with shape = []
        # rank 0
        x = np.random.random([]).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random([]).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        sum_result = tensor_x + tensor_y
        if pg.rank() == 0:
            task = dist.all_reduce(tensor_x)
            assert np.array_equal(tensor_x, sum_result)
        else:
            task = dist.all_reduce(tensor_y)
            assert np.array_equal(tensor_y, sum_result)

        print("test allreduce sum api with = [] ok")

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        # test allreduce max
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        max_result = paddle.maximum(tensor_x, tensor_y)

        if pg.rank() == 0:
            task = dist.all_reduce(tensor_x, dist.ReduceOp.MAX, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, max_result)
        else:
            task = dist.all_reduce(tensor_y, dist.ReduceOp.MAX, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_y, max_result)

        print("test allreduce max api ok")

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        # test allreduce max with shape = []
        # rank 0
        x = np.random.random([]).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random([]).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        max_result = paddle.maximum(tensor_x, tensor_y)

        if pg.rank() == 0:
            task = dist.all_reduce(tensor_x, dist.ReduceOp.MAX, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, max_result)
        else:
            task = dist.all_reduce(tensor_y, dist.ReduceOp.MAX, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_y, max_result)

        print("test allreduce max api with shape = [] ok")

131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
        # test allreduce min
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        min_result = paddle.minimum(tensor_x, tensor_y)

        if pg.rank() == 0:
            task = dist.all_reduce(tensor_x, dist.ReduceOp.MIN, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, min_result)
        else:
            task = dist.all_reduce(tensor_y, dist.ReduceOp.MIN, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_y, min_result)

        print("test allreduce min api ok")

152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
        # test allreduce min with shape = []
        # rank 0
        x = np.random.random([]).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random([]).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        min_result = paddle.minimum(tensor_x, tensor_y)

        if pg.rank() == 0:
            task = dist.all_reduce(tensor_x, dist.ReduceOp.MIN, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, min_result)
        else:
            task = dist.all_reduce(tensor_y, dist.ReduceOp.MIN, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_y, min_result)

        print("test allreduce min api with shape [] ok")

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
        # test allreduce prod
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        prod_result = np.multiply(x, y)

        if pg.rank() == 0:
            task = dist.all_reduce(tensor_x, dist.ReduceOp.PROD, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, prod_result)
        else:
            task = dist.all_reduce(tensor_y, dist.ReduceOp.PROD, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_y, prod_result)

        print("test allreduce prod api ok")

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        # test allreduce prod with shape = []
        # rank 0
        x = np.random.random([]).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random([]).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        prod_result = np.multiply(x, y)

        if pg.rank() == 0:
            task = dist.all_reduce(tensor_x, dist.ReduceOp.PROD, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, prod_result)
        else:
            task = dist.all_reduce(tensor_y, dist.ReduceOp.PROD, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_y, prod_result)

        print("test allreduce prod api with shape = [] ok")

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        # test broadcast
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        broadcast_result = paddle.assign(tensor_x)
        if pg.rank() == 0:
            task = dist.broadcast(tensor_x, 0, sync_op=False)
            task.synchronize()
            paddle.device.cuda.synchronize()
            assert task.is_completed()
            assert np.array_equal(broadcast_result, tensor_x)
        else:
            task = dist.broadcast(tensor_y, 0)
            paddle.device.cuda.synchronize()
            assert np.array_equal(broadcast_result, tensor_y)

        print("test broadcast api ok")

        # test broadcast with shape=[]
        # rank 0
        x = np.random.random([]).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random([]).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        broadcast_result = paddle.assign(tensor_x)
        if pg.rank() == 0:
            task = dist.broadcast(tensor_x, 0, sync_op=False)
            task.synchronize()
            paddle.device.cuda.synchronize()
            assert task.is_completed()
            assert np.array_equal(broadcast_result, tensor_x)
        else:
            task = dist.broadcast(tensor_y, 0)
            paddle.device.cuda.synchronize()
            assert np.array_equal(broadcast_result, tensor_y)
        assert tensor_y.shape == []

        print("test broadcast api with shape=[] ok")

        # test barrier
        # rank 0
        if pg.rank() == 0:
            pg.barrier(device_id)
        # rank 1
        else:
            task = pg.barrier(device_id)
            task.wait()

        print("test barrier api ok\n")

        # test allgather
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        tensor_y = paddle.to_tensor(y)
        out_shape = list(self.shape)
        out_shape[0] *= 2
        out = np.random.random(out_shape).astype(self.dtype)
        tensor_out = paddle.to_tensor(out)
        if pg.rank() == 0:
            task = pg.all_gather(tensor_x, tensor_out)
            task.wait()
            paddle.device.cuda.synchronize()
        # rank 1
        else:
            tensor_out_list = [
                paddle.empty_like(tensor_x),
                paddle.empty_like(tensor_x),
            ]
            task = dist.all_gather(tensor_out_list, tensor_y, sync_op=False)
            paddle.device.cuda.synchronize()
            tensor_out = paddle.concat(tensor_out_list)
        out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
        out_2 = paddle.slice(
            tensor_out, [0], [out_shape[0] // 2], [out_shape[0]]
        )
        assert np.array_equal(tensor_x, out_1)
        assert np.array_equal(tensor_y, out_2)
        print("test allgather api ok\n")

        if pg.rank() == 0:
            task = pg.all_gather(tensor_x, tensor_out)
            task.wait()
            paddle.device.cuda.synchronize()
        # rank 1
        else:
            tensor_out_list = []
            task = dist.all_gather(tensor_out_list, tensor_y, sync_op=False)
            paddle.device.cuda.synchronize()
            tensor_out = paddle.concat(tensor_out_list)
        out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
        out_2 = paddle.slice(
            tensor_out, [0], [out_shape[0] // 2], [out_shape[0]]
        )
        assert np.array_equal(tensor_x, out_1)
        assert np.array_equal(tensor_y, out_2)
        print("test allgather api2 ok\n")

320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
        # test allgather with shape = []
        # rank 0
        x = np.random.random([]).astype(self.dtype)
        y = np.random.random([]).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        tensor_y = paddle.to_tensor(y)
        tensor_out_list = []
        if pg.rank() == 0:
            task = dist.all_gather(tensor_out_list, tensor_x)
            task.wait()
            paddle.device.cuda.synchronize()
        # rank 1
        else:
            task = dist.all_gather(tensor_out_list, tensor_y, sync_op=False)
            paddle.device.cuda.synchronize()
        out_1 = tensor_out_list[0]
        out_2 = tensor_out_list[1]
        assert np.array_equal(tensor_x, out_1)
        assert np.array_equal(tensor_y, out_2)
        print("test allgather api with shape [] ok\n")

341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
        # test alltoall
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        y = np.random.random(self.shape).astype(self.dtype)
        out1 = np.random.random(self.shape).astype(self.dtype)
        out2 = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        tensor_y = paddle.to_tensor(y)
        tensor_out1 = paddle.to_tensor(out1)
        tensor_out2 = paddle.to_tensor(out2)
        raw_tensor_x_2 = paddle.slice(
            tensor_x, [0], [self.shape[0] // 2], [self.shape[0]]
        )
        raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0], [self.shape[0] // 2])
        if pg.rank() == 0:
            task = pg.alltoall(tensor_x, tensor_out1)
            task.wait()
        # rank 1
        else:
            in_1, in_2 = paddle.split(tensor_y, 2)
            out_1, out_2 = paddle.split(tensor_out2, 2)
            out_tensor_list = [out_1, out_2]
            task = dist.alltoall([in_1, in_2], out_tensor_list)
            paddle.device.cuda.synchronize()
            tensor_out2 = paddle.concat(out_tensor_list)
        out1_2 = paddle.slice(
            tensor_out1, [0], [self.shape[0] // 2], [self.shape[0]]
        )
        out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
        if pg.rank() == 0:
            assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
        else:
            assert np.array_equal(out2_1, raw_tensor_x_2)
        print("test alltoall api ok\n")

        x = np.random.random(self.shape).astype(self.dtype)
        y = np.random.random(self.shape).astype(self.dtype)
        out1 = np.random.random(self.shape).astype(self.dtype)
        out2 = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        tensor_y = paddle.to_tensor(y)
        tensor_out1 = paddle.to_tensor(out1)
        tensor_out2 = paddle.to_tensor(out2)
        raw_tensor_x_2 = paddle.slice(
            tensor_x, [0], [self.shape[0] // 2], [self.shape[0]]
        )
        raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0], [self.shape[0] // 2])
        if pg.rank() == 0:
            task = pg.alltoall(tensor_x, tensor_out1)
            task.wait()
        # rank 1
        else:
            in_1, in_2 = paddle.split(tensor_y, 2)
            out_1, out_2 = paddle.split(tensor_out2, 2)
            out_tensor_list = []
            task = dist.alltoall([in_1, in_2], out_tensor_list)
            paddle.device.cuda.synchronize()
            tensor_out2 = paddle.concat(out_tensor_list)
        out1_2 = paddle.slice(
            tensor_out1, [0], [self.shape[0] // 2], [self.shape[0]]
        )
        out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
        if pg.rank() == 0:
            assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
        else:
            assert np.array_equal(out2_1, raw_tensor_x_2)
        print("test alltoall api2 ok\n")

        # test Reduce
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        tensor_y = paddle.to_tensor(y)
        sum_result = tensor_x + tensor_y
        if pg.rank() == 0:
            task = dist.reduce(tensor_x, 0, sync_op=True)
            paddle.device.cuda.synchronize()
        # rank 1
        else:
            task = dist.reduce(tensor_y, 0, sync_op=False)
            task.wait()
            paddle.device.cuda.synchronize()
        if pg.rank() == 0:
            assert np.array_equal(tensor_x, sum_result)
        print("test reduce sum api ok\n")

        # test reduce max
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        max_result = paddle.maximum(tensor_x, tensor_y)

        if pg.rank() == 0:
            task = dist.reduce(tensor_x, 0, dist.ReduceOp.MAX, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, max_result)
        else:
            task = dist.reduce(tensor_y, 0, dist.ReduceOp.MAX, sync_op=False)
            task.wait()

        print("test reduce max api ok")

        # test reduce min
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        min_result = paddle.minimum(tensor_x, tensor_y)

        if pg.rank() == 0:
            task = dist.reduce(tensor_x, 0, dist.ReduceOp.MIN, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, min_result)
        else:
            task = dist.reduce(tensor_y, 0, dist.ReduceOp.MIN, sync_op=False)
            task.wait()

        print("test reduce min api ok")

        # test reduce product
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        prod_result = np.multiply(x, y)

        if pg.rank() == 0:
            task = dist.reduce(tensor_x, 0, dist.ReduceOp.PROD, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_x, prod_result)
        else:
            task = dist.reduce(tensor_y, 0, dist.ReduceOp.PROD, sync_op=False)
            task.wait()

        print("test reduce prod api ok")
        # test Scatter
        # rank 0
        in_shape = list(self.shape)
        in_shape[0] *= 2
        x = np.random.random(in_shape).astype(self.dtype)
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        tensor_y = paddle.to_tensor(y)
        if pg.rank() == 0:
            in_1, in_2 = paddle.split(tensor_x, 2)
            task = dist.scatter(tensor_y, [in_1, in_2], 0, sync_op=True)
            # task.wait()
            paddle.device.cuda.synchronize()
        # rank 1
        else:
            task = dist.scatter(tensor_y, [], 0, sync_op=False)
            task.wait()
            paddle.device.cuda.synchronize()
        out1 = paddle.slice(tensor_x, [0], [0], [self.shape[0]])
        out2 = paddle.slice(tensor_x, [0], [self.shape[0]], [self.shape[0] * 2])
        if pg.rank() == 0:
            assert np.array_equal(tensor_y, out1)
        else:
            assert np.array_equal(tensor_y, out2)
        print("test scatter api ok\n")

        # test Scatter with shape=[]
        # rank 0
        x = np.random.random([]).astype(self.dtype)
        y = np.random.random([]).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        tensor_y = paddle.to_tensor(y)
        if pg.rank() == 0:
            in_1, in_2 = tensor_x, tensor_x + 1
            task = dist.scatter(tensor_y, [in_1, in_2], 0, sync_op=True)
            paddle.device.cuda.synchronize()
        # rank 1
        else:
            task = dist.scatter(tensor_y, [], 0, sync_op=True)
            task.wait()
            paddle.device.cuda.synchronize()
        out1 = paddle.assign(tensor_x)
        out2 = paddle.assign(tensor_x + 1)
        if pg.rank() == 0:
            assert np.array_equal(tensor_y, out1)
        else:
            assert np.array_equal(tensor_y, out2), f"{tensor_y}, {out2}"
        assert tensor_y.shape == []
        print("test scatter api with shape=[] ok\n")

        # test send min
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        if pg.rank() == 0:
            task = dist.send(tensor_x, 1, sync_op=False)
            task.wait()
        else:
            task = dist.recv(tensor_y, 0, sync_op=False)
            task.wait()
            assert np.array_equal(tensor_y, tensor_x)

        print("test send api ok")

        # test send min
        # rank 0
        x = np.random.random(self.shape).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.random.random(self.shape).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        if pg.rank() == 0:
            task = dist.send(tensor_x, 1, sync_op=True)
        else:
            task = dist.recv(tensor_y, 0, sync_op=True)
            assert np.array_equal(tensor_y, tensor_x)

        print("test send api ok")
570

571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
        # test send 0-d tensor
        # rank 0
        x = np.random.uniform(-1, 1, []).astype(self.dtype)
        tensor_x = paddle.to_tensor(x)
        # rank 1
        y = np.array(0.2022).astype(self.dtype)
        tensor_y = paddle.to_tensor(y)

        if pg.rank() == 0:
            task = dist.send(tensor_x, 1, sync_op=True)
        else:
            task = dist.recv(tensor_y, 0, sync_op=True)
            assert np.array_equal(tensor_y, tensor_x) and tensor_y.shape == []

        print("test send & recv 0-d tensor ok")

587 588 589 590 591 592 593 594 595 596 597 598 599 600 601

class TestProcessGroupFp16(TestProcessGroupFp32):
    def setUp(self):
        paddle.seed(2022)
        random.seed(2022)
        np.random.seed(2022)
        self.config()

    def config(self):
        self.dtype = "float16"
        self.shape = (4, 20, 20)


if __name__ == "__main__":
    unittest.main()