fleet_executor.cc 7.2 KB
Newer Older
L
LiYuRio 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
16
#include "paddle/fluid/distributed/fleet_executor/global.h"
17
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
L
LiYuRio 已提交
18
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
19
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
20 21 22
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
23
#include "paddle/fluid/framework/operator.h"
L
LiYuRio 已提交
24
#include "paddle/fluid/framework/program_desc.h"
25 26
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
L
LiYuRio 已提交
27 28 29 30 31

namespace paddle {
namespace distributed {

FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
L
LiYuRio 已提交
32 33 34
  bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
  PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
                                 "Error occurs while parsing string to proto"));
35 36 37
  // Message bus will be created and inited only once
  GlobalVal<MessageBus>::Create();
  InitMessageBus();
L
LiYuRio 已提交
38 39
}

40 41
FleetExecutor::~FleetExecutor() {
  root_scope_->DropKids();
42 43
  for (const auto& carrier_id : carrier_ids_) {
    GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();
44
  }
45
}
L
LiYuRio 已提交
46

47
void FleetExecutor::Init(
48 49 50
    const std::string& carrier_id, const framework::ProgramDesc& program_desc,
    framework::Scope* scope, const platform::Place& place,
    const std::vector<TaskNode*>& task_nodes,
51
    const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
52 53 54 55 56 57 58 59
  PADDLE_ENFORCE_GT(task_nodes.size(), 0,
                    platform::errors::InvalidArgument(
                        "Fleet executor is inited with empty task node"));
  // TODO(fleet_exe devs): the unused_vars should be got from run time graph
  std::vector<std::unique_ptr<framework::OperatorBase>> ops;
  for (auto task_node : task_nodes) {
    for (auto op : task_node->ops()) {
      ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
60
    }
61
  }
62 63 64 65 66 67 68 69 70 71 72 73 74
  auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
  runtime_graph_ = std::make_shared<RuntimeGraph>();
  std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
  for (auto task_node : task_nodes) {
    task_node->SetUnusedVars(unused_vars);
    int64_t interceptor_id = task_node->task_id();
    interceptor_id_to_task.emplace(interceptor_id, task_node);
  }
  runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
  runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
  for (auto& unique_op : ops) {
    unique_op.release();
  }
75 76 77 78 79 80 81 82 83 84 85
  root_scope_ = scope;
  place_ = place;
  PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
                                           "root_scope_ can not be nullptr"));
  minibatch_scope_ = &root_scope_->NewScope();
  int64_t num_micro_batches = exe_desc_.num_micro_batches();
  microbatch_scopes_.resize(num_micro_batches);
  for (int i = 0; i < num_micro_batches; ++i) {
    microbatch_scopes_[i] = &minibatch_scope_->NewScope();
    CopyParameters(i, program_desc);
  }
86
  VLOG(5) << runtime_graph_->DebugString();
87 88 89
  Carrier* carrier =
      GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
  carrier_ids_.insert(carrier_id);
90 91
  // Set current running carrier
  GlobalVal<std::string>::Set(new std::string(carrier_id));
92
  InitCarrier(carrier);
93
  GlobalVal<MessageBus>::Get()->Barrier();
94 95
}

96 97 98 99
void FleetExecutor::InitCarrier(Carrier* carrier) {
  carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
                runtime_graph_->interceptor_id_to_node(), root_scope_,
                minibatch_scope_, microbatch_scopes_, place_);
100 101
}

102 103 104 105 106 107 108
void FleetExecutor::InitMessageBus() {
  std::stringstream ss;
  ss << "\nThe DNS table of the message bus is: \n";
  int64_t cur_rank = exe_desc_.cur_rank();
  std::unordered_map<int64_t, std::string> rank_to_addr;
  std::string addr;
  for (const auto& rank_info : exe_desc_.cluster_info()) {
109
    // init the dns map
110 111 112 113 114 115 116 117
    int64_t rank = rank_info.rank();
    std::string ip_port = rank_info.ip_port();
    ss << rank << "\t->\t" << ip_port << "\n";
    rank_to_addr.insert(std::make_pair(rank, ip_port));
    if (rank == cur_rank) {
      addr = ip_port;
    }
  }
118 119
  if (addr == "") {
    PADDLE_ENFORCE_EQ(
120
        rank_to_addr.size(), 1,
121 122 123 124 125 126 127 128 129 130
        platform::errors::NotFound("Empty address is not valid for "
                                   "paddle.distributed.launch method."));
    PADDLE_ENFORCE_EQ(
        cur_rank, 0,
        platform::errors::NotFound("Address is empty but cur rank is not 0."));
  }
  VLOG(3) << "Current rank is " << cur_rank << " and the ip_port is "
          << (addr == "" ? "empty" : addr) << ".";
  VLOG(3) << "The number of ranks are "
          << (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
131
  VLOG(5) << ss.str();
132
  GlobalVal<MessageBus>::Get()->Init(cur_rank, rank_to_addr, addr);
L
LiYuRio 已提交
133 134
}

135
void FleetExecutor::Run(const std::string& carrier_id) {
136 137 138 139 140 141 142 143
  Carrier* carrier = GlobalMap<std::string, Carrier>::Get(carrier_id);
  // Set current running carrier
  if (*GlobalVal<std::string>::Get() != carrier_id) {
    GlobalVal<std::string>::Set(new std::string(carrier_id));
    // TODO(liyurui): Move barrier to service
    GlobalVal<MessageBus>::Get()->Barrier();
  }
  carrier->Start();
Y
Yuang Liu 已提交
144 145 146 147 148 149 150 151 152
  for (auto* micro_scop : microbatch_scopes_) {
    // By default, we should delete all kid scopes after run executor because
    // some operators may create local scope when running, such as while_op.
    // But when while_op also create a local executor to run it's sub block,
    // the sub scopes it created should not be dropped immediately, because
    // while_grad_op will use some variables created during while_op run, so
    // we need to keep the kids and wait for the outer executor to drop them.
    micro_scop->DropKids();
  }
L
LiYuRio 已提交
153 154
}

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
void FleetExecutor::CopyParameters(int microbatch_id,
                                   const framework::ProgramDesc& program) {
  auto& global_block = program.Block(0);

  for (auto& var : global_block.AllVars()) {
    if (var->Persistable() && microbatch_id == 0) {
      auto* ptr = root_scope_->Var(var->Name());
      InitializeVariable(ptr, var->GetType());
      VLOG(5) << "Create persistable var: " << var->Name()
              << ", which pointer is " << ptr;
    } else if (!var->Persistable()) {
      auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
      VLOG(5) << "Create variable " << var->Name() << " for microbatch "
              << microbatch_id << ", which pointer is " << ptr << ".";
      InitializeVariable(ptr, var->GetType());
    }
  }
L
LiYuRio 已提交
172 173 174 175
}

}  // namespace distributed
}  // namespace paddle