提交 6d6b42bb 编写于 作者: M Megvii Engine Team

refactor(gopt): refactor interface of add passes for common optimizations

GitOrigin-RevId: d0f3819c3a6e969430ebf8a65e26fc2e77ae8aeb
上级 78fe9e55
......@@ -474,8 +474,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
}
#endif
gopt::GraphOptimizer optimizer;
optimizer.apply_optimize_options(options().graph_opt);
options().graph_opt.reset();
optimizer.add_passes_for_optimize_options(options().graph_opt, true);
optimizer.apply_inplace(dest_vars);
const OprNodeArray* opr_seq = nullptr;
......
......@@ -107,19 +107,17 @@ struct GraphCommonOptimizeOptions {
};
LayoutTransform layout_transform = LayoutTransform::DEFAULT;
void reset() {
f16_io_f32_comp = false;
f16_io_comp = false;
fuse_conv_bias_nonlinearity = false;
fuse_conv_bias_with_z = false;
layout_transform = LayoutTransform::DEFAULT;
}
#define SET(n) \
GraphCommonOptimizeOptions& enable_##n() { \
n = true; \
return *this; \
} \
GraphCommonOptimizeOptions& disable_##n() { \
n = false; \
return *this; \
} \
bool has_set_##n() { return n == true; }
#define SET(n) \
GraphCommonOptimizeOptions& enable_##n() { \
n = true; \
return *this; \
}
SET(f16_io_f32_comp);
SET(f16_io_comp);
SET(fuse_conv_bias_nonlinearity);
......@@ -131,7 +129,11 @@ struct GraphCommonOptimizeOptions {
layout_transform = LayoutTransform::_trans_capital; \
return *this; \
} \
bool transform_##_trans() const { \
GraphCommonOptimizeOptions& disable_##_trans() { \
layout_transform = LayoutTransform::DEFAULT; \
return *this; \
} \
bool has_set_##_trans() const { \
return layout_transform == LayoutTransform::_trans_capital; \
}
......
......@@ -675,7 +675,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
if (inference_opt) {
add_pass<ParamFusePass>();
apply_optimize_options(*inference_opt);
add_passes_for_optimize_options(*inference_opt);
}
......@@ -704,56 +704,56 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) {
}
}
const GraphOptimizer& GraphOptimizer::apply_optimize_options(
const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
const cg::GraphCommonOptimizeOptions& options) {
return add_passes_for_optimize_options(
const_cast<cg::GraphCommonOptimizeOptions&>(options));
}
const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cg::GraphCommonOptimizeOptions& options, bool reset) {
bool need_param_fuse = false;
if (options.f16_io_comp) {
add_pass(ConvertF32ToF16Pass::make(false));
need_param_fuse = true;
}
if (options.f16_io_f32_comp) {
add_pass(ConvertF32ToF16Pass::make(true));
need_param_fuse = true;
#define cb(_option, _passes) \
if (options.has_set_##_option()) { \
_passes need_param_fuse = true; \
if (reset) { \
options.disable_##_option(); \
} \
}
if (options.transform_nhwcd4()) {
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
cb(nhwcd4, {
add_pass<FuseConvBiasNonlinPass>();
add_pass(ConvertFormatPass::make_nhwcd4_converter());
need_param_fuse = true;
}
if (options.transform_nchw88()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
need_param_fuse = true;
}
if (options.transform_nchw44()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
need_param_fuse = true;
}
if (options.transform_nchw32()) {
});
cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); });
cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); });
cb(nchw32, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
add_pass(EnableTensorCorePass::make_tensorcore_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
need_param_fuse = true;
}
if (options.transform_chwn4()) {
});
cb(chwn4, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
add_pass(EnableCHWN4Pass::make_chwn4_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
need_param_fuse = true;
}
});
if (options.fuse_conv_bias_nonlinearity) {
add_pass<FuseConvBiasNonlinPass>();
need_param_fuse = true;
}
if (options.fuse_conv_bias_with_z) {
cb(fuse_conv_bias_nonlinearity, { add_pass<FuseConvBiasNonlinPass>(); });
cb(fuse_conv_bias_with_z, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
need_param_fuse = true;
}
});
#undef cb
if (need_param_fuse) {
add_pass<ParamFusePass>();
}
......
......@@ -468,9 +468,16 @@ namespace gopt {
static VarNode* var_replace_lookup(VarNode *var);
/**
* \brief apply optimize options
* \brief add pass indicated by optimize options.
*
* \param options common options
* \param reset if set true, it will reset options when add passes.
*/
const GraphOptimizer& apply_optimize_options(
const GraphOptimizer& add_passes_for_optimize_options(
cg::GraphCommonOptimizeOptions& options,
bool reset = false);
const GraphOptimizer& add_passes_for_optimize_options(
const cg::GraphCommonOptimizeOptions& options);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册