threaded_ssa_graph_executor.cc 7.3 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"

#include "paddle/fluid/framework/details/fetch_op_handle.h"

namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
    size_t num_threads, bool use_event,
    const std::vector<Scope *> &local_scopes,
    const std::vector<platform::Place> &places,
X
Xin Pan 已提交
26
    std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay)
Y
Yu Yang 已提交
27 28 29 30 31
    : SSAGraphExecutor(std::move(graph)),
      pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
      local_scopes_(local_scopes),
      places_(places),
      fetch_ctxs_(places),
X
Xin Pan 已提交
32
      use_event_(use_event),
X
Xin Pan 已提交
33 34
      running_ops_(0),
      allow_op_delay_(allow_op_delay) {}
X
Xin Pan 已提交
35 36 37 38 39 40 41

void ThreadedSSAGraphExecutor::RunDelayedOps(
    const std::unordered_set<OpHandleBase *> &delayed_ops) {
  for (auto op : delayed_ops) {
    op->Run(use_event_);
  }
}
Y
Yu Yang 已提交
42 43 44 45

FeedFetchList ThreadedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  std::unordered_map<OpHandleBase *, size_t> pending_ops;
Y
Yu Yang 已提交
46 47
  std::unordered_set<VarHandleBase *> pending_vars;
  BlockingQueue<VarHandleBase *> ready_vars;
Y
Yu Yang 已提交
48
  std::unordered_set<OpHandleBase *> ready_ops;
X
Xin Pan 已提交
49 50 51 52
  // For ops (e.g. nccl_all_reduce) that need to coordinate multiple
  // streams from multiple GPUs, it's faster to buffer them and schedule
  // together since we currently cannot overlap computation and memcpy streams.
  // Should revisit it if overlapping is available.
X
Xin Pan 已提交
53
  std::unordered_set<OpHandleBase *> delayed_ops;
X
Polish  
Xin Pan 已提交
54
  std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
X
Xin Pan 已提交
55 56
  std::unordered_set<VarHandleBase *> delayed_vars;

Y
Yu Yang 已提交
57 58 59 60 61
  auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
    pending_vars.insert(&var);
    if (var.generated_op_ == nullptr) {
      ready_vars.Push(&var);
    }
Y
Yu Yang 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
  };

  auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
    pending_ops.insert({&op_instance, op_instance.inputs_.size()});
  };

  // Transform SSAGraph to pending_ops & pending_vars
  for (auto &var_map : graph_->vars_) {
    for (auto &name_pair : var_map) {
      for (auto &version_pair : name_pair.second) {
        InsertPendingVar(version_pair.second);
      }
    }
  }
  for (auto &var : graph_->dep_vars_) {
    InsertPendingVar(*var);
  }

  for (auto &op : graph_->ops_) {
    if (op->inputs_.empty()) {  // Special case, Op has no input.
      ready_ops.insert(op.get());
    } else {
      InsertPendingOp(*op);
    }
  }

  // Step 2. Insert FetchOps
Y
Yu Yang 已提交
89
  std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
Y
Yu Yang 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  std::vector<DummyVarHandle> dummy_vars;
  FeedFetchList fetch_data(fetch_tensors.size());

  std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;

  for (auto &fetch_var_name : fetch_tensors) {
    for (auto &var_map : graph_->vars_) {
      auto it = var_map.find(fetch_var_name);
      if (it != var_map.end()) {
        fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
      }
    }
  }

  for (size_t i = 0; i < fetch_tensors.size(); ++i) {
    auto &var_name = fetch_tensors[i];
Y
Yu Yang 已提交
106 107 108
    auto &vars = fetched_vars.at(var_name);
    auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
    fetch_ops.emplace_back(op);
Y
Yu Yang 已提交
109 110 111

    // FIXME: Use new device context
    for (auto &p : places_) {
Y
Yu Yang 已提交
112
      op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
Y
Yu Yang 已提交
113 114 115 116 117 118 119 120 121 122
    }

    for (auto *var : vars) {
      op->AddInput(var);
    }
    InsertPendingOp(*op);
  }

  auto run_all_ready_ops = [&] {
    for (auto *op : ready_ops) {
X
Xin Pan 已提交
123
      if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
X
Xin Pan 已提交
124 125 126 127 128 129 130
        delayed_ops.insert(op);
        delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
        ready_vars.Extend(op->outputs_);
        continue;
      }
      running_ops_++;
      RunOp(&ready_vars, op);
Y
Yu Yang 已提交
131 132 133 134
    }
    ready_ops.clear();
  };

Y
Yu Yang 已提交
135 136 137 138 139 140
  // Create local scopes.
  for (auto &scope : local_scopes_) {
    auto &local_scope = scope->NewScope();
    *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
  }

Y
Yu Yang 已提交
141
  // Step 3. Execution
X
Xin Pan 已提交
142
  while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
Y
Yu Yang 已提交
143 144 145 146
    // 1. Run All Ready ops
    run_all_ready_ops();

    // 2. Find ready variable
Y
Yu Yang 已提交
147
    bool timeout;
X
Xin Pan 已提交
148
    auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
Y
Yu Yang 已提交
149 150 151 152 153 154 155 156

    if (timeout) {
      if (exception_) {
        throw * exception_;
      } else {
        continue;
      }
    }
Y
Yu Yang 已提交
157 158
    // 3. Remove the dependency of ready_var.
    // Find the ready_ops after the ready_var.
Y
Yu Yang 已提交
159 160 161 162 163 164
    for (auto ready_var : cur_ready_vars) {
      pending_vars.erase(ready_var);
      for (auto *op : ready_var->pending_ops_) {
        auto &deps = pending_ops[op];
        --deps;
        if (deps == 0) {
X
Xin Pan 已提交
165
          if (delayed_vars.find(ready_var) != delayed_vars.end()) {
X
Polish  
Xin Pan 已提交
166
            blocked_by_delayed_ops.insert(op);
X
Xin Pan 已提交
167 168 169
          } else {
            ready_ops.insert(op);
          }
Y
Yu Yang 已提交
170
        }
Y
Yu Yang 已提交
171 172
      }
    }
X
Polish  
Xin Pan 已提交
173 174
    // When there are no other ops to schedule, schedule buffered delayed
    // ops and unblock other ops.
X
Xin Pan 已提交
175 176 177
    if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
      RunDelayedOps(delayed_ops);
      delayed_ops.clear();
X
Polish  
Xin Pan 已提交
178
      for (auto *op : blocked_by_delayed_ops) {
X
Xin Pan 已提交
179 180
        ready_ops.insert(op);
      }
X
Polish  
Xin Pan 已提交
181
      blocked_by_delayed_ops.clear();
X
Xin Pan 已提交
182
    }
Y
Yu Yang 已提交
183 184
    // Keep loop until all vars are ready.
  }
X
Xin Pan 已提交
185 186 187
  PADDLE_ENFORCE(ready_ops.empty());
  PADDLE_ENFORCE(delayed_ops.empty());
  PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
Y
Yu Yang 已提交
188 189 190 191 192 193 194 195
  ++computation_count_;

  auto sync_computation = [&] {
    computation_count_ = 0;
    // Wait All computational streams
    for (auto p : this->places_) {
      platform::DeviceContextPool::Instance().Get(p)->Wait();
    }
Y
Yu Yang 已提交
196 197
    for (auto &scope : local_scopes_) {
      scope->DropKids();
Y
Yu Yang 已提交
198 199 200
    }
  };

Y
Yu Yang 已提交
201
  // Wait FetchOps.
Y
Yu Yang 已提交
202
  if (!fetch_ops.empty()) {
Y
Yu Yang 已提交
203
    fetch_ops.clear();
Y
Yu Yang 已提交
204 205 206 207 208
    sync_computation();
  }

  if (computation_count_ == max_async_computation) {
    sync_computation();
Y
Yu Yang 已提交
209 210
  }

Y
Yu Yang 已提交
211 212 213 214 215 216 217
  // NOTE: the temp scope can be dropped lazily if needed.
  // Drop tmp scopes;
  for (auto &scope : local_scopes_) {
    auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
    kid = nullptr;
  }

Y
Yu Yang 已提交
218 219 220 221
  return fetch_data;
}

void ThreadedSSAGraphExecutor::RunOp(
X
Xin Pan 已提交
222 223
    BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
  auto op_run = [ready_var_q, op, this] {
Y
Yu Yang 已提交
224
    try {
Y
Yu Yang 已提交
225
      VLOG(10) << op->Name() << " : " << op->DebugString();
Y
Yu Yang 已提交
226
      op->Run(use_event_);
X
Xin Pan 已提交
227 228
      running_ops_--;
      ready_var_q->Extend(op->outputs_);
Y
Yu Yang 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
    } catch (platform::EnforceNotMet ex) {
      exception_.reset(new platform::EnforceNotMet(ex));
    } catch (...) {
      LOG(FATAL) << "Unknown exception catched";
    }
  };
  if (pool_) {
    pool_->enqueue(op_run);
  } else {
    op_run();
  }
}
}  // namespace details
}  // namespace framework
}  // namespace paddle