build_strategy.cc 16.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/details/build_strategy.h"

D
dzhwinter 已提交
17 18
#include <glog/logging.h>
#include <memory>
19
#include <unordered_set>
Q
Qiao Longfei 已提交
20
#include <utility>
21
#include "paddle/fluid/framework/details/reduce_op_handle.h"
22
#include "paddle/fluid/framework/ir/graph.h"
D
dzhwinter 已提交
23
#include "paddle/fluid/framework/ir/graph_helper.h"
W
WangZhen 已提交
24
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
25
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
26
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
27
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
28 29
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
30

31 32
DECLARE_bool(use_mkldnn);

33 34 35 36
namespace paddle {
namespace framework {
namespace details {

37
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
Y
Yancey1989 已提交
38 39
  // Should fix the allreduce op order if scheduling
  // them in multiple threads or processes to avoid hang.
Y
Yancey1989 已提交
40
  // NOTE: ParallelGraph would execute this pass on each graph, so
Y
Yancey1989 已提交
41
  // don't need to append it here.
Y
Yancey1989 已提交
42
  return (!strategy.enable_sequential_execution_ &&
Y
Yancey1989 已提交
43 44
          strategy.num_trainers_ > 1) &&
         !strategy.enable_parallel_graph_;
45 46
}

47 48 49 50
class ParallelExecutorPassBuilder : public ir::PassBuilder {
 public:
  explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
      : ir::PassBuilder(), strategy_(strategy) {
C
chengduo 已提交
51 52
    // Add a graph viz pass to record a graph.
    if (!strategy_.debug_graphviz_path_.empty()) {
53
      VLOG(1) << "Add graph_viz_pass";
C
chengduo 已提交
54 55 56 57 58 59
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
          "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

Z
Zeng Jinle 已提交
60
    // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
61
    VLOG(1) << "Add record_skip_memory_opt_vars_pass";
Z
Zeng Jinle 已提交
62 63
    AppendPass("record_skip_memory_opt_vars_pass");

64 65
#ifdef PADDLE_WITH_MKLDNN
    if (FLAGS_use_mkldnn) {
66
      VLOG(1) << "Add mkldnn_placement_pass";
67 68 69 70 71 72 73 74 75 76 77 78 79
      AppendPass("mkldnn_placement_pass");
    } else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
      LOG(WARNING)
          << "mkldnn_enabled_op_types specify the operator type list to "
             "use MKLDNN acceleration. It is null in default, means "
             "that all the operators supported by MKLDNN will be "
             "accelerated. And it should not be set when "
             "FLAGS_use_mkldnn=false.";
    }
#else
    PADDLE_ENFORCE(!FLAGS_use_mkldnn,
                   "Please compile with MKLDNN first to use MKLDNN");
#endif
80

S
sneaxiy 已提交
81
    if (strategy_.enable_sequential_execution_) {
82
      VLOG(1) << "Add sequential_execution_pass";
S
sneaxiy 已提交
83 84 85
      AppendPass("sequential_execution_pass");
    }

Q
qingqing01 已提交
86 87 88 89 90
    // Add op fusion.
    if (strategy.sync_batch_norm_) {
      AppendPass("sync_batch_norm_pass");
    }

D
dzhwinter 已提交
91 92
    // Add op fusion.
    if (strategy.fuse_relu_depthwise_conv_) {
93
      VLOG(1) << "Add fuse_relu_depthwise_conv_pass";
D
dzhwinter 已提交
94
      AppendPass("fuse_relu_depthwise_conv_pass");
D
dzhwinter 已提交
95
    }
96

97 98 99 100 101 102 103 104
    // TODO(zjl): refactor MemoryOptimizePass to fit
    // new strategy, which does not need to set
    // var.persistable = True
    if (strategy_.use_legacy_memory_optimize_strategy_) {
      if (strategy_.enable_inplace_) {
        VLOG(5) << "Add inplace_pass";
        AppendPass("inplace_pass");
      }
S
sneaxiy 已提交
105 106
    }

C
chengduo 已提交
107
    if (strategy_.fuse_elewise_add_act_ops_) {
108
      VLOG(1) << "Add fuse_elewise_add_act_pass";
C
chengduo 已提交
109 110 111 112
      AppendPass("fuse_elewise_add_act_pass");
    }

    // for single card training, fuse_all_reduce_ops is unnecessary.
113
    // coalesce_grad_tensor_pass should be before of MultiDevPass.
C
chengduo 已提交
114
    if (strategy_.fuse_all_reduce_ops_) {
115 116
      VLOG(1) << "Add coalesce_grad_tensor_pass";
      AppendPass("coalesce_grad_tensor_pass");
C
chengduo 已提交
117 118
    }

119 120 121 122 123 124 125 126 127 128 129 130
    // Fuse all the optimization operators.
    if (strategy_.is_distribution_) {
      VLOG(3) << "Currently, fuse_all_optimizer_ops only works under "
                 "Non-distributed mode.";
      strategy_.fuse_all_optimizer_ops_ = false;
    }
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
        strategy_.is_distribution_) {
      VLOG(3) << "Currently, fuse_all_optimizer_ops only works under AllReduce "
                 "mode.";
      strategy_.fuse_all_optimizer_ops_ = false;
    }
C
chengduo 已提交
131
    if (strategy_.fuse_all_optimizer_ops_) {
132 133 134 135 136 137 138 139 140
      // NOTE: fuse_all_xx_ops will count the number of xx operator first,
      // if the number is zero, fuse_all_reduce_ops will do nothing.
      // Currently, only one type of optimization algorithm can be fused.
      VLOG(1) << "Add fuse_adam_op_pass";
      AppendPass("fuse_adam_op_pass");
      VLOG(1) << "Add fuse_sgd_op_pass";
      AppendPass("fuse_sgd_op_pass");
      VLOG(1) << "Add fuse_momentum_op_pass";
      AppendPass("fuse_momentum_op_pass");
C
chengduo 已提交
141 142
    }

X
Xin Pan 已提交
143
    // Add a graph viz pass to record a graph.
C
chengduo 已提交
144
    if (!strategy.debug_graphviz_path_.empty()) {
145 146
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
C
chengduo 已提交
147
          "%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph");
148 149 150
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

151 152 153 154
    CollectiveContext *context = CollectiveContext::GetInstance();
    context->endpoints_ = strategy_.trainers_endpoints_;
    context->trainer_id_ = strategy_.trainer_id_;
    PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
155
    if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
156 157 158 159 160 161
      PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
                         strategy_.trainers_endpoints_.size(),
                     "trainer_id_ < endpoints_ size");
    }
    VLOG(1) << "CollectiveContext:" << context->String();

D
dzhwinter 已提交
162 163 164 165 166
    // NOTE(dzh): memory optimize should be a runtime pass.
    // However, after multi_devices_pass, VarHandle, OpHandle is
    // the de-fact IR, any reuse on Graph is meaningless.
    // A side-effect of that, memory optimize cannot forsee the fetched vars
    // , so fetchlist should be set persistable before call the Run interface.
167 168 169 170 171
    if (strategy_.use_legacy_memory_optimize_strategy_) {
      if (strategy_.memory_optimize_) {
        VLOG(5) << "Add memory_optimize_pass";
        AppendPass("memory_optimize_pass");
      }
D
dzhwinter 已提交
172
    }
173

174 175 176 177
    // runtime_context_cache pass should be the last pass to enable the attr of
    // all original and fused operators. But no operators can be enabled this
    // attr if putting it after MultiDevPass.
    if (strategy_.cache_runtime_context_) {
178
      VLOG(1) << "Add runtime_context_cache_pass";
179 180 181
      AppendPass("runtime_context_cache_pass");
    }

C
chengduo 已提交
182
    AppendMultiDevPass(strategy_);
183

C
chengduo 已提交
184
    if (strategy_.fuse_all_reduce_ops_) {
C
chengduo 已提交
185 186
      // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
      // first, if the number is zero, fuse_all_reduce_ops will do nothing.
187
      VLOG(1) << "Add fuse_all_reduce_op_pass";
C
chengduo 已提交
188 189 190
      AppendPass("fuse_all_reduce_op_pass");
    }

X
Xin Pan 已提交
191
    // Add a graph print pass to record a graph with device info.
192
    if (!strategy_.debug_graphviz_path_.empty()) {
193
      VLOG(1) << "Add multi_devices_print_pass";
194
      auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
D
dzhwinter 已提交
195 196 197
      const std::string graph_path =
          string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
                          "_multi_devices_graph");
198
      multi_devices_print_pass->Set<std::string>(ir::kGraphvizPath,
D
dzhwinter 已提交
199
                                                 new std::string(graph_path));
200 201
      multi_devices_print_pass->Set<ir::GraphvizSSAGraphPrinter>(
          "graph_printer", new ir::GraphvizSSAGraphPrinter);
202 203
    }

204 205 206 207 208
    // experimental shows that the program will be faster if append
    // all_reduce_deps_pass here.
    if (!strategy_.enable_parallel_graph_ &&
        (SeqOnlyAllReduceOps(strategy_) ||
         strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
209
      VLOG(1) << "Add all_reduce_deps_pass";
210 211 212
      AppendPass("all_reduce_deps_pass");
    }

213 214 215
    if (strategy_.num_trainers_ > 1 && !strategy_.async_mode_ &&
        !strategy_.is_distribution_ &&
        strategy_.enable_backward_optimizer_op_deps_) {
216 217 218 219
      VLOG(1) << "Add backward_op_deps_pass";
      AppendPass("backward_optimizer_op_deps_pass");
    }

S
sneaxiy 已提交
220
    if (strategy_.remove_unnecessary_lock_) {
221
      VLOG(1) << "Add modify_op_lock_and_record_event_pass";
S
sneaxiy 已提交
222 223
      AppendPass("modify_op_lock_and_record_event_pass");
    }
224 225

    // Verify that the graph is correct for multi-device executor.
226
    VLOG(1) << "Add multi_devices_check_pass";
227
    AppendPass("multi_devices_check_pass");
228 229
  }

230 231
  // Convert graph to run on multi-devices.
  void AppendMultiDevPass(const BuildStrategy &strategy) {
C
chengduo 已提交
232
    ir::Pass *multi_devices_pass = nullptr;
Q
can run  
Qiao Longfei 已提交
233

Q
Qiao Longfei 已提交
234
    if (strategy_.async_mode_) {
235
      VLOG(1) << "Add async_multi_devices_pass";
Q
Qiao Longfei 已提交
236 237
      multi_devices_pass = AppendPass("async_multi_devices_pass").get();
    } else if (strategy_.is_distribution_) {
238
      VLOG(1)
239
          << "Add dist_multi_devices_pass, multi device parameter server mode";
240 241 242
      multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
    } else {
      if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
243
        VLOG(1) << "Add all_reduce_mode_multi_devices_pass";
244
        multi_devices_pass =
C
chengduo 已提交
245
            AppendPass("all_reduce_mode_multi_devices_pass").get();
246
      } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
247
        VLOG(1) << "Add reduce_mode_multi_devices_pass";
248 249 250 251 252 253 254 255 256
        multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
      } else {
        PADDLE_THROW("Unknown reduce strategy.");
      }
    }
    multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
                                                         &strategy_);
  }

257 258 259 260
 private:
  BuildStrategy strategy_;
};

261
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
X
Xin Pan 已提交
262 263
    bool finalize_strategy) const {
  if (is_finalized_) {
264 265
    return pass_builder_;
  }
266
  pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
X
Xin Pan 已提交
267 268
  if (finalize_strategy) {
    is_finalized_ = true;
269
  }
X
fix  
Xin Pan 已提交
270
  return pass_builder_;
271 272
}

273
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
274
  return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0;
275 276
}

277 278 279 280 281
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
                                const std::vector<platform::Place> &places,
                                const std::string &loss_var_name,
                                const std::vector<Scope *> &local_scopes,
                                const size_t &nranks,
P
peizhilin 已提交
282
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
283 284
                                const bool use_cuda,
                                platform::NCCLCommunicator *nccl_ctxs) const {
285
#else
286
                                const bool use_cuda) const {
287
#endif
288
  VLOG(3) << "apply all passes";
289 290
  // Create a default one if not finalized by user.
  CreatePassesFromStrategy(false);
X
fix  
Xin Pan 已提交
291 292

  for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
G
gongweibao 已提交
293
    VLOG(3) << "BuildStrategy::Apply pass:" << pass->Type();
294 295 296
    if (IsMultiDevPass(pass->Type())) {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
297 298
      pass->Erase(ir::kLossVarName);
      pass->SetNotOwned<const std::string>(ir::kLossVarName, &loss_var_name);
299 300
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
X
fix  
Xin Pan 已提交
301
                                                    &local_scopes);
302 303
      pass->Erase(ir::kNRanks);
      pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
Y
Yancey1989 已提交
304

P
peizhilin 已提交
305
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
306
      platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
C
chengduo 已提交
307
      pass->Erase(kNCCLCtxs);
308
      pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
309
#endif
310
    } else if (pass->Type() == "coalesce_grad_tensor_pass" ||
C
chengduo 已提交
311 312
               pass->Type() == "fuse_adam_op_pass" ||
               pass->Type() == "fuse_sgd_op_pass" ||
C
chengduo 已提交
313
               pass->Type() == "fuse_momentum_op_pass" ||
C
chengduo 已提交
314
               pass->Type() == "fuse_all_reduce_op_pass") {
C
chengduo 已提交
315 316 317 318 319
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
C
chengduo 已提交
320
      if (pass->Type() == "fuse_all_reduce_op_pass") {
C
chengduo 已提交
321
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
322
        platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
C
chengduo 已提交
323
        pass->Erase(kNCCLCtxs);
324
        pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
325 326 327
        pass->Erase(kUseHierarchicalAllReduce);
        pass->Set<bool>(kUseHierarchicalAllReduce,
                        new bool(use_hierarchical_allreduce_));
328
#endif
C
chengduo 已提交
329
      }
330
    } else if (pass->Type() == "coalesce_grad_tensor_pass") {
C
chengduo 已提交
331 332 333 334 335
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
S
sneaxiy 已提交
336
    } else if (pass->Type() == "sequential_execution_pass") {
337 338
      LOG(INFO) << "set enable_sequential_execution:"
                << enable_sequential_execution_;
339
    } else if (pass->Type() == "all_reduce_deps_pass") {
340
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
341
      platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
342
      pass->Erase(kNCCLCtxs);
343
      pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
344 345 346 347
      pass->Erase(kUseHierarchicalAllReduce);
      pass->Set<bool>(kUseHierarchicalAllReduce,
                      new bool(use_hierarchical_allreduce_));
#endif
348 349
      LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
                << ", num_trainers:" << num_trainers_;
350 351 352 353 354 355
    } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
      if (!use_cuda) {
        LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
                        "GPU, skipped.";
        continue;
      }
356
    } else if (pass->Type() == "inplace_pass") {
357 358
      pass->Erase(ir::kUseCuda);
      pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
359 360 361
    } else if (pass->Type() == "mkldnn_placement_pass") {
      pass->Set("mkldnn_enabled_op_types",
                new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
362 363 364 365 366 367
    } else if (pass->Type() == "backward_optimizer_op_deps_pass") {
      if (!use_cuda) {
        VLOG(1) << "backward_optimizer_op_deps_pass is only supported on "
                   "GPU, skipped.";
        continue;
      }
X
fix  
Xin Pan 已提交
368
    }
369
    VLOG(3) << "Start Apply Pass " << pass->Type();
370
    graph = pass->Apply(graph);
371
    VLOG(3) << "Finish Apply Pass " << pass->Type();
X
fix  
Xin Pan 已提交
372
  }
Q
Qiao Longfei 已提交
373
  VLOG(3) << "All Passes Applied";
374 375
  return graph;
}
D
dzhwinter 已提交
376

377 378 379 380
}  // namespace details
}  // namespace framework
}  // namespace paddle

Q
qingqing01 已提交
381
USE_PASS(sync_batch_norm_pass);
382
USE_PASS(fuse_relu_depthwise_conv_pass);
383 384
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
385
USE_PASS(multi_batch_merge_pass);
386
USE_PASS(reduce_mode_multi_devices_pass);
C
chengduo 已提交
387
USE_PASS(all_reduce_mode_multi_devices_pass);
388
USE_PASS(dist_multi_devices_pass);
389 390
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
D
dzhwinter 已提交
391
USE_PASS(memory_optimize_pass);
S
sneaxiy 已提交
392
USE_PASS(sequential_execution_pass);
393
USE_PASS(all_reduce_deps_pass);
394
USE_PASS(backward_optimizer_op_deps_pass);
S
sneaxiy 已提交
395
USE_PASS(modify_op_lock_and_record_event_pass);
D
dzhwinter 已提交
396
USE_PASS(inplace_pass);
M
minqiyang 已提交
397
USE_PASS(lock_free_optimize_pass);
398
USE_PASS(coalesce_grad_tensor_pass);
W
WangZhen 已提交
399
USE_PASS(graph_to_program_pass);
C
chengduo 已提交
400 401
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
C
chengduo 已提交
402
USE_PASS(fuse_momentum_op_pass);
C
chengduo 已提交
403
USE_PASS(fuse_all_reduce_op_pass);
404
USE_PASS(runtime_context_cache_pass);
Z
Zeng Jinle 已提交
405
USE_PASS(record_skip_memory_opt_vars_pass);
406 407 408
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif