multi_devices_graph_pass.h 7.6 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16

Q
Qiao Longfei 已提交
17
#include <memory>
T
wip  
typhoonzero 已提交
18
#include <string>
Q
Qiao Longfei 已提交
19 20
#include <unordered_map>
#include <unordered_set>
C
chengduoZH 已提交
21
#include <utility>
T
wip  
typhoonzero 已提交
22
#include <vector>
W
wanghuancoder 已提交
23

Y
yuyang18 已提交
24
#include "paddle/fluid/framework/details/build_strategy.h"
X
Xin Pan 已提交
25
#include "paddle/fluid/framework/details/multi_devices_helper.h"
X
Xin Pan 已提交
26
#include "paddle/fluid/framework/ir/graph.h"
Y
Yu Yang 已提交
27

W
wanghuancoder 已提交
28 29 30 31 32 33 34 35 36 37 38 39
namespace paddle {
namespace framework {
namespace details {
class OpHandleBase;
struct VarHandle;
}  // namespace details
namespace ir {
class Graph;
}  // namespace ir
}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
40 41
namespace paddle {
namespace platform {
42
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
W
wanghuancoder 已提交
43
class NCCLCommunicator;
44
class NCCLContextMap;
45
#elif defined(PADDLE_WITH_XPU_BKCL)
46 47
class BKCLContextMap;
class BKCLCommunicator;
48
#endif
Y
Yu Yang 已提交
49 50 51 52
}

namespace framework {
class Scope;
W
wanghuancoder 已提交
53

54
namespace ir {
C
chengduoZH 已提交
55

C
chengduo 已提交
56 57 58
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy";

59
class MultiDevSSAGraphBuilderBase : public ir::Pass {
X
Xin Pan 已提交
60
 protected:
61
  void ApplyImpl(ir::Graph *graph) const override;
62

63
  virtual void Init() const;
T
wip  
typhoonzero 已提交
64

C
chengduo 已提交
65 66
  virtual void CheckGraph(const ir::Graph &graph) const;

67
  virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
Y
Yu Yang 已提交
68

69 70
  virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
                                  const std::string &p_name,
71
                                  const std::string &g_name) const = 0;
X
Xin Pan 已提交
72

C
chengduo 已提交
73
  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
74 75

  virtual void InsertPostprocessOps(ir::Graph *result) const = 0;
Y
Yu Yang 已提交
76

77 78
  bool UseGPU() const;

Q
Qiao Longfei 已提交
79 80
  virtual bool NeedCollectiveForGrad(const std::string &grad_name,
                                     std::vector<ir::Node *> ops) const;
81 82

  bool IsScaleLossOp(ir::Node *node) const;
Y
Yu Yang 已提交
83

X
Xin Pan 已提交
84
  void CreateComputationalOps(ir::Graph *result, ir::Node *node,
T
typhoonzero 已提交
85
                              size_t num_places) const;
Y
Yu Yang 已提交
86

87
  void CreateScaleLossGradOp(ir::Graph *result,
88
                             const std::string &loss_grad_name,
89
                             ir::Node *out_var_node, size_t loss_scale,
W
Wu Yi 已提交
90
                             proto::VarType::Type dtype) const;
91

92 93
  details::VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
                                     size_t dst_dev_id) const;
94

X
Xin Pan 已提交
95
  void CreateComputationalOp(ir::Graph *result, ir::Node *node,
96
                             size_t dev_id) const;
Y
Yu Yang 已提交
97

98
  bool IsSparseGradient(const std::string &og) const;
Y
Yu Yang 已提交
99

100 101
  void CreateAllReduceOp(ir::Graph *result, ir::Node *node,
                         const std::string &og, bool is_encoded = false) const;
102

X
Xin Pan 已提交
103
  void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
C
chengduoZH 已提交
104
                         size_t src_dev_id) const;
C
chengduoZH 已提交
105

106 107
  void InsertScaleLossGradOp(ir::Graph *result, const ir::Node *node) const;

108 109 110 111
  void CreateFusedBroadcastOp(
      ir::Graph *result,
      const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;

112
  void SetCommunicationContext(details::OpHandleBase *op_handle,
X
clean  
Xin Pan 已提交
113 114
                               const platform::Place &p) const;

115 116
  void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
                         size_t device_id) const;
C
chengduo 已提交
117

118 119
  void CreateIsolatedVarNode(ir::Graph *result, ir::Node *var_node) const;

120
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
121
  mutable platform::NCCLContextMap *nccl_ctxs_{nullptr};
122
  mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr};
123 124 125
#elif defined(PADDLE_WITH_XPU_BKCL)
  mutable platform::BKCLContextMap *bkcl_ctxs_{nullptr};
  mutable platform::BKCLCommunicator *multi_bkcl_ctxs_{nullptr};
126
#endif
C
chengduo 已提交
127

X
clean  
Xin Pan 已提交
128 129 130 131
  mutable std::string loss_var_name_;
  mutable std::vector<platform::Place> places_;
  mutable std::vector<Scope *> local_scopes_;

132
  mutable details::BuildStrategy strategy_;
Y
Yancey1989 已提交
133
  mutable std::unordered_map<std::string, VarDesc *> all_vars_;
134 135 136 137
};

class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
 protected:
138 139
  virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
                                  const std::string &p_name,
140 141 142
                                  const std::string &g_name) const;

  virtual void InsertPostprocessOps(ir::Graph *result) const {}
G
gongweibao 已提交
143 144

  bool IsEncoded(const std::string &p_name) const;
145 146
};

Q
can run  
Qiao Longfei 已提交
147 148
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
 protected:
149 150
  void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
                          const std::string &p_name,
151
                          const std::string &g_name) const override {}
Q
can run  
Qiao Longfei 已提交
152

Q
Qiao Longfei 已提交
153
  bool NeedCollectiveForGrad(const std::string &grad_name,
C
chengduo 已提交
154
                             std::vector<ir::Node *> ops) const override {
Q
Qiao Longfei 已提交
155 156
    return false;
  }
Q
can run  
Qiao Longfei 已提交
157

158
  bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override {
Q
can run  
Qiao Longfei 已提交
159
    if (node->Op()->Type() == "recv") {
Q
Qiao Longfei 已提交
160
      VLOG(1) << "set recv op do_not_run to true";
161
      node->Op()->SetAttr("do_not_run", 1);
Q
can run  
Qiao Longfei 已提交
162 163
      node->Op()->Flush();
    }
Q
can run  
Qiao Longfei 已提交
164 165 166
    return false;
  }

167
  void InsertPostprocessOps(ir::Graph *result) const override {}
Q
can run  
Qiao Longfei 已提交
168 169
};

170 171 172 173 174 175 176 177 178 179 180 181
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
 protected:
  int GetVarDeviceID(const std::string &varname) const;

  int GetOpDeviceID(ir::Node *node) const;

  size_t GetAppropriateDeviceID(
      const std::vector<std::string> &var_names) const;

  virtual void ResetState() const;

  mutable std::unordered_map<std::string, int> sharded_var_device_;
Y
Yancey1989 已提交
182
  mutable std::vector<int64_t> balance_vars_;
Y
Yu Yang 已提交
183
};
184 185 186 187 188

class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
 protected:
  virtual void Init() const;

189 190
  virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
                                  const std::string &p_name,
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
                                  const std::string &g_name) const;

  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;

  virtual void InsertPostprocessOps(ir::Graph *result) const;

  virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;

  virtual void ResetState() const;

  int GetOpDeviceID(ir::Node *node,
                    std::unordered_map<std::string, std::vector<ir::Node *>>
                        *delay_ops) const;

  std::vector<ir::Node *> SortForReduceMode(
      const std::vector<ir::Node *> &topo_ops) const;

  mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
};

class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
 protected:
  virtual void Init() const;

  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;

  virtual void InsertPostprocessOps(ir::Graph *result) const;

219 220
  virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
                                  const std::string &p_name,
221 222 223 224 225 226 227 228 229 230 231 232 233 234
                                  const std::string &g_name) const;

  virtual void ResetState() const;

  int CreateRPCOp(ir::Graph *result, ir::Node *node) const;

  int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;

  mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
  mutable bool need_broadcast_var_{false};
};

std::unordered_set<std::string> &MultiDevSSAGraphBuilder();

235
}  // namespace ir
Y
Yu Yang 已提交
236 237
}  // namespace framework
}  // namespace paddle