build_strategy.h 7.8 KB
Newer Older
Y
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

Q
Qiao Longfei 已提交
17
#include <memory>
Y
yuyang18 已提交
18
#include <string>
19
#include <unordered_set>
C
chengduo 已提交
20
#include <utility>
21
#include <vector>
22

23
#include "boost/optional.hpp"
24 25 26 27 28 29
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"

W
wanghuancoder 已提交
30 31 32 33 34 35 36 37 38 39 40 41
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class PassBuilder;
}  // namespace ir
}  // namespace framework
namespace platform {
class NCCLCommunicator;
}  // namespace platform
}  // namespace paddle

42
#if defined(PADDLE_WITH_NCCL)
43 44
#include "paddle/fluid/platform/nccl_helper.h"
#endif
Y
yuyang18 已提交
45

Y
yuyang18 已提交
46 47 48 49 50
namespace paddle {
namespace framework {
namespace details {

struct BuildStrategy {
C
chengduo 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
  // ParallelExecutor supports two modes of ReduceStrategy, kAllReduce and
  // kReduce, for CPU and GPU. If you use kAllReduce, different threads
  // optimize their parameters separately. If you use kReduce, the optimizations
  // of parameters are distributed to different threads.
  // For example, a model has 100 parameters and is running with four threads,
  // if you choose kAllReduce, every thread is to optimize 100 parameters
  // separately, if you choose kReduce, every thread is to optimize 25
  // parameters.
  // Of particular note is, if you use kReduce when using CPU training,
  // all the parameters are shared between different threads. This feature will
  // save memory.
  // FIXME(zcd): The result of the two modes(kAllReduce and kReduce) maybe not
  // equal for GPU. Because, the result of the different order of summing maybe
  // different, for example, the result of `a+b+c+d` may be different with the
  // result of `c+a+b+d`.
  // For GPU, the implementation of kAllReduce and kReduce is adopted NCCL,
  // so the result of kAllReduce and kReduce maybe not equal.
  // For CPU, if you want to fix the order of summing to make the result
  // of kAllReduce and kReduce no diff, you can add
  // `FLAGS_cpu_deterministic=true` to env.
Y
yuyang18 已提交
71 72 73 74 75
  enum class ReduceStrategy { kAllReduce = 0, kReduce = 1 };

  enum class GradientScaleStrategy {
    kCoeffNumDevice = 0,
    kOne = 1,
C
chengduo 已提交
76 77
    // user can customize gradient scale to use, and just feed
    // it into exe.run().
Y
yuyang18 已提交
78 79 80
    kCustomized = 2,
  };

Y
yuyang18 已提交
81
  ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
Y
yuyang18 已提交
82
  GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
Y
yuyang18 已提交
83 84

  std::string debug_graphviz_path_{""};
F
fengjiayi 已提交
85

C
chengduo 已提交
86 87 88
  // Add dependency between backward ops and optimization ops, make sure that
  // all the backward ops are finished before running the optimization ops.
  // It might make the training speed of data parallelism faster.
89
  bool enable_backward_optimizer_op_deps_{true};
C
chengduo 已提交
90 91 92 93 94 95 96 97
  // TODO(dev-paddle): enable_sequential_execution depends on
  // kStaleProgramOpDescs, it is not appropriate, because kStaleProgramOpDescs
  // will be removed in the near future.
  bool enable_sequential_execution_{false};
  bool remove_unnecessary_lock_{true};
  // TODO(dev-paddle): cache_runtime_context may cause some models to hang up
  // while running.
  bool cache_runtime_context_{false};
C
chengduo 已提交
98

C
chengduo 已提交
99 100 101
  // Operator fusion
  // TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
  // cycle.
Z
Zhen Wang 已提交
102
  bool fuse_bn_act_ops_{false};
Z
Zhang Ting 已提交
103
  bool fuse_bn_add_act_ops_{true};
104 105
  bool fuse_elewise_add_act_ops_{false};
  bool enable_auto_fusion_{false};
C
chengduo 已提交
106 107
  // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
  // should not be sparse types
108
  boost::optional<bool> fuse_all_optimizer_ops_{false};
C
chengduo 已提交
109
  boost::optional<bool> fuse_all_reduce_ops_{boost::none};
C
chengduo 已提交
110 111
  // fuse_relu_depthwise_conv can fuse the `relu ->
  // depthwise_conv`
112
  bool fuse_relu_depthwise_conv_{false};
C
chengduo 已提交
113 114 115 116
  // NOTE(zcd): In reduce mode, fusing broadcast ops may make the program
  // faster. Because fusing broadcast OP equals delaying the execution of all
  // broadcast Ops, in this case, all nccl streams are used only for reduce
  // operations for a period of time.
C
chengduo 已提交
117
  boost::optional<bool> fuse_broadcast_ops_{boost::none};
C
chengduo 已提交
118
  // replace batch_norm with sync_batch_norm.
Q
qingqing01 已提交
119 120
  bool sync_batch_norm_{false};

C
chengduo 已提交
121 122 123 124 125 126 127
  // mkldnn_enabled_op_types specify the operator type list to
  // use MKLDNN acceleration. It is null in default, means
  // that all the operators supported by MKLDNN will be
  // accelerated. And it should not be set when
  // FLAGS_use_mkldnn=false
  std::unordered_set<std::string> mkldnn_enabled_op_types_;

128 129 130 131
  // By default, memory_optimize would be opened if gc is disabled, and
  // be closed if gc is enabled.
  // Users can forcely enable/disable memory_optimize by setting True/False.
  boost::optional<bool> memory_optimize_{boost::none};
132 133 134 135

  // Turn on inplace by default.
  bool enable_inplace_{true};

136 137 138
  // Turn off inplace addto by default.
  bool enable_addto_{false};

139 140 141 142
  // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
  // num_trainers is 1, so the current fields of build_strategy doesn't tell if
  // it's distributed model.
  bool is_distribution_{false};
Q
can run  
Qiao Longfei 已提交
143
  bool async_mode_{false};
144
  int num_trainers_{1};
145 146
  int trainer_id_{0};
  std::vector<std::string> trainers_endpoints_;
147

C
chengduo 已提交
148
  // NCCL config
149 150 151 152
  size_t nccl_comm_num_{1};
  // The picture is here:
  // https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
  bool use_hierarchical_allreduce_{false};
T
tianshuo78520a 已提交
153
  // Nccl ranks in a node when use hierarchical allreduce, it's set to gpu
154 155
  // cards' number in most cases.
  size_t hierarchical_allreduce_inter_nranks_{0};
T
tianshuo78520a 已提交
156
  // Nccl ranks bewteen nodes when use hierarchical allreduce, it's set to
157 158 159
  // nodes number.
  size_t hierarchical_allreduce_exter_nranks_{0};

X
Xin Pan 已提交
160 161 162 163 164
  // NOTE:
  // Before you add new options, think if it's a general strategy that works
  // with other strategy. If not, the strategy should be created through
  // CreatePassesFromStrategy and the pass can be managed separately.

X
Xin Pan 已提交
165
  // User normally doesn't need to call this API.
X
Xin Pan 已提交
166
  // The PassBuilder allows for more customized insert, remove of passes
X
Xin Pan 已提交
167 168 169
  // from python side.
  // A new PassBuilder is created based on configs defined above and
  // passes are owned by the PassBuilder.
170
  std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy(
X
Xin Pan 已提交
171 172 173
      bool finalize_strategy) const;

  bool IsFinalized() const { return is_finalized_; }
174

175 176
  bool IsMultiDevPass(const std::string &pass_name) const;

X
Xin Pan 已提交
177 178
  // Apply the passes built by the pass_builder_. The passes will be
  // applied to the Program and output an ir::Graph.
179 180 181 182
  ir::Graph *Apply(ir::Graph *graph, const std::vector<platform::Place> &places,
                   const std::string &loss_var_name,
                   const std::vector<Scope *> &local_scopes,
                   const size_t &nranks,
183
#if defined(PADDLE_WITH_NCCL)
184
                   const bool use_cuda,
185
                   platform::NCCLCommunicator *nccl_ctxs) const;
186
#else
187
                   const bool use_cuda) const;
188 189
#endif

190 191 192 193 194 195 196
  // If set true, ParallelExecutor would build the main_program into multiple
  // graphs,
  // each of the graphs would run with one device. This approach can achieve
  // better performance
  // on some scenarios.
  mutable bool enable_parallel_graph_ = false;

197
 private:
X
Xin Pan 已提交
198
  mutable bool is_finalized_ = false;
199
  mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
Y
yuyang18 已提交
200 201 202 203 204
};

}  // namespace details
}  // namespace framework
}  // namespace paddle