multi_devices_graph_pass.h 5.9 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16

T
wip  
typhoonzero 已提交
17
#include <string>
C
chengduoZH 已提交
18
#include <utility>
T
wip  
typhoonzero 已提交
19 20
#include <vector>

Y
yuyang18 已提交
21
#include "paddle/fluid/framework/details/build_strategy.h"
X
Xin Pan 已提交
22
#include "paddle/fluid/framework/details/multi_devices_helper.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/framework/ir/graph.h"
Y
Yu Yang 已提交
24 25 26 27 28 29 30 31 32

namespace paddle {
namespace platform {
class NCCLContextMap;
}

namespace framework {
class Scope;
namespace details {
C
chengduoZH 已提交
33

34 35 36 37 38 39 40
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";

class MultiDevSSAGraphBuilderBase : public ir::Pass {
X
Xin Pan 已提交
41 42
 protected:
  std::unique_ptr<ir::Graph> ApplyImpl(
X
Xin Pan 已提交
43
      std::unique_ptr<ir::Graph> graph) const override;
44

45
  virtual void Init() const;
T
wip  
typhoonzero 已提交
46

47
  virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
Y
Yu Yang 已提交
48

Y
Yancey1989 已提交
49
  virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
50
                                  const std::string &g_name) const = 0;
X
Xin Pan 已提交
51

52 53 54
  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;

  virtual void InsertPostprocessOps(ir::Graph *result) const = 0;
Y
Yu Yang 已提交
55

56 57 58 59 60
  bool UseGPU() const;

  bool NeedCollectiveOps() const;

  bool IsScaleLossOp(ir::Node *node) const;
Y
Yu Yang 已提交
61

X
Xin Pan 已提交
62
  void CreateComputationalOps(ir::Graph *result, ir::Node *node,
T
typhoonzero 已提交
63
                              size_t num_places) const;
Y
Yu Yang 已提交
64

65
  void CreateScaleLossGradOp(ir::Graph *result,
66
                             const std::string &loss_grad_name,
67
                             ir::Node *out_var_node, size_t loss_scale,
W
Wu Yi 已提交
68
                             proto::VarType::Type dtype) const;
69

X
Xin Pan 已提交
70
  VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
C
chengduoZH 已提交
71
                            int dst_dev_id) const;
72

X
Xin Pan 已提交
73 74
  void CreateComputationalOp(ir::Graph *result, ir::Node *node,
                             int dev_id) const;
Y
Yu Yang 已提交
75

76
  bool IsSparseGradient(const std::string &og) const;
Y
Yu Yang 已提交
77

Y
Yancey1989 已提交
78
  void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
79

X
Xin Pan 已提交
80
  void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
C
chengduoZH 已提交
81
                         size_t src_dev_id) const;
C
chengduoZH 已提交
82

83 84
  void InsertScaleLossGradOp(ir::Graph *result, const ir::Node *node) const;

85 86 87 88
  void CreateFusedBroadcastOp(
      ir::Graph *result,
      const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;

X
clean  
Xin Pan 已提交
89 90 91
  void SetCommunicationContext(OpHandleBase *op_handle,
                               const platform::Place &p) const;

92 93
  void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
                         size_t device_id) const;
C
chengduo 已提交
94

95 96 97
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
  mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
C
chengduo 已提交
98

X
clean  
Xin Pan 已提交
99 100 101 102
  mutable std::string loss_var_name_;
  mutable std::vector<platform::Place> places_;
  mutable std::vector<Scope *> local_scopes_;

X
Xin Pan 已提交
103
  mutable BuildStrategy strategy_;
Y
Yancey1989 已提交
104
  mutable std::unordered_map<std::string, VarDesc *> all_vars_;
105 106 107 108
};

class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
 protected:
Y
Yancey1989 已提交
109
  virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
                                  const std::string &g_name) const;

  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
    return false;
  }

  virtual void InsertPostprocessOps(ir::Graph *result) const {}
};

class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
 protected:
  int GetVarDeviceID(const std::string &varname) const;

  int GetOpDeviceID(ir::Node *node) const;

  size_t GetAppropriateDeviceID(
      const std::vector<std::string> &var_names) const;

  virtual void ResetState() const;

  mutable std::unordered_map<std::string, int> sharded_var_device_;
Y
Yancey1989 已提交
131
  mutable std::vector<int64_t> balance_vars_;
Y
Yu Yang 已提交
132
};
133 134 135 136 137

class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
 protected:
  virtual void Init() const;

Y
Yancey1989 已提交
138
  virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
                                  const std::string &g_name) const;

  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;

  virtual void InsertPostprocessOps(ir::Graph *result) const;

  virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;

  virtual void ResetState() const;

  int GetOpDeviceID(ir::Node *node,
                    std::unordered_map<std::string, std::vector<ir::Node *>>
                        *delay_ops) const;

  std::vector<ir::Node *> SortForReduceMode(
      const std::vector<ir::Node *> &topo_ops) const;

  mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
};

class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
 protected:
  virtual void Init() const;

  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;

  virtual void InsertPostprocessOps(ir::Graph *result) const;

Y
Yancey1989 已提交
167
  virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
168 169 170 171 172 173 174 175 176 177 178 179 180 181
                                  const std::string &g_name) const;

  virtual void ResetState() const;

  int CreateRPCOp(ir::Graph *result, ir::Node *node) const;

  int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;

  mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
  mutable bool need_broadcast_var_{false};
};

std::unordered_set<std::string> &MultiDevSSAGraphBuilder();

Y
Yu Yang 已提交
182 183 184
}  // namespace details
}  // namespace framework
}  // namespace paddle