Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b82e8f00
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b82e8f00
编写于
9月 16, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(gopt): refact the padding channel opt pass
GitOrigin-RevId: ee3f55aa66f21fe2d4a042298aafe4a0a02915f7
上级
f444d4fe
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
327 addition
and
307 deletion
+327
-307
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+2
-1
src/gopt/impl/padding_channel.cpp
src/gopt/impl/padding_channel.cpp
+286
-302
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+30
-0
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+9
-4
未找到文件。
src/gopt/impl/framework.cpp
浏览文件 @
b82e8f00
...
...
@@ -783,7 +783,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
});
cb
(
nchw64
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
<
PaddingChannelPass
>
();
add_pass
(
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
));
add_pass
<
FuseConvBiasZPass
>
();
add_pass
(
EnableNCHW64Pass
::
make_nchw64_converter
());
add_pass
<
ShuffleShuffleRemovePass
>
();
...
...
src/gopt/impl/padding_channel.cpp
浏览文件 @
b82e8f00
...
...
@@ -33,6 +33,54 @@ using namespace gopt;
using
ReformatKey
=
ReformatManager
::
ReformatKey
;
/* ==================== PaddingChannelPass ================= */
namespace
{
size_t
padding_int4
(
size_t
in_channel
,
bool
flag
)
{
static_cast
<
void
>
(
flag
);
if
(
in_channel
<=
32
)
{
return
(
8
-
(
in_channel
%
8
))
%
8
;
}
else
{
return
(
64
-
(
in_channel
%
64
))
%
64
;
}
}
size_t
padding_int8
(
size_t
in_channel
,
bool
flag
)
{
if
(
flag
)
{
if
(
in_channel
<=
16
)
{
return
(
4
-
(
in_channel
%
4
))
%
4
;
}
else
{
return
(
32
-
(
in_channel
%
32
))
%
32
;
}
}
else
{
return
(
4
-
(
in_channel
%
4
))
%
4
;
}
}
size_t
padding_4
(
size_t
in_channel
,
bool
)
{
return
(
4
-
(
in_channel
%
4
))
%
4
;
};
}
// namespace
std
::
unique_ptr
<
PaddingChannelPass
>
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
layout_transform
)
{
MIDOUT_B
(
"PaddingChannelPass::make"
)
using
LayoutTrans
=
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
;
auto
ret
=
std
::
make_unique
<
PaddingChannelPass
>
();
auto
&
alignment_map
=
ret
->
m_alignment_map
;
if
(
layout_transform
==
LayoutTrans
::
NCHW64
)
{
alignment_map
[
DTypeEnum
::
QuantizedS4
]
=
padding_int4
;
alignment_map
[
DTypeEnum
::
Quantized4Asymm
]
=
padding_int4
;
alignment_map
[
DTypeEnum
::
QuantizedS8
]
=
padding_int8
;
}
else
if
(
layout_transform
==
LayoutTrans
::
NCHW44
||
layout_transform
==
LayoutTrans
::
NCHW44_DOT
)
{
alignment_map
[
DTypeEnum
::
QuantizedS8
]
=
padding_4
;
alignment_map
[
DTypeEnum
::
Quantized8Asymm
]
=
padding_4
;
alignment_map
[
DTypeEnum
::
Float32
]
=
padding_4
;
}
ret
->
fill_opr_convert_fun
(
layout_transform
);
return
ret
;
MIDOUT_E
}
const
char
*
PaddingChannelPass
::
name
()
const
{
return
mgb_cstr_log
(
"padding output channel to multiple of 4/32"
);
}
...
...
@@ -42,267 +90,240 @@ void PaddingChannelPass::apply(OptState& opt) const {
// do not check shape
opt
.
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
CHECK_ALL
^
VarReplaceCheckFlag
::
CHECK_SHAPE
);
ThinHashSet
<
OperatorNodeBase
*>
padding_oprs
;
ThinHashMap
<
Typeinfo
*
,
thin_function
<
OperatorNodeBase
*
(
OperatorNodeBase
*
,
const
VarNodeArray
&
)
>>
opr_replace_funcs
;
m_padding_oprs
.
clear
();
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
pad_in_channels
=
[](
VarNode
*
inp
,
size_t
pad_channels
)
->
VarNode
*
{
mgb_assert
(
inp
->
shape
().
ndim
==
4
);
mgb_assert
(
inp
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS4
||
inp
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized4Asymm
||
inp
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
||
inp
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS32
);
TensorShape
shape
{
inp
->
shape
()[
0
],
pad_channels
,
inp
->
shape
()[
2
],
inp
->
shape
()[
3
]};
std
::
shared_ptr
<
HostTensorND
>
host_val
=
std
::
make_shared
<
HostTensorND
>
(
inp
->
comp_node
(),
inp
->
dtype
());
host_val
->
resize
(
shape
);
auto
ptr
=
host_val
->
raw_ptr
();
size_t
size_bytes
=
TensorLayout
{
shape
,
inp
->
dtype
()}.
span
().
dist_byte
();
std
::
memset
(
ptr
,
0
,
size_bytes
);
auto
padding
=
opr
::
ImmutableTensor
::
make
(
*
inp
->
owner_graph
(),
*
host_val
);
auto
out
=
opr
::
Concat
::
make
({
inp
,
padding
},
1
);
return
out
.
node
();
};
auto
pad_out_channels
=
[](
VarNode
*
inp
,
size_t
pad_channels
)
->
VarNode
*
{
mgb_assert
(
inp
->
shape
().
ndim
==
4
);
mgb_assert
(
inp
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS4
||
inp
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized4Asymm
||
inp
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
||
inp
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS32
);
TensorShape
shape
{
pad_channels
,
inp
->
shape
()[
1
],
inp
->
shape
()[
2
],
inp
->
shape
()[
3
]};
std
::
shared_ptr
<
HostTensorND
>
host_val
=
std
::
make_shared
<
HostTensorND
>
(
inp
->
comp_node
(),
inp
->
dtype
());
host_val
->
resize
(
shape
);
auto
ptr
=
host_val
->
raw_ptr
();
size_t
size_bytes
=
TensorLayout
{
shape
,
inp
->
dtype
()}.
span
().
dist_byte
();
std
::
memset
(
ptr
,
0
,
size_bytes
);
auto
padding
=
opr
::
ImmutableTensor
::
make
(
*
inp
->
owner_graph
(),
*
host_val
);
auto
out
=
opr
::
Concat
::
make
({
inp
,
padding
},
0
);
return
out
.
node
();
};
auto
extract_subtensor
=
[](
VarNode
*
inp
,
const
TensorShape
&
orig_shape
)
->
VarNode
*
{
mgb_assert
(
inp
->
shape
().
ndim
==
4
);
mgb_assert
(
inp
->
shape
()[
0
]
==
orig_shape
[
0
]);
mgb_assert
(
inp
->
shape
()[
2
]
==
orig_shape
[
2
]);
mgb_assert
(
inp
->
shape
()[
3
]
==
orig_shape
[
3
]);
size_t
orig_channels
=
orig_shape
[
1
];
auto
x
=
SymbolVar
(
inp
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
using
AIdx
=
opr
::
Subtensor
::
AxisIndexer
;
auto
sub
=
opr
::
Subtensor
::
make
(
x
,
{
AIdx
::
make_interval
(
0
,
None
,
None
,
cv
(
1
)),
AIdx
::
make_interval
(
1
,
None
,
cv
(
orig_channels
),
None
),
AIdx
::
make_interval
(
2
,
None
,
None
,
cv
(
1
)),
AIdx
::
make_interval
(
3
,
None
,
None
,
cv
(
1
))});
return
sub
.
node
();
};
// padding policy for conv bias with data type qint8
auto
padding_policy_qint8
=
[
&
padding_oprs
,
&
pad_in_channels
,
&
pad_out_channels
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
new_inp
.
size
()
==
3
);
mgb_assert
(
opr
->
input
(
1
)
->
shape
().
eq_shape
(
new_inp
[
1
]
->
shape
()));
auto
inps
=
new_inp
;
size_t
out_channels
=
opr
->
input
(
1
)
->
shape
()[
0
];
size_t
in_channels
=
opr
->
input
(
1
)
->
shape
()[
1
];
size_t
new_in_channels
=
new_inp
[
0
]
->
shape
()[
1
];
// pad input channels
if
(
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
()))
{
size_t
pad_channels
=
new_in_channels
-
in_channels
;
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
else
{
size_t
pad_channels
=
0
;
mgb_assert
(
new_in_channels
==
in_channels
);
if
(
in_channels
<=
16
)
{
if
(
in_channels
%
4
)
pad_channels
=
4
-
(
in_channels
%
4
);
// pad to use dp4a
}
else
{
if
(
in_channels
%
32
)
pad_channels
=
32
-
(
in_channels
%
32
);
// pad to use tensorcore
auto
on_opr
=
[
this
,
&
opt
,
&
rewriter
](
OperatorNodeBase
*
opr
)
{
auto
it
=
m_opr_replace_funcs
.
find
(
opr
->
dyn_typeinfo
());
if
(
it
!=
m_opr_replace_funcs
.
end
())
{
VarNodeArray
new_inp
;
new_inp
.
reserve
(
opr
->
input
().
size
());
for
(
auto
&&
inp
:
opr
->
input
())
{
new_inp
.
push_back
(
rewriter
.
get_var
(
inp
));
}
if
(
pad_channels
>
0
)
{
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
auto
new_opr
=
(
it
->
second
)(
opr
,
new_inp
);
auto
&&
out0
=
opr
->
output
(),
&&
out1
=
new_opr
->
output
();
mgb_assert
(
out0
.
size
()
==
out1
.
size
(),
"bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu"
,
opr
->
cname
(),
opr
->
dyn_typeinfo
()
->
name
,
new_opr
->
cname
(),
new_opr
->
dyn_typeinfo
()
->
name
,
out0
.
size
(),
out1
.
size
());
for
(
size_t
i
=
0
;
i
<
out0
.
size
();
++
i
)
{
if
(
!
out0
[
i
]
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
mgb_assert
(
!
out1
[
i
]
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
));
auto
src
=
out0
[
i
];
auto
dst
=
out1
[
i
];
if
(
opt
.
graph
().
endpoint_contain
(
src
)
&&
!
src
->
shape
().
eq_shape
(
dst
->
shape
()))
{
dst
=
extract_subtensor
(
dst
,
src
->
shape
());
}
rewriter
.
replace_var
(
src
,
dst
,
nullptr
);
}
}
}
out_channels
=
inps
[
1
]
->
shape
()[
0
];
in_channels
=
inps
[
1
]
->
shape
()[
1
];
size_t
pad_channels
=
0
;
if
(
out_channels
<=
16
)
{
if
(
out_channels
%
4
)
pad_channels
=
4
-
(
out_channels
%
4
);
}
else
{
if
(
out_channels
%
32
)
pad_channels
=
32
-
(
out_channels
%
32
);
}
if
(
pad_channels
>
0
)
{
inps
[
1
]
=
pad_out_channels
(
inps
[
1
],
pad_channels
);
inps
[
2
]
=
pad_in_channels
(
inps
[
2
],
pad_channels
);
padding_oprs
.
insert
(
opr
);
rewriter
.
auto_replace_outputs
(
opr
);
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
opr
->
config
());
};
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
// padding policy for conv bias with data type qint4 and quint4
auto
padding_policy_int4
=
[
&
padding_oprs
,
&
pad_in_channels
,
&
pad_out_channels
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
new_inp
.
size
()
==
3
);
mgb_assert
(
opr
->
input
(
1
)
->
shape
().
eq_shape
(
new_inp
[
1
]
->
shape
()));
auto
inps
=
new_inp
;
size_t
out_channels
=
opr
->
input
(
1
)
->
shape
()[
0
];
size_t
in_channels
=
opr
->
input
(
1
)
->
shape
()[
1
];
size_t
new_in_channels
=
new_inp
[
0
]
->
shape
()[
1
];
// pad input channels
if
(
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
()))
{
if
(
new_in_channels
<=
32
)
{
if
(
new_in_channels
%
8
==
0
)
{
size_t
pad_channels
=
new_in_channels
-
in_channels
;
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
else
{
size_t
pad_channels_0
=
8
-
(
new_in_channels
%
8
);
size_t
pad_channels_1
=
8
-
(
in_channels
%
8
);
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels_0
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels_1
);
}
}
else
{
if
(
new_in_channels
%
64
==
0
)
{
size_t
pad_channels
=
new_in_channels
-
in_channels
;
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
else
{
size_t
pad_channels_0
=
64
-
(
new_in_channels
%
64
);
size_t
pad_channels_1
=
64
-
(
in_channels
%
64
);
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels_0
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels_1
);
}
}
MIDOUT_E
}
VarNode
*
PaddingChannelPass
::
extract_subtensor
(
VarNode
*
inp
,
const
TensorShape
&
orig_shape
)
const
{
mgb_assert
(
inp
->
shape
().
ndim
==
4
);
mgb_assert
(
inp
->
shape
()[
0
]
==
orig_shape
[
0
]);
mgb_assert
(
inp
->
shape
()[
2
]
==
orig_shape
[
2
]);
mgb_assert
(
inp
->
shape
()[
3
]
==
orig_shape
[
3
]);
size_t
orig_channels
=
orig_shape
[
1
];
auto
x
=
SymbolVar
(
inp
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
using
AIdx
=
opr
::
Subtensor
::
AxisIndexer
;
auto
sub
=
opr
::
Subtensor
::
make
(
x
,
{
AIdx
::
make_interval
(
0
,
None
,
None
,
cv
(
1
)),
AIdx
::
make_interval
(
1
,
None
,
cv
(
orig_channels
),
None
),
AIdx
::
make_interval
(
2
,
None
,
None
,
cv
(
1
)),
AIdx
::
make_interval
(
3
,
None
,
None
,
cv
(
1
))});
return
sub
.
node
();
};
VarNode
*
PaddingChannelPass
::
pad_in_channels
(
VarNode
*
inp
,
size_t
pad_channels
)
{
mgb_assert
(
inp
->
shape
().
ndim
==
4
);
TensorShape
shape
{
inp
->
shape
()[
0
],
pad_channels
,
inp
->
shape
()[
2
],
inp
->
shape
()[
3
]};
std
::
shared_ptr
<
HostTensorND
>
host_val
=
std
::
make_shared
<
HostTensorND
>
(
inp
->
comp_node
(),
inp
->
dtype
());
host_val
->
resize
(
shape
);
auto
ptr
=
host_val
->
raw_ptr
();
size_t
size_bytes
=
TensorLayout
{
shape
,
inp
->
dtype
()}.
span
().
dist_byte
();
std
::
memset
(
ptr
,
0
,
size_bytes
);
auto
padding
=
opr
::
ImmutableTensor
::
make
(
*
inp
->
owner_graph
(),
*
host_val
);
auto
out
=
opr
::
Concat
::
make
({
inp
,
padding
},
1
);
return
out
.
node
();
};
VarNode
*
PaddingChannelPass
::
pad_out_channels
(
VarNode
*
inp
,
size_t
pad_channels
)
{
mgb_assert
(
inp
->
shape
().
ndim
==
4
);
TensorShape
shape
{
pad_channels
,
inp
->
shape
()[
1
],
inp
->
shape
()[
2
],
inp
->
shape
()[
3
]};
std
::
shared_ptr
<
HostTensorND
>
host_val
=
std
::
make_shared
<
HostTensorND
>
(
inp
->
comp_node
(),
inp
->
dtype
());
host_val
->
resize
(
shape
);
auto
ptr
=
host_val
->
raw_ptr
();
size_t
size_bytes
=
TensorLayout
{
shape
,
inp
->
dtype
()}.
span
().
dist_byte
();
std
::
memset
(
ptr
,
0
,
size_bytes
);
auto
padding
=
opr
::
ImmutableTensor
::
make
(
*
inp
->
owner_graph
(),
*
host_val
);
auto
out
=
opr
::
Concat
::
make
({
inp
,
padding
},
0
);
return
out
.
node
();
};
// padding policy for conv bias with data type qint8
OperatorNodeBase
*
PaddingChannelPass
::
padding_policy
(
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
new_inp
.
size
()
==
3
);
//! new weights and old weights are same shape
mgb_assert
(
opr
->
input
(
1
)
->
shape
().
eq_shape
(
new_inp
[
1
]
->
shape
()));
auto
inps
=
new_inp
;
size_t
out_channels
=
opr
->
input
(
1
)
->
shape
()[
0
];
size_t
in_channels
=
opr
->
input
(
1
)
->
shape
()[
1
];
size_t
new_in_channels
=
new_inp
[
0
]
->
shape
()[
1
];
auto
it
=
m_alignment_map
.
find
(
opr
->
input
(
0
)
->
dtype
().
enumv
());
if
(
it
!=
m_alignment_map
.
end
())
{
mgb_assert
(
it
->
second
);
}
else
{
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
opr
->
config
());
}
// pad input channels
if
(
m_padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
()))
{
//! as the opr of input var is padding, but the dtype of input and output of
//! the input opr maybe different, so the alignment is not the same
size_t
pad_channels_0
=
it
->
second
(
new_in_channels
,
true
);
size_t
pad_channels_1
=
it
->
second
(
in_channels
,
true
);
if
(
pad_channels_0
)
{
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels_0
);
}
else
{
size_t
pad_channels
=
0
;
mgb_assert
(
new_in_channels
==
in_channels
);
if
(
in_channels
<=
32
)
{
if
(
in_channels
%
8
)
pad_channels
=
8
-
(
in_channels
%
8
);
}
else
{
if
(
in_channels
%
64
)
pad_channels
=
64
-
(
in_channels
%
64
);
}
if
(
pad_channels
>
0
)
{
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
pad_channels_1
=
new_in_channels
-
in_channels
;
}
out_channels
=
inps
[
1
]
->
shape
()[
0
];
in_channels
=
inps
[
1
]
->
shape
()[
1
];
size_t
pad_channels
=
0
;
if
(
out_channels
<=
32
)
{
if
(
out_channels
%
8
)
pad_channels
=
8
-
(
out_channels
%
8
);
}
else
{
if
(
out_channels
%
64
)
pad_channels
=
64
-
(
out_channels
%
64
);
if
(
pad_channels_1
)
{
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels_1
);
}
}
else
{
mgb_assert
(
new_in_channels
==
in_channels
);
size_t
pad_channels
=
it
->
second
(
in_channels
,
true
);
if
(
pad_channels
>
0
)
{
inps
[
1
]
=
pad_out_channels
(
inps
[
1
],
pad_channels
);
inps
[
2
]
=
pad_in_channels
(
inps
[
2
],
pad_channels
);
padding_oprs
.
insert
(
opr
);
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
opr
->
config
());
};
}
out_channels
=
inps
[
1
]
->
shape
()[
0
];
size_t
pad_channels
=
it
->
second
(
out_channels
,
true
);
if
(
pad_channels
>
0
)
{
inps
[
1
]
=
pad_out_channels
(
inps
[
1
],
pad_channels
);
inps
[
2
]
=
pad_in_channels
(
inps
[
2
],
pad_channels
);
m_padding_oprs
.
insert
(
opr
);
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
opr
->
config
());
};
opr_replace_funcs
[
opr
::
ConvBiasForward
::
typeinfo
()]
=
[
&
padding_oprs
,
&
padding_policy_qint8
,
&
padding_policy_int4
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
if
(
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
return
padding_policy_qint8
(
opr
,
new_inp
);
}
else
if
(
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS4
||
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
return
padding_policy_int4
(
opr
,
new_inp
);
}
else
{
mgb_assert
(
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
())
==
0
,
"conv bias operator for data type(%s) cannot be "
"padded channel. "
"consumer(%s), producer(%s)"
,
opr
->
input
(
0
)
->
dtype
().
name
(),
opr
->
cname
(),
opr
->
input
(
0
)
->
owner_opr
()
->
cname
());
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
};
opr_replace_funcs
[
opr
::
ConvolutionBackwardData
::
typeinfo
()]
=
[
&
padding_oprs
,
&
pad_in_channels
,
&
pad_out_channels
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
if
(
opr
->
input
(
1
)
->
dtype
().
enumv
()
!=
DTypeEnum
::
QuantizedS8
)
{
void
PaddingChannelPass
::
fill_opr_convert_fun
(
LayoutTrans
layout_trans
)
{
add_convbias_replace_func
(
layout_trans
);
add_conv_backward_data_replace_func
(
layout_trans
);
add_format_aware_opr_replace_func
(
layout_trans
);
add_elemwise_like_opr_replace_func
(
layout_trans
);
add_nonpadding_oprs_replace_func
(
layout_trans
);
}
void
PaddingChannelPass
::
add_convbias_replace_func
(
LayoutTrans
layout_trans
)
{
if
(
layout_trans
==
LayoutTrans
::
NCHW64
)
{
m_opr_replace_funcs
[
opr
::
ConvBiasForward
::
typeinfo
()]
=
[
this
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
if
(
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
return
padding_policy
(
opr
,
new_inp
);
}
else
if
(
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS4
||
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
return
padding_policy
(
opr
,
new_inp
);
}
else
{
mgb_assert
(
m_padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
())
==
0
,
"conv bias operator for data type(%s) cannot be "
"padded channel. "
"consumer(%s), producer(%s)"
,
opr
->
input
(
0
)
->
dtype
().
name
(),
opr
->
cname
(),
opr
->
input
(
0
)
->
owner_opr
()
->
cname
());
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
};
}
else
if
(
layout_trans
==
LayoutTrans
::
NCHW44
)
{
m_opr_replace_funcs
[
opr
::
ConvBiasForward
::
typeinfo
()]
=
[
this
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
return
padding_policy
(
opr
,
new_inp
);
};
}
}
void
PaddingChannelPass
::
add_conv_backward_data_replace_func
(
LayoutTrans
layout_trans
)
{
if
(
layout_trans
==
LayoutTrans
::
NCHW64
)
{
m_opr_replace_funcs
[
opr
::
ConvolutionBackwardData
::
typeinfo
()]
=
[
this
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
if
(
opr
->
input
(
1
)
->
dtype
().
enumv
()
!=
DTypeEnum
::
QuantizedS8
)
{
mgb_assert
(
m_padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
())
==
0
,
"conv bwd data operator for data type(%s) cannot "
"be "
"padded channel. "
"consumer(%s), producer(%s)"
,
opr
->
input
(
0
)
->
dtype
().
name
(),
opr
->
cname
(),
opr
->
input
(
0
)
->
owner_opr
()
->
cname
());
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
())
==
0
,
"conv bwd data operator for data type(%s) cannot "
"be "
"padded channel. "
"consumer(%s), producer(%s)"
,
opr
->
input
(
0
)
->
dtype
().
name
(),
opr
->
cname
(),
opr
->
input
(
0
)
->
owner_opr
()
->
cname
());
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
new_inp
.
size
()
==
2
,
"deconv (conv bwd data) operator for inference can "
"only have 2 input vars(got:%zu)"
,
new_inp
.
size
());
mgb_assert
(
opr
->
input
(
0
)
->
shape
().
eq_shape
(
new_inp
[
0
]
->
shape
()));
auto
inps
=
new_inp
;
size_t
out_channels
=
opr
->
input
(
0
)
->
shape
()[
0
];
size_t
in_channels
=
opr
->
input
(
0
)
->
shape
()[
1
];
size_t
new_out_channels
=
new_inp
[
1
]
->
shape
()[
1
];
// pad output channels
if
(
padding_oprs
.
count
(
opr
->
input
(
1
)
->
owner_opr
()))
{
size_t
pad_channels
=
new_out_channels
-
out_channels
;
inps
[
0
]
=
pad_out_channels
(
new_inp
[
0
],
pad_channels
);
}
else
{
size_t
pad_channels
=
0
;
if
(
out_channels
%
4
)
pad_channels
=
4
-
(
out_channels
%
4
);
if
(
pad_channels
>
0
)
{
new_inp
.
size
()
==
2
,
"deconv (conv bwd data) operator for inference can "
"only have 2 input vars(got:%zu)"
,
new_inp
.
size
());
mgb_assert
(
opr
->
input
(
0
)
->
shape
().
eq_shape
(
new_inp
[
0
]
->
shape
()));
auto
inps
=
new_inp
;
size_t
out_channels
=
opr
->
input
(
0
)
->
shape
()[
0
];
size_t
in_channels
=
opr
->
input
(
0
)
->
shape
()[
1
];
size_t
new_out_channels
=
new_inp
[
1
]
->
shape
()[
1
];
auto
it
=
m_alignment_map
.
find
(
opr
->
input
(
1
)
->
dtype
().
enumv
());
// pad output channels
if
(
m_padding_oprs
.
count
(
opr
->
input
(
1
)
->
owner_opr
()))
{
size_t
pad_channels
=
new_out_channels
-
out_channels
;
inps
[
0
]
=
pad_out_channels
(
new_inp
[
0
],
pad_channels
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
else
{
size_t
pad_channels
=
it
->
second
(
out_channels
,
false
);
if
(
pad_channels
>
0
)
{
inps
[
0
]
=
pad_out_channels
(
new_inp
[
0
],
pad_channels
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
}
}
out_channels
=
inps
[
0
]
->
shape
()[
0
];
in_channels
=
inps
[
0
]
->
shape
()[
1
];
// pad input channels
size_t
pad_channels
=
0
;
if
(
in_channels
%
4
)
pad_channels
=
4
-
(
in_channels
%
4
);
if
(
pad_channels
>
0
)
{
inps
[
0
]
=
pad_in_channels
(
inps
[
0
],
pad_channels
);
padding_oprs
.
insert
(
opr
);
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
opr
->
config
());
};
auto
replace_format_aware_opr
=
[
&
padding_oprs
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
out_channels
=
inps
[
0
]
->
shape
()[
0
];
// pad input channels
size_t
pad_channels
=
it
->
second
(
in_channels
,
false
);
if
(
pad_channels
>
0
)
{
inps
[
0
]
=
pad_in_channels
(
inps
[
0
],
pad_channels
);
m_padding_oprs
.
insert
(
opr
);
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
opr
->
config
());
};
}
}
void
PaddingChannelPass
::
add_format_aware_opr_replace_func
(
LayoutTrans
)
{
auto
replace_format_aware_opr
=
[
this
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
if
(
opr
->
input
(
0
)
->
dtype
().
enumv
()
!=
DTypeEnum
::
QuantizedS8
&&
opr
->
input
(
0
)
->
dtype
().
enumv
()
!=
DTypeEnum
::
QuantizedS4
&&
opr
->
input
(
0
)
->
dtype
().
enumv
()
!=
DTypeEnum
::
Quantized4Asymm
)
{
mgb_assert
(
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
())
==
0
,
m_
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
())
==
0
,
"operator(type:%s,name:%s) for data type(%s) cannot be "
"padded channel. extra info:"
"consumer(%s), producer(%s)"
,
...
...
@@ -312,18 +333,19 @@ void PaddingChannelPass::apply(OptState& opt) const {
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
if
(
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
()))
{
padding_oprs
.
insert
(
opr
);
if
(
m_
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
()))
{
m_
padding_oprs
.
insert
(
opr
);
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
};
opr_replace_funcs
[
opr
::
PoolingForward
::
typeinfo
()]
=
replace_format_aware_opr
;
opr_replace_funcs
[
opr
::
WarpPerspectiveForward
::
typeinfo
()]
=
m_
opr_replace_funcs
[
opr
::
PoolingForward
::
typeinfo
()]
=
replace_format_aware_opr
;
m_
opr_replace_funcs
[
opr
::
WarpPerspectiveForward
::
typeinfo
()]
=
replace_format_aware_opr
;
}
auto
replace_elemwise_like_opr
=
[
&
padding_oprs
,
&
extract_subtensor
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
void
PaddingChannelPass
::
add_elemwise_like_opr_replace_func
(
LayoutTrans
)
{
auto
replace_elemwise_like_opr
=
[
this
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
bool
have_padding_inp
=
false
;
bool
padding_all_inps
=
true
;
...
...
@@ -331,7 +353,7 @@ void PaddingChannelPass::apply(OptState& opt) const {
size_t
channels_after_padding
=
0
;
size_t
i
=
0
;
for
(
auto
&&
cur_inp
:
opr
->
input
())
{
bool
padding_cur_inp
=
padding_oprs
.
count
(
cur_inp
->
owner_opr
())
>
0
;
bool
padding_cur_inp
=
m_
padding_oprs
.
count
(
cur_inp
->
owner_opr
())
>
0
;
if
(
padding_cur_inp
)
{
if
(
!
have_padding_inp
)
have_padding_inp
=
true
;
...
...
@@ -349,7 +371,7 @@ void PaddingChannelPass::apply(OptState& opt) const {
auto
inps
=
new_inp
;
for
(
size_t
i
=
0
;
i
<
new_inp
.
size
();
++
i
)
{
auto
cur_inp
=
opr
->
input
(
i
);
bool
padding_cur_inp
=
padding_oprs
.
count
(
cur_inp
->
owner_opr
())
>
0
;
bool
padding_cur_inp
=
m_
padding_oprs
.
count
(
cur_inp
->
owner_opr
())
>
0
;
if
(
padding_cur_inp
)
{
inps
[
i
]
=
extract_subtensor
(
inps
[
i
],
cur_inp
->
shape
());
}
...
...
@@ -357,72 +379,34 @@ void PaddingChannelPass::apply(OptState& opt) const {
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
opr
->
config
());
}
if
(
padding_all_inps
)
{
padding_oprs
.
insert
(
opr
);
m_
padding_oprs
.
insert
(
opr
);
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
};
opr_replace_funcs
[
opr
::
ElemwiseMultiType
::
typeinfo
()]
=
replace_elemwise_like_opr
;
opr_replace_funcs
[
opr
::
Elemwise
::
typeinfo
()]
=
replace_elemwise_like_opr
;
opr_replace_funcs
[
opr
::
TypeCvt
::
typeinfo
()]
=
replace_elemwise_like_opr
;
m_opr_replace_funcs
[
opr
::
ElemwiseMultiType
::
typeinfo
()]
=
replace_elemwise_like_opr
;
m_opr_replace_funcs
[
opr
::
Elemwise
::
typeinfo
()]
=
replace_elemwise_like_opr
;
m_opr_replace_funcs
[
opr
::
TypeCvt
::
typeinfo
()]
=
replace_elemwise_like_opr
;
}
auto
replace_nonpadding_oprs
=
[
&
padding_oprs
,
&
extract_subtensor
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
void
PaddingChannelPass
::
add_nonpadding_oprs_replace_func
(
LayoutTrans
)
{
auto
replace_nonpadding_oprs
=
[
this
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
inps
=
new_inp
;
for
(
size_t
i
=
0
;
i
<
new_inp
.
size
();
++
i
)
{
auto
cur_inp
=
opr
->
input
(
i
);
bool
padding_cur_inp
=
padding_oprs
.
count
(
cur_inp
->
owner_opr
())
>
0
;
bool
padding_cur_inp
=
m_
padding_oprs
.
count
(
cur_inp
->
owner_opr
())
>
0
;
if
(
padding_cur_inp
)
{
inps
[
i
]
=
extract_subtensor
(
inps
[
i
],
cur_inp
->
shape
());
}
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
inps
,
opr
->
config
());
};
opr_replace_funcs
[
opr
::
Reshape
::
typeinfo
()]
=
replace_nonpadding_oprs
;
opr_replace_funcs
[
opr
::
GetVarShape
::
typeinfo
()]
=
replace_nonpadding_oprs
;
opr_replace_funcs
[
opr
::
Concat
::
typeinfo
()]
=
replace_nonpadding_oprs
;
opr_replace_funcs
[
opr
::
Reduce
::
typeinfo
()]
=
replace_nonpadding_oprs
;
opr_replace_funcs
[
opr
::
Subtensor
::
typeinfo
()]
=
replace_nonpadding_oprs
;
auto
on_opr
=
[
&
opt
,
&
rewriter
,
&
opr_replace_funcs
,
&
extract_subtensor
](
OperatorNodeBase
*
opr
)
{
auto
it
=
opr_replace_funcs
.
find
(
opr
->
dyn_typeinfo
());
if
(
it
!=
opr_replace_funcs
.
end
())
{
VarNodeArray
new_inp
;
new_inp
.
reserve
(
opr
->
input
().
size
());
for
(
auto
&&
inp
:
opr
->
input
())
{
new_inp
.
push_back
(
rewriter
.
get_var
(
inp
));
}
auto
new_opr
=
(
it
->
second
)(
opr
,
new_inp
);
auto
&&
out0
=
opr
->
output
(),
&&
out1
=
new_opr
->
output
();
mgb_assert
(
out0
.
size
()
==
out1
.
size
(),
"bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu"
,
opr
->
cname
(),
opr
->
dyn_typeinfo
()
->
name
,
new_opr
->
cname
(),
new_opr
->
dyn_typeinfo
()
->
name
,
out0
.
size
(),
out1
.
size
());
for
(
size_t
i
=
0
;
i
<
out0
.
size
();
++
i
)
{
if
(
!
out0
[
i
]
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
mgb_assert
(
!
out1
[
i
]
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
));
auto
src
=
out0
[
i
];
auto
dst
=
out1
[
i
];
if
(
opt
.
graph
().
endpoint_contain
(
src
)
&&
!
src
->
shape
().
eq_shape
(
dst
->
shape
()))
{
dst
=
extract_subtensor
(
dst
,
src
->
shape
());
}
rewriter
.
replace_var
(
src
,
dst
,
nullptr
);
}
}
}
else
{
rewriter
.
auto_replace_outputs
(
opr
);
}
};
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
m_opr_replace_funcs
[
opr
::
Reshape
::
typeinfo
()]
=
replace_nonpadding_oprs
;
m_opr_replace_funcs
[
opr
::
GetVarShape
::
typeinfo
()]
=
replace_nonpadding_oprs
;
m_opr_replace_funcs
[
opr
::
Concat
::
typeinfo
()]
=
replace_nonpadding_oprs
;
m_opr_replace_funcs
[
opr
::
Reduce
::
typeinfo
()]
=
replace_nonpadding_oprs
;
m_opr_replace_funcs
[
opr
::
Subtensor
::
typeinfo
()]
=
replace_nonpadding_oprs
;
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
b82e8f00
...
...
@@ -509,8 +509,38 @@ public:
*/
class
PaddingChannelPass
final
:
public
Pass
{
public:
using
ChannelAlignmentMap
=
ThinHashMap
<
DTypeEnum
,
std
::
function
<
size_t
(
size_t
,
bool
)
>>
;
using
LayoutTrans
=
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
;
const
char
*
name
()
const
override
;
void
apply
(
OptState
&
opt
)
const
override
;
void
fill_opr_convert_fun
(
LayoutTrans
layout_trans
);
using
ReplaceFuncs
=
ThinHashMap
<
Typeinfo
*
,
thin_function
<
OperatorNodeBase
*
(
OperatorNodeBase
*
,
const
VarNodeArray
&
)
>>
;
//! make channel padding opt pass with given tensor format
static
std
::
unique_ptr
<
PaddingChannelPass
>
make
(
LayoutTrans
layout_transform
);
private:
VarNode
*
extract_subtensor
(
VarNode
*
inp
,
const
TensorShape
&
orig_shape
)
const
;
VarNode
*
pad_in_channels
(
VarNode
*
inp
,
size_t
pad_channels
);
VarNode
*
pad_out_channels
(
VarNode
*
inp
,
size_t
pad_channels
);
OperatorNodeBase
*
padding_policy
(
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
);
void
add_convbias_replace_func
(
LayoutTrans
layout_transform
);
void
add_conv_backward_data_replace_func
(
LayoutTrans
layout_transform
);
void
add_format_aware_opr_replace_func
(
LayoutTrans
layout_transform
);
void
add_elemwise_like_opr_replace_func
(
LayoutTrans
layout_transform
);
void
add_nonpadding_oprs_replace_func
(
LayoutTrans
layout_transform
);
ChannelAlignmentMap
m_alignment_map
;
ReplaceFuncs
m_opr_replace_funcs
;
mutable
ThinHashSet
<
OperatorNodeBase
*>
m_padding_oprs
;
};
/*!
...
...
src/gopt/test/inference.cpp
浏览文件 @
b82e8f00
#include "megbrain/graph/cg.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h"
...
...
@@ -5037,7 +5038,8 @@ TEST(TestGoptInference, PaddingChannels) {
SymbolVar
y3_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
(
gopt
::
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
))
.
apply
({{
y3
}})
.
endpoint_vars
(),
y3_pad
);
...
...
@@ -5101,7 +5103,8 @@ TEST(TestGoptInference, ConcatAfterPaddingChannels) {
SymbolVar
y2_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
(
gopt
::
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
))
.
apply
({{
y2
}})
.
endpoint_vars
(),
y2_pad
);
...
...
@@ -5166,7 +5169,8 @@ TEST(TestGoptInference, PaddingChannelsWithPooling) {
SymbolVar
y1_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
(
gopt
::
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
))
.
apply
({{
y1
}})
.
endpoint_vars
(),
y1_pad
);
...
...
@@ -5232,7 +5236,8 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) {
SymbolVar
y1_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
(
gopt
::
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
))
.
apply
({{
y1
}})
.
endpoint_vars
(),
y1_pad
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录