fast_threaded_ssa_graph_executor.cc 9.5 KB
Newer Older
Y
Stash  
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
15
#include <deque>
16
#include <memory>
Y
Stash  
yuyang18 已提交
17
#include <string>
18
#include <unordered_map>
C
chengduo 已提交
19
#include <unordered_set>
Y
Stash  
yuyang18 已提交
20 21 22
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/framework/ir/graph_helper.h"
24
#include "paddle/fluid/platform/profiler.h"
Y
Stash  
yuyang18 已提交
25 26 27 28 29 30 31

namespace paddle {
namespace framework {
namespace details {

FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
    const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
32
    const std::vector<Scope *> &local_exec_scopes,
X
Xin Pan 已提交
33
    const std::vector<platform::Place> &places, ir::Graph *graph)
Y
Stash  
yuyang18 已提交
34 35
    : strategy_(strategy),
      local_scopes_(local_scopes),
36
      local_exec_scopes_(local_exec_scopes),
Y
Stash  
yuyang18 已提交
37
      places_(places),
X
Xin Pan 已提交
38
      graph_(graph),
C
chengduo 已提交
39
      fetch_ctxs_(places),
S
sneaxiy 已提交
40
      pool_(strategy.num_threads_),
C
chengduo 已提交
41 42
      // add one more thread for generate op_deps
      prepare_pool_(1) {
X
Xin Pan 已提交
43
  for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
Y
Stash  
yuyang18 已提交
44
    int dep = static_cast<int>(op->NotReadyInputSize());
X
clean1  
Xin Pan 已提交
45
    op_deps_.emplace(op, dep);
Y
Stash  
yuyang18 已提交
46
    if (dep == 0) {
X
clean1  
Xin Pan 已提交
47
      bootstrap_ops_.emplace_back(op);
Y
Stash  
yuyang18 已提交
48 49
    }
  }
50
  PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators.");
Y
Stash  
yuyang18 已提交
51 52 53
  PrepareAtomicOpDeps();
}

Z
Zhen Wang 已提交
54 55
FetchResultType FastThreadedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors, bool return_merged) {
56
  VLOG(3) << "enter FastThreadedSSAGraphExecutor Run";
57 58
  std::unique_ptr<platform::RecordEvent> event(
      new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare"));
Y
Stash  
yuyang18 已提交
59 60 61
  std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
      op_deps = atomic_op_deps_.get();
  PrepareAtomicOpDeps();
62
  size_t num_ops = op_deps->size();
Y
Stash  
yuyang18 已提交
63

Z
Zhen Wang 已提交
64 65 66 67 68 69
  FetchResultType fetches;
  if (return_merged) {
    fetches = FeedFetchList(fetch_tensors.size());
  } else {
    fetches = FetchUnmergedList(fetch_tensors.size());
  }
Y
Stash  
yuyang18 已提交
70
  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
71
  std::vector<OpHandleBase *> fetch_ops;
72
  std::vector<OpHandleBase *> ready_fetch_ops;
73 74 75
  exception_.Clear();

  InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
Z
Zhen Wang 已提交
76
                 &fetch_ops, &ready_fetch_ops, return_merged);
77
  event.reset(nullptr);
78 79 80 81 82 83
  if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
    // If the num_threads is 1, we can record the order of operator's
    // execution in the first iteration, and in subsequent iterations,
    // run the recorded operators directly. This strategy could make the
    // execution faster.
    VLOG(3) << "Run the traced ops.";
84 85 86
    bool is_exception_free =
        RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
    if (!is_exception_free) {
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
      ExecutionFinal(&fetch_ops);
    }
  } else {
    traced_ops_.clear();
    remaining_ = 0;
    auto complete_q = std::make_shared<BlockingQueue<size_t>>();
    for (auto op : bootstrap_ops_) {
      RunOpAsync(op_deps.get(), op, complete_q);
    }
    for (auto op : ready_fetch_ops) {
      RunOpAsync(op_deps.get(), op, complete_q);
    }

    size_t num_complete = 0;
    while (num_complete != op_deps->size()) {
      size_t num_comp = complete_q->Pop();
      if (num_comp == -1UL) {
        int remaining = 0;
        while (true) {
          remaining = remaining_;
          if (remaining == 0) {
            break;
          }
          for (int i = 0; i < remaining; ++i) {
            complete_q->Pop();
          }
        }
        if (exception_.IsCaught()) {
          ExecutionFinal(&fetch_ops);
        }
      }
      num_complete += num_comp;
    }
  }
  // Wait FetchOps.
  ClearFetchOp(graph_, &fetch_ops);
  return fetches;
}
Y
Stash  
yuyang18 已提交
125

126
void FastThreadedSSAGraphExecutor::InsertFetchOps(
Z
Zhen Wang 已提交
127
    const std::vector<std::string> &fetch_tensors, FetchResultType *fetches,
128 129 130
    std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
    std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
    std::vector<OpHandleBase *> *fetch_ops,
Z
Zhen Wang 已提交
131
    std::vector<OpHandleBase *> *ready_fetch_ops, bool return_merged) {
C
chengduo 已提交
132 133 134
  std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(),
                                                   fetch_tensors.end());
  for (auto &fetch_var_name : fetch_tensor_set) {
135
    for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
Y
Stash  
yuyang18 已提交
136 137
      auto it = var_map.find(fetch_var_name);
      if (it != var_map.end()) {
138
        (*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
Y
Stash  
yuyang18 已提交
139 140 141 142 143
      }
    }
  }

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
144 145
    auto &var_name = fetch_tensors.at(i);
    auto fetched_var_it = fetched_vars->find(var_name);
146 147 148
    PADDLE_ENFORCE_NE(
        fetched_var_it, fetched_vars->end(),
        platform::errors::PreconditionNotMet(
149 150 151 152
            "Cannot find fetched variable(%s) in current computation graph. "
            "Possible reasons are:\n"
            "  1. The variable to be fetched is not defined in main program.\n"
            "  2. The variable to be fetched is not an input or output of any "
153 154
            "operator.",
            var_name));
Y
Stash  
yuyang18 已提交
155 156 157

    auto &vars = fetched_var_it->second;

X
Xin Pan 已提交
158 159
    ir::Node *fetch_node =
        graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
160
    auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
Z
Zhen Wang 已提交
161
                                 &local_exec_scopes_, return_merged);
162
    fetch_ops->emplace_back(op);
Y
Stash  
yuyang18 已提交
163 164 165 166 167 168 169 170 171

    for (auto &p : places_) {
      op->SetDeviceContext(p, fetch_ctxs_.Get(p));
    }

    for (auto *var : vars) {
      op->AddInput(var);
    }

172 173 174
    int dep = static_cast<int>(op->NotReadyInputSize());
    (*op_deps)[op] = dep;
    if (dep == 0) {
175
      ready_fetch_ops->emplace_back(op);
Y
Stash  
yuyang18 已提交
176 177 178
    }
  }
}
M
minqiyang 已提交
179

Z
Zeng Jinle 已提交
180 181 182
bool FastThreadedSSAGraphExecutor::RunOp(
    OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
    size_t *complete) {
183 184
  RunOpSync(op);
  if (LIKELY(!exception_.IsCaught())) {
Z
Zeng Jinle 已提交
185
    if (LIKELY(!strategy_.dry_run_)) {
186
      RecordOps(op);
Z
Zeng Jinle 已提交
187 188 189
    }
    ++(*complete);
    return true;
190
  } else {
Z
Zeng Jinle 已提交
191 192 193 194 195 196
    --remaining_;
    complete_q->Push(-1UL);
    return false;
  }
}

Y
Stash  
yuyang18 已提交
197 198
void FastThreadedSSAGraphExecutor::RunOpAsync(
    std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
199 200
    OpHandleBase *op,
    const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
Y
Stash  
yuyang18 已提交
201 202
  ++remaining_;
  this->pool_.enqueue([=] {
203 204
    std::deque<OpHandleBase *> op_queue;
    op_queue.push_front(op);
Z
Zeng Jinle 已提交
205

Y
Stash  
yuyang18 已提交
206
    size_t complete = 0;
Z
Zeng Jinle 已提交
207
    while (!op_queue.empty()) {
208 209
      OpHandleBase *op_to_run = op_queue.back();
      op_queue.pop_back();
Z
Zeng Jinle 已提交
210 211

      if (!RunOp(op_to_run, complete_q, &complete)) {
Y
Stash  
yuyang18 已提交
212 213
        return;
      }
Z
Zeng Jinle 已提交
214

Y
Stash  
yuyang18 已提交
215 216 217 218 219
      auto &outputs = op_to_run->Outputs();
      op_to_run = nullptr;
      for (auto &output : outputs) {
        for (auto &pending_op : output->PendingOps()) {
          std::atomic<int> &deps = op_deps->at(pending_op);
Z
Zeng Jinle 已提交
220 221 222 223 224
          if (deps.fetch_sub(1) != 1) continue;

          // NOTE(zjl): op with highest priority should run
          // first without switching to another thread.
          if (pending_op->GetPriority() == OpHandleBase::Priority::kHighest) {
225
            op_queue.push_back(pending_op);
Z
Zeng Jinle 已提交
226
          } else {
Y
Stash  
yuyang18 已提交
227 228 229
            if (op_to_run == nullptr) {
              op_to_run = pending_op;
            } else {
230
              RunOpAsync(op_deps, pending_op, complete_q);
Y
Stash  
yuyang18 已提交
231 232 233 234
            }
          }
        }
      }
Z
Zeng Jinle 已提交
235

236 237 238
      if (op_to_run != nullptr) {
        op_queue.push_front(op_to_run);
      }
Y
Stash  
yuyang18 已提交
239 240 241 242 243
    }
    --remaining_;
    complete_q->Push(complete);
  });
}
244

Y
Stash  
yuyang18 已提交
245
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
S
sneaxiy 已提交
246 247
  atomic_op_deps_ = prepare_pool_.enqueue([&] {
    auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
Y
Stash  
yuyang18 已提交
248 249 250 251 252 253 254
    for (auto &pair : op_deps_) {
      (*op_deps)[pair.first] = pair.second;
    }
    return std::unique_ptr<
        std::unordered_map<OpHandleBase *, std::atomic<int>>>(op_deps);
  });
}
Y
yuyang18 已提交
255 256

const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
257 258 259 260 261 262 263 264 265 266 267 268 269 270

void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
  if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
    traced_ops_.emplace_back(op);
  }
}

void FastThreadedSSAGraphExecutor::ExecutionFinal(
    std::vector<OpHandleBase *> *fetch_ops) {
  VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
  ClearFetchOp(graph_, fetch_ops);
  exception_.ReThrow();
}

271
bool FastThreadedSSAGraphExecutor::RunTracedOps(
272 273
    const std::vector<OpHandleBase *> &traced_ops) {
  for (auto &op : traced_ops) {
274
    if (!RunOpSync(op)) return false;
275
  }
276
  return true;
277 278
}

279
bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
280
  try {
281
    VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
282 283 284 285
    if (LIKELY(!strategy_.dry_run_)) {
      op->Run(strategy_.use_cuda_);
    }
    VLOG(10) << op << " " << op->Name() << " Done ";
286
    return true;
287 288
  } catch (...) {
    exception_.Catch(std::current_exception());
289
    return false;
290 291 292
  }
}

Y
Stash  
yuyang18 已提交
293 294 295
}  // namespace details
}  // namespace framework
}  // namespace paddle