distributed_py.cc 58.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif

#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif

#include "paddle/fluid/distributed/collective/ProcessGroup.h"
25
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
26
#include "paddle/fluid/distributed/collective/Types.h"
27
#include "paddle/fluid/distributed/collective/Utils.h"
28
#include "paddle/fluid/distributed/collective/reducer.h"
29 30 31 32 33 34 35
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/pybind/distributed_py.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/phi/api/all.h"

36
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
37 38 39
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif

W
wuhuachaocoding 已提交
40 41 42 43
#if defined(PADDLE_WITH_MPI)
#include "paddle/fluid/distributed/collective/ProcessGroupMPI.h"
#endif

44 45 46 47
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h"
#endif

48 49 50 51
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/distributed/collective/ProcessGroupCustom.h"
#endif

52 53 54 55 56
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL))
#include "paddle/fluid/distributed/collective/ProcessGroupHeter.h"
#endif

57 58 59 60 61
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#endif

62 63
#include "paddle/phi/kernels/sync_batch_norm_kernel.h"

64 65 66 67 68 69 70
namespace py = pybind11;

namespace paddle {
namespace pybind {

using Tensor = paddle::experimental::Tensor;

71 72 73 74 75
std::shared_ptr<distributed::EagerReducer> CreateEagerReducer(
    py::handle py_tensors,
    const std::vector<std::vector<size_t>> &group_indices,
    const std::vector<bool> &is_sparse_gradient,
    std::shared_ptr<distributed::ProcessGroup> process_group,
76 77
    const std::vector<size_t> &group_size_limits,
    bool find_unused_parameters) {
78
  auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
79 80 81 82 83 84
  return std::make_shared<distributed::EagerReducer>(params,
                                                     group_indices,
                                                     is_sparse_gradient,
                                                     process_group,
                                                     group_size_limits,
                                                     find_unused_parameters);
85 86
}

87 88 89 90 91 92 93 94
#if defined(PADDLE_WITH_GLOO)
using ProcessGroupGloo = paddle::distributed::ProcessGroupGloo;
using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore;
using GlooOptions = paddle::distributed::ProcessGroupGloo::GlooOptions;
#endif

static std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";  // NOLINT

95 96 97
static UNUSED void *use_ccl_comm_func =
    phi::detail::GetCCLComm(phi::CPUPlace());

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
void BindDistributed(py::module *m) {
  py::enum_<distributed::ReduceOp>(*m, "ReduceOp")
      .value("SUM", distributed::ReduceOp::SUM)
      .value("AVG", distributed::ReduceOp::AVG)
      .value("MAX", distributed::ReduceOp::MAX)
      .value("MIN", distributed::ReduceOp::MIN)
      .value("PRODUCT", distributed::ReduceOp::PRODUCT);

  py::class_<distributed::AllreduceOptions>(*m, "AllreduceOptions")
      .def(py::init<>())
      .def_readwrite("reduce_op", &distributed::AllreduceOptions::reduce_op);

  py::class_<distributed::BroadcastOptions>(*m, "BroadcastOptions")
      .def(py::init<>())
      .def_readwrite("source_rank", &distributed::BroadcastOptions::source_rank)
      .def_readwrite("source_root",
                     &distributed::BroadcastOptions::source_root);

B
Baibaifan 已提交
116 117 118 119
  py::class_<distributed::BarrierOptions>(*m, "BarrierOptions")
      .def(py::init<>())
      .def_readwrite("place_ids", &distributed::BarrierOptions::place_ids);

120 121 122 123 124
  py::class_<distributed::ReduceOptions>(*m, "ReduceOptions")
      .def(py::init<>())
      .def_readwrite("reduce_op", &distributed::ReduceOptions::reduce_op)
      .def_readwrite("source_root", &distributed::ReduceOptions::root_rank);

125 126 127 128 129 130
  auto ProcessGroup =
      py::class_<distributed::ProcessGroup,
                 std::shared_ptr<distributed::ProcessGroup>>(*m, "ProcessGroup")
          .def("rank", &distributed::ProcessGroup::GetRank)
          .def("size", &distributed::ProcessGroup::GetSize)
          .def("name", &distributed::ProcessGroup::GetBackendName)
131
          .def(
L
LiYuRio 已提交
132
              "all_reduce",
133 134 135 136 137
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 distributed::ReduceOp op,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
138
                auto p_dense =
139
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
140 141 142 143
                auto *out_dense = p_dense.get();
                auto in_dense = *p_dense;
                distributed::AllreduceOptions opts{op};
                return self.AllReduce(out_dense, in_dense, opts, sync_op);
144 145 146 147 148 149
              },
              py::arg("tensor"),
              py::arg("op"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

150 151 152 153 154 155 156
          .def(
              "broadcast",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
157
                auto p_dense =
158
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
159 160 161 162
                auto *out_dense = p_dense.get();
                auto in_dense = *p_dense;
                distributed::BroadcastOptions opts{src};
                return self.Broadcast(out_dense, in_dense, opts, sync_op);
163 164 165 166 167 168
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
          .def(
              "send",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int dst,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Send(tensors, dst, sync_op);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "send_partial",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int dst_rank,
                 int nranks,
                 int rank_id,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
197 198 199
                int64_t numel = (*dense).numel();
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
                return self.Send_Partial(
                    *dense, dst_rank, offset, send_numel, sync_op);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("num"),
              py::arg("id"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Recv(tensors, src, sync_op);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv_partial",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src_rank,
                 int nranks,
                 int rank_id,
                 bool sync_op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
238 239 240
                int64_t numel = (*dense).numel();
                int64_t recv_numel = numel / nranks;
                int64_t offset = recv_numel * rank_id;
241 242 243 244 245 246 247 248 249 250
                return self.Recv_Partial(
                    *dense, src_rank, offset, recv_numel, sync_op);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("num"),
              py::arg("id"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

251 252
          .def(
              "all_gather",
253 254
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor_list,
255
                 py::handle py_in_tensor,
256 257 258 259
                 bool sync_op) {
                auto out_tensor_list =
                    CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
                Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
260
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
261
                    concat_out_tensor.impl());
262 263 264 265 266 267
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
268

269
                const auto &dev_ctx = self.GetDeviceContext(in_tensor.place());
270
                auto task = self.AllGather(out_dense, in_dense, sync_op);
271
                distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
272
                task->UpdateWaitChain(dev_ctx);
273 274 275
                return task;
              },
              py::arg("out"),
276
              py::arg("in"),
277 278 279 280
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
281
              "all_gather_into_tensor",
282 283
              [](distributed::ProcessGroup &self,
                 py::handle py_out_tensor,
284
                 py::handle py_in_tensor,
285 286
                 bool sync_op) {
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
287
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
288
                    out_tensor.impl());
289 290 291 292 293 294
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
295

296
                return self.AllGather(out_dense, in_dense, sync_op);
297 298
              },
              py::arg("out"),
299
              py::arg("in"),
300 301 302
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

303
          .def(
L
LiYuRio 已提交
304
              "all_to_all",
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor_list,
                 py::handle py_out_tensor_list,
                 bool sync_op) {
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor_list =
                    CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
                Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                // in_tensor_list should not be empty
324
                const auto &dev_ctx =
325 326 327
                    self.GetDeviceContext(in_tensor_list.back().place());
                auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op);
                distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
328
                task->UpdateWaitChain(dev_ctx);
329 330 331 332 333 334 335 336
                return task;
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
337
              "all_to_all_tensor",
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 bool sync_op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                return self.AllToAll(in_wrapper, out_wrapper, sync_op);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

359
          .def(
L
LiYuRio 已提交
360
              "all_to_all_single",
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 std::vector<int64_t> &in_sizes,
                 std::vector<int64_t> &out_sizes,
                 bool sync_op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                return self.AllToAllSingle(
                    in_wrapper, out_wrapper, in_sizes, out_sizes, sync_op);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("in_sizes"),
              py::arg("out_sizes"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 int dst,
                 distributed::ReduceOp op,
                 bool sync_op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                distributed::ReduceOptions opts{op, dst};
                auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Reduce(tensors, tensors, opts, sync_op);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("op"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_scatter",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor_list,
                 py::handle py_out_tensor,
                 distributed::ReduceOp op,
                 bool sync_op) {
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                distributed::ReduceScatterOptions opts{op};
                return self.ReduceScatter(
                    in_wrapper, out_wrapper, opts, sync_op);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("op"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_scatter_tensor",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 distributed::ReduceOp op,
                 bool sync_op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                distributed::ReduceScatterOptions opts{op};
                return self.ReduceScatter(
                    in_wrapper, out_wrapper, opts, sync_op);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("op"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor_list,
                 py::handle py_out_tensor,
                 int src,
                 bool sync_op) {
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                distributed::ScatterOptions opts{src};
                return self.Scatter(in_wrapper, out_wrapper, opts, sync_op);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("src"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter_tensor",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 int src,
                 bool sync_op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                distributed::ScatterOptions opts{src};
                return self.Scatter(in_wrapper, out_wrapper, opts, sync_op);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("src"),
              py::arg("sync_op"),
              py::call_guard<py::gil_scoped_release>())

L
LiYuRio 已提交
517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763
          .def(
              "barrier",
              [](distributed::ProcessGroup &self, std::vector<int> place_ids) {
                distributed::BarrierOptions opts;
                opts.place_ids = place_ids;
                return self.Barrier(opts);
              },
              py::arg("place_ids") = std::vector<int>{},
              py::call_guard<py::gil_scoped_release>())

          // TODO(liyurui): Interface below will be removed in the future.
          .def(
              "allreduce",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 distributed::ReduceOp op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                distributed::AllreduceOptions opts;
                opts.reduce_op = op;
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.AllReduce(tensors, tensors, opts);
              },
              py::arg("tensor"),
              py::arg("op") = distributed::ReduceOp::SUM,
              py::call_guard<py::gil_scoped_release>())

          .def(
              "broadcast",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int source_rank) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                distributed::BroadcastOptions opts;
                opts.source_rank = source_rank;
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Broadcast(tensors, tensors, opts);
              },
              py::arg("tensor"),
              py::arg("source_rank"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "send",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int dst) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Send(tensors, dst);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "send_partial",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int dst_rank,
                 int nranks,
                 int rank_id) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                int64_t numel = (*dense).numel();
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
                return self.Send_Partial(*dense, dst_rank, offset, send_numel);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("num"),
              py::arg("id"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Recv(tensors, src);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv_partial",
              [](distributed::ProcessGroup &self,
                 py::handle py_tensor,
                 int src_rank,
                 int nranks,
                 int rank_id) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                int64_t numel = (*dense).numel();
                int64_t recv_numel = numel / nranks;
                int64_t offset = recv_numel * rank_id;
                return self.Recv_Partial(*dense, src_rank, offset, recv_numel);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("num"),
              py::arg("id"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "all_gather",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                return self.AllGather(in_tensors, out_tensors);
              },
              py::arg("in"),
              py::arg("out"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "all_gather_partial",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 int nranks,
                 int rank_id) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                int64_t numel = (*in_dense).numel();
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
                return self.AllGather_Partial(
                    in_tensors, out_tensors, offset, send_numel);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("num"),
              py::arg("id"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "alltoall",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                return self.AllToAll(in_tensors, out_tensors);
              },
              py::arg("in"),
              py::arg("out"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "alltoall_single",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 std::vector<int64_t> in_sizes,
                 std::vector<int64_t> out_sizes) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                return self.AllToAll_Single(
                    in_tensors, out_tensors, in_sizes, out_sizes);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("in_sizes"),
              py::arg("out_sizes"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 int dst,
                 distributed::ReduceOp op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                distributed::ReduceOptions opts;
                opts.reduce_op = op;
                opts.root_rank = dst;
                auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Reduce(tensors, tensors, opts);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("op") = distributed::ReduceOp::SUM,
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter",
              [](distributed::ProcessGroup &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 int src) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                distributed::ScatterOptions opts;
                opts.root_rank = src;
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                return self.Scatter(in_tensors, out_tensors, opts);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("src"),
764
              py::call_guard<py::gil_scoped_release>());
765

766 767 768 769
  auto ProcessGroupStream =
      py::class_<distributed::ProcessGroupStream,
                 std::shared_ptr<distributed::ProcessGroupStream>>(
          *m, "ProcessGroupStream", ProcessGroup)
770
          .def(
L
LiYuRio 已提交
771
              "all_gather_on_calc_stream",
772
              [](distributed::ProcessGroupStream &self,
773 774
                 py::handle py_out_tensor_list,
                 py::handle py_in_tensor) {
775 776 777
                auto out_tensor_list =
                    CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
                Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
778
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
779
                    concat_out_tensor.impl());
780 781 782 783 784 785
                auto *out_dense = p_out_tensor.get();

                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;
786

787
                const auto &dev_ctx =
788
                    self.GetDeviceContext(in_tensor.place(), true);
789 790
                auto task = self.AllGather(out_dense,
                                           in_dense,
791 792 793 794 795 796
                                           /*sync_op*/ true,
                                           /*use_calc_stream*/ true);
                distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
                return task;
              },
              py::arg("out"),
797
              py::arg("in"),
798 799 800
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
801
              "all_gather_into_tensor_on_calc_stream",
802
              [](distributed::ProcessGroupStream &self,
803 804
                 py::handle py_out_tensor,
                 py::handle py_in_tensor) {
805
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
806
                auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
807
                    out_tensor.impl());
808
                auto *out_dense = p_out_tensor.get();
809

810 811 812 813 814 815 816
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto in_dense = *p_in_tensor;

                return self.AllGather(out_dense,
                                      in_dense,
817 818 819 820
                                      /*sync_op*/ true,
                                      /*use_calc_stream*/ true);
              },
              py::arg("out"),
821
              py::arg("in"),
822 823
              py::call_guard<py::gil_scoped_release>())

824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
          .def(
              "all_gather_partial_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 int nranks,
                 int rank_id) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> in_tensors = {*in_dense};
                std::vector<phi::DenseTensor> out_tensors = {*out_dense};
                int64_t numel = (*in_dense).numel();
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
                return self.AllGather_Partial(in_tensors,
                                              out_tensors,
                                              offset,
                                              send_numel,
                                              /*sync_op*/ true,
                                              /*use_calc_stream*/ true);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("num"),
              py::arg("id"),
              py::call_guard<py::gil_scoped_release>())

855
          .def(
L
LiYuRio 已提交
856
              "all_reduce_on_calc_stream",
857 858 859 860
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 distributed::ReduceOp op) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
861
                auto p_dense =
862
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
863 864 865 866 867
                auto in_dense = *p_dense;
                auto *out_dense = p_dense.get();
                distributed::AllreduceOptions opts{op};
                return self.AllReduce(out_dense,
                                      in_dense,
868 869 870 871 872
                                      opts,
                                      /*sync_op*/ true,
                                      /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
L
LiYuRio 已提交
873
              py::arg("op") = distributed::ReduceOp::SUM,
874 875
              py::call_guard<py::gil_scoped_release>())

876
          .def(
L
LiYuRio 已提交
877
              "all_to_all_on_calc_stream",
878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor_list,
                 py::handle py_out_tensor_list) {
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor_list =
                    CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
                Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                // in_tensor_list must not be empty
896
                const auto &dev_ctx = self.GetDeviceContext(
897 898 899 900 901 902 903 904 905 906 907 908 909
                    in_tensor_list.back().place(), /*use_calc_stream*/ true);
                auto task = self.AllToAll(in_wrapper,
                                          out_wrapper,
                                          /*sync_op*/ true,
                                          /*use_calc_stream*/ true);
                distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
                return task;
              },
              py::arg("in"),
              py::arg("out"),
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
910
              "all_to_all_tensor_on_calc_stream",
911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                return self.AllToAll(in_wrapper,
                                     out_wrapper,
                                     /*sync_op*/ true,
                                     /*use_calc_stream*/ true);
              },
              py::arg("in"),
              py::arg("out"),
              py::call_guard<py::gil_scoped_release>())

          .def(
L
LiYuRio 已提交
934
              "all_to_all_single_on_calc_stream",
935 936 937
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
938 939
                 std::vector<int64_t> &in_sizes,
                 std::vector<int64_t> &out_sizes) {
940
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
941 942 943 944
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

945
                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                return self.AllToAllSingle(in_wrapper,
                                           out_wrapper,
                                           in_sizes,
                                           out_sizes,
                                           /*sync_op*/ true,
                                           /*use_calc_stream*/ true);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("in_sizes"),
              py::arg("out_sizes"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "broadcast_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int src) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
969
                auto p_dense =
970
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
971 972 973 974 975
                auto *out_dense = p_dense.get();
                auto in_dense = *p_dense;
                distributed::BroadcastOptions opts{src};
                return self.Broadcast(out_dense,
                                      in_dense,
976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
                                      opts,
                                      /*sync_op*/ true,
                                      /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor,
                 int dst,
                 distributed::ReduceOp op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                distributed::ReduceOptions opts{op, dst};
                auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Reduce(tensors,
                                   tensors,
                                   opts,
                                   /*sync_op*/ true,
                                   /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_scatter_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor_list,
                 py::handle py_out_tensor,
                 distributed::ReduceOp op) {
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                distributed::ReduceScatterOptions opts{op};
                return self.ReduceScatter(in_wrapper,
                                          out_wrapper,
                                          opts,
                                          /*sync_op*/ true,
                                          /*use_calc_stream*/ true);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "reduce_scatter_tensor_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 distributed::ReduceOp op) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
1043 1044
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
1045 1046 1047
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
1048 1049
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
1050 1051 1052 1053 1054 1055 1056 1057
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                distributed::ReduceScatterOptions opts{op};
                return self.ReduceScatter(in_wrapper,
                                          out_wrapper,
                                          opts,
                                          /*sync_op*/ true,
                                          /*use_calc_stream*/ true);
1058 1059 1060
              },
              py::arg("in"),
              py::arg("out"),
1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119
              py::arg("op"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor_list,
                 py::handle py_out_tensor,
                 int src) {
                auto in_tensor_list =
                    CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
                Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    concat_in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                distributed::ScatterOptions opts{src};
                return self.Scatter(in_wrapper,
                                    out_wrapper,
                                    opts,
                                    /*sync_op*/ true,
                                    /*use_calc_stream*/ true);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("src"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "scatter_tensor_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_in_tensor,
                 py::handle py_out_tensor,
                 int src) {
                auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
                auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    in_tensor.impl());
                std::vector<phi::DenseTensor> in_wrapper = {*in_dense};

                auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
                auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
                    out_tensor.impl());
                std::vector<phi::DenseTensor> out_wrapper = {*out_dense};

                distributed::ScatterOptions opts{src};
                return self.Scatter(in_wrapper,
                                    out_wrapper,
                                    opts,
                                    /*sync_op*/ true,
                                    /*use_calc_stream*/ true);
              },
              py::arg("in"),
              py::arg("out"),
              py::arg("src"),
1120 1121
              py::call_guard<py::gil_scoped_release>())

1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149
          .def(
              "send_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int dst) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Send(tensors,
                                 dst,
                                 /*sync_op*/ true,
                                 /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "send_partial_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int dst_rank,
                 int nranks,
                 int rank_id) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
1150 1151 1152
                int64_t numel = (*dense).numel();
                int64_t send_numel = numel / nranks;
                int64_t offset = send_numel * rank_id;
1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193
                return self.Send_Partial(*dense,
                                         dst_rank,
                                         offset,
                                         send_numel,
                                         /*sync_op*/ true,
                                         /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("dst"),
              py::arg("num"),
              py::arg("id"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int src) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
                std::vector<phi::DenseTensor> tensors = {*dense};
                return self.Recv(tensors,
                                 src,
                                 /*sync_op*/ true,
                                 /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::call_guard<py::gil_scoped_release>())

          .def(
              "recv_partial_on_calc_stream",
              [](distributed::ProcessGroupStream &self,
                 py::handle py_tensor,
                 int src_rank,
                 int nranks,
                 int rank_id) {
                auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
                auto dense =
                    std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
1194 1195 1196
                int64_t numel = (*dense).numel();
                int64_t recv_numel = numel / nranks;
                int64_t offset = recv_numel * rank_id;
1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207
                return self.Recv_Partial(*dense,
                                         src_rank,
                                         offset,
                                         recv_numel,
                                         /*sync_op*/ true,
                                         /*use_calc_stream*/ true);
              },
              py::arg("tensor"),
              py::arg("src"),
              py::arg("num"),
              py::arg("id"),
1208 1209
              py::call_guard<py::gil_scoped_release>());

1210
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
1211 1212 1213
  auto processGroupNCCL =
      py::class_<distributed::ProcessGroupNCCL,
                 std::shared_ptr<distributed::ProcessGroupNCCL>>(
1214
          *m, "ProcessGroupNCCL", ProcessGroupStream)
1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231
          .def(py::init<const std::shared_ptr<distributed::Store> &,
                        int,
                        int,
                        const platform::CUDAPlace &,
                        int>(),
               py::arg("store"),
               py::arg("rank"),
               py::arg("world_size"),
               py::arg("place"),
               py::arg("group_id") = 0,
               py::call_guard<py::gil_scoped_release>());

  processGroupNCCL.def_static(
      "group_start", []() { distributed::ProcessGroupNCCL::GroupStart(); });
  processGroupNCCL.def_static(
      "group_end", []() { distributed::ProcessGroupNCCL::GroupEnd(); });

1232
#endif
1233

W
wuhuachaocoding 已提交
1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252
#if defined(PADDLE_WITH_MPI)
  py::class_<distributed::ProcessGroupMPI,
             std::shared_ptr<distributed::ProcessGroupMPI>>(
      *m, "ProcessGroupMPI", ProcessGroup)
      .def_static(
          "create",
          [](const std::vector<int> &ranks,
             int gid) -> std::shared_ptr<distributed::ProcessGroupMPI> {
            return paddle::distributed::ProcessGroupMPI::CreateProcessGroupMPI(
                ranks, gid);
          })
      .def("get_rank",
           &distributed::ProcessGroup::GetRank,
           py::call_guard<py::gil_scoped_release>())
      .def("get_world_size",
           &distributed::ProcessGroup::GetSize,
           py::call_guard<py::gil_scoped_release>());
#endif

1253 1254 1255 1256 1257
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL))
  py::class_<distributed::ProcessGroupHeter,
             std::shared_ptr<distributed::ProcessGroupHeter>>(
      *m, "ProcessGroupHeter", ProcessGroup)
1258 1259 1260
      .def(py::init<const std::shared_ptr<distributed::Store> &,
                    int,
                    int,
1261 1262 1263 1264 1265
#if defined(PADDLE_WITH_ASCEND_CL)
                    const platform::NPUPlace &,
#else
                    const platform::CUDAPlace &,
#endif
1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288
                    int,
                    int,
                    int,
                    int,
                    int,
                    bool,
                    std::string,
                    int,
                    int>(),
           py::arg("store"),
           py::arg("rank"),
           py::arg("world_size"),
           py::arg("place"),
           py::arg("gid") = 0,
           py::arg("local_rank") = 0,
           py::arg("local_size") = 1,
           py::arg("gloo_rank") = 0,
           py::arg("gloo_size") = 1,
           py::arg("with_switch") = false,
           py::arg("switch_endpoint") = "",
           py::arg("src_rank") = "",
           py::arg("dst_rank") = "",
           py::call_guard<py::gil_scoped_release>());
1289
#endif
1290

1291 1292 1293 1294
#if defined(PADDLE_WITH_ASCEND_CL)
  py::class_<distributed::ProcessGroupHCCL,
             std::shared_ptr<distributed::ProcessGroupHCCL>>(
      *m, "ProcessGroupHCCL", ProcessGroup)
1295 1296 1297 1298 1299 1300 1301 1302 1303 1304
      .def(py::init<const std::shared_ptr<distributed::Store> &,
                    int,
                    int,
                    const platform::NPUPlace &,
                    int>(),
           py::arg("store"),
           py::arg("rank"),
           py::arg("world_size"),
           py::arg("place"),
           py::arg("group_id") = 0,
1305
           py::call_guard<py::gil_scoped_release>());
1306

1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324
#endif

#if defined(PADDLE_WITH_CUSTOM_DEVICE)
  py::class_<distributed::ProcessGroupCustom,
             std::shared_ptr<distributed::ProcessGroupCustom>>(
      *m, "ProcessGroupCustom", ProcessGroup)
      .def(py::init<const std::shared_ptr<distributed::Store> &,
                    int,
                    int,
                    const platform::CustomPlace &,
                    int>(),
           py::arg("store"),
           py::arg("rank"),
           py::arg("world_size"),
           py::arg("place"),
           py::arg("group_id") = 0,
           py::call_guard<py::gil_scoped_release>());

1325 1326
#endif

1327 1328 1329
  py::class_<distributed::ProcessGroup::Task,
             std::shared_ptr<distributed::ProcessGroup::Task>>(*m, "task")
      .def("is_completed", &distributed::ProcessGroup::Task::IsCompleted)
1330
      .def("is_sync", &distributed::ProcessGroup::Task::IsSync)
1331 1332
      .def("wait",
           &distributed::ProcessGroup::Task::Wait,
1333 1334
           py::arg("timeout") = kWaitTimeout,
           py::call_guard<py::gil_scoped_release>())
1335 1336
      .def("synchronize",
           &distributed::ProcessGroup::Task::Synchronize,
1337 1338
           py::call_guard<py::gil_scoped_release>());

1339 1340 1341
#if defined(PADDLE_WITH_GLOO)
  py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
      *m, "ProcessGroupGloo", ProcessGroup)
1342 1343 1344 1345 1346
      .def(py::init<const std::shared_ptr<paddle::distributed::Store> &,
                    int,
                    int,
                    const platform::CPUPlace &,
                    int,
1347
                    std::shared_ptr<GlooOptions> &>(),
1348
           py::call_guard<py::gil_scoped_release>())
1349
      .def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store,
1350 1351 1352 1353
                       int rank,
                       int world_size,
                       const platform::CPUPlace &place,
                       int gid) {
1354 1355 1356 1357 1358 1359 1360 1361
             auto opts = GlooOptions::create();
             char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
             if (ifname && strlen(ifname) > 1) {
               opts->device = ProcessGroupGloo::createDeviceForInterface(
                   std::string(ifname));
             } else {
               opts->device = ProcessGroupGloo::createDefaultDevice();
             }
1362 1363
             return std::make_shared<ProcessGroupGloo>(
                 store, rank, world_size, place, gid, opts);
1364
           }),
1365 1366 1367 1368 1369
           py::arg("store"),
           py::arg("rank"),
           py::arg("world_size"),
           py::arg("place"),
           py::arg("group_id") = 0,
1370
           py::call_guard<py::gil_scoped_release>())
1371 1372 1373 1374
      .def_static("create_default_device",
                  &ProcessGroupGloo::createDefaultDevice);
#endif

1375 1376
  m->def(
      "eager_assign_group_by_size",
1377 1378
      [](py::handle py_tensors,
         std::vector<bool> is_sparse_gradient,
1379 1380 1381 1382 1383 1384
         std::vector<size_t> group_size_limits,
         std::vector<int64_t> tensor_indices) {
        auto tensors = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
        return distributed::Eager_AssignGroupBySize(
            tensors, is_sparse_gradient, group_size_limits, tensor_indices);
      },
1385 1386
      py::arg("tensors"),
      py::arg("is_sparse_gradient"),
1387 1388 1389
      py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024},
      py::arg("tensor_indices") = std::vector<int64_t>{},
      py::call_guard<py::gil_scoped_release>());
1390 1391

  py::class_<distributed::EagerReducer,
1392 1393
             std::shared_ptr<distributed::EagerReducer>>(
      *m, "EagerReducer", R"DOC()DOC")
1394
      .def(py::init(&CreateEagerReducer))
1395 1396
      .def(
          "prepare_for_backward",
1397
          [](distributed::EagerReducer &self, py::handle py_tensors) {
1398
            auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
1399
            self.PrepareForBackward(params);
1400
          },
1401 1402
          py::arg("tensors"),
          py::call_guard<py::gil_scoped_release>());
1403 1404 1405 1406
}

}  // end namespace pybind
}  // namespace paddle