process_group_nccl.py 16.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import random
import numpy as np

import paddle
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
22
import paddle.distributed as dist
23 24 25


def init_process_group(strategy=None):
26 27 28
    nranks = ParallelEnv().nranks
    rank = ParallelEnv().local_rank
    is_master = True if rank == 0 else False
29
    pg_group = dist.init_parallel_env()
30

31
    return pg_group.process_group
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46


class TestProcessGroupFp32(unittest.TestCase):
    def setUp(self):
        paddle.seed(2022)
        random.seed(2022)
        np.random.seed(2022)
        self.config()

    def config(self):
        self.dtype = "float32"
        self.shape = (2, 10, 5)

    def test_create_process_group_nccl(self):
        with _test_eager_guard():
47 48
            device_id = paddle.distributed.ParallelEnv().dev_id
            paddle.set_device('gpu:%d' % device_id)
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

            pg = init_process_group()
            print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name())
            print("test new group api ok")

            # test allreduce sum
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            sum_result = tensor_x + tensor_y
            if pg.rank() == 0:
64
                task = dist.all_reduce(tensor_x)
65 66
                assert np.array_equal(tensor_x, sum_result)
            else:
67
                task = dist.all_reduce(tensor_y)
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
                assert np.array_equal(tensor_y, sum_result)

            print("test allreduce sum api ok")

            # test allreduce max
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            max_result = paddle.maximum(tensor_x, tensor_y)

            if pg.rank() == 0:
83 84 85
                task = dist.all_reduce(
                    tensor_x, dist.ReduceOp.MAX, sync_op=False
                )
86 87 88
                task.wait()
                assert np.array_equal(tensor_x, max_result)
            else:
89 90 91
                task = dist.all_reduce(
                    tensor_y, dist.ReduceOp.MAX, sync_op=False
                )
92 93 94 95 96
                task.wait()
                assert np.array_equal(tensor_y, max_result)

            print("test allreduce max api ok")

97 98 99 100 101 102 103 104 105 106 107
            # test allreduce min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            min_result = paddle.minimum(tensor_x, tensor_y)

            if pg.rank() == 0:
108 109 110
                task = dist.all_reduce(
                    tensor_x, dist.ReduceOp.MIN, sync_op=False
                )
111 112 113
                task.wait()
                assert np.array_equal(tensor_x, min_result)
            else:
114 115 116
                task = dist.all_reduce(
                    tensor_y, dist.ReduceOp.MIN, sync_op=False
                )
117 118 119 120 121
                task.wait()
                assert np.array_equal(tensor_y, min_result)

            print("test allreduce min api ok")

122 123 124 125 126 127 128 129 130 131 132
            # test allreduce prod
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            prod_result = np.multiply(x, y)

            if pg.rank() == 0:
133 134 135
                task = dist.all_reduce(
                    tensor_x, dist.ReduceOp.PROD, sync_op=False
                )
136 137 138
                task.wait()
                assert np.array_equal(tensor_x, prod_result)
            else:
139 140 141
                task = dist.all_reduce(
                    tensor_y, dist.ReduceOp.PROD, sync_op=False
                )
142 143 144 145 146
                task.wait()
                assert np.array_equal(tensor_y, prod_result)

            print("test allreduce prod api ok")

147 148 149 150 151 152 153 154 155 156
            # test broadcast
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            broadcast_result = paddle.assign(tensor_x)
            if pg.rank() == 0:
157
                task = dist.broadcast(tensor_x, 0, sync_op=False)
158 159 160 161 162
                task.synchronize()
                paddle.device.cuda.synchronize()
                assert task.is_completed()
                assert np.array_equal(broadcast_result, tensor_x)
            else:
163
                task = dist.broadcast(tensor_y, 0)
164 165 166 167 168
                paddle.device.cuda.synchronize()
                assert np.array_equal(broadcast_result, tensor_y)

            print("test broadcast api ok")

B
Baibaifan 已提交
169 170 171
            # test barrier
            # rank 0
            if pg.rank() == 0:
172
                pg.barrier(device_id)
B
Baibaifan 已提交
173 174
            # rank 1
            else:
175
                task = pg.barrier(device_id)
B
Baibaifan 已提交
176 177 178 179
                task.wait()

            print("test barrier api ok\n")

180
            # test allgather
B
Baibaifan 已提交
181 182
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
183 184 185 186 187 188 189 190 191 192 193 194 195
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            out_shape = list(self.shape)
            out_shape[0] *= 2
            out = np.random.random(out_shape).astype(self.dtype)
            tensor_out = paddle.to_tensor(out)
            if pg.rank() == 0:
                task = pg.all_gather(tensor_x, tensor_out)
                task.wait()
                paddle.device.cuda.synchronize()
            # rank 1
            else:
196
                tensor_out_list = [
197
                    paddle.empty_like(tensor_x),
198
                    paddle.empty_like(tensor_x),
199
                ]
200
                task = dist.all_gather(tensor_out_list, tensor_y, sync_op=False)
201
                paddle.device.cuda.synchronize()
202
                tensor_out = paddle.concat(tensor_out_list)
203
            out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
204 205 206
            out_2 = paddle.slice(
                tensor_out, [0], [out_shape[0] // 2], [out_shape[0]]
            )
207 208 209 210
            assert np.array_equal(tensor_x, out_1)
            assert np.array_equal(tensor_y, out_2)
            print("test allgather api ok\n")

211 212 213 214 215 216 217
            if pg.rank() == 0:
                task = pg.all_gather(tensor_x, tensor_out)
                task.wait()
                paddle.device.cuda.synchronize()
            # rank 1
            else:
                tensor_out_list = []
218
                task = dist.all_gather(tensor_out_list, tensor_y, sync_op=False)
219 220 221
                paddle.device.cuda.synchronize()
                tensor_out = paddle.concat(tensor_out_list)
            out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
222 223 224
            out_2 = paddle.slice(
                tensor_out, [0], [out_shape[0] // 2], [out_shape[0]]
            )
225 226 227 228
            assert np.array_equal(tensor_x, out_1)
            assert np.array_equal(tensor_y, out_2)
            print("test allgather api2 ok\n")

229 230 231 232 233 234
            # test alltoall
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            out1 = np.random.random(self.shape).astype(self.dtype)
            out2 = np.random.random(self.shape).astype(self.dtype)
B
Baibaifan 已提交
235
            tensor_x = paddle.to_tensor(x)
236 237 238
            tensor_y = paddle.to_tensor(y)
            tensor_out1 = paddle.to_tensor(out1)
            tensor_out2 = paddle.to_tensor(out2)
239 240 241 242 243 244
            raw_tensor_x_2 = paddle.slice(
                tensor_x, [0], [self.shape[0] // 2], [self.shape[0]]
            )
            raw_tensor_y_1 = paddle.slice(
                tensor_y, [0], [0], [self.shape[0] // 2]
            )
B
Baibaifan 已提交
245
            if pg.rank() == 0:
246
                task = pg.alltoall(tensor_x, tensor_out1)
B
Baibaifan 已提交
247 248 249
                task.wait()
            # rank 1
            else:
250 251 252 253
                in_1, in_2 = paddle.split(tensor_y, 2)
                out_1, out_2 = paddle.split(tensor_out2, 2)
                out_tensor_list = [out_1, out_2]
                task = dist.alltoall([in_1, in_2], out_tensor_list)
B
Baibaifan 已提交
254
                paddle.device.cuda.synchronize()
255
                tensor_out2 = paddle.concat(out_tensor_list)
256 257 258
            out1_2 = paddle.slice(
                tensor_out1, [0], [self.shape[0] // 2], [self.shape[0]]
            )
259 260 261 262 263 264 265
            out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
            if pg.rank() == 0:
                assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
            else:
                assert np.array_equal(out2_1, raw_tensor_x_2)
            print("test alltoall api ok\n")

266 267 268 269 270 271 272 273
            x = np.random.random(self.shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            out1 = np.random.random(self.shape).astype(self.dtype)
            out2 = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            tensor_out1 = paddle.to_tensor(out1)
            tensor_out2 = paddle.to_tensor(out2)
274 275 276 277 278 279
            raw_tensor_x_2 = paddle.slice(
                tensor_x, [0], [self.shape[0] // 2], [self.shape[0]]
            )
            raw_tensor_y_1 = paddle.slice(
                tensor_y, [0], [0], [self.shape[0] // 2]
            )
280 281 282 283 284 285 286 287 288 289 290
            if pg.rank() == 0:
                task = pg.alltoall(tensor_x, tensor_out1)
                task.wait()
            # rank 1
            else:
                in_1, in_2 = paddle.split(tensor_y, 2)
                out_1, out_2 = paddle.split(tensor_out2, 2)
                out_tensor_list = []
                task = dist.alltoall([in_1, in_2], out_tensor_list)
                paddle.device.cuda.synchronize()
                tensor_out2 = paddle.concat(out_tensor_list)
291 292 293
            out1_2 = paddle.slice(
                tensor_out1, [0], [self.shape[0] // 2], [self.shape[0]]
            )
294 295 296 297 298 299 300
            out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
            if pg.rank() == 0:
                assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
            else:
                assert np.array_equal(out2_1, raw_tensor_x_2)
            print("test alltoall api2 ok\n")

301 302 303 304 305 306 307 308
            # test Reduce
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            sum_result = tensor_x + tensor_y
            if pg.rank() == 0:
309
                task = dist.reduce(tensor_x, 0, sync_op=True)
310 311 312
                paddle.device.cuda.synchronize()
            # rank 1
            else:
313
                task = dist.reduce(tensor_y, 0, sync_op=False)
314 315 316 317 318 319
                task.wait()
                paddle.device.cuda.synchronize()
            if pg.rank() == 0:
                assert np.array_equal(tensor_x, sum_result)
            print("test reduce sum api ok\n")

320 321 322 323 324 325 326 327 328 329 330
            # test reduce max
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            max_result = paddle.maximum(tensor_x, tensor_y)

            if pg.rank() == 0:
331 332 333
                task = dist.reduce(
                    tensor_x, 0, dist.ReduceOp.MAX, sync_op=False
                )
334 335 336
                task.wait()
                assert np.array_equal(tensor_x, max_result)
            else:
337 338 339
                task = dist.reduce(
                    tensor_y, 0, dist.ReduceOp.MAX, sync_op=False
                )
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
                task.wait()

            print("test reduce max api ok")

            # test reduce min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            min_result = paddle.minimum(tensor_x, tensor_y)

            if pg.rank() == 0:
355 356 357
                task = dist.reduce(
                    tensor_x, 0, dist.ReduceOp.MIN, sync_op=False
                )
358 359 360
                task.wait()
                assert np.array_equal(tensor_x, min_result)
            else:
361 362 363
                task = dist.reduce(
                    tensor_y, 0, dist.ReduceOp.MIN, sync_op=False
                )
364 365 366 367
                task.wait()

            print("test reduce min api ok")

368 369 370 371 372 373 374 375 376 377 378
            # test reduce product
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            prod_result = np.multiply(x, y)

            if pg.rank() == 0:
379 380 381
                task = dist.reduce(
                    tensor_x, 0, dist.ReduceOp.PROD, sync_op=False
                )
382 383 384
                task.wait()
                assert np.array_equal(tensor_x, prod_result)
            else:
385 386 387
                task = dist.reduce(
                    tensor_y, 0, dist.ReduceOp.PROD, sync_op=False
                )
388 389 390
                task.wait()

            print("test reduce prod api ok")
391 392 393 394 395 396 397 398 399
            # test Scatter
            # rank 0
            in_shape = list(self.shape)
            in_shape[0] *= 2
            x = np.random.random(in_shape).astype(self.dtype)
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            tensor_y = paddle.to_tensor(y)
            if pg.rank() == 0:
400
                in_1, in_2 = paddle.split(tensor_x, 2)
401
                task = dist.scatter(tensor_y, [in_1, in_2], 0, sync_op=True)
402
                # task.wait()
403 404 405
                paddle.device.cuda.synchronize()
            # rank 1
            else:
406
                task = dist.scatter(tensor_y, [], 0, sync_op=False)
407 408 409
                task.wait()
                paddle.device.cuda.synchronize()
            out1 = paddle.slice(tensor_x, [0], [0], [self.shape[0]])
410 411 412
            out2 = paddle.slice(
                tensor_x, [0], [self.shape[0]], [self.shape[0] * 2]
            )
413 414 415 416 417
            if pg.rank() == 0:
                assert np.array_equal(tensor_y, out1)
            else:
                assert np.array_equal(tensor_y, out2)
            print("test scatter api ok\n")
B
Baibaifan 已提交
418

419 420 421 422 423 424 425 426 427
            # test send min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            if pg.rank() == 0:
428
                task = dist.send(tensor_x, 1, sync_op=False)
429 430
                task.wait()
            else:
431
                task = dist.recv(tensor_y, 0, sync_op=False)
432 433 434 435 436 437 438 439 440 441 442 443 444 445
                task.wait()
                assert np.array_equal(tensor_y, tensor_x)

            print("test send api ok")

            # test send min
            # rank 0
            x = np.random.random(self.shape).astype(self.dtype)
            tensor_x = paddle.to_tensor(x)
            # rank 1
            y = np.random.random(self.shape).astype(self.dtype)
            tensor_y = paddle.to_tensor(y)

            if pg.rank() == 0:
446
                task = dist.send(tensor_x, 1, sync_op=True)
447
            else:
448
                task = dist.recv(tensor_y, 0, sync_op=True)
449 450 451 452
                assert np.array_equal(tensor_y, tensor_x)

            print("test send api ok")

453 454 455 456 457 458 459 460 461 462 463 464 465 466 467

class TestProcessGroupFp16(TestProcessGroupFp32):
    def setUp(self):
        paddle.seed(2022)
        random.seed(2022)
        np.random.seed(2022)
        self.config()

    def config(self):
        self.dtype = "float16"
        self.shape = (4, 20, 20)


if __name__ == "__main__":
    unittest.main()