multi_devices_graph_builder.cc 27.4 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
C
chengduoZH 已提交
14
#include <algorithm>
Y
Yancey1989 已提交
15
#include <fstream>
C
chengduoZH 已提交
16
#include <string>
C
chengduoZH 已提交
17
#include <utility>
C
chengduoZH 已提交
18 19
#include <vector>

20
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
C
chengduoZH 已提交
21
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
Y
Yu Yang 已提交
22
#include "paddle/fluid/framework/details/computation_op_handle.h"
23
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
C
chengduoZH 已提交
24
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
C
chengduoZH 已提交
25
#include "paddle/fluid/framework/details/reduce_op_handle.h"
Y
Yancey1989 已提交
26
#include "paddle/fluid/framework/details/rpc_op_handle.h"
Y
Yu Yang 已提交
27
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
X
better  
Xin Pan 已提交
28
#include "paddle/fluid/framework/ir/graph_helper.h"
X
Xin Pan 已提交
29
#include "paddle/fluid/framework/ir/node.h"
Y
Fix bug  
yuyang18 已提交
30
#include "paddle/fluid/framework/op_info.h"
Y
Yu Yang 已提交
31
#include "paddle/fluid/framework/scope.h"
Y
Yu Yang 已提交
32

Y
Yu Yang 已提交
33 34 35
namespace paddle {
namespace framework {
namespace details {
Y
Yu Yang 已提交
36 37

#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
38 39 40 41
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
C
chengduoZH 已提交
42
    const std::vector<Scope *> &local_scopes,
Y
yuyang18 已提交
43
    platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
Y
Yu Yang 已提交
44 45 46
    : loss_var_name_(loss_var_name),
      places_(places),
      local_scopes_(local_scopes),
C
chengduoZH 已提交
47
      nccl_ctxs_(nccl_ctxs),
Y
yuyang18 已提交
48
      strategy_(strategy) {
Y
Yu Yang 已提交
49 50 51 52 53
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
Y
yuyang18 已提交
54
    const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
Y
Yu Yang 已提交
55 56
    : loss_var_name_(loss_var_name),
      places_(places),
C
chengduoZH 已提交
57
      local_scopes_(local_scopes),
Y
yuyang18 已提交
58
      strategy_(strategy) {
Y
Yu Yang 已提交
59
#endif
Y
Yu Yang 已提交
60 61 62
  for (auto &p : params) {
    grad_names_.insert(GradVarName(p));
  }
Y
Yancey1989 已提交
63
  balance_vars_.resize(places_.size(), 0);
Y
yuyang18 已提交
64 65 66 67 68
  if (strategy_.enable_data_balance_ && places_.size() == 1) {
    LOG(WARNING) << "It is no need to enable data balance when there is only "
                    "one place. enable_data_balance is set to False.";
    strategy_.enable_data_balance_ = false;
  }
Y
Yu Yang 已提交
69 70
}

71
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
Y
Yu Yang 已提交
72 73
                                                size_t place_id) const {
  auto p = places_[place_id];
X
Xin Pan 已提交
74
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
X
Xin Pan 已提交
75 76
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
T
wip  
typhoonzero 已提交
77

78 79
  for (ir::Node *input : node->inputs) {
    VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
T
wip  
typhoonzero 已提交
80 81 82
    op_handle->AddInput(var);
  }

83
  for (ir::Node *output : node->outputs) {
X
polish  
Xin Pan 已提交
84 85 86 87 88 89 90 91
    ir::Node *new_node = nullptr;
    if (output->Var()) {
      new_node = result->CreateVarNode(output->Var());
    } else {
      new_node =
          result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
    }
    CreateOpOutput(result, op_handle, new_node, p, place_id);
T
wip  
typhoonzero 已提交
92 93
  }
}
Y
fix pe  
Yancey1989 已提交
94 95

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
96
    const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
Y
fix pe  
Yancey1989 已提交
97
  std::vector<std::string> send_vars;
Y
Yancey1989 已提交
98 99
  // since parameters are all in block 0,
  // it's enough to only scan send ops in block 0
100
  for (auto &node : nodes) {
X
Xin Pan 已提交
101
    if (node->NodeType() != ir::Node::Type::kOperation) continue;
102
    OpDesc *op = node->Op();
Y
Yancey1989 已提交
103 104
    // TODO(Yancey1989): use a graceful method to find send op,
    // instead of the the hard code string
105
    if (op->Type() == "send") {
Y
fix pe  
Yancey1989 已提交
106 107 108 109 110 111 112 113 114 115
      auto op_vars = op->InputArgumentNames();
      send_vars.reserve(send_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return send_vars;
}

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
116
    const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
Y
fix pe  
Yancey1989 已提交
117
  std::vector<std::string> recv_vars;
118
  for (auto &node : nodes) {
X
Xin Pan 已提交
119
    if (node->NodeType() != ir::Node::Type::kOperation) continue;
120
    OpDesc *op = node->Op();
Y
Yancey1989 已提交
121 122 123
    // TODO(Yancey1989): use a graceful method to find recv op,
    // instead of the hard code string
    if (op->Type() == "recv") {
Y
fix pe  
Yancey1989 已提交
124 125 126 127 128 129 130 131 132 133
      auto op_vars = op->OutputArgumentNames();
      recv_vars.reserve(recv_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return recv_vars;
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(
134
    ir::Node *node, const std::vector<std::string> &send_vars,
Y
fix pe  
Yancey1989 已提交
135 136
    const std::vector<std::string> &recv_vars) const {
  if (send_vars.size() == 0 || recv_vars.size() == 0) {
T
typhoonzero 已提交
137 138 139
    return false;
  }

Y
Yu Yang 已提交
140 141 142 143
  /**
   * Check any of opvars contains `.block` and in sendvars
   */
  auto checker = [](const std::vector<std::string> &opvars,
Y
fix pe  
Yancey1989 已提交
144
                    const std::vector<std::string> &rpc_vars) -> bool {
T
typhoonzero 已提交
145
    for (auto &var : opvars) {
Y
Yancey1989 已提交
146 147 148
      // a variable name with the suffix `.block` means it's a splited
      // variable by (DistributeTranspiler)
      // [python/paddle/fluid/transpiler/distribute_transpiler.py]
T
typhoonzero 已提交
149
      if (var.find(".block") != std::string::npos &&
Y
fix pe  
Yancey1989 已提交
150
          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
Y
Yu Yang 已提交
151
        return true;
T
typhoonzero 已提交
152 153
      }
    }
Y
Yu Yang 已提交
154
    return false;
T
typhoonzero 已提交
155 156
  };

157 158 159
  std::vector<std::string> input_var_names;
  std::vector<std::string> output_var_names;
  for (ir::Node *input : node->inputs) {
X
Xin Pan 已提交
160
    input_var_names.push_back(input->Name());
161 162
  }
  for (ir::Node *output : node->outputs) {
X
Xin Pan 已提交
163
    output_var_names.push_back(output->Name());
164 165 166 167
  }

  return checker(output_var_names, send_vars) ||
         checker(input_var_names, recv_vars);
T
typhoonzero 已提交
168 169
}

Y
Yancey1989 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
    const std::vector<std::string> &var_names) const {
  int64_t numel_sum = 0;
  for (auto var_name : var_names) {
    auto var_desc = all_vars_.at(var_name);
    PADDLE_ENFORCE_NOT_NULL(var_desc);
    auto dim = framework::make_ddim(var_desc->GetShape());
    int64_t numel = framework::product(dim);
    PADDLE_ENFORCE_GT(numel, 0);
    numel_sum += numel;
  }

  auto smallest =
      std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
  size_t dev_id =
      static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
  balance_vars_[dev_id] += numel_sum;
  return dev_id;
}

X
better  
Xin Pan 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
// Topology sort the graph nodes from inputs to outputs.
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
  std::vector<ir::Node *> ret = ir::TopologySort(graph);
  size_t last_backward = 0;
  std::vector<ir::Node *> optimize_ops;
  std::vector<ir::Node *> sorted_ret;
  for (size_t i = 0; i < ret.size(); ++i) {
    if (boost::get<int>(
            ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
        static_cast<int>(OpRole::kBackward)) {
      sorted_ret.push_back(ret[i]);
      last_backward = sorted_ret.size();
    } else if (boost::get<int>(ret[i]->Op()->GetAttr(
                   OpProtoAndCheckerMaker::OpRoleAttrName())) ==
               static_cast<int>(OpRole::kOptimize)) {
      optimize_ops.push_back(ret[i]);
    } else {
      sorted_ret.push_back(ret[i]);
    }
  }

  sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
                    optimize_ops.end());

  for (ir::Node *n : sorted_ret) {
    n->inputs.erase(std::remove_if(n->inputs.begin(), n->inputs.end(),
                                   [n](ir::Node *t) {
                                     return t->Name() ==
                                            ir::Node::kControlDepVarName;
                                   }),
                    n->inputs.end());
    n->outputs.erase(std::remove_if(n->outputs.begin(), n->outputs.end(),
                                    [n](ir::Node *t) {
                                      return t->Name() ==
                                             ir::Node::kControlDepVarName;
                                    }),
                     n->outputs.end());
  }
  return sorted_ret;
}

235
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
X
Xin Pan 已提交
236
    std::unique_ptr<Graph> graph) const {
X
Xin Pan 已提交
237
  // Rebuild the graph structure.
X
better  
Xin Pan 已提交
238
  std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
239 240 241 242
  auto nodes = std::move(graph->nodes);
  graph->nodes.clear();

  for (auto &node : nodes) {
X
Xin Pan 已提交
243 244
    if (node->NodeType() == ir::Node::Type::kVariable) {
      all_vars_.emplace(node->Name(), node->Var());
245
    }
C
fix ci  
chengduoZH 已提交
246
  }
C
chengduoZH 已提交
247

X
Xin Pan 已提交
248
  Graph &result = *graph;
C
chengduoZH 已提交
249
  std::unordered_set<std::string> og_has_been_broadcast;
Y
Yu Yang 已提交
250 251

  // We cannot invoke resize. It is a bug of GCC 4.8
X
Xin Pan 已提交
252 253 254
  result.Set("vars", new GraphVars(places_.size()));
  result.Set("dep_vars", new GraphDepVars);
  result.Set("ops", new GraphOps);
255

Y
fix pe  
Yancey1989 已提交
256 257
  // find send/recv vars so that we can place the distributed training
  // realted op in the place 0
258 259
  auto send_vars = FindDistTrainSendVars(nodes);
  auto recv_vars = FindDistTrainRecvVars(nodes);
T
typhoonzero 已提交
260

C
chengduoZH 已提交
261 262 263
  std::vector<std::unordered_set<std::string>> bcast_var_name_set;
  bcast_var_name_set.resize(places_.size());

C
chengduoZH 已提交
264
  size_t cur_device_id = 0;
Y
Yu Yang 已提交
265
  bool is_forwarding = true;
266

X
better  
Xin Pan 已提交
267
  for (ir::Node *node : sorted_ops) {
Y
Yancey1989 已提交
268
    if (boost::get<int>(
269
            node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Yancey1989 已提交
270
        static_cast<int>(OpRole::kRPC)) {
X
Xin Pan 已提交
271 272 273 274
      CreateRPCOp(&result, node);
    } else if (IsDistTrainOp(node, send_vars, recv_vars)) {
      CreateDistTrainOp(&result, node);
    } else if (IsScaleLossOp(node)) {
Y
Yu Yang 已提交
275
      // user can customize loss@grad if not use_default_grad_scale_
Y
yuyang18 已提交
276 277
      if (strategy_.gradient_scale_ !=
          BuildStrategy::GradientScaleStrategy::kCustomized) {
Y
Yu Yang 已提交
278 279
        CreateScaleLossGradOp(&result);
      }
280 281 282 283
      // This assumes the backward generating code will ensure IsScaleLossOp
      // is true only for the op that scale the final scalar loss.
      // It also assumes backward op will always follow the forward op in
      // the block.
Y
Yu Yang 已提交
284
      is_forwarding = false;
Y
Yu Yang 已提交
285
    } else {
X
Xin Pan 已提交
286
      int op_dev_id = GetOpDeviceID(node);
C
chengduo 已提交
287
      if (op_dev_id != -1) {  // This op only runs on one specific device.
X
Xin Pan 已提交
288
        CreateComputationalOp(&result, node, op_dev_id);
289
        for (ir::Node *n : node->outputs) {
X
Xin Pan 已提交
290
          var_name_on_devices_.emplace(n->Name(), op_dev_id);
C
chengduoZH 已提交
291
        }
C
chengduo 已提交
292 293 294
      } else {
        // This op runs on all devices, and its output may have parameter's
        // gradients.
295 296
        if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
          node->Op()->SetAttr("throw_eof_exp", false);
X
Xin Pan 已提交
297
          CreateComputationalOps(&result, node, places_.size());
298
          const auto &data_var_names = node->Op()->Output("Out");
299
          InsertDataBalanceOp(&result, data_var_names);
F
fengjiayi 已提交
300
        } else {
X
Xin Pan 已提交
301
          CreateComputationalOps(&result, node, places_.size());
302 303
        }

C
chengduo 已提交
304 305 306
        if (!is_forwarding && places_.size() > 1) {
          // Currently, we assume that once gradient is generated, it can be
          // broadcast, and each gradient is only broadcast once.
307
          if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
C
chengduo 已提交
308 309 310
                                    OpProtoAndCheckerMaker::OpRoleAttrName())) &
                                static_cast<int>(OpRole::kBackward))) {
            try {
311 312
              auto backward_vars = boost::get<std::vector<std::string>>(
                  node->Op()->GetNullableAttr(
C
chengduo 已提交
313
                      OpProtoAndCheckerMaker::OpRoleVarAttrName()));
Y
yuyang18 已提交
314

C
chengduo 已提交
315
              PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
Y
yuyang18 已提交
316

C
chengduo 已提交
317 318 319 320
              for (size_t i = 0; i < backward_vars.size(); i += 2) {
                auto &p_name = backward_vars[i];
                auto &g_name = backward_vars[i + 1];
                VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
Y
yuyang18 已提交
321

C
chengduo 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
                switch (strategy_.reduce_) {
                  case BuildStrategy::ReduceStrategy::kReduce:
                    cur_device_id = GetAppropriateDeviceID({g_name});
                    CreateReduceOp(&result, g_name, cur_device_id);
                    var_name_on_devices_.emplace(g_name, cur_device_id);
                    bcast_var_name_set[cur_device_id].emplace(p_name);
                    break;
                  case BuildStrategy::ReduceStrategy::kAllReduce:
                    if (IsSparseGradient(g_name)) {
                      CreateReduceOp(&result, g_name, 0);
                      CreateBroadcastOp(&result, g_name, 0);
                    } else {
                      InsertAllReduceOp(&result, g_name);
                    }
                    break;
                  default:
                    LOG(FATAL) << "Unknown reduce strategy ";
                    break;
                }
Y
yuyang18 已提交
341
              }
C
chengduo 已提交
342
            } catch (boost::bad_get e) {
C
chengduoZH 已提交
343
            }
Y
Yu Yang 已提交
344 345 346 347 348 349
          }
        }
      }
    }
  }

350 351 352 353 354 355 356 357 358 359 360 361 362
  bool use_gpu = false;
#ifdef PADDLE_WITH_CUDA
  use_gpu = nccl_ctxs_ != nullptr;
#endif

  if (use_gpu ||
      strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
    // Insert BCast Ops
    for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
      auto &to_bcast_set = bcast_var_name_set[dev_id];
      for (auto &bcast_name : to_bcast_set) {
        CreateBroadcastOp(&result, bcast_name, dev_id);
      }
C
chengduoZH 已提交
363 364
    }
  }
365

Y
Yu Yang 已提交
366 367
  /*
    Dependency graph has been constructed. However, there are still data
368
    hazards need to be handled.
Y
Yu Yang 已提交
369 370
   */
  PolishGraphToSupportDataHazards(&result);
Y
Yu Yang 已提交
371

Y
Yu Yang 已提交
372 373 374 375
  /*
   * Only variables should be the leaves of graph.
   */
  AddOutputToLeafOps(&result);
Q
qiaolongfei 已提交
376
  return graph;
Y
Yu Yang 已提交
377 378
}

Y
Yancey1989 已提交
379 380 381
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
  PADDLE_ENFORCE(all_vars_.count(og) != 0);
  if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
C
fix ci  
chengduoZH 已提交
382 383 384
    return true;
  }
  return false;
385 386
}

387 388 389 390 391 392 393 394 395 396 397 398 399
void MultiDevSSAGraphBuilder::SetCommunicationContext(
    OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
  if (nccl_ctxs_ == nullptr) {
    op_handle->SetDeviceContext(p,
                                platform::DeviceContextPool::Instance().Get(p));
  }
#else
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
#endif
}

X
Xin Pan 已提交
400
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
C
chengduoZH 已提交
401
                                                const std::string &p_name,
C
chengduoZH 已提交
402
                                                size_t src_dev_id) const {
C
chengduoZH 已提交
403
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
404 405 406
  auto *op_handle = new BroadcastOpHandle(
      result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_);
C
chengduoZH 已提交
407
#else
X
polish  
Xin Pan 已提交
408 409 410
  auto *op_handle = new BroadcastOpHandle(
      result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
      local_scopes_, places_);
C
chengduoZH 已提交
411
#endif
X
Xin Pan 已提交
412
  result->Get<GraphOps>("ops").emplace_back(op_handle);
X
Xin Pan 已提交
413

X
Xin Pan 已提交
414 415
  auto *in =
      result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
C
chengduoZH 已提交
416 417 418 419
  op_handle->AddInput(in);

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
420
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
421
    auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
X
polish  
Xin Pan 已提交
422 423 424
    auto *out_var = new VarHandle(
        result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
        i, p_name, p);
C
chengduoZH 已提交
425 426 427 428 429
    vars.emplace_back(out_var);
    op_handle->AddOutput(out_var);
  }
}

X
Xin Pan 已提交
430
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
431
                                                    ir::Node *node,
C
chengduoZH 已提交
432
                                                    int dev_id) const {
433
  result->Get<GraphOps>("ops").emplace_back(
X
Xin Pan 已提交
434
      new ComputationOpHandle(result->CreateOpNode(node->Op()),
435 436
                              local_scopes_[dev_id], places_[dev_id]));
  CreateOpHandleIOs(result, node, dev_id);
C
chengduoZH 已提交
437 438
}

X
Xin Pan 已提交
439
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
C
chengduoZH 已提交
440
                                                const std::string &og) const {
Y
Yu Yang 已提交
441
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
442 443 444
  result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
      result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
445
#else
X
Xin Pan 已提交
446
  result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
X
polish  
Xin Pan 已提交
447 448
      result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
      local_scopes_, places_));
C
chengduoZH 已提交
449
#endif
X
Xin Pan 已提交
450
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
Y
Yu Yang 已提交
451 452 453

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
454
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
455
    auto &vars = result->Get<GraphVars>("vars")[i][og];
Y
Yu Yang 已提交
456 457
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
Y
Yu Yang 已提交
458 459
    op_handle->AddInput(prev_grad.get());

X
Xin Pan 已提交
460
    auto var =
X
polish  
Xin Pan 已提交
461 462
        new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
                      vars.size(), i, og, p);
Y
Yu Yang 已提交
463 464 465 466 467
    vars.emplace_back(var);
    op_handle->AddOutput(var);
  }
}

468
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
X
Xin Pan 已提交
469
    Graph *result, const std::vector<std::string> &datas) const {
F
fengjiayi 已提交
470
#ifdef PADDLE_WITH_CUDA
X
polish  
Xin Pan 已提交
471 472 473
  result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
F
fengjiayi 已提交
474
#else
X
Xin Pan 已提交
475
  result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
X
polish  
Xin Pan 已提交
476 477
      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
      local_scopes_, places_));
F
fengjiayi 已提交
478
#endif
X
Xin Pan 已提交
479
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
480 481 482 483
  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
    SetCommunicationContext(op_handle, p);
    for (const std::string &d_name : datas) {
X
Xin Pan 已提交
484
      auto &vars = result->Get<GraphVars>("vars")[i][d_name];
485 486
      PADDLE_ENFORCE(!vars.empty());
      op_handle->AddInput(vars.back().get());
X
polish  
Xin Pan 已提交
487 488 489
      auto var = new VarHandle(
          result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
          vars.size(), i, d_name, p);
490 491 492 493 494 495
      vars.emplace_back(var);
      op_handle->AddOutput(var);
    }
  }
}

Y
Yu Yang 已提交
496 497 498 499 500 501 502 503 504 505 506 507
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
    const std::string &og,
    std::unordered_set<std::string> *og_has_been_broadcast) const {
  bool is_pg_once =
      grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0;
  if (is_pg_once) {
    // Insert NCCL AllReduce Op
    og_has_been_broadcast->insert(og);
  }
  return is_pg_once;
}

508
int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
Y
yuyang18 已提交
509
  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
C
chengduoZH 已提交
510 511
    return -1;
  }
512
  int op_role = boost::get<int>(
513
      node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
514 515
  if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
    return -1;
C
chengduoZH 已提交
516
  }
517
  auto param_grad = boost::get<std::vector<std::string>>(
X
Xin Pan 已提交
518
      node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
519 520 521

  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
  int dev_id = GetVarDeviceID(param_grad[1]);
X
Xin Pan 已提交
522 523
  PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
                    node->Op()->Type(), param_grad[0], param_grad[1]);
524
  return dev_id;
525 526 527 528 529
}

int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
  auto got = var_name_on_devices_.find(varname);
  return got == var_name_on_devices_.end() ? -1 : got->second;
C
chengduoZH 已提交
530 531
}

X
Xin Pan 已提交
532
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
Y
Yu Yang 已提交
533 534 535
  for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
C
chengduoZH 已提交
536 537 538
    auto *communication_dev_ctx =
        nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
                   : platform::DeviceContextPool::Instance().Get(places_[i]);
Y
Yu Yang 已提交
539 540 541 542
#else
    auto *communication_dev_ctx =
        platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif
X
Xin Pan 已提交
543
    auto *op_handle = new ScaleLossGradOpHandle(
X
polish  
Xin Pan 已提交
544 545 546
        result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
        local_scopes_.size(), local_scopes_[i], places_[i],
        communication_dev_ctx);
X
Xin Pan 已提交
547
    result->Get<GraphOps>("ops").emplace_back(op_handle);
Y
Yu Yang 已提交
548 549 550 551 552 553 554

    // FIXME: Currently ScaleLossGradOp only use device_count as scale
    // factor. So it does not depend on any other operators.
    // VarHandle *loss = GetVarHandle(loss_var_name, place);
    // loss->pending_ops_.emplace_back(op_handle);
    // op_handle->inputs_.emplace_back(loss);

X
polish  
Xin Pan 已提交
555 556 557 558
    CreateOpOutput(result, op_handle,
                   result->CreateEmptyNode(GradVarName(loss_var_name_),
                                           ir::Node::Type::kVariable),
                   places_[i], i);
Y
Yu Yang 已提交
559 560 561
  }
}

X
Xin Pan 已提交
562
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
563
                                                     ir::Node *node,
T
typhoonzero 已提交
564 565
                                                     size_t num_places) const {
  for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
Y
Yu Yang 已提交
566 567
    auto p = places_[scope_idx];
    auto s = local_scopes_[scope_idx];
X
Xin Pan 已提交
568 569
    result->Get<GraphOps>("ops").emplace_back(
        new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
570
    CreateOpHandleIOs(result, node, scope_idx);
Y
Yu Yang 已提交
571 572 573
  }
}

X
Xin Pan 已提交
574
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
C
chengduoZH 已提交
575 576
                                                   const std::string &og,
                                                   int dst_dev_id) const {
C
chengduoZH 已提交
577
#ifdef PADDLE_WITH_CUDA
X
Xin Pan 已提交
578
  result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
X
polish  
Xin Pan 已提交
579 580
      result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
      local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
581
#else
582
  result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
X
polish  
Xin Pan 已提交
583 584
      result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
      local_scopes_, places_));
C
chengduoZH 已提交
585
#endif
X
Xin Pan 已提交
586
  auto *op_handle = result->Get<GraphOps>("ops").back().get();
C
chengduoZH 已提交
587 588 589

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
590
    SetCommunicationContext(op_handle, p);
X
Xin Pan 已提交
591
    auto &vars = result->Get<GraphVars>("vars")[i][og];
C
chengduoZH 已提交
592 593 594 595
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
    op_handle->AddInput(prev_grad.get());
  }
X
Xin Pan 已提交
596
  auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
X
polish  
Xin Pan 已提交
597 598 599
  auto var =
      new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
                    vars.size(), dst_dev_id, og, places_[dst_dev_id]);
C
chengduoZH 已提交
600 601 602 603 604
  vars.emplace_back(var);
  op_handle->AddOutput(var);
  return var;
}

605 606
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
X
Xin Pan 已提交
607
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
Y
fix pe  
Yancey1989 已提交
608
                                        const std::string &prev_op_name) const {
X
Xin Pan 已提交
609
  for (auto &prev_op : result->Get<GraphOps>("ops")) {
Y
fix pe  
Yancey1989 已提交
610
    if (prev_op->Name() == prev_op_name) {
X
polish  
Xin Pan 已提交
611 612
      auto *dep_var = new DummyVarHandle(
          result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
Y
Yancey1989 已提交
613
      prev_op->AddOutput(dep_var);
X
Xin Pan 已提交
614
      result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
Y
fix pe  
Yancey1989 已提交
615
      op->AddInput(dep_var);
Y
Yancey1989 已提交
616 617 618 619
    }
  }
}

X
Xin Pan 已提交
620
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
621
                                                ir::Node *node) const {
Y
Yancey1989 已提交
622
  int op_dev_id = -1;
623 624 625
  std::vector<std::string> input_var_names;
  std::vector<std::string> output_var_names;
  for (ir::Node *input : node->inputs) {
X
Xin Pan 已提交
626
    input_var_names.push_back(input->Name());
627 628
  }
  for (ir::Node *output : node->outputs) {
X
Xin Pan 已提交
629
    output_var_names.push_back(output->Name());
630 631 632 633 634
  }

  if (node->Op()->Type() == "split_byref" ||
      node->Op()->Type() == "split_selected_rows") {
    op_dev_id = GetVarDeviceID(input_var_names[0]);
Y
Yancey1989 已提交
635
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
636 637
      op_dev_id = GetAppropriateDeviceID(input_var_names);
      for (auto &varname : input_var_names) {
Y
Yancey1989 已提交
638 639 640
        var_name_on_devices_.emplace(varname, op_dev_id);
      }
    }
641
    for (auto &varname : output_var_names) {
Y
Yancey1989 已提交
642 643
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
644 645 646
  } else if (node->Op()->Type() == "concat") {
    op_dev_id = GetVarDeviceID(input_var_names[0]);
    for (auto &varname : output_var_names) {
Y
yi.wu 已提交
647 648
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
Y
Yancey1989 已提交
649 650 651 652 653 654 655
  } else {
    PADDLE_ENFORCE(
        "the distribute training related op should be in [split_byref, "
        "concat].");
  }

  PADDLE_ENFORCE(op_dev_id != -1,
656 657
                 "can not find right place for distributed op: %s",
                 node->Op()->Type());
Y
Yancey1989 已提交
658

659 660
  CreateComputationalOp(result, node, op_dev_id);
  if (node->Op()->Type() == "concat") {
X
Xin Pan 已提交
661
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
X
Xin Pan 已提交
662
              "fetch_barrier");
Y
Yancey1989 已提交
663 664 665
  }
}

666
// Create RPC related op handles that connects its in ops and out ops.
667
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
Y
Yancey1989 已提交
668
  int op_dev_id = -1;
669
  if (node->Op()->Type() == "send") {
X
Xin Pan 已提交
670
    op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
Y
Yancey1989 已提交
671 672
    // the variable name which contains .block means it was splited by
    // split_byref op
673 674
    // so that we can balance the variable blocks to all the pserver
    // instances.
Y
Yancey1989 已提交
675
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
X
Xin Pan 已提交
676
        node->inputs[0]->Name().find(".block") == std::string::npos) {
677 678
      std::vector<std::string> input_var_names;
      for (ir::Node *n : node->inputs) {
X
Xin Pan 已提交
679
        input_var_names.push_back(n->Name());
680 681 682
      }
      op_dev_id = GetAppropriateDeviceID(input_var_names);
      for (auto &varname : input_var_names) {
Y
Yancey1989 已提交
683 684 685
        var_name_on_devices_.emplace(varname, op_dev_id);
      }
    }
686 687 688
  } else if (node->Op()->Type() == "recv") {
    std::vector<std::string> output_var_names;
    for (ir::Node *n : node->outputs) {
X
Xin Pan 已提交
689
      output_var_names.push_back(n->Name());
690 691 692
    }
    op_dev_id = GetAppropriateDeviceID(output_var_names);
    for (auto &varname : output_var_names) {
Y
Yancey1989 已提交
693 694 695 696 697 698 699 700
      var_name_on_devices_.emplace(varname, op_dev_id);
    }
  } else {
    // send_barrier and fetch_barrier op can be scheduled on device 0
    op_dev_id = 0;
  }

  PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
701
                 node->Op()->Type());
Y
Yancey1989 已提交
702

703 704 705
  result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
      result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
      node->Op()->Type(), places_[op_dev_id]));
Y
fix pe  
Yancey1989 已提交
706

707
  if (node->Op()->Type() == "send_barrier") {
X
Xin Pan 已提交
708
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
709
  } else if (node->Op()->Type() == "recv") {
X
Xin Pan 已提交
710
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
X
Xin Pan 已提交
711
              "send_barrier");
712
  } else if (node->Op()->Type() == "fetch_barrier") {
X
Xin Pan 已提交
713
    ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
714
  } else if (node->Op()->Type() == "send") {
Y
Yancey1989 已提交
715 716 717
    // do nothing
  } else {
    PADDLE_THROW(
Y
Yancey1989 已提交
718
        "rpc op should be in ["
719
        "send, send_barrier. recv, fetch_barrier]");
Y
Yancey1989 已提交
720 721
  }

722
  CreateOpHandleIOs(result, node, op_dev_id);
Y
Yu Yang 已提交
723 724
}

725
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
Y
yuyang18 已提交
726
  return boost::get<int>(
727
             node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Fix bug  
yuyang18 已提交
728 729 730
             (static_cast<int>(OpRole::kBackward) |
              static_cast<int>(OpRole::kLoss)) &&
         !loss_var_name_.empty();  // If loss_var is empty. This is test mode
Y
Yu Yang 已提交
731
}
Y
Yu Yang 已提交
732 733 734
}  // namespace details
}  // namespace framework
}  // namespace paddle