提交 cb2d33a8 编写于 作者: X Xin Pan

resolve conflict

test=develop
上级 25123a3b
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h" #include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( ...@@ -35,10 +36,10 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl( std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const { std::unique_ptr<ir::Graph> ir_graph) const {
auto &all_ops = ir_graph->Get<GraphOps>(kGraphOps); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops); OpGraphView graph_view(all_ops);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue; if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free = bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view); IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
......
...@@ -20,19 +20,16 @@ namespace paddle { ...@@ -20,19 +20,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
OpGraphView::OpGraphView( OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) { void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
for (auto &op : ops) { for (auto &op : ops) {
preceding_ops_[op.get()]; preceding_ops_[op];
pending_ops_[op.get()]; pending_ops_[op];
for (auto &var : op->Outputs()) { for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) { for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op.get()); preceding_ops_[pending_op].insert(op);
pending_ops_[op.get()].insert(pending_op); pending_ops_[op].insert(pending_op);
} }
} }
} }
...@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) { ...@@ -41,8 +38,6 @@ void OpGraphView::Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops) {
"There are duplicate ops in graph."); "There are duplicate ops in graph.");
} }
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const { std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret; std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) { for (auto &pair : preceding_ops_) {
...@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const { ...@@ -60,12 +55,6 @@ void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
op == nullptr ? "nullptr" : op->DebugString()); op == nullptr ? "nullptr" : op->DebugString());
} }
const std::unordered_set<OpHandleBase *> &OpGraphView::PrecedingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps( const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const { OpHandleBase *op) const {
EnforceHasOp(op); EnforceHasOp(op);
......
...@@ -26,21 +26,16 @@ namespace details { ...@@ -26,21 +26,16 @@ namespace details {
class OpGraphView { class OpGraphView {
public: public:
explicit OpGraphView(const std::vector<std::unique_ptr<OpHandleBase>> &ops); explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
size_t OpNumber() const;
std::unordered_set<OpHandleBase *> AllOps() const; std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PrecedingOps(
OpHandleBase *op) const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const; const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const; bool HasOp(OpHandleBase *op) const;
private: private:
void Build(const std::vector<std::unique_ptr<OpHandleBase>> &ops); void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const; void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>> std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册