fast_threaded_ssa_graph_executor.cc 5.9 KB
Newer Older
Y
Stash  
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
15
#include <memory>
Y
Stash  
yuyang18 已提交
16
#include <string>
17
#include <unordered_map>
Y
Stash  
yuyang18 已提交
18 19 20
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
X
Xin Pan 已提交
21
#include "paddle/fluid/framework/ir/graph_helper.h"
Y
Stash  
yuyang18 已提交
22 23 24 25 26 27 28

namespace paddle {
namespace framework {
namespace details {

FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
    const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
X
Xin Pan 已提交
29
    const std::vector<platform::Place> &places, ir::Graph *graph)
Y
Stash  
yuyang18 已提交
30 31 32
    : strategy_(strategy),
      local_scopes_(local_scopes),
      places_(places),
X
Xin Pan 已提交
33
      graph_(graph),
C
chengduo 已提交
34
      fetch_ctxs_(places),
S
sneaxiy 已提交
35
      pool_(strategy.num_threads_),
C
chengduo 已提交
36 37
      // add one more thread for generate op_deps
      prepare_pool_(1) {
X
Xin Pan 已提交
38
  for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
Y
Stash  
yuyang18 已提交
39
    int dep = static_cast<int>(op->NotReadyInputSize());
X
clean1  
Xin Pan 已提交
40
    op_deps_.emplace(op, dep);
Y
Stash  
yuyang18 已提交
41
    if (dep == 0) {
X
clean1  
Xin Pan 已提交
42
      bootstrap_ops_.emplace_back(op);
Y
Stash  
yuyang18 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    }
  }

  PrepareAtomicOpDeps();
}

FeedFetchList FastThreadedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
      op_deps = atomic_op_deps_.get();
  PrepareAtomicOpDeps();

  paddle::framework::FeedFetchList fetches;
  fetches.resize(fetch_tensors.size());
  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
X
clean1  
Xin Pan 已提交
58
  std::vector<FetchOpHandle *> fetch_ops;
59
  std::vector<OpHandleBase *> ready_fetch_ops;
Y
Stash  
yuyang18 已提交
60 61

  for (auto &fetch_var_name : fetch_tensors) {
62
    for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
Y
Stash  
yuyang18 已提交
63 64
      auto it = var_map.find(fetch_var_name);
      if (it != var_map.end()) {
X
clean1  
Xin Pan 已提交
65
        fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
Y
Stash  
yuyang18 已提交
66 67 68 69 70 71 72 73
      }
    }
  }

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
    auto &var_name = fetch_tensors[i];
    auto fetched_var_it = fetched_vars.find(var_name);
    PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
74 75 76
                   "Cannot find fetched variable(%s).(Perhaps the main_program "
                   "is not set to ParallelExecutor)",
                   var_name);
Y
Stash  
yuyang18 已提交
77 78 79

    auto &vars = fetched_var_it->second;

X
Xin Pan 已提交
80 81 82
    ir::Node *fetch_node =
        graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
    auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
Y
Stash  
yuyang18 已提交
83 84 85 86 87 88 89 90 91 92
    fetch_ops.emplace_back(op);

    for (auto &p : places_) {
      op->SetDeviceContext(p, fetch_ctxs_.Get(p));
    }

    for (auto *var : vars) {
      op->AddInput(var);
    }

93 94 95 96 97
    int dep = static_cast<int>(op->NotReadyInputSize());
    (*op_deps)[op] = dep;
    if (dep == 0) {
      ready_fetch_ops.emplace_back(op);
    }
Y
Stash  
yuyang18 已提交
98 99 100 101
  }

  size_t num_complete = 0;
  remaining_ = 0;
102
  auto complete_q = std::make_shared<BlockingQueue<size_t>>();
Y
Stash  
yuyang18 已提交
103
  for (auto op : bootstrap_ops_) {
104
    RunOpAsync(op_deps.get(), op, complete_q);
Y
Stash  
yuyang18 已提交
105
  }
106 107 108
  for (auto op : ready_fetch_ops) {
    RunOpAsync(op_deps.get(), op, complete_q);
  }
Y
Stash  
yuyang18 已提交
109
  while (num_complete != op_deps->size()) {
110
    size_t num_comp = complete_q->Pop();
Y
Stash  
yuyang18 已提交
111
    if (num_comp == -1UL) {
Y
yuyang18 已提交
112 113 114 115 116 117 118
      int remaining = 0;
      while (true) {
        remaining = remaining_;
        if (remaining == 0) {
          break;
        }
        for (int i = 0; i < remaining; ++i) {
119
          complete_q->Pop();
Y
yuyang18 已提交
120
        }
Y
Stash  
yuyang18 已提交
121
      }
X
Xin Pan 已提交
122
      if (exception_.IsCaught()) {
X
Xin Pan 已提交
123
        ClearFetchOp(graph_, &fetch_ops);
X
Xin Pan 已提交
124 125
        exception_.ReThrow();
      }
Y
Stash  
yuyang18 已提交
126 127 128 129
    }
    num_complete += num_comp;
  }
  // Wait FetchOps.
X
Xin Pan 已提交
130
  ClearFetchOp(graph_, &fetch_ops);
Y
Stash  
yuyang18 已提交
131 132
  return fetches;
}
M
minqiyang 已提交
133

Y
Stash  
yuyang18 已提交
134 135
void FastThreadedSSAGraphExecutor::RunOpAsync(
    std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
136 137
    OpHandleBase *op,
    const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
Y
Stash  
yuyang18 已提交
138 139 140 141 142 143
  ++remaining_;
  this->pool_.enqueue([=] {
    OpHandleBase *op_to_run = op;
    size_t complete = 0;
    while (op_to_run != nullptr) {
      try {
144 145 146
        if (LIKELY(!strategy_.dry_run_)) {
          op_to_run->Run(strategy_.use_cuda_);
        }
Y
Stash  
yuyang18 已提交
147 148
        ++complete;
      } catch (...) {
Y
yuyang18 已提交
149
        exception_.Catch(std::current_exception());
Y
Stash  
yuyang18 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162
        --remaining_;
        complete_q->Push(-1UL);
        return;
      }
      auto &outputs = op_to_run->Outputs();
      op_to_run = nullptr;
      for (auto &output : outputs) {
        for (auto &pending_op : output->PendingOps()) {
          std::atomic<int> &deps = op_deps->at(pending_op);
          if (deps.fetch_sub(1) == 1) {  // pending_op ready
            if (op_to_run == nullptr) {
              op_to_run = pending_op;
            } else {
163
              RunOpAsync(op_deps, pending_op, complete_q);
Y
Stash  
yuyang18 已提交
164 165 166 167 168 169 170 171 172 173
            }
          }
        }
      }
    }
    --remaining_;
    complete_q->Push(complete);
  });
}
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
S
sneaxiy 已提交
174 175
  atomic_op_deps_ = prepare_pool_.enqueue([&] {
    auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
Y
Stash  
yuyang18 已提交
176 177 178 179 180 181 182
    for (auto &pair : op_deps_) {
      (*op_deps)[pair.first] = pair.second;
    }
    return std::unique_ptr<
        std::unordered_map<OpHandleBase *, std::atomic<int>>>(op_deps);
  });
}
Y
yuyang18 已提交
183 184

const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
Y
Stash  
yuyang18 已提交
185 186 187
}  // namespace details
}  // namespace framework
}  // namespace paddle