all_reduce.cc 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_NCCL

#include "paddle/fluid/imperative/all_reduce.h"

19 20 21 22 23
#include <nccl.h>

#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
24
#include "paddle/fluid/imperative/parallel_context.h"
25 26 27 28
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"

29 30
namespace paddle {
namespace imperative {
31

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
static const platform::Place &GetVarPlace(const framework::Variable &src) {
  if (src.IsType<framework::LoDTensor>()) {
    return src.Get<framework::LoDTensor>().place();
#if NCCL_VERSION_CODE >= 2212
  } else if (src.IsType<framework::SelectedRows>()) {
    return src.Get<framework::SelectedRows>().value().place();
#endif
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Cannot get unsupported variable type %s for imperative allreduce, "
        "only "
        "LoDTensor and SelectedRows are supported.",
        platform::demangle(framework::ToTypeName(src.Type()))));
  }
}
47 48

static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
49 50
                      const cudaStream_t stream,
                      const platform::NCCLComm *comm) {
51 52 53 54 55 56 57 58 59 60 61
  const auto &place = src.place();
  PADDLE_ENFORCE_EQ(
      platform::is_gpu_place(place), true,
      platform::errors::Unimplemented(
          "Imperative mode does not support multi-CPU training yet."));

  const void *src_ptr = src.data<void>();
  dst->Resize(src.dims());
  auto *dst_ptr = dst->mutable_data(src.place(), src.type());
  auto nccl_dtype = platform::ToNCCLDataType(src.type());
  PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
62 63
      src_ptr, dst_ptr, src.numel(), nccl_dtype, ncclSum, comm->comm(),
      stream));
64 65 66 67 68
}

#if NCCL_VERSION_CODE >= 2212
static void AllReduce(const framework::SelectedRows &src,
                      framework::SelectedRows *dst,
69 70 71
                      const ParallelStrategy &strategy,
                      const cudaStream_t stream,
                      const platform::NCCLComm *comm) {
72
  VLOG(3) << "SelectedRows AllReduce start";
73 74 75 76 77 78 79 80 81 82 83
  const auto &src_tensor = src.value();
  const auto &place = src_tensor.place();
  PADDLE_ENFORCE_EQ(
      platform::is_gpu_place(place), true,
      platform::errors::Unimplemented(
          "Imperative mode does not support multi-CPU training yet."));

  auto dtype = src_tensor.type();
  auto nccl_dtype = platform::ToNCCLDataType(dtype);
  auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
      platform::DeviceContextPool::Instance().Get(place));
84 85

  bool use_calc_stream = (dev_ctx->stream() == stream);
86 87 88 89 90 91

  // 1. Gather rows number from all workers. Here use ncclAllGather to do this,
  // but we can use other ways to implement is in the future
  const auto &src_rows = src.rows();
  framework::Vector<int64_t> rows_num_vector(strategy.nranks_);
  rows_num_vector[strategy.local_rank_] = static_cast<int64_t>(src_rows.size());
92
  // CUDAMutableData use CalStream
93
  auto *gpu_rows_num_ptr = rows_num_vector.CUDAMutableData(place);
94 95 96
  if (!use_calc_stream) {
    dev_ctx->Wait();
  }
97 98
  PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
      gpu_rows_num_ptr + strategy.local_rank_, gpu_rows_num_ptr, 1, ncclInt64,
99
      comm->comm(), stream));
100

101
  if (!use_calc_stream) {
102 103 104 105 106 107 108 109 110
    PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
  }

  const auto *cpu_rows_num_ptr = rows_num_vector.data();
  auto rows_num =
      std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + strategy.nranks_,
                      static_cast<int64_t>(0));
  dst->set_height(src.height());

111
  VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',')
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
          << ", total rows number: " << rows_num
          << ", height: " << src.height();

  auto *dst_rows = dst->mutable_rows();
  dst_rows->resize(rows_num);
  auto *dst_rows_ptr = dst_rows->CUDAMutableData(place);
  const auto *src_rows_ptr = src_rows.CUDAData(place);

  auto *dst_tensor = dst->mutable_value();
  auto dims = src_tensor.dims();
  dims[0] = rows_num;
  auto feature_size = framework::product(dims) / dims[0];
  dst_tensor->Resize(dims);
  auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype);
  const auto *src_tensor_ptr = src_tensor.data<void>();

  auto sizeof_dtype = framework::SizeOfType(dtype);
  int64_t row_offset = 0;
130 131 132
  if (!use_calc_stream) {
    dev_ctx->Wait();
  }
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + strategy.nranks_,
                  [&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) {
    // During sparse communication, the number of each card is same.
    // allgather is used to speed up the allreduce by replacing broadcast.
    auto row_sendcount = cpu_rows_num_ptr[0];
    VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce";
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
        src_rows_ptr, dst_rows_ptr, row_sendcount, ncclInt64, comm->comm(),
        stream));
    auto value_sendcount = cpu_rows_num_ptr[0] * feature_size;
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
        src_tensor_ptr, dst_tensor_ptr, value_sendcount, nccl_dtype,
        comm->comm(), stream));
    return;
  }
148 149 150 151 152
  for (int i = 0; i < strategy.nranks_; ++i) {
    if (cpu_rows_num_ptr[i] > 0) {
      // 2. Broadcast the rows of SelectedRows
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast(
          src_rows_ptr, dst_rows_ptr + row_offset, cpu_rows_num_ptr[i],
153
          ncclInt64, i, comm->comm(), stream));
154 155 156 157 158
      // 3. Broadcast the tensor data of SelectedRows
      auto *dst_tensor_ptr_i = reinterpret_cast<uint8_t *>(dst_tensor_ptr) +
                               row_offset * feature_size * sizeof_dtype;
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast(
          src_tensor_ptr, dst_tensor_ptr_i, cpu_rows_num_ptr[i] * feature_size,
159
          nccl_dtype, i, comm->comm(), stream));
160 161 162 163
      row_offset += cpu_rows_num_ptr[i];
    }
  }

164
  VLOG(3) << "Original SelectedRows rows: "
165
          << string::join_strings(src_rows, ',');
166
  VLOG(3) << "Result SelectedRows rows: "
167 168 169 170 171
          << string::join_strings(*dst_rows, ',');
}
#endif

void AllReduce(const framework::Variable &src, framework::Variable *dst,
172 173 174 175 176 177 178 179 180
               const ParallelStrategy &strategy, int ring_id,
               bool use_calc_stream) {
  const auto &place = GetVarPlace(src);
  auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
      platform::DeviceContextPool::Instance().Get(place));
  platform::NCCLComm *comm =
      platform::NCCLCommContext::Instance().Get(ring_id, place);
  cudaStream_t stream = (use_calc_stream ? dev_ctx->stream() : comm->stream());

181 182 183 184 185
  if (src.IsType<framework::LoDTensor>()) {
    if (!dst->IsType<framework::LoDTensor>()) {
      dst->Clear();
    }
    AllReduce(src.Get<framework::LoDTensor>(),
186
              dst->GetMutable<framework::LoDTensor>(), stream, comm);
187 188 189 190 191 192 193
#if NCCL_VERSION_CODE >= 2212
  } else if (src.IsType<framework::SelectedRows>()) {
    if (&src != dst) {
      if (!dst->IsType<framework::SelectedRows>()) {
        dst->Clear();
      }
      AllReduce(src.Get<framework::SelectedRows>(),
194 195
                dst->GetMutable<framework::SelectedRows>(), strategy, stream,
                comm);
196 197 198 199
    } else {
      // SelectedRows cannot be allreduce in-place
      framework::Variable tmp_dst;
      AllReduce(src.Get<framework::SelectedRows>(),
200 201 202 203
                tmp_dst.GetMutable<framework::SelectedRows>(), strategy, stream,
                comm);
      // stream must synchronize to ensure accuracy of the move operation
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
204 205 206 207 208 209 210 211 212 213 214 215 216
      *dst = std::move(tmp_dst);
    }
#endif
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Unsupported variable type %s for imperative allreduce, only "
        "LoDTensor and SelectedRows are supported.",
        platform::demangle(framework::ToTypeName(src.Type()))));
  }
}

void AllReduce(const framework::Variable &src, framework::Variable *dst,
               const ParallelStrategy &strategy) {
217
  AllReduce(src, dst, strategy, 0, true);
218 219 220 221 222 223
}

}  // namespace imperative
}  // namespace paddle

#endif