Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
19345fa7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
19345fa7
编写于
6月 30, 2023
作者:
Y
Yuanle Liu
提交者:
GitHub
6月 30, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IR&PASS] add conv + bn fuse pattern, and other works (#54933)
上级
0f69d932
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
331 addition
and
44 deletion
+331
-44
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
+1
-1
paddle/fluid/framework/new_executor/standalone_executor.cc
paddle/fluid/framework/new_executor/standalone_executor.cc
+1
-1
paddle/fluid/ir/CMakeLists.txt
paddle/fluid/ir/CMakeLists.txt
+1
-1
paddle/fluid/ir/pass/CMakeLists.txt
paddle/fluid/ir/pass/CMakeLists.txt
+0
-7
paddle/fluid/ir/transforms/CMakeLists.txt
paddle/fluid/ir/transforms/CMakeLists.txt
+9
-0
paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc
paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc
+1
-1
paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h
paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h
+0
-0
paddle/fluid/ir/transforms/transform_general_functions.cc
paddle/fluid/ir/transforms/transform_general_functions.cc
+47
-0
paddle/fluid/ir/transforms/transform_general_functions.h
paddle/fluid/ir/transforms/transform_general_functions.h
+79
-0
paddle/ir/core/operation.cc
paddle/ir/core/operation.cc
+5
-3
paddle/ir/core/program.cc
paddle/ir/core/program.cc
+2
-2
paddle/ir/core/program.h
paddle/ir/core/program.h
+3
-2
paddle/ir/pass/pass.cc
paddle/ir/pass/pass.cc
+3
-3
paddle/ir/pattern_rewrite/pattern_match.h
paddle/ir/pattern_rewrite/pattern_match.h
+1
-1
test/cpp/ir/core/ir_exe_test.cc
test/cpp/ir/core/ir_exe_test.cc
+1
-1
test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc
test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc
+1
-1
test/cpp/ir/pass/pass_manager_test.cc
test/cpp/ir/pass/pass_manager_test.cc
+2
-1
test/cpp/ir/pattern_rewrite/CMakeLists.txt
test/cpp/ir/pattern_rewrite/CMakeLists.txt
+1
-0
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
+172
-18
test/cpp/new_executor/standalone_executor_new_ir_test.cc
test/cpp/new_executor/standalone_executor_new_ir_test.cc
+1
-1
未找到文件。
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
浏览文件 @
19345fa7
...
@@ -347,7 +347,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -347,7 +347,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
return
;
return
;
}
}
// conv_weight fp
32 --> fp16
// conv_weight fp
16 --> fp32
auto
*
conv_weight_tensor
=
auto
*
conv_weight_tensor
=
scope
->
FindVar
(
conv_weight
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
conv_weight
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
tensor_type
=
conv_weight_tensor
->
dtype
();
auto
tensor_type
=
conv_weight_tensor
->
dtype
();
...
...
paddle/fluid/framework/new_executor/standalone_executor.cc
浏览文件 @
19345fa7
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/ir/
pas
s/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/
transform
s/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
...
...
paddle/fluid/ir/CMakeLists.txt
浏览文件 @
19345fa7
add_subdirectory
(
interface
)
add_subdirectory
(
interface
)
add_subdirectory
(
dialect
)
add_subdirectory
(
dialect
)
add_subdirectory
(
pas
s
)
add_subdirectory
(
transform
s
)
add_subdirectory
(
phi_kernel_adaptor
)
add_subdirectory
(
phi_kernel_adaptor
)
paddle/fluid/ir/pass/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
0f69d932
# All source files of pd_dialect, except for the source file of op, which is generated in the compilation directory.
file
(
GLOB PD_PASS_SRCS
"*.cc"
)
cc_library
(
pd_op_to_kernel_pass
SRCS
${
PD_PASS_SRCS
}
DEPS ir phi_utils pd_interface
)
paddle/fluid/ir/transforms/CMakeLists.txt
0 → 100644
浏览文件 @
19345fa7
cc_library
(
transform_general_functions
SRCS transform_general_functions.cc
DEPS ir phi pd_dialect
)
cc_library
(
pd_op_to_kernel_pass
SRCS pd_op_to_kernel_pass.cc
DEPS ir phi_utils pd_interface
)
paddle/fluid/ir/
pas
s/pd_op_to_kernel_pass.cc
→
paddle/fluid/ir/
transform
s/pd_op_to_kernel_pass.cc
浏览文件 @
19345fa7
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include <iostream>
#include <iostream>
#include "paddle/fluid/ir/
pas
s/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/
transform
s/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/dialect/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/kernel_dialect.h"
#include "paddle/fluid/ir/dialect/kernel_dialect.h"
...
...
paddle/fluid/ir/
pas
s/pd_op_to_kernel_pass.h
→
paddle/fluid/ir/
transform
s/pd_op_to_kernel_pass.h
浏览文件 @
19345fa7
文件已移动
paddle/fluid/ir/transforms/transform_general_functions.cc
0 → 100644
浏览文件 @
19345fa7
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
namespace
ir
{
ir
::
Parameter
*
GetParameterFromValue
(
ir
::
Value
value
)
{
ir
::
GetParameterOp
op
=
value
.
GetDefiningOp
()
->
dyn_cast
<
ir
::
GetParameterOp
>
();
PADDLE_ENFORCE_NOT_NULL
(
op
,
phi
::
errors
::
InvalidArgument
(
"Value must be a weight from a GetParameter op."
));
ir
::
Program
*
program
=
op
->
GetParentProgram
();
std
::
string
name
=
op
->
attributes
()
.
at
(
op
.
attributes_name
[
0
])
.
dyn_cast
<
ir
::
StrAttribute
>
()
.
data
();
return
program
->
GetParameter
(
name
);
}
const
phi
::
DDim
&
GetShapeFromValue
(
ir
::
Value
value
)
{
// TODO(dev): Support other types like DenseTensor.
PADDLE_ENFORCE_EQ
(
value
.
type
().
isa
<
paddle
::
dialect
::
DenseTensorType
>
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Value's type must be a DenseTensorType."
));
return
value
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
}
}
// namespace ir
paddle/fluid/ir/transforms/transform_general_functions.h
0 → 100644
浏览文件 @
19345fa7
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace
ir
{
/**
* @brief Get the parameter from a value.
*
* @note The value must be a output of a GetParameterOp.
*
* @param ir::Value
*
* @return ir::Parameter*
*/
ir
::
Parameter
*
GetParameterFromValue
(
ir
::
Value
value
);
/**
* @brief Get tensor's shape from a value.
*
* @param ir::Value
*
* @return const phi::DDim&
*/
const
phi
::
DDim
&
GetShapeFromValue
(
ir
::
Value
value
);
/**
* @brief Get an operation that defines the specific input of the operation.
*
* @param Operation*
*
* @return Operation*
*/
template
<
uint32_t
Index
=
0
>
Operation
*
GetDefiningOpForInput
(
Operation
*
op
)
{
PADDLE_ENFORCE_EQ
(
Index
<
op
->
num_operands
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Intput operand's index must be valid."
));
return
op
->
operand
(
Index
).
GetDefiningOp
();
}
/**
* @brief Get an operation that is the first to use the specific output of the
* operation.
*
* @param Operation*
*
* @return Operation*
*/
template
<
uint32_t
Index
=
0
>
Operation
*
GetFirstUseOperationForOutput
(
Operation
*
op
)
{
PADDLE_ENFORCE_EQ
(
Index
<
op
->
num_results
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Output op result's index must be valid."
));
return
op
->
result
(
Index
).
first_use
().
owner
();
}
}
// namespace ir
paddle/ir/core/operation.cc
浏览文件 @
19345fa7
...
@@ -107,6 +107,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
...
@@ -107,6 +107,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
// Call destructors for Region , OpResults, Operation, and OpOperands in
// Call destructors for Region , OpResults, Operation, and OpOperands in
// sequence, and finally free memory.
// sequence, and finally free memory.
void
Operation
::
Destroy
()
{
void
Operation
::
Destroy
()
{
VLOG
(
6
)
<<
"Destroy Operation ["
<<
name
()
<<
"] ..."
;
// 1. Deconstruct Regions.
// 1. Deconstruct Regions.
if
(
num_regions_
>
0
)
{
if
(
num_regions_
>
0
)
{
for
(
size_t
idx
=
0
;
idx
<
num_regions_
;
idx
++
)
{
for
(
size_t
idx
=
0
;
idx
<
num_regions_
;
idx
++
)
{
...
@@ -117,7 +118,8 @@ void Operation::Destroy() {
...
@@ -117,7 +118,8 @@ void Operation::Destroy() {
// 2. Deconstruct Result.
// 2. Deconstruct Result.
for
(
size_t
idx
=
0
;
idx
<
num_results_
;
++
idx
)
{
for
(
size_t
idx
=
0
;
idx
<
num_results_
;
++
idx
)
{
detail
::
OpResultImpl
*
impl
=
result
(
idx
).
impl
();
detail
::
OpResultImpl
*
impl
=
result
(
idx
).
impl
();
IR_ENFORCE
(
impl
->
use_empty
(),
"operation destroyed but still has uses."
);
IR_ENFORCE
(
impl
->
use_empty
(),
name
()
+
" operation destroyed but still has uses."
);
if
(
detail
::
OpOutlineResultImpl
::
classof
(
*
impl
))
{
if
(
detail
::
OpOutlineResultImpl
::
classof
(
*
impl
))
{
static_cast
<
detail
::
OpOutlineResultImpl
*>
(
impl
)
->~
OpOutlineResultImpl
();
static_cast
<
detail
::
OpOutlineResultImpl
*>
(
impl
)
->~
OpOutlineResultImpl
();
}
else
{
}
else
{
...
@@ -143,8 +145,8 @@ void Operation::Destroy() {
...
@@ -143,8 +145,8 @@ void Operation::Destroy() {
:
sizeof
(
detail
::
OpInlineResultImpl
)
*
num_results_
;
:
sizeof
(
detail
::
OpInlineResultImpl
)
*
num_results_
;
void
*
aligned_ptr
=
reinterpret_cast
<
char
*>
(
this
)
-
result_mem_size
;
void
*
aligned_ptr
=
reinterpret_cast
<
char
*>
(
this
)
-
result_mem_size
;
VLOG
(
4
)
<<
"Destroy an Operation
: {ptr = "
<<
aligned_ptr
VLOG
(
6
)
<<
"Destroy Operation ["
<<
name
()
<<
"]
: {ptr = "
<<
aligned_ptr
<<
", size = "
<<
result_mem_size
<<
"}"
;
<<
", size = "
<<
result_mem_size
<<
"}
done.
"
;
aligned_free
(
aligned_ptr
);
aligned_free
(
aligned_ptr
);
}
}
...
...
paddle/ir/core/program.cc
浏览文件 @
19345fa7
...
@@ -27,14 +27,14 @@ Program::~Program() {
...
@@ -27,14 +27,14 @@ Program::~Program() {
}
}
}
}
Parameter
*
Program
::
GetParameter
(
std
::
string
name
)
const
{
Parameter
*
Program
::
GetParameter
(
const
std
::
string
&
name
)
const
{
if
(
parameters_
.
count
(
name
)
!=
0
)
{
if
(
parameters_
.
count
(
name
)
!=
0
)
{
return
parameters_
.
at
(
name
).
get
();
return
parameters_
.
at
(
name
).
get
();
}
}
return
nullptr
;
return
nullptr
;
}
}
void
Program
::
SetParameter
(
std
::
string
name
,
void
Program
::
SetParameter
(
const
std
::
string
&
name
,
std
::
unique_ptr
<
Parameter
>&&
parameter
)
{
std
::
unique_ptr
<
Parameter
>&&
parameter
)
{
parameters_
[
name
].
reset
(
parameter
.
release
());
parameters_
[
name
].
reset
(
parameter
.
release
());
}
}
...
...
paddle/ir/core/program.h
浏览文件 @
19345fa7
...
@@ -54,8 +54,9 @@ class IR_API Program {
...
@@ -54,8 +54,9 @@ class IR_API Program {
Block
*
block
()
{
return
module_
.
block
();
}
Block
*
block
()
{
return
module_
.
block
();
}
Parameter
*
GetParameter
(
std
::
string
name
)
const
;
Parameter
*
GetParameter
(
const
std
::
string
&
name
)
const
;
void
SetParameter
(
std
::
string
name
,
std
::
unique_ptr
<
Parameter
>&&
parameter
);
void
SetParameter
(
const
std
::
string
&
name
,
std
::
unique_ptr
<
Parameter
>&&
parameter
);
ParameterMap
&
parameters
()
{
return
parameters_
;
}
ParameterMap
&
parameters
()
{
return
parameters_
;
}
void
set_parameters
(
ParameterMap
&&
parameters
)
{
void
set_parameters
(
ParameterMap
&&
parameters
)
{
...
...
paddle/ir/pass/pass.cc
浏览文件 @
19345fa7
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/verify.h"
#include "paddle/ir/pass/pass_adaptor.h"
#include "paddle/ir/pass/pass_adaptor.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/pass_manager.h"
...
@@ -109,10 +110,9 @@ bool detail::PassAdaptor::RunPass(Pass* pass,
...
@@ -109,10 +110,9 @@ bool detail::PassAdaptor::RunPass(Pass* pass,
bool
pass_failed
=
pass
->
pass_state
().
pass_failed
;
bool
pass_failed
=
pass
->
pass_state
().
pass_failed
;
// TODO(liuyuanle): Support verification of operation
if
(
!
pass_failed
&&
verify
)
{
if
(
!
pass_failed
&&
verify
)
{
//
bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass);
bool
verify_recursively
=
!
dynamic_cast
<
PassAdaptor
*>
(
pass
);
// pass_failed =
ir::Verify(op, verify_recursively);
ir
::
Verify
(
op
,
verify_recursively
);
}
}
return
!
pass_failed
;
return
!
pass_failed
;
...
...
paddle/ir/pattern_rewrite/pattern_match.h
浏览文件 @
19345fa7
...
@@ -274,7 +274,7 @@ class RewriterBase : public Builder {
...
@@ -274,7 +274,7 @@ class RewriterBase : public Builder {
virtual
void
EraseOp
(
Operation
*
op
);
virtual
void
EraseOp
(
Operation
*
op
);
void
ReplaceAllUsesWith
(
Value
from
,
Value
to
);
IR_API
void
ReplaceAllUsesWith
(
Value
from
,
Value
to
);
void
ReplaceUseIf
(
Value
from
,
void
ReplaceUseIf
(
Value
from
,
Value
to
,
Value
to
,
...
...
test/cpp/ir/core/ir_exe_test.cc
浏览文件 @
19345fa7
...
@@ -42,8 +42,8 @@
...
@@ -42,8 +42,8 @@
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
...
...
test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc
浏览文件 @
19345fa7
...
@@ -26,8 +26,8 @@
...
@@ -26,8 +26,8 @@
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_dialect.h"
...
...
test/cpp/ir/pass/pass_manager_test.cc
浏览文件 @
19345fa7
...
@@ -94,7 +94,8 @@ IR_DEFINE_EXPLICIT_TYPE_ID(AddOp)
...
@@ -94,7 +94,8 @@ IR_DEFINE_EXPLICIT_TYPE_ID(AddOp)
struct
CountOpAnalysis
{
struct
CountOpAnalysis
{
explicit
CountOpAnalysis
(
ir
::
Operation
*
container_op
)
{
explicit
CountOpAnalysis
(
ir
::
Operation
*
container_op
)
{
IR_ENFORCE
(
container_op
->
num_regions
()
>
0
,
true
);
IR_ENFORCE
(
container_op
->
num_regions
()
>
0
,
"op must be a container with zero or multiple regions."
);
LOG
(
INFO
)
<<
"In CountOpAnalysis, op is "
<<
container_op
->
name
()
<<
"
\n
"
;
LOG
(
INFO
)
<<
"In CountOpAnalysis, op is "
<<
container_op
->
name
()
<<
"
\n
"
;
for
(
size_t
i
=
0
;
i
<
container_op
->
num_regions
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
container_op
->
num_regions
();
++
i
)
{
...
...
test/cpp/ir/pattern_rewrite/CMakeLists.txt
浏览文件 @
19345fa7
...
@@ -5,4 +5,5 @@ cc_test_old(
...
@@ -5,4 +5,5 @@ cc_test_old(
DEPS
DEPS
ir
ir
pd_dialect
pd_dialect
transform_general_functions
gtest
)
gtest
)
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
浏览文件 @
19345fa7
...
@@ -13,12 +13,14 @@
...
@@ -13,12 +13,14 @@
// limitations under the License.
// limitations under the License.
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <cstdint>
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include <vector>
#include <vector>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_dialect.h"
...
@@ -28,7 +30,9 @@
...
@@ -28,7 +30,9 @@
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/value.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
...
@@ -39,9 +43,11 @@
...
@@ -39,9 +43,11 @@
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/phi/core/ddim.h"
// Define op1.
// Define op1.
class
Operation1
:
public
ir
::
Op
<
Operation1
>
{
class
Operation1
:
public
ir
::
Op
<
Operation1
>
{
...
@@ -53,6 +59,7 @@ class Operation1 : public ir::Op<Operation1> {
...
@@ -53,6 +59,7 @@ class Operation1 : public ir::Op<Operation1> {
void
Verify
();
void
Verify
();
static
void
InferShape
()
{
VLOG
(
2
)
<<
"This is op2's InferShape interface."
;
}
static
void
InferShape
()
{
VLOG
(
2
)
<<
"This is op2's InferShape interface."
;
}
};
};
void
Operation1
::
Verify
()
{
void
Operation1
::
Verify
()
{
auto
&
attributes
=
this
->
attributes
();
auto
&
attributes
=
this
->
attributes
();
if
(
attributes
.
count
(
"op2_attr1"
)
==
0
||
if
(
attributes
.
count
(
"op2_attr1"
)
==
0
||
...
@@ -183,7 +190,7 @@ class TransposePatternRewrite
...
@@ -183,7 +190,7 @@ class TransposePatternRewrite
bool
MatchAndRewrite
(
paddle
::
dialect
::
TransposeOp
op
,
bool
MatchAndRewrite
(
paddle
::
dialect
::
TransposeOp
op
,
ir
::
PatternRewriter
&
rewriter
)
const
override
{
ir
::
PatternRewriter
&
rewriter
)
const
override
{
auto
prev_op
=
op
->
operand
(
0
).
GetDefiningOp
(
);
auto
prev_op
=
ir
::
GetDefiningOpForInput
<
0
>
(
op
);
std
::
vector
<
int
>
axis_last
=
GetAxis
(
op
);
std
::
vector
<
int
>
axis_last
=
GetAxis
(
op
);
auto
prev_trans_op
=
prev_op
->
dyn_cast
<
paddle
::
dialect
::
TransposeOp
>
();
auto
prev_trans_op
=
prev_op
->
dyn_cast
<
paddle
::
dialect
::
TransposeOp
>
();
if
(
prev_trans_op
)
{
if
(
prev_trans_op
)
{
...
@@ -192,9 +199,9 @@ class TransposePatternRewrite
...
@@ -192,9 +199,9 @@ class TransposePatternRewrite
"tranpose op's perm rank should be same."
);
"tranpose op's perm rank should be same."
);
auto
new_perm
=
GetPerm
(
axis_first
,
axis_last
);
auto
new_perm
=
GetPerm
(
axis_first
,
axis_last
);
rewriter
.
SetInsertionPoint
(
op
);
rewriter
.
SetInsertionPoint
(
op
);
auto
new_op
=
rewriter
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
auto
new_
transpose_
op
=
rewriter
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
prev_op
->
operand
(
0
).
GetDefiningOp
(
)
->
result
(
0
),
new_perm
);
ir
::
GetDefiningOpForInput
<
0
>
(
prev_trans_op
)
->
result
(
0
),
new_perm
);
rewriter
.
ReplaceOp
(
op
,
{
new_op
.
out
()});
rewriter
.
ReplaceOp
(
op
,
{
new_
transpose_
op
.
out
()});
return
true
;
return
true
;
}
}
...
@@ -203,9 +210,7 @@ class TransposePatternRewrite
...
@@ -203,9 +210,7 @@ class TransposePatternRewrite
private:
private:
std
::
vector
<
int
>
GetAxis
(
paddle
::
dialect
::
TransposeOp
op
)
const
{
std
::
vector
<
int
>
GetAxis
(
paddle
::
dialect
::
TransposeOp
op
)
const
{
auto
attr_map
=
op
->
attributes
();
auto
array_attr
=
op
.
attribute
<
ir
::
ArrayAttribute
>
(
"perm"
).
data
();
ir
::
ArrayAttribute
array_attr
=
attr_map
.
at
(
"perm"
).
dyn_cast
<
ir
::
ArrayAttribute
>
();
std
::
vector
<
int
>
axis
(
array_attr
.
size
());
std
::
vector
<
int
>
axis
(
array_attr
.
size
());
for
(
size_t
i
=
0
;
i
<
array_attr
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
array_attr
.
size
();
++
i
)
{
axis
[
i
]
=
array_attr
[
i
].
dyn_cast
<
ir
::
Int32Attribute
>
().
data
();
axis
[
i
]
=
array_attr
[
i
].
dyn_cast
<
ir
::
Int32Attribute
>
().
data
();
...
@@ -228,12 +233,121 @@ class TransposePatternRewrite
...
@@ -228,12 +233,121 @@ class TransposePatternRewrite
}
}
};
};
class
Conv2dBnFusePattern
:
public
ir
::
OpRewritePattern
<
paddle
::
dialect
::
BatchNormOp
>
{
public:
using
ir
::
OpRewritePattern
<
paddle
::
dialect
::
BatchNormOp
>::
OpRewritePattern
;
bool
MatchAndRewrite
(
paddle
::
dialect
::
BatchNormOp
op
,
ir
::
PatternRewriter
&
rewriter
)
const
override
{
// NOLINT
// The next op should be batch_norm.
paddle
::
dialect
::
Conv2dOp
conv2d_op
=
ir
::
GetDefiningOpForInput
(
op
)
->
dyn_cast
<
paddle
::
dialect
::
Conv2dOp
>
();
if
(
!
conv2d_op
)
return
false
;
ir
::
OpResult
conv2d_out
=
conv2d_op
.
out
();
if
(
!
conv2d_out
.
HasOneUse
())
return
false
;
ir
::
Value
conv2d_filter
=
conv2d_op
.
filter
();
// ir::GetParameterOp filter_parameter_op =
// conv2d_filter.GetDefiningOp()->dyn_cast<ir::GetParameterOp>();
// if (!filter_parameter_op) return false;
ir
::
OpResult
conv2d_filter_result
=
conv2d_filter
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
conv2d_filter_result
);
ir
::
Value
bn_input
=
op
.
x
();
IR_ENFORCE
(
bn_input
==
conv2d_out
);
ir
::
Value
bn_mean
=
op
.
mean
();
ir
::
Value
bn_variance
=
op
.
variance
();
ir
::
Value
bn_scale
=
op
.
scale
();
ir
::
Value
bn_bias
=
op
.
bias
();
ir
::
OpResult
bn_mean_result
=
bn_mean
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
bn_mean_result
);
ir
::
OpResult
bn_variance_result
=
bn_variance
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
bn_variance_result
);
ir
::
OpResult
bn_scale_result
=
bn_scale
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
bn_scale_result
);
ir
::
OpResult
bn_bias_result
=
bn_bias
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
bn_bias_result
);
// --- deal with filter ---
rewriter
.
SetInsertionPoint
(
conv2d_op
);
phi
::
DDim
bn_variance_shape
=
bn_variance
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
float
epsilon
=
op
.
attribute
<
ir
::
FloatAttribute
>
(
"epsilon"
).
data
();
paddle
::
dialect
::
FullOp
full_op
=
rewriter
.
Build
<
paddle
::
dialect
::
FullOp
>
(
phi
::
vectorize
(
bn_variance_shape
),
epsilon
);
paddle
::
dialect
::
AddOp
add_op
=
rewriter
.
Build
<
paddle
::
dialect
::
AddOp
>
(
bn_variance_result
,
full_op
.
out
());
paddle
::
dialect
::
SqrtOp
sqrt_op
=
rewriter
.
Build
<
paddle
::
dialect
::
SqrtOp
>
(
add_op
.
out
());
paddle
::
dialect
::
DivideOp
div_op
=
rewriter
.
Build
<
paddle
::
dialect
::
DivideOp
>
(
bn_scale_result
,
sqrt_op
.
out
());
// reshape scale
phi
::
DDim
conv2d_filter_shape
=
ir
::
GetShapeFromValue
(
conv2d_filter
);
phi
::
DDim
bn_scale_shape
=
bn_scale
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
std
::
vector
<
int64_t
>
bn_scale_new_shape
(
conv2d_filter_shape
.
size
(),
1
);
bn_scale_new_shape
[
0
]
=
bn_scale_shape
[
0
];
paddle
::
dialect
::
ReshapeOp
reshape_scale_op
=
rewriter
.
Build
<
paddle
::
dialect
::
ReshapeOp
>
(
div_op
.
out
(),
bn_scale_new_shape
);
// new filter --> mul_op.out()
paddle
::
dialect
::
MultiplyOp
mul_op
=
rewriter
.
Build
<
paddle
::
dialect
::
MultiplyOp
>
(
conv2d_filter_result
,
reshape_scale_op
.
out
());
// TODO(liuyuanle): Use rewriter.
conv2d_op
->
op_operand
(
1
).
set_source
(
mul_op
.
out
());
// --- deal with bias ---
rewriter
.
SetInsertionPoint
(
op
);
paddle
::
dialect
::
MultiplyOp
mul_bias_op
=
rewriter
.
Build
<
paddle
::
dialect
::
MultiplyOp
>
(
bn_mean_result
,
div_op
.
out
());
// new bias --> sub_op.out()
paddle
::
dialect
::
SubtractOp
sub_op
=
rewriter
.
Build
<
paddle
::
dialect
::
SubtractOp
>
(
bn_bias_result
,
mul_bias_op
.
out
());
// reshape new bias
phi
::
DDim
conv2d_out_shape
=
ir
::
GetShapeFromValue
(
conv2d_out
);
std
::
vector
<
int64_t
>
new_bias_new_shape
(
conv2d_out_shape
.
size
(),
1
);
std
::
string
data_format
=
conv2d_op
.
attribute
<
ir
::
StrAttribute
>
(
"data_format"
).
data
();
IR_ENFORCE
(
data_format
==
"NCHW"
,
"Only support NCHW now."
);
new_bias_new_shape
[
0
]
=
conv2d_out_shape
[
0
];
new_bias_new_shape
[
1
]
=
conv2d_out_shape
[
1
];
paddle
::
dialect
::
ReshapeOp
reshape_bias_op
=
rewriter
.
Build
<
paddle
::
dialect
::
ReshapeOp
>
(
sub_op
.
out
(),
new_bias_new_shape
);
paddle
::
dialect
::
AddOp
add_bias_op
=
rewriter
.
Build
<
paddle
::
dialect
::
AddOp
>
(
conv2d_out
,
reshape_bias_op
.
out
());
auto
next_op
=
ir
::
GetFirstUseOperationForOutput
<
0
>
(
op
);
rewriter
.
ReplaceAllUsesWith
(
next_op
->
operand
(
0
),
add_bias_op
.
out
());
rewriter
.
EraseOp
(
op
);
return
true
;
}
};
class
TestPass
:
public
ir
::
Pass
{
class
TestPass
:
public
ir
::
Pass
{
public:
public:
TestPass
()
:
ir
::
Pass
(
"TestPass"
,
1
)
{}
TestPass
()
:
ir
::
Pass
(
"TestPass"
,
1
)
{}
void
Run
(
ir
::
Operation
*
op
)
override
{
void
Run
(
ir
::
Operation
*
op
)
override
{
ir
::
RewritePatternSet
ps
(
op
->
ir_context
());
ir
::
RewritePatternSet
ps
(
op
->
ir_context
());
ps
.
Add
<
TransposePatternRewrite
>
(
op
->
ir_context
());
ps
.
Add
<
TransposePatternRewrite
>
(
op
->
ir_context
());
ps
.
Add
<
Conv2dBnFusePattern
>
(
op
->
ir_context
());
ir
::
FrozenRewritePatternSet
frozen_ps
(
std
::
move
(
ps
));
ir
::
FrozenRewritePatternSet
frozen_ps
(
std
::
move
(
ps
));
ir
::
GreedyRewriteConfig
cfg
;
ir
::
GreedyRewriteConfig
cfg
;
cfg
.
use_top_down_traversal
=
true
;
cfg
.
use_top_down_traversal
=
true
;
...
@@ -247,15 +361,55 @@ class TestPass : public ir::Pass {
...
@@ -247,15 +361,55 @@ class TestPass : public ir::Pass {
};
};
void
BuildProgram
(
ir
::
Builder
&
builder
)
{
// NOLINT
void
BuildProgram
(
ir
::
Builder
&
builder
)
{
// NOLINT
paddle
::
dialect
::
FullOp
full_op
=
paddle
::
dialect
::
FullOp
full_
input_
op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
1
,
3
,
16
,
16
},
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
1
,
3
,
16
,
16
},
1.5
,
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
phi
::
CPUPlace
());
ir
::
OpResult
full_op_output
=
full_op
->
result
(
0
);
paddle
::
dialect
::
FullOp
full_filter_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
,
3
,
3
,
3
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_mean_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_variance_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_scale_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
FullOp
full_bias_op
=
builder
.
Build
<
paddle
::
dialect
::
FullOp
>
(
std
::
vector
<
int64_t
>
{
64
},
1.5
,
phi
::
DataType
::
FLOAT32
,
phi
::
CPUPlace
());
paddle
::
dialect
::
Conv2dOp
conv2d_op
=
builder
.
Build
<
paddle
::
dialect
::
Conv2dOp
>
(
full_input_op
.
out
(),
full_filter_op
.
out
());
paddle
::
dialect
::
BatchNormOp
batch_norm_op
=
builder
.
Build
<
paddle
::
dialect
::
BatchNormOp
>
(
conv2d_op
.
out
(),
full_mean_op
.
out
(),
full_variance_op
.
out
(),
full_scale_op
.
out
(),
full_bias_op
.
out
(),
true
,
0.9
,
1e-6
,
"NCHW"
,
false
,
false
);
auto
transpose1_op
=
builder
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
auto
transpose1_op
=
builder
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
full_op_output
,
std
::
vector
<
int
>
{
0
,
2
,
3
,
1
});
batch_norm_op
.
out
()
,
std
::
vector
<
int
>
{
0
,
2
,
3
,
1
});
auto
transpose2_op
=
builder
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
auto
transpose2_op
=
builder
.
Build
<
paddle
::
dialect
::
TransposeOp
>
(
transpose1_op
.
out
(),
std
::
vector
<
int
>
{
0
,
3
,
1
,
2
});
transpose1_op
.
out
(),
std
::
vector
<
int
>
{
0
,
3
,
1
,
2
});
...
@@ -264,22 +418,22 @@ void BuildProgram(ir::Builder &builder) { // NOLINT
...
@@ -264,22 +418,22 @@ void BuildProgram(ir::Builder &builder) { // NOLINT
}
}
// TODO(wilber): Add a normal test.
// TODO(wilber): Add a normal test.
TEST
(
PatternRewrite
,
GreedyPatternRewriteDriver
)
{
TEST
(
pattern_rewrite
,
Patterns
)
{
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
ir
::
Program
program
(
ctx
);
ir
::
Program
program
(
ctx
);
ir
::
Builder
builder
=
ir
::
Builder
(
ctx
,
program
.
block
());
ir
::
Builder
builder
=
ir
::
Builder
(
ctx
,
program
.
block
());
BuildProgram
(
builder
);
BuildProgram
(
builder
);
EXPECT_EQ
(
program
.
block
()
->
size
(),
4u
);
EXPECT_EQ
(
program
.
block
()
->
size
(),
11u
);
ir
::
PassManager
pm
(
ctx
);
ir
::
PassManager
pm
(
ctx
);
pm
.
AddPass
(
std
::
make_unique
<
TestPass
>
());
pm
.
AddPass
(
std
::
make_unique
<
TestPass
>
());
pm
.
AddPass
(
ir
::
CreateDCEPass
());
pm
.
AddPass
(
ir
::
CreateDCEPass
());
std
::
stringstream
o1
,
o2
;
program
.
Print
(
std
::
cout
);
program
.
Print
(
o1
);
std
::
cout
<<
std
::
endl
;
LOG
(
INFO
)
<<
o1
.
str
();
pm
.
Run
(
&
program
);
pm
.
Run
(
&
program
);
LOG
(
INFO
)
<<
"After Pass."
;
LOG
(
INFO
)
<<
"After Pass."
;
program
.
Print
(
o2
);
program
.
Print
(
std
::
cout
);
LOG
(
INFO
)
<<
o2
.
str
()
;
std
::
cout
<<
std
::
endl
;
}
}
test/cpp/new_executor/standalone_executor_new_ir_test.cc
浏览文件 @
19345fa7
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/
pas
s/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/
transform
s/pd_op_to_kernel_pass.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/program.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录