diff --git a/paddle/ir/core/block.cc b/paddle/ir/core/block.cc index a0f74cfa6a61b3313744b99c81e5b53604b67383..5012c7a2eb07f43b6b56bd11fb1e4b72f080a271 100644 --- a/paddle/ir/core/block.cc +++ b/paddle/ir/core/block.cc @@ -38,4 +38,10 @@ void Block::clear() { ops_.pop_back(); } } + +void Block::SetParent(Region *parent, Region::iterator position) { + parent_ = parent; + position_ = position; +} + } // namespace ir diff --git a/paddle/ir/core/block.h b/paddle/ir/core/block.h index 5c4bf08019d3578843f218355883114429d0f0f6..ef87ac2871e312dbb6c01d3b3b56e2d28ce0ece8 100644 --- a/paddle/ir/core/block.h +++ b/paddle/ir/core/block.h @@ -17,8 +17,9 @@ #include #include +#include "paddle/ir/core/region.h" + namespace ir { -class Region; class Operation; class Block { @@ -47,16 +48,19 @@ class Block { void push_front(Operation *op); iterator insert(const_iterator iterator, Operation *op); void clear(); + operator Region::iterator() { return position_; } private: Block(Block &) = delete; Block &operator=(const Block &) = delete; + // Allow access to 'SetParent'. friend class Region; - void SetParent(Region *parent) { parent_ = parent; } + void SetParent(Region *parent, Region::iterator position); private: - Region *parent_; // not owned + Region *parent_; // not owned + Region::iterator position_; std::list ops_; // owned }; } // namespace ir diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index db733137ad27143fc09dcd97ac9267458a017341..af62d2b14cf4be958cf7d3012f1d817709b9b725 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -236,4 +236,9 @@ Region &Operation::GetRegion(unsigned index) { return regions_[index]; } +void Operation::SetParent(Block *parent, const Block::iterator &position) { + parent_ = parent; + position_ = position; +} + } // namespace ir diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index b4506ceb659f311b19d852b468de83f3d5282dbd..c887c7025b1f0101d9093ae776441f593cd1a2af 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -112,11 +112,9 @@ class alignas(8) Operation final { } }; + // Allow access to 'SetParent'. friend class Block; - void SetParent(Block *parent, const Block::iterator &position) { - parent_ = parent; - position_ = position; - } + void SetParent(Block *parent, const Block::iterator &position); template struct CastUtil< diff --git a/paddle/ir/core/region.cc b/paddle/ir/core/region.cc index 854df9cf9bb7bfb25e02e97e36b56b70fd130ca5..cc94d6936901f53a88f78a409183849dd867ada2 100644 --- a/paddle/ir/core/region.cc +++ b/paddle/ir/core/region.cc @@ -18,27 +18,22 @@ namespace ir { Region::~Region() { clear(); } -void Region::push_back(Block *block) { - block->SetParent(this); - blocks_.push_back(block); -} +void Region::push_back(Block *block) { insert(blocks_.end(), block); } void Region::emplace_back() { push_back(new Block); } -void Region::push_front(Block *block) { - block->SetParent(this); - blocks_.push_front(block); -} +void Region::push_front(Block *block) { insert(blocks_.begin(), block); } Region::iterator Region::insert(const_iterator position, Block *block) { - block->SetParent(this); - return blocks_.insert(position, block); + Region::iterator iter = blocks_.insert(position, block); + block->SetParent(this, iter); + return iter; } void Region::TakeBody(Region &&other) { clear(); blocks_.swap(other.blocks_); - for (auto &block : blocks_) { - block->SetParent(this); + for (auto iter = blocks_.begin(); iter != blocks_.end(); ++iter) { + (*iter)->SetParent(this, iter); } }