threaded_ssa_graph_executor.cc 7.0 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"

#include "paddle/fluid/framework/details/fetch_op_handle.h"

namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
    size_t num_threads, bool use_event,
    const std::vector<Scope *> &local_scopes,
    const std::vector<platform::Place> &places,
    std::unique_ptr<SSAGraph> &&graph)
    : SSAGraphExecutor(std::move(graph)),
      pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
      local_scopes_(local_scopes),
      places_(places),
      fetch_ctxs_(places),
X
Xin Pan 已提交
32 33 34 35 36 37 38 39 40
      use_event_(use_event),
      running_ops_(0) {}

void ThreadedSSAGraphExecutor::RunDelayedOps(
    const std::unordered_set<OpHandleBase *> &delayed_ops) {
  for (auto op : delayed_ops) {
    op->Run(use_event_);
  }
}
Y
Yu Yang 已提交
41 42 43 44

FeedFetchList ThreadedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  std::unordered_map<OpHandleBase *, size_t> pending_ops;
Y
Yu Yang 已提交
45 46
  std::unordered_set<VarHandleBase *> pending_vars;
  BlockingQueue<VarHandleBase *> ready_vars;
Y
Yu Yang 已提交
47
  std::unordered_set<OpHandleBase *> ready_ops;
X
Xin Pan 已提交
48 49 50 51
  // For ops (e.g. nccl_all_reduce) that need to coordinate multiple
  // streams from multiple GPUs, it's faster to buffer them and schedule
  // together since we currently cannot overlap computation and memcpy streams.
  // Should revisit it if overlapping is available.
X
Xin Pan 已提交
52
  std::unordered_set<OpHandleBase *> delayed_ops;
X
Polish  
Xin Pan 已提交
53
  std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
X
Xin Pan 已提交
54 55
  std::unordered_set<VarHandleBase *> delayed_vars;

Y
Yu Yang 已提交
56 57 58 59 60
  auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
    pending_vars.insert(&var);
    if (var.generated_op_ == nullptr) {
      ready_vars.Push(&var);
    }
Y
Yu Yang 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
  };

  auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
    pending_ops.insert({&op_instance, op_instance.inputs_.size()});
  };

  // Transform SSAGraph to pending_ops & pending_vars
  for (auto &var_map : graph_->vars_) {
    for (auto &name_pair : var_map) {
      for (auto &version_pair : name_pair.second) {
        InsertPendingVar(version_pair.second);
      }
    }
  }
  for (auto &var : graph_->dep_vars_) {
    InsertPendingVar(*var);
  }

  for (auto &op : graph_->ops_) {
    if (op->inputs_.empty()) {  // Special case, Op has no input.
      ready_ops.insert(op.get());
    } else {
      InsertPendingOp(*op);
    }
  }

  // Step 2. Insert FetchOps
Y
Yu Yang 已提交
88
  std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
Y
Yu Yang 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
  std::vector<DummyVarHandle> dummy_vars;
  FeedFetchList fetch_data(fetch_tensors.size());

  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;

  for (auto &fetch_var_name : fetch_tensors) {
    for (auto &var_map : graph_->vars_) {
      auto it = var_map.find(fetch_var_name);
      if (it != var_map.end()) {
        fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
      }
    }
  }

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
    auto &var_name = fetch_tensors[i];
Y
Yu Yang 已提交
105 106 107
    auto &vars = fetched_vars.at(var_name);
    auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
    fetch_ops.emplace_back(op);
Y
Yu Yang 已提交
108 109 110

    // FIXME: Use new device context
    for (auto &p : places_) {
Y
Yu Yang 已提交
111
      op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
Y
Yu Yang 已提交
112 113 114 115 116 117 118 119 120 121
    }

    for (auto *var : vars) {
      op->AddInput(var);
    }
    InsertPendingOp(*op);
  }

  auto run_all_ready_ops = [&] {
    for (auto *op : ready_ops) {
X
Polish  
Xin Pan 已提交
122
      if (op->IsMultiDeviceTransfer()) {
X
Xin Pan 已提交
123 124 125 126 127 128 129
        delayed_ops.insert(op);
        delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
        ready_vars.Extend(op->outputs_);
        continue;
      }
      running_ops_++;
      RunOp(&ready_vars, op);
Y
Yu Yang 已提交
130 131 132 133
    }
    ready_ops.clear();
  };

Y
Yu Yang 已提交
134 135 136 137 138 139
  // Create local scopes.
  for (auto &scope : local_scopes_) {
    auto &local_scope = scope->NewScope();
    *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
  }

Y
Yu Yang 已提交
140 141 142 143 144 145
  // Step 3. Execution
  while (!pending_vars.empty()) {
    // 1. Run All Ready ops
    run_all_ready_ops();

    // 2. Find ready variable
Y
Yu Yang 已提交
146
    bool timeout;
X
Xin Pan 已提交
147
    auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
Y
Yu Yang 已提交
148 149 150 151 152 153 154 155

    if (timeout) {
      if (exception_) {
        throw * exception_;
      } else {
        continue;
      }
    }
Y
Yu Yang 已提交
156 157
    // 3. Remove the dependency of ready_var.
    // Find the ready_ops after the ready_var.
Y
Yu Yang 已提交
158 159 160 161 162 163
    for (auto ready_var : cur_ready_vars) {
      pending_vars.erase(ready_var);
      for (auto *op : ready_var->pending_ops_) {
        auto &deps = pending_ops[op];
        --deps;
        if (deps == 0) {
X
Xin Pan 已提交
164
          if (delayed_vars.find(ready_var) != delayed_vars.end()) {
X
Polish  
Xin Pan 已提交
165
            blocked_by_delayed_ops.insert(op);
X
Xin Pan 已提交
166 167 168
          } else {
            ready_ops.insert(op);
          }
Y
Yu Yang 已提交
169
        }
Y
Yu Yang 已提交
170 171
      }
    }
X
Polish  
Xin Pan 已提交
172 173
    // When there are no other ops to schedule, schedule buffered delayed
    // ops and unblock other ops.
X
Xin Pan 已提交
174 175 176
    if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
      RunDelayedOps(delayed_ops);
      delayed_ops.clear();
X
Polish  
Xin Pan 已提交
177
      for (auto *op : blocked_by_delayed_ops) {
X
Xin Pan 已提交
178 179
        ready_ops.insert(op);
      }
X
Polish  
Xin Pan 已提交
180
      blocked_by_delayed_ops.clear();
X
Xin Pan 已提交
181
    }
Y
Yu Yang 已提交
182 183
    // Keep loop until all vars are ready.
  }
Y
Yu Yang 已提交
184 185 186 187 188 189 190 191
  ++computation_count_;

  auto sync_computation = [&] {
    computation_count_ = 0;
    // Wait All computational streams
    for (auto p : this->places_) {
      platform::DeviceContextPool::Instance().Get(p)->Wait();
    }
Y
Yu Yang 已提交
192 193
    for (auto &scope : local_scopes_) {
      scope->DropKids();
Y
Yu Yang 已提交
194 195 196
    }
  };

Y
Yu Yang 已提交
197
  // Wait FetchOps.
Y
Yu Yang 已提交
198
  if (!fetch_ops.empty()) {
Y
Yu Yang 已提交
199
    fetch_ops.clear();
Y
Yu Yang 已提交
200 201 202 203 204
    sync_computation();
  }

  if (computation_count_ == max_async_computation) {
    sync_computation();
Y
Yu Yang 已提交
205 206
  }

Y
Yu Yang 已提交
207 208 209 210 211 212 213
  // NOTE: the temp scope can be dropped lazily if needed.
  // Drop tmp scopes;
  for (auto &scope : local_scopes_) {
    auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
    kid = nullptr;
  }

Y
Yu Yang 已提交
214 215 216 217
  return fetch_data;
}

void ThreadedSSAGraphExecutor::RunOp(
X
Xin Pan 已提交
218 219
    BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
  auto op_run = [ready_var_q, op, this] {
Y
Yu Yang 已提交
220
    try {
Y
Yu Yang 已提交
221
      VLOG(10) << op->Name() << " : " << op->DebugString();
Y
Yu Yang 已提交
222
      op->Run(use_event_);
X
Xin Pan 已提交
223 224
      running_ops_--;
      ready_var_q->Extend(op->outputs_);
Y
Yu Yang 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    } catch (platform::EnforceNotMet ex) {
      exception_.reset(new platform::EnforceNotMet(ex));
    } catch (...) {
      LOG(FATAL) << "Unknown exception catched";
    }
  };
  if (pool_) {
    pool_->enqueue(op_run);
  } else {
    op_run();
  }
}
}  // namespace details
}  // namespace framework
}  // namespace paddle