multi_devices_graph_builder.cc 26.8 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
C
chengduoZH 已提交
14
#include <algorithm>
Y
Yancey1989 已提交
15
#include <fstream>
C
chengduoZH 已提交
16
#include <string>
C
chengduoZH 已提交
17
#include <utility>
C
chengduoZH 已提交
18 19
#include <vector>

20
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
C
chengduoZH 已提交
21
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
Y
Yu Yang 已提交
22
#include "paddle/fluid/framework/details/computation_op_handle.h"
23
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
C
chengduoZH 已提交
24
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
C
chengduoZH 已提交
25
#include "paddle/fluid/framework/details/reduce_op_handle.h"
Y
Yancey1989 已提交
26
#include "paddle/fluid/framework/details/rpc_op_handle.h"
Y
Yu Yang 已提交
27
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
X
better  
Xin Pan 已提交
28
#include "paddle/fluid/framework/ir/graph_helper.h"
X
Xin Pan 已提交
29
#include "paddle/fluid/framework/ir/node.h"
Y
Fix bug  
yuyang18 已提交
30
#include "paddle/fluid/framework/op_info.h"
Y
Yu Yang 已提交
31
#include "paddle/fluid/framework/scope.h"
Y
Yu Yang 已提交
32

Y
Yu Yang 已提交
33 34 35
namespace paddle {
namespace framework {
namespace details {
Y
Yu Yang 已提交
36 37

#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
38 39 40 41
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
C
chengduoZH 已提交
42
    const std::vector<Scope *> &local_scopes,
Y
yuyang18 已提交
43
    platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
Y
Yu Yang 已提交
44 45 46
    : loss_var_name_(loss_var_name),
      places_(places),
      local_scopes_(local_scopes),
C
chengduoZH 已提交
47
      nccl_ctxs_(nccl_ctxs),
Y
yuyang18 已提交
48
      strategy_(strategy) {
Y
Yu Yang 已提交
49 50 51 52 53
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
Y
yuyang18 已提交
54
    const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
Y
Yu Yang 已提交
55 56
    : loss_var_name_(loss_var_name),
      places_(places),
C
chengduoZH 已提交
57
      local_scopes_(local_scopes),
Y
yuyang18 已提交
58
      strategy_(strategy) {
Y
Yu Yang 已提交
59
#endif
Y
Yu Yang 已提交
60 61 62
  for (auto &p : params) {
    grad_names_.insert(GradVarName(p));
  }
Y
Yancey1989 已提交
63
  balance_vars_.resize(places_.size(), 0);
Y
yuyang18 已提交
64 65 66 67 68
  if (strategy_.enable_data_balance_ && places_.size() == 1) {
    LOG(WARNING) << "It is no need to enable data balance when there is only "
                    "one place. enable_data_balance is set to False.";
    strategy_.enable_data_balance_ = false;
  }
Y
Yu Yang 已提交
69 70
}

X
Xin Pan 已提交
71 72
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
                                                ir::Node *node,
Y
Yu Yang 已提交
73 74
                                                size_t place_id) const {
  auto p = places_[place_id];
X
Xin Pan 已提交
75
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
X
Xin Pan 已提交
76 77
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
T
wip  
typhoonzero 已提交
78

79 80
  for (ir::Node *input : node->inputs) {
    VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
T
wip  
typhoonzero 已提交
81 82 83
    op_handle->AddInput(var);
  }

84
  for (ir::Node *output : node->outputs) {
X
polish  
Xin Pan 已提交
85 86 87 88 89 90 91 92
    ir::Node *new_node = nullptr;
    if (output->Var()) {
      new_node = result->CreateVarNode(output->Var());
    } else {
      new_node =
          result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
    }
    CreateOpOutput(result, op_handle, new_node, p, place_id);
T
wip  
typhoonzero 已提交
93 94
  }
}
Y
fix pe  
Yancey1989 已提交
95 96

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
97
    const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
Y
fix pe  
Yancey1989 已提交
98
  std::vector<std::string> send_vars;
Y
Yancey1989 已提交
99 100
  // since parameters are all in block 0,
  // it's enough to only scan send ops in block 0
101
  for (auto &node : nodes) {
X
Xin Pan 已提交
102
    if (node->NodeType() != ir::Node::Type::kOperation) continue;
103
    OpDesc *op = node->Op();
Y
Yancey1989 已提交
104 105
    // TODO(Yancey1989): use a graceful method to find send op,
    // instead of the the hard code string
106
    if (op->Type() == "send") {
Y
fix pe  
Yancey1989 已提交
107 108 109 110 111 112 113 114 115 116
      auto op_vars = op->InputArgumentNames();
      send_vars.reserve(send_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return send_vars;
}

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
117
    const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
Y
fix pe  
Yancey1989 已提交
118
  std::vector<std::string> recv_vars;
119
  for (auto &node : nodes) {
X
Xin Pan 已提交
120
    if (node->NodeType() != ir::Node::Type::kOperation) continue;
121
    OpDesc *op = node->Op();
Y
Yancey1989 已提交
122 123 124
    // TODO(Yancey1989): use a graceful method to find recv op,
    // instead of the hard code string
    if (op->Type() == "recv") {
Y
fix pe  
Yancey1989 已提交
125 126 127 128 129 130 131 132 133 134
      auto op_vars = op->OutputArgumentNames();
      recv_vars.reserve(recv_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return recv_vars;
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(
135
    ir::Node *node, const std::vector<std::string> &send_vars,
Y
fix pe  
Yancey1989 已提交
136 137
    const std::vector<std::string> &recv_vars) const {
  if (send_vars.size() == 0 || recv_vars.size() == 0) {
T
typhoonzero 已提交
138 139 140
    return false;
  }

Y
Yu Yang 已提交
141 142 143 144
  /**
   * Check any of opvars contains `.block` and in sendvars
   */
  auto checker = [](const std::vector<std::string> &opvars,
Y
fix pe  
Yancey1989 已提交
145
                    const std::vector<std::string> &rpc_vars) -> bool {
T
typhoonzero 已提交
146
    for (auto &var : opvars) {
Y
Yancey1989 已提交
147 148 149
      // a variable name with the suffix `.block` means it's a splited
      // variable by (DistributeTranspiler)
      // [python/paddle/fluid/transpiler/distribute_transpiler.py]
T
typhoonzero 已提交
150
      if (var.find(".block") != std::string::npos &&
Y
fix pe  
Yancey1989 已提交
151
          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
Y
Yu Yang 已提交
152
        return true;
T
typhoonzero 已提交
153 154
      }
    }
Y
Yu Yang 已提交
155
    return false;
T
typhoonzero 已提交
156 157
  };

158 159 160
  std::vector<std::string> input_var_names;
  std::vector<std::string> output_var_names;
  for (ir::Node *input : node->inputs) {
X
Xin Pan 已提交
161
    input_var_names.push_back(input->Name());
162 163
  }
  for (ir::Node *output : node->outputs) {
X
Xin Pan 已提交
164
    output_var_names.push_back(output->Name());
165 166 167 168
  }

  return checker(output_var_names, send_vars) ||
         checker(input_var_names, recv_vars);
T
typhoonzero 已提交
169 170
}

Y
Yancey1989 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
    const std::vector<std::string> &var_names) const {
  int64_t numel_sum = 0;
  for (auto var_name : var_names) {
    auto var_desc = all_vars_.at(var_name);
    PADDLE_ENFORCE_NOT_NULL(var_desc);
    auto dim = framework::make_ddim(var_desc->GetShape());
    int64_t numel = framework::product(dim);
    PADDLE_ENFORCE_GT(numel, 0);
    numel_sum += numel;
  }

  auto smallest =
      std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
  size_t dev_id =
      static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
  balance_vars_[dev_id] += numel_sum;
  return dev_id;
}

X
better  
Xin Pan 已提交
191 192 193 194 195
// Topology sort the graph nodes from inputs to outputs.
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
X
Xin Pan 已提交
196 197 198
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
  std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
X
better  
Xin Pan 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
  size_t last_backward = 0;
  std::vector<ir::Node *> optimize_ops;
  std::vector<ir::Node *> sorted_ret;
  for (size_t i = 0; i < ret.size(); ++i) {
    if (boost::get<int>(
            ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
        static_cast<int>(OpRole::kBackward)) {
      sorted_ret.push_back(ret[i]);
      last_backward = sorted_ret.size();
    } else if (boost::get<int>(ret[i]->Op()->GetAttr(
                   OpProtoAndCheckerMaker::OpRoleAttrName())) ==
               static_cast<int>(OpRole::kOptimize)) {
      optimize_ops.push_back(ret[i]);
    } else {
      sorted_ret.push_back(ret[i]);
    }
  }

  sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
                    optimize_ops.end());
  return sorted_ret;
}

X
Xin Pan 已提交
222 223
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
    std::unique_ptr<ir::Graph> graph) const {
X
Xin Pan 已提交
224
  // Rebuild the graph structure.
X
better  
Xin Pan 已提交
225
  std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
226 227 228 229
  auto nodes = std::move(graph->nodes);
  graph->nodes.clear();

  for (auto &node : nodes) {
X
Xin Pan 已提交
230 231
    if (node->NodeType() == ir::Node::Type::kVariable) {
      all_vars_.emplace(node->Name(), node->Var());
232
    }
C
fix ci  
chengduoZH 已提交
233
  }
C
chengduoZH 已提交
234

X
Xin Pan 已提交
235
  ir::Graph &result = *graph;
C
chengduoZH 已提交
236
  std::unordered_set<std::string> og_has_been_broadcast;
Y
Yu Yang 已提交
237 238

  // We cannot invoke resize. It is a bug of GCC 4.8
X
Xin Pan 已提交
239 240 241
  result.Set("vars", new GraphVars(places_.size()));
  result.Set("dep_vars", new GraphDepVars);
  result.Set("ops", new GraphOps);
242

Y
fix pe  
Yancey1989 已提交
243 244
  // find send/recv vars so that we can place the distributed training
  // realted op in the place 0
245 246
  auto send_vars = FindDistTrainSendVars(nodes);
  auto recv_vars = FindDistTrainRecvVars(nodes);
T
typhoonzero 已提交
247

C
chengduoZH 已提交
248 249 250
  std::vector<std::unordered_set<std::string>> bcast_var_name_set;
  bcast_var_name_set.resize(places_.size());

C
chengduoZH 已提交
251
  size_t cur_device_id = 0;
Y
Yu Yang 已提交
252
  bool is_forwarding = true;
253

X
better  
Xin Pan 已提交
254
  for (ir::Node *node : sorted_ops) {
Y
Yancey1989 已提交
255
    if (boost::get<int>(
256
            node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Yancey1989 已提交
257
        static_cast<int>(OpRole::kRPC)) {
X
Xin Pan 已提交
258 259 260 261
      CreateRPCOp(&result, node);
    } else if (IsDistTrainOp(node, send_vars, recv_vars)) {
      CreateDistTrainOp(&result, node);
    } else if (IsScaleLossOp(node)) {
Y
Yu Yang 已提交
262
      // user can customize loss@grad if not use_default_grad_scale_
Y
yuyang18 已提交
263 264
      if (strategy_.gradient_scale_ !=
          BuildStrategy::GradientScaleStrategy::kCustomized) {
Y
Yu Yang 已提交
265 266
        CreateScaleLossGradOp(&result);
      }
267 268 269 270
      // This assumes the backward generating code will ensure IsScaleLossOp
      // is true only for the op that scale the final scalar loss.
      // It also assumes backward op will always follow the forward op in
      // the block.
Y
Yu Yang 已提交
271
      is_forwarding = false;
Y
Yu Yang 已提交
272
    } else {
X
Xin Pan 已提交
273
      int op_dev_id = GetOpDeviceID(node);
C
chengduo 已提交
274
      if (op_dev_id != -1) {  // This op only runs on one specific device.
X
Xin Pan 已提交
275
        CreateComputationalOp(&result, node, op_dev_id);
276
        for (ir::Node *n : node->outputs) {
X
Xin Pan 已提交
277
          var_name_on_devices_.emplace(n->Name(), op_dev_id);
C
chengduoZH 已提交
278
        }
C
chengduo 已提交
279 280 281
      } else {
        // This op runs on all devices, and its output may have parameter's
        // gradients.
282 283
        if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
          node->Op()->SetAttr("throw_eof_exp", false);
X
Xin Pan 已提交
284
          CreateComputationalOps(&result, node, places_.size());
285
          const auto &data_var_names = node->Op()->Output("Out");
286
          InsertDataBalanceOp(&result, data_var_names);
F
fengjiayi 已提交
287
        } else {
X
Xin Pan 已提交
288
          CreateComputationalOps(&result, node, places_.size());
289 290
        }

C
chengduo 已提交
291 292 293
        if (!is_forwarding && places_.size() > 1) {
          // Currently, we assume that once gradient is generated, it can be
          // broadcast, and each gradient is only broadcast once.
294
          if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
C
chengduo 已提交
295 296 297
                                    OpProtoAndCheckerMaker::OpRoleAttrName())) &
                                static_cast<int>(OpRole::kBackward))) {
            try {
298 299
              auto backward_vars = boost::get<std::vector<std::string>>(
                  node->Op()->GetNullableAttr(
C
chengduo 已提交
300
                      OpProtoAndCheckerMaker::OpRoleVarAttrName()));
Y
yuyang18 已提交
301

C
chengduo 已提交
302
              PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
Y
yuyang18 已提交
303

C
chengduo 已提交
304 305 306 307
              for (size_t i = 0; i < backward_vars.size(); i += 2) {
                auto &p_name = backward_vars[i];
                auto &g_name = backward_vars[i + 1];
                VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
Y
yuyang18 已提交
308

C
chengduo 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
                switch (strategy_.reduce_) {
                  case BuildStrategy::ReduceStrategy::kReduce:
                    cur_device_id = GetAppropriateDeviceID({g_name});
                    CreateReduceOp(&result, g_name, cur_device_id);
                    var_name_on_devices_.emplace(g_name, cur_device_id);
                    bcast_var_name_set[cur_device_id].emplace(p_name);
                    break;
                  case BuildStrategy::ReduceStrategy::kAllReduce:
                    if (IsSparseGradient(g_name)) {
                      CreateReduceOp(&result, g_name, 0);
                      CreateBroadcastOp(&result, g_name, 0);
                    } else {
                      InsertAllReduceOp(&result, g_name);
                    }
                    break;
                  default:
                    LOG(FATAL) << "Unknown reduce strategy ";
                    break;
                }
Y
yuyang18 已提交
328
              }
C
chengduo 已提交
329
            } catch (boost::bad_get e) {
C
chengduoZH 已提交
330
            }
Y
Yu Yang 已提交
331 332 333 334 335 336
          }
        }
      }
    }
  }

337 338 339 340 341 342 343 344 345 346 347 348 349
  bool use_gpu = false;
#ifdef PADDLE_WITH_CUDA
  use_gpu = nccl_ctxs_ != nullptr;
#endif

  if (use_gpu ||
      strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
    // Insert BCast Ops
    for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
      auto &to_bcast_set = bcast_var_name_set[dev_id];
      for (auto &bcast_name : to_bcast_set) {
        CreateBroadcastOp(&result, bcast_name, dev_id);
      }
C
chengduoZH 已提交
350 351
    }
  }
352

Y
Yu Yang 已提交
353 354 355 356
  /*
   * Only variables should be the leaves of graph.
   */
  AddOutputToLeafOps(&result);
Q
qiaolongfei 已提交
357
  return graph;
Y
Yu Yang 已提交
358 359
}

Y
Yancey1989 已提交
360 361 362
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
  PADDLE_ENFORCE(all_vars_.count(og) != 0);
  if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
C
fix ci  
chengduoZH 已提交
363 364 365
    return true;
  }
  return false;
366 367
}

368 369 370 371 372 373 374 375 376 377 378 379 380
void MultiDevSSAGraphBuilder::SetCommunicationContext(
    OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
  if (nccl_ctxs_ == nullptr) {
    op_handle->SetDeviceContext(p,
                                platform::DeviceContextPool::Instance().Get(p));
  }
#else
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
#endif
}

X
Xin Pan 已提交
381
void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
C
chengduoZH 已提交
382
                                                const std::string &p_name,
C
chengduoZH 已提交
383
                                                size_t src_dev_id) const {
C
chengduoZH 已提交
384
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
385 386 387
  auto *op_handle = new BroadcastOpHandle(
      result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_);
C
chengduoZH 已提交
388
#else
X
polish  
Xin Pan 已提交
389 390 391
  auto *op_handle = new BroadcastOpHandle(
      result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
      local_scopes_, places_);
C
chengduoZH 已提交
392
#endif
X
Xin Pan 已提交
393
  result->Get<GraphOps>("ops").emplace_back(op_handle);
X
Xin Pan 已提交
394

X
Xin Pan 已提交
395 396
  auto *in =
      result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
C
chengduoZH 已提交
397 398 399 400
  op_handle->AddInput(in);

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
401
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
402
    auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
X
polish  
Xin Pan 已提交
403 404 405
    auto *out_var = new VarHandle(
        result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
        i, p_name, p);
C
chengduoZH 已提交
406 407 408 409 410
    vars.emplace_back(out_var);
    op_handle->AddOutput(out_var);
  }
}

X
Xin Pan 已提交
411
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
412
                                                    ir::Node *node,
C
chengduoZH 已提交
413
                                                    int dev_id) const {
414
  result->Get<GraphOps>("ops").emplace_back(
X
Xin Pan 已提交
415
      new ComputationOpHandle(result->CreateOpNode(node->Op()),
416 417
                              local_scopes_[dev_id], places_[dev_id]));
  CreateOpHandleIOs(result, node, dev_id);
C
chengduoZH 已提交
418 419
}

X
Xin Pan 已提交
420
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
C
chengduoZH 已提交
421
                                                const std::string &og) const {
Y
Yu Yang 已提交
422
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
423 424 425
  result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
      result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
426
#else
X
Xin Pan 已提交
427
  result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
X
polish  
Xin Pan 已提交
428 429
      result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
      local_scopes_, places_));
C
chengduoZH 已提交
430
#endif
X
Xin Pan 已提交
431
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
Y
Yu Yang 已提交
432 433 434

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
435
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
436
    auto &vars = result->Get<GraphVars>("vars")[i][og];
Y
Yu Yang 已提交
437 438
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
Y
Yu Yang 已提交
439 440
    op_handle->AddInput(prev_grad.get());

X
Xin Pan 已提交
441
    auto var =
X
polish  
Xin Pan 已提交
442 443
        new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
                      vars.size(), i, og, p);
Y
Yu Yang 已提交
444 445 446 447 448
    vars.emplace_back(var);
    op_handle->AddOutput(var);
  }
}

449
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
X
Xin Pan 已提交
450
    ir::Graph *result, const std::vector<std::string> &datas) const {
F
fengjiayi 已提交
451
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
452 453 454
  result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
F
fengjiayi 已提交
455
#else
X
Xin Pan 已提交
456
  result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
X
polish  
Xin Pan 已提交
457 458
      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
      local_scopes_, places_));
F
fengjiayi 已提交
459
#endif
X
Xin Pan 已提交
460
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
461 462 463 464
  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
    SetCommunicationContext(op_handle, p);
    for (const std::string &d_name : datas) {
X
Xin Pan 已提交
465
      auto &vars = result->Get<GraphVars>("vars")[i][d_name];
466 467
      PADDLE_ENFORCE(!vars.empty());
      op_handle->AddInput(vars.back().get());
X
polish  
Xin Pan 已提交
468 469 470
      auto var = new VarHandle(
          result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
          vars.size(), i, d_name, p);
471 472 473 474 475 476
      vars.emplace_back(var);
      op_handle->AddOutput(var);
    }
  }
}

Y
Yu Yang 已提交
477 478 479 480 481 482 483 484 485 486 487 488
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
    const std::string &og,
    std::unordered_set<std::string> *og_has_been_broadcast) const {
  bool is_pg_once =
      grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0;
  if (is_pg_once) {
    // Insert NCCL AllReduce Op
    og_has_been_broadcast->insert(og);
  }
  return is_pg_once;
}

489
int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
Y
yuyang18 已提交
490
  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
C
chengduoZH 已提交
491 492
    return -1;
  }
493
  int op_role = boost::get<int>(
494
      node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
495 496
  if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
    return -1;
C
chengduoZH 已提交
497
  }
498
  auto param_grad = boost::get<std::vector<std::string>>(
X
Xin Pan 已提交
499
      node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
500 501 502

  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
  int dev_id = GetVarDeviceID(param_grad[1]);
X
Xin Pan 已提交
503 504
  PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
                    node->Op()->Type(), param_grad[0], param_grad[1]);
505
  return dev_id;
506 507 508 509 510
}

int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
  auto got = var_name_on_devices_.find(varname);
  return got == var_name_on_devices_.end() ? -1 : got->second;
C
chengduoZH 已提交
511 512
}

X
Xin Pan 已提交
513
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
Y
Yu Yang 已提交
514 515 516
  for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
C
chengduoZH 已提交
517 518 519
    auto *communication_dev_ctx =
        nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
                   : platform::DeviceContextPool::Instance().Get(places_[i]);
Y
Yu Yang 已提交
520 521 522 523
#else
    auto *communication_dev_ctx =
        platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif
X
Xin Pan 已提交
524
    auto *op_handle = new ScaleLossGradOpHandle(
X
polish  
Xin Pan 已提交
525 526 527
        result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
        local_scopes_.size(), local_scopes_[i], places_[i],
        communication_dev_ctx);
X
Xin Pan 已提交
528
    result->Get<GraphOps>("ops").emplace_back(op_handle);
Y
Yu Yang 已提交
529 530 531 532 533 534 535

    // FIXME: Currently ScaleLossGradOp only use device_count as scale
    // factor. So it does not depend on any other operators.
    // VarHandle *loss = GetVarHandle(loss_var_name, place);
    // loss->pending_ops_.emplace_back(op_handle);
    // op_handle->inputs_.emplace_back(loss);

X
polish  
Xin Pan 已提交
536 537 538 539
    CreateOpOutput(result, op_handle,
                   result->CreateEmptyNode(GradVarName(loss_var_name_),
                                           ir::Node::Type::kVariable),
                   places_[i], i);
Y
Yu Yang 已提交
540 541 542
  }
}

X
Xin Pan 已提交
543
void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
544
                                                     ir::Node *node,
T
typhoonzero 已提交
545 546
                                                     size_t num_places) const {
  for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
Y
Yu Yang 已提交
547 548
    auto p = places_[scope_idx];
    auto s = local_scopes_[scope_idx];
X
Xin Pan 已提交
549 550
    result->Get<GraphOps>("ops").emplace_back(
        new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
551
    CreateOpHandleIOs(result, node, scope_idx);
Y
Yu Yang 已提交
552 553 554
  }
}

X
Xin Pan 已提交
555
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
C
chengduoZH 已提交
556 557
                                                   const std::string &og,
                                                   int dst_dev_id) const {
C
chengduoZH 已提交
558
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
559
  result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
X
polish  
Xin Pan 已提交
560 561
      result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
562
#else
563
  result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
X
polish  
Xin Pan 已提交
564 565
      result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
      local_scopes_, places_));
C
chengduoZH 已提交
566
#endif
X
Xin Pan 已提交
567
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
C
chengduoZH 已提交
568 569 570

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
571
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
572
    auto &vars = result->Get<GraphVars>("vars")[i][og];
C
chengduoZH 已提交
573 574 575 576
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
    op_handle->AddInput(prev_grad.get());
  }
X
Xin Pan 已提交
577
  auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
X
polish  
Xin Pan 已提交
578 579 580
  auto var =
      new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
                    vars.size(), dst_dev_id, og, places_[dst_dev_id]);
C
chengduoZH 已提交
581 582 583 584 585
  vars.emplace_back(var);
  op_handle->AddOutput(var);
  return var;
}

586 587
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
X
Xin Pan 已提交
588
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
Y
fix pe  
Yancey1989 已提交
589
                                        const std::string &prev_op_name) const {
X
Xin Pan 已提交
590
  for (auto &prev_op : result->Get<GraphOps>("ops")) {
Y
fix pe  
Yancey1989 已提交
591
    if (prev_op->Name() == prev_op_name) {
X
polish  
Xin Pan 已提交
592 593
      auto *dep_var = new DummyVarHandle(
          result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
Y
Yancey1989 已提交
594
      prev_op->AddOutput(dep_var);
X
Xin Pan 已提交
595
      result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
Y
fix pe  
Yancey1989 已提交
596
      op->AddInput(dep_var);
Y
Yancey1989 已提交
597 598 599 600
    }
  }
}

X
Xin Pan 已提交
601
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
602
                                                ir::Node *node) const {
Y
Yancey1989 已提交
603
  int op_dev_id = -1;
604 605 606
  std::vector<std::string> input_var_names;
  std::vector<std::string> output_var_names;
  for (ir::Node *input : node->inputs) {
X
Xin Pan 已提交
607
    input_var_names.push_back(input->Name());
608 609
  }
  for (ir::Node *output : node->outputs) {
X
Xin Pan 已提交
610
    output_var_names.push_back(output->Name());
611 612 613 614 615
  }

  if (node->Op()->Type() == "split_byref" ||
      node->Op()->Type() == "split_selected_rows") {
    op_dev_id = GetVarDeviceID(input_var_names[0]);
Y
Yancey1989 已提交
616
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
617 618
      op_dev_id = GetAppropriateDeviceID(input_var_names);
      for (auto &varname : input_var_names) {
Y
Yancey1989 已提交
619 620 621
        var_name_on_devices_.emplace(varname, op_dev_id);
      }
    }
622
    for (auto &varname : output_var_names) {
Y
Yancey1989 已提交
623 624
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
625 626 627
  } else if (node->Op()->Type() == "concat") {
    op_dev_id = GetVarDeviceID(input_var_names[0]);
    for (auto &varname : output_var_names) {
Y
yi.wu 已提交
628 629
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
Y
Yancey1989 已提交
630 631 632 633 634 635 636
  } else {
    PADDLE_ENFORCE(
        "the distribute training related op should be in [split_byref, "
        "concat].");
  }

  PADDLE_ENFORCE(op_dev_id != -1,
637 638
                 "can not find right place for distributed op: %s",
                 node->Op()->Type());
Y
Yancey1989 已提交
639

640 641
  CreateComputationalOp(result, node, op_dev_id);
  if (node->Op()->Type() == "concat") {
X
Xin Pan 已提交
642
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
X
Xin Pan 已提交
643
              "fetch_barrier");
Y
Yancey1989 已提交
644 645 646
  }
}

647
// Create RPC related op handles that connects its in ops and out ops.
X
Xin Pan 已提交
648 649
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
                                          ir::Node *node) const {
Y
Yancey1989 已提交
650
  int op_dev_id = -1;
651
  if (node->Op()->Type() == "send") {
X
Xin Pan 已提交
652
    op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
Y
Yancey1989 已提交
653 654
    // the variable name which contains .block means it was splited by
    // split_byref op
655 656
    // so that we can balance the variable blocks to all the pserver
    // instances.
Y
Yancey1989 已提交
657
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
X
Xin Pan 已提交
658
        node->inputs[0]->Name().find(".block") == std::string::npos) {
659 660
      std::vector<std::string> input_var_names;
      for (ir::Node *n : node->inputs) {
X
Xin Pan 已提交
661
        input_var_names.push_back(n->Name());
662 663 664
      }
      op_dev_id = GetAppropriateDeviceID(input_var_names);
      for (auto &varname : input_var_names) {
Y
Yancey1989 已提交
665 666 667
        var_name_on_devices_.emplace(varname, op_dev_id);
      }
    }
668 669 670
  } else if (node->Op()->Type() == "recv") {
    std::vector<std::string> output_var_names;
    for (ir::Node *n : node->outputs) {
X
Xin Pan 已提交
671
      output_var_names.push_back(n->Name());
672 673 674
    }
    op_dev_id = GetAppropriateDeviceID(output_var_names);
    for (auto &varname : output_var_names) {
Y
Yancey1989 已提交
675 676 677 678 679 680 681 682
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
  } else {
    // send_barrier and fetch_barrier op can be scheduled on device 0
    op_dev_id = 0;
  }

  PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
683
                 node->Op()->Type());
Y
Yancey1989 已提交
684

685 686 687
  result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
      result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
      node->Op()->Type(), places_[op_dev_id]));
Y
fix pe  
Yancey1989 已提交
688

689
  if (node->Op()->Type() == "send_barrier") {
X
Xin Pan 已提交
690
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
691
  } else if (node->Op()->Type() == "recv") {
X
Xin Pan 已提交
692
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
X
Xin Pan 已提交
693
              "send_barrier");
694
  } else if (node->Op()->Type() == "fetch_barrier") {
X
Xin Pan 已提交
695
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
696
  } else if (node->Op()->Type() == "send") {
Y
Yancey1989 已提交
697 698 699
    // do nothing
  } else {
    PADDLE_THROW(
Y
Yancey1989 已提交
700
        "rpc op should be in ["
701
        "send, send_barrier. recv, fetch_barrier]");
Y
Yancey1989 已提交
702 703
  }

704
  CreateOpHandleIOs(result, node, op_dev_id);
Y
Yu Yang 已提交
705 706
}

707
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
Y
yuyang18 已提交
708
  return boost::get<int>(
709
             node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Fix bug  
yuyang18 已提交
710 711 712
             (static_cast<int>(OpRole::kBackward) |
              static_cast<int>(OpRole::kLoss)) &&
         !loss_var_name_.empty();  // If loss_var is empty. This is test mode
Y
Yu Yang 已提交
713
}
Y
Yu Yang 已提交
714 715 716
}  // namespace details
}  // namespace framework
}  // namespace paddle