提交 d49763a8 编写于 作者: Y yuyang18

Stash

上级 83c85f34
...@@ -42,3 +42,5 @@ cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_b ...@@ -42,3 +42,5 @@ cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_b
cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor) cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory #cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle ) # device_context reduce_op_handle )
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
namespace paddle {
namespace framework {
namespace details {
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph)
: strategy_(strategy),
local_scopes_(local_scopes),
places_(places),
graph_(std::move(graph)),
pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps
fetch_ctxs_(places) {
auto &ops = graph_->Get<details::GraphOps>("ops");
for (auto &op : ops) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op.get(), dep);
if (dep == 0) {
bootstrap_ops_.emplace_back(op.get());
}
}
PrepareAtomicOpDeps();
}
FeedFetchList FastThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps();
paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<ir::Node>> fetch_nodes;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program "
"is not set to ParallelExecutor)");
auto &vars = fetched_var_it->second;
fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i,
&local_scopes_);
fetch_ops.emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
(*op_deps)[op] = static_cast<int>(op->NotReadyInputSize());
}
size_t num_complete = 0;
remaining_ = 0;
BlockingQueue<size_t> complete_q;
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, &complete_q);
}
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q.Pop();
if (num_comp == -1UL) {
int remaining = remaining_;
for (int i = 0; i < remaining; ++i) {
complete_q.Pop();
}
LOG(FATAL) << "On exception thrown, not implemented";
}
num_complete += num_comp;
}
// Wait FetchOps.
if (!fetch_ops.empty()) {
fetch_ops.clear();
}
return fetches;
}
void FastThreadedSSAGraphExecutor::RunOpAsync(
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, BlockingQueue<size_t> *complete_q) {
++remaining_;
this->pool_.enqueue([=] {
OpHandleBase *op_to_run = op;
size_t complete = 0;
while (op_to_run != nullptr) {
try {
op_to_run->Run(strategy_.use_cuda_);
++complete;
} catch (...) {
--remaining_;
complete_q->Push(-1UL);
return;
}
auto &outputs = op_to_run->Outputs();
op_to_run = nullptr;
for (auto &output : outputs) {
for (auto &pending_op : output->PendingOps()) {
std::atomic<int> &deps = op_deps->at(pending_op);
if (deps.fetch_sub(1) == 1) { // pending_op ready
if (op_to_run == nullptr) {
op_to_run = pending_op;
} else {
this->RunOpAsync(op_deps, pending_op, complete_q);
}
}
}
}
}
--remaining_;
complete_q->Push(complete);
});
}
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = pool_.enqueue([&] {
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps =
new std::unordered_map<OpHandleBase *, std::atomic<int>>;
for (auto &pair : op_deps_) {
(*op_deps)[pair.first] = pair.second;
}
return std::unique_ptr<
std::unordered_map<OpHandleBase *, std::atomic<int>>>(op_deps);
});
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace paddle {
namespace framework {
class Scope;
namespace details {
class OpHandleBase;
class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph);
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
std::unique_ptr<ir::Graph> graph_;
std::unordered_map<OpHandleBase *, int> op_deps_;
std::vector<OpHandleBase *> bootstrap_ops_;
::ThreadPool pool_;
platform::DeviceContextPool fetch_ctxs_;
std::atomic<int> remaining_;
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, BlockingQueue<size_t> *complete_q);
void PrepareAtomicOpDeps();
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -158,6 +158,16 @@ void OpHandleBase::RunAndRecordEvent(platform::Place p, ...@@ -158,6 +158,16 @@ void OpHandleBase::RunAndRecordEvent(platform::Place p,
#endif #endif
} }
size_t OpHandleBase::NotReadyInputSize() const {
std::unordered_set<VarHandleBase *> res;
for (auto *var : inputs_) {
if (var->GeneratedOp() != nullptr) {
res.emplace(var);
}
}
return res.size();
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -81,6 +81,8 @@ class OpHandleBase { ...@@ -81,6 +81,8 @@ class OpHandleBase {
return res.size(); return res.size();
} }
size_t NotReadyInputSize() const;
const std::vector<VarHandleBase *> &Outputs() const { return outputs_; } const std::vector<VarHandleBase *> &Outputs() const { return outputs_; }
size_t NoDummyInputSize() const; size_t NoDummyInputSize() const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册