fuse_all_reduce_op_pass.cc 7.7 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
//   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

namespace paddle {
namespace framework {
27
namespace ir {
C
chengduo 已提交
28 29 30

class FuseAllReduceOpPass : public ir::Pass {
 protected:
31
  void ApplyImpl(ir::Graph *graph) const override {
C
chengduo 已提交
32
    ir::Graph &result = *graph;
33 34
    auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
    auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
C
chengduo 已提交
35
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
36
    auto *multi_nccl_ctxs =
37
        &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
C
chengduo 已提交
38 39
#endif

40
    auto &params_grads =
41
        result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads);
C
chengduo 已提交
42
    size_t num_of_all_reduce = params_grads.size();
43
    std::unordered_set<std::string> grads;
C
chengduo 已提交
44 45 46 47 48
    grads.reserve(num_of_all_reduce);
    for (auto p_g : params_grads) {
      grads.insert(p_g.second);
    }

49 50
    std::unordered_map<std::string, Node *> all_reduce_ops =
        GetAllReduceOps(result, places, grads);
C
chengduo 已提交
51 52 53

    VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
    if (all_reduce_ops.size() == 0) {
54
      return;
C
chengduo 已提交
55 56 57 58 59 60 61 62
    }

    PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
                      "The number of all_reduce OpHandle is not equal to the "
                      "number of grads. Maybe some gradients are sparse type, "
                      "it is not supported currently.");
    VLOG(10) << "Insert fused_all_reduce";

63 64
    auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>(
        details::kGroupParamsAndDenseGrads);
C
chengduo 已提交
65

66 67
    for (auto &group_p_g : group_params_grads) {
      size_t group_size = group_p_g.size();
C
chengduo 已提交
68 69 70
      PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0));
      std::vector<ir::Node *> group_all_reduce_ops;
      group_all_reduce_ops.reserve(group_size);
71 72
      for (auto &p_g : group_p_g) {
        group_all_reduce_ops.emplace_back(all_reduce_ops.at(p_g.second));
C
chengduo 已提交
73 74 75
      }
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
      InsertFusedAllReduce(places, local_scopes, group_size,
76
                           group_all_reduce_ops, multi_nccl_ctxs, &result);
C
chengduo 已提交
77 78 79 80 81 82 83
#else
      InsertFusedAllReduce(places, local_scopes, group_size,
                           group_all_reduce_ops, &result);
#endif
    }
  }

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
  std::unordered_map<std::string, Node *> GetAllReduceOps(
      const Graph &result, const std::vector<platform::Place> &places,
      const std::unordered_set<std::string> &grads) const {
    size_t num_place = places.size();
    std::unordered_map<std::string, Node *> all_reduce_ops;
    all_reduce_ops.reserve(grads.size());
    for (auto &node : result.Nodes()) {
      if (node->IsOp()) {
        PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
        auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
            &node->Wrapper<details::OpHandleBase>());
        if (all_reduce_op_handle) {
          auto inputs = details::DynamicCast<details::VarHandle>(
              all_reduce_op_handle->Inputs());
          PADDLE_ENFORCE_EQ(inputs.size(), num_place);
          // The inputs' name should be the same.
          auto &grad_name = inputs[0]->name();
          for (size_t i = 1; i < inputs.size(); ++i) {
            PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name,
                              "The input name should be the same.");
          }
          PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0));
          all_reduce_ops.emplace(grad_name, node);
        }
      }
    }
    return all_reduce_ops;
  }

113 114 115 116
  void InsertFusedAllReduce(const std::vector<platform::Place> &places,
                            const std::vector<Scope *> &local_scopes,
                            const size_t num_of_all_reduce,
                            const std::vector<ir::Node *> &all_reduce_ops,
C
chengduo 已提交
117
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
118
                            const platform::NCCLCommunicator *multi_nccl_ctxs,
C
chengduo 已提交
119
#endif
120
                            ir::Graph *result) const {
121 122
    std::vector<details::VarHandleBase *> inputs;
    std::vector<details::VarHandleBase *> outputs;
C
chengduo 已提交
123
    for (auto &op : all_reduce_ops) {
124
      auto &op_handle = op->Wrapper<details::OpHandleBase>();
C
chengduo 已提交
125 126 127 128
      inputs.insert(inputs.end(), op_handle.Inputs().begin(),
                    op_handle.Inputs().end());
      // Remove output
      for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(),
129
               [&op_handle](details::VarHandleBase *var_handle) {
C
chengduo 已提交
130 131 132 133 134 135
                 var_handle->RemoveOutput(&op_handle, op_handle.Node());
               });

      outputs.insert(outputs.end(), op_handle.Outputs().begin(),
                     op_handle.Outputs().end());
      // Remove Input
136 137 138 139
      for_each(op_handle.Outputs().begin(), op_handle.Outputs().end(),
               [](details::VarHandleBase *var_handle) {
                 var_handle->ClearGeneratedOp();
               });
C
chengduo 已提交
140 141 142 143 144 145

      result->RemoveNode(op_handle.Node());
    }

#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
    CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
146
                           local_scopes, multi_nccl_ctxs, result);
C
chengduo 已提交
147 148 149 150 151 152 153
#else
    CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
                           local_scopes, result);
#endif
  }

 private:
154 155 156 157 158 159
  void CreateFusedAllReduceOp(
      const std::vector<details::VarHandleBase *> &inputs,
      const std::vector<details::VarHandleBase *> &outputs,
      const size_t num_of_all_reduce,
      const std::vector<platform::Place> &places,
      const std::vector<Scope *> &local_scopes,
C
chengduo 已提交
160
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
161
      const platform::NCCLCommunicator *multi_nccl_ctxs,
C
chengduo 已提交
162
#endif
163
      ir::Graph *result) const {
C
chengduo 已提交
164
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
165
    auto *op_handle = new details::FusedAllReduceOpHandle(
C
chengduo 已提交
166
        result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
167
        local_scopes, places, num_of_all_reduce, multi_nccl_ctxs);
C
chengduo 已提交
168
#else
169
    auto *op_handle = new details::FusedAllReduceOpHandle(
C
chengduo 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182
        result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
        local_scopes, places, num_of_all_reduce);
#endif

    for (auto in : inputs) {
      op_handle->AddInput(in);
    }

    for (auto out : outputs) {
      op_handle->AddOutput(out);
    }

#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
183
    if (!multi_nccl_ctxs) {
C
chengduo 已提交
184 185 186 187 188 189 190
      SetCommunicationContext(places, op_handle);
    }
#else
    SetCommunicationContext(places, op_handle);
#endif
  }

191 192 193
  void SetCommunicationContext(
      const std::vector<platform::Place> &places,
      details::FusedAllReduceOpHandle *op_handle) const {
C
chengduo 已提交
194 195 196 197 198 199 200
    for (size_t i = 0; i < places.size(); ++i) {
      op_handle->SetDeviceContext(
          places[i], platform::DeviceContextPool::Instance().Get(places[i]));
    }
  }
};

201
}  // namespace ir
C
chengduo 已提交
202 203 204 205
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(fuse_all_reduce_op_pass,
206
              paddle::framework::ir::FuseAllReduceOpPass);