Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f20d3fc2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f20d3fc2
编写于
6月 07, 2023
作者:
W
Wilber
提交者:
GitHub
6月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IR&PASS] part 3-1: Add PatternMatch base class. (#54385)
上级
c15e53d6
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
633 addition
and
0 deletion
+633
-0
paddle/ir/CMakeLists.txt
paddle/ir/CMakeLists.txt
+1
-0
paddle/ir/pattern_rewrite/CMakeLists.txt
paddle/ir/pattern_rewrite/CMakeLists.txt
+6
-0
paddle/ir/pattern_rewrite/pattern_match.cc
paddle/ir/pattern_rewrite/pattern_match.cc
+144
-0
paddle/ir/pattern_rewrite/pattern_match.h
paddle/ir/pattern_rewrite/pattern_match.h
+356
-0
test/cpp/ir/CMakeLists.txt
test/cpp/ir/CMakeLists.txt
+1
-0
test/cpp/ir/pattern_rewrite/CMakeLists.txt
test/cpp/ir/pattern_rewrite/CMakeLists.txt
+10
-0
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
+115
-0
未找到文件。
paddle/ir/CMakeLists.txt
浏览文件 @
f20d3fc2
...
...
@@ -4,3 +4,4 @@ endif()
add_subdirectory
(
core
)
add_subdirectory
(
pass
)
add_subdirectory
(
pattern_rewrite
)
paddle/ir/pattern_rewrite/CMakeLists.txt
0 → 100644
浏览文件 @
f20d3fc2
file
(
GLOB PATTERN_SRCS
"*.cc"
)
cc_library
(
pattern_rewrite
SRCS
${
PATTERN_SRCS
}
DEPS new_ir
)
paddle/ir/pattern_rewrite/pattern_match.cc
0 → 100644
浏览文件 @
f20d3fc2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include <cassert>
#include <cstdint>
#include "paddle/ir/core/operation.h"
namespace
ir
{
//===----------------------------------------------------------------------===//
// Pattern
//===----------------------------------------------------------------------===//
// Pattern::Pattern(const void* root_val,
// RootKind root_kind,
// const std::vector<std::string>& generated_names,
// PatternBenefit benefit,
// ir::IrContext* context)
// : benefit_(benefit), context_(context), generated_names_(generated_names)
// {}
Pattern
::
Pattern
(
const
std
::
string
&
root_name
,
PatternBenefit
benefit
,
IrContext
*
context
,
const
std
::
vector
<
std
::
string
>&
generated_names
)
:
op_name_
(
root_name
),
root_kind_
(
RootKind
::
OperationName
),
benefit_
(
benefit
),
context_
(
context
),
generated_names_
(
generated_names
)
{}
Pattern
::
Pattern
(
MatchAnyOpTypeTag
tag
,
PatternBenefit
benefit
,
ir
::
IrContext
*
context
,
const
std
::
vector
<
std
::
string
>&
generated_names
)
:
root_kind_
(
RootKind
::
Any
),
benefit_
(
benefit
),
context_
(
context
),
generated_names_
(
generated_names
)
{}
Pattern
::
Pattern
(
MatchInterfaceOpTypeTag
tag
,
ir
::
TypeId
interface_id
,
PatternBenefit
benefit
,
ir
::
IrContext
*
context
,
const
std
::
vector
<
std
::
string
>&
generated_names
)
:
interface_id_
(
interface_id
),
root_kind_
(
RootKind
::
InterfaceId
),
benefit_
(
benefit
),
context_
(
context
),
generated_names_
(
generated_names
)
{}
Pattern
::
Pattern
(
MatchTraitOpTypeTag
tag
,
ir
::
TypeId
trait_id
,
PatternBenefit
benefit
,
ir
::
IrContext
*
context
,
const
std
::
vector
<
std
::
string
>&
generated_names
)
:
trait_id_
(
trait_id
),
root_kind_
(
RootKind
::
TraitId
),
benefit_
(
benefit
),
context_
(
context
),
generated_names_
(
generated_names
)
{}
RewritePattern
::~
RewritePattern
()
=
default
;
//===----------------------------------------------------------------------===//
// RewriterBase
//===----------------------------------------------------------------------===//
RewriterBase
::~
RewriterBase
()
=
default
;
// TODO(wilber): value support replace method.
// void RewriterBase::ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// bool* all_uses_replaced,
// std::function<bool(OpOperand&)> functor) {
// // assert(op->num_results() == new_values.size() && "incorrect number of
// values to replace operation"); NotifyRootReplaced(op, new_values); bool
// replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) {
// // op->GetResultByIndex(0)
// }
// }
// void RewriterBase::ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// std::function<bool(OpOperand&)> functor) {
// ReplaceOpWithIf(op, new_values, nullptr, functor);
// }
// TODO(wilber): support erase.
// void ReplaceOp(Operation* op, ValueRange new_values) {
// NotifyRootReplaced(op, new_values);
// assert(op->num_results() == new_values.size() && "incorrect # of
// replacement values"); op->ReplaceAllUsesWith(new_values);
// NotifyOperationRemoved(op);
// op->erase();
// }
void
RewriterBase
::
EraseOp
(
Operation
*
op
)
{
// assert(op->use_empty() && "expected 'op' to have no uses");
// NotifyOperationRemoved(op);
// op->erase();
}
void
RewriterBase
::
ReplaceAllUsesWith
(
Value
from
,
Value
to
)
{
// from.
// for (mlir::OpOperand& operand : llvm::make_early_inc_range(from.getUses()))
// {
// mlir::Operation* op = operand.getOwner();
// UpdateRootInPlace(op, [&]() { operand.set(to); });
// }
}
// TODO(wilber): iterator maybe should support modify inplace.
void
RewriterBase
::
ReplaceUseIf
(
Value
from
,
Value
to
,
std
::
function
<
bool
(
OpOperand
&
)
>
functor
)
{
// for (auto it = from.begin(); it != from.end(); ++it) {
// // TODO: need a lvalue.
// if (functor(it.get())) {
// UpdateRootInplace(it.owner(), [&](){it.get().set(to)});
// }
// }
}
void
RewriterBase
::
ReplaceOpWithResultsOfAnotherOp
(
Operation
*
op
,
Operation
*
new_op
)
{
assert
(
op
->
num_results
()
==
new_op
->
num_results
()
&&
"replacement op doesn't match results of original op"
);
// TODO(wilber): Op support results method.
// if (op->num_results() == 1) return ReplaceOp(op,
// new_op->GetResultByIndex(0)); return ReplaceOp(op, new_op->GetResults());
}
}
// namespace ir
paddle/ir/pattern_rewrite/pattern_match.h
0 → 100644
浏览文件 @
f20d3fc2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <initializer_list>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/type_id.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/ir/core/value.h"
namespace
ir
{
/// The design is mainly from MLIR, very thanks to the greate project.
/// This class reprensents the benefit of a pattern. The most common
/// unit to use is the `numver of operations` in the pattern.
class
PatternBenefit
{
public:
PatternBenefit
(
unsigned
val
)
:
val_
(
val
)
{}
// NOLINT
unsigned
benefit
()
{
return
val_
;
}
bool
operator
==
(
const
PatternBenefit
&
rhs
)
const
{
return
val_
==
rhs
.
val_
;
}
bool
operator
!=
(
const
PatternBenefit
&
rhs
)
const
{
return
!
(
*
this
==
rhs
);
}
bool
operator
<
(
const
PatternBenefit
&
rhs
)
const
{
return
val_
<
rhs
.
val_
;
}
bool
operator
>
(
const
PatternBenefit
&
rhs
)
const
{
return
rhs
<
*
this
;
}
bool
operator
<=
(
const
PatternBenefit
&
rhs
)
const
{
return
!
(
*
this
>
rhs
);
}
bool
operator
>=
(
const
PatternBenefit
&
rhs
)
const
{
return
!
(
*
this
<=
rhs
);
}
private:
unsigned
int
val_
{
0
};
};
/// This class contains all of the data related to a Pattern, but not contains
/// any methods for the matching. This class is used to interface with the
/// metadata of a pattern, such as benefit or root operation.
class
Pattern
{
enum
class
RootKind
{
Any
,
OperationName
,
InterfaceId
,
TraitId
};
public:
PatternBenefit
benefit
()
const
{
return
benefit_
;
}
IrContext
*
context
()
const
{
return
context_
;
}
std
::
string
debug_name
()
const
{
return
debug_name_
;
}
void
SetDebugName
(
const
std
::
string
&
name
)
{
debug_name_
=
name
;
}
const
std
::
vector
<
std
::
string
>&
debug_labels
()
const
{
return
debug_labels_
;
}
void
AddDebugLabels
(
const
std
::
vector
<
std
::
string
>&
labels
)
{
debug_labels_
.
insert
(
debug_labels_
.
end
(),
labels
.
begin
(),
labels
.
end
());
}
void
AddDebugLabels
(
const
std
::
string
&
label
)
{
debug_labels_
.
push_back
(
label
);
}
protected:
struct
MatchAnyOpTypeTag
{};
struct
MatchInterfaceOpTypeTag
{};
struct
MatchTraitOpTypeTag
{};
Pattern
(
const
std
::
string
&
root_name
,
PatternBenefit
benefit
,
ir
::
IrContext
*
context
,
const
std
::
vector
<
std
::
string
>&
generated_names
=
{});
Pattern
(
MatchAnyOpTypeTag
tag
,
PatternBenefit
benefit
,
ir
::
IrContext
*
context
,
const
std
::
vector
<
std
::
string
>&
generated_names
=
{});
Pattern
(
MatchInterfaceOpTypeTag
tag
,
ir
::
TypeId
interface_id
,
PatternBenefit
benefit
,
ir
::
IrContext
*
context
,
const
std
::
vector
<
std
::
string
>&
generated_names
=
{});
Pattern
(
MatchTraitOpTypeTag
tag
,
ir
::
TypeId
trait_id
,
PatternBenefit
benefit
,
ir
::
IrContext
*
context
,
const
std
::
vector
<
std
::
string
>&
generated_names
=
{});
private:
// TODO(wilber): How to uniform variables and constructor.
// Pattern(const void* root_val,
// RootKind root_kind,
// const std::vector<std::string>& generated_names,
// PatternBenefit benefit,
// ir::IrContext* context);
std
::
string
op_name_
;
ir
::
TypeId
interface_id_
;
ir
::
TypeId
trait_id_
;
RootKind
root_kind_
;
const
PatternBenefit
benefit_
;
ir
::
IrContext
*
context_
;
std
::
vector
<
std
::
string
>
generated_names_
;
std
::
string
debug_name_
;
std
::
vector
<
std
::
string
>
debug_labels_
;
};
class
PatternRewriter
;
class
RewritePattern
:
public
Pattern
{
public:
virtual
~
RewritePattern
();
virtual
void
Rewrite
(
ir
::
Operation
*
op
,
PatternRewriter
&
rewriter
)
const
{
// NOLINT
throw
(
"need to implement either MatchAndRewrite or one of the rewrite "
"functions."
);
}
virtual
bool
Match
(
ir
::
Operation
*
op
)
const
{
throw
(
"need to implement either MatchAndRewrite or Match."
);
return
false
;
}
virtual
bool
MatchAndRewrite
(
ir
::
Operation
*
op
,
PatternRewriter
&
rewriter
)
const
{
// NOLINT
if
(
Match
(
op
))
{
Rewrite
(
op
,
rewriter
);
return
true
;
}
return
false
;
}
virtual
void
Initialize
()
{}
template
<
typename
T
,
typename
...
Args
>
static
std
::
unique_ptr
<
T
>
Create
(
Args
&&
...
args
)
{
std
::
unique_ptr
<
T
>
pattern
=
std
::
make_unique
<
T
>
(
std
::
forward
<
Args
>
(
args
)...);
pattern
->
Initialize
();
if
(
pattern
->
debug_name
().
empty
())
pattern
->
SetDebugName
(
get_type_name
<
T
>
());
return
pattern
;
}
protected:
using
Pattern
::
Pattern
;
};
namespace
detail
{
/// A wrapper around PatternWrite that allows for matching and rewriting
/// against an instance of a derived operation class or Interface.
template
<
typename
SourceOp
>
struct
OpOrInterfaceRewritePatternBase
:
public
RewritePattern
{
using
RewritePattern
::
RewritePattern
;
void
Rewrite
(
Operation
*
op
,
PatternRewriter
&
rewriter
)
const
final
{
// NOLINT
Rewrite
(
op
->
dyn_cast
<
SourceOp
>
(),
rewriter
);
}
bool
Match
(
Operation
*
op
)
const
final
{
return
Match
(
op
->
dyn_cast
<
SourceOp
>
());
}
bool
MatchAndRewrite
(
Operation
*
op
,
PatternRewriter
&
rewriter
)
const
final
{
// NOLINT
return
MatchAndRewrite
(
op
->
dyn_cast
<
SourceOp
>
(),
rewriter
);
}
virtual
void
Rewrite
(
SourceOp
op
,
PatternRewriter
&
rewriter
)
const
{
// NOLINT
throw
(
"must override Rewrite or MatchAndRewrite"
);
}
virtual
bool
Match
(
SourceOp
op
)
const
{
throw
(
"must override Match or MatchAndRewrite"
);
}
virtual
bool
MatchAndRewrite
(
SourceOp
op
,
PatternRewriter
&
rewriter
)
const
{
// NOLINT
if
(
Match
(
op
))
{
Rewrite
(
op
,
rewriter
);
return
true
;
}
return
false
;
}
};
}
// namespace detail
/// OpRewritePattern is a wrapper around RewritePattern that allows for
/// matching and rewriting against an instance of a derived operation
/// class as opposed to a raw Operation.
template
<
typename
SourceOp
>
struct
OpRewritePattern
:
public
detail
::
OpOrInterfaceRewritePatternBase
<
SourceOp
>
{
OpRewritePattern
(
ir
::
IrContext
*
context
,
PatternBenefit
benefit
=
1
,
const
std
::
vector
<
std
::
string
>&
generated_names
=
{})
:
detail
::
OpOrInterfaceRewritePatternBase
<
SourceOp
>
(
"NeedToFix"
,
// TODO(wilber): Need to fix. SourceOp maybe should
// have a getOperationName static method.
benefit
,
context
,
generated_names
)
{}
};
// TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern.
// ...
/// This class provides a series of interfaces for modifying IR and tracking IR
/// changes. This class provides a unified API for IR modification.
///
class
RewriterBase
{
// maybe should inherit OpBuilder.
public:
// TODO(wilber): Supplementary methods of block and region.
// TODO(wilber): Support ValueRange.
// virtual void ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// bool* all_uses_replaced,
// std::function<bool(OpOperand&)> functor);
// void ReplaceOpWithIf(Operation* op,
// ValueRange new_values,
// std::function<bool(OpOperand&)> functor);
// virtual void ReplaceOp(Operation* op, ValueRange new_values);
// virtual void ReplaceOpWithNewOp()
virtual
void
EraseOp
(
Operation
*
op
);
virtual
void
StartRootUpdate
(
Operation
*
op
)
{}
virtual
void
FinalizeRootUpdate
(
Operation
*
op
)
{}
virtual
void
CancleRootUpdate
(
Operation
*
op
)
{}
template
<
typename
CallableT
>
void
UpdateRootInplace
(
Operation
*
root
,
CallableT
&&
callable
)
{
StartRootUpdate
(
root
);
callable
();
FinalizeRootUpdate
(
root
);
}
void
ReplaceAllUsesWith
(
Value
from
,
Value
to
);
void
ReplaceUseIf
(
Value
from
,
Value
to
,
std
::
function
<
bool
(
OpOperand
&
)
>
functor
);
protected:
explicit
RewriterBase
(
IrContext
*
ctx
)
:
ctx_
(
ctx
)
{}
virtual
~
RewriterBase
();
// virtual void NotifyRootReplaced(Operation* op, ValueRange replacement) {}
virtual
void
NotifyOperationRemoved
(
Operation
*
op
)
{}
// virtual bool NotifyMatchFailure()
private:
void
operator
=
(
const
RewriterBase
&
)
=
delete
;
RewriterBase
(
const
RewriterBase
&
)
=
delete
;
void
ReplaceOpWithResultsOfAnotherOp
(
Operation
*
op
,
Operation
*
new_op
);
private:
IrContext
*
ctx_
;
};
class
PatternRewriter
:
public
RewriterBase
{
public:
using
RewriterBase
::
RewriterBase
;
};
/// A pattern collection, easy to add patterns.
class
RewritePatternSet
{
using
NativePatternListT
=
std
::
vector
<
std
::
unique_ptr
<
RewritePattern
>>
;
public:
explicit
RewritePatternSet
(
IrContext
*
context
)
:
context_
(
context
)
{}
RewritePatternSet
(
IrContext
*
context
,
std
::
unique_ptr
<
RewritePattern
>
pattern
)
:
context_
(
context
)
{
native_patterns_
.
emplace_back
(
std
::
move
(
pattern
));
}
IrContext
*
context
()
const
{
return
context_
;
}
NativePatternListT
&
native_patterns
()
{
return
native_patterns_
;
}
void
Clear
()
{
native_patterns_
.
clear
();
}
// 'add' methods for adding patterns to the set.
template
<
typename
...
Ts
,
typename
ConstructorArg
,
typename
...
ConstructorArgs
,
typename
=
std
::
enable_if_t
<
sizeof
...(
Ts
)
!=
0
>
>
RewritePatternSet
&
Add
(
ConstructorArg
&&
arg
,
ConstructorArgs
&&
...
args
)
{
std
::
initializer_list
<
int
>
{
(
AddImpl
<
Ts
>
({},
std
::
forward
<
ConstructorArg
>
(
arg
),
std
::
forward
<
ConstructorArgs
>
(
args
)...),
0
)...};
return
*
this
;
}
template
<
typename
...
Ts
,
typename
ConstructorArg
,
typename
...
ConstructorArgs
,
typename
=
std
::
enable_if_t
<
sizeof
...(
Ts
)
!=
0
>
>
RewritePatternSet
&
AddWithLabel
(
const
std
::
vector
<
std
::
string
>&
debug_labels
,
ConstructorArg
&&
arg
,
ConstructorArgs
&&
...
args
)
{
std
::
initializer_list
<
int
>
{
(
AddImpl
<
Ts
>
(
debug_labels
,
std
::
forward
<
ConstructorArg
>
(
arg
),
std
::
forward
<
ConstructorArgs
>
(
args
)...),
0
)...};
return
*
this
;
}
RewritePatternSet
&
Add
(
std
::
unique_ptr
<
RewritePattern
>
pattern
)
{
native_patterns_
.
emplace_back
(
std
::
move
(
pattern
));
return
*
this
;
}
private:
template
<
typename
T
,
typename
...
Args
>
std
::
enable_if_t
<
std
::
is_base_of
<
RewritePattern
,
T
>::
value
>
AddImpl
(
const
std
::
vector
<
std
::
string
>&
debug_labels
,
Args
&&
...
args
)
{
std
::
unique_ptr
<
T
>
pattern
=
RewritePattern
::
Create
<
T
>
(
std
::
forward
<
Args
>
(
args
)...);
pattern
->
AddDebugLabels
(
debug_labels
);
native_patterns_
.
emplace_back
(
std
::
move
(
pattern
));
}
private:
IrContext
*
const
context_
;
NativePatternListT
native_patterns_
;
};
}
// namespace ir
test/cpp/ir/CMakeLists.txt
浏览文件 @
f20d3fc2
...
...
@@ -4,3 +4,4 @@ endif()
add_subdirectory
(
core
)
add_subdirectory
(
pass
)
add_subdirectory
(
pattern_rewrite
)
test/cpp/ir/pattern_rewrite/CMakeLists.txt
0 → 100644
浏览文件 @
f20d3fc2
cc_test_old
(
pattern_rewrite_test
SRCS
pattern_rewrite_test.cc
DEPS
new_pass
pattern_rewrite
pd_dialect
phi
gtest
)
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
0 → 100644
浏览文件 @
f20d3fc2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
TEST
(
PatternBenefit
,
PatternBenefit
)
{
ir
::
PatternBenefit
benefit1
(
1
);
EXPECT_EQ
(
benefit1
.
benefit
(),
1U
);
ir
::
PatternBenefit
benefit2
(
2
);
EXPECT_EQ
(
benefit2
.
benefit
(),
2U
);
EXPECT_TRUE
(
benefit2
>
benefit1
);
EXPECT_TRUE
(
benefit2
>=
benefit1
);
EXPECT_TRUE
(
benefit1
<
benefit2
);
EXPECT_TRUE
(
benefit1
<=
benefit2
);
EXPECT_TRUE
(
benefit1
!=
benefit2
);
ir
::
PatternBenefit
benefit3
(
2
);
EXPECT_TRUE
(
benefit2
==
benefit3
);
}
// Define op1.
class
Operation1
:
public
ir
::
Op
<
Operation1
>
{
public:
using
Op
::
Op
;
static
const
char
*
name
()
{
return
"test.Operation1"
;
}
static
constexpr
uint32_t
attributes_num
=
2
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
if
(
attributes
.
count
(
"op2_attr1"
)
==
0
||
(
!
attributes
.
at
(
"op2_attr1"
).
isa
<
ir
::
StrAttribute
>
()))
{
throw
(
"Type of attribute: parameter_name is not right."
);
}
if
(
attributes
.
count
(
"op2_attr2"
)
==
0
||
(
!
attributes
.
at
(
"op2_attr2"
).
isa
<
ir
::
StrAttribute
>
()))
{
throw
(
"Type of attribute: parameter_name is not right."
);
}
}
static
void
InferShape
()
{
VLOG
(
2
)
<<
"This is op2's InferShape interface."
;
}
};
const
char
*
Operation1
::
attributes_name
[
attributes_num
]
=
{
"op2_attr1"
,
"op2_attr2"
};
// Define a dialect, op1 and op2 will be registered by this dialect.
class
TestDialect
:
public
ir
::
Dialect
{
public:
explicit
TestDialect
(
ir
::
IrContext
*
context
)
:
ir
::
Dialect
(
name
(),
context
,
ir
::
TypeId
::
get
<
TestDialect
>
())
{
initialize
();
}
static
const
char
*
name
()
{
return
"test"
;
}
private:
void
initialize
()
{
RegisterOps
<
Operation1
>
();
}
};
// TODO(wilber): Add logical when ir support erase, replace or update.
class
TestPatternRewrite
:
public
ir
::
OpRewritePattern
<
Operation1
>
{
public:
using
ir
::
OpRewritePattern
<
Operation1
>::
OpRewritePattern
;
void
Rewrite
(
Operation1
op
,
ir
::
PatternRewriter
&
rewriter
)
const
override
{}
bool
Match
(
Operation1
op
)
const
override
{
return
false
;
}
};
class
TestPatternRewrite2
:
public
ir
::
OpRewritePattern
<
Operation1
>
{
public:
using
ir
::
OpRewritePattern
<
Operation1
>::
OpRewritePattern
;
bool
MatchAndRewrite
(
Operation1
op
,
ir
::
PatternRewriter
&
rewriter
)
const
override
{
// NOLINT
return
false
;
}
};
TEST
(
RewritePattern
,
OpRewritePattern
)
{
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
ir
::
BuiltinDialect
>
();
auto
*
test_dialect
=
ctx
->
GetOrRegisterDialect
<
TestDialect
>
();
test_dialect
->
RegisterOp
<
Operation1
>
();
ir
::
RewritePatternSet
ps
(
ctx
);
ps
.
Add
<
TestPatternRewrite
>
(
ctx
,
1
);
EXPECT_EQ
(
ps
.
native_patterns
().
size
(),
1U
);
EXPECT_TRUE
(
ps
.
native_patterns
().
back
()
->
debug_labels
().
empty
());
EXPECT_EQ
(
ps
.
native_patterns
().
back
()
->
benefit
(),
1U
);
ps
.
AddWithLabel
<
TestPatternRewrite2
>
({
"TestPatternRewrite2"
},
ctx
,
2
);
EXPECT_EQ
(
ps
.
native_patterns
().
size
(),
2U
);
EXPECT_EQ
(
ps
.
native_patterns
().
back
()
->
debug_labels
()[
0
],
"TestPatternRewrite2"
);
EXPECT_EQ
(
ps
.
native_patterns
().
back
()
->
benefit
(),
2U
);
ps
.
Clear
();
ps
.
Add
<
TestPatternRewrite
,
TestPatternRewrite2
>
(
ctx
,
2
);
EXPECT_EQ
(
ps
.
native_patterns
().
size
(),
2U
);
EXPECT_EQ
(
ps
.
native_patterns
()[
0
]
->
benefit
(),
2U
);
EXPECT_EQ
(
ps
.
native_patterns
()[
1
]
->
benefit
(),
2U
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录