threaded_ssa_graph_executor.cc 5.5 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"

#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
    size_t num_threads, bool use_event,
    const std::vector<Scope *> &local_scopes,
    const std::vector<platform::Place> &places,
    std::unique_ptr<SSAGraph> &&graph)
    : SSAGraphExecutor(std::move(graph)),
      pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
      local_scopes_(local_scopes),
      places_(places),
      fetch_ctxs_(places),
      use_event_(use_event) {}

FeedFetchList ThreadedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  std::unordered_map<OpHandleBase *, size_t> pending_ops;
Y
Yu Yang 已提交
38 39 40 41
  std::unordered_set<VarHandleBase *> pending_vars;

  BlockingQueue<VarHandleBase *> ready_vars;

Y
Yu Yang 已提交
42 43
  std::unordered_set<OpHandleBase *> ready_ops;

Y
Yu Yang 已提交
44 45 46 47 48
  auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
    pending_vars.insert(&var);
    if (var.generated_op_ == nullptr) {
      ready_vars.Push(&var);
    }
Y
Yu Yang 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
  };

  auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
    pending_ops.insert({&op_instance, op_instance.inputs_.size()});
  };

  // Transform SSAGraph to pending_ops & pending_vars
  for (auto &var_map : graph_->vars_) {
    for (auto &name_pair : var_map) {
      for (auto &version_pair : name_pair.second) {
        InsertPendingVar(version_pair.second);
      }
    }
  }
  for (auto &var : graph_->dep_vars_) {
    InsertPendingVar(*var);
  }

  for (auto &op : graph_->ops_) {
    if (op->inputs_.empty()) {  // Special case, Op has no input.
      ready_ops.insert(op.get());
    } else {
      InsertPendingOp(*op);
    }
  }

  // Step 2. Insert FetchOps
Y
Yu Yang 已提交
76
  std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
Y
Yu Yang 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  std::vector<DummyVarHandle> dummy_vars;
  FeedFetchList fetch_data(fetch_tensors.size());

  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;

  for (auto &fetch_var_name : fetch_tensors) {
    for (auto &var_map : graph_->vars_) {
      auto it = var_map.find(fetch_var_name);
      if (it != var_map.end()) {
        fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
      }
    }
  }

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
    auto &var_name = fetch_tensors[i];
Y
Yu Yang 已提交
93 94 95
    auto &vars = fetched_vars.at(var_name);
    auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
    fetch_ops.emplace_back(op);
Y
Yu Yang 已提交
96 97 98

    // FIXME: Use new device context
    for (auto &p : places_) {
Y
Yu Yang 已提交
99
      op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
Y
Yu Yang 已提交
100 101 102 103 104 105 106 107 108 109
    }

    for (auto *var : vars) {
      op->AddInput(var);
    }
    InsertPendingOp(*op);
  }

  auto run_all_ready_ops = [&] {
    for (auto *op : ready_ops) {
Y
Yu Yang 已提交
110
      RunOp(ready_vars, op);
Y
Yu Yang 已提交
111 112 113 114
    }
    ready_ops.clear();
  };

Y
Yu Yang 已提交
115 116 117 118 119 120
  // Create local scopes.
  for (auto &scope : local_scopes_) {
    auto &local_scope = scope->NewScope();
    *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
  }

Y
Yu Yang 已提交
121 122 123 124 125 126
  // Step 3. Execution
  while (!pending_vars.empty()) {
    // 1. Run All Ready ops
    run_all_ready_ops();

    // 2. Find ready variable
Y
Yu Yang 已提交
127
    VarHandleBase *ready_var = ready_vars.Pop();
Y
Yu Yang 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141

    // 3. Remove the dependency of ready_var.
    // Find the ready_ops after the ready_var.
    pending_vars.erase(ready_var);
    for (auto *op : ready_var->pending_ops_) {
      auto &deps = pending_ops[op];
      --deps;
      if (deps == 0) {
        ready_ops.insert(op);
      }
    }
    // Keep loop until all vars are ready.
  }

Y
Yu Yang 已提交
142 143 144 145 146 147 148 149
  ++computation_count_;

  auto sync_computation = [&] {
    computation_count_ = 0;
    // Wait All computational streams
    for (auto p : this->places_) {
      platform::DeviceContextPool::Instance().Get(p)->Wait();
    }
Y
Yu Yang 已提交
150 151
    for (auto &scope : local_scopes_) {
      scope->DropKids();
Y
Yu Yang 已提交
152 153 154
    }
  };

Y
Yu Yang 已提交
155
  // Wait FetchOps.
Y
Yu Yang 已提交
156
  if (!fetch_ops.empty()) {
Y
Yu Yang 已提交
157
    fetch_ops.clear();
Y
Yu Yang 已提交
158 159 160 161 162
    sync_computation();
  }

  if (computation_count_ == max_async_computation) {
    sync_computation();
Y
Yu Yang 已提交
163 164
  }

Y
Yu Yang 已提交
165 166 167 168 169 170 171
  // NOTE: the temp scope can be dropped lazily if needed.
  // Drop tmp scopes;
  for (auto &scope : local_scopes_) {
    auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
    kid = nullptr;
  }

Y
Yu Yang 已提交
172 173 174 175
  return fetch_data;
}

void ThreadedSSAGraphExecutor::RunOp(
Y
Yu Yang 已提交
176 177
    BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
  auto op_run = [&ready_var_q, op, this] {
Y
Yu Yang 已提交
178
    try {
Y
Yu Yang 已提交
179
      VLOG(10) << op->Name() << " : " << op->DebugString();
Y
Yu Yang 已提交
180
      op->Run(use_event_);
Y
Yu Yang 已提交
181
      ready_var_q.Extend(op->outputs_);
Y
Yu Yang 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    } catch (platform::EnforceNotMet ex) {
      exception_.reset(new platform::EnforceNotMet(ex));
    } catch (...) {
      LOG(FATAL) << "Unknown exception catched";
    }
  };
  if (pool_) {
    pool_->enqueue(op_run);
  } else {
    op_run();
  }
}
}  // namespace details
}  // namespace framework
}  // namespace paddle