未验证 提交 5406699d 编写于 作者: Y Yuanle Liu 提交者: GitHub

[IR&PASS] fix constant_folding_pass and add use_count api for Value (#56967)

上级 d121cf29
......@@ -92,14 +92,13 @@ class ConstantFoldingPattern : public ir::RewritePattern {
}
// Execute program
paddle::framework::interpreter::ExecutionConfig exe_config;
exe_config.create_local_scope = false;
exe_config_.create_local_scope = false;
paddle::framework::InterpreterCore core(
phi::CPUPlace{},
fetch_var_names,
paddle::dialect::PdOpLowerToKernelPass(temp_program.get()),
&scope_,
exe_config);
exe_config_);
paddle::framework::FetchList fetch_list = core.Run({});
......@@ -112,6 +111,7 @@ class ConstantFoldingPattern : public ir::RewritePattern {
std::string param_name =
"@constant_folding_pass@_" + std::to_string(suffix_++);
exe_config_.skip_gc_vars.insert(param_name);
auto* param_var = scope_.Var(param_name);
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
......@@ -180,13 +180,11 @@ class ConstantFoldingPattern : public ir::RewritePattern {
}
private:
static size_t suffix_;
static paddle::framework::Scope scope_;
inline static size_t suffix_{0};
inline static paddle::framework::Scope scope_{};
inline static paddle::framework::interpreter::ExecutionConfig exe_config_{};
};
size_t ConstantFoldingPattern::suffix_ = 0;
paddle::framework::Scope ConstantFoldingPattern::scope_ = {};
class ConstantFoldingPass : public ir::Pass {
public:
// TODO(liuyuanle): Naming convention for pass.
......
......@@ -13,6 +13,9 @@
// limitations under the License.
#include "paddle/ir/core/value.h"
#include <cstddef>
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value_impl.h"
......@@ -128,6 +131,12 @@ bool Value::HasOneUse() const {
return impl_->HasOneUse();
}
size_t Value::use_count() const {
size_t count = 0;
for (auto it = use_begin(); it != use_end(); ++it) count++;
return count;
}
void Value::ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const {
......
......@@ -123,6 +123,8 @@ class IR_API Value {
bool HasOneUse() const;
size_t use_count() const;
friend struct std::hash<Value>;
void ReplaceUsesWithIf(
......
......@@ -78,7 +78,7 @@ class IR_API Pass {
virtual ~Pass();
std::string name() const { return pass_info().name; }
const std::string& name() const { return pass_info().name; }
const detail::PassInfo& pass_info() const { return pass_info_; }
......
......@@ -131,11 +131,15 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter {
for (uint32_t i = 0; i < op->num_operands(); ++i) {
AddOperandToWorklist(op->operand_source(i));
}
for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i);
for (auto& block : region) {
for (auto& op_item : *block) {
RemoveFromWorklist(op_item);
if (op->num_regions() == 0) {
RemoveFromWorklist(op);
} else {
for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i);
for (auto& block : region) {
for (auto& op_item : *block) {
RemoveFromWorklist(op_item);
}
}
}
}
......
......@@ -1097,7 +1097,7 @@ TEST(pattern_rewrite, Patterns) {
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
// pm.AddPass(ir::CreateConstantFoldingPass());
pm.AddPass(ir::CreateConstantFoldingPass());
pm.AddPass(ir::CreateDeadCodeEliminationPass());
pm.EnablePassTiming();
pm.EnableIRPrinting();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册