distributed_py.cc 55.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif

#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif

#include "paddle/fluid/distributed/collective/ProcessGroup.h"
25
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
26
#include "paddle/fluid/distributed/collective/Types.h"
27
#include "paddle/fluid/distributed/collective/reducer.h"
28 29 30 31 32
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/pybind/distributed_py.h"
#include "paddle/fluid/pybind/eager_utils.h"
33
#include "paddle/fluid/pybind/process_group_utils.h"
34 35
#include "paddle/phi/api/all.h"

36
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
37 38 39
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif

W
wuhuachaocoding 已提交
40 41 42 43
#if defined(PADDLE_WITH_MPI)
#include "paddle/fluid/distributed/collective/ProcessGroupMPI.h"
#endif

44 45 46 47
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/distributed/collective/ProcessGroupCustom.h"
#endif

48 49 50 51 52
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#endif

J
james 已提交
53 54 55 56
#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/distributed/collective/ProcessGroupBKCL.h"
#endif

57 58
#include "paddle/phi/kernels/sync_batch_norm_kernel.h"

59 60 61 62 63 64 65
namespace py = pybind11;

namespace paddle {
namespace pybind {

using Tensor = paddle::experimental::Tensor;

66 67 68 69 70
std::shared_ptr<distributed::EagerReducer> CreateEagerReducer(
    py::handle py_tensors,
    const std::vector<std::vector<size_t>> &group_indices,
    const std::vector<bool> &is_sparse_gradient,
    std::shared_ptr<distributed::ProcessGroup> process_group,
71 72
    const std::vector<size_t> &group_size_limits,
    bool find_unused_parameters) {
73
  auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
74 75 76 77 78 79
  return std::make_shared<distributed::EagerReducer>(params,
                                                     group_indices,
                                                     is_sparse_gradient,
                                                     process_group,
                                                     group_size_limits,
                                                     find_unused_parameters);
80 81
}

82 83 84 85 86 87
#if defined(PADDLE_WITH_GLOO)
using ProcessGroupGloo = paddle::distributed::ProcessGroupGloo;
using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore;
using GlooOptions = paddle::distributed::ProcessGroupGloo::GlooOptions;
#endif

88 89 90
static UNUSED void *use_ccl_comm_func =
    phi::detail::GetCCLComm(phi::CPUPlace());

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
void BindDistributed(py::module *m) {
  py::enum_<distributed::ReduceOp>(*m, "ReduceOp")
      .value("SUM", distributed::ReduceOp::SUM)
      .value("AVG", distributed::ReduceOp::AVG)
      .value("MAX", distributed::ReduceOp::MAX)
      .value("MIN", distributed::ReduceOp::MIN)
      .value("PRODUCT", distributed::ReduceOp::PRODUCT);

  py::class_<distributed::AllreduceOptions>(*m, "AllreduceOptions")
      .def(py::init<>())
      .def_readwrite("reduce_op", &distributed::AllreduceOptions::reduce_op);

  py::class_<distributed::BroadcastOptions>(*m, "BroadcastOptions")
      .def(py::init<>())
      .def_readwrite("source_rank", &distributed::BroadcastOptions::source_rank)
      .def_readwrite("source_root",
                     &distributed::BroadcastOptions::source_root);

B
Baibaifan 已提交
109 110
  py::class_<distributed::BarrierOptions>(*m, "BarrierOptions")
      .def(py::init<>())
111
      .def_readwrite("device_id", &distributed::BarrierOptions::device_id);
B
Baibaifan 已提交
112

113 114 115 116 117
  py::class_<distributed::ReduceOptions>(*m, "ReduceOptions")
      .def(py::init<>())
      .def_readwrite("reduce_op", &distributed::ReduceOptions::reduce_op)
      .def_readwrite("source_root", &distributed::ReduceOptions::root_rank);

118 119 120 121 122 123
  auto ProcessGroup =
      py::class_<distributed::ProcessGroup,
                 std::shared_ptr<distributed::ProcessGroup>>(*m, "ProcessGroup")
          .def("rank", &distributed::ProcessGroup::GetRank)
          .def("size", &distributed::ProcessGroup::GetSize)
          .def("name", &distributed::ProcessGroup::GetBackendName)
124
          .def(
L
LiYuRio 已提交
125
              "all_reduce",
126 127 128 129 130
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 distributed::ReduceOp op,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
131
                auto p_dense =
132
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
133 134 135 136
                auto *out_dense = p_dense.get();
                auto in_dense = *p_dense;
                distributed::AllreduceOptions opts{op};
                return self.AllReduce(out_dense, in_dense, opts, sync_op);
137 138 139 140 141 142
              },
              py::arg("tensor"),
              py::arg("op"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

143 144 145 146 147 148 149
          .def(
              "broadcast",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
150
                auto p_dense =
151
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
152 153 154 155
                auto *out_dense = p_dense.get();
                auto in_dense = *p_dense;
                distributed::BroadcastOptions opts{src};
                return self.Broadcast(out_dense, in_dense, opts, sync_op);
156 157 158 159 160 161
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

162 163 164 165 166 167 168
          .def(
              "send",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int dst,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
169
                auto p_dense =
170
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
171
                auto *out_dense = p_dense.get();
172 173 174
                // numel == -1 indicates sending the whole tensor
                return self.Send(
                    out_dense, dst, /*offset*/ 0, /*numel*/ -1, sync_op);
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "send_partial",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int dst_rank,
                 int nranks,
                 int rank_id,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
190
                auto p_dense =
191
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
192 193
                auto *out_dense = p_dense.get();

194
                int64_t numel = p_dense->numel();
195 196
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
197 198

                return self.Send(
199
                    out_dense, dst_rank, offset, send_numel, sync_op);
200 201 202 203 204
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("num"),
              py::arg("id"),
205
              py::arg("sync_op") = true,
206 207 208 209 210 211 212 213 214
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
215
                auto p_dense =
216
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
217
                auto *in_dense = p_dense.get();
218 219 220
                // numel == -1 indicates receiving the whole tensor
                return self.Recv(
                    in_dense, src, /*offset*/ 0, /*numel*/ -1, sync_op);
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv_partial",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src_rank,
                 int nranks,
                 int rank_id,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
236
                auto p_dense =
237
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
238 239
                auto *out_dense = p_dense.get();

240
                int64_t numel = p_dense->numel();
241 242
                int64_t recv_numel = numel / nranks;
                int64_t offset = recv_numel * rank_id;
243 244

                return self.Recv(
245
                    out_dense, src_rank, offset, recv_numel, sync_op);
246 247 248 249 250
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("num"),
              py::arg("id"),
251
              py::arg("sync_op") = true,
252 253
              py::call_guard<py::gil_scoped_release>())

254 255
          .def(
              "all_gather",
256 257
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor_list,
258
                 py::handle py_in_tensor,
259 260 261 262
                 bool sync_op) {
                auto out_tensor_list =
                    CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
                Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
263
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
264
                    concat_out_tensor.impl());
265 266 267 268 269 270
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
271

272
                auto *dev_ctx = self.GetDeviceContext(in_tensor.place());
273 274 275 276 277
                auto task = self.AllGather(out_dense,
                                           in_dense,
                                           /*offset*/ 0,
                                           /*numel*/ -1,
                                           sync_op);
278 279
                SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
                task->UpdateWaitChain(*dev_ctx);
280 281 282
                return task;
              },
              py::arg("out"),
283
              py::arg("in"),
284 285 286 287
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
288
              "all_gather_into_tensor",
289 290
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
291
                 py::handle py_in_tensor,
292 293
                 bool sync_op) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
294
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
295
                    out_tensor.impl());
296 297 298 299 300 301
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
302

303 304 305 306 307
                return self.AllGather(out_dense,
                                      in_dense,
                                      /*offset*/ 0,
                                      /*numel*/ -1,
                                      sync_op);
308 309
              },
              py::arg("out"),
310
              py::arg("in"),
311 312 313
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

314
          .def(
L
LiYuRio 已提交
315
              "all_to_all",
316 317
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor_list,
318
                 py::handle py_in_tensor_list,
319 320 321 322
                 bool sync_op) {
                auto out_tensor_list =
                    CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
                Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
323
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
324
                    concat_out_tensor.impl());
325 326 327 328 329 330 331 332
                auto *out_dense = p_out_tensor.get();

                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_in_tensor.impl());
                auto in_dense = *p_in_tensor;
333 334

                // in_tensor_list should not be empty
335
                auto *dev_ctx =
336
                    self.GetDeviceContext(in_tensor_list.back().place());
337 338 339 340 341 342 343
                int world_size = self.GetSize();
                auto task =
                    self.AllToAll(out_dense,
                                  in_dense,
                                  GetDefaultSplitSizes(*out_dense, world_size),
                                  GetDefaultSplitSizes(in_dense, world_size),
                                  sync_op);
344 345
                SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
                task->UpdateWaitChain(*dev_ctx);
346 347 348
                return task;
              },
              py::arg("out"),
349
              py::arg("in"),
350 351 352 353
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
354
              "all_to_all_tensor",
355 356
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
357
                 py::handle py_in_tensor,
358 359
                 bool sync_op) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
360
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
361
                    out_tensor.impl());
362 363 364 365 366 367
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
368

369 370 371 372 373 374 375
                int world_size = self.GetSize();
                return self.AllToAll(
                    out_dense,
                    in_dense,
                    GetDefaultSplitSizes(*out_dense, world_size),
                    GetDefaultSplitSizes(in_dense, world_size),
                    sync_op);
376 377
              },
              py::arg("out"),
378
              py::arg("in"),
379 380 381
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

382
          .def(
L
LiYuRio 已提交
383
              "all_to_all_single",
384 385
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
386 387 388
                 py::handle py_in_tensor,
                 const std::vector<int64_t> &out_sizes,
                 const std::vector<int64_t> &in_sizes,
389 390
                 bool sync_op) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
391
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
392
                    out_tensor.impl());
393 394 395 396 397 398
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
399

400 401
                return self.AllToAll(
                    out_dense, in_dense, out_sizes, in_sizes, sync_op);
402 403
              },
              py::arg("out"),
404
              py::arg("in"),
405
              py::arg("out_sizes"),
406
              py::arg("in_sizes"),
407 408 409 410 411 412
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce",
              [](distributed::ProcessGroup &self,
413
                 py::handle py_tensor,
414 415 416
                 int dst,
                 distributed::ReduceOp op,
                 bool sync_op) {
417 418 419 420 421
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto p_dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                auto *out_dense = p_dense.get();
                auto in_dense = *p_dense;
422
                distributed::ReduceOptions opts{op, dst};
423
                return self.Reduce(out_dense, in_dense, opts, sync_op);
424 425 426 427 428 429 430 431 432 433 434
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("op"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_scatter",
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
435
                 py::handle py_in_tensor_list,
436 437
                 distributed::ReduceOp op,
                 bool sync_op) {
438 439 440 441 442
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                auto out_dense = p_out_tensor.get();

443 444 445
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
446
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
447
                    concat_in_tensor.impl());
448
                auto in_dense = *p_in_tensor;
449 450

                distributed::ReduceScatterOptions opts{op};
451
                return self.ReduceScatter(out_dense, in_dense, opts, sync_op);
452 453
              },
              py::arg("out"),
454
              py::arg("in"),
455 456 457 458 459 460 461 462
              py::arg("op"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_scatter_tensor",
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
463
                 py::handle py_in_tensor,
464 465 466
                 distributed::ReduceOp op,
                 bool sync_op) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
467
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
468
                    out_tensor.impl());
469 470 471 472 473 474
                auto out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
475 476

                distributed::ReduceScatterOptions opts{op};
477
                return self.ReduceScatter(out_dense, in_dense, opts, sync_op);
478 479
              },
              py::arg("out"),
480
              py::arg("in"),
481 482 483 484 485 486 487 488
              py::arg("op"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter",
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
489
                 py::handle py_in_tensor_list,
490 491
                 int src,
                 bool sync_op) {
492 493 494 495 496
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                auto *out_dense = p_out_tensor.get();

497 498 499
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
500
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
501
                    concat_in_tensor.impl());
502
                auto in_dense = *p_in_tensor;
503 504

                distributed::ScatterOptions opts{src};
505
                return self.Scatter(out_dense, in_dense, opts, sync_op);
506 507
              },
              py::arg("out"),
508
              py::arg("in"),
509 510 511 512 513 514 515 516
              py::arg("src"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter_tensor",
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
517
                 py::handle py_in_tensor,
518 519 520
                 int src,
                 bool sync_op) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
521
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
522
                    out_tensor.impl());
523 524 525 526 527 528
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
529 530

                distributed::ScatterOptions opts{src};
531
                return self.Scatter(out_dense, in_dense, opts, sync_op);
532 533
              },
              py::arg("out"),
534
              py::arg("in"),
535 536 537 538
              py::arg("src"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

L
LiYuRio 已提交
539 540
          .def(
              "barrier",
541
              [](distributed::ProcessGroup &self, int8_t device_id) {
L
LiYuRio 已提交
542
                distributed::BarrierOptions opts;
543
                opts.device_id = device_id;
L
LiYuRio 已提交
544 545
                return self.Barrier(opts);
              },
546
              py::arg("device_id") = -1,
L
LiYuRio 已提交
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
              py::call_guard<py::gil_scoped_release>())

          // TODO(liyurui): Interface below will be removed in the future.
          .def(
              "allreduce",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 distributed::ReduceOp op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                distributed::AllreduceOptions opts;
                opts.reduce_op = op;
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.AllReduce(tensors, tensors, opts);
              },
              py::arg("tensor"),
              py::arg("op") = distributed::ReduceOp::SUM,
              py::call_guard<py::gil_scoped_release>())

          .def(
              "broadcast",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int source_rank) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                distributed::BroadcastOptions opts;
                opts.source_rank = source_rank;
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Broadcast(tensors, tensors, opts);
              },
              py::arg("tensor"),
              py::arg("source_rank"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "send",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int dst) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Send(tensors, dst);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Recv(tensors, src);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "all_gather",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                return self.AllGather(in_tensors, out_tensors);
              },
              py::arg("in"),
              py::arg("out"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "all_gather_partial",
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
637
                 py::handle py_in_tensor,
L
LiYuRio 已提交
638 639 640
                 int nranks,
                 int rank_id) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
641
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
L
LiYuRio 已提交
642
                    out_tensor.impl());
643 644 645 646 647 648 649 650
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;

                int64_t numel = in_dense.numel();
L
LiYuRio 已提交
651 652
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
653 654
                return self.AllGather(
                    out_dense, in_dense, offset, send_numel, /*sync_op*/ true);
L
LiYuRio 已提交
655 656
              },
              py::arg("out"),
657
              py::arg("in"),
L
LiYuRio 已提交
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
              py::arg("num"),
              py::arg("id"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "alltoall",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                return self.AllToAll(in_tensors, out_tensors);
              },
              py::arg("in"),
              py::arg("out"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "alltoall_single",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
686 687
                 const std::vector<int64_t> in_sizes,
                 const std::vector<int64_t> out_sizes) {
L
LiYuRio 已提交
688
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
689
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
L
LiYuRio 已提交
690
                    out_tensor.impl());
691 692 693 694 695 696 697 698 699
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;

                return self.AllToAll(
                    out_dense, in_dense, out_sizes, in_sizes, /*sync_op*/ true);
L
LiYuRio 已提交
700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("in_sizes"),
              py::arg("out_sizes"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 int dst,
                 distributed::ReduceOp op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                distributed::ReduceOptions opts;
                opts.reduce_op = op;
                opts.root_rank = dst;
                auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Reduce(tensors, tensors, opts);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("op") = distributed::ReduceOp::SUM,
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 int src) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                distributed::ScatterOptions opts;
                opts.root_rank = src;
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                return self.Scatter(in_tensors, out_tensors, opts);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("src"),
748
              py::call_guard<py::gil_scoped_release>());
749

750 751 752 753
  auto ProcessGroupStream =
      py::class_<distributed::ProcessGroupStream,
                 std::shared_ptr<distributed::ProcessGroupStream>>(
          *m, "ProcessGroupStream", ProcessGroup)
754
          .def(
L
LiYuRio 已提交
755
              "all_gather_on_calc_stream",
756
              [](distributed::ProcessGroupStream &self,
757 758
                 py::handle py_out_tensor_list,
                 py::handle py_in_tensor) {
759 760 761
                auto out_tensor_list =
                    CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
                Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
762
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
763
                    concat_out_tensor.impl());
764 765 766 767 768 769
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
770

771
                auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true);
772 773
                auto task = self.AllGather(out_dense,
                                           in_dense,
774 775
                                           /*offset*/ 0,
                                           /*numel*/ -1,
776 777
                                           /*sync_op*/ true,
                                           /*use_calc_stream*/ true);
778
                SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
779 780 781
                return task;
              },
              py::arg("out"),
782
              py::arg("in"),
783 784 785
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
786
              "all_gather_into_tensor_on_calc_stream",
787
              [](distributed::ProcessGroupStream &self,
788 789
                 py::handle py_out_tensor,
                 py::handle py_in_tensor) {
790
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
791
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
792
                    out_tensor.impl());
793
                auto *out_dense = p_out_tensor.get();
794

795 796 797 798 799 800 801
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;

                return self.AllGather(out_dense,
                                      in_dense,
802 803
                                      /*offset*/ 0,
                                      /*numel*/ -1,
804 805 806 807
                                      /*sync_op*/ true,
                                      /*use_calc_stream*/ true);
              },
              py::arg("out"),
808
              py::arg("in"),
809 810
              py::call_guard<py::gil_scoped_release>())

811 812 813 814
          .def(
              "all_gather_partial_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_out_tensor,
815
                 py::handle py_in_tensor,
816 817 818
                 int nranks,
                 int rank_id) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
819
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
820
                    out_tensor.impl());
821 822 823 824 825 826 827 828
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;

                int64_t numel = in_dense.numel();
829 830
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
831 832 833 834 835 836 837

                return self.AllGather(out_dense,
                                      in_dense,
                                      offset,
                                      send_numel,
                                      /*sync_op*/ true,
                                      /*use_calc_stream*/ true);
838 839
              },
              py::arg("out"),
840
              py::arg("in"),
841 842 843 844
              py::arg("num"),
              py::arg("id"),
              py::call_guard<py::gil_scoped_release>())

845
          .def(
L
LiYuRio 已提交
846
              "all_reduce_on_calc_stream",
847 848 849 850
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 distributed::ReduceOp op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
851
                auto p_dense =
852
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
853 854 855 856 857
                auto in_dense = *p_dense;
                auto *out_dense = p_dense.get();
                distributed::AllreduceOptions opts{op};
                return self.AllReduce(out_dense,
                                      in_dense,
858 859 860 861 862
                                      opts,
                                      /*sync_op*/ true,
                                      /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
L
LiYuRio 已提交
863
              py::arg("op") = distributed::ReduceOp::SUM,
864 865
              py::call_guard<py::gil_scoped_release>())

866
          .def(
L
LiYuRio 已提交
867
              "all_to_all_on_calc_stream",
868
              [](distributed::ProcessGroupStream &self,
869 870
                 py::handle py_out_tensor_list,
                 py::handle py_in_tensor_list) {
871 872 873
                auto out_tensor_list =
                    CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
                Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
874
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
875
                    concat_out_tensor.impl());
876 877 878 879 880 881 882 883
                auto *out_dense = p_out_tensor.get();

                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_in_tensor.impl());
                auto in_dense = *p_in_tensor;
884

885
                // in_tensor_list should not be empty
886
                auto *dev_ctx = self.GetDeviceContext(
887
                    in_tensor_list.back().place(), /*use_calc_stream*/ true);
888 889 890 891 892 893 894 895
                int world_size = self.GetSize();
                auto task =
                    self.AllToAll(out_dense,
                                  in_dense,
                                  GetDefaultSplitSizes(*out_dense, world_size),
                                  GetDefaultSplitSizes(in_dense, world_size),
                                  /*sync_op*/ true,
                                  /*use_calc_stream*/ true);
896
                SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
897 898 899
                return task;
              },
              py::arg("out"),
900
              py::arg("in"),
901 902 903
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
904
              "all_to_all_tensor_on_calc_stream",
905
              [](distributed::ProcessGroupStream &self,
906 907
                 py::handle py_out_tensor,
                 py::handle py_in_tensor) {
908
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
909
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
910
                    out_tensor.impl());
911
                auto *out_dense = p_out_tensor.get();
912

913 914 915 916 917 918 919 920 921 922 923 924 925
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;

                int world_size = self.GetSize();
                return self.AllToAll(
                    out_dense,
                    in_dense,
                    GetDefaultSplitSizes(*out_dense, world_size),
                    GetDefaultSplitSizes(in_dense, world_size),
                    /*sync_op*/ true,
                    /*use_calc_stream*/ true);
926 927
              },
              py::arg("out"),
928
              py::arg("in"),
929 930 931
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
932
              "all_to_all_single_on_calc_stream",
933 934
              [](distributed::ProcessGroupStream &self,
                 py::handle py_out_tensor,
935 936 937
                 py::handle py_in_tensor,
                 const std::vector<int64_t> &out_sizes,
                 const std::vector<int64_t> &in_sizes) {
938
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
939
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
940
                    out_tensor.impl());
941
                auto *out_dense = p_out_tensor.get();
942

943 944 945 946 947 948 949 950 951 952 953
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;

                return self.AllToAll(out_dense,
                                     in_dense,
                                     out_sizes,
                                     in_sizes,
                                     /*sync_op*/ true,
                                     /*use_calc_stream*/ true);
954 955
              },
              py::arg("out"),
956
              py::arg("in"),
957
              py::arg("out_sizes"),
958
              py::arg("in_sizes"),
959 960 961 962 963 964 965 966
              py::call_guard<py::gil_scoped_release>())

          .def(
              "broadcast_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int src) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
967
                auto p_dense =
968
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
969 970 971 972 973
                auto *out_dense = p_dense.get();
                auto in_dense = *p_dense;
                distributed::BroadcastOptions opts{src};
                return self.Broadcast(out_dense,
                                      in_dense,
974 975 976 977 978 979 980 981 982 983 984
                                      opts,
                                      /*sync_op*/ true,
                                      /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
985
                 py::handle py_tensor,
986 987
                 int dst,
                 distributed::ReduceOp op) {
988 989 990 991 992
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto p_dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                auto *out_dense = p_dense.get();
                auto in_dense = *p_dense;
993
                distributed::ReduceOptions opts{op, dst};
994 995
                return self.Reduce(out_dense,
                                   in_dense,
996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
                                   opts,
                                   /*sync_op*/ true,
                                   /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_scatter_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_out_tensor,
1009
                 py::handle py_in_tensor_list,
1010
                 distributed::ReduceOp op) {
1011 1012 1013 1014 1015
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                auto out_dense = p_out_tensor.get();

1016 1017 1018
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
1019
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
1020
                    concat_in_tensor.impl());
1021
                auto in_dense = *p_in_tensor;
1022 1023

                distributed::ReduceScatterOptions opts{op};
1024 1025
                return self.ReduceScatter(out_dense,
                                          in_dense,
1026 1027 1028 1029 1030
                                          opts,
                                          /*sync_op*/ true,
                                          /*use_calc_stream*/ true);
              },
              py::arg("out"),
1031
              py::arg("in"),
1032 1033 1034 1035 1036 1037 1038
              py::arg("op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_scatter_tensor_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_out_tensor,
1039
                 py::handle py_in_tensor,
1040 1041
                 distributed::ReduceOp op) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
1042
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
1043
                    out_tensor.impl());
1044 1045 1046 1047 1048 1049
                auto out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
1050 1051

                distributed::ReduceScatterOptions opts{op};
1052 1053
                return self.ReduceScatter(out_dense,
                                          in_dense,
1054 1055 1056
                                          opts,
                                          /*sync_op*/ true,
                                          /*use_calc_stream*/ true);
1057 1058
              },
              py::arg("out"),
1059
              py::arg("in"),
1060 1061 1062 1063 1064 1065 1066
              py::arg("op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_out_tensor,
1067
                 py::handle py_in_tensor_list,
1068
                 int src) {
1069 1070 1071 1072 1073
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                auto *out_dense = p_out_tensor.get();

1074 1075 1076
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
1077
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
1078
                    concat_in_tensor.impl());
1079
                auto in_dense = *p_in_tensor;
1080 1081

                distributed::ScatterOptions opts{src};
1082 1083
                return self.Scatter(out_dense,
                                    in_dense,
1084 1085 1086 1087 1088
                                    opts,
                                    /*sync_op*/ true,
                                    /*use_calc_stream*/ true);
              },
              py::arg("out"),
1089
              py::arg("in"),
1090 1091 1092 1093 1094 1095 1096
              py::arg("src"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter_tensor_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_out_tensor,
1097
                 py::handle py_in_tensor,
1098 1099
                 int src) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
1100
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
1101
                    out_tensor.impl());
1102 1103 1104 1105 1106 1107
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
1108 1109

                distributed::ScatterOptions opts{src};
1110 1111
                return self.Scatter(out_dense,
                                    in_dense,
1112 1113 1114 1115 1116
                                    opts,
                                    /*sync_op*/ true,
                                    /*use_calc_stream*/ true);
              },
              py::arg("out"),
1117
              py::arg("in"),
1118
              py::arg("src"),
1119 1120
              py::call_guard<py::gil_scoped_release>())

1121 1122 1123 1124 1125 1126
          .def(
              "send_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int dst) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
1127
                auto p_dense =
1128
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
1129
                auto *out_dense = p_dense.get();
1130
                // numel == -1 indicates sending the whole tensor
1131
                return self.Send(out_dense,
1132
                                 dst,
1133 1134
                                 /*offset*/ 0,
                                 /*numel*/ -1,
1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149
                                 /*sync_op*/ true,
                                 /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "send_partial_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int dst_rank,
                 int nranks,
                 int rank_id) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
1150
                auto p_dense =
1151
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
1152 1153
                auto *out_dense = p_dense.get();

1154
                int64_t numel = p_dense->numel();
1155 1156
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
1157 1158 1159 1160 1161 1162 1163

                return self.Send(out_dense,
                                 dst_rank,
                                 offset,
                                 send_numel,
                                 /*sync_op*/ true,
                                 /*use_calc_stream*/ true);
1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("num"),
              py::arg("id"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int src) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
1177
                auto p_dense =
1178
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
1179
                auto *in_dense = p_dense.get();
1180
                // numel == -1 indicates receiving the whole tensor
1181
                return self.Recv(in_dense,
1182
                                 src,
1183 1184
                                 /*offset*/ 0,
                                 /*numel*/ -1,
1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199
                                 /*sync_op*/ true,
                                 /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv_partial_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int src_rank,
                 int nranks,
                 int rank_id) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
1200
                auto p_dense =
1201
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
1202 1203
                auto *out_dense = p_dense.get();

1204
                int64_t numel = p_dense->numel();
1205 1206
                int64_t recv_numel = numel / nranks;
                int64_t offset = recv_numel * rank_id;
1207 1208 1209 1210 1211 1212 1213

                return self.Recv(out_dense,
                                 src_rank,
                                 offset,
                                 recv_numel,
                                 /*sync_op*/ true,
                                 /*use_calc_stream*/ true);
1214 1215 1216 1217 1218
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("num"),
              py::arg("id"),
1219 1220
              py::call_guard<py::gil_scoped_release>());

1221
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
L
LiYuRio 已提交
1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233
  py::class_<distributed::ProcessGroupNCCL,
             std::shared_ptr<distributed::ProcessGroupNCCL>>(
      *m, "ProcessGroupNCCL", ProcessGroupStream)
      .def_static("create",
                  distributed::ProcessGroupNCCL::CreateProcessGroupNCCL,
                  py::arg("store"),
                  py::arg("rank"),
                  py::arg("world_size"),
                  py::arg("group_id") = 0,
                  py::call_guard<py::gil_scoped_release>())
      .def_static("group_start", distributed::ProcessGroupNCCL::GroupStart)
      .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd);
1234

1235
#endif
1236

W
wuhuachaocoding 已提交
1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255
#if defined(PADDLE_WITH_MPI)
  py::class_<distributed::ProcessGroupMPI,
             std::shared_ptr<distributed::ProcessGroupMPI>>(
      *m, "ProcessGroupMPI", ProcessGroup)
      .def_static(
          "create",
          [](const std::vector<int> &ranks,
             int gid) -> std::shared_ptr<distributed::ProcessGroupMPI> {
            return paddle::distributed::ProcessGroupMPI::CreateProcessGroupMPI(
                ranks, gid);
          })
      .def("get_rank",
           &distributed::ProcessGroup::GetRank,
           py::call_guard<py::gil_scoped_release>())
      .def("get_world_size",
           &distributed::ProcessGroup::GetSize,
           py::call_guard<py::gil_scoped_release>());
#endif

1256 1257 1258 1259
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
  py::class_<distributed::ProcessGroupCustom,
             std::shared_ptr<distributed::ProcessGroupCustom>>(
      *m, "ProcessGroupCustom", ProcessGroup)
L
LiYuRio 已提交
1260 1261 1262 1263 1264 1265 1266 1267
      .def_static("create",
                  distributed::ProcessGroupCustom::CreateProcessGroupCustom,
                  py::arg("store"),
                  py::arg("device_type"),
                  py::arg("rank"),
                  py::arg("world_size"),
                  py::arg("group_id") = 0,
                  py::call_guard<py::gil_scoped_release>());
1268

1269 1270
#endif

J
james 已提交
1271 1272 1273 1274
#if defined(PADDLE_WITH_XPU_BKCL)
  auto processGroupBKCL =
      py::class_<distributed::ProcessGroupBKCL,
                 std::shared_ptr<distributed::ProcessGroupBKCL>>(
1275
          *m, "ProcessGroupBKCL", ProcessGroupStream)
L
LiYuRio 已提交
1276 1277 1278 1279 1280 1281 1282
          .def_static("create",
                      distributed::ProcessGroupBKCL::CreateProcessGroupBKCL,
                      py::arg("store"),
                      py::arg("rank"),
                      py::arg("world_size"),
                      py::arg("group_id") = 0,
                      py::call_guard<py::gil_scoped_release>());
J
james 已提交
1283 1284
#endif

1285 1286 1287
  py::class_<distributed::ProcessGroup::Task,
             std::shared_ptr<distributed::ProcessGroup::Task>>(*m, "task")
      .def("is_completed", &distributed::ProcessGroup::Task::IsCompleted)
1288
      .def("is_sync", &distributed::ProcessGroup::Task::IsSync)
1289 1290
      .def("wait",
           &distributed::ProcessGroup::Task::Wait,
1291 1292
           py::arg("timeout") = kWaitTimeout,
           py::call_guard<py::gil_scoped_release>())
1293 1294
      .def("synchronize",
           &distributed::ProcessGroup::Task::Synchronize,
1295 1296
           py::call_guard<py::gil_scoped_release>());

1297 1298 1299
#if defined(PADDLE_WITH_GLOO)
  py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
      *m, "ProcessGroupGloo", ProcessGroup)
L
LiYuRio 已提交
1300 1301 1302 1303 1304 1305 1306
      .def_static("create",
                  distributed::ProcessGroupGloo::CreateProcessGroupGloo,
                  py::arg("store"),
                  py::arg("rank"),
                  py::arg("world_size"),
                  py::arg("group_id") = 0,
                  py::call_guard<py::gil_scoped_release>())
1307 1308 1309 1310
      .def_static("create_default_device",
                  &ProcessGroupGloo::createDefaultDevice);
#endif

1311 1312
  m->def(
      "eager_assign_group_by_size",
1313 1314
      [](py::handle py_tensors,
         std::vector<bool> is_sparse_gradient,
1315 1316 1317 1318 1319 1320
         std::vector<size_t> group_size_limits,
         std::vector<int64_t> tensor_indices) {
        auto tensors = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
        return distributed::Eager_AssignGroupBySize(
            tensors, is_sparse_gradient, group_size_limits, tensor_indices);
      },
1321 1322
      py::arg("tensors"),
      py::arg("is_sparse_gradient"),
1323 1324 1325
      py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024},
      py::arg("tensor_indices") = std::vector<int64_t>{},
      py::call_guard<py::gil_scoped_release>());
1326 1327

  py::class_<distributed::EagerReducer,
1328 1329
             std::shared_ptr<distributed::EagerReducer>>(
      *m, "EagerReducer", R"DOC()DOC")
1330
      .def(py::init(&CreateEagerReducer))
1331 1332
      .def(
          "prepare_for_backward",
1333
          [](distributed::EagerReducer &self, py::handle py_tensors) {
1334
            auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
1335
            self.PrepareForBackward(params);
1336
          },
1337 1338
          py::arg("tensors"),
          py::call_guard<py::gil_scoped_release>());
1339 1340 1341 1342
}

}  // end namespace pybind
}  // namespace paddle