multi_devices_graph_builder.cc 24.1 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
C
chengduoZH 已提交
14
#include <algorithm>
Y
Yancey1989 已提交
15
#include <fstream>
C
chengduoZH 已提交
16
#include <string>
C
chengduoZH 已提交
17
#include <utility>
C
chengduoZH 已提交
18 19
#include <vector>

20
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
C
chengduoZH 已提交
21
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
Y
Yu Yang 已提交
22
#include "paddle/fluid/framework/details/computation_op_handle.h"
23
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
C
chengduoZH 已提交
24
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
C
chengduoZH 已提交
25
#include "paddle/fluid/framework/details/reduce_op_handle.h"
Y
Yancey1989 已提交
26
#include "paddle/fluid/framework/details/rpc_op_handle.h"
Y
Yu Yang 已提交
27
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
X
Xin Pan 已提交
28
#include "paddle/fluid/framework/ir/node.h"
Y
Fix bug  
yuyang18 已提交
29
#include "paddle/fluid/framework/op_info.h"
Y
Yu Yang 已提交
30
#include "paddle/fluid/framework/scope.h"
Y
Yu Yang 已提交
31

Y
Yu Yang 已提交
32 33 34
namespace paddle {
namespace framework {
namespace details {
Y
Yu Yang 已提交
35 36

#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
37 38 39 40
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
C
chengduoZH 已提交
41
    const std::vector<Scope *> &local_scopes,
Y
yuyang18 已提交
42
    platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
Y
Yu Yang 已提交
43 44 45
    : loss_var_name_(loss_var_name),
      places_(places),
      local_scopes_(local_scopes),
C
chengduoZH 已提交
46
      nccl_ctxs_(nccl_ctxs),
Y
yuyang18 已提交
47
      strategy_(strategy) {
Y
Yu Yang 已提交
48 49 50 51 52
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
Y
yuyang18 已提交
53
    const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
Y
Yu Yang 已提交
54 55
    : loss_var_name_(loss_var_name),
      places_(places),
C
chengduoZH 已提交
56
      local_scopes_(local_scopes),
Y
yuyang18 已提交
57
      strategy_(strategy) {
Y
Yu Yang 已提交
58
#endif
Y
Yu Yang 已提交
59 60 61
  for (auto &p : params) {
    grad_names_.insert(GradVarName(p));
  }
Y
Yancey1989 已提交
62
  balance_vars_.resize(places_.size(), 0);
Y
yuyang18 已提交
63 64 65 66 67
  if (strategy_.enable_data_balance_ && places_.size() == 1) {
    LOG(WARNING) << "It is no need to enable data balance when there is only "
                    "one place. enable_data_balance is set to False.";
    strategy_.enable_data_balance_ = false;
  }
Y
Yu Yang 已提交
68 69
}

X
Xin Pan 已提交
70
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op,
Y
Yu Yang 已提交
71 72
                                                size_t place_id) const {
  auto p = places_[place_id];
X
Xin Pan 已提交
73 74
  auto *op_handle =
      boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
X
Xin Pan 已提交
75 76
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
T
wip  
typhoonzero 已提交
77

Y
Yu Yang 已提交
78 79 80
  for (auto &each_var_name : op.InputArgumentNames()) {
    VarHandle *var =
        CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
T
wip  
typhoonzero 已提交
81 82 83
    op_handle->AddInput(var);
  }

Y
Yu Yang 已提交
84 85
  for (auto &each_var_name : op.OutputArgumentNames()) {
    CreateOpOutput(result, op_handle, each_var_name, p, place_id);
T
wip  
typhoonzero 已提交
86 87
  }
}
Y
fix pe  
Yancey1989 已提交
88 89 90 91

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
    const ProgramDesc &program) const {
  std::vector<std::string> send_vars;
Y
Yancey1989 已提交
92 93
  // since parameters are all in block 0,
  // it's enough to only scan send ops in block 0
Y
fix pe  
Yancey1989 已提交
94
  for (auto *op : program.Block(0).AllOps()) {
Y
Yancey1989 已提交
95 96
    // TODO(Yancey1989): use a graceful method to find send op,
    // instead of the the hard code string
97
    if (op->Type() == "send") {
Y
fix pe  
Yancey1989 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110
      auto op_vars = op->InputArgumentNames();
      send_vars.reserve(send_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return send_vars;
}

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
    const ProgramDesc &program) const {
  std::vector<std::string> recv_vars;
  for (auto *op : program.Block(0).AllOps()) {
Y
Yancey1989 已提交
111 112 113
    // TODO(Yancey1989): use a graceful method to find recv op,
    // instead of the hard code string
    if (op->Type() == "recv") {
Y
fix pe  
Yancey1989 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126
      auto op_vars = op->OutputArgumentNames();
      recv_vars.reserve(recv_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return recv_vars;
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(
    const OpDesc &op, const std::vector<std::string> &send_vars,
    const std::vector<std::string> &recv_vars) const {
  if (send_vars.size() == 0 || recv_vars.size() == 0) {
T
typhoonzero 已提交
127 128 129
    return false;
  }

Y
Yu Yang 已提交
130 131 132 133
  /**
   * Check any of opvars contains `.block` and in sendvars
   */
  auto checker = [](const std::vector<std::string> &opvars,
Y
fix pe  
Yancey1989 已提交
134
                    const std::vector<std::string> &rpc_vars) -> bool {
T
typhoonzero 已提交
135
    for (auto &var : opvars) {
Y
Yancey1989 已提交
136 137 138
      // a variable name with the suffix `.block` means it's a splited
      // variable by (DistributeTranspiler)
      // [python/paddle/fluid/transpiler/distribute_transpiler.py]
T
typhoonzero 已提交
139
      if (var.find(".block") != std::string::npos &&
Y
fix pe  
Yancey1989 已提交
140
          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
Y
Yu Yang 已提交
141
        return true;
T
typhoonzero 已提交
142 143
      }
    }
Y
Yu Yang 已提交
144
    return false;
T
typhoonzero 已提交
145 146
  };

Y
Yancey1989 已提交
147 148
  return checker(op.OutputArgumentNames(), send_vars) ||
         checker(op.InputArgumentNames(), recv_vars);
T
typhoonzero 已提交
149 150
}

Y
Yancey1989 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
    const std::vector<std::string> &var_names) const {
  int64_t numel_sum = 0;
  for (auto var_name : var_names) {
    auto var_desc = all_vars_.at(var_name);
    PADDLE_ENFORCE_NOT_NULL(var_desc);
    auto dim = framework::make_ddim(var_desc->GetShape());
    int64_t numel = framework::product(dim);
    PADDLE_ENFORCE_GT(numel, 0);
    numel_sum += numel;
  }

  auto smallest =
      std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
  size_t dev_id =
      static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
  balance_vars_[dev_id] += numel_sum;
  return dev_id;
}

Y
Yu Yang 已提交
171 172
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
    const ProgramDesc &program) const {
X
Xin Pan 已提交
173
  std::unique_ptr<Graph> graph(new Graph);
C
fix ci  
chengduoZH 已提交
174
  for (auto *var : program.Block(0).AllVars()) {
Y
Yancey1989 已提交
175
    all_vars_.emplace(var->Name(), var);
C
fix ci  
chengduoZH 已提交
176
  }
C
chengduoZH 已提交
177

X
Xin Pan 已提交
178
  Graph &result = *graph;
C
chengduoZH 已提交
179
  std::unordered_set<std::string> og_has_been_broadcast;
Y
Yu Yang 已提交
180 181

  // We cannot invoke resize. It is a bug of GCC 4.8
X
Xin Pan 已提交
182
  result.attrs["vars"] = new std::vector<
Y
Yu Yang 已提交
183 184
      std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
      places_.size());
X
Xin Pan 已提交
185 186 187
  result.attrs["dep_vars"] =
      new std::unordered_set<std::unique_ptr<VarHandleBase>>();
  result.attrs["ops"] = new std::vector<std::unique_ptr<OpHandleBase>>();
Y
Yu Yang 已提交
188

Y
fix pe  
Yancey1989 已提交
189 190 191 192
  // find send/recv vars so that we can place the distributed training
  // realted op in the place 0
  auto send_vars = FindDistTrainSendVars(program);
  auto recv_vars = FindDistTrainRecvVars(program);
T
typhoonzero 已提交
193

C
chengduoZH 已提交
194 195 196
  std::vector<std::unordered_set<std::string>> bcast_var_name_set;
  bcast_var_name_set.resize(places_.size());

C
chengduoZH 已提交
197
  size_t cur_device_id = 0;
Y
Yu Yang 已提交
198
  bool is_forwarding = true;
199

Y
Yu Yang 已提交
200
  for (auto *op : program.Block(0).AllOps()) {
Y
Yancey1989 已提交
201 202 203
    if (boost::get<int>(
            op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
        static_cast<int>(OpRole::kRPC)) {
Y
Yancey1989 已提交
204
      CreateRPCOp(&result, *op);
Y
fix pe  
Yancey1989 已提交
205
    } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
Y
Yancey1989 已提交
206
      CreateDistTrainOp(&result, *op);
Y
Yu Yang 已提交
207
    } else if (IsScaleLossOp(*op)) {
Y
Yu Yang 已提交
208
      // user can customize loss@grad if not use_default_grad_scale_
Y
yuyang18 已提交
209 210
      if (strategy_.gradient_scale_ !=
          BuildStrategy::GradientScaleStrategy::kCustomized) {
Y
Yu Yang 已提交
211 212
        CreateScaleLossGradOp(&result);
      }
213 214 215 216
      // This assumes the backward generating code will ensure IsScaleLossOp
      // is true only for the op that scale the final scalar loss.
      // It also assumes backward op will always follow the forward op in
      // the block.
Y
Yu Yang 已提交
217
      is_forwarding = false;
Y
Yu Yang 已提交
218
    } else {
219
      int op_dev_id = GetOpDeviceID(*op);
C
chengduo 已提交
220
      if (op_dev_id != -1) {  // This op only runs on one specific device.
C
chengduoZH 已提交
221 222
        CreateComputationalOp(&result, *op, op_dev_id);
        for (auto &var_name : op->OutputArgumentNames()) {
223
          var_name_on_devices_.emplace(var_name, op_dev_id);
C
chengduoZH 已提交
224
        }
C
chengduo 已提交
225 226 227
      } else {
        // This op runs on all devices, and its output may have parameter's
        // gradients.
F
fengjiayi 已提交
228
        if (op->Type() == "read" && strategy_.enable_data_balance_) {
F
fengjiayi 已提交
229 230
          op->SetAttr("throw_eof_exp", false);
          CreateComputationalOps(&result, *op, places_.size());
231 232
          const auto &data_var_names = op->Output("Out");
          InsertDataBalanceOp(&result, data_var_names);
F
fengjiayi 已提交
233 234
        } else {
          CreateComputationalOps(&result, *op, places_.size());
235 236
        }

C
chengduo 已提交
237 238 239 240 241 242 243 244 245 246
        if (!is_forwarding && places_.size() > 1) {
          // Currently, we assume that once gradient is generated, it can be
          // broadcast, and each gradient is only broadcast once.
          if (static_cast<bool>(boost::get<int>(op->GetAttr(
                                    OpProtoAndCheckerMaker::OpRoleAttrName())) &
                                static_cast<int>(OpRole::kBackward))) {
            try {
              auto backward_vars =
                  boost::get<std::vector<std::string>>(op->GetNullableAttr(
                      OpProtoAndCheckerMaker::OpRoleVarAttrName()));
Y
yuyang18 已提交
247

C
chengduo 已提交
248
              PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
Y
yuyang18 已提交
249

C
chengduo 已提交
250 251 252 253
              for (size_t i = 0; i < backward_vars.size(); i += 2) {
                auto &p_name = backward_vars[i];
                auto &g_name = backward_vars[i + 1];
                VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
Y
yuyang18 已提交
254

C
chengduo 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
                switch (strategy_.reduce_) {
                  case BuildStrategy::ReduceStrategy::kReduce:
                    cur_device_id = GetAppropriateDeviceID({g_name});
                    CreateReduceOp(&result, g_name, cur_device_id);
                    var_name_on_devices_.emplace(g_name, cur_device_id);
                    bcast_var_name_set[cur_device_id].emplace(p_name);
                    break;
                  case BuildStrategy::ReduceStrategy::kAllReduce:
                    if (IsSparseGradient(g_name)) {
                      CreateReduceOp(&result, g_name, 0);
                      CreateBroadcastOp(&result, g_name, 0);
                    } else {
                      InsertAllReduceOp(&result, g_name);
                    }
                    break;
                  default:
                    LOG(FATAL) << "Unknown reduce strategy ";
                    break;
                }
Y
yuyang18 已提交
274
              }
C
chengduo 已提交
275
            } catch (boost::bad_get e) {
C
chengduoZH 已提交
276
            }
Y
Yu Yang 已提交
277 278 279 280 281 282
          }
        }
      }
    }
  }

283 284 285 286 287 288 289 290 291 292 293 294 295
  bool use_gpu = false;
#ifdef PADDLE_WITH_CUDA
  use_gpu = nccl_ctxs_ != nullptr;
#endif

  if (use_gpu ||
      strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
    // Insert BCast Ops
    for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
      auto &to_bcast_set = bcast_var_name_set[dev_id];
      for (auto &bcast_name : to_bcast_set) {
        CreateBroadcastOp(&result, bcast_name, dev_id);
      }
C
chengduoZH 已提交
296 297
    }
  }
298

Y
Yu Yang 已提交
299 300
  /*
    Dependency graph has been constructed. However, there are still data
301
    hazards need to be handled.
Y
Yu Yang 已提交
302 303
   */
  PolishGraphToSupportDataHazards(&result);
Y
Yu Yang 已提交
304

Y
Yu Yang 已提交
305 306 307 308 309
  /*
   * Only variables should be the leaves of graph.
   */
  AddOutputToLeafOps(&result);

X
Xin Pan 已提交
310 311 312 313 314 315 316 317 318
  std::unique_ptr<SSAGraph> ssa_graph(new SSAGraph);
  ssa_graph->vars_ =
      std::move(*boost::any_cast<GraphVars *>(graph->attrs["vars"]));
  ssa_graph->ops_ =
      std::move(*boost::any_cast<GraphOps *>(graph->attrs["ops"]));
  ssa_graph->dep_vars_ =
      std::move(*boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"]));

  return std::move(ssa_graph);
Y
Yu Yang 已提交
319 320
}

Y
Yancey1989 已提交
321 322 323
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
  PADDLE_ENFORCE(all_vars_.count(og) != 0);
  if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
C
fix ci  
chengduoZH 已提交
324 325 326
    return true;
  }
  return false;
327 328
}

329 330 331 332 333 334 335 336 337 338 339 340 341
void MultiDevSSAGraphBuilder::SetCommunicationContext(
    OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
  if (nccl_ctxs_ == nullptr) {
    op_handle->SetDeviceContext(p,
                                platform::DeviceContextPool::Instance().Get(p));
  }
#else
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
#endif
}

X
Xin Pan 已提交
342
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
C
chengduoZH 已提交
343
                                                const std::string &p_name,
C
chengduoZH 已提交
344
                                                size_t src_dev_id) const {
C
chengduoZH 已提交
345 346 347 348 349 350
#ifdef PADDLE_WITH_CUDA
  auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
#else
  auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
#endif

X
Xin Pan 已提交
351 352 353 354 355 356
  boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
  auto *in = boost::any_cast<GraphVars *>(result->attrs["vars"])
                 ->at(src_dev_id)
                 .at(p_name)
                 .back()
                 .get();
C
chengduoZH 已提交
357 358 359 360
  op_handle->AddInput(in);

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
361
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
362 363
    auto &vars =
        boost::any_cast<GraphVars *>(result->attrs["vars"])->at(i).at(p_name);
C
chengduoZH 已提交
364 365 366 367 368 369
    auto *out_var = new VarHandle(vars.size(), i, p_name, p);
    vars.emplace_back(out_var);
    op_handle->AddOutput(out_var);
  }
}

X
Xin Pan 已提交
370
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
C
chengduoZH 已提交
371 372
                                                    const OpDesc &op,
                                                    int dev_id) const {
X
Xin Pan 已提交
373 374 375
  boost::any_cast<GraphOps *>(result->attrs["ops"])
      ->emplace_back(
          new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
C
chengduoZH 已提交
376 377 378
  CreateOpHandleIOs(result, op, dev_id);
}

X
Xin Pan 已提交
379
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
C
chengduoZH 已提交
380
                                                const std::string &og) const {
Y
Yu Yang 已提交
381
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
382 383
  boost::any_cast<GraphOps *>(result->attrs["ops"])
      ->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
384
#else
X
Xin Pan 已提交
385 386
  boost::any_cast<GraphOps *>(result->attrs["ops"])
      ->emplace_back(new AllReduceOpHandle(local_scopes_, places_));
C
chengduoZH 已提交
387
#endif
X
Xin Pan 已提交
388 389
  auto *op_handle =
      boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
Y
Yu Yang 已提交
390 391 392

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
393
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
394
    auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
Y
Yu Yang 已提交
395 396
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
Y
Yu Yang 已提交
397 398
    op_handle->AddInput(prev_grad.get());

399
    auto var = new VarHandle(vars.size(), i, og, p);
Y
Yu Yang 已提交
400 401 402 403 404
    vars.emplace_back(var);
    op_handle->AddOutput(var);
  }
}

405
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
X
Xin Pan 已提交
406
    Graph *result, const std::vector<std::string> &datas) const {
F
fengjiayi 已提交
407
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
408 409 410
  boost::any_cast<GraphOps *>(result->attrs["ops"])
      ->emplace_back(
          new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
F
fengjiayi 已提交
411
#else
X
Xin Pan 已提交
412 413
  boost::any_cast<GraphOps *>(result->attrs["ops"])
      ->emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
F
fengjiayi 已提交
414
#endif
X
Xin Pan 已提交
415 416
  auto *op_handle =
      boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
417 418 419 420
  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
    SetCommunicationContext(op_handle, p);
    for (const std::string &d_name : datas) {
X
Xin Pan 已提交
421 422
      auto &vars =
          (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][d_name];
423 424 425 426 427 428 429 430 431
      PADDLE_ENFORCE(!vars.empty());
      op_handle->AddInput(vars.back().get());
      auto var = new VarHandle(vars.size(), i, d_name, p);
      vars.emplace_back(var);
      op_handle->AddOutput(var);
    }
  }
}

Y
Yu Yang 已提交
432 433 434 435 436 437 438 439 440 441 442 443
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
    const std::string &og,
    std::unordered_set<std::string> *og_has_been_broadcast) const {
  bool is_pg_once =
      grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0;
  if (is_pg_once) {
    // Insert NCCL AllReduce Op
    og_has_been_broadcast->insert(og);
  }
  return is_pg_once;
}

444
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
Y
yuyang18 已提交
445
  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
C
chengduoZH 已提交
446 447
    return -1;
  }
448 449 450 451
  int op_role = boost::get<int>(
      op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
  if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
    return -1;
C
chengduoZH 已提交
452
  }
453 454 455 456 457 458 459 460
  auto param_grad = boost::get<std::vector<std::string>>(
      op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));

  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
  int dev_id = GetVarDeviceID(param_grad[1]);
  PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(),
                    param_grad[0]);
  return dev_id;
461 462 463 464 465
}

int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
  auto got = var_name_on_devices_.find(varname);
  return got == var_name_on_devices_.end() ? -1 : got->second;
C
chengduoZH 已提交
466 467
}

X
Xin Pan 已提交
468
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
Y
Yu Yang 已提交
469 470 471
  for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
C
chengduoZH 已提交
472 473 474
    auto *communication_dev_ctx =
        nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
                   : platform::DeviceContextPool::Instance().Get(places_[i]);
Y
Yu Yang 已提交
475 476 477 478 479 480 481 482
#else
    auto *communication_dev_ctx =
        platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif

    auto *op_handle =
        new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
                                  places_[i], communication_dev_ctx);
X
Xin Pan 已提交
483
    boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
Y
Yu Yang 已提交
484 485 486 487 488 489 490 491 492 493 494 495

    // FIXME: Currently ScaleLossGradOp only use device_count as scale
    // factor. So it does not depend on any other operators.
    // VarHandle *loss = GetVarHandle(loss_var_name, place);
    // loss->pending_ops_.emplace_back(op_handle);
    // op_handle->inputs_.emplace_back(loss);

    CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i],
                   i);
  }
}

X
Xin Pan 已提交
496
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
T
typhoonzero 已提交
497 498 499
                                                     const OpDesc &op,
                                                     size_t num_places) const {
  for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
Y
Yu Yang 已提交
500 501
    auto p = places_[scope_idx];
    auto s = local_scopes_[scope_idx];
X
Xin Pan 已提交
502 503
    boost::any_cast<GraphOps *>(result->attrs["ops"])
        ->emplace_back(new ComputationOpHandle(op, s, p));
Y
Yu Yang 已提交
504
    CreateOpHandleIOs(result, op, scope_idx);
Y
Yu Yang 已提交
505 506 507
  }
}

X
Xin Pan 已提交
508
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
C
chengduoZH 已提交
509 510
                                                   const std::string &og,
                                                   int dst_dev_id) const {
C
chengduoZH 已提交
511
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
512 513
  boost::any_cast<GraphOps *>(result->attrs["ops"])
      ->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
514
#else
X
Xin Pan 已提交
515 516
  boost::any_cast<GraphOps *>(result->attrs["ops"])
      ->emplace_back(new ReduceOpHandle(local_scopes_, places_));
C
chengduoZH 已提交
517
#endif
X
Xin Pan 已提交
518 519
  auto *op_handle =
      boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
C
chengduoZH 已提交
520 521 522

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
523
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
524
    auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
C
chengduoZH 已提交
525 526 527 528
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
    op_handle->AddInput(prev_grad.get());
  }
X
Xin Pan 已提交
529 530
  auto &vars =
      (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[dst_dev_id][og];
531
  auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
C
chengduoZH 已提交
532 533 534 535 536
  vars.emplace_back(var);
  op_handle->AddOutput(var);
  return var;
}

537 538
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
X
Xin Pan 已提交
539
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
Y
fix pe  
Yancey1989 已提交
540
                                        const std::string &prev_op_name) const {
X
Xin Pan 已提交
541
  for (auto &prev_op : (*boost::any_cast<GraphOps *>(result->attrs["ops"]))) {
Y
fix pe  
Yancey1989 已提交
542
    if (prev_op->Name() == prev_op_name) {
Y
Yancey1989 已提交
543 544
      auto *dep_var = new DummyVarHandle();
      prev_op->AddOutput(dep_var);
X
Xin Pan 已提交
545 546
      boost::any_cast<GraphDepVars *>(result->attrs["dep_vars"])
          ->emplace(dep_var);
Y
fix pe  
Yancey1989 已提交
547
      op->AddInput(dep_var);
Y
Yancey1989 已提交
548 549 550 551
    }
  }
}

X
Xin Pan 已提交
552
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
Y
Yancey1989 已提交
553 554
                                                const OpDesc &op) const {
  int op_dev_id = -1;
Y
yi.wu 已提交
555
  if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
Y
Yancey1989 已提交
556 557 558 559 560 561 562 563 564 565 566 567
    op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
      op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
      for (auto &varname : op.InputArgumentNames()) {
        var_name_on_devices_.emplace(varname, op_dev_id);
      }
    }
    for (auto &varname : op.OutputArgumentNames()) {
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
  } else if (op.Type() == "concat") {
    op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
Y
yi.wu 已提交
568 569 570
    for (auto &varname : op.OutputArgumentNames()) {
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
Y
Yancey1989 已提交
571 572 573 574 575 576 577 578 579 580
  } else {
    PADDLE_ENFORCE(
        "the distribute training related op should be in [split_byref, "
        "concat].");
  }

  PADDLE_ENFORCE(op_dev_id != -1,
                 "can not find right place for distributed op: %s", op.Type());

  CreateComputationalOp(result, op, op_dev_id);
Y
Yancey1989 已提交
581
  if (op.Type() == "concat") {
X
Xin Pan 已提交
582 583 584
    ConnectOp(result,
              boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
              "fetch_barrier");
Y
Yancey1989 已提交
585 586 587
  }
}

588
// Create RPC related op handles that connects its in ops and out ops.
X
Xin Pan 已提交
589
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
Y
Yancey1989 已提交
590 591 592 593 594 595
                                          const OpDesc &op) const {
  int op_dev_id = -1;
  if (op.Type() == "send") {
    op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
    // the variable name which contains .block means it was splited by
    // split_byref op
596 597
    // so that we can balance the variable blocks to all the pserver
    // instances.
Y
Yancey1989 已提交
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
        op.InputArgumentNames()[0].find(".block") == std::string::npos) {
      op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
      for (auto &varname : op.InputArgumentNames()) {
        var_name_on_devices_.emplace(varname, op_dev_id);
      }
    }
  } else if (op.Type() == "recv") {
    op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
    for (auto &varname : op.OutputArgumentNames()) {
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
  } else {
    // send_barrier and fetch_barrier op can be scheduled on device 0
    op_dev_id = 0;
  }

  PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
                 op.Type());

X
Xin Pan 已提交
618 619 620
  boost::any_cast<GraphOps *>(result->attrs["ops"])
      ->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(),
                                     places_[op_dev_id]));
Y
fix pe  
Yancey1989 已提交
621

Y
Yancey1989 已提交
622
  if (op.Type() == "send_barrier") {
X
Xin Pan 已提交
623 624 625
    ConnectOp(result,
              boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
              "send");
Y
Yancey1989 已提交
626
  } else if (op.Type() == "recv") {
X
Xin Pan 已提交
627 628 629
    ConnectOp(result,
              boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
              "send_barrier");
Y
Yancey1989 已提交
630
  } else if (op.Type() == "fetch_barrier") {
X
Xin Pan 已提交
631 632 633
    ConnectOp(result,
              boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
              "recv");
634
  } else if (op.Type() == "send") {
Y
Yancey1989 已提交
635 636 637
    // do nothing
  } else {
    PADDLE_THROW(
Y
Yancey1989 已提交
638
        "rpc op should be in ["
639
        "send, send_barrier. recv, fetch_barrier]");
Y
Yancey1989 已提交
640 641
  }

Y
Yancey1989 已提交
642
  CreateOpHandleIOs(result, op, op_dev_id);
Y
Yu Yang 已提交
643 644 645
}

bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
Y
yuyang18 已提交
646 647
  return boost::get<int>(
             op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Fix bug  
yuyang18 已提交
648 649 650
             (static_cast<int>(OpRole::kBackward) |
              static_cast<int>(OpRole::kLoss)) &&
         !loss_var_name_.empty();  // If loss_var is empty. This is test mode
Y
Yu Yang 已提交
651
}
Y
Yu Yang 已提交
652 653 654
}  // namespace details
}  // namespace framework
}  // namespace paddle