build_strategy.cc 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/details/build_strategy.h"

D
dzhwinter 已提交
17 18
#include <glog/logging.h>
#include <memory>
Q
qingqing01 已提交
19
#include <utility>
D
dzhwinter 已提交
20

D
dzhwinter 已提交
21
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
22
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
23
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
24
#include "paddle/fluid/framework/details/reduce_op_handle.h"
S
sneaxiy 已提交
25
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
26
#include "paddle/fluid/framework/ir/graph.h"
D
dzhwinter 已提交
27
#include "paddle/fluid/framework/ir/graph_helper.h"
W
WangZhen 已提交
28
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
29 30 31 32 33 34
#include "paddle/fluid/framework/ir/graph_viz_pass.h"

namespace paddle {
namespace framework {
namespace details {

35
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
Y
Yancey1989 已提交
36 37
  // Should fix the allreduce op order if scheduling
  // them in multiple threads or processes to avoid hang.
Y
Yancey1989 已提交
38
  // NOTE: ParallelGraph would execute this pass on each graph, so
Y
Yancey1989 已提交
39
  // don't need to append it here.
Y
Yancey1989 已提交
40
  return (!strategy.enable_sequential_execution_ &&
Y
Yancey1989 已提交
41 42
          strategy.num_trainers_ > 1) &&
         !strategy.enable_parallel_graph_;
43 44
}

45 46 47 48
class ParallelExecutorPassBuilder : public ir::PassBuilder {
 public:
  explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
      : ir::PassBuilder(), strategy_(strategy) {
C
chengduo 已提交
49 50 51 52 53 54 55 56
    // Add a graph viz pass to record a graph.
    if (!strategy_.debug_graphviz_path_.empty()) {
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
          "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

D
dzhwinter 已提交
57
    if (strategy_.enable_sequential_execution_) {
C
chengduo 已提交
58
      VLOG(10) << "Add sequential_execution_pass";
D
dzhwinter 已提交
59 60
      AppendPass("sequential_execution_pass");
    }
61

Q
qingqing01 已提交
62 63 64 65 66
    // Add op fusion.
    if (strategy.sync_batch_norm_) {
      AppendPass("sync_batch_norm_pass");
    }

D
dzhwinter 已提交
67 68
    // Add op fusion.
    if (strategy.fuse_relu_depthwise_conv_) {
C
chengduo 已提交
69
      VLOG(10) << "Add fuse_relu_depthwise_conv_pass";
D
dzhwinter 已提交
70
      AppendPass("fuse_relu_depthwise_conv_pass");
D
dzhwinter 已提交
71
    }
72

D
dzhwinter 已提交
73 74 75 76 77 78
    // NOTE(dzhwinter): A note for automatical inplace.
    // 1. modify program desc passes should put
    // before inplace pass.
    // 2. manually configured inplace should put
    // before inplace_pass

D
dzhwinter 已提交
79 80
    // Add automatically inplace.
    if (strategy_.enable_inplace_) {
C
chengduo 已提交
81
      VLOG(10) << "Add inplace_pass";
D
dzhwinter 已提交
82
      AppendPass("inplace_pass");
S
sneaxiy 已提交
83 84
    }

C
chengduo 已提交
85 86 87 88 89 90 91 92 93 94 95 96
    if (strategy.fuse_elewise_add_act_ops_) {
      VLOG(10) << "Add fuse_elewise_add_act_pass";
      AppendPass("fuse_elewise_add_act_pass");
    }

    // for single card training, fuse_all_reduce_ops is unnecessary.
    // alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
    if (strategy.fuse_all_reduce_ops_) {
      VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
      AppendPass("alloc_continuous_space_for_grad_pass");
    }

X
Xin Pan 已提交
97
    // Add a graph viz pass to record a graph.
C
chengduo 已提交
98
    if (!strategy.debug_graphviz_path_.empty()) {
99 100
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
C
chengduo 已提交
101
          "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
102 103 104
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

105 106 107 108
    CollectiveContext *context = CollectiveContext::GetInstance();
    context->endpoints_ = strategy_.trainers_endpoints_;
    context->trainer_id_ = strategy_.trainer_id_;
    PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
109
    if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
110 111 112 113 114 115
      PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
                         strategy_.trainers_endpoints_.size(),
                     "trainer_id_ < endpoints_ size");
    }
    VLOG(1) << "CollectiveContext:" << context->String();

D
dzhwinter 已提交
116 117 118 119 120 121
    // NOTE(dzh): memory optimize should be a runtime pass.
    // However, after multi_devices_pass, VarHandle, OpHandle is
    // the de-fact IR, any reuse on Graph is meaningless.
    // A side-effect of that, memory optimize cannot forsee the fetched vars
    // , so fetchlist should be set persistable before call the Run interface.
    if (strategy.memory_optimize_) {
C
chengduo 已提交
122 123
      VLOG(10) << "Add memory_optimize_pass";
      AppendPass("memory_optimize_pass");
D
dzhwinter 已提交
124
    }
125 126

    AppendMultiDevPass(strategy);
127

C
chengduo 已提交
128 129 130 131 132 133 134
    if (strategy.fuse_all_reduce_ops_) {
      // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
      // first, if the number is zero, fuse_all_reduce_ops will do nothing.
      VLOG(10) << "Add fuse_all_reduce_op_pass";
      AppendPass("fuse_all_reduce_op_pass");
    }

X
Xin Pan 已提交
135
    // Add a graph print pass to record a graph with device info.
136 137
    if (!strategy_.debug_graphviz_path_.empty()) {
      auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
D
dzhwinter 已提交
138 139 140 141 142
      const std::string graph_path =
          string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
                          "_multi_devices_graph");
      multi_devices_print_pass->Set<std::string>(kGraphvizPath,
                                                 new std::string(graph_path));
143 144 145 146 147 148
      multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
          "graph_printer", new details::GraphvizSSAGraphPrinter);
    }

    // Verify that the graph is correct for multi-device executor.
    AppendPass("multi_devices_check_pass");
S
sneaxiy 已提交
149

150
    if (SeqOnlyAllReduceOps(strategy)) {
C
chengduo 已提交
151
      VLOG(10) << "Add all_reduce_deps_pass";
152 153 154
      AppendPass("all_reduce_deps_pass");
    }

S
sneaxiy 已提交
155
    if (strategy_.remove_unnecessary_lock_) {
C
chengduo 已提交
156
      VLOG(10) << "Add modify_op_lock_and_record_event_pass";
S
sneaxiy 已提交
157 158
      AppendPass("modify_op_lock_and_record_event_pass");
    }
159 160
  }

161 162
  // Convert graph to run on multi-devices.
  void AppendMultiDevPass(const BuildStrategy &strategy) {
C
chengduo 已提交
163
    ir::Pass *multi_devices_pass = nullptr;
164
    if (strategy_.is_distribution_) {
C
chengduo 已提交
165
      VLOG(10) << "Add dist_multi_devices_pass";
166 167 168
      multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
    } else {
      if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
C
chengduo 已提交
169
        VLOG(10) << "Add all_reduce_mode_multi_devices_pass";
170
        multi_devices_pass =
C
chengduo 已提交
171
            AppendPass("all_reduce_mode_multi_devices_pass").get();
172
      } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
C
chengduo 已提交
173
        VLOG(10) << "Add reduce_mode_multi_devices_pass";
174 175 176 177 178 179 180 181 182
        multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
      } else {
        PADDLE_THROW("Unknown reduce strategy.");
      }
    }
    multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
                                                         &strategy_);
  }

183 184 185 186
 private:
  BuildStrategy strategy_;
};

187
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
X
Xin Pan 已提交
188 189
    bool finalize_strategy) const {
  if (is_finalized_) {
190 191
    return pass_builder_;
  }
192
  pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
X
Xin Pan 已提交
193 194
  if (finalize_strategy) {
    is_finalized_ = true;
195
  }
X
fix  
Xin Pan 已提交
196
  return pass_builder_;
197 198
}

199 200 201 202
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
  return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}

203
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
X
Xin Pan 已提交
204 205
    std::unique_ptr<ir::Graph> graph,
    const std::vector<platform::Place> &places,
206
    const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
207
    const size_t &nranks,
P
peizhilin 已提交
208
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
209 210 211 212
    const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
#else
    const bool use_cuda) const {
#endif
213 214
  // Create a default one if not finalized by user.
  CreatePassesFromStrategy(false);
X
fix  
Xin Pan 已提交
215 216

  for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
217 218 219 220 221 222 223
    if (IsMultiDevPass(pass->Type())) {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLossVarName);
      pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
X
fix  
Xin Pan 已提交
224
                                                    &local_scopes);
225 226
      pass->Erase(kNRanks);
      pass->Set<size_t>(kNRanks, new size_t(nranks));
Y
Yancey1989 已提交
227

P
peizhilin 已提交
228
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
X
fix  
Xin Pan 已提交
229
      platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
C
chengduo 已提交
230 231
      pass->Erase(kNCCLCtxs);
      pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
232
#endif
C
chengduo 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
    } else if (pass->Type() == "fuse_all_reduce_op_pass") {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
      platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
      pass->Erase(kNCCLCtxs);
      pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
    } else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
S
sneaxiy 已提交
250
    } else if (pass->Type() == "sequential_execution_pass") {
251 252
      LOG(INFO) << "set enable_sequential_execution:"
                << enable_sequential_execution_;
253
    } else if (pass->Type() == "all_reduce_deps_pass") {
254 255
      LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
                << ", num_trainers:" << num_trainers_;
256 257 258 259 260 261
    } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
      if (!use_cuda) {
        LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
                        "GPU, skipped.";
        continue;
      }
X
fix  
Xin Pan 已提交
262
    }
263
    VLOG(3) << "Start Apply Pass " << pass->Type();
X
fix  
Xin Pan 已提交
264
    graph = pass->Apply(std::move(graph));
265
    VLOG(3) << "Finish Apply Pass " << pass->Type();
X
fix  
Xin Pan 已提交
266
  }
267 268
  return graph;
}
D
dzhwinter 已提交
269

270 271 272 273
}  // namespace details
}  // namespace framework
}  // namespace paddle

Q
qingqing01 已提交
274
USE_PASS(sync_batch_norm_pass);
275
USE_PASS(fuse_relu_depthwise_conv_pass);
276 277
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
278
USE_PASS(multi_batch_merge_pass);
279
USE_PASS(reduce_mode_multi_devices_pass);
C
chengduo 已提交
280
USE_PASS(all_reduce_mode_multi_devices_pass);
281
USE_PASS(dist_multi_devices_pass);
282 283
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
D
dzhwinter 已提交
284
USE_PASS(memory_optimize_pass);
S
sneaxiy 已提交
285
USE_PASS(sequential_execution_pass);
286
USE_PASS(all_reduce_deps_pass);
S
sneaxiy 已提交
287
USE_PASS(modify_op_lock_and_record_event_pass);
D
dzhwinter 已提交
288
USE_PASS(inplace_pass);
M
minqiyang 已提交
289
USE_PASS(lock_free_optimize_pass);
C
chengduo 已提交
290
USE_PASS(alloc_continuous_space_for_grad_pass);
W
WangZhen 已提交
291
USE_PASS(graph_to_program_pass);
C
chengduo 已提交
292
USE_PASS(fuse_all_reduce_op_pass);