未验证 提交 bdaf15ad 编写于 作者: W winter-wang 提交者: GitHub

[IR] add replaceUsesWithIf interface for value (#54795)

上级 98a165bf
...@@ -130,7 +130,7 @@ void Operation::Destroy() { ...@@ -130,7 +130,7 @@ void Operation::Destroy() {
detail::OpOperandImpl *first_use = detail::OpOperandImpl *first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use(); reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
while (first_use != nullptr) { while (first_use != nullptr) {
first_use->remove_from_ud_chain(); first_use->RemoveFromUdChain();
first_use = first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use(); reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/value_impl.h" #include "paddle/ir/core/value_impl.h"
namespace ir { namespace ir {
...@@ -37,6 +38,11 @@ OpOperand OpOperand::next_use() const { return impl_->next_use(); } ...@@ -37,6 +38,11 @@ OpOperand OpOperand::next_use() const { return impl_->next_use(); }
Value OpOperand::source() const { return impl_->source(); } Value OpOperand::source() const { return impl_->source(); }
void OpOperand::set_source(Value value) {
IR_ENFORCE(impl_, "Can't set source for a null value.");
impl_->set_source(value);
}
Operation *OpOperand::owner() const { return impl_->owner(); } Operation *OpOperand::owner() const { return impl_->owner(); }
// Value // Value
...@@ -55,30 +61,44 @@ bool Value::operator!() const { return impl_ == nullptr; } ...@@ -55,30 +61,44 @@ bool Value::operator!() const { return impl_ == nullptr; }
Value::operator bool() const { return impl_; } Value::operator bool() const { return impl_; }
detail::ValueImpl *Value::impl() const { return impl_; } ir::Type Value::type() const { return impl()->type(); }
ir::Type Value::type() const { return impl_->type(); }
void Value::SetType(ir::Type type) { impl_->SetType(type); } void Value::set_type(ir::Type type) { impl()->set_type(type); }
Operation *Value::GetDefiningOp() const { Operation *Value::GetDefiningOp() const {
if (auto result = dyn_cast<OpResult>()) return result.owner(); if (auto result = dyn_cast<OpResult>()) return result.owner();
return nullptr; return nullptr;
} }
std::string Value::print_ud_chain() { return impl_->print_ud_chain(); } std::string Value::PrintUdChain() { return impl()->PrintUdChain(); }
Value::use_iterator Value::begin() const { Value::use_iterator Value::begin() const { return ir::OpOperand(first_use()); }
return ir::OpOperand(impl_->first_use());
}
Value::use_iterator Value::end() const { return Value::use_iterator(); } Value::use_iterator Value::end() const { return Value::use_iterator(); }
OpOperand Value::first_use() const { return impl()->first_use(); } OpOperand Value::first_use() const { return impl()->first_use(); }
bool Value::use_empty() const { return !first_use(); }
void Value::ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const {
for (auto it = begin(); it != end();) {
auto cur = it++;
if (should_replace(*cur)) {
cur->set_source(new_value);
}
}
}
detail::ValueImpl *Value::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return impl_;
}
// OpResult // OpResult
bool OpResult::classof(Value value) { bool OpResult::classof(Value value) {
return ir::isa<detail::OpResultImpl>(value.impl()); return value && ir::isa<detail::OpResultImpl>(value.impl());
} }
Operation *OpResult::owner() const { return impl()->owner(); } Operation *OpResult::owner() const { return impl()->owner(); }
...@@ -103,22 +123,33 @@ ir::detail::OpOperandImpl *OpOperandImpl::next_use() { return next_use_; } ...@@ -103,22 +123,33 @@ ir::detail::OpOperandImpl *OpOperandImpl::next_use() { return next_use_; }
ir::Value OpOperandImpl::source() const { return source_; } ir::Value OpOperandImpl::source() const { return source_; }
void OpOperandImpl::release_source() { source_ = nullptr; } void OpOperandImpl::set_source(Value source) {
RemoveFromUdChain();
if (!source) {
return;
}
source_ = source;
InsertToUdChain();
}
OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner) OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner)
: source_(source), owner_(owner) { : source_(source), owner_(owner) {
if (!source) { if (!source) {
return; return;
} }
prev_use_addr_ = source.impl()->first_use_addr(); InsertToUdChain();
next_use_ = source.impl()->first_use(); }
void OpOperandImpl::InsertToUdChain() {
prev_use_addr_ = source_.impl()->first_use_addr();
next_use_ = source_.impl()->first_use();
if (next_use_) { if (next_use_) {
next_use_->prev_use_addr_ = &next_use_; next_use_->prev_use_addr_ = &next_use_;
} }
source.impl()->SetFirstUse(this); source_.impl()->SetFirstUse(this);
} }
void OpOperandImpl::remove_from_ud_chain() { void OpOperandImpl::RemoveFromUdChain() {
if (!source_) return; if (!source_) return;
if (!prev_use_addr_) return; if (!prev_use_addr_) return;
if (prev_use_addr_ == source_.impl()->first_use_addr()) { if (prev_use_addr_ == source_.impl()->first_use_addr()) {
...@@ -134,10 +165,10 @@ void OpOperandImpl::remove_from_ud_chain() { ...@@ -134,10 +165,10 @@ void OpOperandImpl::remove_from_ud_chain() {
} }
next_use_ = nullptr; next_use_ = nullptr;
prev_use_addr_ = nullptr; prev_use_addr_ = nullptr;
release_source(); source_ = nullptr;
} }
OpOperandImpl::~OpOperandImpl() { remove_from_ud_chain(); } OpOperandImpl::~OpOperandImpl() { RemoveFromUdChain(); }
uint32_t ValueImpl::index() const { uint32_t ValueImpl::index() const {
uint32_t index = uint32_t index =
...@@ -147,7 +178,7 @@ uint32_t ValueImpl::index() const { ...@@ -147,7 +178,7 @@ uint32_t ValueImpl::index() const {
->GetResultIndex(); ->GetResultIndex();
} }
std::string ValueImpl::print_ud_chain() { std::string ValueImpl::PrintUdChain() {
std::stringstream result; std::stringstream result;
result << "Value[" << this << "] -> "; result << "Value[" << this << "] -> ";
OpOperandImpl *tmp = first_use(); OpOperandImpl *tmp = first_use();
......
...@@ -55,9 +55,9 @@ class IR_API OpOperand { ...@@ -55,9 +55,9 @@ class IR_API OpOperand {
Value source() const; Value source() const;
Operation *owner() const; void set_source(Value value);
// detail::OpOperandImpl *impl() const { return impl_;} Operation *owner() const;
private: private:
detail::OpOperandImpl *impl_{nullptr}; detail::OpOperandImpl *impl_{nullptr};
...@@ -80,9 +80,9 @@ class ValueUseIterator { ...@@ -80,9 +80,9 @@ class ValueUseIterator {
ir::Operation *owner() const { return current_.owner(); } ir::Operation *owner() const { return current_.owner(); }
OperandType get() const { return current_; } OperandType &operator*() { return current_; }
OperandType operator*() const { return get(); } OperandType *operator->() { return &operator*(); }
ValueUseIterator<OperandType> &operator++() { ValueUseIterator<OperandType> &operator++() {
current_ = current_.next_use(); current_ = current_.next_use();
...@@ -129,15 +129,13 @@ class IR_API Value { ...@@ -129,15 +129,13 @@ class IR_API Value {
return ir::dyn_cast<U>(*this); return ir::dyn_cast<U>(*this);
} }
detail::ValueImpl *impl() const; Type type() const;
ir::Type type() const;
void SetType(ir::Type type); void set_type(Type type);
Operation *GetDefiningOp() const; Operation *GetDefiningOp() const;
std::string print_ud_chain(); std::string PrintUdChain();
/// ///
/// \brief Provide iterator interface to access Value use chain. /// \brief Provide iterator interface to access Value use chain.
...@@ -152,6 +150,16 @@ class IR_API Value { ...@@ -152,6 +150,16 @@ class IR_API Value {
friend struct std::hash<Value>; friend struct std::hash<Value>;
bool use_empty() const;
void ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const;
// The interface shoule ensure impl_ isn't nullptr.
// if the user can accept impl_ is nullptr, shoule use impl_ member directly.
detail::ValueImpl *impl() const;
protected: protected:
detail::ValueImpl *impl_{nullptr}; detail::ValueImpl *impl_{nullptr};
}; };
......
...@@ -33,10 +33,10 @@ class OpOperandImpl { ...@@ -33,10 +33,10 @@ class OpOperandImpl {
ir::Value source() const; ir::Value source() const;
void release_source(); void set_source(Value value);
/// Remove this operand from the current use list. /// Remove this operand from the current use list.
void remove_from_ud_chain(); void RemoveFromUdChain();
~OpOperandImpl(); ~OpOperandImpl();
...@@ -45,13 +45,17 @@ class OpOperandImpl { ...@@ -45,13 +45,17 @@ class OpOperandImpl {
private: private:
OpOperandImpl(ir::Value source, ir::Operation *owner); OpOperandImpl(ir::Value source, ir::Operation *owner);
// Insert self to the UD chain holded by source_;
// It is not safe. So set provate.
void InsertToUdChain();
ir::detail::OpOperandImpl *next_use_ = nullptr; ir::detail::OpOperandImpl *next_use_ = nullptr;
ir::detail::OpOperandImpl **prev_use_addr_ = nullptr; ir::detail::OpOperandImpl **prev_use_addr_ = nullptr;
ir::Value source_; ir::Value source_;
ir::Operation *owner_ = nullptr; ir::Operation *const owner_ = nullptr;
}; };
/// ///
...@@ -69,7 +73,7 @@ class alignas(8) ValueImpl { ...@@ -69,7 +73,7 @@ class alignas(8) ValueImpl {
/// ///
ir::Type type() const { return type_; } ir::Type type() const { return type_; }
void SetType(ir::Type type) { type_ = type; } void set_type(ir::Type type) { type_ = type; }
/// ///
/// \brief Interface functions of "first_use_offseted_by_index_" attribute. /// \brief Interface functions of "first_use_offseted_by_index_" attribute.
...@@ -94,7 +98,7 @@ class alignas(8) ValueImpl { ...@@ -94,7 +98,7 @@ class alignas(8) ValueImpl {
bool use_empty() const { return first_use() == nullptr; } bool use_empty() const { return first_use() == nullptr; }
std::string print_ud_chain(); std::string PrintUdChain();
protected: protected:
/// ///
......
...@@ -44,6 +44,7 @@ TEST(value_test, value_test) { ...@@ -44,6 +44,7 @@ TEST(value_test, value_test) {
op1_output_types, op1_output_types,
ir::OpInfo()); ir::OpInfo());
op1->Print(std::cout); op1->Print(std::cout);
ir::OpResult a = op1->result(0);
// 2. Construct OP2: b = OP2(); // 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {}; std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
...@@ -53,8 +54,9 @@ TEST(value_test, value_test) { ...@@ -53,8 +54,9 @@ TEST(value_test, value_test) {
op2_output_types, op2_output_types,
ir::OpInfo()); ir::OpInfo());
op2->Print(std::cout); op2->Print(std::cout);
ir::OpResult b = op2->result(0);
// 3. Construct OP3: c = OP3(a, b); // 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->result(0), op2->result(0)}; std::vector<ir::OpResult> op3_inputs{a, b};
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op3 = ir::Operation *op3 =
ir::Operation::Create(op3_inputs, ir::Operation::Create(op3_inputs,
...@@ -62,8 +64,9 @@ TEST(value_test, value_test) { ...@@ -62,8 +64,9 @@ TEST(value_test, value_test) {
op3_output_types, op3_output_types,
ir::OpInfo()); ir::OpInfo());
op3->Print(std::cout); op3->Print(std::cout);
ir::OpResult c = op3->result(0);
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->result(0), op3->result(0)}; std::vector<ir::OpResult> op4_inputs = {a, c};
std::vector<ir::Type> op4_output_types; std::vector<ir::Type> op4_output_types;
for (size_t i = 0; i < 7; i++) { for (size_t i = 0; i < 7; i++) {
op4_output_types.push_back(ir::Float32Type::get(ctx)); op4_output_types.push_back(ir::Float32Type::get(ctx));
...@@ -98,13 +101,20 @@ TEST(value_test, value_test) { ...@@ -98,13 +101,20 @@ TEST(value_test, value_test) {
++iter; ++iter;
EXPECT_EQ(iter.owner(), op3); EXPECT_EQ(iter.owner(), op3);
// Test 4: Value Replace Use
// a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c);
//
c.ReplaceUsesWithIf(a, [](ir::OpOperand) { return true; });
EXPECT_EQ(op4->operand(1).source(), a);
EXPECT_TRUE(c.use_empty());
// destroy // destroy
VLOG(0) << op1->result(0).print_ud_chain() << std::endl; VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op4->Destroy(); op4->Destroy();
VLOG(0) << op1->result(0).print_ud_chain() << std::endl; VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op3->Destroy(); op3->Destroy();
VLOG(0) << op1->result(0).print_ud_chain() << std::endl; VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op2->Destroy(); op2->Destroy();
VLOG(0) << op1->result(0).print_ud_chain() << std::endl; VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op1->Destroy(); op1->Destroy();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册