collective_helper_npu.cc 4.5 KB
Newer Older
1
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include <utility>

namespace paddle {
namespace platform {

class HCCLCommImpl : public HCCLComm {
 public:
24 25
  void set_ring_id(int ring_id) { ring_id_ = ring_id; }
  int ring_id() const override { return ring_id_; }
26

27 28
  void set_nranks(int nranks) { nranks_ = nranks; }
  int nranks() const override { return nranks_; }
29

30 31 32 33 34 35
  void set_rank(int rank) { rank_ = rank; }
  int rank() const override { return rank_; }

  int device_id() const override {
    return BOOST_GET_CONST(NPUPlace, dev_ctx_->GetPlace()).device;
  }
36

L
lw921014 已提交
37 38 39 40 41 42 43
  ~HCCLCommImpl(){
    PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclCommDestroy(comm_));
  }

  void set_comm(HcclComm comm) { comm_ = comm; }
  HcclComm comm() const override { return comm_; }

44 45 46 47 48 49 50 51
  aclrtStream stream() const override { return dev_ctx_->stream(); }

  void set_dev_ctx(std::unique_ptr<NPUDeviceContext>&& dev_ctx) {
    dev_ctx_ = std::move(dev_ctx);
  }
  NPUDeviceContext* dev_context() const override { return dev_ctx_.get(); }

 private:
52 53 54
  int ring_id_;
  int nranks_;
  int rank_;
L
lw921014 已提交
55
  HcclComm comm_;
56 57 58
  std::unique_ptr<NPUDeviceContext> dev_ctx_;
};

L
lw921014 已提交
59 60 61 62 63
HCCLComm* HCCLCommContext::CreateHCCLComm(HcclRootInfo* hccl_id, int nranks,
                                          int rank, int dev_id, int ring_id) {
  PADDLE_ENFORCE_NOT_NULL(hccl_id,
                          platform::errors::InvalidArgument(
                              "The hccl unique id should not be null."));
64
  PADDLE_ENFORCE_GT(
L
lw921014 已提交
65
      nranks, 1,
66
      platform::errors::InvalidArgument(
L
lw921014 已提交
67
          "Expected nranks > 1. But received nranks is %d.", nranks));
68
  PADDLE_ENFORCE_GE(rank, 0,
69 70 71
                    platform::errors::InvalidArgument(
                        "Expected rank >= 0. But received rank is %d.", rank));
  PADDLE_ENFORCE_LT(
L
lw921014 已提交
72
      rank, nranks,
73
      platform::errors::InvalidArgument(
74
          "Expected rank < nranks. But received rank is %d, nranks is %d.",
L
lw921014 已提交
75
          rank, nranks));
76 77
  PADDLE_ENFORCE_GE(
      dev_id, 0,
78
      platform::errors::InvalidArgument(
79 80
          "Expected dev_id >= 0. But received dev_id is %d.", dev_id));

L
lw921014 已提交
81 82
  HcclComm comm;
  PADDLE_ENFORCE_NPU_SUCCESS(aclrtSetDevice(dev_id));
83
  PADDLE_ENFORCE_NPU_SUCCESS(
L
lw921014 已提交
84 85 86
      platform::dynload::HcclCommInitRootInfo(nranks, hccl_id, rank, &comm));

 VLOG(1) << "initialized comm: " << &comm  << ", nranks: " << nranks << ", hccl_id: " << hccl_id << ", rank: " << rank;
87

L
lw921014 已提交
88
  auto* comm_wrapper = AssignHCCLComm(comm, nranks, rank, dev_id, ring_id);
89 90

  VLOG(1) << "hccl communicator of rank " << rank << " in ring " << ring_id
L
lw921014 已提交
91
          << " has been created on device " << dev_id << ", with comm: " << comm_wrapper->comm();
92 93 94 95

  std::call_once(once_flag_, []() {
    std::atexit([]() { HCCLCommContext::Instance().ReleaseHCCLComms(); });
  });
96 97 98 99

  return comm_wrapper;
}

L
lw921014 已提交
100 101
HCCLComm* HCCLCommContext::AssignHCCLComm(HcclComm comm, int nranks, int rank,
                                          int dev_id, int ring_id) {
102
  std::unique_ptr<NPUDeviceContext> dev_ctx(
103
      new NPUDeviceContext(NPUPlace(dev_id)));
104 105

  HCCLCommImpl* c = new HCCLCommImpl;
106 107
  c->set_ring_id(ring_id);
  c->set_nranks(nranks);
108
  c->set_rank(rank);
L
lw921014 已提交
109
  c->set_comm(comm);
110
  c->set_dev_ctx(std::move(dev_ctx));
111 112 113 114 115 116 117 118 119 120

  comm_map_mutex_.lock();
  if (comm_map_.count(ring_id) == 0) {
    comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<HCCLComm>>());
  }
  auto& dev2comm = comm_map_[ring_id];

  dev2comm.emplace(dev_id, std::unique_ptr<HCCLComm>(c));
  comm_map_mutex_.unlock();

L
lw921014 已提交
121 122 123 124 125 126
  if (ring_id == 0) {
    auto* dev_ctx = static_cast<platform::NPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(
            platform::NPUPlace(dev_id)));
    dev_ctx->set_hccl_comm(comm);
  }
127

L
lw921014 已提交
128
  return comm_map_[ring_id][dev_id].get();
129 130 131 132 133 134 135 136
}

void HCCLCommContext::ReleaseHCCLComms() {
  for (auto& p : comm_map_) {
    for (auto& q : p.second) {
      q.second.reset();
    }
  }
137 138 139 140 141
}

}  // namespace platform
}  // namespace paddle
#endif