diff --git a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp index 19c5c99e7ecb63279e07657f84adef653ca43ad5..b696395f6616f847deaed25a5f0cf1bdda3ee599 100644 --- a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp +++ b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp @@ -110,7 +110,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( bool is_version_ok = CUDNN_VERSION >= 7500; bool is_dtype_ok = (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && - (args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || + (args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 && args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); bool is_bias_ok = args.bias_layout->ndim == 0 || diff --git a/src/gopt/impl/folding_conv_typecvt.cpp b/src/gopt/impl/folding_conv_typecvt.cpp index 76c187b3c2197098d74113c4d029d36a3f3e015c..1ca0179d8078788a165175cbc6a5683339b9e867 100644 --- a/src/gopt/impl/folding_conv_typecvt.cpp +++ b/src/gopt/impl/folding_conv_typecvt.cpp @@ -76,7 +76,7 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { if (conv_bias == nullptr) return false; auto inp_dtype_conv = conv_bias->input(0)->dtype(), - out_dtype_conv = conv_bias->input(0)->dtype(); + out_dtype_conv = conv_bias->output(0)->dtype(); bool is_s8nhwc = inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && out_dtype_conv.enumv() == inp_dtype_conv.enumv() && @@ -86,7 +86,11 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && out_dtype_conv.enumv() == inp_dtype_conv.enumv() && conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC; - if (!(is_s8nhwc || is_s4nhwc)) + bool is_s8nchw = + inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && + out_dtype_conv.enumv() == inp_dtype_conv.enumv() && + conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW; + if (!(is_s8nhwc || is_s4nhwc || is_s8nchw)) return false; if (conv_bias->input().size() != 3) return false; @@ -107,15 +111,27 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32) ? opr::TypeCvt::make(bias, dtype::Float32()).node() : bias; - auto new_param = conv_bias->param(); - new_param.format = megdnn::param::ConvBias::Format::NHWC; - auto conv_bias_typecvt = opr::ConvBias::make( - src, filter, new_bias, new_param, conv_bias->execution_policy(), - OperatorNodeConfig{out_dtype_typecvt}); - rewriter.replace_var( - opr->output(0), conv_bias_typecvt.node(), - mgb_cstr_log("replace conv_bias(NHWC) + typecvt " - "to conv_bias(NHWC)")); + if (is_s8nchw && is_s82s4) { + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW; + auto conv_bias_typecvt = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + OperatorNodeConfig{out_dtype_typecvt}); + rewriter.replace_var( + opr->output(0), conv_bias_typecvt.node(), + mgb_cstr_log("replace conv_bias(NCHW) + typecvt " + "to conv_bias(NCHW)")); + } else { + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NHWC; + auto conv_bias_typecvt = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + OperatorNodeConfig{out_dtype_typecvt}); + rewriter.replace_var( + opr->output(0), conv_bias_typecvt.node(), + mgb_cstr_log("replace conv_bias(NHWC) + typecvt " + "to conv_bias(NHWC)")); + } return true; }; diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 6a72ff9f0be27cf83f4f47edcb402d9c1c2d5559..15f64b4ba2d537ee21667569c6c15a10f47eea90 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -823,6 +823,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( add_pass(); if (options.target == Target::CUDA) add_pass(); +#if CUDA_VERSION >= 10020 + add_pass(); +#endif add_pass(LayoutTransformPass::make(options.target)); add_pass(); if (options.target == Target::CUDA) {