broadcast_op_handle.cc 9.3 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

C
chengduoZH 已提交
15
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
16

Y
Yu Yang 已提交
17 18
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
Y
Yancey1989 已提交
19
#include "paddle/fluid/platform/profiler.h"
C
chengduoZH 已提交
20 21 22 23 24

namespace paddle {
namespace framework {
namespace details {

C
chengduoZH 已提交
25
void BroadcastOpHandle::RunImpl() {
26
  platform::RecordEvent record_event(Name());
Y
Yancey1989 已提交
27

C
chengduoZH 已提交
28
  if (places_.size() == 1) return;
Y
Yu Yang 已提交
29

C
chengduoZH 已提交
30
  // The input and output may have dummy vars.
C
chengduo 已提交
31
  auto in_var_handles = DynamicCast<VarHandle>(inputs_);
Y
Yu Yang 已提交
32
  auto out_var_handles = DynamicCast<VarHandle>(outputs_);
C
chengduoZH 已提交
33

C
chengduo 已提交
34
  PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
35 36 37 38 39 40 41 42 43
                    platform::errors::PreconditionNotMet(
                        "The number of inputs should be 1, but got %d.",
                        in_var_handles.size()));
  PADDLE_ENFORCE_EQ(out_var_handles.size(), places_.size(),
                    platform::errors::PreconditionNotMet(
                        "The number of outputs and the number of places should "
                        "be equal, but got the number of outputs is %d and the "
                        "number of places is %d.",
                        out_var_handles.size(), places_.size()));
C
chengduoZH 已提交
44

C
chengduo 已提交
45 46
  VarHandle *in_var_handle = in_var_handles[0];

47
  BroadcastOneVar(*in_var_handle, out_var_handles, local_exec_scopes_);
48 49 50 51 52
}

void BroadcastOpHandle::BroadcastOneVar(
    const VarHandle &in_var_handle,
    const std::vector<VarHandle *> &out_var_handles,
53
    const std::vector<Scope *> &var_scopes) {
C
chengduoZH 已提交
54
  auto *in_var =
G
gongweibao 已提交
55
      var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
56 57 58
  PADDLE_ENFORCE_NOT_NULL(
      in_var, platform::errors::NotFound("Variable %s is not found in scopes.",
                                         in_var_handle.name()));
Y
Yu Yang 已提交
59
  Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
60
  if (UNLIKELY(!in_tensor.IsInitialized())) {
G
gongweibao 已提交
61
    VLOG(3) << "in var " << in_var_handle.name() << "not inited, return!";
62 63
    return;
  }
C
chengduoZH 已提交
64

65
  InitOutputValue(in_var_handle, out_var_handles);
C
chengduoZH 已提交
66

C
chengduoZH 已提交
67
  if (platform::is_cpu_place(in_tensor.place())) {
68
    WaitInputVarGenerated();
C
chengduoZH 已提交
69
    for (auto *out_var_handle : out_var_handles) {
70
      if (out_var_handle->IsTheSameVar(in_var_handle)) {
C
chengduoZH 已提交
71 72
        continue;
      }
G
gongweibao 已提交
73 74 75
      auto &out_p = out_var_handle->place();
      auto *out_var = var_scopes.at(out_var_handle->scope_idx())
                          ->FindVar(out_var_handle->name());
C
chengduoZH 已提交
76

C
chengduoZH 已提交
77
      RunAndRecordEvent(out_p, [in_tensor, out_var] {
C
chengduoZH 已提交
78
        paddle::framework::TensorCopy(
C
chengduoZH 已提交
79
            in_tensor, platform::CPUPlace(),
C
chengduoZH 已提交
80 81 82
            &VariableVisitor::GetMutableTensor(out_var));
      });
    }
83
  } else if (platform::is_gpu_place(in_tensor.place())) {
84
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
C
chengduoZH 已提交
85
    VarHandle *out_handle = nullptr;
86 87
    int root_id =
        BOOST_GET_CONST(platform::CUDAPlace, in_tensor.place()).device;
C
chengduoZH 已提交
88 89
    std::vector<std::function<void()>> broadcast_calls;

C
chengduoZH 已提交
90 91 92
    int type = platform::ToNCCLDataType(in_tensor.type());
    size_t numel = static_cast<size_t>(in_tensor.numel());

C
chengduoZH 已提交
93
    for (auto out_var_handle : out_var_handles) {
G
gongweibao 已提交
94 95
      Variable *out_var = var_scopes.at(out_var_handle->scope_idx())
                              ->FindVar(out_var_handle->name());
C
chengduoZH 已提交
96

C
chengduoZH 已提交
97
      int dst_id =
98
          BOOST_GET_CONST(platform::CUDAPlace, out_var_handle->place()).device;
C
chengduoZH 已提交
99

C
chengduoZH 已提交
100
      auto &nccl_ctx = nccl_ctxs_->at(dst_id);
C
chengduoZH 已提交
101 102

      void *send_recv_buffer = nullptr;
C
chengduoZH 已提交
103
      if (root_id == dst_id) {
104
        send_recv_buffer = const_cast<void *>(in_tensor.data());
C
chengduoZH 已提交
105 106
        out_handle = out_var_handle;
      } else {
C
chengduoZH 已提交
107 108
        send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
                               .Resize(in_tensor.dims())
G
gongweibao 已提交
109
                               .mutable_data(out_var_handle->place());
C
chengduoZH 已提交
110 111
      }

C
chengduoZH 已提交
112 113
      broadcast_calls.emplace_back(
          [send_recv_buffer, numel, type, root_id, &nccl_ctx] {
114
            PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
C
chengduoZH 已提交
115 116 117
                send_recv_buffer, numel, static_cast<ncclDataType_t>(type),
                root_id, nccl_ctx.comm_, nccl_ctx.stream()));
          });
Y
Yu Yang 已提交
118 119
    }

120
    WaitInputVarGenerated();
121 122 123 124 125
    this->RunAndRecordEvent([&] {
      {
        platform::NCCLGroupGuard guard;
        for (auto &call : broadcast_calls) {
          call();
C
chengduoZH 已提交
126
        }
127
      }
C
chengduoZH 已提交
128

129
      if (!out_handle->IsTheSameVar(in_var_handle)) {
G
gongweibao 已提交
130 131
        auto out_var = var_scopes.at(in_var_handle.scope_idx())
                           ->FindVar(out_var_handles[0]->name());
132
        paddle::framework::TensorCopy(
G
gongweibao 已提交
133 134
            in_tensor, in_var_handle.place(),
            *(dev_ctxes_.at(in_var_handle.place())),
135 136 137
            &VariableVisitor::GetMutableTensor(out_var));
      }
    });
C
chengduo 已提交
138 139 140
    for (auto &p : places_) {
      nccl_ctxs_->DevCtx(p)->Wait();
    }
C
chengduoZH 已提交
141
#else
142 143
    PADDLE_THROW(
        platform::errors::PreconditionNotMet("Not compiled with NCLL."));
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
#endif
  } else {
#if defined(PADDLE_WITH_XPU_BKCL)
    VarHandle *out_handle = nullptr;
    int root_id = BOOST_GET_CONST(platform::XPUPlace, in_tensor.place()).device;
    std::vector<std::function<void()>> broadcast_calls;

    int type = platform::ToBKCLDataType(in_tensor.type());
    size_t numel = static_cast<size_t>(in_tensor.numel());

    for (auto out_var_handle : out_var_handles) {
      Variable *out_var = var_scopes.at(out_var_handle->scope_idx())
                              ->FindVar(out_var_handle->name());

      int dst_id =
          BOOST_GET_CONST(platform::XPUPlace, out_var_handle->place()).device;

      auto &bkcl_ctx = bkcl_ctxs_->at(dst_id);

      void *send_recv_buffer = nullptr;
      if (root_id == dst_id) {
165
        send_recv_buffer = const_cast<void *>(in_tensor.data());
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
        out_handle = out_var_handle;
      } else {
        send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
                               .Resize(in_tensor.dims())
                               .mutable_data(out_var_handle->place());
      }

      broadcast_calls.emplace_back([send_recv_buffer, numel, type, root_id,
                                    &bkcl_ctx] {
        PADDLE_ENFORCE_EQ(
            bkcl_broadcast(bkcl_ctx.comm(), send_recv_buffer, send_recv_buffer,
                           numel, static_cast<BKCLDataType>(type), root_id,
                           nullptr),
            BKCL_SUCCESS,
            platform::errors::Unavailable("bkcl_broadcast failed"));
      });
    }

    WaitInputVarGenerated();
    this->RunAndRecordEvent([&] {
      {
        PADDLE_ENFORCE_EQ(
            bkcl_group_start(), BKCL_SUCCESS,
            platform::errors::Unavailable("bkcl_group_start failed"));
        for (auto &call : broadcast_calls) {
          call();
        }
        PADDLE_ENFORCE_EQ(
            bkcl_group_end(), BKCL_SUCCESS,
            platform::errors::Unavailable("bkcl_group_end failed"));
      }

      if (!out_handle->IsTheSameVar(in_var_handle)) {
        auto out_var = var_scopes.at(in_var_handle.scope_idx())
                           ->FindVar(out_var_handles[0]->name());
        paddle::framework::TensorCopy(
            in_tensor, in_var_handle.place(),
            *(dev_ctxes_.at(in_var_handle.place())),
            &VariableVisitor::GetMutableTensor(out_var));
      }
    });
#else
    PADDLE_THROW(
        platform::errors::PreconditionNotMet("Not compiled with BKCL."));
C
chengduoZH 已提交
210
#endif
C
chengduoZH 已提交
211 212 213
  }
}

214 215 216
void BroadcastOpHandle::InitOutputValue(
    const VarHandle &in_var_handle,
    const std::vector<VarHandle *> &out_var_handles) const {
217
  auto &var_scopes = local_exec_scopes_;
218
  auto *in_var =
G
gongweibao 已提交
219
      var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
220 221 222 223 224 225 226 227 228

  Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);

  // NOTE: The tensors' Place of input and output must be all on GPU or all on
  // CPU.
  for (auto *out_var_handle : out_var_handles) {
    if (out_var_handle->IsTheSameVar(in_var_handle)) {
      continue;
    }
G
gongweibao 已提交
229 230 231
    auto t_out_p = out_var_handle->place();
    auto *out_var = var_scopes.at(out_var_handle->scope_idx())
                        ->FindVar(out_var_handle->name());
232 233 234
    PADDLE_ENFORCE_NOT_NULL(out_var, platform::errors::NotFound(
                                         "Variable %s is not found in scopes.",
                                         out_var_handle->name()));
235
    if (is_gpu_place(in_tensor.place())) {
236 237 238
      PADDLE_ENFORCE_EQ(platform::is_gpu_place(t_out_p), true,
                        platform::errors::PreconditionNotMet(
                            "Places of input and output must be all on GPU."));
239 240 241 242 243 244 245 246 247
    } else {
      t_out_p = platform::CPUPlace();
    }
    VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
    VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
                                                            in_tensor.type());
  }
}

C
chengduoZH 已提交
248
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
C
chengduoZH 已提交
249 250 251
}  // namespace details
}  // namespace framework
}  // namespace paddle