collective_helper_npu.cc 5.0 KB
Newer Older
1
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include <utility>

namespace paddle {
namespace platform {

class HCCLCommImpl : public HCCLComm {
 public:
24 25
  void set_ring_id(int ring_id) { ring_id_ = ring_id; }
  int ring_id() const override { return ring_id_; }
26

27 28
  void set_nranks(int nranks) { nranks_ = nranks; }
  int nranks() const override { return nranks_; }
29

30 31 32 33 34 35
  void set_rank(int rank) { rank_ = rank; }
  int rank() const override { return rank_; }

  int device_id() const override {
    return BOOST_GET_CONST(NPUPlace, dev_ctx_->GetPlace()).device;
  }
36 37 38 39 40 41 42 43 44

  aclrtStream stream() const override { return dev_ctx_->stream(); }

  void set_dev_ctx(std::unique_ptr<NPUDeviceContext>&& dev_ctx) {
    dev_ctx_ = std::move(dev_ctx);
  }
  NPUDeviceContext* dev_context() const override { return dev_ctx_.get(); }

 private:
45 46 47
  int ring_id_;
  int nranks_;
  int rank_;
48 49 50
  std::unique_ptr<NPUDeviceContext> dev_ctx_;
};

51 52 53 54 55
HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id) {
  PADDLE_ENFORCE_GT(
      world_rank_ids.size(), 1,
      platform::errors::InvalidArgument(
          "Expected world_rank_ids.size() > 1. But received size is %d.", world_rank_ids.size()));
56
  PADDLE_ENFORCE_GE(rank, 0,
57 58 59 60
                    platform::errors::InvalidArgument(
                        "Expected rank >= 0. But received rank is %d.", rank));
  PADDLE_ENFORCE_LT(
      rank, world_rank_ids.size(),
61
      platform::errors::InvalidArgument(
62 63 64 65
          "Expected rank < nranks. But received rank is %d, nranks is %d.",
          rank, world_rank_ids.size()));
  PADDLE_ENFORCE_GE(
      dev_id, 0,
66
      platform::errors::InvalidArgument(
67 68 69 70 71 72 73 74 75 76 77 78 79
          "Expected dev_id >= 0. But received dev_id is %d.", dev_id));
  PADDLE_ENFORCE_GE(
      ring_id, 0,
      platform::errors::InvalidArgument(
          "Expected ring_id >= 0. But received ring_id is %d.", ring_id));

  auto* comm_wrapper = AssignHCCLComm(world_rank_ids.size(), rank, dev_id, ring_id);

  // HACK(sunpeng17): hcom API requires bind stream to a model
  // but we don't need model in Paddle, so we feed stream pointer as model pointer
  PADDLE_ENFORCE_NPU_SUCCESS(
      platform::dynload::hcom_bind_model(comm_wrapper->stream(),
                                         comm_wrapper->stream()));
80

81 82 83 84 85 86 87 88 89 90 91
  // Get world_rank_ids registered in gen_nccl_id op
  std::string group_name = HCOM_GROUP_PREFIX + std::to_string(ring_id);
  PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_create_group(
      group_name.c_str(), world_rank_ids.size(), (unsigned int*)world_rank_ids.data()));

  VLOG(1) << "hccl communicator of rank " << rank << " in ring " << ring_id
          << " has been created on device " << dev_id << ", group name: " << group_name;

  std::call_once(once_flag_, []() {
    std::atexit([]() { HCCLCommContext::Instance().ReleaseHCCLComms(); });
  });
92 93 94 95

  return comm_wrapper;
}

96
HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id) {
97
  std::unique_ptr<NPUDeviceContext> dev_ctx(
98
      new NPUDeviceContext(NPUPlace(dev_id)));
99 100

  HCCLCommImpl* c = new HCCLCommImpl;
101 102
  c->set_ring_id(ring_id);
  c->set_nranks(nranks);
103 104
  c->set_rank(rank);
  c->set_dev_ctx(std::move(dev_ctx));
105 106 107 108 109 110 111 112 113 114 115

  comm_map_mutex_.lock();
  if (comm_map_.count(ring_id) == 0) {
    comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<HCCLComm>>());
  }
  auto& dev2comm = comm_map_[ring_id];

  dev2comm.emplace(dev_id, std::unique_ptr<HCCLComm>(c));
  comm_map_mutex_.unlock();

  return comm_map_[ring_id][dev_id].get();
116 117
}

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
void HCCLCommContext::InitHcomWorldGroup() {
  const char *rank_table_file = getenv(ENV_RANK_TABLE_FILE);
  PADDLE_ENFORCE_NOT_NULL(
      rank_table_file,
      platform::errors::InvalidArgument("The RANK_TABLE_FILE environment variable should not be null."));

  const char *rank_id = getenv(ENV_RANK_ID);
  PADDLE_ENFORCE_NOT_NULL(
      rank_id,
      platform::errors::InvalidArgument("The RANK_ID environment variable should not be null."));

  PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_init(rank_table_file, rank_id));
  VLOG(3) << "Successfully initialized hcom. rank_table_file: "
    << rank_table_file << ", rank_id " << rank_id;
}

void HCCLCommContext::ReleaseHCCLComms() {
  for (auto& p : comm_map_) {
    for (auto& q : p.second) {
      q.second.reset();
    }
  }
140 141 142 143 144
}

}  // namespace platform
}  // namespace paddle
#endif