threaded_ssa_graph_executor.cc 8.6 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"

X
Xin Pan 已提交
17 18
#include "paddle/fluid/framework/details/ssa_graph_builder.h"

Y
Yu Yang 已提交
19 20 21 22
namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
Y
yuyang18 已提交
23
    const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
X
Xin Pan 已提交
24 25
    const std::vector<platform::Place> &places,
    std::unique_ptr<ir::Graph> &&graph)
Y
yuyang18 已提交
26
    : graph_(std::move(graph)),
Y
yuyang18 已提交
27 28
      pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
                                       : nullptr),
Y
Yu Yang 已提交
29 30 31
      local_scopes_(local_scopes),
      places_(places),
      fetch_ctxs_(places),
X
Xin Pan 已提交
32
      running_ops_(0),
Y
yuyang18 已提交
33
      strategy_(strategy) {}
X
Xin Pan 已提交
34

Y
Yu Yang 已提交
35 36 37
FeedFetchList ThreadedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  std::unordered_map<OpHandleBase *, size_t> pending_ops;
Y
Yu Yang 已提交
38 39
  std::unordered_set<VarHandleBase *> pending_vars;
  BlockingQueue<VarHandleBase *> ready_vars;
Y
Yu Yang 已提交
40
  std::unordered_set<OpHandleBase *> ready_ops;
X
Xin Pan 已提交
41 42 43 44
  // For ops (e.g. nccl_all_reduce) that need to coordinate multiple
  // streams from multiple GPUs, it's faster to buffer them and schedule
  // together since we currently cannot overlap computation and memcpy streams.
  // Should revisit it if overlapping is available.
X
Xin Pan 已提交
45 46
  std::unordered_set<OpHandleBase *> delayed_ops;

Y
Yu Yang 已提交
47
  // Transform SSAGraph to pending_ops & pending_vars
X
Xin Pan 已提交
48
  for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
Y
Yu Yang 已提交
49 50
    for (auto &name_pair : var_map) {
      for (auto &version_pair : name_pair.second) {
C
chengduoZH 已提交
51
        InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
Y
Yu Yang 已提交
52 53 54
      }
    }
  }
X
Xin Pan 已提交
55
  for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
C
chengduoZH 已提交
56
    InsertPendingVar(&pending_vars, &ready_vars, var.get());
Y
Yu Yang 已提交
57 58
  }

X
Xin Pan 已提交
59
  for (auto &op : graph_->Get<details::GraphOps>("ops")) {
X
Xin Pan 已提交
60
    if (op->Inputs().empty()) {  // Special case, Op has no input.
Y
Yu Yang 已提交
61 62
      ready_ops.insert(op.get());
    } else {
C
chengduoZH 已提交
63
      InsertPendingOp(&pending_ops, op.get());
Y
Yu Yang 已提交
64 65 66 67
    }
  }

  // Step 2. Insert FetchOps
Y
Yu Yang 已提交
68
  std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
X
Xin Pan 已提交
69
  std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
Y
Yu Yang 已提交
70
  std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
C
chengduoZH 已提交
71
  FeedFetchList fetch_data(fetch_tensors.size());
Y
Yu Yang 已提交
72

X
Xin Pan 已提交
73 74
  InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
                 &pending_ops, &pending_vars, &ready_vars, &fetch_data);
Y
Yu Yang 已提交
75

Y
Yu Yang 已提交
76 77
  auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
    for (auto *op : set) {
X
Xin Pan 已提交
78 79
      running_ops_++;
      RunOp(&ready_vars, op);
Y
Yu Yang 已提交
80
    }
Y
Yu Yang 已提交
81
    set.clear();
Y
Yu Yang 已提交
82 83
  };

Y
yuyang18 已提交
84 85 86 87
  // Clean run context
  run_op_futures_.clear();
  exception_.reset();

Y
Yu Yang 已提交
88
  // Step 3. Execution
Y
Yu Yang 已提交
89
  while (!pending_vars.empty()) {
Y
Yu Yang 已提交
90
    // 1. Run All Ready ops
Y
Yu Yang 已提交
91 92 93 94
    // Keep loop until all vars are ready.
    //
    // NOTE: DelayedOps have a lower priority. It will be scheduled after all
    // ready_ops have been performed.
Y
yuyang18 已提交
95
    if (ready_ops.empty() && strategy_.allow_op_delay_ && running_ops_ == 0) {
Y
Yu Yang 已提交
96 97 98 99
      run_all_ops(delayed_ops);
    } else {
      run_all_ops(ready_ops);
    }
Y
Yu Yang 已提交
100 101

    // 2. Find ready variable
Y
Yu Yang 已提交
102
    bool timeout;
X
Xin Pan 已提交
103
    auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
Y
Yu Yang 已提交
104 105

    if (timeout) {
F
fengjiayi 已提交
106
      std::unique_lock<std::mutex> l(exception_mu_);
Y
Yu Yang 已提交
107
      if (exception_) {
F
fengjiayi 已提交
108
        l.unlock();
Y
yuyang18 已提交
109 110 111
        for (auto &run_op_future : run_op_futures_) {
          run_op_future.wait();
        }
F
fengjiayi 已提交
112
        l.lock();
113 114 115 116 117 118 119 120 121 122
        std::exception *exp = exception_.get();
        if (dynamic_cast<platform::EOFException *>(exp)) {
          auto e = *static_cast<platform::EOFException *>(exp);
          throw e;
        } else if (dynamic_cast<platform::EnforceNotMet *>(exp)) {
          auto e = *static_cast<platform::EnforceNotMet *>(exp);
          throw e;
        } else {
          LOG(FATAL) << "Unknown exception.";
        }
Y
Yu Yang 已提交
123 124 125 126
      } else {
        continue;
      }
    }
Y
Yu Yang 已提交
127 128
    // 3. Remove the dependency of ready_var.
    // Find the ready_ops after the ready_var.
Y
Yu Yang 已提交
129 130
    for (auto ready_var : cur_ready_vars) {
      pending_vars.erase(ready_var);
X
Xin Pan 已提交
131
      for (auto *op : ready_var->PendingOps()) {
Y
Yu Yang 已提交
132 133 134
        auto &deps = pending_ops[op];
        --deps;
        if (deps == 0) {
Y
yuyang18 已提交
135
          if (op->IsMultiDeviceTransfer() && strategy_.allow_op_delay_) {
Y
Yu Yang 已提交
136
            delayed_ops.insert(op);
X
Xin Pan 已提交
137 138 139
          } else {
            ready_ops.insert(op);
          }
Y
Yu Yang 已提交
140
        }
Y
Yu Yang 已提交
141 142 143
      }
    }
  }
X
Xin Pan 已提交
144
  PADDLE_ENFORCE(ready_ops.empty());
Y
Yu Yang 已提交
145

Y
Yu Yang 已提交
146
  // Wait FetchOps.
Y
Yu Yang 已提交
147
  if (!fetch_ops.empty()) {
Y
Yu Yang 已提交
148
    fetch_ops.clear();
Y
Yu Yang 已提交
149 150
  }

Y
Yu Yang 已提交
151 152 153
  return fetch_data;
}

C
chengduoZH 已提交
154 155 156
void ThreadedSSAGraphExecutor::InsertFetchOps(
    const std::vector<std::string> &fetch_tensors,
    std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
X
Xin Pan 已提交
157
    std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
C
chengduoZH 已提交
158 159 160 161 162 163 164
    std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
    std::unordered_map<OpHandleBase *, size_t> *pending_ops,
    std::unordered_set<VarHandleBase *> *pending_vars,
    BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;

  for (auto &fetch_var_name : fetch_tensors) {
X
Xin Pan 已提交
165
    for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
C
chengduoZH 已提交
166 167 168 169 170 171 172 173 174
      auto it = var_map.find(fetch_var_name);
      if (it != var_map.end()) {
        fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
      }
    }
  }

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
    auto &var_name = fetch_tensors[i];
Y
yuyang18 已提交
175 176 177 178 179 180
    auto fetched_var_it = fetched_vars.find(var_name);
    PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
                   "Cannot find fetched variable.(Perhaps the main_program "
                   "is not set to ParallelExecutor)");

    auto &vars = fetched_var_it->second;
X
Xin Pan 已提交
181

X
polish  
Xin Pan 已提交
182
    temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
X
clean  
Xin Pan 已提交
183 184
    auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
                                 &local_scopes_);
C
chengduoZH 已提交
185 186 187 188 189 190 191 192 193 194
    fetch_ops->emplace_back(op);

    for (auto &p : places_) {
      op->SetDeviceContext(p, fetch_ctxs_.Get(p));
    }

    for (auto *var : vars) {
      op->AddInput(var);
    }

X
polish  
Xin Pan 已提交
195
    temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
X
clean  
Xin Pan 已提交
196
    auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
C
chengduoZH 已提交
197 198 199 200 201 202 203 204 205 206
    op->AddOutput(fetch_dummy);
    fetch_dependencies->emplace(fetch_dummy);
    this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
    this->InsertPendingOp(pending_ops, op);
  }
}

void ThreadedSSAGraphExecutor::InsertPendingOp(
    std::unordered_map<OpHandleBase *, size_t> *pending_ops,
    OpHandleBase *op_instance) const {
207
  pending_ops->insert({op_instance, op_instance->NoDupInputSize()});
C
chengduoZH 已提交
208 209 210 211 212 213
}

void ThreadedSSAGraphExecutor::InsertPendingVar(
    std::unordered_set<VarHandleBase *> *pending_vars,
    BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
  pending_vars->insert(var);
X
Xin Pan 已提交
214
  if (var->GeneratedOp() == nullptr) {
C
chengduoZH 已提交
215 216 217
    ready_vars->Push(var);
  }
}
C
chengduoZH 已提交
218

Y
Yu Yang 已提交
219
void ThreadedSSAGraphExecutor::RunOp(
X
Xin Pan 已提交
220 221
    BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
  auto op_run = [ready_var_q, op, this] {
Y
Yu Yang 已提交
222
    try {
Y
yuyang18 已提交
223 224 225
      if (VLOG_IS_ON(10)) {
        VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
      }
226
      op->Run(strategy_.use_cuda_);
Y
Yu Yang 已提交
227
      VLOG(10) << op << " " << op->Name() << " Done ";
X
Xin Pan 已提交
228
      running_ops_--;
X
Xin Pan 已提交
229
      ready_var_q->Extend(op->Outputs());
Y
Yu Yang 已提交
230
      VLOG(10) << op << " " << op->Name() << "Signal posted";
231 232 233 234 235 236
    } catch (platform::EOFException ex) {
      std::lock_guard<std::mutex> l(exception_mu_);
      // EOFException will not cover up existing EnforceNotMet.
      if (exception_.get() == nullptr) {
        exception_.reset(new platform::EOFException(ex));
      }
Y
Yu Yang 已提交
237
    } catch (platform::EnforceNotMet ex) {
238
      std::lock_guard<std::mutex> l(exception_mu_);
Y
Yu Yang 已提交
239 240 241 242 243 244
      exception_.reset(new platform::EnforceNotMet(ex));
    } catch (...) {
      LOG(FATAL) << "Unknown exception catched";
    }
  };
  if (pool_) {
Y
yuyang18 已提交
245
    run_op_futures_.emplace_back(pool_->enqueue(op_run));
Y
Yu Yang 已提交
246 247 248 249 250 251 252
  } else {
    op_run();
  }
}
}  // namespace details
}  // namespace framework
}  // namespace paddle