提交 2e149999 编写于 作者: X Xin Pan

clean1

test=develop
上级 34b401fc
...@@ -36,9 +36,9 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -36,9 +36,9 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
for (auto &op : ops) { for (auto &op : ops) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op.get(), dep); op_deps_.emplace(op, dep);
if (dep == 0) { if (dep == 0) {
bootstrap_ops_.emplace_back(op.get()); bootstrap_ops_.emplace_back(op);
} }
} }
...@@ -54,13 +54,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -54,13 +54,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
......
...@@ -31,8 +31,8 @@ struct TestGatherOpHandle { ...@@ -31,8 +31,8 @@ struct TestGatherOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
void WaitAll() { void WaitAll() {
...@@ -84,8 +84,8 @@ struct TestGatherOpHandle { ...@@ -84,8 +84,8 @@ struct TestGatherOpHandle {
nodes.emplace_back( nodes.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_.reset( op_handle_ =
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_);
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
...@@ -102,7 +102,7 @@ struct TestGatherOpHandle { ...@@ -102,7 +102,7 @@ struct TestGatherOpHandle {
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
in_dummy_var_handle->ClearGeneratedOp(); in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
...@@ -119,7 +119,7 @@ struct TestGatherOpHandle { ...@@ -119,7 +119,7 @@ struct TestGatherOpHandle {
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
op_handle_->AddOutput(dummy_var_handle); op_handle_->AddOutput(dummy_var_handle);
} }
......
...@@ -36,20 +36,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -36,20 +36,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair);
} }
} }
} }
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) { for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get()); insert_pending_var(var);
} }
for (auto &op : graph->Get<GraphOps>(kGraphOps)) { for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); pending_ops.insert({op, op->NoDupInputSize()});
} }
} }
......
...@@ -93,7 +93,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, ...@@ -93,7 +93,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
} }
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
var = var_holder.rbegin()->get(); var = *var_holder.rbegin();
} }
return var; return var;
} }
...@@ -155,7 +155,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ...@@ -155,7 +155,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -498,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -498,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
auto *in = auto *in =
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -535,7 +535,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( ...@@ -535,7 +535,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
for (auto &p_name : bcast_varnames[dev_id]) { for (auto &p_name : bcast_varnames[dev_id]) {
auto *in = auto *in =
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) { for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
auto &p = places_[out_dev_id]; auto &p = places_[out_dev_id];
...@@ -571,7 +571,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -571,7 +571,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -579,7 +579,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -579,7 +579,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad);
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
...@@ -600,14 +600,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -600,14 +600,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) { for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back());
auto var = new VarHandle( auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p); vars.size(), i, d_name, p);
...@@ -691,7 +691,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -691,7 +691,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -699,7 +699,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -699,7 +699,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad);
} }
auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
auto var = auto var =
...@@ -760,14 +760,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp( ...@@ -760,14 +760,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
} }
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (ir::Node *input : node->inputs) { for (ir::Node *input : node->inputs) {
VarHandle *var = nullptr; VarHandle *var = nullptr;
for (int place_offset = 0; place_offset < num_places; ++place_offset) { for (int place_offset = 0; place_offset < num_places; ++place_offset) {
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset]; auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[input->Name()]; auto &var_holder = var_holders[input->Name()];
if (!var_holder.empty()) { if (!var_holder.empty()) {
var = var_holder.rbegin()->get(); var = *var_holder.rbegin();
op_handle->AddInput(var); op_handle->AddInput(var);
} }
} }
...@@ -840,7 +840,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -840,7 +840,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from // send_barrier, recv, fetch_barrier's inputs are deps var, get them from
// all places // all places
auto p = places_[op_dev_id]; auto p = places_[op_dev_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
......
...@@ -36,18 +36,17 @@ namespace details { ...@@ -36,18 +36,17 @@ namespace details {
// map from variable name to variables. The variables, who have the same name, // map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the // will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles. // `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
typedef std::vector< typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars"; const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<VarHandleBase*> GraphDepVars;
const char kGraphDepVars[] = "dep_vars"; const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is // all operators. NOTE that even we use a vector here, the operators is
// unordered. // unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps; typedef std::vector<OpHandleBase*> GraphOps;
const char kGraphOps[] = "ops"; const char kGraphOps[] = "ops";
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -31,7 +31,9 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; ...@@ -31,7 +31,9 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node. // It's responsible for populating necessary fields of ir::Node.
class OpHandleBase { class OpHandleBase {
public: public:
explicit OpHandleBase(ir::Node *node) : node_(node) {} explicit OpHandleBase(ir::Node *node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~OpHandleBase(); virtual ~OpHandleBase();
......
...@@ -71,14 +71,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -71,14 +71,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables // Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops // in computation ops
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>> std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map; compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&]( auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op; std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr || if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace())) !platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op; return var_names_in_op;
...@@ -121,9 +120,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -121,9 +120,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
}; };
auto update_ref_cnts_from_non_compute_op = [&]( auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) { if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) { for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base); auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
...@@ -151,7 +149,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -151,7 +149,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get()); AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
} }
} }
} }
...@@ -165,7 +163,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -165,7 +163,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if (in_var_names.empty() && out_var_names.empty()) continue; if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(), in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end()); out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace()); auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node = ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
...@@ -173,7 +171,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -173,7 +171,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node, compute_op->GetScope(), place, in_var_names, ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
AddDependencyBetween(compute_op, ref_cnt_handle, graph.get()); AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); compute_ref_cnt_map[compute_op] = ref_cnt_handle;
} }
for (auto &op : all_ops) { for (auto &op : all_ops) {
...@@ -181,11 +179,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -181,11 +179,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
update_ref_cnts_from_non_compute_op(op, op->Outputs()); update_ref_cnts_from_non_compute_op(op, op->Outputs());
} }
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops; std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) { for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op)); new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) { if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle // Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
......
...@@ -19,8 +19,7 @@ namespace framework { ...@@ -19,8 +19,7 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
......
...@@ -38,8 +38,7 @@ class SSAGraphExecutor { ...@@ -38,8 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -51,25 +51,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -51,25 +51,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, ready_vars.get(), version_pair.get()); InsertPendingVar(&pending_vars, ready_vars.get(), version_pair);
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var.get()); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) { for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
InsertPendingOp(&pending_ops, op.get()); InsertPendingOp(&pending_ops, op);
} }
} }
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
...@@ -140,8 +140,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -140,8 +140,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) { BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
...@@ -151,7 +151,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -151,7 +151,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
......
...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
BlockingQueue<VarHandleBase *> *ready_vars, BlockingQueue<VarHandleBase *> *ready_vars,
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps( void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
const std::vector<std::string> &fetch_tensors, std::vector<FetchOpHandle *> *fetch_ops,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data); BlockingQueue<VarHandleBase *> *ready_vars,
FeedFetchList *fetch_data);
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
......
...@@ -35,7 +35,9 @@ class OpHandleBase; ...@@ -35,7 +35,9 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
// This is a single assignment graph. // This is a single assignment graph.
struct VarHandleBase { struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {} explicit VarHandleBase(ir::Node* node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~VarHandleBase(); virtual ~VarHandleBase();
......
...@@ -27,6 +27,8 @@ namespace ir { ...@@ -27,6 +27,8 @@ namespace ir {
// Node should normally created by Graph::CreateXXXNode(). // Node should normally created by Graph::CreateXXXNode().
class Node { class Node {
public: public:
virtual ~Node() {}
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var"; static constexpr char kControlDepVarName[] = "__control_var";
...@@ -44,6 +46,20 @@ class Node { ...@@ -44,6 +46,20 @@ class Node {
return op_desc_.get(); return op_desc_.get();
} }
template <typename T>
void WrappedBy(T* wrapper) {
if (!wrapper_.empty()) {
wrapper_deleter_();
}
wrapper_ = wrapper;
wrapper_deleter_ = [wrapper]() { delete wrapper; };
}
template <typename T>
T& Wrapper() {
return *boost::any_cast<T*>(wrapper_);
}
// Please don't use this API! // Please don't use this API!
int id() const { return id_; } int id() const { return id_; }
...@@ -95,6 +111,10 @@ class Node { ...@@ -95,6 +111,10 @@ class Node {
static int count_; static int count_;
// Please don't use this API or make this public. // Please don't use this API or make this public.
static void ResetId() { count_ = 0; } static void ResetId() { count_ = 0; }
boost::any wrapper_;
std::function<void(void)> wrapper_deleter_;
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册