multi_devices_graph_pass.cc 33.5 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
C
chengduoZH 已提交
14
#include <algorithm>
Y
Yancey1989 已提交
15
#include <fstream>
C
chengduoZH 已提交
16
#include <string>
C
chengduoZH 已提交
17
#include <utility>
C
chengduoZH 已提交
18 19
#include <vector>

20
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
C
chengduoZH 已提交
21
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
Y
Yu Yang 已提交
22
#include "paddle/fluid/framework/details/computation_op_handle.h"
23
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
C
chengduoZH 已提交
25
#include "paddle/fluid/framework/details/reduce_op_handle.h"
Y
Yancey1989 已提交
26
#include "paddle/fluid/framework/details/rpc_op_handle.h"
Y
Yu Yang 已提交
27
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
X
better  
Xin Pan 已提交
28
#include "paddle/fluid/framework/ir/graph_helper.h"
X
Xin Pan 已提交
29
#include "paddle/fluid/framework/ir/node.h"
Y
Fix bug  
yuyang18 已提交
30
#include "paddle/fluid/framework/op_info.h"
Y
Yu Yang 已提交
31
#include "paddle/fluid/framework/scope.h"
Y
Yu Yang 已提交
32

Y
Yu Yang 已提交
33 34 35
namespace paddle {
namespace framework {
namespace details {
X
Xin Pan 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
namespace {
void PolishGraphToSupportDataHazards(ir::Graph *graph) {
  for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
    for (auto &name_pair : var_map) {
      if (name_pair.second.size() <= 1) {
        continue;
      }
      auto it_new = name_pair.second.rbegin();
      auto it_old = name_pair.second.rbegin();
      ++it_old;
      for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
        OpHandleBase *write_op = (*it_new)->GeneratedOp();
        const auto &read_ops = (*it_old)->PendingOps();

        for (auto *read_op : read_ops) {
          // Manually add a dependency var from read_op to write_op;
          if (read_op == write_op) {
            // Read Write is the same op.
            continue;
          }
          bool has_dep = false;
          for (auto *r_out : read_op->Outputs()) {
            for (auto *w_in : write_op->Inputs()) {
              if (r_out->Node() == w_in->Node()) {
                has_dep = true;
                break;
              }
            }
          }
          if (has_dep) continue;

          auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
          read_op->AddOutput(dep_var);
          write_op->AddInput(dep_var);
          graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
        }
      }
    }
  }
}

VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
                                      const platform::Place &place,
                                      size_t place_offset) {
  auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
  auto &var_holder = var_holders[node->Name()];
  VarHandle *var = nullptr;
  if (var_holder.empty()) {
    if (node->Var()) {
      var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
                          node->Name(), place);
    } else {
      var = new VarHandle(
          graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0,
          place_offset, node->Name(), place);
    }
    var_holder.emplace_back(var);
  } else {
    var = var_holder.rbegin()->get();
  }
  return var;
}

void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
                    ir::Node *new_node, const platform::Place &place,
                    size_t place_offset) {
  auto &vars =
      graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
  size_t version = vars.size();
  auto var =
      new VarHandle(new_node, version, place_offset, new_node->Name(), place);
  vars.emplace_back(var);
  op_handle->AddOutput(var);
}

void AddOutputToLeafOps(ir::Graph *graph) {
  for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
    if (!op->Outputs().empty()) {
      continue;
    }
    auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
    graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
    op->AddOutput(dummy_leaf);
  }
}
}  // namespace
Y
Yu Yang 已提交
122

X
Xin Pan 已提交
123 124 125 126 127 128
static const char kLossVarName[] = "loss_var_name";
static const char kPlaces[] = "places";
static const char kParams[] = "params";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";

X
Xin Pan 已提交
129
void MultiDevSSAGraphBuilder::Init() const {
X
clean  
Xin Pan 已提交
130 131 132
  all_vars_.clear();
  balance_vars_.clear();

X
Xin Pan 已提交
133 134 135 136
  loss_var_name_ = Get<const std::string>(kLossVarName);
  places_ = Get<const std::vector<platform::Place>>(kPlaces);
  local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
  strategy_ = Get<const BuildStrategy>(kStrategy);
Y
Yu Yang 已提交
137
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
138
  nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
Y
Yu Yang 已提交
139
#endif
X
Xin Pan 已提交
140

X
Xin Pan 已提交
141
  for (auto &p : Get<const std::unordered_set<std::string>>(kParams)) {
Y
Yu Yang 已提交
142 143
    grad_names_.insert(GradVarName(p));
  }
Y
Yancey1989 已提交
144
  balance_vars_.resize(places_.size(), 0);
Y
yuyang18 已提交
145 146 147 148 149
  if (strategy_.enable_data_balance_ && places_.size() == 1) {
    LOG(WARNING) << "It is no need to enable data balance when there is only "
                    "one place. enable_data_balance is set to False.";
    strategy_.enable_data_balance_ = false;
  }
Y
Yu Yang 已提交
150 151
}

X
Xin Pan 已提交
152 153
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
                                                ir::Node *node,
Y
Yu Yang 已提交
154 155
                                                size_t place_id) const {
  auto p = places_[place_id];
X
Xin Pan 已提交
156
  auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
X
Xin Pan 已提交
157 158
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
T
wip  
typhoonzero 已提交
159

160 161
  for (ir::Node *input : node->inputs) {
    VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
T
wip  
typhoonzero 已提交
162 163 164
    op_handle->AddInput(var);
  }

165
  for (ir::Node *output : node->outputs) {
X
polish  
Xin Pan 已提交
166 167 168 169 170 171 172 173
    ir::Node *new_node = nullptr;
    if (output->Var()) {
      new_node = result->CreateVarNode(output->Var());
    } else {
      new_node =
          result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
    }
    CreateOpOutput(result, op_handle, new_node, p, place_id);
T
wip  
typhoonzero 已提交
174 175
  }
}
Y
fix pe  
Yancey1989 已提交
176 177

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
X
Xin Pan 已提交
178
    const std::vector<ir::Node *> &nodes) const {
Y
fix pe  
Yancey1989 已提交
179
  std::vector<std::string> send_vars;
Y
Yancey1989 已提交
180 181
  // since parameters are all in block 0,
  // it's enough to only scan send ops in block 0
182 183
  for (auto &node : nodes) {
    OpDesc *op = node->Op();
Y
Yancey1989 已提交
184 185
    // TODO(Yancey1989): use a graceful method to find send op,
    // instead of the the hard code string
186
    if (op->Type() == "send") {
Y
fix pe  
Yancey1989 已提交
187 188 189 190 191 192 193 194 195 196
      auto op_vars = op->InputArgumentNames();
      send_vars.reserve(send_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return send_vars;
}

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
X
Xin Pan 已提交
197
    const std::vector<ir::Node *> &nodes) const {
Y
fix pe  
Yancey1989 已提交
198
  std::vector<std::string> recv_vars;
199 200
  for (auto &node : nodes) {
    OpDesc *op = node->Op();
Y
Yancey1989 已提交
201 202 203
    // TODO(Yancey1989): use a graceful method to find recv op,
    // instead of the hard code string
    if (op->Type() == "recv") {
Y
fix pe  
Yancey1989 已提交
204 205 206 207 208 209 210 211 212 213
      auto op_vars = op->OutputArgumentNames();
      recv_vars.reserve(recv_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return recv_vars;
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(
214
    ir::Node *node, const std::vector<std::string> &send_vars,
Y
fix pe  
Yancey1989 已提交
215 216
    const std::vector<std::string> &recv_vars) const {
  if (send_vars.size() == 0 || recv_vars.size() == 0) {
T
typhoonzero 已提交
217 218 219
    return false;
  }

Y
Yu Yang 已提交
220 221 222 223
  /**
   * Check any of opvars contains `.block` and in sendvars
   */
  auto checker = [](const std::vector<std::string> &opvars,
Y
fix pe  
Yancey1989 已提交
224
                    const std::vector<std::string> &rpc_vars) -> bool {
T
typhoonzero 已提交
225
    for (auto &var : opvars) {
Y
Yancey1989 已提交
226 227 228
      // a variable name with the suffix `.block` means it's a splited
      // variable by (DistributeTranspiler)
      // [python/paddle/fluid/transpiler/distribute_transpiler.py]
T
typhoonzero 已提交
229
      if (var.find(".block") != std::string::npos &&
Y
fix pe  
Yancey1989 已提交
230
          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
Y
Yu Yang 已提交
231
        return true;
T
typhoonzero 已提交
232 233
      }
    }
Y
Yu Yang 已提交
234
    return false;
T
typhoonzero 已提交
235 236
  };

237 238 239
  std::vector<std::string> input_var_names;
  std::vector<std::string> output_var_names;
  for (ir::Node *input : node->inputs) {
X
Xin Pan 已提交
240
    input_var_names.push_back(input->Name());
241 242
  }
  for (ir::Node *output : node->outputs) {
X
Xin Pan 已提交
243
    output_var_names.push_back(output->Name());
244 245 246 247
  }

  return checker(output_var_names, send_vars) ||
         checker(input_var_names, recv_vars);
T
typhoonzero 已提交
248 249
}

Y
Yancey1989 已提交
250 251 252 253
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
    const std::vector<std::string> &var_names) const {
  int64_t numel_sum = 0;
  for (auto var_name : var_names) {
X
Xin Pan 已提交
254
    if (all_vars_.find(var_name) == all_vars_.end()) continue;
Y
Yancey1989 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    auto var_desc = all_vars_.at(var_name);
    PADDLE_ENFORCE_NOT_NULL(var_desc);
    auto dim = framework::make_ddim(var_desc->GetShape());
    int64_t numel = framework::product(dim);
    PADDLE_ENFORCE_GT(numel, 0);
    numel_sum += numel;
  }

  auto smallest =
      std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
  size_t dev_id =
      static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
  balance_vars_[dev_id] += numel_sum;
  return dev_id;
}

X
better  
Xin Pan 已提交
271 272 273 274 275
// Topology sort the graph nodes from inputs to outputs.
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
X
Xin Pan 已提交
276 277 278
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
  std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
X
better  
Xin Pan 已提交
279 280 281 282 283
  size_t last_backward = 0;
  for (size_t i = 0; i < ret.size(); ++i) {
    if (boost::get<int>(
            ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
        static_cast<int>(OpRole::kBackward)) {
X
Xin Pan 已提交
284
      last_backward = i;
X
better  
Xin Pan 已提交
285 286 287
    }
  }

X
Xin Pan 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
  std::vector<ir::Node *> optimize_ops;
  std::vector<ir::Node *> sorted_ret;
  for (size_t i = 0; i < ret.size(); ++i) {
    if (i < last_backward) {
      if (boost::get<int>(ret[i]->Op()->GetAttr(
              OpProtoAndCheckerMaker::OpRoleAttrName())) ==
          static_cast<int>(OpRole::kOptimize)) {
        optimize_ops.push_back(ret[i]);
      } else {
        sorted_ret.push_back(ret[i]);
      }
    } else if (i == last_backward) {
      sorted_ret.push_back(ret[i]);
      // Verify that no operations before optimize ops depends on optimize ops.
      std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
                                                  optimize_ops.end());
      for (ir::Node *n : sorted_ret) {
        for (ir::Node *in : n->inputs) {
          for (ir::Node *pre_n : in->inputs) {
            PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
                           "optimize operations cannot be depended by forward "
                           "or backward node %s -> %s",
                           pre_n->Name(), n->Name());
          }
        }
X
Xin Pan 已提交
313
      }
X
Xin Pan 已提交
314 315 316 317
      sorted_ret.insert(sorted_ret.end(), optimize_ops.begin(),
                        optimize_ops.end());
    } else {
      sorted_ret.push_back(ret[i]);
X
Xin Pan 已提交
318 319
    }
  }
X
better  
Xin Pan 已提交
320 321 322
  return sorted_ret;
}

X
Xin Pan 已提交
323
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
X
Xin Pan 已提交
324
    std::unique_ptr<ir::Graph> graph) const {
X
Xin Pan 已提交
325
  Init();
X
Xin Pan 已提交
326
  // Give the topology sort order and rebuild the graph structure.
X
better  
Xin Pan 已提交
327
  std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
X
Xin Pan 已提交
328 329
  auto nodes = graph->ReleaseNodes();
  ir::Graph &result = *graph;
330 331

  for (auto &node : nodes) {
X
Xin Pan 已提交
332
    if (node->IsVar() && node->Var()) {
X
Xin Pan 已提交
333
      all_vars_.emplace(node->Name(), node->Var());
334
    }
C
fix ci  
chengduoZH 已提交
335
  }
C
chengduoZH 已提交
336
  std::unordered_set<std::string> og_has_been_broadcast;
Y
Yu Yang 已提交
337 338

  // We cannot invoke resize. It is a bug of GCC 4.8
X
Xin Pan 已提交
339 340 341 342
  result.Set(kGraphVars, new GraphVars(places_.size()));
  result.Set(kGraphDepVars, new GraphDepVars);
  result.Set(kGraphOps, new GraphOps);
  result.Set(kShardedVarDevice, new ShardedVarDevice);
343

Y
fix pe  
Yancey1989 已提交
344
  // find send/recv vars so that we can place the distributed training
345
  // related op in the place 0
X
Xin Pan 已提交
346 347
  auto send_vars = FindDistTrainSendVars(sorted_ops);
  auto recv_vars = FindDistTrainRecvVars(sorted_ops);
T
typhoonzero 已提交
348

C
chengduoZH 已提交
349 350 351
  std::vector<std::unordered_set<std::string>> bcast_var_name_set;
  bcast_var_name_set.resize(places_.size());

C
chengduoZH 已提交
352
  size_t cur_device_id = 0;
Y
Yu Yang 已提交
353
  bool is_forwarding = true;
Y
Yancey1989 已提交
354
  bool is_dist_train = false;
355

X
better  
Xin Pan 已提交
356
  for (ir::Node *node : sorted_ops) {
Y
Yancey1989 已提交
357
    if (boost::get<int>(
358
            node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Yancey1989 已提交
359
        static_cast<int>(OpRole::kRPC)) {
Y
Yancey1989 已提交
360 361 362 363 364 365 366 367 368 369 370 371 372
      int op_dev_id = CreateRPCOp(&result, node);
      PADDLE_ENFORCE(op_dev_id != -1,
                     "Can not schedule the RPC operator to the right place.");
      if (node->Op()->Type() == "recv") {
        auto recv_vars_attr =
            boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
                OpProtoAndCheckerMaker::OpRoleVarAttrName()));
        PADDLE_ENFORCE(recv_vars_attr.size() == 2UL);  // [parameter, gradient]
        if (recv_vars_attr[0].find(".block") == std::string::npos) {
          bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]);
        }
      }
      is_dist_train = true;
X
Xin Pan 已提交
373
    } else if (IsDistTrainOp(node, send_vars, recv_vars)) {
Y
Yancey1989 已提交
374 375 376 377 378
      int op_dev_id = CreateDistTrainOp(&result, node);
      if (node->Op()->Type() == "concat") {
        auto origin_param_name = node->Op()->OutputArgumentNames()[0];
        bcast_var_name_set[op_dev_id].emplace(origin_param_name);
      }
X
Xin Pan 已提交
379
    } else if (IsScaleLossOp(node)) {
Y
Yu Yang 已提交
380
      // user can customize loss@grad if not use_default_grad_scale_
Y
yuyang18 已提交
381 382
      if (strategy_.gradient_scale_ !=
          BuildStrategy::GradientScaleStrategy::kCustomized) {
X
Xin Pan 已提交
383
        // TODO(paddle-dev): Why is there no input for this op_handle?
384 385
        auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
        CreateScaleLossGradOp(&result, loss_grad_name);
Y
Yu Yang 已提交
386
      }
387 388 389 390
      // This assumes the backward generating code will ensure IsScaleLossOp
      // is true only for the op that scale the final scalar loss.
      // It also assumes backward op will always follow the forward op in
      // the block.
Y
Yu Yang 已提交
391
      is_forwarding = false;
Y
Yu Yang 已提交
392
    } else {
X
Xin Pan 已提交
393
      int op_dev_id = GetOpDeviceID(result, node);
C
chengduo 已提交
394
      if (op_dev_id != -1) {  // This op only runs on one specific device.
X
Xin Pan 已提交
395
        CreateComputationalOp(&result, node, op_dev_id);
396
        for (ir::Node *n : node->outputs) {
X
Xin Pan 已提交
397
          graph->Get<ShardedVarDevice>(kShardedVarDevice)
X
Xin Pan 已提交
398
              .emplace(n->Name(), op_dev_id);
C
chengduoZH 已提交
399
        }
C
chengduo 已提交
400 401 402
      } else {
        // This op runs on all devices, and its output may have parameter's
        // gradients.
X
Xin Pan 已提交
403
        // TODO(paddle-dev): Why is so special about "read" op?
404 405
        if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
          node->Op()->SetAttr("throw_eof_exp", false);
X
Xin Pan 已提交
406
          CreateComputationalOps(&result, node, places_.size());
407
          const auto &data_var_names = node->Op()->Output("Out");
408
          InsertDataBalanceOp(&result, data_var_names);
F
fengjiayi 已提交
409
        } else {
X
Xin Pan 已提交
410
          CreateComputationalOps(&result, node, places_.size());
411 412
        }

C
chengduo 已提交
413 414 415
        if (!is_forwarding && places_.size() > 1) {
          // Currently, we assume that once gradient is generated, it can be
          // broadcast, and each gradient is only broadcast once.
416
          if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
C
chengduo 已提交
417 418 419
                                    OpProtoAndCheckerMaker::OpRoleAttrName())) &
                                static_cast<int>(OpRole::kBackward))) {
            try {
420 421
              auto backward_vars = boost::get<std::vector<std::string>>(
                  node->Op()->GetNullableAttr(
C
chengduo 已提交
422
                      OpProtoAndCheckerMaker::OpRoleVarAttrName()));
Y
yuyang18 已提交
423

C
chengduo 已提交
424
              PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
Y
yuyang18 已提交
425

C
chengduo 已提交
426 427 428 429
              for (size_t i = 0; i < backward_vars.size(); i += 2) {
                auto &p_name = backward_vars[i];
                auto &g_name = backward_vars[i + 1];
                VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
Y
yuyang18 已提交
430

C
chengduo 已提交
431 432 433 434
                switch (strategy_.reduce_) {
                  case BuildStrategy::ReduceStrategy::kReduce:
                    cur_device_id = GetAppropriateDeviceID({g_name});
                    CreateReduceOp(&result, g_name, cur_device_id);
X
Xin Pan 已提交
435
                    graph->Get<ShardedVarDevice>(kShardedVarDevice)
X
Xin Pan 已提交
436
                        .emplace(g_name, cur_device_id);
Y
Yancey1989 已提交
437 438 439
                    if (!is_dist_train) {
                      bcast_var_name_set[cur_device_id].emplace(p_name);
                    }
C
chengduo 已提交
440 441 442 443 444 445 446 447 448 449 450 451 452
                    break;
                  case BuildStrategy::ReduceStrategy::kAllReduce:
                    if (IsSparseGradient(g_name)) {
                      CreateReduceOp(&result, g_name, 0);
                      CreateBroadcastOp(&result, g_name, 0);
                    } else {
                      InsertAllReduceOp(&result, g_name);
                    }
                    break;
                  default:
                    LOG(FATAL) << "Unknown reduce strategy ";
                    break;
                }
Y
yuyang18 已提交
453
              }
C
chengduo 已提交
454
            } catch (boost::bad_get e) {
C
chengduoZH 已提交
455
            }
Y
Yu Yang 已提交
456 457 458 459 460
          }
        }
      }
    }
  }
461 462 463 464 465
  bool use_gpu = false;
#ifdef PADDLE_WITH_CUDA
  use_gpu = nccl_ctxs_ != nullptr;
#endif

Y
Yancey1989 已提交
466 467 468 469 470
  // Insert broadcast operators principle:
  // 1. Broadcast optimized parameters in Reduce strategy;
  // 2. No need broadcast optimized parameters in AllReduce strategy because of
  //    the optimization sub-graph would be run on every GPU;
  // 3. Allways broadcast received parameters in Distribute Training.
Y
Yancey1989 已提交
471 472 473
  if ((use_gpu &&
       strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
      is_dist_train) {
474 475 476 477 478
    for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
      auto &to_bcast_set = bcast_var_name_set[dev_id];
      for (auto &bcast_name : to_bcast_set) {
        CreateBroadcastOp(&result, bcast_name, dev_id);
      }
C
chengduoZH 已提交
479 480
    }
  }
Y
Yu Yang 已提交
481
  /*
X
Xin Pan 已提交
482 483 484
  Dependency graph has been constructed. However, there are still data
  hazards need to be handled.
 */
Y
Yu Yang 已提交
485
  PolishGraphToSupportDataHazards(&result);
Y
Yu Yang 已提交
486

Y
Yu Yang 已提交
487 488 489 490
  /*
   * Only variables should be the leaves of graph.
   */
  AddOutputToLeafOps(&result);
X
Xin Pan 已提交
491
  PADDLE_ENFORCE(!ir::HasCircle(result));
Q
qiaolongfei 已提交
492
  return graph;
Y
Yu Yang 已提交
493 494
}

Y
Yancey1989 已提交
495 496 497
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
  PADDLE_ENFORCE(all_vars_.count(og) != 0);
  if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
C
fix ci  
chengduoZH 已提交
498 499 500
    return true;
  }
  return false;
501 502
}

503 504 505 506 507 508 509 510 511 512 513 514 515
void MultiDevSSAGraphBuilder::SetCommunicationContext(
    OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
  if (nccl_ctxs_ == nullptr) {
    op_handle->SetDeviceContext(p,
                                platform::DeviceContextPool::Instance().Get(p));
  }
#else
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
#endif
}

X
Xin Pan 已提交
516
void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
C
chengduoZH 已提交
517
                                                const std::string &p_name,
C
chengduoZH 已提交
518
                                                size_t src_dev_id) const {
C
chengduoZH 已提交
519
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
520 521 522
  auto *op_handle = new BroadcastOpHandle(
      result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_);
C
chengduoZH 已提交
523
#else
X
polish  
Xin Pan 已提交
524 525 526
  auto *op_handle = new BroadcastOpHandle(
      result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
      local_scopes_, places_);
C
chengduoZH 已提交
527
#endif
X
Xin Pan 已提交
528
  result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
X
Xin Pan 已提交
529

X
Xin Pan 已提交
530
  auto *in =
X
Xin Pan 已提交
531
      result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get();
C
chengduoZH 已提交
532 533 534 535
  op_handle->AddInput(in);

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
536
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
537
    auto &vars = result->Get<GraphVars>(kGraphVars).at(i).at(p_name);
X
polish  
Xin Pan 已提交
538 539 540
    auto *out_var = new VarHandle(
        result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
        i, p_name, p);
C
chengduoZH 已提交
541 542 543 544 545
    vars.emplace_back(out_var);
    op_handle->AddOutput(out_var);
  }
}

X
Xin Pan 已提交
546
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
547
                                                    ir::Node *node,
C
chengduoZH 已提交
548
                                                    int dev_id) const {
X
Xin Pan 已提交
549
  result->Get<GraphOps>(kGraphOps).emplace_back(
X
Xin Pan 已提交
550
      new ComputationOpHandle(result->CreateOpNode(node->Op()),
551 552
                              local_scopes_[dev_id], places_[dev_id]));
  CreateOpHandleIOs(result, node, dev_id);
C
chengduoZH 已提交
553 554
}

X
Xin Pan 已提交
555
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
C
chengduoZH 已提交
556
                                                const std::string &og) const {
Y
Yu Yang 已提交
557
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
558
  result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
X
polish  
Xin Pan 已提交
559 560
      result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
561
#else
X
Xin Pan 已提交
562
  result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
X
polish  
Xin Pan 已提交
563 564
      result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
      local_scopes_, places_));
C
chengduoZH 已提交
565
#endif
X
Xin Pan 已提交
566
  auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
Y
Yu Yang 已提交
567 568 569

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
570
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
571
    auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
Y
Yu Yang 已提交
572 573
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
Y
Yu Yang 已提交
574 575
    op_handle->AddInput(prev_grad.get());

X
Xin Pan 已提交
576
    auto var =
X
polish  
Xin Pan 已提交
577 578
        new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
                      vars.size(), i, og, p);
Y
Yu Yang 已提交
579 580 581 582 583
    vars.emplace_back(var);
    op_handle->AddOutput(var);
  }
}

584
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
X
Xin Pan 已提交
585
    ir::Graph *result, const std::vector<std::string> &datas) const {
F
fengjiayi 已提交
586
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
587
  result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
X
polish  
Xin Pan 已提交
588 589
      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
F
fengjiayi 已提交
590
#else
X
Xin Pan 已提交
591
  result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
X
polish  
Xin Pan 已提交
592 593
      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
      local_scopes_, places_));
F
fengjiayi 已提交
594
#endif
X
Xin Pan 已提交
595
  auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
596 597 598 599
  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
    SetCommunicationContext(op_handle, p);
    for (const std::string &d_name : datas) {
X
Xin Pan 已提交
600
      auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
601 602
      PADDLE_ENFORCE(!vars.empty());
      op_handle->AddInput(vars.back().get());
X
polish  
Xin Pan 已提交
603 604 605
      auto var = new VarHandle(
          result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
          vars.size(), i, d_name, p);
606 607 608 609 610 611
      vars.emplace_back(var);
      op_handle->AddOutput(var);
    }
  }
}

X
Xin Pan 已提交
612 613
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
                                           ir::Node *node) const {
Y
yuyang18 已提交
614
  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
C
chengduoZH 已提交
615 616
    return -1;
  }
617
  int op_role = boost::get<int>(
618
      node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
619 620
  if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
    return -1;
C
chengduoZH 已提交
621
  }
622
  auto param_grad = boost::get<std::vector<std::string>>(
X
Xin Pan 已提交
623
      node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
624 625

  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
X
Xin Pan 已提交
626
  int dev_id = GetVarDeviceID(graph, param_grad[1]);
X
Xin Pan 已提交
627 628
  PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
                    node->Op()->Type(), param_grad[0], param_grad[1]);
629
  return dev_id;
630 631
}

X
Xin Pan 已提交
632 633
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
                                            const std::string &varname) const {
X
Xin Pan 已提交
634
  auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice);
X
Xin Pan 已提交
635 636
  auto got = sharded_var_device.find(varname);
  return got == sharded_var_device.end() ? -1 : got->second;
C
chengduoZH 已提交
637 638
}

639 640
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
    ir::Graph *result, const std::string &loss_grad_name) const {
Y
Yu Yang 已提交
641
  for (size_t i = 0; i < places_.size(); ++i) {
Y
yuyang18 已提交
642 643
    // Insert ScaleCost OpHandle
    auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
X
Xin Pan 已提交
644
    auto *op_handle = new ScaleLossGradOpHandle(
X
polish  
Xin Pan 已提交
645
        result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
Y
yuyang18 已提交
646
        local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx);
X
Xin Pan 已提交
647
    result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
Y
Yu Yang 已提交
648 649 650 651 652 653 654

    // FIXME: Currently ScaleLossGradOp only use device_count as scale
    // factor. So it does not depend on any other operators.
    // VarHandle *loss = GetVarHandle(loss_var_name, place);
    // loss->pending_ops_.emplace_back(op_handle);
    // op_handle->inputs_.emplace_back(loss);

655 656 657 658
    CreateOpOutput(
        result, op_handle,
        result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable),
        places_[i], i);
Y
Yu Yang 已提交
659 660 661
  }
}

X
Xin Pan 已提交
662
void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
663
                                                     ir::Node *node,
T
typhoonzero 已提交
664 665
                                                     size_t num_places) const {
  for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
Y
Yu Yang 已提交
666 667
    auto p = places_[scope_idx];
    auto s = local_scopes_[scope_idx];
X
Xin Pan 已提交
668
    result->Get<GraphOps>(kGraphOps).emplace_back(
X
Xin Pan 已提交
669
        new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
670
    CreateOpHandleIOs(result, node, scope_idx);
Y
Yu Yang 已提交
671 672 673
  }
}

X
Xin Pan 已提交
674
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
C
chengduoZH 已提交
675 676
                                                   const std::string &og,
                                                   int dst_dev_id) const {
C
chengduoZH 已提交
677
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
678
  result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
X
polish  
Xin Pan 已提交
679 680
      result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
681
#else
X
Xin Pan 已提交
682
  result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
X
polish  
Xin Pan 已提交
683 684
      result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
      local_scopes_, places_));
C
chengduoZH 已提交
685
#endif
X
Xin Pan 已提交
686
  auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
C
chengduoZH 已提交
687 688 689

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
690
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
691
    auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
C
chengduoZH 已提交
692 693 694 695
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
    op_handle->AddInput(prev_grad.get());
  }
X
Xin Pan 已提交
696
  auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
X
polish  
Xin Pan 已提交
697 698 699
  auto var =
      new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
                    vars.size(), dst_dev_id, og, places_[dst_dev_id]);
C
chengduoZH 已提交
700 701 702 703 704
  vars.emplace_back(var);
  op_handle->AddOutput(var);
  return var;
}

Y
Yancey1989 已提交
705 706
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
                                               ir::Node *node) const {
Y
Yancey1989 已提交
707
  int op_dev_id = -1;
708 709 710
  std::vector<std::string> input_var_names;
  std::vector<std::string> output_var_names;
  for (ir::Node *input : node->inputs) {
X
Xin Pan 已提交
711
    input_var_names.push_back(input->Name());
712 713
  }
  for (ir::Node *output : node->outputs) {
X
Xin Pan 已提交
714
    output_var_names.push_back(output->Name());
715 716 717 718
  }

  if (node->Op()->Type() == "split_byref" ||
      node->Op()->Type() == "split_selected_rows") {
X
Xin Pan 已提交
719
    // TODO(paddle-dev): getting the first var is not safe.
X
Xin Pan 已提交
720
    op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
Y
Yancey1989 已提交
721
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
722 723
      op_dev_id = GetAppropriateDeviceID(input_var_names);
      for (auto &varname : input_var_names) {
X
Xin Pan 已提交
724
        result->Get<ShardedVarDevice>(kShardedVarDevice)
X
Xin Pan 已提交
725
            .emplace(varname, op_dev_id);
Y
Yancey1989 已提交
726 727
      }
    }
728
    for (auto &varname : output_var_names) {
X
Xin Pan 已提交
729
      result->Get<ShardedVarDevice>(kShardedVarDevice)
X
Xin Pan 已提交
730
          .emplace(varname, op_dev_id);
Y
Yancey1989 已提交
731
    }
732
  } else if (node->Op()->Type() == "concat") {
X
Xin Pan 已提交
733
    op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
734
    for (auto &varname : output_var_names) {
X
Xin Pan 已提交
735
      result->Get<ShardedVarDevice>(kShardedVarDevice)
X
Xin Pan 已提交
736
          .emplace(varname, op_dev_id);
Y
yi.wu 已提交
737
    }
Y
Yancey1989 已提交
738
  } else {
W
Wu Yi 已提交
739
    PADDLE_THROW(
Y
Yancey1989 已提交
740 741 742 743 744
        "the distribute training related op should be in [split_byref, "
        "concat].");
  }

  PADDLE_ENFORCE(op_dev_id != -1,
745 746
                 "can not find right place for distributed op: %s",
                 node->Op()->Type());
Y
Yancey1989 已提交
747

748
  CreateComputationalOp(result, node, op_dev_id);
Y
Yancey1989 已提交
749
  return op_dev_id;
W
Wu Yi 已提交
750 751 752 753 754 755 756 757 758 759 760 761 762 763
}

void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
  auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
  for (ir::Node *input : node->inputs) {
    VarHandle *var = nullptr;
    for (int place_offset = 0; place_offset < num_places; ++place_offset) {
      auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
      auto &var_holder = var_holders[input->Name()];
      if (!var_holder.empty()) {
        var = var_holder.rbegin()->get();
        op_handle->AddInput(var);
      }
    }
Y
Yancey1989 已提交
764 765 766
  }
}

767
// Create RPC related op handles that connects its in ops and out ops.
Y
Yancey1989 已提交
768 769
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
                                         ir::Node *node) const {
Y
Yancey1989 已提交
770
  int op_dev_id = -1;
771
  if (node->Op()->Type() == "send") {
X
Xin Pan 已提交
772
    // TODO(paddle-dev): getting the first var is not safe.
X
Xin Pan 已提交
773
    op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name());
X
Xin Pan 已提交
774 775
    PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
                   "This hack no longer holds, please fix.");
Y
Yancey1989 已提交
776 777 778
    // the variable name which contains .block means it was splited by
    // split_byref op
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
X
Xin Pan 已提交
779
        node->inputs[0]->Name().find(".block") == std::string::npos) {
780 781
      std::vector<std::string> input_var_names;
      for (ir::Node *n : node->inputs) {
X
Xin Pan 已提交
782
        input_var_names.push_back(n->Name());
783
      }
W
Wu Yi 已提交
784 785 786 787 788 789
      auto send_param_grad = boost::get<std::vector<std::string>>(
          node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
      PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U);
      op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
      VLOG(10) << "send grad " << input_var_names[0] << " origin "
               << send_param_grad[1] << " place: " << op_dev_id;
790
      for (auto &varname : input_var_names) {
X
Xin Pan 已提交
791
        result->Get<ShardedVarDevice>(kShardedVarDevice)
X
Xin Pan 已提交
792
            .emplace(varname, op_dev_id);
Y
Yancey1989 已提交
793
      }
W
Wu Yi 已提交
794 795
      result->Get<ShardedVarDevice>(kShardedVarDevice)
          .emplace(send_param_grad[1], op_dev_id);
Y
Yancey1989 已提交
796
    }
797 798 799
  } else if (node->Op()->Type() == "recv") {
    std::vector<std::string> output_var_names;
    for (ir::Node *n : node->outputs) {
X
Xin Pan 已提交
800
      output_var_names.push_back(n->Name());
801
    }
W
Wu Yi 已提交
802 803 804 805 806 807 808 809 810 811
    auto recv_param_grad = boost::get<std::vector<std::string>>(
        node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
    if (recv_param_grad.size() == 2U) {
      op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
      VLOG(10) << "recv param " << recv_param_grad[0]
               << " get grad place: " << recv_param_grad[1]
               << " place: " << op_dev_id;
    } else {
      op_dev_id = GetAppropriateDeviceID(output_var_names);
    }
812
    for (auto &varname : output_var_names) {
X
Xin Pan 已提交
813
      result->Get<ShardedVarDevice>(kShardedVarDevice)
X
Xin Pan 已提交
814
          .emplace(varname, op_dev_id);
Y
Yancey1989 已提交
815 816
    }
  } else {
W
Wu Yi 已提交
817
    // send_barrier, fetch_barrier will run on place 0;
Y
Yancey1989 已提交
818 819 820 821
    op_dev_id = 0;
  }

  PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
822
                 node->Op()->Type());
X
Xin Pan 已提交
823
  result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
824 825
      result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
      node->Op()->Type(), places_[op_dev_id]));
Y
fix pe  
Yancey1989 已提交
826

W
Wu Yi 已提交
827 828
  if (node->Op()->Type() == "send") {
    CreateOpHandleIOs(result, node, op_dev_id);
Y
Yancey1989 已提交
829
  } else {
W
Wu Yi 已提交
830 831 832 833 834 835
    // send_barrier, recv, fetch_barrier's inputs are deps var, get them from
    // all places
    auto p = places_[op_dev_id];
    auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
    op_handle->SetDeviceContext(p,
                                platform::DeviceContextPool::Instance().Get(p));
Y
Yancey1989 已提交
836

W
Wu Yi 已提交
837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
    SetOpInputsAllPlaces(result, node, places_.size());
    for (ir::Node *output : node->outputs) {
      int outvar_dev_id = op_dev_id;
      if (node->Op()->Type() == "fetch_barrier") {
        outvar_dev_id = GetVarDeviceID(*result, output->Name());
        PADDLE_ENFORCE_NE(outvar_dev_id, -1);
      }
      p = places_[outvar_dev_id];
      ir::Node *new_node = nullptr;
      if (output->Var()) {
        new_node = result->CreateVarNode(output->Var());
      } else {
        new_node =
            result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
      }
      CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
    }
  }
Y
Yancey1989 已提交
855
  return op_dev_id;
Y
Yu Yang 已提交
856 857
}

858
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
Y
yuyang18 已提交
859
  return boost::get<int>(
860
             node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Fix bug  
yuyang18 已提交
861 862 863
             (static_cast<int>(OpRole::kBackward) |
              static_cast<int>(OpRole::kLoss)) &&
         !loss_var_name_.empty();  // If loss_var is empty. This is test mode
Y
Yu Yang 已提交
864
}
Y
Yu Yang 已提交
865 866 867
}  // namespace details
}  // namespace framework
}  // namespace paddle
X
Xin Pan 已提交
868

X
Xin Pan 已提交
869
REGISTER_PASS(multi_devices_pass,
X
Xin Pan 已提交
870 871 872 873 874 875
              paddle::framework::details::MultiDevSSAGraphBuilder)
    .RequirePassAttr(paddle::framework::details::kLossVarName)
    .RequirePassAttr(paddle::framework::details::kPlaces)
    .RequirePassAttr(paddle::framework::details::kParams)
    .RequirePassAttr(paddle::framework::details::kLocalScopes)
    .RequirePassAttr(paddle::framework::details::kStrategy);