From cb2d33a8518bf68780196898e13782e892f55ea5 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 6 Nov 2018 17:48:45 +0800 Subject: [PATCH] resolve conflict test=develop --- .../modify_op_lock_and_record_event_pass.cc | 5 ++-- .../fluid/framework/details/op_graph_view.cc | 23 +++++-------------- .../fluid/framework/details/op_graph_view.h | 9 ++------ 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc index 169ce3ae7..67aad9f94 100644 --- a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc +++ b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/op_graph_view.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( std::unique_ptr ModifyOpLockAndRecordEventPass::ApplyImpl( std::unique_ptr ir_graph) const { - auto &all_ops = ir_graph->Get(kGraphOps); + auto all_ops = ir::FilterByNodeWrapper(*ir_graph); OpGraphView graph_view(all_ops); for (auto &op : all_ops) { - auto *compute_op = dynamic_cast(op.get()); + auto *compute_op = dynamic_cast(op); if (compute_op == nullptr) continue; bool is_lock_and_record_event_free = IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view); diff --git a/paddle/fluid/framework/details/op_graph_view.cc b/paddle/fluid/framework/details/op_graph_view.cc index 65dafd376..4838c4198 100644 --- a/paddle/fluid/framework/details/op_graph_view.cc +++ b/paddle/fluid/framework/details/op_graph_view.cc @@ -20,19 +20,16 @@ namespace paddle { namespace framework { namespace details { -OpGraphView::OpGraphView( - const std::vector> &ops) { - Build(ops); -} +OpGraphView::OpGraphView(const std::vector &ops) { Build(ops); } -void OpGraphView::Build(const std::vector> &ops) { +void OpGraphView::Build(const std::vector &ops) { for (auto &op : ops) { - preceding_ops_[op.get()]; - pending_ops_[op.get()]; + preceding_ops_[op]; + pending_ops_[op]; for (auto &var : op->Outputs()) { for (auto &pending_op : var->PendingOps()) { - preceding_ops_[pending_op].insert(op.get()); - pending_ops_[op.get()].insert(pending_op); + preceding_ops_[pending_op].insert(op); + pending_ops_[op].insert(pending_op); } } } @@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector> &ops) { "There are duplicate ops in graph."); } -size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); } - std::unordered_set OpGraphView::AllOps() const { std::unordered_set ret; for (auto &pair : preceding_ops_) { @@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const { op == nullptr ? "nullptr" : op->DebugString()); } -const std::unordered_set &OpGraphView::PrecedingOps( - OpHandleBase *op) const { - EnforceHasOp(op); - return preceding_ops_.at(op); -} - const std::unordered_set &OpGraphView::PendingOps( OpHandleBase *op) const { EnforceHasOp(op); diff --git a/paddle/fluid/framework/details/op_graph_view.h b/paddle/fluid/framework/details/op_graph_view.h index 398c019be..afb3e8e59 100644 --- a/paddle/fluid/framework/details/op_graph_view.h +++ b/paddle/fluid/framework/details/op_graph_view.h @@ -26,21 +26,16 @@ namespace details { class OpGraphView { public: - explicit OpGraphView(const std::vector> &ops); - - size_t OpNumber() const; + explicit OpGraphView(const std::vector &ops); std::unordered_set AllOps() const; - const std::unordered_set &PrecedingOps( - OpHandleBase *op) const; - const std::unordered_set &PendingOps(OpHandleBase *op) const; bool HasOp(OpHandleBase *op) const; private: - void Build(const std::vector> &ops); + void Build(const std::vector &ops); void EnforceHasOp(OpHandleBase *op) const; std::unordered_map> -- GitLab