multi_devices_graph_builder.cc 19.5 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
C
chengduoZH 已提交
14
#include <algorithm>
Y
Yancey1989 已提交
15
#include <fstream>
C
chengduoZH 已提交
16
#include <string>
C
chengduoZH 已提交
17
#include <utility>
C
chengduoZH 已提交
18 19
#include <vector>

20
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
C
chengduoZH 已提交
21
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
Y
Yu Yang 已提交
22
#include "paddle/fluid/framework/details/computation_op_handle.h"
C
chengduoZH 已提交
23
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
C
chengduoZH 已提交
24
#include "paddle/fluid/framework/details/reduce_op_handle.h"
Y
Yancey1989 已提交
25
#include "paddle/fluid/framework/details/rpc_op_handle.h"
Y
Yu Yang 已提交
26
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
Y
Fix bug  
yuyang18 已提交
27
#include "paddle/fluid/framework/op_info.h"
Y
Yu Yang 已提交
28
#include "paddle/fluid/framework/scope.h"
Y
Yu Yang 已提交
29

Y
Yu Yang 已提交
30 31 32
namespace paddle {
namespace framework {
namespace details {
Y
Yu Yang 已提交
33 34

#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
35 36 37 38
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
C
chengduoZH 已提交
39
    const std::vector<Scope *> &local_scopes,
Y
yuyang18 已提交
40
    platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
Y
Yu Yang 已提交
41 42 43
    : loss_var_name_(loss_var_name),
      places_(places),
      local_scopes_(local_scopes),
C
chengduoZH 已提交
44
      nccl_ctxs_(nccl_ctxs),
Y
yuyang18 已提交
45
      strategy_(strategy) {
Y
Yu Yang 已提交
46 47 48 49 50
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
    const std::vector<platform::Place> &places,
    const std::string &loss_var_name,
    const std::unordered_set<std::string> &params,
Y
yuyang18 已提交
51
    const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
Y
Yu Yang 已提交
52 53
    : loss_var_name_(loss_var_name),
      places_(places),
C
chengduoZH 已提交
54
      local_scopes_(local_scopes),
Y
yuyang18 已提交
55
      strategy_(strategy) {
Y
Yu Yang 已提交
56
#endif
Y
Yu Yang 已提交
57 58 59 60 61
  for (auto &p : params) {
    grad_names_.insert(GradVarName(p));
  }
}

Y
Yu Yang 已提交
62 63
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
                                                const OpDesc &op,
Y
Yu Yang 已提交
64 65
                                                size_t place_id) const {
  auto p = places_[place_id];
T
wip  
typhoonzero 已提交
66
  auto *op_handle = result->ops_.back().get();
X
Xin Pan 已提交
67 68
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
T
wip  
typhoonzero 已提交
69

Y
Yu Yang 已提交
70 71 72
  for (auto &each_var_name : op.InputArgumentNames()) {
    VarHandle *var =
        CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
T
wip  
typhoonzero 已提交
73 74 75
    op_handle->AddInput(var);
  }

Y
Yu Yang 已提交
76 77
  for (auto &each_var_name : op.OutputArgumentNames()) {
    CreateOpOutput(result, op_handle, each_var_name, p, place_id);
T
wip  
typhoonzero 已提交
78 79
  }
}
Y
fix pe  
Yancey1989 已提交
80 81 82 83

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
    const ProgramDesc &program) const {
  std::vector<std::string> send_vars;
Y
Yancey1989 已提交
84 85
  // since parameters are all in block 0,
  // it's enough to only scan send ops in block 0
Y
fix pe  
Yancey1989 已提交
86
  for (auto *op : program.Block(0).AllOps()) {
Y
Yancey1989 已提交
87 88
    // TODO(Yancey1989): use a graceful method to find send op,
    // instead of the the hard code string
89
    if (op->Type() == "send") {
Y
fix pe  
Yancey1989 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102
      auto op_vars = op->InputArgumentNames();
      send_vars.reserve(send_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return send_vars;
}

std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
    const ProgramDesc &program) const {
  std::vector<std::string> recv_vars;
  for (auto *op : program.Block(0).AllOps()) {
Y
Yancey1989 已提交
103 104 105
    // TODO(Yancey1989): use a graceful method to find recv op,
    // instead of the hard code string
    if (op->Type() == "recv") {
Y
fix pe  
Yancey1989 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118
      auto op_vars = op->OutputArgumentNames();
      recv_vars.reserve(recv_vars.size() +
                        std::distance(op_vars.begin(), op_vars.end()));
      recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
    }
  }
  return recv_vars;
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(
    const OpDesc &op, const std::vector<std::string> &send_vars,
    const std::vector<std::string> &recv_vars) const {
  if (send_vars.size() == 0 || recv_vars.size() == 0) {
T
typhoonzero 已提交
119 120 121
    return false;
  }

Y
Yu Yang 已提交
122 123 124 125
  /**
   * Check any of opvars contains `.block` and in sendvars
   */
  auto checker = [](const std::vector<std::string> &opvars,
Y
fix pe  
Yancey1989 已提交
126
                    const std::vector<std::string> &rpc_vars) -> bool {
T
typhoonzero 已提交
127
    for (auto &var : opvars) {
Y
Yancey1989 已提交
128 129 130
      // a variable name with the suffix `.block` means it's a splited
      // variable by (DistributeTranspiler)
      // [python/paddle/fluid/transpiler/distribute_transpiler.py]
T
typhoonzero 已提交
131
      if (var.find(".block") != std::string::npos &&
Y
fix pe  
Yancey1989 已提交
132
          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
Y
Yu Yang 已提交
133
        return true;
T
typhoonzero 已提交
134 135
      }
    }
Y
Yu Yang 已提交
136
    return false;
T
typhoonzero 已提交
137 138
  };

Y
Yancey1989 已提交
139 140
  return checker(op.OutputArgumentNames(), send_vars) ||
         checker(op.InputArgumentNames(), recv_vars);
T
typhoonzero 已提交
141 142
}

Y
Yu Yang 已提交
143 144
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
    const ProgramDesc &program) const {
145
  VLOG(3) << "Building ....";
C
chengduoZH 已提交
146
  std::unordered_map<std::string, VarDesc *> all_vars;
C
fix ci  
chengduoZH 已提交
147
  for (auto *var : program.Block(0).AllVars()) {
C
chengduoZH 已提交
148
    all_vars[var->Name()] = var;
C
fix ci  
chengduoZH 已提交
149
  }
C
chengduoZH 已提交
150

Y
Yu Yang 已提交
151
  auto graph = new SSAGraph();
Y
Yu Yang 已提交
152
  SSAGraph &result = *graph;
C
chengduoZH 已提交
153
  std::unordered_set<std::string> og_has_been_broadcast;
Y
Yu Yang 已提交
154 155 156 157 158

  // We cannot invoke resize. It is a bug of GCC 4.8
  result.vars_ = std::vector<
      std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
      places_.size());
Y
Yu Yang 已提交
159

Y
fix pe  
Yancey1989 已提交
160 161 162 163
  // find send/recv vars so that we can place the distributed training
  // realted op in the place 0
  auto send_vars = FindDistTrainSendVars(program);
  auto recv_vars = FindDistTrainRecvVars(program);
T
typhoonzero 已提交
164

C
chengduoZH 已提交
165 166 167 168 169
  std::vector<std::unordered_set<std::string>> var_name_on_devices;
  std::vector<std::unordered_set<std::string>> bcast_var_name_set;
  var_name_on_devices.resize(places_.size());
  bcast_var_name_set.resize(places_.size());

C
chengduoZH 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  size_t cur_device_id = 0;
  std::vector<int64_t> balance_grads(places_.size(), 0);

  auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
    auto var_desc = all_vars.at(g_name);
    PADDLE_ENFORCE_NOT_NULL(var_desc);
    auto dim = framework::make_ddim(var_desc->GetShape());
    int64_t numel = framework::product(dim);
    PADDLE_ENFORCE_GE(numel, 0);
    auto smallest =
        std::min_element(std::begin(balance_grads), std::end(balance_grads));
    size_t dev_id =
        static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
    balance_grads[dev_id] += numel;
    return dev_id;
  };

Y
Yu Yang 已提交
187
  bool is_forwarding = true;
188 189 190 191 192 193 194 195
  int rpc_op_device_id = 0;
  auto schedule_rpc_op = [&]() -> void {
    rpc_op_device_id++;
    if (rpc_op_device_id >= static_cast<int>(places_.size())) {
      rpc_op_device_id = 0;
    }
  };

Y
Yu Yang 已提交
196
  for (auto *op : program.Block(0).AllOps()) {
Y
Yancey1989 已提交
197 198 199
    if (boost::get<int>(
            op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
        static_cast<int>(OpRole::kRPC)) {
Y
Yancey1989 已提交
200
      // append rpc op if program is distributed trainer main program.
Y
Yu Yang 已提交
201
      // always use the first device
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
      if (op->Type() == "send_vars") {
        auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
        if (got == remote_vars_devices_.end()) {
          schedule_rpc_op();
        } else {
          rpc_op_device_id = got->second;
        }
        CreateRPCOp(&result, *op, rpc_op_device_id);
      } else if (op->Type() == "recv") {
        schedule_rpc_op();
        for (auto &varname : op->OutputArgumentNames()) {
          remote_vars_devices_.insert({varname, rpc_op_device_id});
        }
        CreateRPCOp(&result, *op, rpc_op_device_id);
      } else {
        CreateRPCOp(&result, *op, 0);
      }
Y
fix pe  
Yancey1989 已提交
219
    } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
220 221 222 223 224 225 226
      if (op->Type() == "split_byref") {
        schedule_rpc_op();
        for (auto &varname : op->OutputArgumentNames()) {
          remote_vars_devices_.insert({varname, rpc_op_device_id});
        }
        CreateDistTrainOp(&result, *op, rpc_op_device_id);
      }
Y
Yancey1989 已提交
227
      if (op->Type() == "concat") {
228
        auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
Y
Yancey1989 已提交
229
        PADDLE_ENFORCE(got != remote_vars_devices_.end(),
230
                       "can not find right place to concatenate received var.");
231 232 233 234
        CreateDistTrainOp(&result, *op, got->second);
      } else {
        CreateDistTrainOp(&result, *op, 0);
      }
Y
Yu Yang 已提交
235
    } else if (IsScaleLossOp(*op)) {
Y
Yu Yang 已提交
236
      // user can customize loss@grad if not use_default_grad_scale_
Y
yuyang18 已提交
237 238
      if (strategy_.gradient_scale_ !=
          BuildStrategy::GradientScaleStrategy::kCustomized) {
Y
Yu Yang 已提交
239 240
        CreateScaleLossGradOp(&result);
      }
Y
Yu Yang 已提交
241
      is_forwarding = false;
Y
Yu Yang 已提交
242
    } else {
C
chengduoZH 已提交
243 244 245 246 247 248 249 250 251
      int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
      if (op_dev_id == -1) {  // var on all device
        CreateComputationalOps(&result, *op, places_.size());
      } else {
        CreateComputationalOp(&result, *op, op_dev_id);
        for (auto &var_name : op->OutputArgumentNames()) {
          var_name_on_devices[op_dev_id].emplace(var_name);
        }
      }
C
chengduoZH 已提交
252
      if (!is_forwarding && places_.size() > 1) {
Y
Yu Yang 已提交
253
        // Currently, we assume that once gradient is generated, it can be
Y
Yu Yang 已提交
254
        // broadcast, and each gradient is only broadcast once.
Y
yuyang18 已提交
255 256 257
        if (static_cast<bool>(boost::get<int>(op->GetAttr(
                                  OpProtoAndCheckerMaker::OpRoleAttrName())) &
                              static_cast<int>(OpRole::kBackward))) {
Y
yuyang18 已提交
258 259 260 261 262 263 264
          try {
            auto backward_vars =
                boost::get<std::vector<std::string>>(op->GetNullableAttr(
                    OpProtoAndCheckerMaker::OpRoleVarAttrName()));

            PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);

Y
Fix bug  
yuyang18 已提交
265
            for (size_t i = 0; i < backward_vars.size(); i += 2) {
Y
yuyang18 已提交
266 267
              auto &p_name = backward_vars[i];
              auto &g_name = backward_vars[i + 1];
Y
yuyang18 已提交
268 269
              VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;

Y
yuyang18 已提交
270 271
              switch (strategy_.reduce_) {
                case BuildStrategy::ReduceStrategy::kReduce:
C
chengduoZH 已提交
272
                  cur_device_id = get_appropriate_dev(g_name);
Y
yuyang18 已提交
273 274 275 276 277
                  CreateReduceOp(&result, g_name, cur_device_id);
                  var_name_on_devices[cur_device_id].emplace(g_name);
                  bcast_var_name_set[cur_device_id].emplace(p_name);
                  break;
                case BuildStrategy::ReduceStrategy::kAllReduce:
C
chengduoZH 已提交
278
                  if (IsSparseGradient(all_vars, g_name)) {
Y
yuyang18 已提交
279 280 281
                    CreateReduceOp(&result, g_name, 0);
                    CreateBroadcastOp(&result, g_name, 0);
                  } else {
C
chengduoZH 已提交
282
                    InsertAllReduceOp(&result, g_name);
Y
yuyang18 已提交
283 284 285
                  }
                  break;
              }
C
chengduoZH 已提交
286
            }
Y
yuyang18 已提交
287
          } catch (boost::bad_get e) {
Y
Yu Yang 已提交
288 289 290 291 292 293
          }
        }
      }
    }
  }

C
chengduoZH 已提交
294 295 296 297 298 299 300
  // Insert BCast Ops
  for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
    auto &to_bcast_set = bcast_var_name_set[dev_id];
    for (auto &bcast_name : to_bcast_set) {
      CreateBroadcastOp(&result, bcast_name, dev_id);
    }
  }
Y
Yu Yang 已提交
301 302 303 304 305
  /*
    Dependency graph has been constructed. However, there are still data
    harzaeds need to be handled.
   */
  PolishGraphToSupportDataHazards(&result);
Y
Yu Yang 已提交
306

Y
Yu Yang 已提交
307 308 309 310 311
  /*
   * Only variables should be the leaves of graph.
   */
  AddOutputToLeafOps(&result);

Y
Yu Yang 已提交
312
  return std::unique_ptr<SSAGraph>(graph);
Y
Yu Yang 已提交
313 314
}

C
fix ci  
chengduoZH 已提交
315
bool MultiDevSSAGraphBuilder::IsSparseGradient(
C
chengduoZH 已提交
316
    const std::unordered_map<std::string, VarDesc *> &all_vars,
C
fix ci  
chengduoZH 已提交
317
    const std::string &og) const {
C
chengduoZH 已提交
318 319
  PADDLE_ENFORCE(all_vars.count(og) != 0);
  if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
C
fix ci  
chengduoZH 已提交
320 321 322
    return true;
  }
  return false;
323 324
}

325 326 327 328 329 330 331 332 333 334 335 336 337
void MultiDevSSAGraphBuilder::SetCommunicationContext(
    OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
  if (nccl_ctxs_ == nullptr) {
    op_handle->SetDeviceContext(p,
                                platform::DeviceContextPool::Instance().Get(p));
  }
#else
  op_handle->SetDeviceContext(p,
                              platform::DeviceContextPool::Instance().Get(p));
#endif
}

C
chengduoZH 已提交
338 339
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
                                                const std::string &p_name,
C
chengduoZH 已提交
340
                                                size_t src_dev_id) const {
C
chengduoZH 已提交
341 342 343 344 345 346 347
#ifdef PADDLE_WITH_CUDA
  auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
#else
  auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
#endif

  result->ops_.emplace_back(op_handle);
C
chengduoZH 已提交
348
  auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
C
chengduoZH 已提交
349 350 351 352
  op_handle->AddInput(in);

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
353 354
    SetCommunicationContext(op_handle, p);
    auto &vars = result->vars_.at(i).at(p_name);
C
chengduoZH 已提交
355 356 357 358 359 360 361 362 363 364 365 366 367 368
    auto *out_var = new VarHandle(vars.size(), i, p_name, p);
    vars.emplace_back(out_var);
    op_handle->AddOutput(out_var);
  }
}

void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
                                                    const OpDesc &op,
                                                    int dev_id) const {
  result->ops_.emplace_back(
      new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
  CreateOpHandleIOs(result, op, dev_id);
}

C
chengduoZH 已提交
369 370
void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
                                                const std::string &og) const {
Y
Yu Yang 已提交
371 372
#ifdef PADDLE_WITH_CUDA
  result->ops_.emplace_back(
373
      new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
C
chengduoZH 已提交
374
#else
375
  result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
C
chengduoZH 已提交
376
#endif
Y
Yu Yang 已提交
377 378 379 380
  auto *op_handle = result->ops_.back().get();

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
381
    SetCommunicationContext(op_handle, p);
Y
Yu Yang 已提交
382
    auto &vars = result->vars_[i][og];
Y
Yu Yang 已提交
383 384
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
Y
Yu Yang 已提交
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
    op_handle->AddInput(prev_grad.get());

    auto var = new VarHandle(vars.size() - 1, i, og, p);
    vars.emplace_back(var);
    op_handle->AddOutput(var);
  }
}

bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
    const std::string &og,
    std::unordered_set<std::string> *og_has_been_broadcast) const {
  bool is_pg_once =
      grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0;
  if (is_pg_once) {
    // Insert NCCL AllReduce Op
    og_has_been_broadcast->insert(og);
  }
  return is_pg_once;
}

C
chengduoZH 已提交
405 406 407
int MultiDevSSAGraphBuilder::GetOpDeviceID(
    const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
    const OpDesc &op) const {
Y
yuyang18 已提交
408
  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
C
chengduoZH 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    return -1;
  }

  int var_dev_id = -1;
  for (auto &var_name : op.InputArgumentNames()) {
    if (var_dev_id != -1) break;
    for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
      if (var_name_on_devices[i].count(var_name)) {
        var_dev_id = static_cast<int>(i);
        break;
      }
    }
  }
  return var_dev_id;
}

Y
Yu Yang 已提交
425 426 427 428
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
  for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
C
chengduoZH 已提交
429 430 431
    auto *communication_dev_ctx =
        nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
                   : platform::DeviceContextPool::Instance().Get(places_[i]);
Y
Yu Yang 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
#else
    auto *communication_dev_ctx =
        platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif

    auto *op_handle =
        new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
                                  places_[i], communication_dev_ctx);
    result->ops_.emplace_back(op_handle);

    // FIXME: Currently ScaleLossGradOp only use device_count as scale
    // factor. So it does not depend on any other operators.
    // VarHandle *loss = GetVarHandle(loss_var_name, place);
    // loss->pending_ops_.emplace_back(op_handle);
    // op_handle->inputs_.emplace_back(loss);

    CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i],
                   i);
  }
}

void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
T
typhoonzero 已提交
454 455 456
                                                     const OpDesc &op,
                                                     size_t num_places) const {
  for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
Y
Yu Yang 已提交
457 458 459
    auto p = places_[scope_idx];
    auto s = local_scopes_[scope_idx];
    result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
Y
Yu Yang 已提交
460
    CreateOpHandleIOs(result, op, scope_idx);
Y
Yu Yang 已提交
461 462 463
  }
}

C
chengduoZH 已提交
464 465 466
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
                                                   const std::string &og,
                                                   int dst_dev_id) const {
C
chengduoZH 已提交
467 468 469 470 471 472 473 474 475 476
#ifdef PADDLE_WITH_CUDA
  result->ops_.emplace_back(
      new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
  result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_));
#endif
  auto *op_handle = result->ops_.back().get();

  for (size_t i = 0; i < places_.size(); ++i) {
    auto &p = places_[i];
C
chengduoZH 已提交
477 478
    SetCommunicationContext(op_handle, p);
    auto &vars = result->vars_[i][og];
C
chengduoZH 已提交
479 480 481 482 483 484 485 486 487 488 489 490
    PADDLE_ENFORCE(!vars.empty());
    auto &prev_grad = vars.back();
    op_handle->AddInput(prev_grad.get());
  }
  auto &vars = result->vars_[dst_dev_id][og];
  auto var =
      new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
  vars.emplace_back(var);
  op_handle->AddOutput(var);
  return var;
}

Y
fix pe  
Yancey1989 已提交
491 492
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
                                        const std::string &prev_op_name) const {
Y
Yancey1989 已提交
493
  for (auto &prev_op : result->ops_) {
Y
fix pe  
Yancey1989 已提交
494
    if (prev_op->Name() == prev_op_name) {
Y
Yancey1989 已提交
495 496 497
      auto *dep_var = new DummyVarHandle();
      prev_op->AddOutput(dep_var);
      result->dep_vars_.emplace(dep_var);
Y
fix pe  
Yancey1989 已提交
498
      op->AddInput(dep_var);
Y
Yancey1989 已提交
499 500 501 502
    }
  }
}

Y
Yancey1989 已提交
503
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
504 505 506
                                                const OpDesc &op,
                                                int place_id) const {
  CreateComputationalOp(result, op, place_id);
Y
Yancey1989 已提交
507 508 509 510 511
  if (op.Type() == "concat") {
    ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
  }
}

512
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
513 514 515
                                          int device_id) const {
  result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[device_id],
                                            op.Type(), places_[device_id]));
Y
fix pe  
Yancey1989 已提交
516

Y
Yancey1989 已提交
517
  if (op.Type() == "send_barrier") {
518
    ConnectOp(result, result->ops_.back().get(), "send");
Y
Yancey1989 已提交
519
  } else if (op.Type() == "recv") {
Y
fix pe  
Yancey1989 已提交
520
    ConnectOp(result, result->ops_.back().get(), "send_barrier");
Y
Yancey1989 已提交
521
  } else if (op.Type() == "fetch_barrier") {
Y
fix pe  
Yancey1989 已提交
522
    ConnectOp(result, result->ops_.back().get(), "recv");
523
  } else if (op.Type() == "send") {
Y
Yancey1989 已提交
524 525 526
    // do nothing
  } else {
    PADDLE_THROW(
Y
Yancey1989 已提交
527
        "rpc op should be in ["
528
        "send, send_barrier. recv, fetch_barrier]");
Y
Yancey1989 已提交
529 530
  }

Y
Yancey1989 已提交
531 532
  // TODO(Yancey1989): schedule rpc op on different place may
  // increate throughput
533
  CreateOpHandleIOs(result, op, device_id);
Y
Yu Yang 已提交
534 535 536
}

bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
Y
yuyang18 已提交
537 538
  return boost::get<int>(
             op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
Y
Fix bug  
yuyang18 已提交
539 540 541
             (static_cast<int>(OpRole::kBackward) |
              static_cast<int>(OpRole::kLoss)) &&
         !loss_var_name_.empty();  // If loss_var is empty. This is test mode
Y
Yu Yang 已提交
542
}
Y
Yu Yang 已提交
543 544 545
}  // namespace details
}  // namespace framework
}  // namespace paddle