未验证 提交 bdaf15ad 编写于 作者: W winter-wang 提交者: GitHub

[IR] add replaceUsesWithIf interface for value (#54795)

上级 98a165bf
......@@ -130,7 +130,7 @@ void Operation::Destroy() {
detail::OpOperandImpl *first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
while (first_use != nullptr) {
first_use->remove_from_ud_chain();
first_use->RemoveFromUdChain();
first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/ir/core/value.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/value_impl.h"
namespace ir {
......@@ -37,6 +38,11 @@ OpOperand OpOperand::next_use() const { return impl_->next_use(); }
Value OpOperand::source() const { return impl_->source(); }
void OpOperand::set_source(Value value) {
IR_ENFORCE(impl_, "Can't set source for a null value.");
impl_->set_source(value);
}
Operation *OpOperand::owner() const { return impl_->owner(); }
// Value
......@@ -55,30 +61,44 @@ bool Value::operator!() const { return impl_ == nullptr; }
Value::operator bool() const { return impl_; }
detail::ValueImpl *Value::impl() const { return impl_; }
ir::Type Value::type() const { return impl_->type(); }
ir::Type Value::type() const { return impl()->type(); }
void Value::SetType(ir::Type type) { impl_->SetType(type); }
void Value::set_type(ir::Type type) { impl()->set_type(type); }
Operation *Value::GetDefiningOp() const {
if (auto result = dyn_cast<OpResult>()) return result.owner();
return nullptr;
}
std::string Value::print_ud_chain() { return impl_->print_ud_chain(); }
std::string Value::PrintUdChain() { return impl()->PrintUdChain(); }
Value::use_iterator Value::begin() const {
return ir::OpOperand(impl_->first_use());
}
Value::use_iterator Value::begin() const { return ir::OpOperand(first_use()); }
Value::use_iterator Value::end() const { return Value::use_iterator(); }
OpOperand Value::first_use() const { return impl()->first_use(); }
bool Value::use_empty() const { return !first_use(); }
void Value::ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const {
for (auto it = begin(); it != end();) {
auto cur = it++;
if (should_replace(*cur)) {
cur->set_source(new_value);
}
}
}
detail::ValueImpl *Value::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return impl_;
}
// OpResult
bool OpResult::classof(Value value) {
return ir::isa<detail::OpResultImpl>(value.impl());
return value && ir::isa<detail::OpResultImpl>(value.impl());
}
Operation *OpResult::owner() const { return impl()->owner(); }
......@@ -103,22 +123,33 @@ ir::detail::OpOperandImpl *OpOperandImpl::next_use() { return next_use_; }
ir::Value OpOperandImpl::source() const { return source_; }
void OpOperandImpl::release_source() { source_ = nullptr; }
void OpOperandImpl::set_source(Value source) {
RemoveFromUdChain();
if (!source) {
return;
}
source_ = source;
InsertToUdChain();
}
OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner)
: source_(source), owner_(owner) {
if (!source) {
return;
}
prev_use_addr_ = source.impl()->first_use_addr();
next_use_ = source.impl()->first_use();
InsertToUdChain();
}
void OpOperandImpl::InsertToUdChain() {
prev_use_addr_ = source_.impl()->first_use_addr();
next_use_ = source_.impl()->first_use();
if (next_use_) {
next_use_->prev_use_addr_ = &next_use_;
}
source.impl()->SetFirstUse(this);
source_.impl()->SetFirstUse(this);
}
void OpOperandImpl::remove_from_ud_chain() {
void OpOperandImpl::RemoveFromUdChain() {
if (!source_) return;
if (!prev_use_addr_) return;
if (prev_use_addr_ == source_.impl()->first_use_addr()) {
......@@ -134,10 +165,10 @@ void OpOperandImpl::remove_from_ud_chain() {
}
next_use_ = nullptr;
prev_use_addr_ = nullptr;
release_source();
source_ = nullptr;
}
OpOperandImpl::~OpOperandImpl() { remove_from_ud_chain(); }
OpOperandImpl::~OpOperandImpl() { RemoveFromUdChain(); }
uint32_t ValueImpl::index() const {
uint32_t index =
......@@ -147,7 +178,7 @@ uint32_t ValueImpl::index() const {
->GetResultIndex();
}
std::string ValueImpl::print_ud_chain() {
std::string ValueImpl::PrintUdChain() {
std::stringstream result;
result << "Value[" << this << "] -> ";
OpOperandImpl *tmp = first_use();
......
......@@ -55,9 +55,9 @@ class IR_API OpOperand {
Value source() const;
Operation *owner() const;
void set_source(Value value);
// detail::OpOperandImpl *impl() const { return impl_;}
Operation *owner() const;
private:
detail::OpOperandImpl *impl_{nullptr};
......@@ -80,9 +80,9 @@ class ValueUseIterator {
ir::Operation *owner() const { return current_.owner(); }
OperandType get() const { return current_; }
OperandType &operator*() { return current_; }
OperandType operator*() const { return get(); }
OperandType *operator->() { return &operator*(); }
ValueUseIterator<OperandType> &operator++() {
current_ = current_.next_use();
......@@ -129,15 +129,13 @@ class IR_API Value {
return ir::dyn_cast<U>(*this);
}
detail::ValueImpl *impl() const;
ir::Type type() const;
Type type() const;
void SetType(ir::Type type);
void set_type(Type type);
Operation *GetDefiningOp() const;
std::string print_ud_chain();
std::string PrintUdChain();
///
/// \brief Provide iterator interface to access Value use chain.
......@@ -152,6 +150,16 @@ class IR_API Value {
friend struct std::hash<Value>;
bool use_empty() const;
void ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const;
// The interface shoule ensure impl_ isn't nullptr.
// if the user can accept impl_ is nullptr, shoule use impl_ member directly.
detail::ValueImpl *impl() const;
protected:
detail::ValueImpl *impl_{nullptr};
};
......
......@@ -33,10 +33,10 @@ class OpOperandImpl {
ir::Value source() const;
void release_source();
void set_source(Value value);
/// Remove this operand from the current use list.
void remove_from_ud_chain();
void RemoveFromUdChain();
~OpOperandImpl();
......@@ -45,13 +45,17 @@ class OpOperandImpl {
private:
OpOperandImpl(ir::Value source, ir::Operation *owner);
// Insert self to the UD chain holded by source_;
// It is not safe. So set provate.
void InsertToUdChain();
ir::detail::OpOperandImpl *next_use_ = nullptr;
ir::detail::OpOperandImpl **prev_use_addr_ = nullptr;
ir::Value source_;
ir::Operation *owner_ = nullptr;
ir::Operation *const owner_ = nullptr;
};
///
......@@ -69,7 +73,7 @@ class alignas(8) ValueImpl {
///
ir::Type type() const { return type_; }
void SetType(ir::Type type) { type_ = type; }
void set_type(ir::Type type) { type_ = type; }
///
/// \brief Interface functions of "first_use_offseted_by_index_" attribute.
......@@ -94,7 +98,7 @@ class alignas(8) ValueImpl {
bool use_empty() const { return first_use() == nullptr; }
std::string print_ud_chain();
std::string PrintUdChain();
protected:
///
......
......@@ -44,6 +44,7 @@ TEST(value_test, value_test) {
op1_output_types,
ir::OpInfo());
op1->Print(std::cout);
ir::OpResult a = op1->result(0);
// 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
......@@ -53,8 +54,9 @@ TEST(value_test, value_test) {
op2_output_types,
ir::OpInfo());
op2->Print(std::cout);
ir::OpResult b = op2->result(0);
// 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->result(0), op2->result(0)};
std::vector<ir::OpResult> op3_inputs{a, b};
std::vector<ir::Type> op3_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op3 =
ir::Operation::Create(op3_inputs,
......@@ -62,8 +64,9 @@ TEST(value_test, value_test) {
op3_output_types,
ir::OpInfo());
op3->Print(std::cout);
ir::OpResult c = op3->result(0);
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->result(0), op3->result(0)};
std::vector<ir::OpResult> op4_inputs = {a, c};
std::vector<ir::Type> op4_output_types;
for (size_t i = 0; i < 7; i++) {
op4_output_types.push_back(ir::Float32Type::get(ctx));
......@@ -98,13 +101,20 @@ TEST(value_test, value_test) {
++iter;
EXPECT_EQ(iter.owner(), op3);
// Test 4: Value Replace Use
// a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c);
//
c.ReplaceUsesWithIf(a, [](ir::OpOperand) { return true; });
EXPECT_EQ(op4->operand(1).source(), a);
EXPECT_TRUE(c.use_empty());
// destroy
VLOG(0) << op1->result(0).print_ud_chain() << std::endl;
VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op4->Destroy();
VLOG(0) << op1->result(0).print_ud_chain() << std::endl;
VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op3->Destroy();
VLOG(0) << op1->result(0).print_ud_chain() << std::endl;
VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op2->Destroy();
VLOG(0) << op1->result(0).print_ud_chain() << std::endl;
VLOG(0) << op1->result(0).PrintUdChain() << std::endl;
op1->Destroy();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册