nccl_wrapper.cc 2.4 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/fleet/nccl_wrapper.h"

namespace paddle {
namespace framework {

std::shared_ptr<NCCLWrapper> NCCLWrapper::s_instance_ = NULL;
bool NCCLWrapper::is_initialized_ = false;

void NCCLWrapper::InitNCCL() {
24
#if defined(PADDLE_WITH_NCCL)
25
  PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitRank(
D
dongdaxiang 已提交
26
      &(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_,
D
dongdaxiang 已提交
27
      nccl_info_.my_global_rank_));
D
dongdaxiang 已提交
28
#endif
D
dongdaxiang 已提交
29 30 31 32
  return;
}

void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) {
33
#if defined(PADDLE_WITH_NCCL)
D
dongdaxiang 已提交
34
  nccl_info_.nccl_id_ = nccl_info.nccl_id_;
D
dongdaxiang 已提交
35 36
#endif
  return;
D
dongdaxiang 已提交
37 38 39
}

NCCLInfo NCCLWrapper::GetNCCLId() {
40
#if defined(PADDLE_WITH_NCCL)
41 42
  PADDLE_ENFORCE_CUDA_SUCCESS(
      platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_)));
D
dongdaxiang 已提交
43
#endif
D
dongdaxiang 已提交
44 45 46 47 48
  return nccl_info_;
}

void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
                              const int ranks) {
49
#if defined(PADDLE_WITH_NCCL)
D
dongdaxiang 已提交
50 51 52
  nccl_info_.local_rank_ = local_rank;
  nccl_info_.my_global_rank_ = global_rank;
  nccl_info_.global_ranks_ = ranks;
L
Leo Chen 已提交
53
  platform::SetDeviceId(local_rank);
54
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(nccl_info_.stream_)));
D
dongdaxiang 已提交
55
#endif
D
dongdaxiang 已提交
56 57 58 59 60
  return;
}

void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
                          const std::vector<std::string>& var_names) {
61
#if defined(PADDLE_WITH_NCCL)
D
dongdaxiang 已提交
62 63 64 65
  for (auto& name : var_names) {
    auto var = scope.FindVar(name);
    LoDTensor* tensor = var->GetMutable<LoDTensor>();
    int32_t total_size = tensor->numel();
66
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
D
dongdaxiang 已提交
67 68
        reinterpret_cast<void*>(tensor->data<float>()), total_size, ncclFloat,
        root_rank, nccl_info_.comm_, nccl_info_.stream_));
D
dongdaxiang 已提交
69 70
    cudaStreamSynchronize(nccl_info_.stream_);
  }
D
dongdaxiang 已提交
71 72
#endif
  return;
D
dongdaxiang 已提交
73 74 75 76
}

}  // end namespace framework
}  // end namespace paddle