all_reduce_op_handle.cc 12.7 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
15

16
#include "paddle/fluid/framework/convert_utils.h"
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/details/container_cast.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/framework/details/reduce_and_gather.h"
19
#include "paddle/fluid/platform/place.h"
20
#include "paddle/fluid/platform/profiler/event_tracing.h"
Y
Stash  
Yu Yang 已提交
21

22
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
23 24
DECLARE_bool(sync_nccl_allreduce);
#endif
Y
Yancey1989 已提交
25

Y
Yu Yang 已提交
26 27 28
namespace paddle {
namespace framework {
namespace details {
C
chengduoZH 已提交
29

30
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
X
Xin Pan 已提交
31 32
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
                                     const std::vector<Scope *> &local_scopes,
33
                                     const std::vector<platform::Place> &places,
34
                                     const platform::NCCLCommunicator *ctxs)
35
    : NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) {
36 37 38 39 40 41
  PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size(),
                    platform::errors::InvalidArgument(
                        "The number of places and the number of local scopes "
                        "should be equal, but got number of places is %d and "
                        "number of local scopes is %d.",
                        places_.size(), local_scopes_.size()));
Y
Yu Yang 已提交
42
}
43 44 45 46 47 48 49 50 51 52 53 54 55
#elif defined(PADDLE_WITH_XPU_BKCL)
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
                                     const std::vector<Scope *> &local_scopes,
                                     const std::vector<platform::Place> &places,
                                     const platform::BKCLCommunicator *ctxs)
    : BKCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) {
  PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size(),
                    platform::errors::InvalidArgument(
                        "The number of places and the number of local scopes "
                        "should be equal, but got number of places is %d and "
                        "number of local scopes is %d.",
                        places_.size(), local_scopes_.size()));
}
C
chengduoZH 已提交
56
#else
X
Xin Pan 已提交
57 58
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
                                     const std::vector<Scope *> &local_scopes,
59
                                     const std::vector<platform::Place> &places)
60
    : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
61 62 63 64 65 66
  PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size(),
                    platform::errors::InvalidArgument(
                        "The number of places and the number of local scopes "
                        "should be equal, but got number of places is %d and "
                        "number of local scopes is %d.",
                        places_.size(), local_scopes_.size()));
67
}
C
chengduoZH 已提交
68
#endif
Y
Yu Yang 已提交
69

70
void AllReduceOpHandle::RunImpl() {
71 72
  platform::RecordEvent record_event(
      Name(), platform::TracerEventType::Communication, 1);
73 74 75 76 77 78 79 80 81 82 83 84
  WaitInputVarGenerated();
  std::vector<VarHandleBase *> inputs = this->Inputs();
  std::vector<VarHandleBase *> outputs = this->Outputs();
  auto in_var_handles = DynamicCast<VarHandle>(inputs);
  auto out_var_handles = DynamicCast<VarHandle>(outputs);
  AllReduceImpl(in_var_handles, out_var_handles);
}

void AllReduceOpHandle::AllReduceImpl(
    const std::vector<VarHandle *> &in_var_handles,
    const std::vector<VarHandle *> &out_var_handles) {
  size_t num_places = places_.size();
85 86 87 88
  PADDLE_ENFORCE_EQ(in_var_handles.size(), num_places,
                    platform::errors::InvalidArgument(
                        "The NoDummyInputSize should be equal "
                        "to the number of places, but got NoDummyInputSize is "
89
                        "%d and the number of places is %d.",
90
                        in_var_handles.size(), num_places));
91 92
  PADDLE_ENFORCE_EQ(
      in_var_handles.size(), out_var_handles.size(),
93 94 95 96 97 98 99 100 101
      platform::errors::InvalidArgument(
          "The NoDummyInputSize and NoDummyOutputSize should be "
          "equal, but got NoDummyInputSize is %d and NoDummyOutputSize is %d.",
          in_var_handles.size(), out_var_handles.size()));
  PADDLE_ENFORCE_EQ(
      local_exec_scopes_.size(), num_places,
      platform::errors::InvalidArgument(
          "The number of local scopes should be equal "
          "to the number of places, but got the number of local scopes is "
102
          "%d and the number of places is %d.",
103
          in_var_handles.size(), num_places));
104 105 106 107 108 109 110

  std::vector<const void *> lod_tensor_data;
  std::vector<platform::Place> places;
  lod_tensor_data.reserve(num_places);
  places.reserve(num_places);
  int64_t numel = -1;
  bool is_gpu_place = false;
111 112 113
#if defined(PADDLE_WITH_XPU_BKCL)
  bool is_xpu_place = false;
#endif
114 115 116 117
  auto dtype = static_cast<framework::proto::VarType::Type>(0);
  for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
    auto &local_scope = local_exec_scopes_[i];
    auto var = local_scope->FindVar(in_var_handles[i]->name());
118 119 120
    PADDLE_ENFORCE_NOT_NULL(var, platform::errors::NotFound(
                                     "Variable %s is not found in local scope.",
                                     in_var_handles[i]->name()));
121 122 123 124
    auto &lod_tensor = var->Get<LoDTensor>();

    if (i == 0) {
      numel = static_cast<int64_t>(lod_tensor.numel());
125 126
      // only enforce place0, we will enforce other palce numel == place0 numel
      PADDLE_ENFORCE_GT(
127 128 129 130
          numel, 0,
          platform::errors::PreconditionNotMet(
              "The numel of tensor %s should be > 0, but got numel is %d.",
              in_var_handles[i]->name(), numel));
131
      dtype = framework::TransToProtoVarType(lod_tensor.dtype());
132
      is_gpu_place = platform::is_gpu_place(lod_tensor.place());
133 134 135
#if defined(PADDLE_WITH_XPU_BKCL)
      is_xpu_place = platform::is_xpu_place(lod_tensor.place());
#endif
136
    }
137 138 139 140 141 142
    PADDLE_ENFORCE_EQ(
        numel, static_cast<int64_t>(lod_tensor.numel()),
        platform::errors::PreconditionNotMet(
            "The size of tensors of the same variable in different local "
            "scopes should be equal."));
    PADDLE_ENFORCE_EQ(
143
        dtype, framework::TransToProtoVarType(lod_tensor.dtype()),
144 145 146
        platform::errors::PreconditionNotMet(
            "The dtype of tensors of the same variable in different local "
            "scopes should be equal."));
147 148 149 150 151 152
#if defined(PADDLE_WITH_XPU_BKCL)
    PADDLE_ENFORCE_EQ(is_xpu_place, platform::is_xpu_place(lod_tensor.place()),
                      platform::errors::PreconditionNotMet(
                          "The place type of tensors of the same variable "
                          "in different local scopes should be equal."));
#endif
153 154 155 156
    PADDLE_ENFORCE_EQ(is_gpu_place, platform::is_gpu_place(lod_tensor.place()),
                      platform::errors::PreconditionNotMet(
                          "The place type of tensors of the same variable "
                          "in different local scopes should be equal."));
157

158
    lod_tensor_data.emplace_back(lod_tensor.data());
159 160 161 162 163
    places.emplace_back(lod_tensor.place());

    VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
             << ", out_name:" << out_var_handles[i]->name();

164 165 166 167 168 169
    PADDLE_ENFORCE_EQ(
        in_var_handles[i]->name(), out_var_handles[i]->name(),
        platform::errors::InvalidArgument(
            "The name of input and output of all_reduce op should be equal, "
            "but got input is %s and output is %s.",
            in_var_handles[i]->name(), out_var_handles[i]->name()));
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
  }

  std::vector<std::string> grad_var_names;
  grad_var_names.reserve(num_places);
  for (auto &out_var : out_var_handles) {
    grad_var_names.emplace_back(out_var->Name());
  }

  AllReduceFunc(lod_tensor_data, dtype, numel, places, grad_var_names);
}

void AllReduceOpHandle::AllReduceFunc(
    std::vector<const void *> lod_tensor_data,
    const framework::proto::VarType::Type &dtype, int64_t numel,
    const std::vector<platform::Place> &places,
    const std::vector<std::string> &out_var_names) {
186
  if (platform::is_gpu_place(places[0])) {
187
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
188 189 190
    PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_,
                            platform::errors::InvalidArgument(
                                "The nccl context should not be NULL."));
191 192 193 194 195 196 197 198 199 200 201
    ncclDataType_t nccl_dtype = platform::ToNCCLDataType(dtype);
    std::vector<std::function<void()>> all_reduce_calls;
    for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
      auto &p = places[i];
      void *buffer = const_cast<void *>(lod_tensor_data.at(i));
      all_reduce_calls.emplace_back([=] {
        NCCLAllReduce(p, buffer, buffer, numel, nccl_dtype, ncclSum);
      });
    }
    NCCLAllReduceFunc(all_reduce_calls);
#else
202
    PADDLE_THROW(
203
        platform::errors::PreconditionNotMet("Not compiled with GPU."));
204
#endif
205
  } else if (platform::is_xpu_place(places[0])) {
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
#if defined(PADDLE_WITH_XPU_BKCL)
    PADDLE_ENFORCE_NOT_NULL(bkcl_ctxs_,
                            platform::errors::InvalidArgument(
                                "The bkcl context should not be NULL."));
    BKCLDataType bkcl_dtype = platform::ToBKCLDataType(dtype);
    std::vector<std::function<void()>> all_reduce_calls;
    for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
      auto &p = places[i];
      void *buffer = const_cast<void *>(lod_tensor_data.at(i));
      all_reduce_calls.emplace_back([=] {
        BKCLAllReduce(p, buffer, buffer, numel, bkcl_dtype, BKCL_ADD);
      });
    }
    BKCLAllReduceFunc(all_reduce_calls);
#else
    PADDLE_THROW(
        platform::errors::PreconditionNotMet("Not compiled with BKCL."));
223 224 225 226 227 228 229
#endif
  } else {  // Special handle CPU only Operator's gradient. Like CRF
    auto &trg = *local_exec_scopes_[0]
                     ->FindVar(out_var_names[0])
                     ->GetMutable<LoDTensor>();

    // Reduce All Tensor to trg in CPU
230
    ReduceBufferData func(lod_tensor_data, trg.data(), numel);
231
    VisitDataType(framework::TransToProtoVarType(trg.dtype()), func);
232 233 234 235 236 237

    for (size_t i = 1; i < local_exec_scopes_.size(); ++i) {
      auto &scope = local_exec_scopes_[i];
      auto &p = places[i];
      auto *var = scope->FindVar(out_var_names[i]);

238 239
      size_t size =
          numel * SizeOfType(framework::TransToProtoVarType(trg.dtype()));
240
      RunAndRecordEvent(p, [&trg, var, p, size] {
241
        auto dst_ptr = var->GetMutable<framework::LoDTensor>()->data();
242
        platform::CPUPlace cpu_place;
243
        memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data(), size);
244 245 246 247 248 249
      });
    }
  }
  VLOG(10) << Name() << " size:" << numel * SizeOfType(dtype);
}

250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
#if defined(PADDLE_WITH_XPU_BKCL)
void AllReduceOpHandle::BKCLAllReduceFunc(
    const std::vector<std::function<void()>> &all_reduce_calls) {
  this->RunAndRecordEvent([&] {
    if (all_reduce_calls.size() == 1UL) {
      all_reduce_calls[0]();
    } else {
      PADDLE_ENFORCE_EQ(
          bkcl_group_start(), BKCL_SUCCESS,
          platform::errors::PreconditionNotMet("bkcl_group_start failed"));
      for (auto &call : all_reduce_calls) {
        call();
      }
      PADDLE_ENFORCE_EQ(
          bkcl_group_end(), BKCL_SUCCESS,
          platform::errors::PreconditionNotMet("bkcl_group_end failed"));
    }
  });
}
#endif

271
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
272
void AllReduceOpHandle::NCCLAllReduceFunc(
G
gongweibao 已提交
273
    const std::vector<std::function<void()>> &all_reduce_calls) {
274 275 276 277 278 279 280 281 282 283 284 285
  this->RunAndRecordEvent([&] {
    if (all_reduce_calls.size() == 1UL) {
      // Do not use NCCLGroup when manage NCCL by per thread per device
      all_reduce_calls[0]();
    } else {
      platform::NCCLGroupGuard guard;
      for (auto &call : all_reduce_calls) {
        call();
      }
    }
  });

286 287 288 289
  SyncNCCLAllReduce();
}

void AllReduceOpHandle::SyncNCCLAllReduce() {
290 291
  if (FLAGS_sync_nccl_allreduce) {
    for (auto &p : places_) {
292
      int dev_id = p.device;
293 294 295
      auto *nccl_ctxs =
          nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_);
      auto &nccl_ctx = nccl_ctxs->at(dev_id);
296
      auto stream = nccl_ctx.stream();
297 298 299

      platform::GpuStreamSync(stream);
      PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
300 301 302 303 304
    }
  }
}
#endif

C
chengduoZH 已提交
305
std::string AllReduceOpHandle::Name() const { return "all_reduce"; }
Y
Yu Yang 已提交
306 307 308
}  // namespace details
}  // namespace framework
}  // namespace paddle