Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4905a247
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4905a247
编写于
7月 10, 2023
作者:
Y
Yuanle Liu
提交者:
GitHub
7月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PASS] add constant folding pass (#55099)
上级
df311526
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
482 addition
and
219 deletion
+482
-219
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
+6
-8
paddle/fluid/ir/transforms/CMakeLists.txt
paddle/fluid/ir/transforms/CMakeLists.txt
+6
-0
paddle/fluid/ir/transforms/constant_folding_pass.cc
paddle/fluid/ir/transforms/constant_folding_pass.cc
+203
-0
paddle/fluid/ir/transforms/constant_folding_pass.h
paddle/fluid/ir/transforms/constant_folding_pass.h
+2
-1
paddle/fluid/ir/transforms/transform_general_functions.cc
paddle/fluid/ir/transforms/transform_general_functions.cc
+33
-2
paddle/fluid/ir/transforms/transform_general_functions.h
paddle/fluid/ir/transforms/transform_general_functions.h
+21
-22
paddle/ir/CMakeLists.txt
paddle/ir/CMakeLists.txt
+1
-1
paddle/ir/builtin_transforms/CMakeLists.txt
paddle/ir/builtin_transforms/CMakeLists.txt
+0
-0
paddle/ir/builtin_transforms/dead_code_elimination_pass.cc
paddle/ir/builtin_transforms/dead_code_elimination_pass.cc
+28
-12
paddle/ir/builtin_transforms/dead_code_elimination_pass.h
paddle/ir/builtin_transforms/dead_code_elimination_pass.h
+26
-0
paddle/ir/pass/pass.h
paddle/ir/pass/pass.h
+2
-2
paddle/ir/pass/pass_manager.h
paddle/ir/pass/pass_manager.h
+2
-1
paddle/ir/pattern_rewrite/pattern_match.h
paddle/ir/pattern_rewrite/pattern_match.h
+7
-4
test/cpp/ir/pass/pass_manager_test.cc
test/cpp/ir/pass/pass_manager_test.cc
+67
-124
test/cpp/ir/pattern_rewrite/CMakeLists.txt
test/cpp/ir/pattern_rewrite/CMakeLists.txt
+9
-9
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
+69
-28
test/cpp/new_executor/standalone_executor_new_ir_test.cc
test/cpp/new_executor/standalone_executor_new_ir_test.cc
+0
-5
未找到文件。
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
浏览文件 @
4905a247
...
@@ -139,14 +139,12 @@ void HandleForSpecialOp(ir::Operation* op,
...
@@ -139,14 +139,12 @@ void HandleForSpecialOp(ir::Operation* op,
if
(
op_name
==
"pd.fetch"
)
{
if
(
op_name
==
"pd.fetch"
)
{
// fetch is a very special op, with no output
// fetch is a very special op, with no output
VLOG
(
6
)
<<
"Handle for pd.fetch:"
;
VLOG
(
6
)
<<
"Handle for pd.fetch:"
;
for
(
size_t
i
=
0
;
i
<
input_num
;
++
i
)
{
auto
var
=
scope
->
Var
(
"fetch"
);
auto
var
=
scope
->
Var
(
"fetch"
);
VLOG
(
6
)
<<
"Create var: fetch in scope "
<<
scope
;
VLOG
(
6
)
<<
"Create var: fetch in scope "
<<
scope
;
auto
fetch_list
=
var
->
GetMutable
<
paddle
::
framework
::
FetchList
>
();
auto
fetch_list
=
var
->
GetMutable
<
paddle
::
framework
::
FetchList
>
();
int
index
=
int
index
=
op
->
attributes
().
at
(
"col"
).
dyn_cast
<
ir
::
Int32Attribute
>
().
data
();
op
->
attributes
().
at
(
"col"
).
dyn_cast
<
ir
::
Int32Attribute
>
().
data
();
fetch_list
->
resize
(
index
+
1
);
fetch_list
->
resize
(
index
+
1
);
}
}
}
if
(
op_name
==
"pd.feed"
)
{
if
(
op_name
==
"pd.feed"
)
{
...
...
paddle/fluid/ir/transforms/CMakeLists.txt
浏览文件 @
4905a247
...
@@ -7,3 +7,9 @@ cc_library(
...
@@ -7,3 +7,9 @@ cc_library(
pd_op_to_kernel_pass
pd_op_to_kernel_pass
SRCS pd_op_to_kernel_pass.cc
SRCS pd_op_to_kernel_pass.cc
DEPS phi_utils pd_interface pd_trait ir
)
DEPS phi_utils pd_interface pd_trait ir
)
cc_library
(
_constant_folding_pass
SRCS constant_folding_pass.cc
DEPS standalone_executor phi pd_op_to_kernel_pass transform_general_functions
ir
)
paddle/fluid/ir/transforms/constant_folding_pass.cc
0 → 100644
浏览文件 @
4905a247
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/transforms/constant_folding_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
namespace
{
class
ConstantFoldingPattern
:
public
ir
::
RewritePattern
{
public:
ConstantFoldingPattern
(
ir
::
IrContext
*
context
,
ir
::
PatternBenefit
benefit
=
1
,
const
std
::
vector
<
std
::
string
>&
generated_names
=
{})
:
RewritePattern
(
MatchAnyOpTypeTag
(),
benefit
,
context
,
generated_names
)
{
}
bool
Match
(
ir
::
Operation
*
op
)
const
override
{
// TODO(liuyuanle): Use trait to improve robustness.
if
(
op
->
dyn_cast
<
ir
::
GetParameterOp
>
()
||
op
->
dyn_cast
<
ir
::
SetParameterOp
>
()
||
op
->
dyn_cast
<
paddle
::
dialect
::
FetchOp
>
())
return
false
;
// Inputs must come from get parameter op.
for
(
uint32_t
i
=
0
;
i
<
op
->
num_operands
();
++
i
)
if
(
ir
::
GetDefiningOpForInput
(
op
,
i
)
->
dyn_cast
<
ir
::
GetParameterOp
>
()
==
nullptr
)
return
false
;
return
true
;
}
void
Rewrite
(
ir
::
Operation
*
op
,
ir
::
PatternRewriter
&
rewriter
)
const
override
{
// NOLINT
ir
::
Program
*
program
=
op
->
GetParentProgram
();
auto
temp_program
=
BuildProgramFromOperation
(
op
);
// Execute program
paddle
::
framework
::
interpreter
::
ExecutionConfig
exe_config
;
exe_config
.
create_local_scope
=
false
;
paddle
::
framework
::
InterpreterCore
core
(
phi
::
CPUPlace
{},
paddle
::
dialect
::
PdOpLowerToKernelPass
(
temp_program
.
get
()),
&
scope_
,
exe_config
);
paddle
::
framework
::
FetchList
fetch_list
=
core
.
Run
({});
// TODO(liuyuanle): Support multiple output.
auto
out_tensor
=
PADDLE_GET_CONST
(
phi
::
DenseTensor
,
fetch_list
[
0
]);
std
::
unique_ptr
<
ir
::
Parameter
>
parameter
=
std
::
make_unique
<
ir
::
Parameter
>
(
reinterpret_cast
<
void
*>
(
out_tensor
.
data
()),
out_tensor
.
numel
()
*
phi
::
SizeOf
(
out_tensor
.
dtype
()),
op
->
result
(
0
).
type
());
std
::
string
param_name
=
"@constant_folding_pass@_"
+
std
::
to_string
(
suffix_
++
);
auto
*
param_var
=
scope_
.
Var
(
param_name
);
auto
*
param_tensor
=
param_var
->
GetMutable
<
phi
::
DenseTensor
>
();
*
param_tensor
=
out_tensor
;
program
->
SetParameter
(
param_name
,
std
::
move
(
parameter
));
// rewriter.SetInsertionPoint(op);
auto
get_parameter_op
=
rewriter
.
Build
<
ir
::
GetParameterOp
>
(
param_name
,
op
->
result
(
0
).
type
());
rewriter
.
ReplaceAllUsesWith
(
op
->
result
(
0
),
get_parameter_op
->
result
(
0
));
rewriter
.
EraseOp
(
op
);
}
private:
std
::
unique_ptr
<
ir
::
Program
>
BuildProgramFromOperation
(
ir
::
Operation
*
op
)
const
{
auto
program
=
std
::
make_unique
<
ir
::
Program
>
(
ir_context
());
ir
::
Builder
builder
=
ir
::
Builder
(
ir_context
(),
program
->
block
());
// prepare op inputs
std
::
vector
<
ir
::
OpResult
>
op_inputs
;
for
(
uint32_t
i
=
0
;
i
<
op
->
num_operands
();
i
++
)
{
PADDLE_ENFORCE_EQ
(
op
->
operand
(
i
).
type
().
isa
<
paddle
::
dialect
::
DenseTensorType
>
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Op's input must be a dense tensor type."
));
auto
[
param_name
,
param
]
=
ir
::
GetParameterFromValue
(
op
->
operand
(
i
));
program
->
SetParameter
(
param_name
,
std
::
make_unique
<
ir
::
Parameter
>
(
*
param
));
auto
*
param_var
=
scope_
.
FindVar
(
param_name
);
PADDLE_ENFORCE_NOT_NULL
(
param_var
,
phi
::
errors
::
InvalidArgument
(
"Parameter var not in scope."
));
auto
get_parameter_op
=
builder
.
Build
<
ir
::
GetParameterOp
>
(
param_name
,
op
->
operand
(
i
).
type
());
op_inputs
.
push_back
(
get_parameter_op
->
result
(
0
));
}
// prepare op outputs
std
::
vector
<
ir
::
Type
>
output_types
;
for
(
uint32_t
i
=
0
;
i
<
op
->
num_results
();
i
++
)
{
output_types
.
push_back
(
op
->
result
(
i
).
type
());
}
auto
*
temp_op
=
builder
.
Build
(
op_inputs
,
op
->
attributes
(),
output_types
,
op
->
info
());
// TODO(liuyuanle): Support multiple output.
// for (uint32_t i = 0; i < op->num_results(); i++) {
PADDLE_ENFORCE_EQ
(
temp_op
->
result
(
0
).
type
().
isa
<
paddle
::
dialect
::
DenseTensorType
>
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Op's output must be a dense tensor type."
));
builder
.
Build
<
paddle
::
dialect
::
FetchOp
>
(
temp_op
->
result
(
0
),
"fetch_"
+
std
::
to_string
(
suffix_
++
),
0
);
// }
return
program
;
}
private:
static
size_t
suffix_
;
static
paddle
::
framework
::
Scope
scope_
;
};
size_t
ConstantFoldingPattern
::
suffix_
=
0
;
paddle
::
framework
::
Scope
ConstantFoldingPattern
::
scope_
=
{};
class
ConstantFoldingPass
:
public
ir
::
Pass
{
public:
// TODO(liuyuanle): Naming convention for pass.
ConstantFoldingPass
()
:
ir
::
Pass
(
"ConstantFoldingPass"
,
1
)
{}
bool
Initialize
(
ir
::
IrContext
*
context
)
override
{
ir
::
RewritePatternSet
ps
(
context
);
ps
.
Add
<
ConstantFoldingPattern
>
(
context
);
patterns_
=
ir
::
FrozenRewritePatternSet
(
std
::
move
(
ps
));
return
true
;
}
void
Run
(
ir
::
Operation
*
op
)
override
{
ir
::
GreedyRewriteConfig
cfg
;
cfg
.
use_top_down_traversal
=
true
;
cfg
.
max_iterations
=
10
;
ir
::
ApplyPatternsGreedily
(
op
->
region
(
0
),
patterns_
,
cfg
);
}
bool
CanApplyOn
(
ir
::
Operation
*
op
)
const
override
{
return
op
->
name
()
==
"builtin.module"
&&
op
->
num_regions
()
>
0
;
}
private:
ir
::
FrozenRewritePatternSet
patterns_
;
};
}
// namespace
namespace
ir
{
std
::
unique_ptr
<
Pass
>
CreateConstantFoldingPass
()
{
return
std
::
make_unique
<
ConstantFoldingPass
>
();
}
}
// namespace ir
paddle/
ir/transforms/dce
.h
→
paddle/
fluid/ir/transforms/constant_folding_pass
.h
浏览文件 @
4905a247
...
@@ -18,8 +18,9 @@
...
@@ -18,8 +18,9 @@
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/dll_decl.h"
namespace
ir
{
namespace
ir
{
class
Pass
;
class
Pass
;
IR_API
std
::
unique_ptr
<
Pass
>
Create
Dce
Pass
();
IR_API
std
::
unique_ptr
<
Pass
>
Create
ConstantFolding
Pass
();
}
// namespace ir
}
// namespace ir
paddle/fluid/ir/transforms/transform_general_functions.cc
浏览文件 @
4905a247
...
@@ -17,22 +17,28 @@
...
@@ -17,22 +17,28 @@
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/program.h"
namespace
ir
{
namespace
ir
{
ir
::
Parameter
*
GetParameterFromValue
(
ir
::
Value
value
)
{
std
::
pair
<
std
::
string
,
ir
::
Parameter
*>
GetParameterFromValue
(
ir
::
Value
value
)
{
ir
::
GetParameterOp
op
=
value
.
GetDefiningOp
()
->
dyn_cast
<
ir
::
GetParameterOp
>
();
ir
::
GetParameterOp
op
=
value
.
GetDefiningOp
()
->
dyn_cast
<
ir
::
GetParameterOp
>
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
op
,
op
,
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"Value must be a weight from a GetParameter op."
));
"Value must be a weight from a GetParameter op."
));
ir
::
Program
*
program
=
op
->
GetParentProgram
();
ir
::
Program
*
program
=
op
->
GetParentProgram
();
PADDLE_ENFORCE_NOT_NULL
(
program
,
phi
::
errors
::
InvalidArgument
(
"Program should not be null."
));
std
::
string
name
=
op
->
attributes
()
std
::
string
name
=
op
->
attributes
()
.
at
(
op
.
attributes_name
[
0
])
.
at
(
op
.
attributes_name
[
0
])
.
dyn_cast
<
ir
::
StrAttribute
>
()
.
dyn_cast
<
ir
::
StrAttribute
>
()
.
data
();
.
data
();
return
program
->
GetParameter
(
name
);
ir
::
Parameter
*
param
=
program
->
GetParameter
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
param
,
phi
::
errors
::
InvalidArgument
(
"Parameter should not be null."
));
return
{
name
,
param
};
}
}
const
phi
::
DDim
&
GetShapeFromValue
(
ir
::
Value
value
)
{
const
phi
::
DDim
&
GetShapeFromValue
(
ir
::
Value
value
)
{
...
@@ -44,4 +50,29 @@ const phi::DDim& GetShapeFromValue(ir::Value value) {
...
@@ -44,4 +50,29 @@ const phi::DDim& GetShapeFromValue(ir::Value value) {
return
value
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
return
value
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
}
}
ir
::
Type
GetDataTypeFromValue
(
ir
::
Value
value
)
{
// TODO(dev): Support other types like DenseTensor.
PADDLE_ENFORCE_EQ
(
value
.
type
().
isa
<
paddle
::
dialect
::
DenseTensorType
>
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Value's type must be a DenseTensorType."
));
return
value
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dtype
();
}
Operation
*
GetDefiningOpForInput
(
Operation
*
op
,
uint32_t
index
)
{
PADDLE_ENFORCE_EQ
(
index
<
op
->
num_operands
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Intput operand's index must be valid."
));
return
op
->
operand
(
index
).
GetDefiningOp
();
}
Operation
*
GetFirstUseOperationForOutput
(
Operation
*
op
,
uint32_t
index
)
{
PADDLE_ENFORCE_EQ
(
index
<
op
->
num_results
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Output op result's index must be valid."
));
return
op
->
result
(
index
).
first_use
().
owner
();
}
}
// namespace ir
}
// namespace ir
paddle/fluid/ir/transforms/transform_general_functions.h
浏览文件 @
4905a247
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
...
@@ -24,15 +25,16 @@
...
@@ -24,15 +25,16 @@
namespace
ir
{
namespace
ir
{
/**
/**
* @brief Get the
para
meter from a value.
* @brief Get the
[name, parameter] pair of parar
meter from a value.
*
*
* @note The value must be a output of a GetParameterOp.
* @note The value must be a output of a GetParameterOp.
*
*
* @param ir::Value
* @param ir::Value
*
*
* @return
ir::Parameter*
* @return
std::pair<std::string, ir::Parameter*>
*/
*/
ir
::
Parameter
*
GetParameterFromValue
(
ir
::
Value
value
);
std
::
pair
<
std
::
string
,
ir
::
Parameter
*>
GetParameterFromValue
(
ir
::
Value
value
);
/**
/**
* @brief Get tensor's shape from a value.
* @brief Get tensor's shape from a value.
...
@@ -43,37 +45,34 @@ ir::Parameter* GetParameterFromValue(ir::Value value);
...
@@ -43,37 +45,34 @@ ir::Parameter* GetParameterFromValue(ir::Value value);
*/
*/
const
phi
::
DDim
&
GetShapeFromValue
(
ir
::
Value
value
);
const
phi
::
DDim
&
GetShapeFromValue
(
ir
::
Value
value
);
/**
* @brief Get tensor's data type from a value.
*
* @param ir::Value
*
* @return ir::Type
*/
ir
::
Type
GetDataTypeFromValue
(
ir
::
Value
value
);
/**
/**
* @brief Get an operation that defines the specific input of the operation.
* @brief Get an operation that defines the specific input of the operation.
*
*
* @param Operation*
* @param Operation* pointer to an operation
* @param uint32_t index of operand of the operation
*
*
* @return Operation*
* @return Operation*
*/
*/
template
<
uint32_t
Index
=
0
>
Operation
*
GetDefiningOpForInput
(
Operation
*
op
,
uint32_t
index
);
Operation
*
GetDefiningOpForInput
(
Operation
*
op
)
{
PADDLE_ENFORCE_EQ
(
Index
<
op
->
num_operands
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Intput operand's index must be valid."
));
return
op
->
operand
(
Index
).
GetDefiningOp
();
}
/**
/**
* @brief Get an operation that is the first to use the specific output of the
* @brief Get an operation that is the first to use the specific output of the
* operation.
* operation.
*
*
* @param Operation*
* @param Operation* pointer to an operation
*
* @param uint32_t index of result of the operation
* @return Operation*
* @return Operation*
*/
*/
template
<
uint32_t
Index
=
0
>
Operation
*
GetFirstUseOperationForOutput
(
Operation
*
op
,
uint32_t
index
);
Operation
*
GetFirstUseOperationForOutput
(
Operation
*
op
)
{
PADDLE_ENFORCE_EQ
(
Index
<
op
->
num_results
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Output op result's index must be valid."
));
return
op
->
result
(
Index
).
first_use
().
owner
();
}
}
// namespace ir
}
// namespace ir
paddle/ir/CMakeLists.txt
浏览文件 @
4905a247
...
@@ -37,7 +37,7 @@ endfunction()
...
@@ -37,7 +37,7 @@ endfunction()
add_subdirectory
(
core
)
add_subdirectory
(
core
)
add_subdirectory
(
pass
)
add_subdirectory
(
pass
)
add_subdirectory
(
pattern_rewrite
)
add_subdirectory
(
pattern_rewrite
)
add_subdirectory
(
transforms
)
add_subdirectory
(
builtin_
transforms
)
if
(
WIN32
)
if
(
WIN32
)
if
(
WITH_SHARED_IR
)
if
(
WITH_SHARED_IR
)
...
...
paddle/ir/transforms/CMakeLists.txt
→
paddle/ir/
builtin_
transforms/CMakeLists.txt
浏览文件 @
4905a247
文件已移动
paddle/ir/
transforms/dce
.cc
→
paddle/ir/
builtin_transforms/dead_code_elimination_pass
.cc
浏览文件 @
4905a247
...
@@ -12,9 +12,10 @@
...
@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/ir/
transforms/dce
.h"
#include "paddle/ir/
builtin_transforms/dead_code_elimination_pass
.h"
#include <memory>
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass.h"
namespace
{
namespace
{
...
@@ -22,30 +23,43 @@ namespace {
...
@@ -22,30 +23,43 @@ namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// removed by dce pass.
// Now just a naive implementation.
// Now just a naive implementation.
class
D
ce
Pass
:
public
ir
::
Pass
{
class
D
eadCodeElimination
Pass
:
public
ir
::
Pass
{
public:
public:
D
cePass
()
:
ir
::
Pass
(
"Dce
Pass"
,
0
)
{}
D
eadCodeEliminationPass
()
:
ir
::
Pass
(
"DeadCodeElimination
Pass"
,
0
)
{}
void
Run
(
ir
::
Operation
*
op
)
override
{
void
Run
(
ir
::
Operation
*
op
)
override
{
auto
module_op
=
op
->
dyn_cast
<
ir
::
ModuleOp
>
();
auto
module_op
=
op
->
dyn_cast
<
ir
::
ModuleOp
>
();
IR_ENFORCE
(
module_op
,
"DcePass should run on module op."
);
IR_ENFORCE
(
module_op
,
"DcePass should run on module op."
);
auto
*
block
=
module_op
.
block
();
auto
*
block
=
module_op
.
block
();
std
::
vector
<
ir
::
Operation
>
erased_op
;
std
::
vector
<
ir
::
Operation
*
>
erased_op
;
for
(
auto
it
=
block
->
begin
();
it
!=
block
->
end
();
++
it
)
{
for
(
auto
it
=
block
->
begin
();
it
!=
block
->
end
();
++
it
)
{
auto
&
op
=
*
it
;
// TODO(wilber): Support NoSideEffect trait.
// TODO(wilber): Support NoSideEffect trait.
// if (!
(*it)
->HasTrait<NoSideEffect>()) continue;
// if (!
op
->HasTrait<NoSideEffect>()) continue;
bool
use_empty
=
true
;
bool
use_empty
=
true
;
for
(
uint32_t
i
=
0
;
i
<
(
*
it
)
->
num_results
();
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
op
->
num_results
();
++
i
)
{
use_empty
&=
(
*
it
)
->
result
(
i
).
use_empty
();
use_empty
&=
op
->
result
(
i
).
use_empty
();
}
}
// TODO(wilber): Support Terminator trait.
// TODO(wilber): Support Terminator trait.
if
(
use_empty
&&
(
*
it
)
->
name
()
!=
"pd.fetch"
)
{
if
(
use_empty
&&
op
->
name
()
!=
"pd.fetch"
)
{
erased_op
.
push_back
(
**
it
);
erased_op
.
push_back
(
op
);
}
}
}
}
for
(
auto
ep
:
erased_op
)
block
->
erase
(
ep
);
for
(
auto
*
op
:
erased_op
)
{
if
(
op
->
dyn_cast
<
ir
::
GetParameterOp
>
())
{
// Delete parameter from program.
ir
::
GetParameterOp
get_parameter_op
=
op
->
dyn_cast
<
ir
::
GetParameterOp
>
();
get_parameter_op
->
GetParentProgram
()
->
parameters
().
erase
(
get_parameter_op
->
attributes
()
.
at
(
get_parameter_op
.
attributes_name
[
0
])
.
dyn_cast
<
ir
::
StrAttribute
>
()
.
data
());
}
block
->
erase
(
*
op
);
}
}
}
bool
CanApplyOn
(
ir
::
Operation
*
op
)
const
override
{
bool
CanApplyOn
(
ir
::
Operation
*
op
)
const
override
{
...
@@ -57,6 +71,8 @@ class DcePass : public ir::Pass {
...
@@ -57,6 +71,8 @@ class DcePass : public ir::Pass {
namespace
ir
{
namespace
ir
{
std
::
unique_ptr
<
Pass
>
CreateDcePass
()
{
return
std
::
make_unique
<
DcePass
>
();
}
std
::
unique_ptr
<
Pass
>
CreateDeadCodeEliminationPass
()
{
return
std
::
make_unique
<
DeadCodeEliminationPass
>
();
}
}
// namespace ir
}
// namespace ir
paddle/ir/builtin_transforms/dead_code_elimination_pass.h
0 → 100644
浏览文件 @
4905a247
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/ir/core/dll_decl.h"
namespace
ir
{
class
Pass
;
IR_API
std
::
unique_ptr
<
Pass
>
CreateDeadCodeEliminationPass
();
}
// namespace ir
paddle/ir/pass/pass.h
浏览文件 @
4905a247
...
@@ -55,8 +55,8 @@ struct PassInfo {
...
@@ -55,8 +55,8 @@ struct PassInfo {
std
::
string
name
;
std
::
string
name
;
// opt_level=0: the basic pass which framework need.
// opt_level=0: the basic pass which framework need.
// opt_level=1:
the fusion logical pass
.
// opt_level=1:
constant fold, cse, memory optimize, etc
.
// opt_level=2:
constant fold, cse, memory optimize, etc
.
// opt_level=2:
the fusion logical pass
.
// opt_level=3: layout, etc.
// opt_level=3: layout, etc.
uint8_t
opt_level
;
uint8_t
opt_level
;
...
...
paddle/ir/pass/pass_manager.h
浏览文件 @
4905a247
...
@@ -110,7 +110,8 @@ class IR_API PassManager {
...
@@ -110,7 +110,8 @@ class IR_API PassManager {
// TODO(liuyuanle): Add flags to control printing behavior.
// TODO(liuyuanle): Add flags to control printing behavior.
};
};
void
EnableIRPrinting
(
std
::
unique_ptr
<
IRPrinterOption
>
config
);
void
EnableIRPrinting
(
std
::
unique_ptr
<
IRPrinterOption
>
option
=
std
::
make_unique
<
IRPrinterOption
>
());
void
EnablePassTiming
(
bool
print_module
=
true
);
void
EnablePassTiming
(
bool
print_module
=
true
);
...
...
paddle/ir/pattern_rewrite/pattern_match.h
浏览文件 @
4905a247
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/dll_decl.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/operation.h"
...
@@ -148,6 +149,8 @@ class IR_API Pattern {
...
@@ -148,6 +149,8 @@ class IR_API Pattern {
const
PatternBenefit
benefit_
;
const
PatternBenefit
benefit_
;
IrContext
*
context_
;
IrContext
*
context_
;
// A list of the potential operations that may be generated when rewriting an
// op with this pattern.
std
::
vector
<
OpInfo
>
generated_ops_
;
std
::
vector
<
OpInfo
>
generated_ops_
;
std
::
string
debug_name_
;
std
::
string
debug_name_
;
...
@@ -162,13 +165,13 @@ class IR_API RewritePattern : public Pattern {
...
@@ -162,13 +165,13 @@ class IR_API RewritePattern : public Pattern {
virtual
void
Rewrite
(
Operation
*
op
,
virtual
void
Rewrite
(
Operation
*
op
,
PatternRewriter
&
rewriter
)
const
{
// NOLINT
PatternRewriter
&
rewriter
)
const
{
// NOLINT
throw
(
IR_THROW
(
"need to implement either MatchAndRewrite or one of the rewrite "
"need to implement either MatchAndRewrite or one of the rewrite "
"functions."
);
"functions."
);
}
}
virtual
bool
Match
(
Operation
*
op
)
const
{
virtual
bool
Match
(
Operation
*
op
)
const
{
throw
(
"need to implement either MatchAndRewrite or Match."
);
IR_THROW
(
"need to implement either MatchAndRewrite or Match."
);
return
false
;
return
false
;
}
}
...
@@ -220,10 +223,10 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
...
@@ -220,10 +223,10 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
virtual
void
Rewrite
(
SourceOp
op
,
virtual
void
Rewrite
(
SourceOp
op
,
PatternRewriter
&
rewriter
)
const
{
// NOLINT
PatternRewriter
&
rewriter
)
const
{
// NOLINT
throw
(
"must override Rewrite or MatchAndRewrite"
);
IR_THROW
(
"must override Rewrite or MatchAndRewrite"
);
}
}
virtual
bool
Match
(
SourceOp
op
)
const
{
virtual
bool
Match
(
SourceOp
op
)
const
{
throw
(
"must override Match or MatchAndRewrite"
);
IR_THROW
(
"must override Match or MatchAndRewrite"
);
}
}
virtual
bool
MatchAndRewrite
(
SourceOp
op
,
virtual
bool
MatchAndRewrite
(
SourceOp
op
,
PatternRewriter
&
rewriter
)
const
{
// NOLINT
PatternRewriter
&
rewriter
)
const
{
// NOLINT
...
...
test/cpp/ir/pass/pass_manager_test.cc
浏览文件 @
4905a247
...
@@ -15,6 +15,10 @@
...
@@ -15,6 +15,10 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "glog/logging.h"
#include "glog/logging.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/dialect/utils.h"
...
@@ -125,7 +129,7 @@ class TestPass : public ir::Pass {
...
@@ -125,7 +129,7 @@ class TestPass : public ir::Pass {
pass_state
().
preserved_analyses
.
Preserve
<
CountOpAnalysis
>
();
pass_state
().
preserved_analyses
.
Preserve
<
CountOpAnalysis
>
();
CHECK_EQ
(
pass_state
().
preserved_analyses
.
IsPreserved
<
CountOpAnalysis
>
(),
CHECK_EQ
(
pass_state
().
preserved_analyses
.
IsPreserved
<
CountOpAnalysis
>
(),
true
);
true
);
CHECK_EQ
(
count_op_analysis
.
count
,
4
);
CHECK_EQ
(
count_op_analysis
.
count
,
11
);
auto
module_op
=
op
->
dyn_cast
<
ir
::
ModuleOp
>
();
auto
module_op
=
op
->
dyn_cast
<
ir
::
ModuleOp
>
();
CHECK_EQ
(
module_op
.
operation
(),
op
);
CHECK_EQ
(
module_op
.
operation
(),
op
);
...
@@ -143,139 +147,78 @@ class TestPass : public ir::Pass {
...
@@ -143,139 +147,78 @@ class TestPass : public ir::Pass {
}
}
};
};
TEST
(
pass_manager
,
PassManager
)
{
void
BuildProgram
(
ir
::
Builder
&
builder
)
{
// NOLINT
//
paddle
::
dialect
::
FullOp
full_input_op
=
// TODO(liuyuanle): remove test code other than pass manager
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
4
,
3
,
16
,
16
},
//
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_filter_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
,
3
,
3
,
3
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_mean_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_variance_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_scale_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_bias_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
Conv2dOp
conv2d_op
=
builder
.
Build
<
paddle
::
dialect
::
Conv2dOp
>
(
full_input_op
.
out
(),
full_filter_op
.
out
());
paddle
::
dialect
::
BatchNormOp
batch_norm_op
=
builder
.
Build
<
paddle
::
dialect
::
BatchNormOp
>
(
conv2d_op
.
out
(),
full_mean_op
.
out
(),
full_variance_op
.
out
(),
full_scale_op
.
out
(),
full_bias_op
.
out
(),
true
,
0.9
,
1e-6
,
"NCHW"
,
false
,
false
);
auto
transpose1_op
=
builder
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
batch_norm_op
.
out
(),
std
::
vector
<
int
>
{
0
,
2
,
3
,
1
});
auto
transpose2_op
=
builder
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
transpose1_op
.
out
(),
std
::
vector
<
int
>
{
0
,
3
,
1
,
2
});
builder
.
Build
<
paddle
::
dialect
::
FetchOp
>
(
transpose2_op
.
out
(),
"out"
,
0
);
}
// (1) Init environment.
TEST
(
pass_manager
,
PassManager
)
{
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ir
::
Dialect
*
builtin_dialect
=
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
ctx
->
GetOrRegisterDialect
<
ir
::
BuiltinDialect
>
();
builtin_dialect
->
RegisterOp
<
AddOp
>
();
ir
::
Dialect
*
paddle_dialect
=
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
// (2) Create an empty program object
ir
::
Program
program
(
ctx
);
ir
::
Program
program
(
ctx
);
ir
::
Builder
builder
=
ir
::
Builder
(
ctx
,
program
.
block
());
BuildProgram
(
builder
);
// (3) Create a float32 DenseTensor Parameter and save into Program
EXPECT_EQ
(
program
.
block
()
->
size
(),
11u
);
ir
::
Type
fp32_dtype
=
ir
::
Float32Type
::
get
(
ctx
);
phi
::
DDim
dims
=
{
2
,
2
};
phi
::
DataLayout
data_layout
=
phi
::
DataLayout
::
NCHW
;
phi
::
LoD
lod
=
{{
0
,
1
,
2
}};
size_t
offset
=
0
;
ir
::
Type
dense_tensor_dtype
=
paddle
::
dialect
::
DenseTensorType
::
get
(
ctx
,
fp32_dtype
,
dims
,
data_layout
,
lod
,
offset
);
std
::
vector
<
float
>
data_a
=
{
1
,
2
,
3
,
4
};
std
::
unique_ptr
<
ir
::
Parameter
>
parameter_a
=
std
::
make_unique
<
ir
::
Parameter
>
(
reinterpret_cast
<
void
*>
(
data_a
.
data
()),
4
*
sizeof
(
float
),
dense_tensor_dtype
);
program
.
SetParameter
(
"a"
,
std
::
move
(
parameter_a
));
EXPECT_EQ
(
program
.
parameters_num
()
==
1
,
true
);
std
::
vector
<
float
>
data_b
=
{
5
,
6
,
7
,
8
};
std
::
unique_ptr
<
ir
::
Parameter
>
parameter_b
=
std
::
make_unique
<
ir
::
Parameter
>
(
reinterpret_cast
<
void
*>
(
data_b
.
data
()),
4
*
sizeof
(
float
),
dense_tensor_dtype
);
program
.
SetParameter
(
"b"
,
std
::
move
(
parameter_b
));
EXPECT_EQ
(
program
.
parameters_num
()
==
2
,
true
);
// (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
ir
::
Builder
builder
(
ctx
,
program
.
block
());
auto
op1
=
builder
.
Build
<
ir
::
GetParameterOp
>
(
"a"
,
dense_tensor_dtype
);
EXPECT_EQ
(
&
program
,
op1
->
GetParentProgram
());
EXPECT_EQ
(
op1
->
result
(
0
).
type
().
dialect
().
id
(),
paddle_dialect
->
id
());
using
Interface
=
paddle
::
dialect
::
ParameterConvertInterface
;
Interface
*
a_interface
=
op1
->
result
(
0
).
type
().
dialect
().
GetRegisteredInterface
<
Interface
>
();
std
::
shared_ptr
<
paddle
::
framework
::
Variable
>
a_var
=
a_interface
->
ParameterToVariable
(
program
.
GetParameter
(
"a"
));
const
phi
::
DenseTensor
&
a_tensor
=
a_var
->
Get
<
phi
::
DenseTensor
>
();
EXPECT_EQ
(
a_tensor
.
numel
(),
4
);
EXPECT_EQ
(
a_tensor
.
dims
(),
dims
);
EXPECT_EQ
(
a_tensor
.
dtype
(),
paddle
::
dialect
::
TransToPhiDataType
(
fp32_dtype
));
EXPECT_EQ
(
a_tensor
.
layout
(),
data_layout
);
EXPECT_EQ
(
a_tensor
.
lod
(),
lod
);
EXPECT_EQ
(
a_tensor
.
offset
(),
offset
);
for
(
int64_t
i
=
0
;
i
<
a_tensor
.
numel
();
i
++
)
{
EXPECT_EQ
(
*
(
a_tensor
.
data
<
float
>
()
+
i
),
data_a
[
i
]);
}
// (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
auto
op2
=
builder
.
Build
<
ir
::
GetParameterOp
>
(
"b"
,
dense_tensor_dtype
);
EXPECT_EQ
(
op2
->
result
(
0
).
type
().
dialect
().
id
(),
paddle_dialect
->
id
());
Interface
*
b_interface
=
op2
->
result
(
0
).
type
().
dialect
().
GetRegisteredInterface
<
Interface
>
();
std
::
shared_ptr
<
paddle
::
framework
::
Variable
>
b_var
=
b_interface
->
ParameterToVariable
(
program
.
GetParameter
(
"b"
));
const
phi
::
DenseTensor
&
b_tensor
=
b_var
->
Get
<
phi
::
DenseTensor
>
();
EXPECT_EQ
(
b_tensor
.
numel
(),
4
);
EXPECT_EQ
(
b_tensor
.
dims
(),
dims
);
EXPECT_EQ
(
b_tensor
.
dtype
(),
paddle
::
dialect
::
TransToPhiDataType
(
fp32_dtype
));
EXPECT_EQ
(
b_tensor
.
layout
(),
data_layout
);
EXPECT_EQ
(
b_tensor
.
lod
(),
lod
);
EXPECT_EQ
(
b_tensor
.
offset
(),
offset
);
for
(
int64_t
i
=
0
;
i
<
b_tensor
.
numel
();
i
++
)
{
EXPECT_EQ
(
*
(
b_tensor
.
data
<
float
>
()
+
i
),
data_b
[
i
]);
}
// (6) Def c = AddOp(a, b), execute this op.
auto
op3
=
builder
.
Build
<
AddOp
>
(
op1
->
result
(
0
),
op2
->
result
(
0
),
dense_tensor_dtype
);
phi
::
CPUContext
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
paddle
::
platform
::
CPUPlace
()));
phi
::
DenseTensor
c_tensor
=
phi
::
Add
<
float
,
phi
::
CPUContext
>
(
*
dev_ctx
,
a_tensor
,
b_tensor
);
std
::
shared_ptr
<
paddle
::
framework
::
Variable
>
variable_c
=
std
::
make_shared
<
paddle
::
framework
::
Variable
>
();
auto
*
dst_tensor
=
variable_c
->
GetMutable
<
phi
::
DenseTensor
>
();
*
dst_tensor
=
c_tensor
;
EXPECT_EQ
(
dst_tensor
->
numel
(),
b_tensor
.
numel
());
EXPECT_EQ
(
dst_tensor
->
dims
(),
b_tensor
.
dims
());
EXPECT_EQ
(
dst_tensor
->
dtype
(),
b_tensor
.
dtype
());
EXPECT_EQ
(
dst_tensor
->
layout
(),
b_tensor
.
layout
());
EXPECT_EQ
(
dst_tensor
->
lod
(),
b_tensor
.
lod
());
EXPECT_EQ
(
dst_tensor
->
offset
(),
b_tensor
.
offset
());
for
(
int64_t
i
=
0
;
i
<
dst_tensor
->
numel
();
i
++
)
{
EXPECT_EQ
(
*
(
dst_tensor
->
data
<
float
>
()
+
i
),
data_a
[
i
]
+
data_b
[
i
]);
}
// (7) Def SetParameterOp(c, "c")
auto
op4
=
builder
.
Build
<
ir
::
SetParameterOp
>
(
op3
->
result
(
0
),
"c"
);
EXPECT_EQ
(
op4
->
operand
(
0
).
type
().
dialect
().
id
(),
paddle_dialect
->
id
());
Interface
*
c_interface
=
op4
->
op_operand
(
0
).
type
().
dialect
().
GetRegisteredInterface
<
Interface
>
();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
std
::
unique_ptr
<
ir
::
Parameter
>
parameter_c
=
c_interface
->
VariableToParameter
(
variable_c
.
get
());
EXPECT_EQ
(
parameter_c
->
type
(),
dense_tensor_dtype
);
for
(
int64_t
i
=
0
;
i
<
dst_tensor
->
numel
();
i
++
)
{
EXPECT_EQ
(
*
(
dst_tensor
->
data
<
float
>
()
+
i
),
*
(
static_cast
<
float
*>
(
parameter_c
->
data
())
+
i
));
}
program
.
SetParameter
(
"c"
,
std
::
move
(
parameter_c
));
// (8) Traverse Program
EXPECT_EQ
(
program
.
block
()
->
size
()
==
4
,
true
);
EXPECT_EQ
(
program
.
parameters_num
()
==
3
,
true
);
//
// TODO(liuyuanle): remove the code above.
//
// (9) Test pass manager for program.
// (9) Test pass manager for program.
ir
::
PassManager
pm
(
ctx
);
ir
::
PassManager
pm
(
ctx
);
pm
.
AddPass
(
std
::
make_unique
<
TestPass
>
());
pm
.
AddPass
(
std
::
make_unique
<
TestPass
>
());
// pm.EnableIRPrinting();
pm
.
EnableIRPrinting
(
std
::
make_unique
<
ir
::
PassManager
::
IRPrinterOption
>
(
pm
.
EnableIRPrinting
(
std
::
make_unique
<
ir
::
PassManager
::
IRPrinterOption
>
(
[](
ir
::
Pass
*
pass
,
ir
::
Operation
*
op
)
{
[](
ir
::
Pass
*
pass
,
ir
::
Operation
*
op
)
{
return
pass
->
name
()
==
"TestPass"
;
return
pass
->
name
()
==
"TestPass"
;
...
...
test/cpp/ir/pattern_rewrite/CMakeLists.txt
浏览文件 @
4905a247
cc_test_old
(
set
(
PATTERN_REWRITE_TEST_DEPS _constant_folding_pass
pattern_rewrite_test
transform_general_functions gtest pd_dialect ir
)
SRCS
pattern_rewrite_test.cc
if
(
WITH_DISTRIBUTE
)
DEPS
set
(
PATTERN_REWRITE_TEST_DEPS
${
PATTERN_REWRITE_TEST_DEPS
}
fleet_executor
)
ir
endif
()
pd_dialect
transform_general_functions
cc_test_old
(
pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS
gtest
)
${
PATTERN_REWRITE_TEST_DEPS
}
)
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
浏览文件 @
4905a247
...
@@ -15,12 +15,15 @@
...
@@ -15,12 +15,15 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <cstdint>
#include <cstdint>
#include <iostream>
#include <iostream>
#include <memory>
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include <vector>
#include <vector>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/transforms/constant_folding_pass.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_dialect.h"
...
@@ -39,7 +42,7 @@
...
@@ -39,7 +42,7 @@
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_applicator.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/
ir/transforms/dce
.h"
#include "paddle/
phi/core/kernel_registry
.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
// paddle/fluid/ir/dialect/CMakeLists.txt.
...
@@ -56,6 +59,18 @@
...
@@ -56,6 +59,18 @@
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/multiary.h"
PD_DECLARE_KERNEL
(
full
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
add
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
sqrt
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
divide
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
multiply
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
subtract
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
full_int_array
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
reshape
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
fetch
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
conv2d
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
transpose
,
CPU
,
ALL_LAYOUT
);
// Define op1.
// Define op1.
class
Operation1
:
public
ir
::
Op
<
Operation1
>
{
class
Operation1
:
public
ir
::
Op
<
Operation1
>
{
public:
public:
...
@@ -197,7 +212,7 @@ class RedundantTransposeFusePattern
...
@@ -197,7 +212,7 @@ class RedundantTransposeFusePattern
bool
MatchAndRewrite
(
paddle
::
dialect
::
TransposeOp
op
,
bool
MatchAndRewrite
(
paddle
::
dialect
::
TransposeOp
op
,
ir
::
PatternRewriter
&
rewriter
)
const
override
{
ir
::
PatternRewriter
&
rewriter
)
const
override
{
auto
prev_op
=
ir
::
GetDefiningOpForInput
<
0
>
(
op
);
auto
prev_op
=
ir
::
GetDefiningOpForInput
(
op
,
0
);
std
::
vector
<
int
>
axis_last
=
GetAxis
(
op
);
std
::
vector
<
int
>
axis_last
=
GetAxis
(
op
);
auto
prev_trans_op
=
prev_op
->
dyn_cast
<
paddle
::
dialect
::
TransposeOp
>
();
auto
prev_trans_op
=
prev_op
->
dyn_cast
<
paddle
::
dialect
::
TransposeOp
>
();
if
(
prev_trans_op
)
{
if
(
prev_trans_op
)
{
...
@@ -207,7 +222,7 @@ class RedundantTransposeFusePattern
...
@@ -207,7 +222,7 @@ class RedundantTransposeFusePattern
auto
new_perm
=
GetPerm
(
axis_first
,
axis_last
);
auto
new_perm
=
GetPerm
(
axis_first
,
axis_last
);
rewriter
.
SetInsertionPoint
(
op
);
rewriter
.
SetInsertionPoint
(
op
);
auto
new_transpose_op
=
rewriter
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
auto
new_transpose_op
=
rewriter
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
ir
::
GetDefiningOpForInput
<
0
>
(
prev_trans_op
)
->
result
(
0
),
new_perm
);
ir
::
GetDefiningOpForInput
(
prev_trans_op
,
0
)
->
result
(
0
),
new_perm
);
rewriter
.
ReplaceOp
(
op
,
{
new_transpose_op
.
out
()});
rewriter
.
ReplaceOp
(
op
,
{
new_transpose_op
.
out
()});
return
true
;
return
true
;
}
}
...
@@ -249,7 +264,7 @@ class Conv2dBnFusePattern
...
@@ -249,7 +264,7 @@ class Conv2dBnFusePattern
ir
::
PatternRewriter
&
rewriter
)
const
override
{
// NOLINT
ir
::
PatternRewriter
&
rewriter
)
const
override
{
// NOLINT
// The next op should be batch_norm.
// The next op should be batch_norm.
paddle
::
dialect
::
Conv2dOp
conv2d_op
=
paddle
::
dialect
::
Conv2dOp
conv2d_op
=
ir
::
GetDefiningOpForInput
(
op
)
->
dyn_cast
<
paddle
::
dialect
::
Conv2dOp
>
();
ir
::
GetDefiningOpForInput
(
op
,
0
)
->
dyn_cast
<
paddle
::
dialect
::
Conv2dOp
>
();
if
(
!
conv2d_op
)
return
false
;
if
(
!
conv2d_op
)
return
false
;
ir
::
OpResult
conv2d_out
=
conv2d_op
.
out
();
ir
::
OpResult
conv2d_out
=
conv2d_op
.
out
();
...
@@ -320,7 +335,6 @@ class Conv2dBnFusePattern
...
@@ -320,7 +335,6 @@ class Conv2dBnFusePattern
std
::
string
data_format
=
std
::
string
data_format
=
new_conv2d_op
.
attribute
<
ir
::
StrAttribute
>
(
"data_format"
).
data
();
new_conv2d_op
.
attribute
<
ir
::
StrAttribute
>
(
"data_format"
).
data
();
IR_ENFORCE
(
data_format
==
"NCHW"
,
"Only support NCHW now."
);
IR_ENFORCE
(
data_format
==
"NCHW"
,
"Only support NCHW now."
);
new_bias_new_shape
[
0
]
=
new_conv2d_out_shape
[
0
];
new_bias_new_shape
[
1
]
=
new_conv2d_out_shape
[
1
];
new_bias_new_shape
[
1
]
=
new_conv2d_out_shape
[
1
];
paddle
::
dialect
::
ReshapeOp
reshape_bias_op
=
paddle
::
dialect
::
ReshapeOp
reshape_bias_op
=
rewriter
.
Build
<
paddle
::
dialect
::
ReshapeOp
>
(
sub_op
.
out
(),
rewriter
.
Build
<
paddle
::
dialect
::
ReshapeOp
>
(
sub_op
.
out
(),
...
@@ -895,7 +909,7 @@ class Conv2dAddFusePattern
...
@@ -895,7 +909,7 @@ class Conv2dAddFusePattern
ir
::
PatternRewriter
&
rewriter
)
const
override
{
// NOLINT
ir
::
PatternRewriter
&
rewriter
)
const
override
{
// NOLINT
// The next op should be add.
// The next op should be add.
paddle
::
dialect
::
Conv2dOp
conv2d_op
=
paddle
::
dialect
::
Conv2dOp
conv2d_op
=
ir
::
GetDefiningOpForInput
(
op
)
->
dyn_cast
<
paddle
::
dialect
::
Conv2dOp
>
();
ir
::
GetDefiningOpForInput
(
op
,
0
)
->
dyn_cast
<
paddle
::
dialect
::
Conv2dOp
>
();
if
(
!
conv2d_op
)
return
false
;
if
(
!
conv2d_op
)
return
false
;
ir
::
OpResult
conv2d_out
=
conv2d_op
.
out
();
ir
::
OpResult
conv2d_out
=
conv2d_op
.
out
();
...
@@ -929,12 +943,10 @@ class Conv2dAddFusePattern
...
@@ -929,12 +943,10 @@ class Conv2dAddFusePattern
conv2d_attributes
.
at
(
"dilations"
),
conv2d_attributes
.
at
(
"dilations"
),
conv2d_attributes
.
at
(
"groups"
),
conv2d_attributes
.
at
(
"groups"
),
conv2d_attributes
.
at
(
"data_format"
),
conv2d_attributes
.
at
(
"data_format"
),
ir
::
StrAttribute
::
get
(
ir
::
IrContext
::
Instance
(),
"identity"
),
rewriter
.
str_attr
(
"identity"
),
ir
::
BoolAttribute
::
get
(
ir
::
IrContext
::
Instance
(),
true
),
rewriter
.
bool_attr
(
true
),
ir
::
ArrayAttribute
::
get
(
ir
::
IrContext
::
Instance
(),
rewriter
.
array_attr
(
std
::
vector
<
ir
::
Attribute
>
{}),
std
::
vector
<
ir
::
Attribute
>
()),
rewriter
.
int32_attr
(
0
)};
ir
::
Int32Attribute
::
get
(
ir
::
IrContext
::
Instance
(),
int32_t
(
0
)),
};
ir
::
AttributeMap
conv2d_fusion_attributes
;
ir
::
AttributeMap
conv2d_fusion_attributes
;
for
(
size_t
i
=
0
;
i
<
conv2d_fusion_attrStr
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
conv2d_fusion_attrStr
.
size
();
++
i
)
{
conv2d_fusion_attributes
[
conv2d_fusion_attrStr
[
i
]]
=
con2d_fusing_attr
[
i
];
conv2d_fusion_attributes
[
conv2d_fusion_attrStr
[
i
]]
=
con2d_fusing_attr
[
i
];
...
@@ -943,7 +955,7 @@ class Conv2dAddFusePattern
...
@@ -943,7 +955,7 @@ class Conv2dAddFusePattern
ir
::
OpResult
tmpResidual
;
ir
::
OpResult
tmpResidual
;
auto
conv2d_fuse_op
=
rewriter
.
Build
<
paddle
::
dialect
::
Conv2dFusionOpTest
>
(
auto
conv2d_fuse_op
=
rewriter
.
Build
<
paddle
::
dialect
::
Conv2dFusionOpTest
>
(
ir
::
GetDefiningOpForInput
<
0
>
(
conv2d_op
)
->
result
(
0
),
ir
::
GetDefiningOpForInput
(
conv2d_op
,
0
)
->
result
(
0
),
conv2d_filter_result
,
conv2d_filter_result
,
bias
,
bias
,
tmpResidual
,
tmpResidual
,
...
@@ -956,27 +968,48 @@ class Conv2dAddFusePattern
...
@@ -956,27 +968,48 @@ class Conv2dAddFusePattern
class
TestPass
:
public
ir
::
Pass
{
class
TestPass
:
public
ir
::
Pass
{
public:
public:
TestPass
()
:
ir
::
Pass
(
"TestPass"
,
1
)
{}
TestPass
()
:
ir
::
Pass
(
"TestPass"
,
1
)
{}
void
Run
(
ir
::
Operation
*
op
)
override
{
ir
::
RewritePatternSet
ps
(
op
->
ir_context
());
ps
.
Add
<
RedundantTransposeFusePattern
>
(
op
->
ir_context
());
ps
.
Add
<
Conv2dBnFusePattern
>
(
op
->
ir_context
());
ps
.
Add
<
Conv2dAddFusePattern
>
(
op
->
ir_context
());
ir
::
FrozenRewritePatternSet
frozen_ps
(
std
::
move
(
ps
));
bool
Initialize
(
ir
::
IrContext
*
context
)
override
{
ir
::
RewritePatternSet
ps
(
context
);
ps
.
Add
<
RedundantTransposeFusePattern
>
(
context
);
auto
conv_bn_pattern
=
std
::
make_unique
<
Conv2dBnFusePattern
>
(
context
,
1
,
std
::
vector
<
std
::
string
>
{
paddle
::
dialect
::
FullOp
::
name
(),
paddle
::
dialect
::
AddOp
::
name
(),
paddle
::
dialect
::
SqrtOp
::
name
(),
paddle
::
dialect
::
DivideOp
::
name
(),
paddle
::
dialect
::
ReshapeOp
::
name
(),
paddle
::
dialect
::
MultiplyOp
::
name
(),
paddle
::
dialect
::
SubtractOp
::
name
(),
paddle
::
dialect
::
Conv2dOp
::
name
()});
LOG
(
INFO
)
<<
"Conv2dBnFusePattern will generate the following operations: "
;
for
(
auto
op_info
:
conv_bn_pattern
->
generated_ops
())
{
LOG
(
INFO
)
<<
"--- "
<<
op_info
.
name
();
}
ps
.
Add
(
std
::
move
(
conv_bn_pattern
));
patterns_
=
ir
::
FrozenRewritePatternSet
(
std
::
move
(
ps
));
return
true
;
}
void
Run
(
ir
::
Operation
*
op
)
override
{
ir
::
GreedyRewriteConfig
cfg
;
ir
::
GreedyRewriteConfig
cfg
;
cfg
.
use_top_down_traversal
=
true
;
cfg
.
use_top_down_traversal
=
true
;
cfg
.
max_iterations
=
10
;
cfg
.
max_iterations
=
10
;
ir
::
ApplyPatternsGreedily
(
op
->
region
(
0
),
frozen_ps
,
cfg
);
ir
::
ApplyPatternsGreedily
(
op
->
region
(
0
),
patterns_
,
cfg
);
}
}
bool
CanApplyOn
(
ir
::
Operation
*
op
)
const
override
{
bool
CanApplyOn
(
ir
::
Operation
*
op
)
const
override
{
return
op
->
name
()
==
"builtin.module"
&&
op
->
num_regions
()
>
0
;
return
op
->
name
()
==
"builtin.module"
&&
op
->
num_regions
()
>
0
;
}
}
private:
ir
::
FrozenRewritePatternSet
patterns_
;
};
};
void
BuildProgram
(
ir
::
Builder
&
builder
)
{
// NOLINT
void
BuildProgram
(
ir
::
Builder
&
builder
)
{
// NOLINT
paddle
::
dialect
::
FullOp
full_input_op
=
paddle
::
dialect
::
FullOp
full_input_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
1
,
3
,
16
,
16
},
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
4
,
3
,
16
,
16
},
1.5
,
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
phi
::
CPUPlace
());
...
@@ -1045,11 +1078,19 @@ TEST(pattern_rewrite, Patterns) {
...
@@ -1045,11 +1078,19 @@ TEST(pattern_rewrite, Patterns) {
ir
::
PassManager
pm
(
ctx
);
ir
::
PassManager
pm
(
ctx
);
pm
.
AddPass
(
std
::
make_unique
<
TestPass
>
());
pm
.
AddPass
(
std
::
make_unique
<
TestPass
>
());
pm
.
AddPass
(
ir
::
CreateDcePass
());
pm
.
AddPass
(
ir
::
CreateConstantFoldingPass
());
program
.
Print
(
std
::
cout
);
pm
.
AddPass
(
ir
::
CreateDeadCodeEliminationPass
());
std
::
cout
<<
std
::
endl
;
pm
.
EnablePassTiming
();
pm
.
Run
(
&
program
);
pm
.
EnableIRPrinting
();
LOG
(
INFO
)
<<
"After Pass."
;
// pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
program
.
Print
(
std
::
cout
);
// [](ir::Pass *pass, ir::Operation *op) {
std
::
cout
<<
std
::
endl
;
// return pass->name() == "ConstantFoldingPass";
// },
// [](ir::Pass *pass, ir::Operation *op) {
// return pass->name() == "ConstantFoldingPass";
// },
// true,
// true));
CHECK_EQ
(
pm
.
Run
(
&
program
),
true
);
}
}
test/cpp/new_executor/standalone_executor_new_ir_test.cc
浏览文件 @
4905a247
...
@@ -69,7 +69,6 @@ TEST(StandaloneExecutor, run) {
...
@@ -69,7 +69,6 @@ TEST(StandaloneExecutor, run) {
auto
place
=
platform
::
CPUPlace
();
auto
place
=
platform
::
CPUPlace
();
Scope
scope
;
Scope
scope
;
ProgramDesc
prog_desc
;
InterpreterCore
test_core
(
place
,
std
::
move
(
kernel_program
),
&
scope
);
InterpreterCore
test_core
(
place
,
std
::
move
(
kernel_program
),
&
scope
);
test_core
.
Run
({});
test_core
.
Run
({});
...
@@ -141,8 +140,6 @@ TEST(StandaloneExecutor, run_2) {
...
@@ -141,8 +140,6 @@ TEST(StandaloneExecutor, run_2) {
auto
place
=
platform
::
CPUPlace
();
auto
place
=
platform
::
CPUPlace
();
Scope
scope
;
Scope
scope
;
ProgramDesc
prog_desc
;
InterpreterCore
test_core
(
place
,
std
::
move
(
kernel_program
),
&
scope
);
InterpreterCore
test_core
(
place
,
std
::
move
(
kernel_program
),
&
scope
);
test_core
.
Run
({});
test_core
.
Run
({});
...
@@ -216,8 +213,6 @@ TEST(StandaloneExecutor, data_transfer) {
...
@@ -216,8 +213,6 @@ TEST(StandaloneExecutor, data_transfer) {
auto
place
=
platform
::
CPUPlace
();
auto
place
=
platform
::
CPUPlace
();
Scope
scope
;
Scope
scope
;
ProgramDesc
prog_desc
;
InterpreterCore
test_core
(
place
,
std
::
move
(
kernel_program
),
&
scope
);
InterpreterCore
test_core
(
place
,
std
::
move
(
kernel_program
),
&
scope
);
test_core
.
Run
({});
test_core
.
Run
({});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录