fuse_all_reduce_op_pass.cc 8.4 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
//   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

namespace paddle {
namespace framework {
27
namespace ir {
C
chengduo 已提交
28 29 30

class FuseAllReduceOpPass : public ir::Pass {
 protected:
31
  void ApplyImpl(ir::Graph *graph) const override {
32 33 34 35 36 37
    if (Get<size_t>(details::kNRanks) <= 1) {
      VLOG(6) << "The number of place is" << Get<size_t>(details::kNRanks)
              << ", there doesn't need apply FuseAllReduceOpPass.";
      return;
    }

38 39
    auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
    auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
40

41
#if defined(PADDLE_WITH_NCCL)
42
    auto *multi_nccl_ctxs =
43
        &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
C
chengduo 已提交
44 45
#endif

46
    ir::Graph &result = *graph;
47
    auto &params_grads =
48
        result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads);
C
chengduo 已提交
49
    size_t num_of_all_reduce = params_grads.size();
50
    std::unordered_set<std::string> grads;
C
chengduo 已提交
51 52 53 54 55
    grads.reserve(num_of_all_reduce);
    for (auto p_g : params_grads) {
      grads.insert(p_g.second);
    }

56 57
    std::unordered_map<std::string, Node *> all_reduce_ops =
        GetAllReduceOps(result, places, grads);
C
chengduo 已提交
58

C
chengduo 已提交
59
    VLOG(6) << "Find all_reduce_ops: " << all_reduce_ops.size();
C
chengduo 已提交
60
    if (all_reduce_ops.size() == 0) {
61
      return;
C
chengduo 已提交
62 63
    }

64 65 66 67 68 69
    PADDLE_ENFORCE_EQ(
        all_reduce_ops.size(), grads.size(),
        platform::errors::Unimplemented(
            "The number of all_reduce OpHandle is not equal to the "
            "number of grads. Maybe some gradients are sparse type, "
            "it is not supported currently."));
C
chengduo 已提交
70

71 72
    auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>(
        details::kGroupParamsAndDenseGrads);
C
chengduo 已提交
73

C
chengduo 已提交
74 75 76 77 78 79
    LOG(WARNING) << string::Sprintf(
        "Find all_reduce operators: %d. To make the speed faster, some "
        "all_reduce ops are fused during training, after fusion, "
        "the number of all_reduce ops is %d.",
        all_reduce_ops.size(), group_params_grads.size());

80 81
    for (auto &group_p_g : group_params_grads) {
      size_t group_size = group_p_g.size();
C
chengduo 已提交
82 83 84
      PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0));
      std::vector<ir::Node *> group_all_reduce_ops;
      group_all_reduce_ops.reserve(group_size);
85 86
      for (auto &p_g : group_p_g) {
        group_all_reduce_ops.emplace_back(all_reduce_ops.at(p_g.second));
C
chengduo 已提交
87
      }
88
#if defined(PADDLE_WITH_NCCL)
C
chengduo 已提交
89
      InsertFusedAllReduce(places, local_scopes, group_size,
90
                           group_all_reduce_ops, multi_nccl_ctxs, &result);
C
chengduo 已提交
91 92 93 94 95 96 97
#else
      InsertFusedAllReduce(places, local_scopes, group_size,
                           group_all_reduce_ops, &result);
#endif
    }
  }

98 99 100 101 102 103 104 105 106 107 108 109
  std::unordered_map<std::string, Node *> GetAllReduceOps(
      const Graph &result, const std::vector<platform::Place> &places,
      const std::unordered_set<std::string> &grads) const {
    size_t num_place = places.size();
    std::unordered_map<std::string, Node *> all_reduce_ops;
    all_reduce_ops.reserve(grads.size());
    for (auto &node : result.Nodes()) {
      if (node->IsOp()) {
        PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
        auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
            &node->Wrapper<details::OpHandleBase>());
        if (all_reduce_op_handle) {
110 111 112 113 114 115
#if defined(PADDLE_WITH_DGC)
          PADDLE_ENFORCE_NE(
              all_reduce_op_handle->Name(), "sparse_all_reduce",
              "DGC doesn't support fuse for now, if you want to use DGC "
              "you need set strategy.fuse_all_reduce_ops = False.");
#endif
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
          auto inputs = details::DynamicCast<details::VarHandle>(
              all_reduce_op_handle->Inputs());
          PADDLE_ENFORCE_EQ(inputs.size(), num_place);
          // The inputs' name should be the same.
          auto &grad_name = inputs[0]->name();
          for (size_t i = 1; i < inputs.size(); ++i) {
            PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name,
                              "The input name should be the same.");
          }
          PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0));
          all_reduce_ops.emplace(grad_name, node);
        }
      }
    }
    return all_reduce_ops;
  }

133 134 135 136
  void InsertFusedAllReduce(const std::vector<platform::Place> &places,
                            const std::vector<Scope *> &local_scopes,
                            const size_t num_of_all_reduce,
                            const std::vector<ir::Node *> &all_reduce_ops,
137
#if defined(PADDLE_WITH_NCCL)
138
                            const platform::NCCLCommunicator *multi_nccl_ctxs,
C
chengduo 已提交
139
#endif
140
                            ir::Graph *result) const {
141 142
    std::vector<details::VarHandleBase *> inputs;
    std::vector<details::VarHandleBase *> outputs;
C
chengduo 已提交
143
    for (auto &op : all_reduce_ops) {
144
      auto &op_handle = op->Wrapper<details::OpHandleBase>();
C
chengduo 已提交
145 146 147 148
      inputs.insert(inputs.end(), op_handle.Inputs().begin(),
                    op_handle.Inputs().end());
      // Remove output
      for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(),
149
               [&op_handle](details::VarHandleBase *var_handle) {
C
chengduo 已提交
150 151 152 153 154 155
                 var_handle->RemoveOutput(&op_handle, op_handle.Node());
               });

      outputs.insert(outputs.end(), op_handle.Outputs().begin(),
                     op_handle.Outputs().end());
      // Remove Input
156 157 158 159
      for_each(op_handle.Outputs().begin(), op_handle.Outputs().end(),
               [](details::VarHandleBase *var_handle) {
                 var_handle->ClearGeneratedOp();
               });
C
chengduo 已提交
160 161 162 163

      result->RemoveNode(op_handle.Node());
    }

164
#if defined(PADDLE_WITH_NCCL)
C
chengduo 已提交
165
    CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
166
                           local_scopes, multi_nccl_ctxs, result);
C
chengduo 已提交
167 168 169 170 171 172 173
#else
    CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
                           local_scopes, result);
#endif
  }

 private:
174 175 176 177 178 179
  void CreateFusedAllReduceOp(
      const std::vector<details::VarHandleBase *> &inputs,
      const std::vector<details::VarHandleBase *> &outputs,
      const size_t num_of_all_reduce,
      const std::vector<platform::Place> &places,
      const std::vector<Scope *> &local_scopes,
180
#if defined(PADDLE_WITH_NCCL)
181
      const platform::NCCLCommunicator *multi_nccl_ctxs,
C
chengduo 已提交
182
#endif
183
      ir::Graph *result) const {
184
#if defined(PADDLE_WITH_NCCL)
185
    auto *op_handle = new details::FusedAllReduceOpHandle(
C
chengduo 已提交
186
        result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
187
        local_scopes, places, num_of_all_reduce, multi_nccl_ctxs);
C
chengduo 已提交
188
#else
189
    auto *op_handle = new details::FusedAllReduceOpHandle(
C
chengduo 已提交
190 191 192 193 194 195 196 197 198 199 200 201
        result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
        local_scopes, places, num_of_all_reduce);
#endif

    for (auto in : inputs) {
      op_handle->AddInput(in);
    }

    for (auto out : outputs) {
      op_handle->AddOutput(out);
    }

202
#if defined(PADDLE_WITH_NCCL)
203
    if (!multi_nccl_ctxs) {
C
chengduo 已提交
204 205 206 207 208 209 210
      SetCommunicationContext(places, op_handle);
    }
#else
    SetCommunicationContext(places, op_handle);
#endif
  }

211 212 213
  void SetCommunicationContext(
      const std::vector<platform::Place> &places,
      details::FusedAllReduceOpHandle *op_handle) const {
C
chengduo 已提交
214 215 216 217 218 219 220
    for (size_t i = 0; i < places.size(); ++i) {
      op_handle->SetDeviceContext(
          places[i], platform::DeviceContextPool::Instance().Get(places[i]));
    }
  }
};

221
}  // namespace ir
C
chengduo 已提交
222 223 224 225
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(fuse_all_reduce_op_pass,
226 227
              paddle::framework::ir::FuseAllReduceOpPass)
    .RequirePassAttr(paddle::framework::details::kNRanks);