async_ssa_graph_executor.cc 7.1 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"

Q
Qiao Longfei 已提交
17
#include "paddle/fluid/framework/variable_helper.h"
18 19

#ifdef PADDLE_WITH_DISTRIBUTE
Q
Qiao Longfei 已提交
20
#include "paddle/fluid/operators/distributed/communicator.h"
21
#endif
Q
Qiao Longfei 已提交
22

Q
Qiao Longfei 已提交
23 24 25 26
namespace paddle {
namespace framework {
namespace details {

Q
Qiao Longfei 已提交
27 28
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
                                    Scope *scope) {
Q
Qiao Longfei 已提交
29
  VLOG(3) << "NewTempScopeAndInitVars";
Q
Qiao Longfei 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
  Scope &local_scope = scope->NewScope();
  *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
      &local_scope;

  for (auto &info : var_infos) {
    if (scope->FindVar(info.name_) != nullptr) {
      continue;
    }

    if (info.persistable_) {  // Persistable
      InitializeVariable(scope->Var(info.name_), info.type_);
    } else {
      InitializeVariable(local_scope.Var(info.name_), info.type_);
    }
  }
}

Q
Qiao Longfei 已提交
47 48
// get RpcContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
49
#ifdef PADDLE_WITH_DISTRIBUTE
Q
Qiao Longfei 已提交
50
  using RpcCtxMap = operators::distributed::RpcCtxMap;
Q
Qiao Longfei 已提交
51
  VLOG(3) << "ProcessGraph";
Q
Qiao Longfei 已提交
52 53
  RpcCtxMap send_varname_to_ctx;
  RpcCtxMap recv_varname_to_ctx;
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
  for (auto &node : graphs[0]->Nodes()) {
    VLOG(3) << "node name " << node->Name();
    if (node && node->IsOp()) {
      if (node->Name() == "send") {
        auto send_var_name = node->Op()->Input("X")[0];
        auto send_varnames = boost::get<std::vector<std::string>>(
            node->Op()->GetNullableAttr("send_varnames"));
        auto epmap = boost::get<std::vector<std::string>>(
            node->Op()->GetNullableAttr("epmap"));
        auto height_section = boost::get<std::vector<int64_t>>(
            node->Op()->GetNullableAttr("sections"));
        auto trainer_id =
            boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
        send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
            send_var_name, send_varnames, epmap, height_section, trainer_id);
        VLOG(3) << "find and init an send op: "
                << send_varname_to_ctx[send_var_name];
      } else if (node->Name() == "recv") {
        auto recv_var_name = node->Op()->Output("Out")[0];
        auto recv_varnames = boost::get<std::vector<std::string>>(
            node->Op()->GetNullableAttr("recv_varnames"));
        auto epmap = boost::get<std::vector<std::string>>(
            node->Op()->GetNullableAttr("epmap"));
        auto trainer_id =
            boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
        recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
            recv_var_name, recv_varnames, epmap, {}, trainer_id);
        VLOG(3) << "find and remove an recv op: "
                << recv_varname_to_ctx[recv_var_name];
Q
Qiao Longfei 已提交
83 84 85
      }
    }
  }
86

Q
Qiao Longfei 已提交
87
  // init communicator here
Q
can run  
Qiao Longfei 已提交
88
  if (send_varname_to_ctx.size() > 0) {
Q
Qiao Longfei 已提交
89
    VLOG(3) << "this is distribute mode, will use communicator";
Q
can run  
Qiao Longfei 已提交
90 91
    operators::distributed::Communicator::Init(send_varname_to_ctx,
                                               recv_varname_to_ctx, scope);
Q
Qiao Longfei 已提交
92
    operators::distributed::Communicator::GetInstance()->Start();
Q
can run  
Qiao Longfei 已提交
93
  }
94
#endif
Q
Qiao Longfei 已提交
95 96
}

Q
Qiao Longfei 已提交
97 98
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
    const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
Q
Qiao Longfei 已提交
99
    const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
Q
Qiao Longfei 已提交
100 101 102 103
    : strategy_(std::move(strategy)),
      local_scopes_(std::move(local_scopes)),
      pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
      places_(std::move(places)),
Q
Qiao Longfei 已提交
104
      graphs_(std::move(graphs)) {
Q
can run  
Qiao Longfei 已提交
105
  VLOG(3) << "build AsyncSSAGraphExecutor";
Q
Qiao Longfei 已提交
106 107 108 109 110 111 112 113 114 115
  PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());

  // set the correct size of thread pool to each device.
  strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
                               ? 1UL
                               : strategy_.num_threads_ / places_.size();
  VLOG(1) << "set num_threads: " << strategy_.num_threads_
          << " to run the operators of the graph on each device.";
  for (size_t i = 0; i < places.size(); ++i) {
    executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
Q
Qiao Longfei 已提交
116
        strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
Q
Qiao Longfei 已提交
117 118
  }

Q
Qiao Longfei 已提交
119 120 121 122 123 124 125 126 127 128 129
  for (auto &node : graphs_[0]->Nodes()) {
    if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
      var_infos_.emplace_back();
      var_infos_.back().name_ = node->Var()->Name();
      var_infos_.back().type_ = node->Var()->GetType();
      var_infos_.back().persistable_ = node->Var()->Persistable();
    }
  }
  for (auto *scope : local_scopes_) {
    NewTempScopeAndInitVars(var_infos_, scope);
  }
Q
can run  
Qiao Longfei 已提交
130
  ProcessGraph(graphs_, local_scopes_[0]);
Q
Qiao Longfei 已提交
131
}
Q
Qiao Longfei 已提交
132

Q
Qiao Longfei 已提交
133 134 135 136 137
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
  VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
  for (size_t i = 1; i < places_.size(); ++i) {
    auto call = [this, i]() -> void {
      VLOG(3) << "start off python thread " << i;
Q
Qiao Longfei 已提交
138
      try {
Q
Qiao Longfei 已提交
139 140 141
        while (true) {
          executors_[i]->Run({});
        }
Q
Qiao Longfei 已提交
142 143
      } catch (...) {
        exception_holder_.Catch(std::current_exception());
Q
Qiao Longfei 已提交
144
        VLOG(3) << "get exception type = " << exception_holder_.Type();
Q
Qiao Longfei 已提交
145
      }
Q
Qiao Longfei 已提交
146
      VLOG(3) << "thread " << i << " exited!";
Q
Qiao Longfei 已提交
147
    };
Q
Qiao Longfei 已提交
148
    run_futures_.emplace_back(pool_->enqueue(std::move(call)));
Q
Qiao Longfei 已提交
149
  }
Q
Qiao Longfei 已提交
150
}
Q
Qiao Longfei 已提交
151

Q
Qiao Longfei 已提交
152
void AsyncSSAGraphExecutor::HandleException() {
Q
Qiao Longfei 已提交
153
  if (exception_holder_.IsCaught()) {
Q
Qiao Longfei 已提交
154 155 156 157
    for (auto &f : run_futures_) {
      VLOG(3) << "wait future";
      f.wait();
    }
Q
Qiao Longfei 已提交
158 159
    VLOG(3) << "caught exception " << exception_holder_.Type()
            << ", rethrow it";
Q
Qiao Longfei 已提交
160
    run_futures_.clear();
Q
Qiao Longfei 已提交
161 162
    exception_holder_.ReThrow();
  }
Q
Qiao Longfei 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
}

FeedFetchList AsyncSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  // init once
  if (run_futures_.size() == 0 && places_.size() > 1) {
    exception_holder_.Clear();
    StartOffPythonTrainLoop();
  }

  if (places_.size() == 1) {
    exception_holder_.Clear();
  } else {
    HandleException();
  }

  FeedFetchList fetch_data;
  fetch_data.reserve(fetch_tensors.size());

  try {
    fetch_data = executors_[0]->Run(fetch_tensors);
  } catch (...) {
    exception_holder_.Catch(std::current_exception());
  }

  HandleException();
Q
Qiao Longfei 已提交
189

Q
Qiao Longfei 已提交
190
  FeedFetchList ret;
Q
Qiao Longfei 已提交
191 192
  for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
    std::vector<const LoDTensor *> lodtensor_ptrs;
Q
Qiao Longfei 已提交
193
    lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
Q
Qiao Longfei 已提交
194 195 196 197 198 199 200 201 202
    ret.emplace_back();
    ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
  }
  return ret;
}

}  // namespace details
}  // namespace framework
}  // namespace paddle