build_strategy.cc 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/details/build_strategy.h"

D
dzhwinter 已提交
17 18
#include <glog/logging.h>
#include <memory>
Q
qingqing01 已提交
19
#include <utility>
D
dzhwinter 已提交
20

D
dzhwinter 已提交
21
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
22
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
23
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
24
#include "paddle/fluid/framework/details/reduce_op_handle.h"
S
sneaxiy 已提交
25
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
26
#include "paddle/fluid/framework/ir/graph.h"
D
dzhwinter 已提交
27
#include "paddle/fluid/framework/ir/graph_helper.h"
W
WangZhen 已提交
28
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
29 30 31 32 33 34
#include "paddle/fluid/framework/ir/graph_viz_pass.h"

namespace paddle {
namespace framework {
namespace details {

35
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
Y
Yancey1989 已提交
36 37
  // Should fix the allreduce op order if scheduling
  // them in multiple threads or processes to avoid hang.
Y
Yancey1989 已提交
38
  // NOTE: ParallelGraph would execute this pass on each graph, so
Y
Yancey1989 已提交
39
  // don't need to append it here.
Y
Yancey1989 已提交
40
  return (!strategy.enable_sequential_execution_ &&
Y
Yancey1989 已提交
41 42
          strategy.num_trainers_ > 1) &&
         !strategy.enable_parallel_graph_;
43 44
}

45 46 47 48
class ParallelExecutorPassBuilder : public ir::PassBuilder {
 public:
  explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
      : ir::PassBuilder(), strategy_(strategy) {
C
chengduo 已提交
49 50 51 52 53 54 55 56
    // Add a graph viz pass to record a graph.
    if (!strategy_.debug_graphviz_path_.empty()) {
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
          "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

D
dzhwinter 已提交
57
    if (strategy_.enable_sequential_execution_) {
C
chengduo 已提交
58
      VLOG(10) << "Add sequential_execution_pass";
D
dzhwinter 已提交
59 60
      AppendPass("sequential_execution_pass");
    }
61

Q
qingqing01 已提交
62 63 64 65 66
    // Add op fusion.
    if (strategy.sync_batch_norm_) {
      AppendPass("sync_batch_norm_pass");
    }

D
dzhwinter 已提交
67 68
    // Add op fusion.
    if (strategy.fuse_relu_depthwise_conv_) {
C
chengduo 已提交
69
      VLOG(10) << "Add fuse_relu_depthwise_conv_pass";
D
dzhwinter 已提交
70
      AppendPass("fuse_relu_depthwise_conv_pass");
D
dzhwinter 已提交
71
    }
72

D
dzhwinter 已提交
73 74 75 76 77 78
    // NOTE(dzhwinter): A note for automatical inplace.
    // 1. modify program desc passes should put
    // before inplace pass.
    // 2. manually configured inplace should put
    // before inplace_pass

D
dzhwinter 已提交
79 80
    // Add automatically inplace.
    if (strategy_.enable_inplace_) {
C
chengduo 已提交
81
      VLOG(10) << "Add inplace_pass";
D
dzhwinter 已提交
82
      AppendPass("inplace_pass");
S
sneaxiy 已提交
83 84
    }

C
chengduo 已提交
85 86 87 88 89 90 91 92 93 94 95 96
    if (strategy.fuse_elewise_add_act_ops_) {
      VLOG(10) << "Add fuse_elewise_add_act_pass";
      AppendPass("fuse_elewise_add_act_pass");
    }

    // for single card training, fuse_all_reduce_ops is unnecessary.
    // alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
    if (strategy.fuse_all_reduce_ops_) {
      VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
      AppendPass("alloc_continuous_space_for_grad_pass");
    }

X
Xin Pan 已提交
97
    // Add a graph viz pass to record a graph.
C
chengduo 已提交
98
    if (!strategy.debug_graphviz_path_.empty()) {
99 100
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
C
chengduo 已提交
101
          "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
102 103 104
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

105 106 107 108
    CollectiveContext *context = CollectiveContext::GetInstance();
    context->endpoints_ = strategy_.trainers_endpoints_;
    context->trainer_id_ = strategy_.trainer_id_;
    PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
109
    if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
110 111 112 113 114 115
      PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
                         strategy_.trainers_endpoints_.size(),
                     "trainer_id_ < endpoints_ size");
    }
    VLOG(1) << "CollectiveContext:" << context->String();

D
dzhwinter 已提交
116 117 118 119 120 121
    // NOTE(dzh): memory optimize should be a runtime pass.
    // However, after multi_devices_pass, VarHandle, OpHandle is
    // the de-fact IR, any reuse on Graph is meaningless.
    // A side-effect of that, memory optimize cannot forsee the fetched vars
    // , so fetchlist should be set persistable before call the Run interface.
    if (strategy.memory_optimize_) {
C
chengduo 已提交
122 123
      VLOG(10) << "Add memory_optimize_pass";
      AppendPass("memory_optimize_pass");
D
dzhwinter 已提交
124
    }
125 126

    AppendMultiDevPass(strategy);
127

C
chengduo 已提交
128 129 130 131 132 133 134
    if (strategy.fuse_all_reduce_ops_) {
      // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
      // first, if the number is zero, fuse_all_reduce_ops will do nothing.
      VLOG(10) << "Add fuse_all_reduce_op_pass";
      AppendPass("fuse_all_reduce_op_pass");
    }

X
Xin Pan 已提交
135
    // Add a graph print pass to record a graph with device info.
136 137
    if (!strategy_.debug_graphviz_path_.empty()) {
      auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
D
dzhwinter 已提交
138 139 140 141 142
      const std::string graph_path =
          string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
                          "_multi_devices_graph");
      multi_devices_print_pass->Set<std::string>(kGraphvizPath,
                                                 new std::string(graph_path));
143 144 145 146 147 148
      multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
          "graph_printer", new details::GraphvizSSAGraphPrinter);
    }

    // Verify that the graph is correct for multi-device executor.
    AppendPass("multi_devices_check_pass");
S
sneaxiy 已提交
149

150 151 152 153
    if (VLOG_IS_ON(2)) {
      AppendPass("all_reduce_deps_pass");
    }

154
    if (SeqOnlyAllReduceOps(strategy)) {
C
chengduo 已提交
155
      VLOG(10) << "Add all_reduce_deps_pass";
156 157 158
      AppendPass("all_reduce_deps_pass");
    }

S
sneaxiy 已提交
159
    if (strategy_.remove_unnecessary_lock_) {
C
chengduo 已提交
160
      VLOG(10) << "Add modify_op_lock_and_record_event_pass";
S
sneaxiy 已提交
161 162
      AppendPass("modify_op_lock_and_record_event_pass");
    }
163 164
  }

165 166
  // Convert graph to run on multi-devices.
  void AppendMultiDevPass(const BuildStrategy &strategy) {
C
chengduo 已提交
167
    ir::Pass *multi_devices_pass = nullptr;
168
    if (strategy_.is_distribution_) {
C
chengduo 已提交
169
      VLOG(10) << "Add dist_multi_devices_pass";
170 171 172
      multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
    } else {
      if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
C
chengduo 已提交
173
        VLOG(10) << "Add all_reduce_mode_multi_devices_pass";
174
        multi_devices_pass =
C
chengduo 已提交
175
            AppendPass("all_reduce_mode_multi_devices_pass").get();
176
      } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
C
chengduo 已提交
177
        VLOG(10) << "Add reduce_mode_multi_devices_pass";
178 179 180 181 182 183 184 185 186
        multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
      } else {
        PADDLE_THROW("Unknown reduce strategy.");
      }
    }
    multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
                                                         &strategy_);
  }

187 188 189 190
 private:
  BuildStrategy strategy_;
};

191
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
X
Xin Pan 已提交
192 193
    bool finalize_strategy) const {
  if (is_finalized_) {
194 195
    return pass_builder_;
  }
196
  pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
X
Xin Pan 已提交
197 198
  if (finalize_strategy) {
    is_finalized_ = true;
199
  }
X
fix  
Xin Pan 已提交
200
  return pass_builder_;
201 202
}

203 204 205 206
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
  return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}

207 208 209 210 211
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
                                const std::vector<platform::Place> &places,
                                const std::string &loss_var_name,
                                const std::vector<Scope *> &local_scopes,
                                const size_t &nranks,
P
peizhilin 已提交
212
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
213 214
                                const bool use_cuda,
                                platform::NCCLContextMap *nccl_ctxs) const {
215
#else
216
                                const bool use_cuda) const {
217
#endif
218 219
  // Create a default one if not finalized by user.
  CreatePassesFromStrategy(false);
X
fix  
Xin Pan 已提交
220 221

  for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
222 223 224 225 226 227 228
    if (IsMultiDevPass(pass->Type())) {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLossVarName);
      pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
X
fix  
Xin Pan 已提交
229
                                                    &local_scopes);
230 231
      pass->Erase(kNRanks);
      pass->Set<size_t>(kNRanks, new size_t(nranks));
Y
Yancey1989 已提交
232

P
peizhilin 已提交
233
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
X
fix  
Xin Pan 已提交
234
      platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
C
chengduo 已提交
235 236
      pass->Erase(kNCCLCtxs);
      pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
237
#endif
C
chengduo 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
    } else if (pass->Type() == "fuse_all_reduce_op_pass") {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
      platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
      pass->Erase(kNCCLCtxs);
      pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
    } else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
S
sneaxiy 已提交
255
    } else if (pass->Type() == "sequential_execution_pass") {
256 257
      LOG(INFO) << "set enable_sequential_execution:"
                << enable_sequential_execution_;
258
    } else if (pass->Type() == "all_reduce_deps_pass") {
259 260
      LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
                << ", num_trainers:" << num_trainers_;
261 262 263 264 265 266
    } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
      if (!use_cuda) {
        LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
                        "GPU, skipped.";
        continue;
      }
X
fix  
Xin Pan 已提交
267
    }
268
    VLOG(3) << "Start Apply Pass " << pass->Type();
269
    graph = pass->Apply(graph);
270
    VLOG(3) << "Finish Apply Pass " << pass->Type();
X
fix  
Xin Pan 已提交
271
  }
272 273
  return graph;
}
D
dzhwinter 已提交
274

275 276 277 278
}  // namespace details
}  // namespace framework
}  // namespace paddle

Q
qingqing01 已提交
279
USE_PASS(sync_batch_norm_pass);
280
USE_PASS(fuse_relu_depthwise_conv_pass);
281 282
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
283
USE_PASS(multi_batch_merge_pass);
284
USE_PASS(reduce_mode_multi_devices_pass);
C
chengduo 已提交
285
USE_PASS(all_reduce_mode_multi_devices_pass);
286
USE_PASS(dist_multi_devices_pass);
287 288
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
D
dzhwinter 已提交
289
USE_PASS(memory_optimize_pass);
S
sneaxiy 已提交
290
USE_PASS(sequential_execution_pass);
291
USE_PASS(all_reduce_deps_pass);
S
sneaxiy 已提交
292
USE_PASS(modify_op_lock_and_record_event_pass);
D
dzhwinter 已提交
293
USE_PASS(inplace_pass);
M
minqiyang 已提交
294
USE_PASS(lock_free_optimize_pass);
C
chengduo 已提交
295
USE_PASS(alloc_continuous_space_for_grad_pass);
W
WangZhen 已提交
296
USE_PASS(graph_to_program_pass);
C
chengduo 已提交
297
USE_PASS(fuse_all_reduce_op_pass);