build_strategy.cc 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/details/build_strategy.h"

D
dzhwinter 已提交
17 18
#include <glog/logging.h>
#include <memory>
Q
qingqing01 已提交
19
#include <utility>
D
dzhwinter 已提交
20

D
dzhwinter 已提交
21
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
22
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
23
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
24
#include "paddle/fluid/framework/details/reduce_op_handle.h"
S
sneaxiy 已提交
25
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
26
#include "paddle/fluid/framework/ir/graph.h"
D
dzhwinter 已提交
27
#include "paddle/fluid/framework/ir/graph_helper.h"
W
WangZhen 已提交
28
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
29 30 31 32 33 34
#include "paddle/fluid/framework/ir/graph_viz_pass.h"

namespace paddle {
namespace framework {
namespace details {

35
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
Y
Yancey1989 已提交
36 37
  // Should fix the allreduce op order if scheduling
  // them in multiple threads or processes to avoid hang.
Y
Yancey1989 已提交
38
  // NOTE: ParallelGraph would execute this pass on each graph, so
Y
Yancey1989 已提交
39
  // don't need to append it here.
Y
Yancey1989 已提交
40
  return (!strategy.enable_sequential_execution_ &&
Y
Yancey1989 已提交
41 42
          strategy.num_trainers_ > 1) &&
         !strategy.enable_parallel_graph_;
43 44
}

45 46 47 48
class ParallelExecutorPassBuilder : public ir::PassBuilder {
 public:
  explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
      : ir::PassBuilder(), strategy_(strategy) {
C
chengduo 已提交
49 50 51 52 53 54 55 56
    // Add a graph viz pass to record a graph.
    if (!strategy_.debug_graphviz_path_.empty()) {
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
          "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

D
dzhwinter 已提交
57
    if (strategy_.enable_sequential_execution_) {
C
chengduo 已提交
58
      VLOG(10) << "Add sequential_execution_pass";
D
dzhwinter 已提交
59 60
      AppendPass("sequential_execution_pass");
    }
61

Q
qingqing01 已提交
62 63 64 65 66
    // Add op fusion.
    if (strategy.sync_batch_norm_) {
      AppendPass("sync_batch_norm_pass");
    }

D
dzhwinter 已提交
67 68
    // Add op fusion.
    if (strategy.fuse_relu_depthwise_conv_) {
C
chengduo 已提交
69
      VLOG(10) << "Add fuse_relu_depthwise_conv_pass";
D
dzhwinter 已提交
70
      AppendPass("fuse_relu_depthwise_conv_pass");
D
dzhwinter 已提交
71
    }
72

D
dzhwinter 已提交
73 74 75 76 77 78
    // NOTE(dzhwinter): A note for automatical inplace.
    // 1. modify program desc passes should put
    // before inplace pass.
    // 2. manually configured inplace should put
    // before inplace_pass

D
dzhwinter 已提交
79 80
    // Add automatically inplace.
    if (strategy_.enable_inplace_) {
C
chengduo 已提交
81
      VLOG(10) << "Add inplace_pass";
D
dzhwinter 已提交
82
      AppendPass("inplace_pass");
S
sneaxiy 已提交
83 84
    }

C
chengduo 已提交
85 86 87 88 89 90 91 92 93 94 95 96
    if (strategy.fuse_elewise_add_act_ops_) {
      VLOG(10) << "Add fuse_elewise_add_act_pass";
      AppendPass("fuse_elewise_add_act_pass");
    }

    // for single card training, fuse_all_reduce_ops is unnecessary.
    // alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
    if (strategy.fuse_all_reduce_ops_) {
      VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
      AppendPass("alloc_continuous_space_for_grad_pass");
    }

X
Xin Pan 已提交
97
    // Add a graph viz pass to record a graph.
C
chengduo 已提交
98
    if (!strategy.debug_graphviz_path_.empty()) {
99 100
      auto viz_pass = AppendPass("graph_viz_pass");
      const std::string graph_path = string::Sprintf(
C
chengduo 已提交
101
          "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
102 103 104
      viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
    }

105 106 107 108
    CollectiveContext *context = CollectiveContext::GetInstance();
    context->endpoints_ = strategy_.trainers_endpoints_;
    context->trainer_id_ = strategy_.trainer_id_;
    PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
109
    if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
110 111 112 113 114 115
      PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
                         strategy_.trainers_endpoints_.size(),
                     "trainer_id_ < endpoints_ size");
    }
    VLOG(1) << "CollectiveContext:" << context->String();

D
dzhwinter 已提交
116 117 118 119 120 121
    // NOTE(dzh): memory optimize should be a runtime pass.
    // However, after multi_devices_pass, VarHandle, OpHandle is
    // the de-fact IR, any reuse on Graph is meaningless.
    // A side-effect of that, memory optimize cannot forsee the fetched vars
    // , so fetchlist should be set persistable before call the Run interface.
    if (strategy.memory_optimize_) {
C
chengduo 已提交
122 123
      VLOG(10) << "Add memory_optimize_pass";
      AppendPass("memory_optimize_pass");
D
dzhwinter 已提交
124
    }
125 126

    AppendMultiDevPass(strategy);
127

C
chengduo 已提交
128 129 130 131 132 133 134
    if (strategy.fuse_all_reduce_ops_) {
      // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
      // first, if the number is zero, fuse_all_reduce_ops will do nothing.
      VLOG(10) << "Add fuse_all_reduce_op_pass";
      AppendPass("fuse_all_reduce_op_pass");
    }

X
Xin Pan 已提交
135
    // Add a graph print pass to record a graph with device info.
136 137
    if (!strategy_.debug_graphviz_path_.empty()) {
      auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
D
dzhwinter 已提交
138 139 140 141 142
      const std::string graph_path =
          string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
                          "_multi_devices_graph");
      multi_devices_print_pass->Set<std::string>(kGraphvizPath,
                                                 new std::string(graph_path));
143 144 145 146 147 148
      multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
          "graph_printer", new details::GraphvizSSAGraphPrinter);
    }

    // Verify that the graph is correct for multi-device executor.
    AppendPass("multi_devices_check_pass");
S
sneaxiy 已提交
149

150 151 152 153
    if (VLOG_IS_ON(2)) {
      AppendPass("all_reduce_deps_pass");
    }

154
    if (SeqOnlyAllReduceOps(strategy)) {
C
chengduo 已提交
155
      VLOG(10) << "Add all_reduce_deps_pass";
156 157 158
      AppendPass("all_reduce_deps_pass");
    }

S
sneaxiy 已提交
159
    if (strategy_.remove_unnecessary_lock_) {
C
chengduo 已提交
160
      VLOG(10) << "Add modify_op_lock_and_record_event_pass";
S
sneaxiy 已提交
161 162
      AppendPass("modify_op_lock_and_record_event_pass");
    }
163 164
  }

165 166
  // Convert graph to run on multi-devices.
  void AppendMultiDevPass(const BuildStrategy &strategy) {
C
chengduo 已提交
167
    ir::Pass *multi_devices_pass = nullptr;
168
    if (strategy_.is_distribution_) {
C
chengduo 已提交
169
      VLOG(10) << "Add dist_multi_devices_pass";
170 171 172
      multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
    } else {
      if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
C
chengduo 已提交
173
        VLOG(10) << "Add all_reduce_mode_multi_devices_pass";
174
        multi_devices_pass =
C
chengduo 已提交
175
            AppendPass("all_reduce_mode_multi_devices_pass").get();
176
      } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
C
chengduo 已提交
177
        VLOG(10) << "Add reduce_mode_multi_devices_pass";
178 179 180 181 182 183 184 185 186
        multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
      } else {
        PADDLE_THROW("Unknown reduce strategy.");
      }
    }
    multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
                                                         &strategy_);
  }

187 188 189 190
 private:
  BuildStrategy strategy_;
};

191
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
X
Xin Pan 已提交
192 193
    bool finalize_strategy) const {
  if (is_finalized_) {
194 195
    return pass_builder_;
  }
196
  pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
X
Xin Pan 已提交
197 198
  if (finalize_strategy) {
    is_finalized_ = true;
199
  }
X
fix  
Xin Pan 已提交
200
  return pass_builder_;
201 202
}

203 204 205 206
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
  return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}

207
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
X
Xin Pan 已提交
208 209
    std::unique_ptr<ir::Graph> graph,
    const std::vector<platform::Place> &places,
210
    const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
211
    const size_t &nranks,
P
peizhilin 已提交
212
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
213 214 215 216
    const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
#else
    const bool use_cuda) const {
#endif
217 218
  // Create a default one if not finalized by user.
  CreatePassesFromStrategy(false);
X
fix  
Xin Pan 已提交
219 220

  for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
221 222 223 224 225 226 227
    if (IsMultiDevPass(pass->Type())) {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLossVarName);
      pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
X
fix  
Xin Pan 已提交
228
                                                    &local_scopes);
229 230
      pass->Erase(kNRanks);
      pass->Set<size_t>(kNRanks, new size_t(nranks));
Y
Yancey1989 已提交
231

P
peizhilin 已提交
232
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
X
fix  
Xin Pan 已提交
233
      platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
C
chengduo 已提交
234 235
      pass->Erase(kNCCLCtxs);
      pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
236
#endif
C
chengduo 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
    } else if (pass->Type() == "fuse_all_reduce_op_pass") {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
      platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
      pass->Erase(kNCCLCtxs);
      pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
    } else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
      pass->Erase(kPlaces);
      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
      pass->Erase(kLocalScopes);
      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                    &local_scopes);
S
sneaxiy 已提交
254
    } else if (pass->Type() == "sequential_execution_pass") {
255 256
      LOG(INFO) << "set enable_sequential_execution:"
                << enable_sequential_execution_;
257
    } else if (pass->Type() == "all_reduce_deps_pass") {
258 259
      LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
                << ", num_trainers:" << num_trainers_;
260 261 262 263 264 265
    } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
      if (!use_cuda) {
        LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
                        "GPU, skipped.";
        continue;
      }
X
fix  
Xin Pan 已提交
266
    }
267
    VLOG(3) << "Start Apply Pass " << pass->Type();
X
fix  
Xin Pan 已提交
268
    graph = pass->Apply(std::move(graph));
269
    VLOG(3) << "Finish Apply Pass " << pass->Type();
X
fix  
Xin Pan 已提交
270
  }
271 272
  return graph;
}
D
dzhwinter 已提交
273

274 275 276 277
}  // namespace details
}  // namespace framework
}  // namespace paddle

Q
qingqing01 已提交
278
USE_PASS(sync_batch_norm_pass);
279
USE_PASS(fuse_relu_depthwise_conv_pass);
280 281
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
282
USE_PASS(multi_batch_merge_pass);
283
USE_PASS(reduce_mode_multi_devices_pass);
C
chengduo 已提交
284
USE_PASS(all_reduce_mode_multi_devices_pass);
285
USE_PASS(dist_multi_devices_pass);
286 287
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
D
dzhwinter 已提交
288
USE_PASS(memory_optimize_pass);
S
sneaxiy 已提交
289
USE_PASS(sequential_execution_pass);
290
USE_PASS(all_reduce_deps_pass);
S
sneaxiy 已提交
291
USE_PASS(modify_op_lock_and_record_event_pass);
D
dzhwinter 已提交
292
USE_PASS(inplace_pass);
M
minqiyang 已提交
293
USE_PASS(lock_free_optimize_pass);
C
chengduo 已提交
294
USE_PASS(alloc_continuous_space_for_grad_pass);
W
WangZhen 已提交
295
USE_PASS(graph_to_program_pass);
C
chengduo 已提交
296
USE_PASS(fuse_all_reduce_op_pass);