multi_devices_graph_builder.cc 27.6 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
C
chengduoZH 已提交
14
#include <algorithm>
Y
Yancey1989 已提交
15
#include <fstream>
C
chengduoZH 已提交
16
#include <string>
C
chengduoZH 已提交
17
#include <utility>
C
chengduoZH 已提交
18 19
#include <vector>

20
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
C
chengduoZH 已提交
21
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
Y
Yu Yang 已提交
22
#include "paddle/fluid/framework/details/computation_op_handle.h"
23
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
C
chengduoZH 已提交
24
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
C
chengduoZH 已提交
25
#include "paddle/fluid/framework/details/reduce_op_handle.h"
Y
Yancey1989 已提交
26
#include "paddle/fluid/framework/details/rpc_op_handle.h"
Y
Yu Yang 已提交
27
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
X
better  
Xin Pan 已提交
28
#include "paddle/fluid/framework/ir/graph_helper.h"
X
Xin Pan 已提交
29
#include "paddle/fluid/framework/ir/node.h"
Y
Fix bug  
yuyang18 已提交
30
#include "paddle/fluid/framework/op_info.h"
Y
Yu Yang 已提交
31
#include "paddle/fluid/framework/scope.h"
Y
Yu Yang 已提交
32

Y
Yu Yang 已提交
33 34 35
namespace paddle {
namespace framework {
namespace details {
Y
Yu Yang 已提交
36 37

#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
38 39 40 41
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
C
chengduoZH 已提交
42
    const std::vector<Scope *> &local_scopes,
Y
yuyang18 已提交
43
    platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
Y
Yu Yang 已提交
44 45 46
    : loss_var_name_(loss_var_name),
      places_(places),
      local_scopes_(local_scopes),
C
chengduoZH 已提交
47
      nccl_ctxs_(nccl_ctxs),
Y
yuyang18 已提交
48
      strategy_(strategy) {
Y
Yu Yang 已提交
49 50 51 52 53
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
Y
yuyang18 已提交
54
    const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
Y
Yu Yang 已提交
55 56
    : loss_var_name_(loss_var_name),
      places_(places),
C
chengduoZH 已提交
57
      local_scopes_(local_scopes),
Y
yuyang18 已提交
58
      strategy_(strategy) {
Y
Yu Yang 已提交
59
#endif
Y
Yu Yang 已提交
60 61 62
  for (auto &p : params) {
    grad_names_.insert(GradVarName(p));
  }
Y
Yancey1989 已提交
63
  balance_vars_.resize(places_.size(), 0);
Y
yuyang18 已提交
64 65 66 67 68
  if (strategy_.enable_data_balance_ && places_.size() == 1) {
    LOG(WARNING) << "It is no need to enable data balance when there is only "
                    "one place. enable_data_balance is set to False.";
    strategy_.enable_data_balance_ = false;
  }
Y
Yu Yang 已提交
69 70
}

X
Xin Pan 已提交
71 72
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
                                                ir::Node *node,
Y
Yu Yang 已提交
73 74
                                                size_t place_id) const {
  auto p = places_[place_id];
X
Xin Pan 已提交
75
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
X
Xin Pan 已提交
76 77
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
T
wip  
typhoonzero 已提交
78

79 80
  for (ir::Node *input : node->inputs) {
    VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
T
wip  
typhoonzero 已提交
81 82 83
    op_handle->AddInput(var);
  }

84
  for (ir::Node *output : node->outputs) {
X
polish  
Xin Pan 已提交
85 86 87 88 89 90 91 92
    ir::Node *new_node = nullptr;
    if (output->Var()) {
      new_node = result->CreateVarNode(output->Var());
    } else {
      new_node =
          result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
    }
    CreateOpOutput(result, op_handle, new_node, p, place_id);
T
wip  
typhoonzero 已提交
93 94
  }
}
Y
fix pe  
Yancey1989 已提交
95 96

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
97
    const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
Y
fix pe  
Yancey1989 已提交
98
  std::vector<std::string> send_vars;
Y
Yancey1989 已提交
99 100
  // since parameters are all in block 0,
  // it's enough to only scan send ops in block 0
101
  for (auto &node : nodes) {
X
Xin Pan 已提交
102
    if (node->NodeType() != ir::Node::Type::kOperation) continue;
103
    OpDesc *op = node->Op();
Y
Yancey1989 已提交
104 105
    // TODO(Yancey1989): use a graceful method to find send op,
    // instead of the the hard code string
106
    if (op->Type() == "send") {
Y
fix pe  
Yancey1989 已提交
107 108 109 110 111 112 113 114 115 116
      auto op_vars = op->InputArgumentNames();
      send_vars.reserve(send_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return send_vars;
}

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
117
    const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
Y
fix pe  
Yancey1989 已提交
118
  std::vector<std::string> recv_vars;
119
  for (auto &node : nodes) {
X
Xin Pan 已提交
120
    if (node->NodeType() != ir::Node::Type::kOperation) continue;
121
    OpDesc *op = node->Op();
Y
Yancey1989 已提交
122 123 124
    // TODO(Yancey1989): use a graceful method to find recv op,
    // instead of the hard code string
    if (op->Type() == "recv") {
Y
fix pe  
Yancey1989 已提交
125 126 127 128 129 130 131 132 133 134
      auto op_vars = op->OutputArgumentNames();
      recv_vars.reserve(recv_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return recv_vars;
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(
135
    ir::Node *node, const std::vector<std::string> &send_vars,
Y
fix pe  
Yancey1989 已提交
136 137
    const std::vector<std::string> &recv_vars) const {
  if (send_vars.size() == 0 || recv_vars.size() == 0) {
T
typhoonzero 已提交
138 139 140
    return false;
  }

Y
Yu Yang 已提交
141 142 143 144
  /**
   * Check any of opvars contains `.block` and in sendvars
   */
  auto checker = [](const std::vector<std::string> &opvars,
Y
fix pe  
Yancey1989 已提交
145
                    const std::vector<std::string> &rpc_vars) -> bool {
T
typhoonzero 已提交
146
    for (auto &var : opvars) {
Y
Yancey1989 已提交
147 148 149
      // a variable name with the suffix `.block` means it's a splited
      // variable by (DistributeTranspiler)
      // [python/paddle/fluid/transpiler/distribute_transpiler.py]
T
typhoonzero 已提交
150
      if (var.find(".block") != std::string::npos &&
Y
fix pe  
Yancey1989 已提交
151
          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
Y
Yu Yang 已提交
152
        return true;
T
typhoonzero 已提交
153 154
      }
    }
Y
Yu Yang 已提交
155
    return false;
T
typhoonzero 已提交
156 157
  };

158 159 160
  std::vector<std::string> input_var_names;
  std::vector<std::string> output_var_names;
  for (ir::Node *input : node->inputs) {
X
Xin Pan 已提交
161
    input_var_names.push_back(input->Name());
162 163
  }
  for (ir::Node *output : node->outputs) {
X
Xin Pan 已提交
164
    output_var_names.push_back(output->Name());
165 166 167 168
  }

  return checker(output_var_names, send_vars) ||
         checker(input_var_names, recv_vars);
T
typhoonzero 已提交
169 170
}

Y
Yancey1989 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
    const std::vector<std::string> &var_names) const {
  int64_t numel_sum = 0;
  for (auto var_name : var_names) {
    auto var_desc = all_vars_.at(var_name);
    PADDLE_ENFORCE_NOT_NULL(var_desc);
    auto dim = framework::make_ddim(var_desc->GetShape());
    int64_t numel = framework::product(dim);
    PADDLE_ENFORCE_GT(numel, 0);
    numel_sum += numel;
  }

  auto smallest =
      std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
  size_t dev_id =
      static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
  balance_vars_[dev_id] += numel_sum;
  return dev_id;
}

X
better  
Xin Pan 已提交
191 192 193 194 195
// Topology sort the graph nodes from inputs to outputs.
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
X
Xin Pan 已提交
196 197 198
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
  std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
X
better  
Xin Pan 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
  size_t last_backward = 0;
  std::vector<ir::Node *> optimize_ops;
  std::vector<ir::Node *> sorted_ret;
  for (size_t i = 0; i < ret.size(); ++i) {
    if (boost::get<int>(
            ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
        static_cast<int>(OpRole::kBackward)) {
      sorted_ret.push_back(ret[i]);
      last_backward = sorted_ret.size();
    } else if (boost::get<int>(ret[i]->Op()->GetAttr(
                   OpProtoAndCheckerMaker::OpRoleAttrName())) ==
               static_cast<int>(OpRole::kOptimize)) {
      optimize_ops.push_back(ret[i]);
    } else {
      sorted_ret.push_back(ret[i]);
    }
  }

  sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
                    optimize_ops.end());

  for (ir::Node *n : sorted_ret) {
    n->inputs.erase(std::remove_if(n->inputs.begin(), n->inputs.end(),
                                   [n](ir::Node *t) {
                                     return t->Name() ==
                                            ir::Node::kControlDepVarName;
                                   }),
                    n->inputs.end());
    n->outputs.erase(std::remove_if(n->outputs.begin(), n->outputs.end(),
                                    [n](ir::Node *t) {
                                      return t->Name() ==
                                             ir::Node::kControlDepVarName;
                                    }),
                     n->outputs.end());
  }
  return sorted_ret;
}

X
Xin Pan 已提交
237 238
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
    std::unique_ptr<ir::Graph> graph) const {
X
Xin Pan 已提交
239
  // Rebuild the graph structure.
X
better  
Xin Pan 已提交
240
  std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
241 242 243 244
  auto nodes = std::move(graph->nodes);
  graph->nodes.clear();

  for (auto &node : nodes) {
X
Xin Pan 已提交
245 246
    if (node->NodeType() == ir::Node::Type::kVariable) {
      all_vars_.emplace(node->Name(), node->Var());
247
    }
C
fix ci  
chengduoZH 已提交
248
  }
C
chengduoZH 已提交
249

X
Xin Pan 已提交
250
  ir::Graph &result = *graph;
C
chengduoZH 已提交
251
  std::unordered_set<std::string> og_has_been_broadcast;
Y
Yu Yang 已提交
252 253

  // We cannot invoke resize. It is a bug of GCC 4.8
X
Xin Pan 已提交
254 255 256
  result.Set("vars", new GraphVars(places_.size()));
  result.Set("dep_vars", new GraphDepVars);
  result.Set("ops", new GraphOps);
257

Y
fix pe  
Yancey1989 已提交
258 259
  // find send/recv vars so that we can place the distributed training
  // realted op in the place 0
260 261
  auto send_vars = FindDistTrainSendVars(nodes);
  auto recv_vars = FindDistTrainRecvVars(nodes);
T
typhoonzero 已提交
262

C
chengduoZH 已提交
263 264 265
  std::vector<std::unordered_set<std::string>> bcast_var_name_set;
  bcast_var_name_set.resize(places_.size());

C
chengduoZH 已提交
266
  size_t cur_device_id = 0;
Y
Yu Yang 已提交
267
  bool is_forwarding = true;
268

X
better  
Xin Pan 已提交
269
  for (ir::Node *node : sorted_ops) {
Y
Yancey1989 已提交
270
    if (boost::get<int>(
271
            node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Yancey1989 已提交
272
        static_cast<int>(OpRole::kRPC)) {
X
Xin Pan 已提交
273 274 275 276
      CreateRPCOp(&result, node);
    } else if (IsDistTrainOp(node, send_vars, recv_vars)) {
      CreateDistTrainOp(&result, node);
    } else if (IsScaleLossOp(node)) {
Y
Yu Yang 已提交
277
      // user can customize loss@grad if not use_default_grad_scale_
Y
yuyang18 已提交
278 279
      if (strategy_.gradient_scale_ !=
          BuildStrategy::GradientScaleStrategy::kCustomized) {
Y
Yu Yang 已提交
280 281
        CreateScaleLossGradOp(&result);
      }
282 283 284 285
      // This assumes the backward generating code will ensure IsScaleLossOp
      // is true only for the op that scale the final scalar loss.
      // It also assumes backward op will always follow the forward op in
      // the block.
Y
Yu Yang 已提交
286
      is_forwarding = false;
Y
Yu Yang 已提交
287
    } else {
X
Xin Pan 已提交
288
      int op_dev_id = GetOpDeviceID(node);
C
chengduo 已提交
289
      if (op_dev_id != -1) {  // This op only runs on one specific device.
X
Xin Pan 已提交
290
        CreateComputationalOp(&result, node, op_dev_id);
291
        for (ir::Node *n : node->outputs) {
X
Xin Pan 已提交
292
          var_name_on_devices_.emplace(n->Name(), op_dev_id);
C
chengduoZH 已提交
293
        }
C
chengduo 已提交
294 295 296
      } else {
        // This op runs on all devices, and its output may have parameter's
        // gradients.
297 298
        if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
          node->Op()->SetAttr("throw_eof_exp", false);
X
Xin Pan 已提交
299
          CreateComputationalOps(&result, node, places_.size());
300
          const auto &data_var_names = node->Op()->Output("Out");
301
          InsertDataBalanceOp(&result, data_var_names);
F
fengjiayi 已提交
302
        } else {
X
Xin Pan 已提交
303
          CreateComputationalOps(&result, node, places_.size());
304 305
        }

C
chengduo 已提交
306 307 308
        if (!is_forwarding && places_.size() > 1) {
          // Currently, we assume that once gradient is generated, it can be
          // broadcast, and each gradient is only broadcast once.
309
          if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
C
chengduo 已提交
310 311 312
                                    OpProtoAndCheckerMaker::OpRoleAttrName())) &
                                static_cast<int>(OpRole::kBackward))) {
            try {
313 314
              auto backward_vars = boost::get<std::vector<std::string>>(
                  node->Op()->GetNullableAttr(
C
chengduo 已提交
315
                      OpProtoAndCheckerMaker::OpRoleVarAttrName()));
Y
yuyang18 已提交
316

C
chengduo 已提交
317
              PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
Y
yuyang18 已提交
318

C
chengduo 已提交
319 320 321 322
              for (size_t i = 0; i < backward_vars.size(); i += 2) {
                auto &p_name = backward_vars[i];
                auto &g_name = backward_vars[i + 1];
                VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
Y
yuyang18 已提交
323

C
chengduo 已提交
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
                switch (strategy_.reduce_) {
                  case BuildStrategy::ReduceStrategy::kReduce:
                    cur_device_id = GetAppropriateDeviceID({g_name});
                    CreateReduceOp(&result, g_name, cur_device_id);
                    var_name_on_devices_.emplace(g_name, cur_device_id);
                    bcast_var_name_set[cur_device_id].emplace(p_name);
                    break;
                  case BuildStrategy::ReduceStrategy::kAllReduce:
                    if (IsSparseGradient(g_name)) {
                      CreateReduceOp(&result, g_name, 0);
                      CreateBroadcastOp(&result, g_name, 0);
                    } else {
                      InsertAllReduceOp(&result, g_name);
                    }
                    break;
                  default:
                    LOG(FATAL) << "Unknown reduce strategy ";
                    break;
                }
Y
yuyang18 已提交
343
              }
C
chengduo 已提交
344
            } catch (boost::bad_get e) {
C
chengduoZH 已提交
345
            }
Y
Yu Yang 已提交
346 347 348 349 350 351
          }
        }
      }
    }
  }

352 353 354 355 356 357 358 359 360 361 362 363 364
  bool use_gpu = false;
#ifdef PADDLE_WITH_CUDA
  use_gpu = nccl_ctxs_ != nullptr;
#endif

  if (use_gpu ||
      strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
    // Insert BCast Ops
    for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
      auto &to_bcast_set = bcast_var_name_set[dev_id];
      for (auto &bcast_name : to_bcast_set) {
        CreateBroadcastOp(&result, bcast_name, dev_id);
      }
C
chengduoZH 已提交
365 366
    }
  }
367

Y
Yu Yang 已提交
368 369
  /*
    Dependency graph has been constructed. However, there are still data
370
    hazards need to be handled.
Y
Yu Yang 已提交
371 372
   */
  PolishGraphToSupportDataHazards(&result);
Y
Yu Yang 已提交
373

Y
Yu Yang 已提交
374 375 376 377
  /*
   * Only variables should be the leaves of graph.
   */
  AddOutputToLeafOps(&result);
Q
qiaolongfei 已提交
378
  return graph;
Y
Yu Yang 已提交
379 380
}

Y
Yancey1989 已提交
381 382 383
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
  PADDLE_ENFORCE(all_vars_.count(og) != 0);
  if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
C
fix ci  
chengduoZH 已提交
384 385 386
    return true;
  }
  return false;
387 388
}

389 390 391 392 393 394 395 396 397 398 399 400 401
void MultiDevSSAGraphBuilder::SetCommunicationContext(
    OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
  if (nccl_ctxs_ == nullptr) {
    op_handle->SetDeviceContext(p,
                                platform::DeviceContextPool::Instance().Get(p));
  }
#else
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
#endif
}

X
Xin Pan 已提交
402
void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
C
chengduoZH 已提交
403
                                                const std::string &p_name,
C
chengduoZH 已提交
404
                                                size_t src_dev_id) const {
C
chengduoZH 已提交
405
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
406 407 408
  auto *op_handle = new BroadcastOpHandle(
      result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_);
C
chengduoZH 已提交
409
#else
X
polish  
Xin Pan 已提交
410 411 412
  auto *op_handle = new BroadcastOpHandle(
      result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
      local_scopes_, places_);
C
chengduoZH 已提交
413
#endif
X
Xin Pan 已提交
414
  result->Get<GraphOps>("ops").emplace_back(op_handle);
X
Xin Pan 已提交
415

X
Xin Pan 已提交
416 417
  auto *in =
      result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
C
chengduoZH 已提交
418 419 420 421
  op_handle->AddInput(in);

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
422
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
423
    auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
X
polish  
Xin Pan 已提交
424 425 426
    auto *out_var = new VarHandle(
        result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
        i, p_name, p);
C
chengduoZH 已提交
427 428 429 430 431
    vars.emplace_back(out_var);
    op_handle->AddOutput(out_var);
  }
}

X
Xin Pan 已提交
432
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
433
                                                    ir::Node *node,
C
chengduoZH 已提交
434
                                                    int dev_id) const {
435
  result->Get<GraphOps>("ops").emplace_back(
X
Xin Pan 已提交
436
      new ComputationOpHandle(result->CreateOpNode(node->Op()),
437 438
                              local_scopes_[dev_id], places_[dev_id]));
  CreateOpHandleIOs(result, node, dev_id);
C
chengduoZH 已提交
439 440
}

X
Xin Pan 已提交
441
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
C
chengduoZH 已提交
442
                                                const std::string &og) const {
Y
Yu Yang 已提交
443
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
444 445 446
  result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
      result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
447
#else
X
Xin Pan 已提交
448
  result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
X
polish  
Xin Pan 已提交
449 450
      result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
      local_scopes_, places_));
C
chengduoZH 已提交
451
#endif
X
Xin Pan 已提交
452
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
Y
Yu Yang 已提交
453 454 455

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
456
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
457
    auto &vars = result->Get<GraphVars>("vars")[i][og];
Y
Yu Yang 已提交
458 459
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
Y
Yu Yang 已提交
460 461
    op_handle->AddInput(prev_grad.get());

X
Xin Pan 已提交
462
    auto var =
X
polish  
Xin Pan 已提交
463 464
        new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
                      vars.size(), i, og, p);
Y
Yu Yang 已提交
465 466 467 468 469
    vars.emplace_back(var);
    op_handle->AddOutput(var);
  }
}

470
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
X
Xin Pan 已提交
471
    ir::Graph *result, const std::vector<std::string> &datas) const {
F
fengjiayi 已提交
472
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
473 474 475
  result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
F
fengjiayi 已提交
476
#else
X
Xin Pan 已提交
477
  result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
X
polish  
Xin Pan 已提交
478 479
      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
      local_scopes_, places_));
F
fengjiayi 已提交
480
#endif
X
Xin Pan 已提交
481
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
482 483 484 485
  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
    SetCommunicationContext(op_handle, p);
    for (const std::string &d_name : datas) {
X
Xin Pan 已提交
486
      auto &vars = result->Get<GraphVars>("vars")[i][d_name];
487 488
      PADDLE_ENFORCE(!vars.empty());
      op_handle->AddInput(vars.back().get());
X
polish  
Xin Pan 已提交
489 490 491
      auto var = new VarHandle(
          result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
          vars.size(), i, d_name, p);
492 493 494 495 496 497
      vars.emplace_back(var);
      op_handle->AddOutput(var);
    }
  }
}

Y
Yu Yang 已提交
498 499 500 501 502 503 504 505 506 507 508 509
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
    const std::string &og,
    std::unordered_set<std::string> *og_has_been_broadcast) const {
  bool is_pg_once =
      grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0;
  if (is_pg_once) {
    // Insert NCCL AllReduce Op
    og_has_been_broadcast->insert(og);
  }
  return is_pg_once;
}

510
int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
Y
yuyang18 已提交
511
  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
C
chengduoZH 已提交
512 513
    return -1;
  }
514
  int op_role = boost::get<int>(
515
      node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
516 517
  if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
    return -1;
C
chengduoZH 已提交
518
  }
519
  auto param_grad = boost::get<std::vector<std::string>>(
X
Xin Pan 已提交
520
      node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
521 522 523

  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
  int dev_id = GetVarDeviceID(param_grad[1]);
X
Xin Pan 已提交
524 525
  PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
                    node->Op()->Type(), param_grad[0], param_grad[1]);
526
  return dev_id;
527 528 529 530 531
}

int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
  auto got = var_name_on_devices_.find(varname);
  return got == var_name_on_devices_.end() ? -1 : got->second;
C
chengduoZH 已提交
532 533
}

X
Xin Pan 已提交
534
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
Y
Yu Yang 已提交
535 536 537
  for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
C
chengduoZH 已提交
538 539 540
    auto *communication_dev_ctx =
        nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
                   : platform::DeviceContextPool::Instance().Get(places_[i]);
Y
Yu Yang 已提交
541 542 543 544
#else
    auto *communication_dev_ctx =
        platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif
X
Xin Pan 已提交
545
    auto *op_handle = new ScaleLossGradOpHandle(
X
polish  
Xin Pan 已提交
546 547 548
        result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
        local_scopes_.size(), local_scopes_[i], places_[i],
        communication_dev_ctx);
X
Xin Pan 已提交
549
    result->Get<GraphOps>("ops").emplace_back(op_handle);
Y
Yu Yang 已提交
550 551 552 553 554 555 556

    // FIXME: Currently ScaleLossGradOp only use device_count as scale
    // factor. So it does not depend on any other operators.
    // VarHandle *loss = GetVarHandle(loss_var_name, place);
    // loss->pending_ops_.emplace_back(op_handle);
    // op_handle->inputs_.emplace_back(loss);

X
polish  
Xin Pan 已提交
557 558 559 560
    CreateOpOutput(result, op_handle,
                   result->CreateEmptyNode(GradVarName(loss_var_name_),
                                           ir::Node::Type::kVariable),
                   places_[i], i);
Y
Yu Yang 已提交
561 562 563
  }
}

X
Xin Pan 已提交
564
void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
565
                                                     ir::Node *node,
T
typhoonzero 已提交
566 567
                                                     size_t num_places) const {
  for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
Y
Yu Yang 已提交
568 569
    auto p = places_[scope_idx];
    auto s = local_scopes_[scope_idx];
X
Xin Pan 已提交
570 571
    result->Get<GraphOps>("ops").emplace_back(
        new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
572
    CreateOpHandleIOs(result, node, scope_idx);
Y
Yu Yang 已提交
573 574 575
  }
}

X
Xin Pan 已提交
576
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
C
chengduoZH 已提交
577 578
                                                   const std::string &og,
                                                   int dst_dev_id) const {
C
chengduoZH 已提交
579
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
580
  result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
X
polish  
Xin Pan 已提交
581 582
      result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
583
#else
584
  result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
X
polish  
Xin Pan 已提交
585 586
      result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
      local_scopes_, places_));
C
chengduoZH 已提交
587
#endif
X
Xin Pan 已提交
588
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
C
chengduoZH 已提交
589 590 591

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
592
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
593
    auto &vars = result->Get<GraphVars>("vars")[i][og];
C
chengduoZH 已提交
594 595 596 597
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
    op_handle->AddInput(prev_grad.get());
  }
X
Xin Pan 已提交
598
  auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
X
polish  
Xin Pan 已提交
599 600 601
  auto var =
      new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
                    vars.size(), dst_dev_id, og, places_[dst_dev_id]);
C
chengduoZH 已提交
602 603 604 605 606
  vars.emplace_back(var);
  op_handle->AddOutput(var);
  return var;
}

607 608
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
X
Xin Pan 已提交
609
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
Y
fix pe  
Yancey1989 已提交
610
                                        const std::string &prev_op_name) const {
X
Xin Pan 已提交
611
  for (auto &prev_op : result->Get<GraphOps>("ops")) {
Y
fix pe  
Yancey1989 已提交
612
    if (prev_op->Name() == prev_op_name) {
X
polish  
Xin Pan 已提交
613 614
      auto *dep_var = new DummyVarHandle(
          result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
Y
Yancey1989 已提交
615
      prev_op->AddOutput(dep_var);
X
Xin Pan 已提交
616
      result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
Y
fix pe  
Yancey1989 已提交
617
      op->AddInput(dep_var);
Y
Yancey1989 已提交
618 619 620 621
    }
  }
}

X
Xin Pan 已提交
622
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
623
                                                ir::Node *node) const {
Y
Yancey1989 已提交
624
  int op_dev_id = -1;
625 626 627
  std::vector<std::string> input_var_names;
  std::vector<std::string> output_var_names;
  for (ir::Node *input : node->inputs) {
X
Xin Pan 已提交
628
    input_var_names.push_back(input->Name());
629 630
  }
  for (ir::Node *output : node->outputs) {
X
Xin Pan 已提交
631
    output_var_names.push_back(output->Name());
632 633 634 635 636
  }

  if (node->Op()->Type() == "split_byref" ||
      node->Op()->Type() == "split_selected_rows") {
    op_dev_id = GetVarDeviceID(input_var_names[0]);
Y
Yancey1989 已提交
637
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
638 639
      op_dev_id = GetAppropriateDeviceID(input_var_names);
      for (auto &varname : input_var_names) {
Y
Yancey1989 已提交
640 641 642
        var_name_on_devices_.emplace(varname, op_dev_id);
      }
    }
643
    for (auto &varname : output_var_names) {
Y
Yancey1989 已提交
644 645
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
646 647 648
  } else if (node->Op()->Type() == "concat") {
    op_dev_id = GetVarDeviceID(input_var_names[0]);
    for (auto &varname : output_var_names) {
Y
yi.wu 已提交
649 650
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
Y
Yancey1989 已提交
651 652 653 654 655 656 657
  } else {
    PADDLE_ENFORCE(
        "the distribute training related op should be in [split_byref, "
        "concat].");
  }

  PADDLE_ENFORCE(op_dev_id != -1,
658 659
                 "can not find right place for distributed op: %s",
                 node->Op()->Type());
Y
Yancey1989 已提交
660

661 662
  CreateComputationalOp(result, node, op_dev_id);
  if (node->Op()->Type() == "concat") {
X
Xin Pan 已提交
663
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
X
Xin Pan 已提交
664
              "fetch_barrier");
Y
Yancey1989 已提交
665 666 667
  }
}

668
// Create RPC related op handles that connects its in ops and out ops.
X
Xin Pan 已提交
669 670
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
                                          ir::Node *node) const {
Y
Yancey1989 已提交
671
  int op_dev_id = -1;
672
  if (node->Op()->Type() == "send") {
X
Xin Pan 已提交
673
    op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
Y
Yancey1989 已提交
674 675
    // the variable name which contains .block means it was splited by
    // split_byref op
676 677
    // so that we can balance the variable blocks to all the pserver
    // instances.
Y
Yancey1989 已提交
678
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
X
Xin Pan 已提交
679
        node->inputs[0]->Name().find(".block") == std::string::npos) {
680 681
      std::vector<std::string> input_var_names;
      for (ir::Node *n : node->inputs) {
X
Xin Pan 已提交
682
        input_var_names.push_back(n->Name());
683 684 685
      }
      op_dev_id = GetAppropriateDeviceID(input_var_names);
      for (auto &varname : input_var_names) {
Y
Yancey1989 已提交
686 687 688
        var_name_on_devices_.emplace(varname, op_dev_id);
      }
    }
689 690 691
  } else if (node->Op()->Type() == "recv") {
    std::vector<std::string> output_var_names;
    for (ir::Node *n : node->outputs) {
X
Xin Pan 已提交
692
      output_var_names.push_back(n->Name());
693 694 695
    }
    op_dev_id = GetAppropriateDeviceID(output_var_names);
    for (auto &varname : output_var_names) {
Y
Yancey1989 已提交
696 697 698 699 700 701 702 703
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
  } else {
    // send_barrier and fetch_barrier op can be scheduled on device 0
    op_dev_id = 0;
  }

  PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
704
                 node->Op()->Type());
Y
Yancey1989 已提交
705

706 707 708
  result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
      result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
      node->Op()->Type(), places_[op_dev_id]));
Y
fix pe  
Yancey1989 已提交
709

710
  if (node->Op()->Type() == "send_barrier") {
X
Xin Pan 已提交
711
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
712
  } else if (node->Op()->Type() == "recv") {
X
Xin Pan 已提交
713
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
X
Xin Pan 已提交
714
              "send_barrier");
715
  } else if (node->Op()->Type() == "fetch_barrier") {
X
Xin Pan 已提交
716
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
717
  } else if (node->Op()->Type() == "send") {
Y
Yancey1989 已提交
718 719 720
    // do nothing
  } else {
    PADDLE_THROW(
Y
Yancey1989 已提交
721
        "rpc op should be in ["
722
        "send, send_barrier. recv, fetch_barrier]");
Y
Yancey1989 已提交
723 724
  }

725
  CreateOpHandleIOs(result, node, op_dev_id);
Y
Yu Yang 已提交
726 727
}

728
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
Y
yuyang18 已提交
729
  return boost::get<int>(
730
             node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Fix bug  
yuyang18 已提交
731 732 733
             (static_cast<int>(OpRole::kBackward) |
              static_cast<int>(OpRole::kLoss)) &&
         !loss_var_name_.empty();  // If loss_var is empty. This is test mode
Y
Yu Yang 已提交
734
}
Y
Yu Yang 已提交
735 736 737
}  // namespace details
}  // namespace framework
}  // namespace paddle