提交 096673f6 编写于 作者: S sneaxiy

refactor eager deletion

test=develop
上级 400cf19f
......@@ -33,10 +33,9 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
......@@ -44,10 +43,7 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
......@@ -20,11 +20,13 @@ namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
platform::Place place)
platform::Place place,
size_t scope_idx)
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
scope_(scope),
place_(place) {}
place_(place),
scope_idx_(scope_idx) {}
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
......
......@@ -28,7 +28,8 @@ namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
public:
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx);
std::string Name() const override;
......@@ -38,6 +39,8 @@ struct ComputationOpHandle : public OpHandleBase {
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
size_t GetScopeIdx() const { return scope_idx_; }
protected:
void RunImpl() override;
......@@ -47,6 +50,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
size_t scope_idx_;
bool is_lock_and_record_event_free_{false};
};
} // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle {
namespace framework {
namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place,
const std::vector<std::string> &var_names, GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place).device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
}
#endif
for (auto &name : var_names) AddVar(name);
}
EagerDeletionOpHandle::~EagerDeletionOpHandle() {
#ifdef PADDLE_WITH_CUDA
if (event_) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
#endif
}
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::AddVar(const std::string &name) {
var_names_.insert(name);
}
void EagerDeletionOpHandle::RunImpl() {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<Tensor *> tensors;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) {
continue;
}
auto *var = exec_scope->FindVar(name);
if (var == nullptr) {
continue;
}
if (var->IsType<LoDTensor>()) {
if (it->second.fetch_sub(1) == 1) {
tensors.emplace_back(var->GetMutable<LoDTensor>());
}
} else if (var->IsType<SelectedRows>()) {
if (it->second.fetch_sub(1) == 1) {
tensors.emplace_back(var->GetMutable<SelectedRows>()->mutable_value());
}
} else if (var->IsType<LoDTensorArray>()) {
if (it->second.fetch_sub(1) == 1) {
auto *tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto &t : *tensor_arr) {
tensors.emplace_back(&t);
}
}
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
void EagerDeletionOpHandle::ClearTensors(const std::vector<Tensor *> &tensors) {
#ifdef PADDLE_WITH_CUDA
if (event_) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream =
static_cast<StreamGarbageCollector<Tensor> *>(gc_)->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
#endif
gc_->Add(tensors);
#ifdef PADDLE_WITH_CUDA
}
#endif
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
class Scope;
namespace details {
class EagerDeletionPass;
class EagerDeletionOpHandle : public OpHandleBase {
public:
EagerDeletionOpHandle(ir::Node *node, const Scope *scope,
const platform::Place &place,
const std::vector<std::string> &var_names,
GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle();
std::string Name() const override;
protected:
void RunImpl() override;
private:
void ClearTensors(const std::vector<Tensor *> &tensors);
void AddVar(const std::string &name);
const Scope *scope_;
std::unordered_set<std::string> var_names_;
GarbageCollector<Tensor> *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr};
#endif
friend class EagerDeletionPass;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
// Add leaf node to eager_deletion_node
if (out->Outputs().empty()) {
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
out->AddOutput(dummy_leaf);
}
}
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &vars = graph->Get<GraphVars>(kGraphVars);
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kCurReferenceCount);
auto &last_live_ops = Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
auto &gcs = Get<GarbageCollectorList>(kGarbageCollector);
ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size());
std::unordered_map<ComputationOpHandle *, EagerDeletionOpHandle *> op_map;
for (auto &var_ops_map : last_live_ops) {
for (auto &var_ops_pair : var_ops_map) {
const std::string &var_name = var_ops_pair.first;
for (ComputationOpHandle *op : var_ops_pair.second) {
auto it = op_map.find(op);
if (it != op_map.end()) {
it->second->AddVar(var_name);
} else {
auto *eager_deletion_node = graph->CreateEmptyNode(
"eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), {var_name},
gcs[op->GetScopeIdx()].get(), &(ref_cnts[op->GetScopeIdx()]));
AddDependencyBetween(op, eager_deletion_op, graph.get());
op_map[op] = eager_deletion_op;
}
}
}
}
VLOG(10) << "Create " << op_map.size() << " EagerDeletionOpHandle(s)";
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(eager_deletion_pass,
paddle::framework::details::EagerDeletionPass)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -562,7 +562,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id]));
local_scopes_[dev_id], places_[dev_id], dev_id));
CreateOpHandleIOs(result, node, dev_id);
}
......@@ -685,8 +685,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
result->Get<GraphOps>(kGraphOps).emplace_back(new ComputationOpHandle(
result->CreateOpNode(node->Op()), s, p, scope_idx));
CreateOpHandleIOs(result, node, scope_idx);
}
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
using ReferenceCountMap = std::unordered_map<std::string, int>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<int>>;
using DeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<ReferenceCountMap>>;
using AtomicDeviceReferenceCountMap =
std::unordered_map<int, std::unique_ptr<AtomicReferenceCountMap>>;
using DeviceGarbageCollectorMap =
std::unordered_map<int,
std::unique_ptr<GarbageCollector<framework::Tensor>>>;
class ReferenceCountOpHandle : public OpHandleBase {
public:
ReferenceCountOpHandle(ir::Node *node, const Scope *scope,
const platform::CUDAPlace &place,
const std::vector<std::string> &var_names,
GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) {
platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}
for (auto &name : var_names) AddVar(name);
}
~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_));
}
}
std::string Name() const override { return "reference_count"; }
void AddVar(const std::string &name) {
auto it = var_names_.find(name);
if (it != var_names_.end())
++(it->second);
else
var_names_[name] = 1;
}
protected:
void RunImpl() override {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<Tensor *> tensors;
for (auto &pair : var_names_) {
auto &name = pair.first;
auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) continue;
auto *var = exec_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<LoDTensor>()) {
if (it->second.fetch_sub(pair.second) <= pair.second) {
tensors.emplace_back(var->GetMutable<LoDTensor>());
}
} else if (var->IsType<SelectedRows>()) {
if (it->second.fetch_sub(pair.second) <= pair.second) {
tensors.emplace_back(
var->GetMutable<SelectedRows>()->mutable_value());
}
}
}
if (!tensors.empty()) {
ClearTensors(tensors);
}
}
private:
void ClearTensors(const std::vector<Tensor *> &tensors) {
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream();
auto callback_stream = gc->stream();
auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
};
gc_->Add(tensors, callback_func);
} else {
gc_->Add(tensors);
}
}
bool IsStreamGarabageCollector() const {
return dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_) != nullptr;
}
const Scope *scope_;
platform::CUDADeviceContext *dev_ctx_;
std::unordered_map<std::string, int> var_names_;
GarbageCollector<Tensor> *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
cudaEvent_t event_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -17,184 +17,96 @@
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
std::queue<VarHandleBase *> queue;
queue.push(var_in);
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
auto *var = queue.front();
queue.pop();
for (auto *op : var->PendingOps()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
queue.push(out_var);
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
}
}
} while (!queue.empty());
} while (!q.empty());
return nullptr;
}
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
auto &cur_ref_cnts = Get<AtomicDeviceReferenceCountMap>(kCurReferenceCount);
auto &gcs = Get<DeviceGarbageCollectorMap>(kGarbageCollector);
// It is not easy to find the right reference counts of varaibles in graph
// Step 1: Find all variables in computation ops
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&](
OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op;
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
auto &vars = graph->Get<GraphVars>(kGraphVars);
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size());
ref_cnts = std::vector<ReferenceCountMap>(vars.size());
for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) {
if (name_var_pair.second.empty()) continue;
auto *last_ver_var = name_var_pair.second.back();
VarDesc *var_desc = nullptr;
std::find_if(name_var_pair.second.rbegin(), name_var_pair.second.rend(),
[&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
if (!platform::is_gpu_place(var_handle->place_) ||
boost::get<platform::CUDAPlace>(var_handle->place_) != place)
if (var_desc == nullptr || var_desc->Persistable()) {
continue;
VarDesc *var_desc = var_handle->Node()->Var();
auto var_name = var_handle->Node()->Name();
// This is weird but there is really some variables without var_desc
// in computation_op
if (var_desc == nullptr) {
var_desc = compute_op->Node()->Op()->Block()->FindVar(var_name);
if (var_desc == nullptr) continue;
}
if (var_desc->Persistable()) continue;
auto var_type = var_desc->Proto()->type().type();
if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS) {
var_type != proto::VarType::SELECTED_ROWS &&
var_type != proto::VarType::LOD_TENSOR_ARRAY) {
continue;
}
// compute op only runs in one device
if (ref_cnts[place.device]->count(var_name))
++(*ref_cnts[place.device])[var_name];
else
(*ref_cnts[place.device])[var_name] = 1;
names.insert(var_name);
var_names_in_op.push_back(var_name);
std::unordered_set<ComputationOpHandle *> last_live_op;
auto add_last_live_op = [&](OpHandleBase *op) {
auto *compute_op = FindNextComputationOpHandleOrReturnItself(op, i);
if (compute_op) {
last_live_op.insert(compute_op);
}
return var_names_in_op;
};
auto update_ref_cnts_from_non_compute_op = [&](
OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
auto var_name = var_handle->Node()->Name();
auto var_place = var_handle->place_;
if (!platform::is_gpu_place(var_place)) continue;
auto place = boost::get<platform::CUDAPlace>(var_place);
if (names.count(var_name) == 0) continue;
if (ref_cnts.count(place.device) &&
ref_cnts[place.device]->count(var_name)) {
++(*ref_cnts[place.device])[var_name];
auto *next_compute_op = FindNextComputationOpHandle(var_handle);
if (next_compute_op != nullptr) {
if (compute_ref_cnt_map.count(next_compute_op)) {
compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
VLOG(5) << "Add reference count of " << var_name << " to Operator "
<< next_compute_op->Name();
} else {
// Create new reference_count_op_handle
ir::Node *ref_cnt_node = graph->CreateEmptyNode(
"reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
}
}
}
const std::string &var_name = name_var_pair.first;
auto &pending_ops = last_ver_var->PendingOps();
if (pending_ops.empty()) {
auto *generated_op = last_ver_var->GeneratedOp();
if (generated_op) {
ref_cnts[i].emplace(var_name, 1);
add_last_live_op(generated_op);
}
};
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
} else {
ref_cnts[i].emplace(var_name, pending_ops.size());
for (auto *pending_op : pending_ops) {
add_last_live_op(pending_op);
}
for (auto &op : all_ops) {
update_ref_cnts_from_non_compute_op(op, op->Inputs());
update_ref_cnts_from_non_compute_op(op, op->Outputs());
}
std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
it->second->AddOutput(dummy_leaf);
new_all_ops.emplace_back(std::move(it->second));
last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op));
}
}
all_ops.swap(new_all_ops);
return graph;
}
......@@ -205,5 +117,4 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);
......@@ -14,7 +14,6 @@
#pragma once
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -22,10 +21,6 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kGlobalReferenceCount[] = "reference_count";
constexpr char kCurReferenceCount[] = "current_reference_count";
constexpr char kGarbageCollector[] = "garbage_collector";
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace details {
class ComputationOpHandle;
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<size_t>>;
using GarbageCollectorList =
std::vector<std::unique_ptr<GarbageCollector<Tensor>>>;
const char kGlobalReferenceCount[] = "reference_count";
const char kCurReferenceCount[] = "current_reference_count";
const char kGarbageCollector[] = "garbage_collector";
using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -18,9 +18,6 @@
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"
#endif
namespace paddle {
namespace framework {
......@@ -33,7 +30,11 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
underlying_executor_(std::move(underlying_executor)),
local_scopes_(std::move(local_scopes)),
var_infos_(std::move(var_infos)),
places_(std::move(places)) {}
places_(std::move(places)) {
if (Graph().Has(details::kGarbageCollector)) {
gc_ = &(Graph().Get<GarbageCollectorList>(details::kGarbageCollector));
}
}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
......@@ -69,27 +70,16 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
drop_scope_counter_ += 1;
#ifdef PADDLE_WITH_CUDA
const std::string gc_name = "garbage_collector";
DeviceGarbageCollectorMap *gc =
Graph().Has(gc_name) ? &(Graph().Get<DeviceGarbageCollectorMap>(gc_name))
: nullptr;
#endif
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0;
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
#ifdef PADDLE_WITH_CUDA
if (gc != nullptr && platform::is_gpu_place(p)) {
auto gpu_place = boost::get<platform::CUDAPlace>(p);
auto &gc_at_place = gc->at(gpu_place.device);
gc_at_place->Wait();
gc_at_place->Reset();
for (size_t i = 0; i < places_.size(); ++i) {
platform::DeviceContextPool::Instance().Get(places_[i])->Wait();
if (gc_) {
(*gc_)[i]->Wait();
(*gc_)[i]->Reset();
}
#endif
}
for (auto &scope : local_scopes_) {
auto &local_scope =
......
......@@ -21,9 +21,11 @@
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace details {
......@@ -55,6 +57,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope*> local_scopes_;
std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_;
GarbageCollectorList* gc_{nullptr};
};
} // namespace details
} // namespace framework
......
......@@ -65,7 +65,7 @@ class GarbageCollector {
if (clear_deque != nullptr) {
callback();
ClearCallback([=]() {
ClearCallback([clear_deque]() {
for (auto *obj : *clear_deque) obj->clear();
});
}
......@@ -109,7 +109,6 @@ class DefaultStreamGarbageCollector : public GarbageCollector<T> {
}
void Wait() const override {
this->dev_ctx_->Wait();
static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
......@@ -127,14 +126,14 @@ class StreamGarbageCollector : public GarbageCollector<T> {
StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {
PADDLE_ENFORCE(cudaSetDevice(place.device));
platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(place.device));
platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
......@@ -148,8 +147,11 @@ class StreamGarbageCollector : public GarbageCollector<T> {
cudaStream_t stream() const { return stream_; }
protected:
// ClearCallback and Wait()/Reset() cannot be call in multiple threads
// But it is not important, because they would not be called in multiple
// threads
// either in Executor or ParallelExecutor
void ClearCallback(const std::function<void()> &callback) override {
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->AddCallback(callback);
}
......
......@@ -73,14 +73,21 @@ class Graph {
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
return attrs_.count(attr_name) > 0;
}
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
attr_name);
try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) {
PADDLE_THROW(
"Invalid attribute type of %s error, expected: %s, actual: %s",
attr_name, typeid(AttrType *).name(),
attrs_.at(attr_name).type().name());
}
}
template <typename AttrType>
......
......@@ -51,11 +51,18 @@ class Pass {
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for pass.", attr_name);
try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) {
PADDLE_THROW(
"Invalid attribute type of %s error, expected: %s, actual: %s",
attr_name, typeid(AttrType *).name(),
attrs_.at(attr_name).type().name());
}
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
return attrs_.count(attr_name) > 0;
}
void Erase(const std::string &attr_name) {
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -49,6 +50,15 @@ class ParallelExecutorPrivate {
}
}
}
void ResetRuntimeReferenceCount() {
for (size_t i = 0; i < rt_ref_cnts_.size(); ++i) {
for (auto &pair : rt_ref_cnts_[i]) {
rt_cur_ref_cnts_[i][pair.first] = pair.second;
}
}
}
std::vector<platform::Place> places_;
std::vector<Scope *> local_scopes_;
Scope *global_scope_; // not owned
......@@ -60,6 +70,13 @@ class ParallelExecutorPrivate {
bool own_local_scope_;
bool use_cuda_;
bool use_all_reduce_;
// rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, rt_cur_ref_cnts_ is reset to ref_cnts_
std::vector<details::ReferenceCountMap> rt_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> rt_cur_ref_cnts_;
details::GarbageCollectorList gcs_;
};
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
......@@ -128,35 +145,56 @@ ParallelExecutor::ParallelExecutor(
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get());
#else
std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, member_->places_, loss_var_name,
params, member_->local_scopes_, member_->use_cuda_);
#endif
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
for (auto &place : member_->places_) {
if (!platform::is_gpu_place(place)) continue;
auto gpu_place = boost::get<platform::CUDAPlace>(place);
if (gcs_[gpu_place.device] == nullptr) {
ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap());
cur_ref_cnts_[gpu_place.device].reset(
new details::AtomicReferenceCountMap());
gcs_[gpu_place.device].reset(
new StreamGarbageCollector<Tensor>(gpu_place, max_memory_size));
}
}
if (!gcs_.empty()) {
size_t place_num = member_->places_.size();
for (size_t i = 0; i < place_num; ++i) {
auto &place = member_->places_[i];
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
member_->gcs_.emplace_back(new StreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size));
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
} else if (platform::is_cpu_place(place)) {
#endif
member_->gcs_.emplace_back(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place), max_memory_size));
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
#ifdef PADDLE_WITH_CUDA
}
#endif
}
}
if (!member_->gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&(member_->rt_ref_cnts_));
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
VLOG(10) << "ReferenceCountPass Applied";
graph = ref_cnt_pass->Apply(std::move(graph));
graph->SetNotOwned("garbage_collector", &gcs_);
}
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kCurReferenceCount,
&(member_->rt_cur_ref_cnts_));
eager_deletion_pass->SetNotOwned(details::kGarbageCollector,
&(member_->gcs_));
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied";
}
#else
std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, member_->places_, loss_var_name,
params, member_->local_scopes_, member_->use_cuda_);
#endif
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
......@@ -271,18 +309,16 @@ void ParallelExecutor::BCastParamsToDevices(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
platform::RecordBlock b(0);
#ifdef PADDLE_WITH_CUDA
if (!gcs_.empty()) {
ResetReferenceCount();
for (auto &pair : cur_ref_cnts_) {
auto &name_map = *(pair.second);
if (!member_->gcs_.empty()) {
member_->ResetRuntimeReferenceCount();
size_t n = member_->rt_ref_cnts_.size();
for (size_t i = 0; i < n; ++i) {
for (auto &fetch_name : fetch_tensors) {
name_map.erase(fetch_name);
member_->rt_cur_ref_cnts_[i].erase(fetch_name);
}
name_map.erase(fetched_var_name);
member_->rt_cur_ref_cnts_[i].erase(fetched_var_name);
}
}
#endif
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
......@@ -326,13 +362,11 @@ ParallelExecutor::~ParallelExecutor() {
for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
// member_ must be destructed before gcs_ since the destructor of
// ReferenceCountOpHandle use raw pointers of gcs_ inside.
member_.reset();
delete member_;
}
} // namespace framework
} // namespace paddle
#ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass);
#endif
USE_PASS(eager_deletion_pass);
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -29,10 +28,6 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_pass.h"
#endif
namespace paddle {
namespace framework {
......@@ -75,24 +70,7 @@ class ParallelExecutor {
private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
std::unique_ptr<ParallelExecutorPrivate> member_;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, cur_ref_cnts_ is reset to ref_cnts_
details::DeviceReferenceCountMap ref_cnts_;
details::AtomicDeviceReferenceCountMap cur_ref_cnts_;
details::DeviceGarbageCollectorMap gcs_;
void ResetReferenceCount() {
for (auto &pair1 : ref_cnts_) {
for (auto &pair2 : *(pair1.second)) {
(*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second;
}
}
}
#endif
ParallelExecutorPrivate *member_;
};
} // namespace framework
......
......@@ -56,9 +56,16 @@ ELSE()
set(MKLDNN_CTX_DEPS)
ENDIF()
nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce)
IF(WITH_GPU)
set(STREAM_CALLBACK_DEPS stream_callback_manager)
ELSE()
set(STREAM_CALLBACK_DEPS)
ENDIF()
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/stream_callback_manager.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
struct StreamCallbackContext {
inline StreamCallbackContext(const StreamCallbackManager *manager,
std::function<void()> callback)
: manager_(manager), callback_(std::move(callback)) {}
const StreamCallbackManager *manager_; // do not own
std::function<void()> callback_;
};
StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream)
: stream_(stream), thread_pool_(new ::ThreadPool(1)) {}
void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
auto *stream_callback_context =
new StreamCallbackContext(this, std::move(callback));
#if CUDA_VERSION >= 10000
PADDLE_ENFORCE(cudaLaunchHostFunc(stream_,
StreamCallbackManager::StreamCallbackFunc,
stream_callback_context));
#else
PADDLE_ENFORCE(
cudaStreamAddCallback(stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0));
#endif
}
void StreamCallbackManager::Wait() const {
thread_pool_.reset(new ::ThreadPool(1));
}
#if CUDA_VERSION >= 10000
void CUDART_CB StreamCallbackManager::StreamCallbackFunc(void *user_data)
#else
void CUDART_CB StreamCallbackManager::StreamCallbackFunc(cudaStream_t stream,
cudaError_t status,
void *user_data)
#endif
{
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue(
[callback_context_ptr]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_();
});
}
} // namespace platform
} // namespace paddle
......@@ -19,66 +19,29 @@
#include <cuda_runtime.h>
#include <functional>
#include <memory>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
class StreamCallbackManager;
struct StreamCallbackContext {
template <typename Callback>
inline StreamCallbackContext(const StreamCallbackManager *manager,
Callback &&callback)
: manager_(manager), callback_(callback) {}
const StreamCallbackManager *manager_; // do not own
std::function<void()> callback_;
};
// NOTE(zjl): clean StreamCallback to make compilation faster
class StreamCallbackManager {
public:
explicit inline StreamCallbackManager(cudaStream_t stream = nullptr)
: stream_(stream), thread_pool_(new ThreadPool(1)) {}
explicit StreamCallbackManager(const cudaStream_t stream);
template <typename Callback>
inline void AddCallback(Callback &&callback) const {
auto *stream_callback_context =
new StreamCallbackContext(this, std::forward<Callback>(callback));
#if CUDA_VERSION >= 10000
PADDLE_ENFORCE(cudaLaunchHostFunc(stream_,
StreamCallbackManager::StreamCallbackFunc,
stream_callback_context)); // NOLINT
#else
PADDLE_ENFORCE(cudaStreamAddCallback(
stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0)); // NOLINT
#endif
}
void AddCallback(std::function<void()> callback) const;
void Wait() const { thread_pool_.reset(new ThreadPool(1)); }
void Wait() const;
private:
const cudaStream_t stream_;
mutable std::unique_ptr<ThreadPool> thread_pool_;
mutable std::unique_ptr<::ThreadPool> thread_pool_;
// cudaStreamCallback cannot call CUDA API inside, so we have to use
// thread_pool here
#if CUDA_VERSION >= 10000
static void CUDART_CB StreamCallbackFunc(void *user_data)
static void CUDART_CB StreamCallbackFunc(void *user_data);
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status, void *user_data)
cudaError_t status, void *user_data);
#endif
{
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue([=]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_();
});
}
};
} // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册