提交 fa9d719f 编写于 作者: M Megvii Engine Team

fix(gopt): fix global layout transform fold conv typecvt

GitOrigin-RevId: 66a23a927e1b2355262b78366c690d5bff84da11
上级 767fa474
......@@ -110,7 +110,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
bool is_version_ok = CUDNN_VERSION >= 7500;
bool is_dtype_ok =
(args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 ||
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 &&
args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm));
bool is_bias_ok =
args.bias_layout->ndim == 0 ||
......
......@@ -76,7 +76,7 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
if (conv_bias == nullptr)
return false;
auto inp_dtype_conv = conv_bias->input(0)->dtype(),
out_dtype_conv = conv_bias->input(0)->dtype();
out_dtype_conv = conv_bias->output(0)->dtype();
bool is_s8nhwc =
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
......@@ -86,7 +86,11 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC;
if (!(is_s8nhwc || is_s4nhwc))
bool is_s8nchw =
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW;
if (!(is_s8nhwc || is_s4nhwc || is_s8nchw))
return false;
if (conv_bias->input().size() != 3)
return false;
......@@ -107,15 +111,27 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32)
? opr::TypeCvt::make(bias, dtype::Float32()).node()
: bias;
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NHWC;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NHWC) + typecvt "
"to conv_bias(NHWC)"));
if (is_s8nchw && is_s82s4) {
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NCHW) + typecvt "
"to conv_bias(NCHW)"));
} else {
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NHWC;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NHWC) + typecvt "
"to conv_bias(NHWC)"));
}
return true;
};
......
......@@ -823,6 +823,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
add_pass<FuseConvBiasNonlinPass>();
if (options.target == Target::CUDA)
add_pass<FuseConvBiasZPass>();
#if CUDA_VERSION >= 10020
add_pass<FoldingConvBiasTypecvtPass>();
#endif
add_pass(LayoutTransformPass::make(options.target));
add_pass<ShuffleShuffleRemovePass>();
if (options.target == Target::CUDA) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册