Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ce57365d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ce57365d
编写于
1月 12, 2023
作者:
J
jiangcheng
提交者:
GitHub
1月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CINN] temp fix batch_norm check as inplace op bug (#49738)
上级
30f5e39b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
65 addition
and
19 deletion
+65
-19
paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
+57
-18
paddle/fluid/framework/paddle2cinn/build_cinn_pass.h
paddle/fluid/framework/paddle2cinn/build_cinn_pass.h
+8
-1
未找到文件。
paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
浏览文件 @
ce57365d
...
...
@@ -52,6 +52,16 @@ using framework::ir::Node;
using
GraphNodeVec
=
std
::
vector
<
Node
*>
;
using
GraphNodeMap
=
std
::
unordered_map
<
Node
*
,
Node
*>
;
std
::
string
GetDebugInfo
(
const
std
::
unordered_set
<
std
::
string
>&
var_names
)
{
std
::
string
debug_info
=
"["
;
for
(
auto
&
var
:
var_names
)
{
debug_info
.
append
(
var
);
debug_info
.
append
(
", "
);
}
debug_info
.
append
(
"]"
);
return
debug_info
;
}
OpTransInfo
::
OpTransInfo
()
{
// judgment condition for the dynamic slice
dynamic_op_cond_
.
emplace
(
"slice"
,
[](
const
ir
::
Node
&
node
)
->
bool
{
...
...
@@ -115,16 +125,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const
GraphNodeSet
&
cluster
)
const
{
std
::
unordered_set
<
std
::
string
>
deny_var_set
;
auto
get_debug_info
=
[](
const
std
::
unordered_set
<
std
::
string
>&
var_names
)
{
std
::
string
debug_info
=
"["
;
for
(
auto
&
var
:
var_names
)
{
debug_info
.
append
(
var
);
debug_info
.
append
(
", "
);
}
debug_info
.
append
(
"]"
);
return
debug_info
;
};
for
(
auto
*
op
:
cluster
)
{
if
(
deny_param_cond_
.
count
(
op
->
Name
()))
{
const
auto
*
desc
=
op
->
Op
();
...
...
@@ -136,7 +136,7 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
op
->
Name
().
c_str
()));
auto
deny_param_names
=
deny_param_cond_
.
at
(
op
->
Name
());
VLOG
(
4
)
<<
"We found deny param "
<<
get_debug_i
nfo
(
deny_param_names
)
VLOG
(
4
)
<<
"We found deny param "
<<
GetDebugI
nfo
(
deny_param_names
)
<<
" in op ["
<<
op
->
Name
()
<<
"]."
;
for
(
const
auto
&
param_name
:
deny_param_names
)
{
...
...
@@ -161,16 +161,51 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
}
}
VLOG
(
4
)
<<
"All deny var names are "
<<
get_debug_i
nfo
(
deny_var_set
);
VLOG
(
4
)
<<
"All deny var names are "
<<
GetDebugI
nfo
(
deny_var_set
);
return
deny_var_set
;
}
bool
OpTransInfo
::
IsInplaceOp
(
const
OpDesc
&
op_desc
)
{
std
::
unordered_set
<
std
::
string
>
OpTransInfo
::
GetIgnoreInplaceVarNames
(
const
OpDesc
&
op_desc
)
const
{
if
(
!
ignore_inplace_param_cond_
.
count
(
op_desc
.
Type
()))
{
return
{};
}
const
auto
&
ignore_inplace_names
=
ignore_inplace_param_cond_
.
at
(
op_desc
.
Type
());
VLOG
(
4
)
<<
"We found ignore inplace param "
<<
GetDebugInfo
(
ignore_inplace_names
)
<<
" in op ["
<<
op_desc
.
Type
()
<<
"]."
;
std
::
unordered_set
<
std
::
string
>
ignore_inplace_set
;
for
(
const
auto
&
param_name
:
ignore_inplace_names
)
{
if
(
op_desc
.
HasOutput
(
param_name
))
{
const
auto
&
arg_names
=
op_desc
.
Output
(
param_name
);
ignore_inplace_set
.
insert
(
arg_names
.
begin
(),
arg_names
.
end
());
}
}
VLOG
(
4
)
<<
"All ignore inplace var names are "
<<
GetDebugInfo
(
ignore_inplace_set
);
return
ignore_inplace_set
;
}
bool
OpTransInfo
::
IsInplaceOp
(
const
OpDesc
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>&
deny_var_names
)
const
{
const
auto
&
ignore_inplace_set
=
GetIgnoreInplaceVarNames
(
op_desc
);
auto
inputs
=
op_desc
.
InputArgumentNames
();
std
::
unordered_set
<
std
::
string
>
input_set
(
inputs
.
begin
(),
inputs
.
end
());
for
(
auto
&
name
:
op_desc
.
OutputArgumentNames
())
{
if
(
input_set
.
count
(
name
)
>
0
)
return
true
;
if
(
input_set
.
count
(
name
)
>
0
&&
!
deny_var_names
.
count
(
name
)
&&
!
ignore_inplace_set
.
count
(
name
))
{
VLOG
(
4
)
<<
"The argument "
<<
name
<<
" in op "
<<
op_desc
.
Type
()
<<
" is a inplace op, skip!"
;
return
true
;
}
}
return
false
;
}
...
...
@@ -630,8 +665,11 @@ void ReplaceSubGraphWithCinnOpNode(
void
SearchAllSubgraphs
(
Graph
*
graph
,
bool
is_inference_stage
)
{
auto
allow_ops
=
StringSplit
(
FLAGS_allow_cinn_ops
,
kDelim
);
auto
deny_ops
=
StringSplit
(
FLAGS_deny_cinn_ops
,
kDelim
);
OpTransInfo
trans_info
;
auto
teller
=
[
&
allow_ops
,
&
deny_ops
,
&
trans_info
](
const
Node
*
node
)
{
const
auto
&
deny_var_set
=
trans_info
.
GetDenyVarNames
(
graph
->
Nodes
());
auto
teller
=
[
&
allow_ops
,
&
deny_ops
,
&
trans_info
,
&
deny_var_set
](
const
Node
*
node
)
{
const
auto
&
node_name
=
node
->
Name
();
bool
registered
=
::
cinn
::
frontend
::
OpMapperRegistry
::
Global
()
->
Find
(
node_name
)
!=
nullptr
;
...
...
@@ -643,7 +681,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
bool
is_support
=
registered
&&
!
trans_info
.
default_deny_ops
().
count
(
node_name
)
&&
!
is_dynamic
&&
(
node
->
IsOp
()
&&
!
trans_info
.
IsInplaceOp
(
*
node
->
Op
()));
!
is_dynamic
&&
(
node
->
IsOp
()
&&
!
trans_info
.
IsInplaceOp
(
*
node
->
Op
(),
deny_var_set
));
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if
(
!
allow_ops
.
empty
())
{
...
...
@@ -659,8 +698,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
// return true only when it is registered in CINN
return
is_support
;
};
VLOG
(
4
)
<<
"The allowed Cinn Ops: "
<<
FLAGS_allow_cinn_ops
;
VLOG
(
4
)
<<
"The denied Cinn Ops: "
<<
FLAGS_deny_cinn_ops
;
VLOG
(
4
)
<<
"The allowed Cinn Ops: "
<<
GetDebugInfo
(
allow_ops
)
;
VLOG
(
4
)
<<
"The denied Cinn Ops: "
<<
GetDebugInfo
(
deny_ops
)
;
std
::
vector
<
GraphNodeVec
>
clusters
=
CinnSubgraphDetector
(
graph
,
teller
)();
LOG
(
INFO
)
<<
"--- [build_cinn_pass] detected "
<<
clusters
.
size
()
<<
" cinn supported subgraphs"
;
...
...
paddle/fluid/framework/paddle2cinn/build_cinn_pass.h
浏览文件 @
ce57365d
...
...
@@ -67,7 +67,11 @@ class OpTransInfo {
std
::
unordered_set
<
std
::
string
>
GetDenyVarNames
(
const
GraphNodeSet
&
cluster
)
const
;
static
bool
IsInplaceOp
(
const
OpDesc
&
op_desc
);
std
::
unordered_set
<
std
::
string
>
GetIgnoreInplaceVarNames
(
const
OpDesc
&
op_desc
)
const
;
bool
IsInplaceOp
(
const
OpDesc
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>&
deny_var_names
)
const
;
private:
DyOpCondT
dynamic_op_cond_
;
...
...
@@ -75,6 +79,9 @@ class OpTransInfo {
DeParamCondT
deny_param_cond_
{{
"batch_norm"
,
{
"ReserveSpace"
}},
{
"batch_norm_grad"
,
{
"ReserveSpace"
}}};
DeParamCondT
ignore_inplace_param_cond_
{
{
"batch_norm"
,
{
"MeanOut"
,
"VarianceOut"
}}};
std
::
unordered_set
<
std
::
string
>
default_deny_ops_
{
"feed"
,
"fetch"
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录