未验证 提交 5406699d 编写于 作者: Y Yuanle Liu 提交者: GitHub

[IR&PASS] fix constant_folding_pass and add use_count api for Value (#56967)

上级 d121cf29
...@@ -92,14 +92,13 @@ class ConstantFoldingPattern : public ir::RewritePattern { ...@@ -92,14 +92,13 @@ class ConstantFoldingPattern : public ir::RewritePattern {
} }
// Execute program // Execute program
paddle::framework::interpreter::ExecutionConfig exe_config; exe_config_.create_local_scope = false;
exe_config.create_local_scope = false;
paddle::framework::InterpreterCore core( paddle::framework::InterpreterCore core(
phi::CPUPlace{}, phi::CPUPlace{},
fetch_var_names, fetch_var_names,
paddle::dialect::PdOpLowerToKernelPass(temp_program.get()), paddle::dialect::PdOpLowerToKernelPass(temp_program.get()),
&scope_, &scope_,
exe_config); exe_config_);
paddle::framework::FetchList fetch_list = core.Run({}); paddle::framework::FetchList fetch_list = core.Run({});
...@@ -112,6 +111,7 @@ class ConstantFoldingPattern : public ir::RewritePattern { ...@@ -112,6 +111,7 @@ class ConstantFoldingPattern : public ir::RewritePattern {
std::string param_name = std::string param_name =
"@constant_folding_pass@_" + std::to_string(suffix_++); "@constant_folding_pass@_" + std::to_string(suffix_++);
exe_config_.skip_gc_vars.insert(param_name);
auto* param_var = scope_.Var(param_name); auto* param_var = scope_.Var(param_name);
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>(); auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
...@@ -180,13 +180,11 @@ class ConstantFoldingPattern : public ir::RewritePattern { ...@@ -180,13 +180,11 @@ class ConstantFoldingPattern : public ir::RewritePattern {
} }
private: private:
static size_t suffix_; inline static size_t suffix_{0};
static paddle::framework::Scope scope_; inline static paddle::framework::Scope scope_{};
inline static paddle::framework::interpreter::ExecutionConfig exe_config_{};
}; };
size_t ConstantFoldingPattern::suffix_ = 0;
paddle::framework::Scope ConstantFoldingPattern::scope_ = {};
class ConstantFoldingPass : public ir::Pass { class ConstantFoldingPass : public ir::Pass {
public: public:
// TODO(liuyuanle): Naming convention for pass. // TODO(liuyuanle): Naming convention for pass.
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include <cstddef>
#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value_impl.h" #include "paddle/ir/core/value_impl.h"
...@@ -128,6 +131,12 @@ bool Value::HasOneUse() const { ...@@ -128,6 +131,12 @@ bool Value::HasOneUse() const {
return impl_->HasOneUse(); return impl_->HasOneUse();
} }
size_t Value::use_count() const {
size_t count = 0;
for (auto it = use_begin(); it != use_end(); ++it) count++;
return count;
}
void Value::ReplaceUsesWithIf( void Value::ReplaceUsesWithIf(
Value new_value, Value new_value,
const std::function<bool(OpOperand)> &should_replace) const { const std::function<bool(OpOperand)> &should_replace) const {
......
...@@ -123,6 +123,8 @@ class IR_API Value { ...@@ -123,6 +123,8 @@ class IR_API Value {
bool HasOneUse() const; bool HasOneUse() const;
size_t use_count() const;
friend struct std::hash<Value>; friend struct std::hash<Value>;
void ReplaceUsesWithIf( void ReplaceUsesWithIf(
......
...@@ -78,7 +78,7 @@ class IR_API Pass { ...@@ -78,7 +78,7 @@ class IR_API Pass {
virtual ~Pass(); virtual ~Pass();
std::string name() const { return pass_info().name; } const std::string& name() const { return pass_info().name; }
const detail::PassInfo& pass_info() const { return pass_info_; } const detail::PassInfo& pass_info() const { return pass_info_; }
......
...@@ -131,6 +131,9 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter { ...@@ -131,6 +131,9 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter {
for (uint32_t i = 0; i < op->num_operands(); ++i) { for (uint32_t i = 0; i < op->num_operands(); ++i) {
AddOperandToWorklist(op->operand_source(i)); AddOperandToWorklist(op->operand_source(i));
} }
if (op->num_regions() == 0) {
RemoveFromWorklist(op);
} else {
for (uint32_t i = 0; i < op->num_regions(); ++i) { for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i); auto& region = op->region(i);
for (auto& block : region) { for (auto& block : region) {
...@@ -139,6 +142,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter { ...@@ -139,6 +142,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter {
} }
} }
} }
}
if (config_.strict_mode != ir::GreedyRewriteStrictness::AnyOp) { if (config_.strict_mode != ir::GreedyRewriteStrictness::AnyOp) {
strict_mode_filtered_ops_.erase(op); strict_mode_filtered_ops_.erase(op);
......
...@@ -1097,7 +1097,7 @@ TEST(pattern_rewrite, Patterns) { ...@@ -1097,7 +1097,7 @@ TEST(pattern_rewrite, Patterns) {
ir::PassManager pm(ctx); ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>()); pm.AddPass(std::make_unique<TestPass>());
// pm.AddPass(ir::CreateConstantFoldingPass()); pm.AddPass(ir::CreateConstantFoldingPass());
pm.AddPass(ir::CreateDeadCodeEliminationPass()); pm.AddPass(ir::CreateDeadCodeEliminationPass());
pm.EnablePassTiming(); pm.EnablePassTiming();
pm.EnableIRPrinting(); pm.EnableIRPrinting();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册