Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
72b8c7c2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
72b8c7c2
编写于
6月 27, 2023
作者:
W
Wilber
提交者:
GitHub
6月 27, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IR&PASS] part 3-3: Add PatternRewrite Driver code. (#54738)
上级
813266a2
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
620 addition
and
81 deletion
+620
-81
paddle/ir/core/operation.cc
paddle/ir/core/operation.cc
+8
-0
paddle/ir/core/operation.h
paddle/ir/core/operation.h
+8
-0
paddle/ir/core/region.cc
paddle/ir/core/region.cc
+6
-0
paddle/ir/core/region.h
paddle/ir/core/region.h
+3
-0
paddle/ir/core/value.cc
paddle/ir/core/value.cc
+2
-0
paddle/ir/core/value.h
paddle/ir/core/value.h
+4
-2
paddle/ir/core/value_impl.h
paddle/ir/core/value_impl.h
+4
-0
paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h
paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h
+2
-1
paddle/ir/pattern_rewrite/pattern_match.cc
paddle/ir/pattern_rewrite/pattern_match.cc
+50
-39
paddle/ir/pattern_rewrite/pattern_match.h
paddle/ir/pattern_rewrite/pattern_match.h
+29
-23
paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc
paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc
+227
-0
paddle/ir/pattern_rewrite/pattern_rewrite_driver.h
paddle/ir/pattern_rewrite/pattern_rewrite_driver.h
+86
-0
test/cpp/ir/core/ir_op_test.cc
test/cpp/ir/core/ir_op_test.cc
+1
-0
test/cpp/ir/core/ir_value_test.cc
test/cpp/ir/core/ir_value_test.cc
+5
-0
test/cpp/ir/pattern_rewrite/CMakeLists.txt
test/cpp/ir/pattern_rewrite/CMakeLists.txt
+8
-1
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
+177
-15
未找到文件。
paddle/ir/core/operation.cc
浏览文件 @
72b8c7c2
...
...
@@ -224,4 +224,12 @@ void Operation::SetParent(Block *parent, const Block::iterator &position) {
position_
=
position
;
}
void
Operation
::
ReplaceAllUsesWith
(
const
std
::
vector
<
Value
>
&
values
)
{
IR_ENFORCE
(
num_results_
==
values
.
size
(),
"the num of result should be the same."
);
for
(
uint32_t
i
=
0
;
i
<
num_results_
;
++
i
)
{
result
(
i
).
ReplaceAllUsesWith
(
values
[
i
]);
}
}
}
// namespace ir
paddle/ir/core/operation.h
浏览文件 @
72b8c7c2
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <ostream>
#include <vector>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h"
...
...
@@ -102,6 +103,13 @@ class IR_API alignas(8) Operation final {
operator
Block
::
const_iterator
()
const
{
return
position_
;
}
/// Replace all uses of results of this operation with the provided 'values'.
void
ReplaceAllUsesWith
(
const
std
::
vector
<
Value
>
&
values
);
inline
void
ReplaceAllUsesWith
(
Value
value
)
{
ReplaceAllUsesWith
(
std
::
vector
<
Value
>
{
value
});
}
private:
Operation
(
const
AttributeMap
&
attribute
,
ir
::
OpInfo
op_info
,
...
...
paddle/ir/core/region.cc
浏览文件 @
72b8c7c2
...
...
@@ -15,6 +15,7 @@
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
namespace
ir
{
Region
::~
Region
()
{
clear
();
}
...
...
@@ -50,4 +51,9 @@ void Region::clear() {
blocks_
.
pop_back
();
}
}
IrContext
*
Region
::
ir_context
()
const
{
IR_ENFORCE
(
parent_
,
"Region is not attached to a container."
);
return
parent_
->
ir_context
();
}
}
// namespace ir
paddle/ir/core/region.h
浏览文件 @
72b8c7c2
...
...
@@ -23,6 +23,7 @@ namespace ir {
class
Block
;
class
Operation
;
class
IrContext
;
class
IR_API
Region
{
public:
...
...
@@ -55,6 +56,8 @@ class IR_API Region {
Operation
*
GetParent
()
const
{
return
parent_
;
}
IrContext
*
ir_context
()
const
;
private:
Region
(
Region
&
)
=
delete
;
Region
&
operator
=
(
const
Region
&
)
=
delete
;
...
...
paddle/ir/core/value.cc
浏览文件 @
72b8c7c2
...
...
@@ -85,6 +85,8 @@ OpOperand Value::first_use() const { return impl()->first_use(); }
bool
Value
::
use_empty
()
const
{
return
!
first_use
();
}
bool
Value
::
HasOneUse
()
const
{
return
impl
()
->
HasOneUse
();
}
void
Value
::
ReplaceUsesWithIf
(
Value
new_value
,
const
std
::
function
<
bool
(
OpOperand
)
>
&
should_replace
)
const
{
...
...
paddle/ir/core/value.h
浏览文件 @
72b8c7c2
...
...
@@ -158,10 +158,12 @@ class IR_API Value {
OpOperand
first_use
()
const
;
friend
struct
std
::
hash
<
Value
>
;
bool
use_empty
()
const
;
bool
HasOneUse
()
const
;
friend
struct
std
::
hash
<
Value
>
;
void
ReplaceUsesWithIf
(
Value
new_value
,
const
std
::
function
<
bool
(
OpOperand
)
>
&
should_replace
)
const
;
...
...
paddle/ir/core/value_impl.h
浏览文件 @
72b8c7c2
...
...
@@ -98,6 +98,10 @@ class alignas(8) ValueImpl {
bool
use_empty
()
const
{
return
first_use
()
==
nullptr
;
}
bool
HasOneUse
()
const
{
return
(
first_use
()
!=
nullptr
)
&&
(
first_use
()
->
next_use
()
==
nullptr
);
}
std
::
string
PrintUdChain
();
protected:
...
...
paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h
浏览文件 @
72b8c7c2
...
...
@@ -21,12 +21,13 @@
#include <unordered_map>
#include <vector>
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace
ir
{
class
FrozenRewritePatternSet
{
class
IR_API
FrozenRewritePatternSet
{
using
NativePatternListT
=
std
::
vector
<
std
::
unique_ptr
<
RewritePattern
>>
;
public:
...
...
paddle/ir/pattern_rewrite/pattern_match.cc
浏览文件 @
72b8c7c2
...
...
@@ -15,9 +15,9 @@
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
namespace
ir
{
...
...
@@ -90,44 +90,55 @@ RewritePattern::~RewritePattern() = default;
//===----------------------------------------------------------------------===//
RewriterBase
::~
RewriterBase
()
=
default
;
// TODO(wilber): value support replace method.
// void RewriterBase::ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// bool* all_uses_replaced,
// std::function<bool(OpOperand&)> functor) {
// // assert(op->num_results() == new_values.size() && "incorrect number of
// values to replace operation"); NotifyRootReplaced(op, new_values); bool
// replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) {
// // op->result(0)
// }
// }
// void RewriterBase::ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// std::function<bool(OpOperand&)> functor) {
// ReplaceOpWithIf(op, new_values, nullptr, functor);
// }
// TODO(wilber): support erase.
// void ReplaceOp(Operation* op, ValueRange new_values) {
// NotifyRootReplaced(op, new_values);
// assert(op->num_results() == new_values.size() && "incorrect # of
// replacement values"); op->ReplaceAllUsesWith(new_values);
// NotifyOperationRemoved(op);
// op->erase();
// }
void
RewriterBase
::
ReplaceOpWithIf
(
Operation
*
op
,
const
std
::
vector
<
Value
>&
new_values
,
bool
*
all_uses_replaced
,
const
std
::
function
<
bool
(
OpOperand
)
>&
functor
)
{
IR_ENFORCE
(
op
->
num_results
()
==
new_values
.
size
(),
"incorrect number of values to replace operation"
);
NotifyRootReplaced
(
op
,
new_values
);
// Replace each use of the results when the functor is true.
bool
replace_all_uses
=
true
;
for
(
uint32_t
i
=
0
;
i
<
op
->
num_results
();
++
i
)
{
auto
src_res
=
op
->
result
(
i
);
src_res
.
ReplaceUsesWithIf
(
new_values
[
i
],
functor
);
replace_all_uses
&=
src_res
.
use_empty
();
}
if
(
replace_all_uses
)
{
*
all_uses_replaced
=
replace_all_uses
;
}
}
void
RewriterBase
::
ReplaceOpWithIf
(
Operation
*
op
,
const
std
::
vector
<
Value
>&
new_values
,
const
std
::
function
<
bool
(
OpOperand
)
>&
functor
)
{
ReplaceOpWithIf
(
op
,
new_values
,
nullptr
,
functor
);
}
void
RewriterBase
::
ReplaceOp
(
Operation
*
op
,
const
std
::
vector
<
Value
>&
new_values
)
{
NotifyRootReplaced
(
op
,
new_values
);
IR_ENFORCE
(
op
->
num_results
()
==
new_values
.
size
(),
"incorrect # of replacement values"
);
op
->
ReplaceAllUsesWith
(
new_values
);
NotifyOperationRemoved
(
op
);
op
->
GetParent
()
->
erase
(
*
op
);
}
void
RewriterBase
::
EraseOp
(
Operation
*
op
)
{
// assert(op->use_empty() && "expected 'op' to have no uses");
// NotifyOperationRemoved(op);
// op->erase();
// TODO(wilber): Operation support use_empty.
// IR_ENFORCE(op->use_empty(), "expected 'op' to have no uses");
NotifyOperationRemoved
(
op
);
op
->
GetParent
()
->
erase
(
*
op
);
}
/// Find uses of `from` and replace it with `to`
void
RewriterBase
::
ReplaceAllUsesWith
(
Value
from
,
Value
to
)
{
// from.
// for (OpOperand& operand : llvm::make_early_inc_range(from.getUses()))
// {
// Operation* op = operand.getOwner();
// UpdateRootInPlace(op, [&]() { operand.set(to); });
// }
// TODO(wilber): Substitue a low level impl.
from
.
ReplaceAllUsesWith
(
to
);
}
// TODO(wilber): iterator maybe should support modify inplace.
...
...
@@ -135,8 +146,8 @@ void RewriterBase::ReplaceUseIf(Value from,
Value
to
,
std
::
function
<
bool
(
OpOperand
&
)
>
functor
)
{
// for (auto it = from.begin(); it != from.end(); ++it) {
// // TODO: need a lvalue.
// if (functor(
it.get()
)) {
//
//
// TODO: need a lvalue.
// if (functor(
*it
)) {
// UpdateRootInplace(it.owner(), [&](){it.get().set(to)});
// }
// }
...
...
@@ -144,8 +155,8 @@ void RewriterBase::ReplaceUseIf(Value from,
void
RewriterBase
::
ReplaceOpWithResultsOfAnotherOp
(
Operation
*
op
,
Operation
*
new_op
)
{
assert
(
op
->
num_results
()
==
new_op
->
num_results
()
&&
"replacement op doesn't match results of original op"
);
IR_ENFORCE
(
op
->
num_results
()
==
new_op
->
num_results
(),
"replacement op doesn't match results of original op"
);
// TODO(wilber): Op support results method.
// if (op->num_results() == 1) return ReplaceOp(op,
// new_op->result(0)); return ReplaceOp(op, new_op->GetResults());
...
...
paddle/ir/pattern_rewrite/pattern_match.h
浏览文件 @
72b8c7c2
...
...
@@ -25,6 +25,7 @@
#include <vector>
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
...
...
@@ -36,7 +37,7 @@ namespace ir {
// This class reprensents the benefit of a pattern. The most common
// unit to use is the `numver of operations` in the pattern.
class
PatternBenefit
{
class
IR_API
PatternBenefit
{
public:
PatternBenefit
()
=
default
;
PatternBenefit
(
uint32_t
val
)
:
val_
(
val
)
{}
// NOLINT
...
...
@@ -257,30 +258,21 @@ class RewriterBase : public Builder {
public:
// TODO(wilber): Supplementary methods of block and region.
// TODO(wilber): Support ValueRange.
// virtual void ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// bool* all_uses_replaced,
// std::function<bool(OpOperand&)> functor);
// void ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// std::function<bool(OpOperand&)> functor);
// virtual void ReplaceOp(Operation* op, ValueRange new_values);
virtual
void
ReplaceOpWithIf
(
Operation
*
op
,
const
std
::
vector
<
Value
>&
new_values
,
bool
*
all_uses_replaced
,
const
std
::
function
<
bool
(
OpOperand
)
>&
functor
);
// virtual void ReplaceOpWithNewOp()
void
ReplaceOpWithIf
(
Operation
*
op
,
const
std
::
vector
<
Value
>&
new_values
,
const
std
::
function
<
bool
(
OpOperand
)
>&
functor
);
virtual
void
EraseOp
(
Operation
*
op
);
virtual
void
ReplaceOp
(
Operation
*
op
,
const
std
::
vector
<
Value
>&
new_values
);
virtual
void
StartRootUpdate
(
Operation
*
op
)
{}
virtual
void
FinalizeRootUpdate
(
Operation
*
op
)
{}
virtual
void
CancleRootUpdate
(
Operation
*
op
)
{}
// template <typename OpTy, typename... Args>
// OpTy ReplaceOpWithNewOp(Operation *op, Args &&...args);
template
<
typename
CallableT
>
void
UpdateRootInplace
(
Operation
*
root
,
CallableT
&&
callable
)
{
StartRootUpdate
(
root
);
callable
();
FinalizeRootUpdate
(
root
);
}
virtual
void
EraseOp
(
Operation
*
op
);
void
ReplaceAllUsesWith
(
Value
from
,
Value
to
);
...
...
@@ -293,11 +285,25 @@ class RewriterBase : public Builder {
virtual
~
RewriterBase
();
// virtual void NotifyRootReplaced(Operation* op, ValueRange replacement) {}
virtual
void
NotifyRootReplaced
(
Operation
*
op
,
const
std
::
vector
<
Value
>&
replacement
)
{}
virtual
void
NotifyOperationRemoved
(
Operation
*
op
)
{}
// virtual bool NotifyMatchFailure()
virtual
void
NotifyOperationInserted
(
Operation
*
op
)
{}
virtual
void
StartRootUpdate
(
Operation
*
op
)
{}
virtual
void
FinalizeRootUpdate
(
Operation
*
op
)
{}
virtual
void
CancleRootUpdate
(
Operation
*
op
)
{}
template
<
typename
CallableT
>
void
UpdateRootInplace
(
Operation
*
root
,
CallableT
&&
callable
)
{
StartRootUpdate
(
root
);
callable
();
FinalizeRootUpdate
(
root
);
}
private:
void
operator
=
(
const
RewriterBase
&
)
=
delete
;
...
...
paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc
0 → 100644
浏览文件 @
72b8c7c2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include <unordered_map>
#include <unordered_set>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/value.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace
{
class
GreedyPatternRewriteDriver
:
public
ir
::
PatternRewriter
{
public:
explicit
GreedyPatternRewriteDriver
(
ir
::
IrContext
*
ctx
,
const
ir
::
FrozenRewritePatternSet
&
patterns
,
const
ir
::
GreedyRewriteConfig
&
config
)
:
ir
::
PatternRewriter
(
ctx
),
config_
(
config
),
region_
(
*
config
.
region
),
matcher_
(
patterns
)
{
worklist_
.
reserve
(
128
);
matcher_
.
ApplyDefaultCostModel
();
if
(
config
.
strict_mode
!=
ir
::
GreedyRewriteStrictness
::
AnyOp
)
{
for
(
auto
it
=
region_
.
begin
();
it
!=
region_
.
end
();
++
it
)
{
for
(
auto
op_it
=
(
*
it
)
->
begin
();
op_it
!=
(
*
it
)
->
end
();
++
op_it
)
{
strict_mode_filtered_ops_
.
insert
(
*
op_it
);
}
}
}
}
bool
Simplify
()
{
bool
changed
=
false
;
int64_t
iteration
=
0
;
do
{
// Check if the iteration limit was reached.
if
(
iteration
++
>=
config_
.
max_iterations
&&
config_
.
max_iterations
!=
ir
::
GreedyRewriteConfig
::
kNoLimit
)
break
;
VLOG
(
6
)
<<
"Iteration["
<<
iteration
<<
"] for PatternRewrite"
;
worklist_
.
clear
();
worklist_map_
.
clear
();
for
(
auto
block_it
=
region_
.
begin
();
block_it
!=
region_
.
end
();
++
block_it
)
{
for
(
auto
op_it
=
(
*
block_it
)
->
begin
();
op_it
!=
(
*
block_it
)
->
end
();
++
op_it
)
{
worklist_
.
push_back
(
*
op_it
);
}
}
if
(
config_
.
use_top_down_traversal
)
{
// Reverse the list so out pop-back loop process them in-order.
std
::
reverse
(
worklist_
.
begin
(),
worklist_
.
end
());
}
for
(
size_t
i
=
0
;
i
<
worklist_
.
size
();
++
i
)
{
worklist_map_
[
worklist_
[
i
]]
=
i
;
VLOG
(
6
)
<<
"worklist["
<<
i
<<
"] is "
<<
worklist_
[
i
]
->
name
();
}
changed
=
ProcessWorklist
();
}
while
(
changed
);
return
!
changed
;
}
private:
/// Process ops until the worklist is empty or `config.max_num_rewrites`
/// is reached. Return `true` if any IR was changed.
bool
ProcessWorklist
()
{
bool
changed
=
false
;
int64_t
num_rewrites
=
0
;
while
(
!
worklist_
.
empty
()
&&
(
num_rewrites
<
config_
.
max_num_rewrites
||
config_
.
max_num_rewrites
==
ir
::
GreedyRewriteConfig
::
kNoLimit
))
{
auto
*
op
=
PopFromWorklist
();
if
(
op
==
nullptr
)
continue
;
VLOG
(
6
)
<<
"PopFromWorklist, get op: "
<<
op
->
name
();
// TODO(wilber): ir is dead.
// ...
// TODO(wilber): fold logical.
// ...
bool
match_result
=
matcher_
.
MatchAndRewrite
(
op
,
*
this
);
if
(
match_result
)
{
changed
=
true
;
++
num_rewrites
;
}
}
return
changed
;
}
// TODO(wilber): OpResult support GetUsers method.
void
NotifyRootReplaced
(
ir
::
Operation
*
op
,
const
std
::
vector
<
ir
::
Value
>&
replacement
)
override
{
// for (uint32_t i = 0; i < op->num_results(); ++i) {
// auto res = op->GetResultByIndex(i);
// }
// }
}
void
FinalizeRootUpdate
(
ir
::
Operation
*
op
)
override
{
AddToWorklist
(
op
);
}
void
NotifyOperationRemoved
(
ir
::
Operation
*
op
)
override
{
for
(
uint32_t
i
=
0
;
i
<
op
->
num_operands
();
++
i
)
{
AddOperandToWorklist
(
op
->
operand
(
i
).
source
());
}
for
(
uint32_t
i
=
0
;
i
<
op
->
num_regions
();
++
i
)
{
auto
&
region
=
op
->
region
(
i
);
for
(
auto
it
=
region
.
begin
();
it
!=
region
.
end
();
++
it
)
{
for
(
auto
op_it
=
(
*
it
)
->
begin
();
op_it
!=
(
*
it
)
->
end
();
++
op_it
)
{
RemoveFromWorklist
(
*
op_it
);
}
}
}
if
(
config_
.
strict_mode
!=
ir
::
GreedyRewriteStrictness
::
AnyOp
)
{
strict_mode_filtered_ops_
.
erase
(
op
);
}
}
void
NotifyOperationInserted
(
ir
::
Operation
*
op
)
override
{
if
(
config_
.
strict_mode
==
ir
::
GreedyRewriteStrictness
::
ExistingAndNewOps
)
strict_mode_filtered_ops_
.
insert
(
op
);
AddToWorklist
(
op
);
}
/// Add the given operation to the worklist.
void
AddToWorklist
(
ir
::
Operation
*
op
)
{
if
(
config_
.
strict_mode
==
ir
::
GreedyRewriteStrictness
::
AnyOp
||
strict_mode_filtered_ops_
.
count
(
op
))
{
if
(
worklist_map_
.
count
(
op
))
return
;
worklist_map_
[
op
]
=
worklist_
.
size
();
worklist_
.
push_back
(
op
);
}
}
void
AddOperandToWorklist
(
ir
::
Value
operand
)
{
// If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist.
// This is based on the fact that zero use operations may be deleted, and
// that single use values often have more canonicalization opportunities.
if
(
!
operand
||
(
!
operand
.
use_empty
()
&&
!
operand
.
HasOneUse
()))
return
;
if
(
auto
*
def_op
=
operand
.
GetDefiningOp
())
AddToWorklist
(
def_op
);
}
void
AddOperandsToWorklist
(
const
std
::
vector
<
ir
::
Value
>
operands
)
{
for
(
auto
&
v
:
operands
)
{
AddOperandToWorklist
(
v
);
}
}
/// Pop the next operation from the worklist
ir
::
Operation
*
PopFromWorklist
()
{
auto
*
op
=
worklist_
.
back
();
worklist_
.
pop_back
();
if
(
op
)
worklist_map_
.
erase
(
op
);
return
op
;
}
/// If the specified operation is in the worklist, remove it.
void
RemoveFromWorklist
(
ir
::
Operation
*
op
)
{
auto
it
=
worklist_map_
.
find
(
op
);
if
(
it
!=
worklist_map_
.
end
())
{
worklist_
[
it
->
second
]
=
nullptr
;
worklist_map_
.
erase
(
it
);
}
}
private:
std
::
vector
<
ir
::
Operation
*>
worklist_
;
std
::
unordered_map
<
ir
::
Operation
*
,
unsigned
>
worklist_map_
;
ir
::
GreedyRewriteConfig
config_
;
std
::
unordered_set
<
ir
::
Operation
*>
strict_mode_filtered_ops_
;
ir
::
Region
&
region_
;
ir
::
PatternApplicator
matcher_
;
};
}
// namespace
namespace
ir
{
bool
ApplyPatternsGreedily
(
Region
&
region
,
// NOLINT
const
FrozenRewritePatternSet
&
patterns
,
GreedyRewriteConfig
config
)
{
if
(
!
config
.
region
)
config
.
region
=
&
region
;
GreedyPatternRewriteDriver
driver
(
region
.
ir_context
(),
patterns
,
config
);
bool
converged
=
driver
.
Simplify
();
if
(
!
converged
)
{
LOG
(
WARNING
)
<<
"The pattern rewrite did not converge after scaning "
<<
config
.
max_iterations
<<
" times"
;
}
return
converged
;
}
}
// namespace ir
paddle/ir/pattern_rewrite/pattern_rewrite_driver.h
0 → 100644
浏览文件 @
72b8c7c2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
namespace
ir
{
/// This enum will control which ops will be added to the worklist during the
/// match rewrite process
enum
class
IR_API
GreedyRewriteStrictness
{
/// No restrictions wrt. any ops are processed.
AnyOp
,
/// Only pre-existing and newly created ops are processed.
ExistingAndNewOps
,
/// Only pre-existing ops are processed.
ExistingOps
};
/// Control over how the GreedyPatternRewriteDriver works.
class
IR_API
GreedyRewriteConfig
{
public:
/// Control the way op is added to the worklist: bottom-up or top-down.
bool
use_top_down_traversal
=
false
;
/// Control the maximum number of iterations in the process of applying the
/// pattern, use `kNolimit` to represent unlimited.
int64_t
max_iterations
=
10
;
/// Control the upper limit of rewrite times during each iteration, use
/// kNoLimit to represent unlimited.
int64_t
max_num_rewrites
=
kNoLimit
;
/// Only the op inside this region will be added to the worklist.
Region
*
region
{
nullptr
};
/// Limit which ops will be added to the worklist during the Match and Rewrite
/// process.
/// - AnyOp: all ops will be added to the worklist.
/// - ExistingAndNewOps: pre-existing ops and newly created ops are added to
/// the worklist.
/// - ExistingOps: only pre-existing ops are added to the worklist.
GreedyRewriteStrictness
strict_mode
=
GreedyRewriteStrictness
::
AnyOp
;
static
constexpr
int64_t
kNoLimit
=
-
1
;
};
/// Perform the Match and Rewrite process in the specified region, greedily
/// apply the Pattern with the highest benefit, and repeat this process until
/// convergence or the upper limit of iterations.
///
/// Returns true if the iteration converges and no patterns can be applied.
bool
IR_API
ApplyPatternsGreedily
(
Region
&
region
,
// NOLINT
const
FrozenRewritePatternSet
&
patterns
,
GreedyRewriteConfig
config
=
GreedyRewriteConfig
());
/// Perform a match and rewrite process for all regions of a given op.
inline
IR_API
bool
ApplyPatternsGreedily
(
Operation
*
op
,
const
FrozenRewritePatternSet
&
patterns
,
GreedyRewriteConfig
config
=
GreedyRewriteConfig
())
{
bool
failed
=
false
;
for
(
uint32_t
i
=
0
;
i
<
op
->
num_regions
();
++
i
)
{
Region
&
region
=
op
->
region
(
i
);
failed
|=
!
ApplyPatternsGreedily
(
region
,
patterns
,
config
);
}
return
!
failed
;
}
}
// namespace ir
test/cpp/ir/core/ir_op_test.cc
浏览文件 @
72b8c7c2
...
...
@@ -256,6 +256,7 @@ TEST(op_test, region_test) {
block
->
push_front
(
op1
);
block
->
insert
(
block
->
begin
(),
op1_2
);
ir
::
Operation
*
op2
=
ir
::
Operation
::
Create
(
std
::
move
(
argument
));
EXPECT_EQ
(
op2
->
region
(
0
).
ir_context
(),
ctx
);
op2
->
Destroy
();
}
...
...
test/cpp/ir/core/ir_value_test.cc
浏览文件 @
72b8c7c2
...
...
@@ -45,6 +45,7 @@ TEST(value_test, value_test) {
ir
::
OpInfo
());
op1
->
Print
(
std
::
cout
);
ir
::
OpResult
a
=
op1
->
result
(
0
);
EXPECT_TRUE
(
a
.
use_empty
());
// 2. Construct OP2: b = OP2();
std
::
vector
<
ir
::
OpResult
>
op2_inputs
=
{};
std
::
vector
<
ir
::
Type
>
op2_output_types
=
{
ir
::
Float32Type
::
get
(
ctx
)};
...
...
@@ -55,6 +56,7 @@ TEST(value_test, value_test) {
ir
::
OpInfo
());
op2
->
Print
(
std
::
cout
);
ir
::
OpResult
b
=
op2
->
result
(
0
);
EXPECT_TRUE
(
b
.
use_empty
());
// 3. Construct OP3: c = OP3(a, b);
std
::
vector
<
ir
::
OpResult
>
op3_inputs
{
a
,
b
};
std
::
vector
<
ir
::
Type
>
op3_output_types
=
{
ir
::
Float32Type
::
get
(
ctx
)};
...
...
@@ -63,6 +65,9 @@ TEST(value_test, value_test) {
CreateAttributeMap
(
"op3_name"
,
"op3_attr"
),
op3_output_types
,
ir
::
OpInfo
());
EXPECT_TRUE
(
op1
->
result
(
0
).
HasOneUse
());
EXPECT_TRUE
(
op2
->
result
(
0
).
HasOneUse
());
op3
->
Print
(
std
::
cout
);
ir
::
OpResult
c
=
op3
->
result
(
0
);
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
...
...
test/cpp/ir/pattern_rewrite/CMakeLists.txt
浏览文件 @
72b8c7c2
cc_test_old
(
pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS ir gtest
)
cc_test_old
(
pattern_rewrite_test
SRCS
pattern_rewrite_test.cc
DEPS
ir
pd_dialect
gtest
)
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
浏览文件 @
72b8c7c2
...
...
@@ -13,27 +13,33 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <numeric>
#include <sstream>
#include <vector>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/cast_utils.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
TEST
(
pattern_rewrite
,
PatternBenefit
)
{
ir
::
PatternBenefit
benefit1
(
1
);
EXPECT_EQ
(
benefit1
.
benefit
(),
1U
);
ir
::
PatternBenefit
benefit2
(
2
);
EXPECT_EQ
(
benefit2
.
benefit
(),
2U
);
EXPECT_TRUE
(
benefit2
>
benefit1
);
EXPECT_TRUE
(
benefit2
>=
benefit1
);
EXPECT_TRUE
(
benefit1
<
benefit2
);
EXPECT_TRUE
(
benefit1
<=
benefit2
);
EXPECT_TRUE
(
benefit1
!=
benefit2
);
ir
::
PatternBenefit
benefit3
(
2
);
EXPECT_TRUE
(
benefit2
==
benefit3
);
}
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
// Define op1.
class
Operation1
:
public
ir
::
Op
<
Operation1
>
{
...
...
@@ -95,7 +101,22 @@ class TestPatternRewrite2 : public ir::OpRewritePattern<Operation1> {
}
};
TEST
(
pattern_rewrite
,
RewritePatternSet
)
{
TEST
(
PatternRewrite
,
PatternBenefit
)
{
ir
::
PatternBenefit
benefit1
(
1
);
EXPECT_EQ
(
benefit1
.
benefit
(),
1U
);
ir
::
PatternBenefit
benefit2
(
2
);
EXPECT_EQ
(
benefit2
.
benefit
(),
2U
);
EXPECT_TRUE
(
benefit2
>
benefit1
);
EXPECT_TRUE
(
benefit2
>=
benefit1
);
EXPECT_TRUE
(
benefit1
<
benefit2
);
EXPECT_TRUE
(
benefit1
<=
benefit2
);
EXPECT_TRUE
(
benefit1
!=
benefit2
);
ir
::
PatternBenefit
benefit3
(
2
);
EXPECT_TRUE
(
benefit2
==
benefit3
);
}
TEST
(
RewritePattern
,
RewritePatternSet
)
{
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
ir
::
BuiltinDialect
>
();
auto
*
test_dialect
=
ctx
->
GetOrRegisterDialect
<
TestDialect
>
();
...
...
@@ -118,3 +139,144 @@ TEST(pattern_rewrite, RewritePatternSet) {
EXPECT_EQ
(
ps
.
native_patterns
()[
0
]
->
benefit
(),
2U
);
EXPECT_EQ
(
ps
.
native_patterns
()[
1
]
->
benefit
(),
2U
);
}
// TODO(wilber): Add actual case.
// TEST(PatternRewrite, PatternApplicator) {
// ir::IrContext *ctx = ir::IrContext::Instance();
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// auto *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
// test_dialect->RegisterOp<Operation1>();
// ir::RewritePatternSet ps(ctx);
// ps.Add<TestPatternRewrite, TestPatternRewrite2>(ctx, 2);
// ir::FrozenRewritePatternSet frozen_set(std::move(ps));
// ir::PatternApplicator applicator(frozen_set);
// applicator.ApplyDefaultCostModel();
// }
// // TODO(wilber): Add actual case.
TEST
(
PatternRewrite
,
FrozenRewritePatternSet
)
{
ir
::
FrozenRewritePatternSet
frozen_set
;
EXPECT_TRUE
(
frozen_set
.
match_any_op_native_patterns
().
empty
());
EXPECT_TRUE
(
frozen_set
.
op_specific_native_patterns
().
empty
());
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
ir
::
BuiltinDialect
>
();
auto
*
test_dialect
=
ctx
->
GetOrRegisterDialect
<
TestDialect
>
();
test_dialect
->
RegisterOp
<
Operation1
>
();
ir
::
RewritePatternSet
ps
(
ctx
);
ps
.
Add
<
TestPatternRewrite
,
TestPatternRewrite2
>
(
ctx
,
2
);
ir
::
FrozenRewritePatternSet
frozen_set2
(
std
::
move
(
ps
));
EXPECT_TRUE
(
frozen_set2
.
match_any_op_native_patterns
().
empty
());
const
auto
&
pattern_maps
=
frozen_set2
.
op_specific_native_patterns
();
EXPECT_EQ
(
pattern_maps
.
size
(),
1U
);
EXPECT_EQ
(
pattern_maps
.
at
(
ctx
->
GetRegisteredOpInfo
(
"test.Operation1"
)).
size
(),
2U
);
}
class
TransposePatternRewrite
:
public
ir
::
OpRewritePattern
<
paddle
::
dialect
::
TransposeOp
>
{
public:
using
ir
::
OpRewritePattern
<
paddle
::
dialect
::
TransposeOp
>::
OpRewritePattern
;
bool
MatchAndRewrite
(
paddle
::
dialect
::
TransposeOp
op
,
ir
::
PatternRewriter
&
rewriter
)
const
override
{
auto
prev_op
=
op
->
operand
(
0
).
source
().
GetDefiningOp
();
std
::
vector
<
int
>
axis_last
=
GetAxis
(
op
);
auto
prev_trans_op
=
prev_op
->
dyn_cast
<
paddle
::
dialect
::
TransposeOp
>
();
if
(
prev_trans_op
)
{
std
::
vector
<
int
>
axis_first
=
GetAxis
(
prev_trans_op
);
IR_ENFORCE
(
axis_first
.
size
()
==
axis_last
.
size
(),
"tranpose op's perm rank should be same."
);
auto
new_perm
=
GetPerm
(
axis_first
,
axis_last
);
rewriter
.
SetInsertionPoint
(
op
);
auto
new_op
=
rewriter
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
prev_op
->
operand
(
0
).
source
().
GetDefiningOp
()
->
result
(
0
),
new_perm
);
rewriter
.
ReplaceOp
(
op
,
{
new_op
.
out
()});
return
true
;
}
return
false
;
}
private:
std
::
vector
<
int
>
GetAxis
(
paddle
::
dialect
::
TransposeOp
op
)
const
{
auto
attr_map
=
op
->
attributes
();
ir
::
ArrayAttribute
array_attr
=
attr_map
.
at
(
"perm"
).
dyn_cast
<
ir
::
ArrayAttribute
>
();
std
::
vector
<
int
>
axis
(
array_attr
.
size
());
for
(
size_t
i
=
0
;
i
<
array_attr
.
size
();
++
i
)
{
axis
[
i
]
=
array_attr
[
i
].
dyn_cast
<
ir
::
Int32Attribute
>
().
data
();
}
return
axis
;
}
std
::
vector
<
int
>
GetPerm
(
const
std
::
vector
<
int
>
&
perm1
,
const
std
::
vector
<
int
>
&
perm2
)
const
{
int
n
=
perm1
.
size
();
std
::
vector
<
int
>
axis
(
n
),
axis1
(
n
),
axis2
(
n
);
std
::
iota
(
axis
.
begin
(),
axis
.
end
(),
0
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
axis1
[
i
]
=
axis
[
perm1
[
i
]];
}
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
axis2
[
i
]
=
axis1
[
perm2
[
i
]];
}
return
axis2
;
}
};
class
TestPass
:
public
ir
::
Pass
{
public:
TestPass
()
:
ir
::
Pass
(
"TestPass"
,
1
)
{}
void
Run
(
ir
::
Operation
*
op
)
override
{
ir
::
RewritePatternSet
ps
(
op
->
ir_context
());
ps
.
Add
<
TransposePatternRewrite
>
(
op
->
ir_context
());
ir
::
FrozenRewritePatternSet
frozen_ps
(
std
::
move
(
ps
));
ir
::
GreedyRewriteConfig
cfg
;
cfg
.
use_top_down_traversal
=
true
;
cfg
.
max_iterations
=
1
;
ir
::
ApplyPatternsGreedily
(
op
->
region
(
0
),
frozen_ps
,
cfg
);
}
bool
CanApplyOn
(
ir
::
Operation
*
op
)
const
override
{
return
op
->
name
()
==
"builtin.module"
&&
op
->
num_regions
()
>
0
;
}
};
void
BuildProgram
(
ir
::
Builder
&
builder
)
{
// NOLINT
paddle
::
dialect
::
FullOp
full_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
1
,
3
,
16
,
16
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
ir
::
OpResult
full_op_output
=
full_op
->
result
(
0
);
auto
transpose1_op
=
builder
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
full_op_output
,
std
::
vector
<
int
>
{
0
,
2
,
3
,
1
});
builder
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
transpose1_op
.
out
(),
std
::
vector
<
int
>
{
0
,
3
,
1
,
2
});
// builder.Build<paddle::dialect::FetchOp>(transpose2_op.out());
}
// TODO(wilber): Add a normal test.
TEST
(
PatternRewrite
,
GreedyPatternRewriteDriver
)
{
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
ir
::
Program
program
(
ctx
);
ir
::
Builder
builder
=
ir
::
Builder
(
ctx
,
program
.
block
());
BuildProgram
(
builder
);
EXPECT_EQ
(
program
.
block
()
->
size
(),
3u
);
ir
::
PassManager
pm
(
ctx
);
pm
.
AddPass
(
std
::
make_unique
<
TestPass
>
());
std
::
stringstream
o1
,
o2
;
program
.
Print
(
o1
);
LOG
(
INFO
)
<<
o1
.
str
();
pm
.
Run
(
&
program
);
LOG
(
INFO
)
<<
"After Pass."
;
program
.
Print
(
o2
);
LOG
(
INFO
)
<<
o2
.
str
();
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录