fast_threaded_ssa_graph_executor.cc 12.1 KB
Newer Older
Y
Stash  
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
15

16
#include <deque>
17
#include <memory>
Y
Stash  
yuyang18 已提交
18
#include <string>
19
#include <unordered_map>
C
chengduo 已提交
20
#include <unordered_set>
Y
Stash  
yuyang18 已提交
21
#include <vector>
22

23 24
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
Y
Stash  
yuyang18 已提交
25
#include "paddle/fluid/framework/details/multi_devices_helper.h"
X
Xin Pan 已提交
26
#include "paddle/fluid/framework/ir/graph_helper.h"
27
#include "paddle/fluid/platform/profiler.h"
Y
Stash  
yuyang18 已提交
28 29 30 31 32 33 34

namespace paddle {
namespace framework {
namespace details {

FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
    const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
35
    const std::vector<Scope *> &local_exec_scopes,
X
Xin Pan 已提交
36
    const std::vector<platform::Place> &places, ir::Graph *graph)
Y
Stash  
yuyang18 已提交
37 38
    : strategy_(strategy),
      local_scopes_(local_scopes),
39
      local_exec_scopes_(local_exec_scopes),
Y
Stash  
yuyang18 已提交
40
      places_(places),
X
Xin Pan 已提交
41
      graph_(graph),
C
chengduo 已提交
42 43 44
      fetch_ctxs_(places),
      // add one more thread for generate op_deps
      prepare_pool_(1) {
Z
Zeng Jinle 已提交
45 46 47 48 49 50
  if (ir::IsTopologySortOperationsUnique(*graph_)) {
    VLOG(10)
        << "Change thread number to 1 because the toposort order is unique";
    strategy_.num_threads_ = 1;
  }
  pool_.reset(new ::ThreadPool(strategy.num_threads_));
X
Xin Pan 已提交
51
  for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
Y
Stash  
yuyang18 已提交
52
    int dep = static_cast<int>(op->NotReadyInputSize());
X
clean1  
Xin Pan 已提交
53
    op_deps_.emplace(op, dep);
Y
Stash  
yuyang18 已提交
54
    if (dep == 0) {
X
clean1  
Xin Pan 已提交
55
      bootstrap_ops_.emplace_back(op);
Y
Stash  
yuyang18 已提交
56 57
    }
  }
58 59 60
  PADDLE_ENFORCE_GT(op_deps_.size(), 0,
                    platform::errors::PreconditionNotMet(
                        "The graph doesn't have operators."));
Y
Stash  
yuyang18 已提交
61 62 63
  PrepareAtomicOpDeps();
}

Z
Zhen Wang 已提交
64 65
FetchResultType FastThreadedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors, bool return_merged) {
66
  VLOG(3) << "enter FastThreadedSSAGraphExecutor Run";
67 68
  std::unique_ptr<platform::RecordEvent> event(
      new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare"));
Y
Stash  
yuyang18 已提交
69 70 71
  std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
      op_deps = atomic_op_deps_.get();
  PrepareAtomicOpDeps();
72
  size_t num_ops = op_deps->size();
Y
Stash  
yuyang18 已提交
73

Z
Zhen Wang 已提交
74 75
  FetchResultType fetches;
  if (return_merged) {
76
    fetches = FetchList(fetch_tensors.size());
Z
Zhen Wang 已提交
77 78 79
  } else {
    fetches = FetchUnmergedList(fetch_tensors.size());
  }
Y
Stash  
yuyang18 已提交
80
  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
81
  std::vector<OpHandleBase *> fetch_ops;
82
  std::vector<OpHandleBase *> ready_fetch_ops;
83 84
  exception_.Clear();
  InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
Z
Zhen Wang 已提交
85
                 &fetch_ops, &ready_fetch_ops, return_merged);
86
  event.reset(nullptr);
87 88 89 90 91 92
  if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
    // If the num_threads is 1, we can record the order of operator's
    // execution in the first iteration, and in subsequent iterations,
    // run the recorded operators directly. This strategy could make the
    // execution faster.
    VLOG(3) << "Run the traced ops.";
93 94 95
    bool is_exception_free =
        RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
    if (!is_exception_free) {
96 97 98 99 100 101
      ExecutionFinal(&fetch_ops);
    }
  } else {
    traced_ops_.clear();
    remaining_ = 0;
    auto complete_q = std::make_shared<BlockingQueue<size_t>>();
102 103
    VLOG(3) << "number of bootstrap_ops_: " << bootstrap_ops_.size();
    VLOG(3) << "number of ready_fetch_ops: " << ready_fetch_ops.size();
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    for (auto op : bootstrap_ops_) {
      RunOpAsync(op_deps.get(), op, complete_q);
    }
    for (auto op : ready_fetch_ops) {
      RunOpAsync(op_deps.get(), op, complete_q);
    }

    size_t num_complete = 0;
    while (num_complete != op_deps->size()) {
      size_t num_comp = complete_q->Pop();
      if (num_comp == -1UL) {
        int remaining = 0;
        while (true) {
          remaining = remaining_;
          if (remaining == 0) {
            break;
          }
          for (int i = 0; i < remaining; ++i) {
            complete_q->Pop();
          }
        }
        if (exception_.IsCaught()) {
          ExecutionFinal(&fetch_ops);
        }
      }
      num_complete += num_comp;
    }
  }
  // Wait FetchOps.
  ClearFetchOp(graph_, &fetch_ops);
134 135 136 137 138

  for (auto &place : places_) {
    fetch_ctxs_.Get(place)->Wait();
  }

139 140
  return fetches;
}
Y
Stash  
yuyang18 已提交
141

142
void FastThreadedSSAGraphExecutor::InsertFetchOps(
Z
Zhen Wang 已提交
143
    const std::vector<std::string> &fetch_tensors, FetchResultType *fetches,
144 145 146
    std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
    std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
    std::vector<OpHandleBase *> *fetch_ops,
Z
Zhen Wang 已提交
147
    std::vector<OpHandleBase *> *ready_fetch_ops, bool return_merged) {
C
chengduo 已提交
148 149 150
  std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(),
                                                   fetch_tensors.end());
  for (auto &fetch_var_name : fetch_tensor_set) {
151
    for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
Y
Stash  
yuyang18 已提交
152 153
      auto it = var_map.find(fetch_var_name);
      if (it != var_map.end()) {
154
        (*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
Y
Stash  
yuyang18 已提交
155 156 157 158 159
      }
    }
  }

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
160 161
    auto &var_name = fetch_tensors.at(i);
    auto fetched_var_it = fetched_vars->find(var_name);
162 163 164
    PADDLE_ENFORCE_NE(
        fetched_var_it, fetched_vars->end(),
        platform::errors::PreconditionNotMet(
165 166 167 168
            "Cannot find fetched variable(%s) in current computation graph. "
            "Possible reasons are:\n"
            "  1. The variable to be fetched is not defined in main program.\n"
            "  2. The variable to be fetched is not an input or output of any "
169 170 171 172 173 174 175
            "operator.\n"
            "  3. Confirm that you have used the fetch `Variable` format "
            "instead of the string literal('%s') in `fetch_list` parameter "
            "when using `executor.run` method. In other words, the format of "
            "`executor.run(fetch_list=[fetch_var])`(fetch_var is a Variable) "
            "is recommended.",
            var_name, var_name));
Y
Stash  
yuyang18 已提交
176 177 178

    auto &vars = fetched_var_it->second;

X
Xin Pan 已提交
179 180
    ir::Node *fetch_node =
        graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
181 182
    auto *op = new FetchAsyncOpHandle(fetch_node, fetches, i, &local_scopes_,
                                      &local_exec_scopes_, return_merged);
183
    fetch_ops->emplace_back(op);
Y
Stash  
yuyang18 已提交
184 185 186 187 188 189 190 191 192

    for (auto &p : places_) {
      op->SetDeviceContext(p, fetch_ctxs_.Get(p));
    }

    for (auto *var : vars) {
      op->AddInput(var);
    }

193 194 195 196 197 198 199 200
    for (auto *var : vars) {
      auto *op = var->GeneratedOp();
      auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
      if (compute_op) {
        compute_op->SetLockAndRecordEventFree(false);
      }
    }

201 202 203
    int dep = static_cast<int>(op->NotReadyInputSize());
    (*op_deps)[op] = dep;
    if (dep == 0) {
204
      ready_fetch_ops->emplace_back(op);
Y
Stash  
yuyang18 已提交
205 206 207
    }
  }
}
M
minqiyang 已提交
208

Z
Zeng Jinle 已提交
209 210 211
bool FastThreadedSSAGraphExecutor::RunOp(
    OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
    size_t *complete) {
212 213
  RunOpSync(op);
  if (LIKELY(!exception_.IsCaught())) {
Z
Zeng Jinle 已提交
214
    if (LIKELY(!strategy_.dry_run_)) {
215
      RecordOps(op);
Z
Zeng Jinle 已提交
216 217 218
    }
    ++(*complete);
    return true;
219
  } else {
Z
Zeng Jinle 已提交
220 221 222 223 224 225
    --remaining_;
    complete_q->Push(-1UL);
    return false;
  }
}

Y
Stash  
yuyang18 已提交
226 227
void FastThreadedSSAGraphExecutor::RunOpAsync(
    std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
228 229
    OpHandleBase *op,
    const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
Y
Stash  
yuyang18 已提交
230
  ++remaining_;
Z
Zeng Jinle 已提交
231
  this->pool_->enqueue([=] {
232 233
    std::deque<OpHandleBase *> op_queue;
    op_queue.push_front(op);
Z
Zeng Jinle 已提交
234

Y
Stash  
yuyang18 已提交
235
    size_t complete = 0;
Z
Zeng Jinle 已提交
236
    while (!op_queue.empty()) {
237 238
      OpHandleBase *op_to_run = op_queue.back();
      op_queue.pop_back();
Z
Zeng Jinle 已提交
239

W
WangXi 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
      // The Op involves data transfer of multiple devices may block other
      // computations emit. For example:
      // 1 step, queue=[Share, Allreduce], which Share is high priority
      // 2 step, Share exec, pending_op=Grad, queue=[Allreduce, Grad]
      // 3 step, Allreduce run with sync. Although Allreduce and Grad do not
      // have topo dependency, but Grad must wait for Allreduce to complete
      // before scheduling.
      // In this scenario, calculation and communication may not overlap.
      // Therefore, emit the op in the queue before running multi device op.
      if (op_to_run->IsMultiDeviceTransfer()) {
        while (!op_queue.empty()) {
          OpHandleBase *post_op = op_queue.back();
          op_queue.pop_back();
          RunOpAsync(op_deps, post_op, complete_q);
        }
      }
256
      VLOG(3) << "start to run op: " << op_to_run->Name();
Z
Zeng Jinle 已提交
257
      if (!RunOp(op_to_run, complete_q, &complete)) {
Y
Stash  
yuyang18 已提交
258 259 260 261 262 263 264
        return;
      }
      auto &outputs = op_to_run->Outputs();
      op_to_run = nullptr;
      for (auto &output : outputs) {
        for (auto &pending_op : output->PendingOps()) {
          std::atomic<int> &deps = op_deps->at(pending_op);
Z
Zeng Jinle 已提交
265 266 267 268 269
          if (deps.fetch_sub(1) != 1) continue;

          // NOTE(zjl): op with highest priority should run
          // first without switching to another thread.
          if (pending_op->GetPriority() == OpHandleBase::Priority::kHighest) {
270
            op_queue.push_back(pending_op);
W
WangXi 已提交
271 272 273
          } else if (pending_op->IsMultiDeviceTransfer()) {
            // multi device ops should be scheduled prior to computing ops
            op_queue.push_front(pending_op);
Z
Zeng Jinle 已提交
274
          } else {
Y
Stash  
yuyang18 已提交
275 276 277
            if (op_to_run == nullptr) {
              op_to_run = pending_op;
            } else {
278
              RunOpAsync(op_deps, pending_op, complete_q);
Y
Stash  
yuyang18 已提交
279 280 281 282
            }
          }
        }
      }
Z
Zeng Jinle 已提交
283

284 285 286
      if (op_to_run != nullptr) {
        op_queue.push_front(op_to_run);
      }
Y
Stash  
yuyang18 已提交
287 288 289 290 291
    }
    --remaining_;
    complete_q->Push(complete);
  });
}
292

Y
Stash  
yuyang18 已提交
293
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
S
sneaxiy 已提交
294 295
  atomic_op_deps_ = prepare_pool_.enqueue([&] {
    auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
Y
Stash  
yuyang18 已提交
296 297 298 299 300 301 302
    for (auto &pair : op_deps_) {
      (*op_deps)[pair.first] = pair.second;
    }
    return std::unique_ptr<
        std::unordered_map<OpHandleBase *, std::atomic<int>>>(op_deps);
  });
}
Y
yuyang18 已提交
303 304

const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
305 306

void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
307
  if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchAsyncOpHandle *>(op)) {
308 309 310 311 312 313 314
    traced_ops_.emplace_back(op);
  }
}

void FastThreadedSSAGraphExecutor::ExecutionFinal(
    std::vector<OpHandleBase *> *fetch_ops) {
  VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
315 316 317 318 319 320 321 322
  // NOTE: If a new exception occurs in this ClearFetchOp operation, it will
  // cause the loss of exception triggered firstly not thrown.
  // Instead, the cleanup operation should only be performed when an EOF
  // exception is caught. If other exceptions are triggered, the ClearFetchOp
  // should not be continued.
  if (exception_.Type() == "EOF") {
    ClearFetchOp(graph_, fetch_ops);
  }
323 324 325
  exception_.ReThrow();
}

326
bool FastThreadedSSAGraphExecutor::RunTracedOps(
327 328
    const std::vector<OpHandleBase *> &traced_ops) {
  for (auto &op : traced_ops) {
329
    if (!RunOpSync(op)) return false;
330
  }
331
  return true;
332 333
}

334
bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
335
  try {
336
    VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
337
    if (LIKELY(!strategy_.dry_run_)) {
338
      op->Run(strategy_.use_device_);
339 340
    }
    VLOG(10) << op << " " << op->Name() << " Done ";
341
    return true;
342 343
  } catch (...) {
    exception_.Catch(std::current_exception());
344
    return false;
345 346 347
  }
}

Y
Stash  
yuyang18 已提交
348 349 350
}  // namespace details
}  // namespace framework
}  // namespace paddle