提交 fa9d719f 编写于 作者: M Megvii Engine Team

fix(gopt): fix global layout transform fold conv typecvt

GitOrigin-RevId: 66a23a927e1b2355262b78366c690d5bff84da11
上级 767fa474
...@@ -110,7 +110,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( ...@@ -110,7 +110,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
bool is_version_ok = CUDNN_VERSION >= 7500; bool is_version_ok = CUDNN_VERSION >= 7500;
bool is_dtype_ok = bool is_dtype_ok =
(args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || (args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 &&
args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm));
bool is_bias_ok = bool is_bias_ok =
args.bias_layout->ndim == 0 || args.bias_layout->ndim == 0 ||
......
...@@ -76,7 +76,7 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { ...@@ -76,7 +76,7 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
if (conv_bias == nullptr) if (conv_bias == nullptr)
return false; return false;
auto inp_dtype_conv = conv_bias->input(0)->dtype(), auto inp_dtype_conv = conv_bias->input(0)->dtype(),
out_dtype_conv = conv_bias->input(0)->dtype(); out_dtype_conv = conv_bias->output(0)->dtype();
bool is_s8nhwc = bool is_s8nhwc =
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
...@@ -86,7 +86,11 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { ...@@ -86,7 +86,11 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC; conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC;
if (!(is_s8nhwc || is_s4nhwc)) bool is_s8nchw =
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW;
if (!(is_s8nhwc || is_s4nhwc || is_s8nchw))
return false; return false;
if (conv_bias->input().size() != 3) if (conv_bias->input().size() != 3)
return false; return false;
...@@ -107,15 +111,27 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { ...@@ -107,15 +111,27 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32) auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32)
? opr::TypeCvt::make(bias, dtype::Float32()).node() ? opr::TypeCvt::make(bias, dtype::Float32()).node()
: bias; : bias;
auto new_param = conv_bias->param(); if (is_s8nchw && is_s82s4) {
new_param.format = megdnn::param::ConvBias::Format::NHWC; auto new_param = conv_bias->param();
auto conv_bias_typecvt = opr::ConvBias::make( new_param.format = megdnn::param::ConvBias::Format::NCHW;
src, filter, new_bias, new_param, conv_bias->execution_policy(), auto conv_bias_typecvt = opr::ConvBias::make(
OperatorNodeConfig{out_dtype_typecvt}); src, filter, new_bias, new_param, conv_bias->execution_policy(),
rewriter.replace_var( OperatorNodeConfig{out_dtype_typecvt});
opr->output(0), conv_bias_typecvt.node(), rewriter.replace_var(
mgb_cstr_log("replace conv_bias(NHWC) + typecvt " opr->output(0), conv_bias_typecvt.node(),
"to conv_bias(NHWC)")); mgb_cstr_log("replace conv_bias(NCHW) + typecvt "
"to conv_bias(NCHW)"));
} else {
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NHWC;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NHWC) + typecvt "
"to conv_bias(NHWC)"));
}
return true; return true;
}; };
......
...@@ -823,6 +823,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( ...@@ -823,6 +823,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
if (options.target == Target::CUDA) if (options.target == Target::CUDA)
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
#if CUDA_VERSION >= 10020
add_pass<FoldingConvBiasTypecvtPass>();
#endif
add_pass(LayoutTransformPass::make(options.target)); add_pass(LayoutTransformPass::make(options.target));
add_pass<ShuffleShuffleRemovePass>(); add_pass<ShuffleShuffleRemovePass>();
if (options.target == Target::CUDA) { if (options.target == Target::CUDA) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册