build_strategy.cc 16.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/details/build_strategy.h"

D
dzhwinter 已提交
17 18
#include <glog/logging.h>
#include <memory>
19
#include <unordered_set>
Q
Qiao Longfei 已提交
20
#include <utility>
21
#include "paddle/fluid/framework/details/reduce_op_handle.h"
22
#include "paddle/fluid/framework/ir/graph.h"
D
dzhwinter 已提交
23
#include "paddle/fluid/framework/ir/graph_helper.h"
W
WangZhen 已提交
24
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
25
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
26 27 28
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
29

30 31
DECLARE_bool(use_mkldnn);

32 33 34 35
namespace paddle {
namespace framework {
namespace details {

36
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
Y
Yancey1989 已提交
37 38
  // Should fix the allreduce op order if scheduling
  // them in multiple threads or processes to avoid hang.
Y
Yancey1989 已提交
39
  // NOTE: ParallelGraph would execute this pass on each graph, so
Y
Yancey1989 已提交
40
  // don't need to append it here.
Y
Yancey1989 已提交
41
  return (!strategy.enable_sequential_execution_ &&
Y
Yancey1989 已提交
42 43
          strategy.num_trainers_ > 1) &&
         !strategy.enable_parallel_graph_;
44 45
}

46 47 48 49
class ParallelExecutorPassBuilder : public ir::PassBuilder {
 public:
  explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
      : ir::PassBuilder(), strategy_(strategy) {
C
chengduo 已提交
50 51
    // Add a graph viz pass to record a graph.
    if (!strategy_.debug_graphviz_path_.empty()) {
52
      VLOG(1) << "Add graph_viz_pass";
C
chengduo 已提交
53 54 55 56 57 58
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
          "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

Z
Zeng Jinle 已提交
59
    // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
60
    VLOG(1) << "Add record_skip_memory_opt_vars_pass";
Z
Zeng Jinle 已提交
61 62
    AppendPass("record_skip_memory_opt_vars_pass");

63 64
#ifdef PADDLE_WITH_MKLDNN
    if (FLAGS_use_mkldnn) {
65
      VLOG(1) << "Add mkldnn_placement_pass";
66 67 68 69 70 71 72 73 74 75 76 77 78
      AppendPass("mkldnn_placement_pass");
    } else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
      LOG(WARNING)
          << "mkldnn_enabled_op_types specify the operator type list to "
             "use MKLDNN acceleration. It is null in default, means "
             "that all the operators supported by MKLDNN will be "
             "accelerated. And it should not be set when "
             "FLAGS_use_mkldnn=false.";
    }
#else
    PADDLE_ENFORCE(!FLAGS_use_mkldnn,
                   "Please compile with MKLDNN first to use MKLDNN");
#endif
79

S
sneaxiy 已提交
80
    if (strategy_.enable_sequential_execution_) {
81
      VLOG(1) << "Add sequential_execution_pass";
S
sneaxiy 已提交
82 83 84
      AppendPass("sequential_execution_pass");
    }

Q
qingqing01 已提交
85 86 87 88 89
    // Add op fusion.
    if (strategy.sync_batch_norm_) {
      AppendPass("sync_batch_norm_pass");
    }

D
dzhwinter 已提交
90 91
    // Add op fusion.
    if (strategy.fuse_relu_depthwise_conv_) {
92
      VLOG(1) << "Add fuse_relu_depthwise_conv_pass";
D
dzhwinter 已提交
93
      AppendPass("fuse_relu_depthwise_conv_pass");
D
dzhwinter 已提交
94
    }
95

96 97 98 99 100 101 102 103
    // TODO(zjl): refactor MemoryOptimizePass to fit
    // new strategy, which does not need to set
    // var.persistable = True
    if (strategy_.use_legacy_memory_optimize_strategy_) {
      if (strategy_.enable_inplace_) {
        VLOG(5) << "Add inplace_pass";
        AppendPass("inplace_pass");
      }
S
sneaxiy 已提交
104 105
    }

C
chengduo 已提交
106
    if (strategy_.fuse_elewise_add_act_ops_) {
107
      VLOG(1) << "Add fuse_elewise_add_act_pass";
C
chengduo 已提交
108 109 110 111
      AppendPass("fuse_elewise_add_act_pass");
    }

    // for single card training, fuse_all_reduce_ops is unnecessary.
112
    // coalesce_grad_tensor_pass should be before of MultiDevPass.
C
chengduo 已提交
113
    if (strategy_.fuse_all_reduce_ops_) {
114 115
      VLOG(1) << "Add coalesce_grad_tensor_pass";
      AppendPass("coalesce_grad_tensor_pass");
C
chengduo 已提交
116 117
    }

118 119 120 121 122 123 124 125 126 127 128 129
    // Fuse all the optimization operators.
    if (strategy_.is_distribution_) {
      VLOG(3) << "Currently, fuse_all_optimizer_ops only works under "
                 "Non-distributed mode.";
      strategy_.fuse_all_optimizer_ops_ = false;
    }
    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
        strategy_.is_distribution_) {
      VLOG(3) << "Currently, fuse_all_optimizer_ops only works under AllReduce "
                 "mode.";
      strategy_.fuse_all_optimizer_ops_ = false;
    }
C
chengduo 已提交
130
    if (strategy_.fuse_all_optimizer_ops_) {
131 132 133 134 135 136 137 138 139
      // NOTE: fuse_all_xx_ops will count the number of xx operator first,
      // if the number is zero, fuse_all_reduce_ops will do nothing.
      // Currently, only one type of optimization algorithm can be fused.
      VLOG(1) << "Add fuse_adam_op_pass";
      AppendPass("fuse_adam_op_pass");
      VLOG(1) << "Add fuse_sgd_op_pass";
      AppendPass("fuse_sgd_op_pass");
      VLOG(1) << "Add fuse_momentum_op_pass";
      AppendPass("fuse_momentum_op_pass");
C
chengduo 已提交
140 141
    }

X
Xin Pan 已提交
142
    // Add a graph viz pass to record a graph.
C
chengduo 已提交
143
    if (!strategy.debug_graphviz_path_.empty()) {
144 145
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
C
chengduo 已提交
146
          "%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph");
147 148 149
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

150 151 152 153
    CollectiveContext *context = CollectiveContext::GetInstance();
    context->endpoints_ = strategy_.trainers_endpoints_;
    context->trainer_id_ = strategy_.trainer_id_;
    PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
154
    if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
155 156 157 158 159 160
      PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
                         strategy_.trainers_endpoints_.size(),
                     "trainer_id_ < endpoints_ size");
    }
    VLOG(1) << "CollectiveContext:" << context->String();

D
dzhwinter 已提交
161 162 163 164 165
    // NOTE(dzh): memory optimize should be a runtime pass.
    // However, after multi_devices_pass, VarHandle, OpHandle is
    // the de-fact IR, any reuse on Graph is meaningless.
    // A side-effect of that, memory optimize cannot forsee the fetched vars
    // , so fetchlist should be set persistable before call the Run interface.
166 167 168 169 170
    if (strategy_.use_legacy_memory_optimize_strategy_) {
      if (strategy_.memory_optimize_) {
        VLOG(5) << "Add memory_optimize_pass";
        AppendPass("memory_optimize_pass");
      }
D
dzhwinter 已提交
171
    }
172

173 174 175 176
    // runtime_context_cache pass should be the last pass to enable the attr of
    // all original and fused operators. But no operators can be enabled this
    // attr if putting it after MultiDevPass.
    if (strategy_.cache_runtime_context_) {
177
      VLOG(1) << "Add runtime_context_cache_pass";
178 179 180
      AppendPass("runtime_context_cache_pass");
    }

C
chengduo 已提交
181
    AppendMultiDevPass(strategy_);
182

C
chengduo 已提交
183
    if (strategy_.fuse_all_reduce_ops_) {
C
chengduo 已提交
184 185
      // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
      // first, if the number is zero, fuse_all_reduce_ops will do nothing.
186
      VLOG(1) << "Add fuse_all_reduce_op_pass";
C
chengduo 已提交
187 188 189
      AppendPass("fuse_all_reduce_op_pass");
    }

X
Xin Pan 已提交
190
    // Add a graph print pass to record a graph with device info.
191
    if (!strategy_.debug_graphviz_path_.empty()) {
192
      VLOG(1) << "Add multi_devices_print_pass";
193
      auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
D
dzhwinter 已提交
194 195 196
      const std::string graph_path =
          string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
                          "_multi_devices_graph");
197
      multi_devices_print_pass->Set<std::string>(ir::kGraphvizPath,
D
dzhwinter 已提交
198
                                                 new std::string(graph_path));
199 200
      multi_devices_print_pass->Set<ir::GraphvizSSAGraphPrinter>(
          "graph_printer", new ir::GraphvizSSAGraphPrinter);
201 202
    }

203 204 205 206 207
    // experimental shows that the program will be faster if append
    // all_reduce_deps_pass here.
    if (!strategy_.enable_parallel_graph_ &&
        (SeqOnlyAllReduceOps(strategy_) ||
         strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
208
      VLOG(1) << "Add all_reduce_deps_pass";
209 210 211
      AppendPass("all_reduce_deps_pass");
    }

212 213 214
    if (strategy_.num_trainers_ > 1 && !strategy_.async_mode_ &&
        !strategy_.is_distribution_ &&
        strategy_.enable_backward_optimizer_op_deps_) {
215 216 217 218
      VLOG(1) << "Add backward_op_deps_pass";
      AppendPass("backward_optimizer_op_deps_pass");
    }

S
sneaxiy 已提交
219
    if (strategy_.remove_unnecessary_lock_) {
220
      VLOG(1) << "Add modify_op_lock_and_record_event_pass";
S
sneaxiy 已提交
221 222
      AppendPass("modify_op_lock_and_record_event_pass");
    }
223 224

    // Verify that the graph is correct for multi-device executor.
225
    VLOG(1) << "Add multi_devices_check_pass";
226
    AppendPass("multi_devices_check_pass");
227 228
  }

229 230
  // Convert graph to run on multi-devices.
  void AppendMultiDevPass(const BuildStrategy &strategy) {
C
chengduo 已提交
231
    ir::Pass *multi_devices_pass = nullptr;
Q
can run  
Qiao Longfei 已提交
232

Q
Qiao Longfei 已提交
233
    if (strategy_.async_mode_) {
234
      VLOG(1) << "Add async_multi_devices_pass";
Q
Qiao Longfei 已提交
235 236
      multi_devices_pass = AppendPass("async_multi_devices_pass").get();
    } else if (strategy_.is_distribution_) {
237
      VLOG(1)
238
          << "Add dist_multi_devices_pass, multi device parameter server mode";
239 240 241
      multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
    } else {
      if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
242
        VLOG(1) << "Add all_reduce_mode_multi_devices_pass";
243
        multi_devices_pass =
C
chengduo 已提交
244
            AppendPass("all_reduce_mode_multi_devices_pass").get();
245
      } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
246
        VLOG(1) << "Add reduce_mode_multi_devices_pass";
247 248 249 250 251 252 253 254 255
        multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
      } else {
        PADDLE_THROW("Unknown reduce strategy.");
      }
    }
    multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
                                                         &strategy_);
  }

256 257 258 259
 private:
  BuildStrategy strategy_;
};

260
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
X
Xin Pan 已提交
261 262
    bool finalize_strategy) const {
  if (is_finalized_) {
263 264
    return pass_builder_;
  }
265
  pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
X
Xin Pan 已提交
266 267
  if (finalize_strategy) {
    is_finalized_ = true;
268
  }
X
fix  
Xin Pan 已提交
269
  return pass_builder_;
270 271
}

272
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
273
  return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0;
274 275
}

276 277 278 279 280
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
                                const std::vector<platform::Place> &places,
                                const std::string &loss_var_name,
                                const std::vector<Scope *> &local_scopes,
                                const size_t &nranks,
P
peizhilin 已提交
281
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
282 283
                                const bool use_cuda,
                                platform::NCCLCommunicator *nccl_ctxs) const {
284
#else
285
                                const bool use_cuda) const {
286
#endif
287
  VLOG(3) << "apply all passes";
288 289
  // Create a default one if not finalized by user.
  CreatePassesFromStrategy(false);
X
fix  
Xin Pan 已提交
290 291

  for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
G
gongweibao 已提交
292
    VLOG(3) << "BuildStrategy::Apply pass:" << pass->Type();
293 294 295
    if (IsMultiDevPass(pass->Type())) {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
296 297
      pass->Erase(ir::kLossVarName);
      pass->SetNotOwned<const std::string>(ir::kLossVarName, &loss_var_name);
298 299
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
X
fix  
Xin Pan 已提交
300
                                                    &local_scopes);
301 302
      pass->Erase(ir::kNRanks);
      pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
Y
Yancey1989 已提交
303

P
peizhilin 已提交
304
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
305
      platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
C
chengduo 已提交
306
      pass->Erase(kNCCLCtxs);
307
      pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
308
#endif
309
    } else if (pass->Type() == "coalesce_grad_tensor_pass" ||
C
chengduo 已提交
310 311
               pass->Type() == "fuse_adam_op_pass" ||
               pass->Type() == "fuse_sgd_op_pass" ||
C
chengduo 已提交
312
               pass->Type() == "fuse_momentum_op_pass" ||
C
chengduo 已提交
313
               pass->Type() == "fuse_all_reduce_op_pass") {
C
chengduo 已提交
314 315 316 317 318
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
C
chengduo 已提交
319
      if (pass->Type() == "fuse_all_reduce_op_pass") {
C
chengduo 已提交
320
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
321
        platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
C
chengduo 已提交
322
        pass->Erase(kNCCLCtxs);
323
        pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
324 325 326
        pass->Erase(kUseHierarchicalAllReduce);
        pass->Set<bool>(kUseHierarchicalAllReduce,
                        new bool(use_hierarchical_allreduce_));
327
#endif
C
chengduo 已提交
328
      }
329
    } else if (pass->Type() == "coalesce_grad_tensor_pass") {
C
chengduo 已提交
330 331 332 333 334
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
S
sneaxiy 已提交
335
    } else if (pass->Type() == "sequential_execution_pass") {
336 337
      LOG(INFO) << "set enable_sequential_execution:"
                << enable_sequential_execution_;
338
    } else if (pass->Type() == "all_reduce_deps_pass") {
339
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
340
      platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
341
      pass->Erase(kNCCLCtxs);
342
      pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
343 344 345 346
      pass->Erase(kUseHierarchicalAllReduce);
      pass->Set<bool>(kUseHierarchicalAllReduce,
                      new bool(use_hierarchical_allreduce_));
#endif
347 348
      LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
                << ", num_trainers:" << num_trainers_;
349 350 351 352 353 354
    } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
      if (!use_cuda) {
        LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
                        "GPU, skipped.";
        continue;
      }
355
    } else if (pass->Type() == "inplace_pass") {
356 357
      pass->Erase(ir::kUseCuda);
      pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
358 359 360
    } else if (pass->Type() == "mkldnn_placement_pass") {
      pass->Set("mkldnn_enabled_op_types",
                new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
361 362 363 364 365 366
    } else if (pass->Type() == "backward_optimizer_op_deps_pass") {
      if (!use_cuda) {
        VLOG(1) << "backward_optimizer_op_deps_pass is only supported on "
                   "GPU, skipped.";
        continue;
      }
X
fix  
Xin Pan 已提交
367
    }
368
    VLOG(3) << "Start Apply Pass " << pass->Type();
369
    graph = pass->Apply(graph);
370
    VLOG(3) << "Finish Apply Pass " << pass->Type();
X
fix  
Xin Pan 已提交
371
  }
Q
Qiao Longfei 已提交
372
  VLOG(3) << "All Passes Applied";
373 374
  return graph;
}
D
dzhwinter 已提交
375

376 377 378 379
}  // namespace details
}  // namespace framework
}  // namespace paddle

Q
qingqing01 已提交
380
USE_PASS(sync_batch_norm_pass);
381
USE_PASS(fuse_relu_depthwise_conv_pass);
382 383
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
384
USE_PASS(multi_batch_merge_pass);
385
USE_PASS(reduce_mode_multi_devices_pass);
C
chengduo 已提交
386
USE_PASS(all_reduce_mode_multi_devices_pass);
387
USE_PASS(dist_multi_devices_pass);
388 389
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
D
dzhwinter 已提交
390
USE_PASS(memory_optimize_pass);
S
sneaxiy 已提交
391
USE_PASS(sequential_execution_pass);
392
USE_PASS(all_reduce_deps_pass);
393
USE_PASS(backward_optimizer_op_deps_pass);
S
sneaxiy 已提交
394
USE_PASS(modify_op_lock_and_record_event_pass);
D
dzhwinter 已提交
395
USE_PASS(inplace_pass);
M
minqiyang 已提交
396
USE_PASS(lock_free_optimize_pass);
397
USE_PASS(coalesce_grad_tensor_pass);
W
WangZhen 已提交
398
USE_PASS(graph_to_program_pass);
C
chengduo 已提交
399 400
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
C
chengduo 已提交
401
USE_PASS(fuse_momentum_op_pass);
C
chengduo 已提交
402
USE_PASS(fuse_all_reduce_op_pass);
403
USE_PASS(runtime_context_cache_pass);
Z
Zeng Jinle 已提交
404
USE_PASS(record_skip_memory_opt_vars_pass);
405 406 407
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif