ProcessGroupBKCL.cc 19.4 KB
Newer Older
J
james 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/collective/ProcessGroupBKCL.h"

#include "paddle/fluid/distributed/collective/BKCLTools.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
23
#include "paddle/phi/core/errors.h"
J
james 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59

namespace paddle {
namespace distributed {
using XPUDeviceContext = paddle::platform::XPUDeviceContext;

ProcessGroupBKCL::BKCLTask::BKCLTask(const Place& place,
                                     int rank,
                                     CommType comm_type,
                                     bool sync_op,
                                     bool use_calc_stream)
    : TaskStream(rank, comm_type, sync_op, use_calc_stream), place_(place) {
  comm_event_ = std::make_shared<XPUEventManager>();
}

ProcessGroupBKCL::BKCLTask::~BKCLTask() {}

bool ProcessGroupBKCL::BKCLTask::IsCompleted() {
  LOG_FIRST_N(WARNING, 1) << "XPU do not support event query now.";
  return true;
}

// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupBKCL::BKCLTask::Wait(std::chrono::milliseconds timeout) {
  // Warning here when use calc stream but also invoke waiting explicitly.
  if (UseCalcStream()) {
    VLOG(3) << "Warning: The communication is on calc stream, wait here is "
               "useless.";
    return true;
  }

  const auto* calc_ctx = static_cast<XPUContext*>(
      platform::DeviceContextPool::Instance().Get(place_));
  comm_event_->Block(*calc_ctx);

  if (barrier_) {
    // If we use the work to do barrier, we should block cpu
60 61 62 63 64

    // TODO(zhangxiaoci) There is no such function that can sync entire device
    // for xpu (for now), so all we can do is sync whatever stream that we know
    // and hope for the best. Note that for correctness the communication stream
    // needs to be in sync mode.
J
james 已提交
65 66
    platform::XPUDeviceGuard guard(place_.GetDeviceId());
    xpu_wait();
67
    calc_ctx->Wait();
J
james 已提交
68 69 70 71 72 73 74 75 76 77 78
  }
  return true;
}

// Same as Wait
void ProcessGroupBKCL::BKCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr<Store>& store,
                                   int rank,
                                   int size,
                                   int gid)
79
    : ProcessGroupStream(rank, size, gid), store_(store) {}
J
james 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

void ProcessGroupBKCL::GroupStart() {
  PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
}

void ProcessGroupBKCL::GroupEnd() {
  PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end());
}

std::shared_ptr<ProcessGroupBKCL::BKCLTask> ProcessGroupBKCL::CreateTask(
    const Place& place,
    int rank,
    CommType comm_type,
    bool is_sync,
    bool use_calc_stream) {
  return std::make_shared<ProcessGroupBKCL::BKCLTask>(
      place, rank, comm_type, is_sync, use_calc_stream);
}

void ProcessGroupBKCL::BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id) {
  auto key = "ProcessGroupBKCL/bkcl_ids/" + std::to_string(gid_) + "/0";
  if (rank_ == 0) {
    auto id = std::vector<uint8_t>(
        reinterpret_cast<uint8_t*>(bkcl_id),
        reinterpret_cast<uint8_t*>(bkcl_id) + BKCL_UNIQUE_ID_BYTES);
    store_->set(key, id);
  } else {
    const auto& ret = store_->get(key);
    std::memcpy(bkcl_id, ret.data(), ret.size());
  }
}

void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place,
                                          const std::string& place_key) {
114
  platform::XPUDeviceGuard guard(place.GetDeviceId());
J
james 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
  BKCLUniqueId bkcl_id;
  if (rank_ == 0) {
    PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id));
  }
  BroadcastUniqueBKCLID(&bkcl_id);

  VLOG(3) << "init bkcl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << place_key
          << ", bkcl uniqueid: " << SerializeBKCLUniqueId(bkcl_id);

  calc_event_ = std::make_shared<XPUEventManager>();
  auto* calc_ctx = static_cast<phi::XPUContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  // must use XPUDeviceContext here to make sure XPUContext::Init() is called
  auto comm_ctx = std::make_unique<XPUDeviceContext>(place);
  BKCLContext_t bkcl_comm;
  BKCLCHECK(bkcl_init_rank(&bkcl_comm, GetRank(), GetSize(), &bkcl_id));
  comm_ctx->SetBkclContext(bkcl_comm);

  place_to_calc_ctx_[place_key] = calc_ctx;
  place_to_comm_ctx_[place_key] = std::move(comm_ctx);
}

void ProcessGroupBKCL::SyncCalcStream(const Place& place) {
  const std::string& key = GetKeyFromPlace(place);
  const auto* calc_ctx = place_to_calc_ctx_[key];
  const auto* comm_ctx = place_to_comm_ctx_[key].get();
  calc_event_->Record(*calc_ctx);
  calc_event_->Block(*comm_ctx);
}

template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    Fn fn,
    CommType op_type,
    bool sync_op,
    bool use_calc_stream) {
  const auto& place = in_tensor.place();
  const auto& key = GetKeyFromPlace(place);

  if (!calc_event_) {
    CreateBKCLEnvCache(place, key);
  }

  if (!use_calc_stream) {
    SyncCalcStream(place);
  }

  auto task = CreateTask(place, rank_, op_type, sync_op, use_calc_stream);

  const auto* calc_ctx = place_to_calc_ctx_[key];
  const auto& comm_ctx = place_to_comm_ctx_[key];
  auto bkcl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
  fn(out_tensor, in_tensor, comm_ctx->bkcl_context(), bkcl_stream);

  if (!use_calc_stream) {
    task->comm_event_->Record(*comm_ctx.get());
  }

  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const AllreduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        return bkcl_all_reduce(
            comm,
            input.data(),
            output->data(),
            input.numel(),
            platform::ToBKCLDataType(
                framework::TransToProtoVarType(input.type())),
            ToBKCLRedType(opts.reduce_op),
            stream);
      },
      CommType::ALLREDUCE,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const BroadcastOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        int root = opts.source_rank + opts.source_root;
        return bkcl_broadcast(comm,
                              input.data(),
                              output->data(),
                              input.numel(),
                              platform::ToBKCLDataType(
                                  framework::TransToProtoVarType(input.type())),
                              root,
                              stream);
      },
      CommType::BROADCAST,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
238 239
    int64_t offset,  // for compatibility, no use now
    int64_t numel,   // for compatibility, no use now
J
james 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        return bkcl_all_gather(
            comm,
            input.data(),
            input.numel(),
            output->data(),
            platform::ToBKCLDataType(
                framework::TransToProtoVarType(input.type())),
            stream);
      },
      CommType::ALLGATHER,
      sync_op,
      use_calc_stream);
}

J
james 已提交
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        phi::DenseTensor output_t(*output);
        const auto& place = input.place();
        auto* calc_ctx = static_cast<phi::XPUContext*>(
            platform::DeviceContextPool::Instance().Get(place));
        switch (input.dtype()) {
          case phi::DataType::FLOAT32:
            calc_ctx->template Alloc<float>(&output_t);
            break;
          case phi::DataType::FLOAT16:
            calc_ctx->template Alloc<float16>(&output_t);
            break;
          case phi::DataType::INT32:
            calc_ctx->template Alloc<int>(&output_t);
            break;
          default:
            VLOG(0) << "Error: type " << input.dtype() << " not supported for "
                    << GetBackendName();
            break;
        }
        int ret =
            bkcl_all_reduce(comm,
                            input.data(),
                            output_t.data(),
                            input.numel(),
                            platform::ToBKCLDataType(
                                framework::TransToProtoVarType(input.type())),
                            ToBKCLRedType(opts.reduce_op),
                            stream);
        if (rank_ == opts.root_rank) {
          *output = output_t;
        }
        return ret;
      },
      CommType::ALLREDUCE,
      sync_op,
      use_calc_stream);
}

J
james 已提交
314 315
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier(
    const BarrierOptions& opts) {
316 317 318 319 320
  PADDLE_ENFORCE_GE(opts.device_id,
                    0,
                    platform::errors::PreconditionNotMet(
                        "The barrier device id must greater or equal than 0."));
  platform::XPUPlace place(opts.device_id);
J
james 已提交
321
  auto allocator = std::unique_ptr<phi::Allocator>(
322
      new paddle::experimental::DefaultAllocator(place));
J
james 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335
  phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
  phi::DenseTensor barrier_tensor{allocator.get(), meta};

  auto task = AllReduce(&barrier_tensor,
                        barrier_tensor,
                        {},
                        /*sync_op*/ true,
                        /*use_calc_stream*/ false);
  auto bkcl_task = dynamic_cast<BKCLTask*>(task.get());
  bkcl_task->barrier_ = true;
  return task;
}

336
phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext(
J
james 已提交
337 338 339 340
    const Place& place) const {
  return GetDeviceContext(place, /*use_calc_stream*/ false);
}

341
phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext(
J
james 已提交
342 343 344 345
    const Place& place, bool use_calc_stream) const {
  const std::string& key = GetKeyFromPlace(place);
  if (use_calc_stream) {
    const auto& iter = place_to_calc_ctx_.find(key);
R
Roc 已提交
346
    return iter->second;
J
james 已提交
347 348 349 350 351 352
  } else {
    const auto& iter = place_to_comm_ctx_.find(key);
    PADDLE_ENFORCE_NE(iter,
                      place_to_comm_ctx_.end(),
                      platform::errors::InvalidArgument(
                          "Cannot find device context in process group."));
R
Roc 已提交
353
    return iter->second.get();
J
james 已提交
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
  }
}

// below are old apis
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    const AllreduceOptions& opts) {
  PADDLE_ENFORCE_EQ(
      in_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  PADDLE_ENFORCE_EQ(
      out_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  return Collective(
      &out_tensors[0],
      in_tensors[0],
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        return bkcl_all_reduce(
            comm,
            input.data(),
            output->data(),
            input.numel(),
            platform::ToBKCLDataType(
                framework::TransToProtoVarType(input.type())),
            ToBKCLRedType(opts.reduce_op),
            stream);
      },
      CommType::ALLREDUCE,
      /*sync_op*/ true,
      /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    const AllreduceOptions& opts,
    bool sync_op) {
  PADDLE_ENFORCE_EQ(
      in_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  PADDLE_ENFORCE_EQ(
      out_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  return Collective(
      &out_tensors[0],
      in_tensors[0],
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        return bkcl_all_reduce(
            comm,
            input.data(),
            output->data(),
            input.numel(),
            platform::ToBKCLDataType(
                framework::TransToProtoVarType(input.type())),
            ToBKCLRedType(opts.reduce_op),
            stream);
      },
      CommType::ALLREDUCE,
      sync_op,
      /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    const BroadcastOptions& opts) {
  PADDLE_ENFORCE_EQ(
      in_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  PADDLE_ENFORCE_EQ(
      out_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));

  return Collective(
      &out_tensors[0],
      in_tensors[0],
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        const auto root =
            opts.source_rank * in_tensors.size() + opts.source_root;
        return bkcl_broadcast(comm,
                              input.data(),
                              output->data(),
                              input.numel(),
                              platform::ToBKCLDataType(
                                  framework::TransToProtoVarType(input.type())),
                              root,
                              stream);
      },
      CommType::BROADCAST,
      /*sync_op*/ true,
      /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    const BroadcastOptions& opts,
    bool sync_op) {
  PADDLE_ENFORCE_EQ(
      in_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  PADDLE_ENFORCE_EQ(
      out_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));

  return Collective(
      &out_tensors[0],
      in_tensors[0],
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        const auto root =
            opts.source_rank * in_tensors.size() + opts.source_root;
        return bkcl_broadcast(comm,
                              input.data(),
                              output->data(),
                              input.numel(),
                              platform::ToBKCLDataType(
                                  framework::TransToProtoVarType(input.type())),
                              root,
                              stream);
      },
      CommType::BROADCAST,
      sync_op,
      /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
  PADDLE_ENFORCE_EQ(
      in_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  PADDLE_ENFORCE_EQ(
      out_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  PADDLE_ENFORCE_EQ(
      CheckTensorsInXPUPlace(out_tensors),
      true,
      platform::errors::InvalidArgument("All outputs should be in XPUPlace."));
  return Collective(
      &out_tensors[0],
      in_tensors[0],
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        return bkcl_all_gather(
            comm,
            input.data(),
            input.numel(),
            output->data(),
            platform::ToBKCLDataType(
                framework::TransToProtoVarType(input.type())),
            stream);
      },
      CommType::ALLGATHER,
      /*sync_op*/ true,
      /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    bool sync_op) {
  PADDLE_ENFORCE_EQ(
      in_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  PADDLE_ENFORCE_EQ(
      out_tensors.size(),
      1,
      platform::errors::InvalidArgument(
          "BKCL only support single tensor collective communication."));
  PADDLE_ENFORCE_EQ(
      CheckTensorsInXPUPlace(out_tensors),
      true,
      platform::errors::InvalidArgument("All outputs should be in XPUPlace."));
  return Collective(
      &out_tensors[0],
      in_tensors[0],
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          BKCLContext_t comm,
          const XPUStream& stream) {
        return bkcl_all_gather(
            comm,
            input.data(),
            input.numel(),
            output->data(),
            platform::ToBKCLDataType(
                framework::TransToProtoVarType(input.type())),
            stream);
      },
      CommType::ALLGATHER,
      sync_op,
      /*use_calc_stream*/ false);
}

L
LiYuRio 已提交
585 586 587 588 589 590 591 592
std::shared_ptr<ProcessGroupBKCL> ProcessGroupBKCL::CreateProcessGroupBKCL(
    const std::shared_ptr<Store>& store, int rank, int size, int gid) {
  auto process_group =
      std::make_shared<ProcessGroupBKCL>(store, rank, size, gid);
  ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
  return process_group;
}

J
james 已提交
593 594
}  //  namespace distributed
}  //  namespace paddle