diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index b5248c35182616cde8cfbff8cfc8b6e4e5b4c7f0..7e4390a256d43b1af661ffd011bcf6a463f4b7a7 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -10,12 +10,12 @@ * implied. */ -#include "src/fallback/convolution/opr_impl.h" #include "src/common/algo_chooser.h" #include "src/common/metahelper.h" #include "src/common/opr_delegate.h" #include "src/common/utils.h" #include "src/fallback/convolution/algos.h" +#include "src/fallback/convolution/opr_impl.h" #include "src/fallback/convolution/run_conv.h" #include "src/naive/convolution/helper.h" #include "src/naive/handle.h" @@ -100,10 +100,10 @@ void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, } void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout, - _megdnn_tensor_in filter, - const TensorLayout& dst_layout, - PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) { + _megdnn_tensor_in filter, + const TensorLayout& dst_layout, + PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) { //! exec_preprocess currently only support preprocess weights before exec, //! src/dst will be ignored, just set to nullptr TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout}; @@ -151,7 +151,7 @@ size_t ConvolutionImpl::get_preprocess_workspace_in_bytes( SmallVector ConvolutionImpl::deduce_preprocessed_filter_layout( const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst){ + const TensorLayout& dst) { auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); Algorithm* algo = get_algorithm(fparam); if (is_naive_algo(algo)) { @@ -257,7 +257,8 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NHWC || param.filter_meta.format == Param::Format::NCHW88 || - param.filter_meta.format == Param::Format::NCHW44, + param.filter_meta.format == Param::Format::NCHW44 || + param.filter_meta.format == Param::Format::NCHW44_DOT, "invalid conv format"); auto run = [param, kernel](size_t index, size_t thread_id) { CpuNDRange ndrange_id(kernel.global_size, index); @@ -277,7 +278,8 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NHWC || param.filter_meta.format == Param::Format::NCHW88 || - param.filter_meta.format == Param::Format::NCHW44, + param.filter_meta.format == Param::Format::NCHW44 || + param.filter_meta.format == Param::Format::NCHW44_DOT, "invalid conv format"); auto run = [param, kernel](size_t index, size_t thread_id) { CpuNDRange ndrange_id(kernel.global_size, index); diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index b88fabc8171f96f2c3c9290220379cb1209e96d1..505399faa9a7795790e6936992b62cab70319138 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -31,6 +31,8 @@ #include "megdnn/opr_param_defs.h" #include "megdnn/tensor_format.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + #if MGB_ENABLE_TENSOR_RT #include "megbrain/tensorrt/tensorrt_opr.h" #endif @@ -392,7 +394,8 @@ void TensorReformatPass::insert_pass(OptState& opt) const { auto new_opr = (it->second)(opr, new_inp); auto &&out0 = opr->output(), &&out1 = new_opr->output(); mgb_assert(out0.size() == out1.size(), - "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu " + "bad opr replace: src=%s{%s} dst=%s{%s}, " + "src.size=%zu " "dst.size=%zu", opr->cname(), opr->dyn_typeinfo()->name, new_opr->cname(), new_opr->dyn_typeinfo()->name, @@ -1811,14 +1814,156 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, } return new_var; } -//! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv -static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic, - const size_t pack_c_size, const size_t fh, - const size_t fw, const size_t stride_h, - const size_t stride_w) { - return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw && - stride_h == stride_w && (stride_h == 1 || stride_h == 2) && - (fh == 2 || fh == 3 || fh == 5 || fh == 7); + +template +static inline bool nchw_nchwxx_valid(const OprType& opr, + const VarNodeArray& new_inp, + const size_t pack_size, bool is_dense, + bool is_dot = false); +template <> +inline bool nchw_nchwxx_valid( + const opr::ConvolutionForward& opr, const VarNodeArray& new_inp, + const size_t pack_size, bool is_dense, bool is_dot) { + auto& filter_shape = new_inp[1]->shape(); + auto filter_dtype = new_inp[1]->dtype(); + bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || + filter_dtype.enumv() == DTypeEnum::Int8; + const size_t oc = filter_shape[0]; + const size_t ic = filter_shape[1]; + bool is_like_nchw_nchwxx = + is_dense && oc % pack_size == 0 && ic < pack_size; + if (!is_like_nchw_nchwxx) { + return false; + } + SmallVector layouts; + + //! src + layouts.push_back( + {new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); + + //! weight + layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], + filter_shape[3], filter_shape[1], pack_size}, + new_inp[1]->dtype(), + new_inp[1]->format()}); + + auto out0 = opr.output(0); + auto& out_shape = out0->shape(); + //! FIXME: return false if oc is invalid + layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], + out_shape[3], pack_size}, + out0->dtype(), + out0->format()}); + + auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) + ->create_operator(); + megdnn_conv.get()->param() = opr.param(); + //! set by dtype + switch (pack_size) { + case 4: + if (is_dot && is_int8) { + megdnn_conv.get()->param().format = + megdnn::param::Convolution::Format::NCHW44_DOT; + } else { + megdnn_conv.get()->param().format = + megdnn::param::Convolution::Format::NCHW44; + } + break; + case 8: + megdnn_conv.get()->param().format = + megdnn::param::Convolution::Format::NCHW88; + break; + + default: + break; + } + + bool find_valid_algo = false; + auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1], + layouts[2]); + for (auto i : algos) { + if (i->type() != nullptr) { + find_valid_algo = true; + } + } + + return find_valid_algo; +} +template <> +inline bool nchw_nchwxx_valid( + const opr::ConvBiasForward& opr, const VarNodeArray& new_inp, + const size_t pack_size, bool is_dense, bool is_dot) { + auto& filter_shape = new_inp[1]->shape(); + auto filter_dtype = new_inp[1]->dtype(); + bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || + filter_dtype.enumv() == DTypeEnum::Int8; + const size_t oc = filter_shape[0]; + const size_t ic = filter_shape[1]; + bool is_like_nchw_nchwxx = + is_dense && oc % pack_size == 0 && ic < pack_size; + if (!is_like_nchw_nchwxx) { + return false; + } + SmallVector layouts; + + //! src + layouts.push_back( + {new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); + + //! weight + layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], + filter_shape[3], filter_shape[1], pack_size}, + new_inp[1]->dtype(), + new_inp[1]->format()}); + + auto& bias_shape = new_inp[2]->shape(); + layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2], + bias_shape[3], pack_size}, + new_inp[2]->dtype(), + new_inp[2]->format()}); + + auto out0 = opr.output(0); + auto& out_shape = out0->shape(); + //! FIXME: return false if oc is invalid + layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], + out_shape[3], pack_size}, + out0->dtype(), + out0->format()}); + + // megdnn::ConvolutionForward + auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) + ->create_operator(); + megdnn_conv.get()->param() = opr.param(); + + //! FIXME: set by dtype + switch (pack_size) { + case 4: + if (is_dot && is_int8) { + megdnn_conv.get()->param().format = + megdnn::param::Convolution::Format::NCHW44_DOT; + } else { + megdnn_conv.get()->param().format = + megdnn::param::Convolution::Format::NCHW44; + } + break; + case 8: + megdnn_conv.get()->param().format = + megdnn::param::Convolution::Format::NCHW88; + break; + + default: + break; + } + bool find_valid_algo = false; + auto algos = megdnn_conv.get()->get_all_algorithms( + layouts[0], layouts[1], layouts[2], {}, layouts[3]); + for (auto i : algos) { + if (i->type() != nullptr) { + find_valid_algo = true; + } + } + + return find_valid_algo; } void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { using RelayoutMode = RelayoutPlaceholder::LayoutType; @@ -1839,6 +1984,20 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { megdnn::param::Pooling::Format pooling_format = megdnn::param::Pooling::Format::NCHW88; std::string convter_pass_name = "conv_format_nchw88"; +#if MEGDNN_AARCH64 || MEGDNN_ARMv7 + if (pack_c_size == 8) { + mgb_log_error( + "runtime backend is ARM, but nchw88 only support X86, you may " + "have performance loss\n"); + } +#elif MEGDNN_X86 + if (pack_c_size == 4) { + mgb_log_error( + "runtime backend is X86, but nchw44 only support arm, you may " + "have performance loss\n"); + } +#endif + if (pack_c_size == 4) { weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; @@ -1857,18 +2016,16 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { hybrid_nchw_nchwxx]( const megdnn::param::Convolution::Sparse conv_mode, const VarNode* filter, const size_t stride_h, - const size_t stride_w) -> TestFilterResult { + const size_t stride_w, + bool valid_nchw_nchw44) -> TestFilterResult { TestFilterResult ret{TransType::TRANS_NONE, {}}; if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { size_t OC = filter->shape()[0]; size_t IC = filter->shape()[1]; - size_t FH = filter->shape()[2]; - size_t FW = filter->shape()[3]; if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { ret.first = TransType::TRANS_PURE_NCHWXX; ret.second = weight_to_nchwxx_mode_dense; - } else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, - stride_w)) { + } else if (valid_nchw_nchw44) { ret.first = TransType::TRANS_HYBIRD_NCHWXX; ret.second = hybrid_nchw_nchwxx; } @@ -1888,16 +2045,21 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { return ret; }; auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode, - src_to_nchw_mode](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + src_to_nchw_mode, + pack_c_size](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto& conv_opr = opr->cast_final_safe(); mgb_assert(conv_opr.param().format == megdnn::param::Convolution::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); - auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1], - conv_opr.param().stride_h, - conv_opr.param().stride_w); + bool is_dense = conv_opr.param().sparse == + megdnn::param::Convolution::Sparse::DENSE; + bool valid_nchw_nchw44 = + nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); + auto is_trans = test_trans_nchwxx( + conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, + conv_opr.param().stride_w, valid_nchw_nchw44); //! can not trans to nchwxx if (is_trans.first == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || @@ -1963,17 +2125,23 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { }; auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format, - src_to_nchwxx_mode, src_to_nchw_mode]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + src_to_nchwxx_mode, src_to_nchw_mode, + pack_c_size](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto& conv_bias_opr = opr->cast_final_safe(); mgb_assert(conv_bias_opr.param().format == megdnn::param::ConvBias::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); + bool is_dense = conv_bias_opr.param().sparse == + megdnn::param::Convolution::Sparse::DENSE; + bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, + pack_c_size, is_dense); auto is_trans = test_trans_nchwxx( conv_bias_opr.param().sparse, new_inp[1], - conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); + conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, + valid_nchw_nchw44); + //! can not trans to nchwxx if (is_trans.first == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || @@ -2203,8 +2371,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { MIDOUT_B("EnableNchw44DotPass::make") auto ret = std::make_unique(); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); - //! First is whether the conv can trans to nchwxx, second is the filter - //! trans mode +//! First is whether the conv can trans to nchwxx, second is the filter +//! trans mode +#if MEGDNN_X86 + mgb_log_error( + "backend is X86, but nchw44_dot only support arm, you may have " + "performance loss\n"); +#endif + using RelayoutMode = RelayoutPlaceholder::LayoutType; struct TestTransResult { TransType trans_type; @@ -2215,23 +2389,35 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { auto test_trans_nchw44_dot = [](const megdnn::param::Convolution::Sparse conv_mode, const VarNode* filter, const size_t stride_h, - const size_t stride_w) -> TestTransResult { + const size_t stride_w, + const bool valid_nchw_nchw44) -> TestTransResult { TestTransResult ret{TransType::TRANS_NONE, {}, {}}; + bool is_int8 = filter->dtype().enumv() == DTypeEnum::QuantizedS8 || + filter->dtype().enumv() == DTypeEnum::Int8; if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { size_t OC = filter->shape()[0]; size_t IC = filter->shape()[1]; - size_t FH = filter->shape()[2]; - size_t FW = filter->shape()[3]; if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { ret.trans_type = TransType::TRANS_PURE_NCHWXX; - ret.relayout_mod = - RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; - ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; - } else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, - stride_w)) { + if (is_int8) { + ret.relayout_mod = + RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; + ret.conv_format = + megdnn::param::ConvBias::Format::NCHW44_DOT; + } else { + ret.relayout_mod = + RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; + ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; + } + } else if (valid_nchw_nchw44) { ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; - ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; + if (is_int8) { + ret.conv_format = + megdnn::param::ConvBias::Format::NCHW44_DOT; + } else { + ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; + } } } else { mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); @@ -2244,9 +2430,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { ret.trans_type = TransType::TRANS_PURE_NCHWXX; - ret.relayout_mod = - RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; - ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; + if (is_int8) { + ret.relayout_mod = + RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; + ret.conv_format = + megdnn::param::ConvBias::Format::NCHW44_DOT; + } else { + ret.relayout_mod = + RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; + ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; + } } } return ret; @@ -2260,9 +2453,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { megdnn::param::Convolution::Format::NCHW, "ConvertFormat Pass only support converting NCHW to " "NCHW44_DOT"); + bool is_dense = conv_opr.param().sparse == + megdnn::param::Convolution::Sparse::DENSE; + bool valid_nchw_nchw44 = + nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); auto is_trans = test_trans_nchw44_dot( conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, - conv_opr.param().stride_w); + conv_opr.param().stride_w, valid_nchw_nchw44); + //! can not trans to nchwxx if (is_trans.trans_type == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || @@ -2335,9 +2533,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { mgb_assert(conv_bias_opr.param().format == megdnn::param::ConvBias::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); + bool is_dense = conv_bias_opr.param().sparse == + megdnn::param::Convolution::Sparse::DENSE; + bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, + pack_c_size, is_dense); auto is_trans = test_trans_nchw44_dot( conv_bias_opr.param().sparse, new_inp[1], - conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); + conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, + valid_nchw_nchw44); + auto megdnn_conv = + opr::intl::get_megdnn_handle(conv_bias_opr.comp_node()) + ->create_operator(); + SmallVector layouts; + //! can not trans to nchwxx if (is_trans.trans_type == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || @@ -2350,6 +2558,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { new_inp[0], RelayoutMode::NCHW4_TO_NCHW); temp_inp[0] = new_src.node(); } + //! the bias is nchwxx if (temp_inp[2]->shape().ndim == 5) { auto new_bias = RelayoutPlaceholder::make( diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index a8bbb9ffb5cb88c274fa9703b210894fd3d8624c..093f704f0b23a6529eae4bcea48920a27a085c0b 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -32,6 +32,7 @@ #include "./helper.h" #include "megbrain/comp_node_env.h" +#include "cpuinfo.h" #include "megdnn/tensor_format.h" #include @@ -49,7 +50,21 @@ T& find_opr(SymbolVar endpoint) { } }; cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); - mgb_assert(found); + mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str()); + return *found; +} + +template +T& find_opr(SymbolVar endpoint, const std::string& node_name) { + T* found = nullptr; + auto cb = [&found, &node_name](cg::OperatorNodeBase* opr) { + if (!found && opr->same_type() && opr->name() == node_name) { + found = &opr->cast_final_safe(); + } + }; + cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); + mgb_assert(found, "not found opr %s from %s", node_name.c_str(), + endpoint.node()->name().c_str()); return *found; } @@ -2973,25 +2988,48 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) .rename(name); }; + auto mkcvar_dtype = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; auto host_x = gen({2, 3, 16, 16}, cn); auto x = opr::Host2DeviceCopy::make(*graph, host_x); //! Hybrid nchw44 mode opr::Convolution::Param param_conv; param_conv.pad_h = param_conv.pad_w = 1; - opr::ConvBias::Param param_conv_bias_stride4; - param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; auto w1 = mkcvar("w1", {8, 3, 3, 3}), - conv1 = opr::Convolution::make(x, w1, param_conv); - auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), - conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); - auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); - //! channel wise + conv1 = opr::Convolution::make(x, w1, param_conv, {}, + OperatorNodeConfig("conv1")); + + //! no supported hybrid nchw44 + opr::ConvBias::Param param_conv_bias_pad0; + param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0; + auto b1 = mkcvar("b1", {1, 8, 1, 1}); + auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1}); + auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {}, + OperatorNodeConfig("conv1_f1")); + + auto conv1_add = conv1_f1 * conv1; + auto conv_1_q8 = opr::TypeCvt::make(conv1_add, dtype::QuantizedS8(2.5f)); + + //! s8 dense conv opr::ConvBias::Param param_conv_bias; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f)); + auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); + auto conv_1_2 = opr::ConvBias::make( + conv_1_q8, w1_2, b1_2, param_conv_bias, {}, + OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}}); + auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32()); + + //! channel wise param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), - conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); + conv2 = opr::ConvBias::make(conv_1_2_fp32, w2, b2, param_conv_bias); //! group auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); @@ -3013,42 +3051,67 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { //! Dense param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; - auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), - conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); + auto w3_2 = mkcvar("w3_2", {16, 8, 3, 3}), + b3_2 = mkcvar("b3_2", {1, 16, 1, 1}), + conv3_2 = opr::ConvBias::make(elem, w3_2, b3_2, param_conv_bias, {}, + OperatorNodeConfig("conv3_2")); + //! s8 group conv + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + auto conv3_2_q8 = opr::TypeCvt::make(conv3_2, dtype::QuantizedS8(2.5f)); + auto w3_3 = mkcvar_dtype("w3_3", {4, 8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b3_3 = mkcvar_dtype("b3_3", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)), + conv3_3_q = opr::ConvBias::make( + conv3_2_q8, w3_3, b3_3, param_conv_bias, {}, + OperatorNodeConfig{"conv_3_3_q", cn, + dtype::QuantizedS8{6.25f}}); + auto conv3_3 = opr::TypeCvt::make(conv3_3_q, dtype::Float32()); + + //! Dense + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; + auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), + conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {}, + OperatorNodeConfig("conv4")); + auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), - conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); + conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {}, + OperatorNodeConfig("conv5")); auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), - y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); + y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {}, + OperatorNodeConfig("conv6")); - SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt; + SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; options.enable_nchw44(); - unpack_vector(gopt::optimize_for_inference( - {y, conv1, conv1_f4, conv1_s4, conv2}, options), - y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt); - - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, - find_opr(conv1_opt).param().format); - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, - find_opr(conv1_s4_opt).param().format); - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, - find_opr(conv1_f4_opt).param().format); - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, - find_opr(conv2_opt).param().format); - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, - find_opr(y_opt).param().format); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); - graph->compile({{y_opt, {}}, {conv2, {}}}) +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt, "conv1").param().format); +#else + ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + find_opr(y_opt, "conv1").param().format); +#endif + ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + find_opr(y_opt, "conv1_f1").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt, "conv_1_2").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt, "conv3_2").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt, "conv_3_3_q").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt, "conv4").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + find_opr(y_opt, "conv5").param().format); + + graph->compile({{y_opt, {}}}) ->to_json() ->writeto_fpath( output_file("TestGoptInference.ConvertFormatNCHW44.json")); HostTensorND host_y_opt, host_y; - HostTensorND host_conv1; auto func = graph->compile({make_callback_copy(y, host_y), - make_callback_copy(y_opt, host_y_opt), - make_callback_copy(conv1, host_conv1)}); - + make_callback_copy(y_opt, host_y_opt)}); func->execute(); //! meybe go to winograd in x86-32, so set error 1e-1 MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); @@ -3155,25 +3218,58 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) .rename(name); }; + auto mkcvar_dtype = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; auto host_x = gen({2, 3, 16, 16}, cn); auto x = opr::Host2DeviceCopy::make(*graph, host_x); //! Hybrid nchw44 mode opr::Convolution::Param param_conv; param_conv.pad_h = param_conv.pad_w = 1; - opr::ConvBias::Param param_conv_bias_stride4; - param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; auto w1 = mkcvar("w1", {8, 3, 3, 3}), - conv1 = opr::Convolution::make(x, w1, param_conv); - auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), - conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); - auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); - //! channel wise + conv1 = opr::Convolution::make(x, w1, param_conv, {}, + OperatorNodeConfig("conv1")); + printf("create conv1 %s\n", + conv1.node()->owner_opr()->dyn_typeinfo()->name); + param_conv.pad_h = param_conv.pad_w = 1; + //! no supported hybrid nchw44 + opr::ConvBias::Param param_conv_bias_pad0; + param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0; + auto b1 = mkcvar("b1", {1, 8, 1, 1}); + auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1}); + auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {}, + OperatorNodeConfig("conv1_f1")); + + //! hybrid dot + auto x_s = opr::TypeCvt::make(x, dtype::QuantizedS8(2.5f)); + auto w1_3 = mkcvar_dtype("w1_3", {8, 3, 3, 3}, dtype::QuantizedS8(2.5f)); + auto conv1_3_q = opr::Convolution::make( + x_s, w1_3, param_conv, {}, + OperatorNodeConfig{"conv1_3_q", cn, dtype::QuantizedS8{6.25f}}); + auto conv1_3 = opr::TypeCvt::make(conv1_3_q, dtype::Float32()); + + auto conv1_add = conv1_f1 * conv1 * conv1_3; + auto conv_1_q8 = opr::TypeCvt::make(conv1_add, dtype::QuantizedS8(2.5f)); + + //! s8 dense conv opr::ConvBias::Param param_conv_bias; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f)); + auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); + auto conv_1_2 = opr::ConvBias::make( + conv_1_q8, w1_2, b1_2, param_conv_bias, {}, + OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}}); + auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32()); + + //! channel wise param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), - conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); + conv2 = opr::ConvBias::make(conv_1_2_fp32, w2, b2, param_conv_bias); //! group auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); @@ -3195,35 +3291,68 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { //! Dense param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; - auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), - conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); + auto w3_2 = mkcvar("w3_2", {16, 8, 3, 3}), + b3_2 = mkcvar("b3_2", {1, 16, 1, 1}), + conv3_2 = opr::ConvBias::make(elem, w3_2, b3_2, param_conv_bias, {}, + OperatorNodeConfig("conv3_2")); + //! s8 group conv + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + auto conv3_2_q8 = opr::TypeCvt::make(conv3_2, dtype::QuantizedS8(2.5f)); + auto w3_3 = mkcvar_dtype("w3_3", {4, 8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b3_3 = mkcvar_dtype("b3_3", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)), + conv3_3_q = opr::ConvBias::make( + conv3_2_q8, w3_3, b3_3, param_conv_bias, {}, + OperatorNodeConfig{"conv_3_3_q", cn, + dtype::QuantizedS8{6.25f}}); + auto conv3_3 = opr::TypeCvt::make(conv3_3_q, dtype::Float32()); + + //! Dense + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; + auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), + conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {}, + OperatorNodeConfig("conv4")); + auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), - conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); + conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {}, + OperatorNodeConfig("conv5")); auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), - y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); + y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {}, + OperatorNodeConfig("conv6")); - SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt; + SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_conv_bias_nonlinearity(); options.enable_nchw44_dot(); - unpack_vector(gopt::optimize_for_inference({y, conv1, conv1_f4, conv1_s4}, - options), - y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt); - - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, - find_opr(conv1_opt).param().format); - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, - find_opr(conv1_s4_opt).param().format); - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, - find_opr(conv1_f4_opt).param().format); - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, - find_opr(y_opt).param().format); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, - find_opr(y_opt).param().format); + find_opr(y_opt, "conv1").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, + find_opr(y_opt, "conv1_3_q").param().format); +#else + ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + find_opr(y_opt, "conv1").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + find_opr(y_opt, "conv1_3_q").param().format); +#endif + ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + find_opr(y_opt, "conv1_f1").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, + find_opr(y_opt, "conv_1_2").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt, "conv3_2").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, + find_opr(y_opt, "conv_3_3_q").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt, "conv4").param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + find_opr(y_opt, "conv5").param().format); graph->compile({{y_opt, {}}}) ->to_json() - ->writeto_fpath( - output_file("TestGoptInference.ConvertFormatNCHW44.json")); + ->writeto_fpath(output_file( + "TestGoptInference.ConvertFormatNCHW44_DOT.json")); HostTensorND host_y_opt, host_y; auto func = graph->compile({make_callback_copy(y, host_y), diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 044577bceb2a6b1a5a99e984de52ab5715748dee..588464738c59d4c4330edd5f779086cb793c5b7c 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -608,11 +608,16 @@ public: auto algo = get_algo(ctx); size_t workspace = ctx.get_workspace_size_bytes(algo); mgb_log_debug( - "%s: input shapes (%s, %s): algo=%s " + "%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " "workspace=%.2fMiB reproducible=%d", mgb_opr->dyn_typeinfo()->name, layouts[0].TensorShape::to_string().c_str(), - layouts[1].TensorShape::to_string().c_str(), algo->name(), + layouts[0].dtype.name(), + layouts[1].TensorShape::to_string().c_str(), + layouts[1].dtype.name(), + layouts[layouts.size() - 1].TensorShape::to_string().c_str(), + layouts[layouts.size() - 1].dtype.name(), + algo->name(), workspace / (1024 * 1024.0), algo->is_reproducible()); megdnn_opr->execution_policy() = {algo}; return workspace;