ProcessGroupCustom.cc 14.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/collective/ProcessGroupCustom.h"

#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/CustomCCLTools.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/place.h"

DECLARE_bool(xccl_blocking_wait);

constexpr int64_t kWaitBlockTImeout = 10;

namespace paddle {
namespace distributed {

void SyncDefaultStream(
    const std::vector<Place>& places,
    std::vector<CustomEventManager>& cclEvents,                    // NOLINT
    std::vector<std::unique_ptr<CustomDeviceContext>>& dev_ctx) {  // NOLINT
  for (size_t i = 0; i < places.size(); ++i) {
    auto* default_ctx = static_cast<platform::CustomDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(places[i]));
    cclEvents[i].Record(*dev_ctx[i]);
    cclEvents[i].Block(*default_ctx);
  }
}

std::shared_ptr<ProcessGroupCustom::CustomTask> ProcessGroupCustom::CreateTask(
    std::vector<Place> places,
    int rank,
    CommType comm_type,
    const std::vector<phi::DenseTensor>& inputs) {
  return std::make_shared<ProcessGroupCustom::CustomTask>(
      places, rank, comm_type, inputs);
}

ProcessGroupCustom::CustomTask::CustomTask(
    const std::vector<Place>& places,
    int rank,
    CommType CommType,
    const std::vector<phi::DenseTensor>& inputs)
    : Task(rank, inputs, CommType), places_(places) {
  control_events_.resize(places.size());
  cclComms_.resize(places.size());
}

ProcessGroupCustom::CustomTask::~CustomTask() {}

void ProcessGroupCustom::CustomTask::SetOutputs(
    std::vector<phi::DenseTensor>& outputs) {  // NOLINT
  outputs_ = std::make_shared<std::vector<phi::DenseTensor>>(outputs);
}

void ProcessGroupCustom::CustomTask::SynchronizeStreams() {
  for (size_t i = 0; i < places_.size(); ++i) {
    auto* default_ctx = static_cast<platform::CustomDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(places_[i]));
    phi::DeviceGuard guard(default_ctx->GetPlace());
    phi::stream::Stream stream(default_ctx->GetPlace(), default_ctx->stream());
    stream.WaitEvent(control_events_[i].GetCustomEvent());
  }
}

bool ProcessGroupCustom::CustomTask::IsCompleted() {
  for (size_t i = 0; i < places_.size(); ++i) {
    if (!control_events_[i].Query()) {
      return false;
    }
  }

  return true;
}

bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) {
  SynchronizeStreams();
  while (!IsCompleted()) {
    std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
  }
  return true;
}

// Same as Wait
void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr<Store>& store,
101
                                       const std::string& device_type,
102 103 104
                                       int rank,
                                       int size,
                                       int gid)
105
    : ProcessGroup(rank, size, gid), store_(store), device_type_(device_type) {}
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204

void ProcessGroupCustom::BroadcastUniqueCustomID(
    std::vector<phi::ccl::CCLRootId>& ccl_ids) {  // NOLINT
  if (rank_ == 0) {
    for (size_t i = 0; i < ccl_ids.size(); i++) {
      auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i);
      store_->set(key, ccl_ids[i]);
    }
  } else {
    for (size_t i = 0; i < ccl_ids.size(); i++) {
      auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i);
      ccl_ids[i] = store_->get(key);
    }
  }
}

// create CustomCCLManager cache for places_key
void ProcessGroupCustom::CreateCustomManagerCache(
    const std::string& places_key, const std::vector<Place>& places) {
  PADDLE_ENFORCE_EQ(places_key.empty(),
                    false,
                    platform::errors::PreconditionNotMet(
                        "Not able to create/get the HCCL Communicator since "
                        "the NPU place are not known"));
  const std::string device_type = places.back().GetDeviceType();

  std::vector<std::shared_ptr<CustomCCLCommManager>> ccl_comms;
  ccl_comms.resize(places.size());

  // using vector just for broadcast
  std::vector<phi::ccl::CCLRootId> ccl_ids;
  ccl_ids.resize(1);
  auto& ccl_id = ccl_ids.front();

  if (rank_ == 0) {
    phi::DeviceManager::CCLGetUniqueId(device_type, &ccl_id);
  }
  BroadcastUniqueCustomID(ccl_ids);

  VLOG(3) << "init custom ccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << places_key
          << ", custom ccl uniqueid: " << SerializeCustomCCLUniqueId(ccl_id);

  std::vector<std::unique_ptr<CustomDeviceContext>> dev_ctx;
  dev_ctx.resize(places.size());

  std::unique_ptr<phi::ccl::CCLComm> comms(
      new phi::ccl::CCLComm[places.size()]);
  for (size_t i = 0; i < places.size(); ++i) {
    phi::DeviceGuard guard(places[i]);
    ccl_comms[i] = CustomCCLCommManager::Create(
        device_type, GetSize(), GetRank(), &ccl_id, comms.get() + i);
    dev_ctx[i].reset(new CustomDeviceContext(places[i]));
  }

  std::vector<CustomEventManager> events;
  events.resize(places.size());

  // These caches will be useful to process sync/wait/communicate
  places_to_events_.emplace(places_key, std::move(events));
  places_to_customcomm_.emplace(places_key, std::move(ccl_comms));
  places_to_ctx_.emplace(places_key, std::move(dev_ctx));
}

template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
    std::vector<phi::DenseTensor>& inputs,
    std::vector<phi::DenseTensor>& outputs,
    Fn fn,
    CommType op_type) {
  const auto places = GetPlaceList(inputs);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
    if (places_to_customcomm_.find(key) == places_to_customcomm_.end()) {
      CreateCustomManagerCache(key, places);
    }
  }

  auto& ccl_comms = places_to_customcomm_[key];
  SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
  auto task = CreateTask(places, rank_, op_type, inputs);
  task->SetOutputs(outputs);

  for (size_t i = 0; i < inputs.size(); ++i) {
    phi::DeviceGuard guard(places[i]);
    const auto& ccl_stream = places_to_ctx_[key][i]->stream();
    phi::stream::Stream stream(places[i], ccl_stream);
    fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream);
  }

  for (size_t i = 0; i < inputs.size(); ++i) {
    phi::DeviceGuard guard(places[i]);
    task->control_events_[i].Record(*places_to_ctx_[key][i]);
  }
  return task;
}

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCustomPlace(in_tensors, device_type_),
      true,
      platform::errors::InvalidArgument(
          "All inputs should be in CustomPlace(%s).", device_type_));
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCustomPlace(out_tensors, device_type_),
      true,
      platform::errors::InvalidArgument(
          "All outputs should be in CustomPlace(%s).", device_type_));
  return Collective(
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          phi::ccl::CCLComm comm,
          const phi::stream::Stream& stream) {
        return phi::DeviceManager::CCLAllGather(
            device_type_,
            input.data(),
            output.data(),
            input.numel(),
            phi::ccl::ToCCLDataType(input.dtype()),
            comm,
            stream);
      },
      CommType::ALLGATHER);
}

void* XcclGetPointerByOffset(void* raw_pointer,
                             size_t offset,
                             experimental::DataType type) {
  if (type == experimental::DataType::FLOAT32) {
    return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::FLOAT64) {
    return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::INT32) {
    return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::INT64) {
    return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::FLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
                                   offset);
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "This datatype in xccl is not supported."));
  }
  return nullptr;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather_Partial(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
265 266
    int64_t offset,
    int64_t length) {
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCustomPlace(in_tensors, device_type_),
      true,
      platform::errors::InvalidArgument(
          "All inputs should be in CustomPlace(%s).", device_type_));
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCustomPlace(out_tensors, device_type_),
      true,
      platform::errors::InvalidArgument(
          "All outputs should be in CustomPlace(%s).", device_type_));
  return Collective(
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          phi::ccl::CCLComm comm,
          const phi::stream::Stream& stream) {
        return phi::DeviceManager::CCLAllGather(
            device_type_,
            XcclGetPointerByOffset(input.data(), offset, input.dtype()),
            output.data(),
            length,
            phi::ccl::ToCCLDataType(input.dtype()),
            comm,
            stream);
      },
      CommType::ALLGATHER);
}

296 297 298 299
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
    std::vector<phi::DenseTensor>& in_tensors,   // NOLINT
    std::vector<phi::DenseTensor>& out_tensors,  // NOLINT
    const AllreduceOptions& opts) {
300 301 302 303 304 305 306 307 308 309
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCustomPlace(in_tensors, device_type_),
      true,
      platform::errors::InvalidArgument(
          "All inputs should be in CustomPlace(%s).", device_type_));
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCustomPlace(out_tensors, device_type_),
      true,
      platform::errors::InvalidArgument(
          "All outputs should be in CustomPlace(%s).", device_type_));
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
  return Collective(
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          phi::ccl::CCLComm comm,
          const phi::stream::Stream& stream) {
        return phi::DeviceManager::CCLAllReduce(
            device_type_,
            input.data(),
            output.data(),
            input.numel(),
            phi::ccl::ToCCLDataType(input.dtype()),
            ToCustomCCLRedType(opts.reduce_op),
            comm,
            stream);
      },
      CommType::ALLREDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
    std::vector<phi::DenseTensor>& in_tensors,   // NOLINT
    std::vector<phi::DenseTensor>& out_tensors,  // NOLINT
    const BroadcastOptions& opts) {
334 335 336 337 338 339 340 341 342 343
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCustomPlace(in_tensors, device_type_),
      true,
      platform::errors::InvalidArgument(
          "All inputs should be in CustomPlace(%s).", device_type_));
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCustomPlace(out_tensors, device_type_),
      true,
      platform::errors::InvalidArgument(
          "All outputs should be in CustomPlace(%s).", device_type_));
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
  return Collective(
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          phi::ccl::CCLComm comm,
          const phi::stream::Stream& stream) {
        int root = opts.source_rank * in_tensors.size() + opts.source_root;
        if (rank_ == root) {
          return phi::DeviceManager::CCLBroadcast(
              device_type_,
              input.data(),
              input.numel(),
              phi::ccl::ToCCLDataType(input.dtype()),
              root,
              comm,
              stream);
        } else {
          return phi::DeviceManager::CCLBroadcast(
              device_type_,
              output.data(),
              output.numel(),
              phi::ccl::ToCCLDataType(output.dtype()),
              root,
              comm,
              stream);
        }
      },
      CommType::BROADCAST);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
    const BarrierOptions& opts) {
  // Only support single card single process
378 379 380 381 382 383
  PADDLE_ENFORCE_GE(opts.device_id,
                    0,
                    platform::errors::PreconditionNotMet(
                        "The barrier device id must greater or equal than 0."));
  platform::CustomPlace place(device_type_, opts.device_id);
  std::vector<phi::CustomPlace> places = {place};
384 385 386 387 388
  std::vector<phi::DenseTensor> barrierTensors;
  barrierTensors.reserve(places.size());

  for (auto& place : places) {
    phi::DeviceGuard guard(place);
389 390 391 392
    phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
    auto allocator = std::unique_ptr<phi::Allocator>(
        new paddle::experimental::DefaultAllocator(place));
    barrierTensors.emplace_back(allocator.get(), meta);
393 394 395 396 397 398 399
  }
  auto task = ProcessGroupCustom::AllReduce(barrierTensors, barrierTensors);
  auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
  xccl_task->barrierTensors_ = std::move(barrierTensors);
  return task;
}

400 401 402 403 404 405 406 407 408 409
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
  std::vector<Place> places = {place};
  const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
  PADDLE_ENFORCE_NE(iter,
                    places_to_customcomm_.end(),
                    platform::errors::InvalidArgument(
                        "Cannot find nccl comm in process group."));
  return iter->second[0]->GetCustomCCLComm();
}

410 411
}  //  namespace distributed
}  //  namespace paddle