runtime_graph.cc 10.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"

namespace paddle {
namespace distributed {
namespace {

using OperatorBase = RuntimeGraph::OperatorBase;
using OpRole = paddle::framework::OpRole;
using OpRegistry = paddle::framework::OpRegistry;
using ProgramDesc = paddle::framework::ProgramDesc;

30 31 32 33
bool IsForward(int32_t op_role) {
  return (op_role == static_cast<int32_t>(OpRole::kForward)) ||
         (op_role == (static_cast<int32_t>(OpRole::kForward) |
                      static_cast<int32_t>(OpRole::kLoss)));
34 35
}

36 37
bool IsLRSched(int32_t op_role) {
  return op_role == static_cast<int32_t>(OpRole::kLRSched);
38 39
}

40 41 42 43
bool IsBackward(int32_t op_role) {
  return (op_role == static_cast<int32_t>(OpRole::kBackward)) ||
         (op_role == (static_cast<int32_t>(OpRole::kBackward) |
                      static_cast<int32_t>(OpRole::kLoss)));
44 45
}

46 47
bool IsOptimize(int32_t op_role) {
  return op_role == static_cast<int32_t>(OpRole::kOptimize);
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
}

struct DistCoord {
  int32_t dp_idx;
  int32_t pp_idx;
  int32_t mp_idx;
};

class DistCoordSys final {
 public:
  DistCoordSys(int32_t dp_degree, int32_t pp_degree, int32_t mp_degree)
      : dp_degree_(dp_degree), pp_degree_(pp_degree), mp_degree_(mp_degree) {}
  DistCoord RankToCoord(int64_t rank) const;
  int64_t CoordToRank(const DistCoord& coord) const;

 private:
  DISABLE_COPY_AND_ASSIGN(DistCoordSys);
  bool InvalidCoord(const DistCoord& coord) const;
  int32_t dp_degree_;
  int32_t pp_degree_;
  int32_t mp_degree_;
};

DistCoord DistCoordSys::RankToCoord(int64_t rank) const {
  DistCoord coord;
  coord.mp_idx = rank % mp_degree_;
  rank /= mp_degree_;
  coord.pp_idx = rank % pp_degree_;
  rank /= pp_degree_;
  coord.dp_idx = rank % dp_degree_;
  return coord;
}

int64_t DistCoordSys::CoordToRank(const DistCoord& coord) const {
  if (InvalidCoord(coord)) {
    return -1;
  }
  return coord.dp_idx * pp_degree_ * mp_degree_ + coord.pp_idx * mp_degree_ +
         coord.mp_idx;
}

bool DistCoordSys::InvalidCoord(const DistCoord& coord) const {
  return coord.mp_idx < 0 || coord.mp_idx >= mp_degree_ || coord.pp_idx < 0 ||
         coord.pp_idx >= pp_degree_ || coord.dp_idx < 0 ||
         coord.dp_idx >= dp_degree_;
}

}  // namespace

std::vector<OpRole> RuntimeGraph::functionality_order = {
    OpRole::kLRSched, OpRole::kForward, OpRole::kBackward, OpRole::kOptimize};

RuntimeGraph::RuntimeGraph(const ProgramDesc& program,
                           const FleetExecutorDesc& exe_desc)
    : exe_desc_(exe_desc) {
103
  if (exe_desc.pp_degree() == 1) {
104 105 106 107 108 109 110 111 112 113 114
    int64_t cur_rank = exe_desc_.cur_rank();
    int64_t max_run_times = exe_desc_.num_micro_batches();
    int64_t max_slot_nums = exe_desc_.num_slots();
    auto task_node = std::make_unique<TaskNode>(program, cur_rank,
                                                max_run_times, max_slot_nums);
    task_node->SetType("Compute");
    task_nodes_.emplace_back(std::move(task_node));
    int64_t task_id = task_nodes_[0]->task_id();
    intercepter_id_to_rank_.insert({task_id, cur_rank});
    intercepter_id_to_node_.insert({task_id, task_nodes_[0].get()});
  } else {
115 116 117 118
    SplitProgramBasedFunctionality(program);
    AssignTaskToIntercepter();
    FakeDependence();
    FakeRuntimeInfo();
119 120 121 122 123 124 125
  }
}

void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
  for (const auto& op_desc : program.Block(0).AllOps()) {
    ops_.emplace_back(OpRegistry::CreateOp(*op_desc));
  }
126

127
  std::unordered_map<int32_t, std::vector<OperatorBase*>> role_to_ops;
128
  for (const auto& op : ops_) {
129
    int32_t op_role = op->Attr<int32_t>("op_role");
130 131 132 133 134 135 136 137 138 139 140 141 142 143
    OpRole new_op_role;
    if (IsLRSched(op_role)) {
      new_op_role = OpRole::kLRSched;
    } else if (IsForward(op_role)) {
      new_op_role = OpRole::kForward;
    } else if (IsBackward(op_role)) {
      new_op_role = OpRole::kBackward;
    } else if (IsOptimize(op_role)) {
      new_op_role = OpRole::kOptimize;
    } else {
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "The op %s is None of LRSched, Forward, Backward or Optimize.",
          op->Type()));
    }
144
    int32_t new_op_role_id = static_cast<int32_t>(new_op_role);
145 146 147 148 149
    if (role_to_ops.find(new_op_role_id) == role_to_ops.end()) {
      role_to_ops.insert({new_op_role_id, {}});
    }
    role_to_ops.at(new_op_role_id).emplace_back(op.get());
  }
150

151
  int64_t cur_rank = exe_desc_.cur_rank();
152 153 154 155 156
  DistCoordSys coord_sys(exe_desc_.dp_degree(), exe_desc_.pp_degree(),
                         exe_desc_.mp_degree());
  const auto& coord = coord_sys.RankToCoord(cur_rank);
  int pipeline_stage = coord.pp_idx;
  int64_t num_pipeline_stages = exe_desc_.pp_degree();
157

158
  // TODO(fleet_executor dev): start up steps should be a config `num_slots`
159
  int64_t start_up_steps = num_pipeline_stages - pipeline_stage;
160
  int64_t num_micro_batches = exe_desc_.num_micro_batches();
161 162
  int64_t task_id = cur_rank * functionality_order.size();
  for (std::size_t i = 0; i < functionality_order.size(); ++i) {
163
    VLOG(3) << "Runtime graph is creating task node for: " << task_id << ".";
164
    OpRole role = functionality_order[i];
165
    int32_t role_id = static_cast<int64_t>(role);
166 167
    int64_t max_run_times = num_micro_batches;
    int64_t max_slot_nums = start_up_steps;
168 169 170 171
    // NOTE: use short path, each interceptor should run for max_run_times
    std::vector<OperatorBase*> task_ops{};
    if (role_to_ops.find(role_id) != role_to_ops.end()) {
      task_ops = role_to_ops.at(role_id);
172
    }
173 174 175 176 177 178 179 180 181 182
    std::unique_ptr<TaskNode> task_node = std::make_unique<TaskNode>(
        role_id, task_ops, cur_rank, task_id, max_run_times, max_slot_nums);
    if (IsLRSched(role_id) || IsOptimize(role_id)) {
      task_node->SetType("Amplifier");
      if (IsLRSched(role_id)) {
        task_node->SetRunPerSteps(max_run_times);
      } else {
        task_node->SetRunAtOffset(max_run_times - 1);
        task_node->SetRunPerSteps(max_run_times);
      }
183
    } else {
184
      task_node->SetType("Compute");
185
    }
186
    task_nodes_.emplace_back(std::move(task_node));
187 188 189 190 191 192 193 194 195 196 197 198 199 200
    ++task_id;
  }
}

void RuntimeGraph::FakeDependence() {
  int64_t cur_rank = exe_desc_.cur_rank();
  DistCoordSys coord_sys(exe_desc_.dp_degree(), exe_desc_.pp_degree(),
                         exe_desc_.mp_degree());
  const auto& coord = coord_sys.RankToCoord(cur_rank);
  DistCoord upstream_coord = coord, downstream_coord = coord;
  upstream_coord.pp_idx -= 1;
  downstream_coord.pp_idx += 1;
  int64_t pp_upstream = coord_sys.CoordToRank(upstream_coord);
  int64_t pp_downstream = coord_sys.CoordToRank(downstream_coord);
201 202 203
  bool is_first_stage = (pp_upstream == -1);
  bool is_last_stage = (pp_downstream == -1);

204
  int32_t num_of_functionality = functionality_order.size();
205 206 207 208 209
  // lr(1:m) -> forward -> backward -> (m:1)optimize
  //               ↑          ↓
  // lr(1:m) -> forward -> backward -> (m:1)optimize
  //               ↑          ↓
  // lr(1:m) -> forward -> backward -> (m:1)optimize
210
  for (std::size_t i = 0; i < task_nodes_.size(); ++i) {
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
    auto& node = task_nodes_[i];
    bool is_forward = IsForward(node->role());
    bool is_backward = IsBackward(node->role());

    int64_t cur_id = cur_rank * num_of_functionality + i;
    int64_t prev_id = cur_id - 1;
    int64_t next_id = cur_id + 1;

    int64_t upstream_id = pp_upstream * num_of_functionality + i;
    int64_t downstream_id = pp_downstream * num_of_functionality + i;

    // 1F1B, last stage pp_buff_size should be 1, while first stage
    // pp_buff_size should be pp_degree
    int64_t pp_buff_size = exe_desc_.pp_degree() - coord.pp_idx;

    std::vector<std::pair<int64_t, int64_t>> ups;
    std::vector<std::pair<int64_t, int64_t>> downs;

    if (i != 0) {  // not lr
      int64_t buff_size = is_backward ? pp_buff_size : 2;
      ups.emplace_back(prev_id, buff_size);
232
    }
233 234 235
    if (i != task_nodes_.size() - 1) {  // not optimize
      int64_t buff_size = is_forward ? pp_buff_size : 2;
      downs.emplace_back(next_id, buff_size);
236
    }
237 238 239 240

    if (is_forward) {
      if (!is_first_stage) {
        ups.emplace_back(upstream_id, 2);
241
      }
242 243
      if (!is_last_stage) {
        downs.emplace_back(downstream_id, 2);
244
      }
245 246 247
    } else if (is_backward) {
      if (!is_last_stage) {
        ups.emplace_back(downstream_id, 2);
248
      }
249 250
      if (!is_first_stage) {
        downs.emplace_back(upstream_id, 2);
251 252
      }
    }
253 254 255 256 257 258 259 260 261 262 263

    for (auto up : ups) {
      VLOG(3) << "Task(" << cur_id << ") AddUpstream Task(" << up.first
              << ") with buff_size=" << up.second;
      node->AddUpstreamTask(up.first, up.second);
    }
    for (auto down : downs) {
      VLOG(3) << "Task(" << cur_id << ") AddDownstream Task(" << down.first
              << ") with buff_size=" << down.second;
      node->AddDownstreamTask(down.first, down.second);
    }
264 265 266 267 268 269
  }
}

void RuntimeGraph::AssignTaskToIntercepter() {
  for (const auto& task : task_nodes_) {
    int64_t intercepter_id = task->task_id();
270 271
    VLOG(3) << "Runtime graph is assigning task to interceptor: "
            << intercepter_id << " with type: " << task->type() << ".";
272 273 274 275 276 277 278 279 280 281 282 283 284
    if (intercepter_id_to_node_.find(intercepter_id) !=
        intercepter_id_to_node_.end()) {
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "Repeated intercepter id: %d", intercepter_id));
    }
    intercepter_id_to_node_.insert({intercepter_id, task.get()});
  }
}

void RuntimeGraph::FakeRuntimeInfo() {
  int64_t nrank = exe_desc_.cluster_info().size();
  int32_t num_of_functionality = functionality_order.size();
  for (int64_t i = 0; i < nrank; ++i) {
285
    for (int32_t j = 0; j < num_of_functionality; ++j) {
286 287 288 289 290 291
      int64_t intercepter_id = i * num_of_functionality + j;
      intercepter_id_to_rank_.insert({intercepter_id, i});
    }
  }
}

292 293 294 295 296 297 298 299 300 301
std::string RuntimeGraph::DebugString() const {
  std::ostringstream os;
  os << "\nRuntime Graph Debug: \n";
  for (const auto& task : task_nodes_) {
    os << task->DebugString();
    os << "\n";
  }
  return os.str();
}

302 303
}  // namespace distributed
}  // namespace paddle