提交 b82e8f00 编写于 作者: M Megvii Engine Team

refactor(gopt): refact the padding channel opt pass

GitOrigin-RevId: ee3f55aa66f21fe2d4a042298aafe4a0a02915f7
上级 f444d4fe
......@@ -783,7 +783,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
});
cb(nchw64, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<PaddingChannelPass>();
add_pass(PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64));
add_pass<FuseConvBiasZPass>();
add_pass(EnableNCHW64Pass::make_nchw64_converter());
add_pass<ShuffleShuffleRemovePass>();
......
此差异已折叠。
......@@ -509,8 +509,38 @@ public:
*/
class PaddingChannelPass final : public Pass {
public:
using ChannelAlignmentMap =
ThinHashMap<DTypeEnum, std::function<size_t(size_t, bool)>>;
using LayoutTrans = cg::GraphCommonOptimizeOptions::LayoutTransform;
const char* name() const override;
void apply(OptState& opt) const override;
void fill_opr_convert_fun(LayoutTrans layout_trans);
using ReplaceFuncs = ThinHashMap<
Typeinfo*,
thin_function<OperatorNodeBase*(OperatorNodeBase*, const VarNodeArray&)>>;
//! make channel padding opt pass with given tensor format
static std::unique_ptr<PaddingChannelPass> make(LayoutTrans layout_transform);
private:
VarNode* extract_subtensor(VarNode* inp, const TensorShape& orig_shape) const;
VarNode* pad_in_channels(VarNode* inp, size_t pad_channels);
VarNode* pad_out_channels(VarNode* inp, size_t pad_channels);
OperatorNodeBase* padding_policy(
OperatorNodeBase* opr, const VarNodeArray& new_inp);
void add_convbias_replace_func(LayoutTrans layout_transform);
void add_conv_backward_data_replace_func(LayoutTrans layout_transform);
void add_format_aware_opr_replace_func(LayoutTrans layout_transform);
void add_elemwise_like_opr_replace_func(LayoutTrans layout_transform);
void add_nonpadding_oprs_replace_func(LayoutTrans layout_transform);
ChannelAlignmentMap m_alignment_map;
ReplaceFuncs m_opr_replace_funcs;
mutable ThinHashSet<OperatorNodeBase*> m_padding_oprs;
};
/*!
......
#include "megbrain/graph/cg.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h"
......@@ -5037,7 +5038,8 @@ TEST(TestGoptInference, PaddingChannels) {
SymbolVar y3_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64))
.apply({{y3}})
.endpoint_vars(),
y3_pad);
......@@ -5101,7 +5103,8 @@ TEST(TestGoptInference, ConcatAfterPaddingChannels) {
SymbolVar y2_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64))
.apply({{y2}})
.endpoint_vars(),
y2_pad);
......@@ -5166,7 +5169,8 @@ TEST(TestGoptInference, PaddingChannelsWithPooling) {
SymbolVar y1_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64))
.apply({{y1}})
.endpoint_vars(),
y1_pad);
......@@ -5232,7 +5236,8 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) {
SymbolVar y1_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64))
.apply({{y1}})
.endpoint_vars(),
y1_pad);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册