Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0ee230a7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0ee230a7
编写于
2月 21, 2022
作者:
W
wenbin
提交者:
GitHub
2月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
shuffle_channel pass fix (#39735)
上级
e764cde6
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
92 addition
and
0 deletion
+92
-0
paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc
paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc
+92
-0
未找到文件。
paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc
浏览文件 @
0ee230a7
...
...
@@ -94,6 +94,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
auto
*
input_node
=
subgraph
.
at
(
x
);
auto
reshape1_desc
=
reshape1_op
->
Op
();
auto
reshape2_desc
=
reshape2_op
->
Op
();
auto
trans_desc
=
transpose_op
->
Op
();
std
::
string
input_name
=
input_node
->
Name
();
std
::
string
output_name
=
reshape2_out
->
Name
();
...
...
@@ -101,10 +102,101 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
reshape1_desc
->
GetAttr
(
"shape"
));
auto
reshape2_shape
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
reshape2_desc
->
GetAttr
(
"shape"
));
auto
trans_axis
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
trans_desc
->
GetAttr
(
"axis"
));
auto
*
block1
=
reshape1_desc
->
Block
();
auto
*
block2
=
reshape2_desc
->
Block
();
if
(
block1
&&
block2
)
{
auto
x_var_name
=
reshape1_desc
->
Input
(
"X"
)[
0
];
auto
*
x_var_desc
=
block1
->
FindVar
(
x_var_name
);
auto
x_shape1
=
x_var_desc
->
GetShape
();
x_var_name
=
reshape2_desc
->
Input
(
"X"
)[
0
];
x_var_desc
=
block2
->
FindVar
(
x_var_name
);
auto
x_shape2
=
x_var_desc
->
GetShape
();
// now shuffle_channel is 4D(NCHW) only.
if
(
x_shape1
.
size
()
!=
4
||
reshape1_shape
.
size
()
!=
5
||
reshape2_shape
.
size
()
!=
4
||
trans_axis
.
size
()
!=
5
)
{
return
;
}
// process 0 and -1 in reshape.
constexpr
int64_t
copy_dim_val
=
0
;
for
(
size_t
i
=
0
;
i
<
reshape1_shape
.
size
();
i
++
)
{
if
(
reshape1_shape
[
i
]
==
copy_dim_val
)
{
reshape1_shape
[
i
]
=
x_shape1
[
i
];
}
}
for
(
size_t
i
=
0
;
i
<
reshape2_shape
.
size
();
i
++
)
{
if
(
reshape2_shape
[
i
]
==
copy_dim_val
)
{
reshape2_shape
[
i
]
=
x_shape2
[
i
];
}
}
constexpr
int64_t
unk_dim_idx
=
-
1
;
bool
all_positive
=
std
::
all_of
(
x_shape1
.
cbegin
(),
x_shape1
.
cend
(),
[](
int64_t
i
)
{
return
i
>
0
;
});
for
(
size_t
i
=
0
;
i
<
reshape1_shape
.
size
();
++
i
)
{
// if -1 is not in batch dim, try to calculate number
if
((
reshape1_shape
[
i
]
==
unk_dim_idx
)
&&
(
i
!=
0
))
{
// there is no sufficient info
if
(
!
all_positive
)
return
;
reshape1_shape
[
i
]
=
std
::
accumulate
(
x_shape1
.
begin
(),
x_shape1
.
end
(),
static_cast
<
int64_t
>
(
1
),
std
::
multiplies
<
int64_t
>
())
/
std
::
accumulate
(
reshape1_shape
.
begin
(),
reshape1_shape
.
end
(),
static_cast
<
int64_t
>
(
-
1
),
std
::
multiplies
<
int64_t
>
());
break
;
}
}
all_positive
=
std
::
all_of
(
x_shape2
.
cbegin
(),
x_shape2
.
cend
(),
[](
int64_t
i
)
{
return
i
>
0
;
});
for
(
size_t
i
=
0
;
i
<
reshape2_shape
.
size
();
++
i
)
{
// if -1 is not in batch dim, try to calculate number
if
((
reshape2_shape
[
i
]
==
unk_dim_idx
)
&&
(
i
!=
0
))
{
// there is no sufficient info
if
(
!
all_positive
)
return
;
reshape2_shape
[
i
]
=
std
::
accumulate
(
x_shape2
.
begin
(),
x_shape2
.
end
(),
static_cast
<
int64_t
>
(
1
),
std
::
multiplies
<
int64_t
>
())
/
std
::
accumulate
(
reshape2_shape
.
begin
(),
reshape2_shape
.
end
(),
static_cast
<
int64_t
>
(
-
1
),
std
::
multiplies
<
int64_t
>
());
break
;
}
}
// shuffle_channel dosen't change shape
if
((
reshape2_shape
[
0
]
!=
-
1
)
&&
(
x_shape1
[
0
]
!=
reshape2_shape
[
0
]))
{
return
;
}
for
(
size_t
i
=
1
;
i
<
x_shape1
.
size
();
i
++
)
{
if
(
x_shape1
[
i
]
!=
reshape2_shape
[
i
])
{
return
;
}
}
if
((
reshape2_shape
[
3
]
!=
reshape1_shape
[
4
])
||
(
reshape2_shape
[
2
]
!=
reshape1_shape
[
3
]))
{
return
;
}
}
else
{
return
;
// conservative judgement
}
int
i_c
=
reshape1_shape
[
2
];
int
o_c
=
reshape2_shape
[
1
];
int
group
=
o_c
/
i_c
;
// should split on channel dim
if
(
reshape2_shape
[
1
]
!=
reshape1_shape
[
2
]
*
reshape1_shape
[
1
])
return
;
// trans on channel dim
if
(
trans_axis
[
0
]
!=
0
||
trans_axis
[
3
]
!=
3
||
trans_axis
[
4
]
!=
4
)
return
;
if
(
group
!=
1
&&
i_c
!=
1
)
{
if
(
trans_axis
[
1
]
!=
2
&&
trans_axis
[
2
]
!=
1
)
{
return
;
}
}
framework
::
OpDesc
new_op_desc
;
new_op_desc
.
SetType
(
"shuffle_channel"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录