From c82d88751a2e3cc0b81875817e0934977bd3f669 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 10 Mar 2021 17:58:08 +0800 Subject: [PATCH] fix(dnn/cuda): add cuda nchw int8 conv impl with nchw4 to fix cu111 compatibility GitOrigin-RevId: 771968f9ace72aa6215b5fa95fc00277077eb1c6 --- .gitattributes | 1 + dnn/include/megdnn/oprs/general.h | 1 + dnn/scripts/opr_param_defs.py | 10 +- dnn/src/common/relayout_format.cpp | 124 ++++-- dnn/src/cuda/conv_bias/algo.cpp | 1 + dnn/src/cuda/conv_bias/algo.h | 27 +- dnn/src/cuda/conv_bias/conv_nchwqs8.cpp | 162 ++++++++ dnn/src/cuda/conv_bias/opr_impl.cpp | 5 + dnn/src/cuda/conv_bias/opr_impl.h | 1 + dnn/src/cuda/relayout_format/opr_impl.cpp | 15 +- .../cuda/relayout_format/relayout_format.cpp | 25 +- .../cuda/relayout_format/relayout_format.cu | 368 +++++++++++++++--- .../cuda/relayout_format/relayout_format.cuh | 21 +- .../cuda/relayout_format/relayout_format.h | 4 +- dnn/src/cuda/utils.cuh | 1 + dnn/src/naive/relayout_format/opr_impl.cpp | 186 ++++++++- dnn/test/cuda/conv_bias.cpp | 94 ++++- dnn/test/cuda/relayout_format.cpp | 95 ++++- dnn/test/naive/relayout_format.cpp | 133 ++++++- src/gopt/test/inference.cpp | 17 +- src/opr/impl/tensor_manip.sereg.h | 4 +- src/tensorrt/test/opr_replace.cpp | 5 +- 22 files changed, 1184 insertions(+), 116 deletions(-) create mode 100644 dnn/src/cuda/conv_bias/conv_nchwqs8.cpp diff --git a/.gitattributes b/.gitattributes index 2d3a97507..e189efca8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -11,3 +11,4 @@ imperative/python/test/integration/data/*.mge filter=lfs diff=lfs merge=lfs -tex ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text ci/resource/models/float/shufflenet_v2.pkl filter=lfs diff=lfs merge=lfs -text ci/resource/dump/roi_align_backward_8.8.0.mdl filter=lfs diff=lfs merge=lfs -text +ci/resource/dump/relayout_format_8.10.0.mdl filter=lfs diff=lfs merge=lfs -text diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 24745889b..d8e954e33 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1233,6 +1233,7 @@ protected: void check_exec(const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes); void deduce_exec_layout(const TensorLayout& src, const TensorLayout& dst, + TensorLayout& exec_workspace, TensorLayout& exec_src, TensorLayout& exec_dst); }; } // namespace megdnn diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 9f02b7518..2167fe3b2 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -935,8 +935,9 @@ Relayout mode. Note: the axis column means the corresponding ``align_axis`` for image format when the ``I`` suffix is present. +Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later opr by seting group and oc param with NCHW4_NCHW """ -(pdef('RelayoutFormat', 'Change the tensor layout format'). +(pdef('RelayoutFormat', 'Change the tensor layout format', version=0, is_legacy=True). add_enum( Doc('Mode', RELAYOUT_FORMAT_MODE_DOC), 'NHWC_NHWCD4', @@ -964,9 +965,16 @@ when the ``I`` suffix is present. 'NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', 'NCHW_NCHW4', + 'NCHW4_NCHW', + 'NCHW_NCHW4_WEIGHT', ) ) +(pdef('RelayoutFormat', 'Change the tensor layout format', version=1). + add_enum_alias('Mode', 'RelayoutFormatV0'). + add_fields('uint32', 'oc', '0'). + add_fields('uint32', 'group', '1') +) (pdef('SeparableFilter', version=0, is_legacy=True). add_enum_alias('Format', 'ConvolutionV0'). diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index 7bd3ac258..645e6bac8 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -208,14 +208,48 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, dst[3] = src[2]; dst[4] = src[4]; break; - case Param::Mode::NCHW_NCHW4: + case Param::Mode::NCHW_NCHW4: { megdnn_assert(src.ndim == 4); + const size_t group = param().group; + megdnn_assert(src[1] % group == 0); + const size_t icpg = src[1] / group; dst.ndim = 5; dst[0] = src[0]; - dst[1] = div_ceil(src[1], 4); + dst[1] = group * div_ceil(icpg, 4); dst[2] = src[2]; dst[3] = src[3]; dst[4] = 4; + }; break; + case Param::Mode::NCHW_NCHW4_WEIGHT:; + { + if (src.ndim == 4) { + //! dense case + dst.ndim = 5; + dst[0] = div_ceil(src[0], 4) * 4; + dst[1] = div_ceil(src[1], 4); + dst[2] = src[2]; + dst[3] = src[3]; + dst[4] = 4; + } else if (src.ndim == 5) { + //! group case + dst.ndim = 6; + dst[0] = src[0]; + dst[1] = div_ceil(src[1], 4) * 4; + dst[2] = div_ceil(src[2], 4); + dst[3] = src[3]; + dst[4] = src[4]; + dst[5] = 4; + } + }; + break; + case Param::Mode::NCHW4_NCHW: + megdnn_assert(src.ndim == 5); + dst.ndim = 4; + dst[0] = src[0]; + dst[1] = param().oc == 0 ? src[1] * 4 : param().oc; + dst[2] = src[2]; + dst[3] = src[3]; + megdnn_assert(dst[1] % param().group == 0); break; default: megdnn_assert(0, "Invalid RelayoutFormat Mode"); @@ -258,16 +292,13 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { CHECK_SRC(DefaultTensorFormat::make()); dst = src; break; - case Param::Mode::NCHW_NCHW4: - CHECK_SRC(DefaultTensorFormat::make()); - dst = src; - break; case Param::Mode::NCHW_NHWCD4I: CHECK_SRC(DefaultTensorFormat::make()); dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type); break; case Param::Mode::NHWCD4I_NCHW: - CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); + CHECK_SRC( + Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); dst = DefaultTensorFormat::make(); break; case Param::Mode::NHWCD4_NCHW: @@ -308,6 +339,9 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { CHECK_SRC(DefaultTensorFormat::make()); dst = src; break; + case Param::Mode::NCHW4_NCHW: + case Param::Mode::NCHW_NCHW4: + case Param::Mode::NCHW_NCHW4_WEIGHT: case Param::Mode::NCHW_NCHW88: case Param::Mode::NCHW88_NCHW: case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: @@ -354,6 +388,7 @@ void RelayoutFormat::check_exec(const TensorLayout& src, void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, const TensorLayout& dst, + TensorLayout& exec_workspace, TensorLayout& exec_src, TensorLayout& exec_dst) { check_layout_fwd(src, dst); @@ -362,10 +397,10 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, case Param::Mode::NCHW_NCHW88: // nchw to nchw8c { - TensorLayout work_space_layout( + exec_workspace = TensorLayout( {src[0], round_up(src[1], 8_z), src[2], src[3]}, src.dtype, src.format); - exec_src = work_space_layout + exec_src = exec_workspace .reshape({src[0], div_ceil(src[1], 8_z), 8, src[2], src[3]}) .dimshuffle({0, 1, 3, 4, 2}); @@ -375,13 +410,56 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, case Param::Mode::NCHW_NCHW4: // nchw to nchw4 { - TensorLayout work_space_layout( - {src[0], round_up(src[1], 4_z), src[2], src[3]}, + const size_t group = param().group; + const size_t icpg = src[1] / group; + exec_workspace = TensorLayout( + {src[0], group * round_up(icpg, 4_z), src[2], src[3]}, src.dtype, src.format); - exec_src = work_space_layout - .reshape({src[0], div_ceil(src[1], 4_z), 4, - src[2], src[3]}) - .dimshuffle({0, 1, 3, 4, 2}); + exec_src = + exec_workspace + .reshape({src[0], group * div_ceil(icpg, 4_z), + 4, src[2], src[3]}) + .dimshuffle({0, 1, 3, 4, 2}); + exec_dst = dst; + } + break; + case Param::Mode::NCHW_NCHW4_WEIGHT: + // nchw to nchw4_weight + { + if (src.ndim == 4) { + exec_workspace = TensorLayout( + {round_up(src[0], 4_z), round_up(src[1], 4_z), + src[2], src[3]}, + src.dtype, src.format); + exec_src = exec_workspace + .reshape({round_up(src[0], 4_z), + div_ceil(src[1], 4_z), 4, + src[2], src[3]}) + .dimshuffle({0, 1, 3, 4, 2}); + exec_dst = dst; + } else if (src.ndim == 5) { + exec_workspace = TensorLayout( + {src[0], round_up(src[1], 4_z), + round_up(src[2], 4_z), src[3], src[4]}, + src.dtype, src.format); + exec_src = exec_workspace + .reshape({src[0], round_up(src[1], 4_z), + div_ceil(src[2], 4_z), 4, + src[3], src[4]}) + .dimshuffle({0, 1, 2, 4, 5, 3}); + exec_dst = dst; + } + } + break; + case Param::Mode::NCHW4_NCHW: + // nchw to nchw4 + { + exec_workspace = + TensorLayout({src[0], src[1] * 4, src[2], src[3]}, + src.dtype, src.format) + .reshape({src[0], src[1], 4, src[2], src[3]}) + .dimshuffle({0, 1, 3, 4, 2}); + exec_src = src; exec_dst = dst; } break; @@ -396,11 +474,11 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, { megdnn_assert(src.ndim == 4); megdnn_assert(src[0] % 8 == 0); - TensorLayout work_space_layout( + exec_workspace = TensorLayout( {src[0], round_up(src[1], 8_z), src[2], src[3]}, src.dtype, src.format); exec_src = - work_space_layout + exec_workspace .reshape({src[0] / 8, 8, div_ceil(src[1], 8_z), 8, src[2], src[3]}) .dimshuffle({0, 2, 4, 5, 3, 1}); @@ -411,10 +489,10 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, // goihw to goihw8g { megdnn_assert(src.ndim == 5); - TensorLayout work_space_layout( + exec_workspace = TensorLayout( {round_up(src[0], 8_z), src[1], src[2], src[3], src[4]}, src.dtype, src.format); - exec_src = work_space_layout + exec_src = exec_workspace .reshape({div_ceil(src[0], 8_z), 8, src[1], src[2], src[3], src[4]}) .dimshuffle({0, 2, 3, 4, 5, 1}); @@ -426,10 +504,10 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, { megdnn_assert(src.ndim == 5); megdnn_assert(src[1] % 8 == 0); - TensorLayout work_space_layout( + exec_workspace = TensorLayout( {src[0], src[1], round_up(src[2], 8_z), src[3], src[4]}, src.dtype, src.format); - exec_src = work_space_layout + exec_src = exec_workspace .reshape({src[0], src[1] / 8, 8, div_ceil(src[2], 8_z), 8, src[3], src[4]}) @@ -442,10 +520,10 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: // nchw to nchw4c or oihw to oihw4i { - TensorLayout work_space_layout( + exec_workspace = TensorLayout( {src[0], round_up(src[1], 4_z), src[2], src[3]}, src.dtype, src.format); - exec_src = work_space_layout + exec_src = exec_workspace .reshape({src[0], div_ceil(src[1], 4_z), 4, src[2], src[3]}) .dimshuffle({0, 1, 3, 4, 2}); diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index 4236e031d..5367305e2 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -91,6 +91,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&algo); } all_algos.push_back(&int8_chwn4_dotprod); + all_algos.push_back(&fallback_nchw_qs8); for (size_t i = all_algo_size; i < all_algos.size(); ++i) { non_cudnn_algos.push_back(all_algos[i]); } diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index f764a6975..af9292173 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -15,14 +15,14 @@ #include "megdnn/oprs.h" #include "src/common/algo_base.h" -#include "src/common/utils.h" #include "src/common/metahelper.h" +#include "src/common/utils.h" #include "src/cuda/conv_bias/conv_bias_int8.cuh" #include "src/cuda/conv_bias/helper.h" #include "src/cuda/conv_bias/opr_impl.h" #include "src/cuda/convolution_helper/parameter.cuh" -#include "src/cuda/handle.h" #include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/handle.h" #include #include @@ -547,6 +547,28 @@ private: std::string m_name; }; +class ConvBiasForwardImpl::AlgoFallbackNCHWQS8 final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + const char* name() const override { + return "FALLBACK_CONV_NCHW_QS8"; + } + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE; + } + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8) + +private: + void make_inner_layout(const SizeArgs& args, TensorLayout& inner_src_layout, + TensorLayout& inner_weight_layout, + TensorLayout& inner_dst_layout, + TensorLayout& inner_bias_layout, + TensorLayout& inner_z_layout) const; + WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; +}; + #if CUDA_VERSION >= 10000 class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm final : public AlgoBase { @@ -763,6 +785,7 @@ public: non_cudnn_algos, bfloat16_algos; std::vector cudnn_conv_bias_activations; std::vector cudnn_convs; + AlgoFallbackNCHWQS8 fallback_nchw_qs8; AlgoChanwise chanwise; AlgoChanwiseSmall chanwise_small; AlgoChanwise8x8x32 chanwise8x8x32; diff --git a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp new file mode 100644 index 000000000..fb5202c05 --- /dev/null +++ b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp @@ -0,0 +1,162 @@ +/** + * \file dnn/src/cuda/conv_bias/conv_nchwqs8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/algo.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/relayout_format/opr_impl.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace conv_bias; + +namespace { +inline void deduce_reformat_layout(std::unique_ptr& relayout, + const TensorLayout& src_layout, + TensorLayout& dst_layout, + RelayoutFormat::Param::Mode mode, + const int oc = 0, const int group = 1) { + if (src_layout.ndim > 0) { + RelayoutFormat::Param trans_param; + trans_param.mode = mode; + trans_param.oc = oc; + trans_param.group = group; + relayout->param() = trans_param; + relayout->deduce_layout(src_layout, dst_layout); + } else { + dst_layout = src_layout; + } +} +} // namespace + +void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout( + const SizeArgs& args, TensorLayout& inner_src_layout, + TensorLayout& inner_weight_layout, TensorLayout& inner_dst_layout, + TensorLayout& inner_bias_layout, TensorLayout& inner_z_layout) const { + auto relayout_src = args.handle->create_operator(); + deduce_reformat_layout(relayout_src, *args.src_layout, inner_src_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, + args.filter_meta.group); + deduce_reformat_layout(relayout_src, *args.filter_layout, + inner_weight_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT); + deduce_reformat_layout(relayout_src, *args.dst_layout, inner_dst_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, + args.filter_meta.group); + deduce_reformat_layout(relayout_src, *args.bias_layout, inner_bias_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, + args.filter_meta.group); + deduce_reformat_layout(relayout_src, *args.z_layout, inner_z_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, + args.filter_meta.group); +}; + +bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( + const SizeArgs& args) const { + auto&& param = args.opr->param(); + bool is_format_ok = param.format == param::ConvBias::Format::NCHW; + bool is_version_ok = CUDNN_VERSION >= 7500; + bool is_dtype_ok = + args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; + bool is_bias_ok = + args.bias_layout->ndim == 0 || + (args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && + args.bias_layout->shape[2] == 1 && + args.bias_layout->shape[3] == 1); + bool is_ok = is_format_ok && is_version_ok && is_dtype_ok && is_bias_ok; + return is_ok; +} + +WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( + void* ptr, const SizeArgs& args) const { + TensorLayout inner_src_layout; + TensorLayout inner_weight_layout; + TensorLayout inner_dst_layout; + TensorLayout inner_bias_layout; + TensorLayout inner_z_layout; + make_inner_layout(args, inner_src_layout, inner_weight_layout, + inner_dst_layout, inner_bias_layout, inner_z_layout); + auto opr = args.handle->create_operator(); + Param inner_conv_param = args.opr->param(); + inner_conv_param.format = Param::Format::NCHW4; + opr->param() = inner_conv_param; + return WorkspaceBundle(ptr, {inner_src_layout.span().dist_byte(), + inner_weight_layout.span().dist_byte(), + inner_dst_layout.span().dist_byte(), + inner_bias_layout.span().dist_byte(), + inner_z_layout.span().dist_byte(), + opr->get_workspace_in_bytes( + inner_src_layout, inner_weight_layout, + inner_bias_layout, inner_z_layout, + inner_dst_layout, nullptr)}); +} + +size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_in_bytes( + const SizeArgs& args) const { + auto trans_bundle = get_workspace_bundle(nullptr, args); + return trans_bundle.total_size_in_bytes(); +} + +void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( + const ExecArgs& args) const { + auto relayout_nchw_nchw4 = args.handle->create_operator(); + RelayoutFormat::Param in_trans; + in_trans.mode = RelayoutFormat::Param::Mode::NCHW_NCHW4; + in_trans.group = args.filter_meta.group; + relayout_nchw_nchw4->param() = in_trans; + + auto relayout_weight = args.handle->create_operator(); + RelayoutFormat::Param weight_trans; + weight_trans.mode = RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT; + relayout_weight->param() = weight_trans; + + auto relayout_nchw4_nchw = args.handle->create_operator(); + RelayoutFormat::Param nchw4_nchw_trans; + nchw4_nchw_trans.mode = RelayoutFormat::Param::Mode::NCHW4_NCHW; + nchw4_nchw_trans.oc = args.dst_layout->shape[1]; + nchw4_nchw_trans.group = args.filter_meta.group; + relayout_nchw4_nchw->param() = nchw4_nchw_trans; + + auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); + TensorLayout inner_src_layout; + TensorLayout inner_weight_layout; + TensorLayout inner_dst_layout; + TensorLayout inner_bias_layout; + TensorLayout inner_z_layout; + make_inner_layout(args, inner_src_layout, inner_weight_layout, + inner_dst_layout, inner_bias_layout, inner_z_layout); + TensorND inner_src(bundle.get(0), inner_src_layout); + TensorND inner_weight(bundle.get(1), inner_weight_layout); + TensorND inner_dst(bundle.get(2), inner_dst_layout); + TensorND inner_bias(bundle.get(3), inner_bias_layout); + TensorND inner_z(bundle.get(4), inner_z_layout); + + Param inner_conv_param = args.opr->param(); + inner_conv_param.format = Param::Format::NCHW4; + auto inner_opr = args.handle->create_operator(); + inner_opr->param() = inner_conv_param; + + relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {}); + relayout_weight->exec(*args.filter_tensor, inner_weight, {}); + if (inner_bias_layout.ndim > 0) { + relayout_nchw_nchw4->exec(*args.bias_tensor, inner_bias, {}); + } + if (inner_z_layout.ndim > 0) { + relayout_nchw_nchw4->exec(*args.z_tensor, inner_z, {}); + } + inner_opr->exec(inner_src, inner_weight, inner_bias, inner_z, inner_dst, + nullptr, Workspace((dt_byte*)bundle.get(5), bundle.get_size(5))); + relayout_nchw4_nchw->exec(inner_dst, *args.dst_tensor, {}); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/opr_impl.cpp b/dnn/src/cuda/conv_bias/opr_impl.cpp index 84453117b..17383adc2 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/conv_bias/opr_impl.cpp @@ -201,6 +201,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( return algo; } + if (sm_algo_pack.fallback_nchw_qs8.is_available_reproducible( + args, reproducible, workspace_limit_in_bytes)) { + return &sm_algo_pack.fallback_nchw_qs8; + } + if (args.src_layout->dtype.enumv() != DTypeTrait::enumv) { if (reproducible) { return megdnn::get_reproducible_algo( diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index 916e61c37..66327d3f4 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -49,6 +49,7 @@ public: class AlgoChanwiseSmall; class AlgoChanwise8x8x32; class AlgoCUDNNConv; + class AlgoFallbackNCHWQS8; class AlgoInplaceMatmul; class AlgoMatmul; class AlgoMatmul8x8x32; diff --git a/dnn/src/cuda/relayout_format/opr_impl.cpp b/dnn/src/cuda/relayout_format/opr_impl.cpp index f1333ecce..61cd58960 100644 --- a/dnn/src/cuda/relayout_format/opr_impl.cpp +++ b/dnn/src/cuda/relayout_format/opr_impl.cpp @@ -24,6 +24,9 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, megdnn_assert( param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || param().mode == param::RelayoutFormat::Mode::NCHW_NCHW4 || + param().mode == + param::RelayoutFormat::Mode::NCHW_NCHW4_WEIGHT || + param().mode == param::RelayoutFormat::Mode::NCHW4_NCHW || param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || param().mode == @@ -76,7 +79,9 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); } - if (param().mode == Param::Mode::NCHW_NCHW4) { + if (param().mode == Param::Mode::NCHW_NCHW4 || + param().mode == Param::Mode::NCHW4_NCHW || + param().mode == Param::Mode::NCHW_NCHW4_WEIGHT) { bool is_usable = relayout_format::RelayoutFormatFast::usable( src.layout, dst.layout); megdnn_assert(is_usable, @@ -85,10 +90,12 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, src.layout.to_string().c_str(), src.layout.dtype.name(), dst.layout.to_string().c_str(), dst.layout.dtype.name()); relayout_format::RelayoutFormatFast::exec(src, dst, - cuda_stream(this->handle())); + cuda_stream(this->handle()), + param().mode, param().group); } else { - TensorLayout exec_src, exec_dst; - deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); + TensorLayout exec_src, exec_dst, exec_workspace; + deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, + exec_dst); TensorND exec_src_nd{src.raw_ptr, exec_src}; TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; handle()->create_operator()->exec(exec_src_nd, diff --git a/dnn/src/cuda/relayout_format/relayout_format.cpp b/dnn/src/cuda/relayout_format/relayout_format.cpp index f8d322bbd..be8612d99 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cpp +++ b/dnn/src/cuda/relayout_format/relayout_format.cpp @@ -36,7 +36,9 @@ bool relayout_format::RelayoutFormatFast::usable( void relayout_format::RelayoutFormatFast::exec(const TensorND& src, const TensorND& dst, - cudaStream_t stream) { + cudaStream_t stream, + RelayoutFormat::Param::Mode mode, + int group) { size_t ih = src.layout[2]; size_t iw = src.layout[3]; size_t hw = ih * iw; @@ -49,11 +51,22 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, if (src.layout.dtype.enumv() == DTypeEnum::Uint8) { src_zero_point = 128; } - if (hw % 4 == 0) { - relayout_format_cuda_exec<4>(src, dst, stream, src_scale, dst_scale, - src_zero_point, dst_zero_point); + if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4) { + if (hw % 4 == 0) { + relayout_format_cuda_nchw_nchw4<4>(src, dst, stream, src_scale, + dst_scale, src_zero_point, + dst_zero_point, group); + } else { + relayout_format_cuda_nchw_nchw4<1>(src, dst, stream, src_scale, + dst_scale, src_zero_point, + dst_zero_point, group); + } + + } else if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT) { + relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); + } else if (mode == RelayoutFormat::Param::Mode::NCHW4_NCHW) { + relayout_format_cuda_nchw4_nchw(src, dst, stream, group); } else { - relayout_format_cuda_exec<1>(src, dst, stream, src_scale, dst_scale, - src_zero_point, dst_zero_point); + megdnn_throw("only support nchw_nchw4 nchw4_nchw layout_format"); } } diff --git a/dnn/src/cuda/relayout_format/relayout_format.cu b/dnn/src/cuda/relayout_format/relayout_format.cu index 438bf2678..c2637c604 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cu +++ b/dnn/src/cuda/relayout_format/relayout_format.cu @@ -80,10 +80,30 @@ struct CudaPostProcess { template <> struct CudaPostProcess { + CudaPostProcess(){}; CudaPostProcess(float, uint8_t, float, uint8_t){}; inline __device__ int8_t operator()(int8_t val) { return val; } }; +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, int, float dst_scale, int) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + m_src_type_cvt = CudaDTypeParamImpl(src_scale); + }; + inline __device__ int operator()(int val) { + float med_var = m_src_type_cvt.dequantize(dt_qint32(val)); + return m_dst_type_cvt.quantize(med_var).as_int32(); + } +}; +template <> +struct CudaPostProcess { + CudaPostProcess(float, int, float, int){}; + inline __device__ int operator()(int val) { return val; } +}; + template struct DTypeRWHelper; template <> @@ -98,6 +118,18 @@ struct DTypeRWHelper { using DstDtype = char4; }; +template <> +struct DTypeRWHelper { + using InnerDtype = int; + using DstDtype = int4; +}; + +template <> +struct DTypeRWHelper { + using InnerDtype = int4; + using DstDtype = int4; +}; + template struct Translayout { @@ -165,6 +197,11 @@ inline __device__ char4 make_zero_pad(const char zero_point) { return {zero_point, zero_point, zero_point, zero_point}; } +template <> +inline __device__ int4 make_zero_pad(const char zero_point) { + return {zero_point, zero_point, zero_point, zero_point}; +} + template inline __device__ void write_helper(DstDtype* ptr, DstDtype val) { *ptr = val; @@ -176,14 +213,14 @@ inline __device__ void write_helper(char4* ptr, char4 val) { *rel_ptr = *(int32_t*)(&val); } -template struct RelayoutKern { using InnerDtype = typename DTypeRWHelper::InnerDtype; using DstDtype = typename DTypeRWHelper::DstDtype; static inline __device__ void write(DstType* dst_ptr, - char4 (&dst_width)[pack_w]) { + DstDtype (&dst_width)[pack_w]) { DstDtype* dst_inner_ptr = (DstDtype*)dst_ptr; #pragma unroll for (int iw_idx = 0; iw_idx < pack_w; ++iw_idx) { @@ -213,6 +250,17 @@ struct RelayoutKern { } } + static inline __device__ void fake_read(const SrcType* src_ptr, + InnerDtype (&read_channel)[pack_c], + const int ic_stride, + const int remain_ic, + const InnerDtype zero_point) { +#pragma unroll + for (int ic_idx = 0; ic_idx < pack_c; ++ic_idx) { + read_channel[ic_idx] = zero_point; + } + } + static inline __device__ void core_relayout_kern( const SrcType* src, DstType* dst, const int src_offset_base, const int dst_offset_base, const int ic_offset, const int ic_stride, @@ -220,12 +268,20 @@ struct RelayoutKern { CudaPostProcess& post_process, const char zero_point) { InnerDtype read_channel[pack_c]; - if (with_pad) { + if (all_pad) { const InnerDtype zero_pad = make_zero_pad(zero_point); - read_with_pad(src + ic_offset + src_offset_base, read_channel, - ic_stride, remain_ic, zero_pad); + fake_read(src + ic_offset + src_offset_base, read_channel, + ic_stride, remain_ic, zero_pad); } else { - read(src + ic_offset + src_offset_base, read_channel, ic_stride); + if (with_pad) { + const InnerDtype zero_pad = + make_zero_pad(zero_point); + read_with_pad(src + ic_offset + src_offset_base, read_channel, + ic_stride, remain_ic, zero_pad); + } else { + read(src + ic_offset + src_offset_base, read_channel, + ic_stride); + } } DstDtype dst_width[pack_w]; Translayout __global__ void kern_nchw_nchw4( - const SrcType* src, DstType* dst, int ic, int ihw, int n_stride_src, - int ic_stride, int n_stride_dst, + const SrcType* src, DstType* dst, int in_n, int ic, int ihw, + int n_stride_src, int ic_stride, int n_stride_dst, CudaPostProcess post_process, - const char zero_point) { + const char zero_point, const int group, const int ocpg) { constexpr int pack_c = 4; const int n_idx = blockIdx.y; const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; const int ihw_offset = ihw_block_idx * pack_w; if (ihw_offset < ihw) { - const int ic_block = ic / pack_c; - const int remain_ic = ic % pack_c; const int src_offset_base = n_idx * n_stride_src + ihw_offset; const int dst_offset_base = n_idx * n_stride_dst + ihw_offset * pack_c; + if (n_idx < in_n) { + const int icpg = ic / group; + const int ic_block = icpg / pack_c; + const int remain_ic = icpg % pack_c; + const int src_group_stride = icpg * ic_stride; + const int dst_group_stride = ocpg * ic_stride; + for (int g_idx = 0; g_idx < group; ++g_idx) { + const int src_offset = + src_offset_base + g_idx * src_group_stride; + const int dst_offset = + dst_offset_base + g_idx * dst_group_stride; + for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { + const int ic_offset = ic_blk_idx * pack_c * ic_stride; + RelayoutKern::core_relayout_kern(src, dst, + src_offset, + dst_offset, + ic_offset, + ic_stride, + remain_ic, + post_process, + zero_point); + } + + if (remain_ic > 0) { + const int ic_offset = ic_block * pack_c * ic_stride; + RelayoutKern::core_relayout_kern(src, dst, + src_offset, + dst_offset, + ic_offset, + ic_stride, + remain_ic, + post_process, + zero_point); + } + } + } else { + //! pad n + const int ic_full_block = group * ocpg / pack_c; + for (int ic_blk_idx = 0; ic_blk_idx < ic_full_block; ++ic_blk_idx) { + RelayoutKern::core_relayout_kern(src, dst, + src_offset_base, + dst_offset_base, 0, + ic_stride, 0, + post_process, + zero_point); + } + } + } +} - for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { - const int ic_offset = ic_blk_idx * pack_c * ic_stride; - RelayoutKern::core_relayout_kern(src, dst, - src_offset_base, - dst_offset_base, - ic_offset, ic_stride, - remain_ic, - post_process, - zero_point); +__global__ void kern_nchw4_nchw(const int8_t* src, int8_t* dst, int n, int ic, + int oc, int oh, int ow, int group) { + constexpr int pack_w = 1; + constexpr int pack_ic = 4; + const int n_idx = blockIdx.y; + const int hw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int hw_offset = hw_block_idx * pack_w; + const int hw = oh * ow; + const int n_stride_src = ic * hw; + const int n_stride_dst = oc * hw; + const int c_stride = hw; + + if (hw_offset < hw) { + const int icpg = ic / group; + const int ocpg = oc / group; + const int src_group_stride = icpg * c_stride; + const int dst_group_stride = ocpg * c_stride; + for (int g_idx = 0; g_idx < group; ++g_idx) { + const int oc_block = ocpg / pack_ic; + const int remain_oc = ocpg % pack_ic; + const int src_offset_base = n_idx * n_stride_src + + g_idx * src_group_stride + + hw_offset * pack_ic; + const int dst_offset_base = + n_idx * n_stride_dst + g_idx * dst_group_stride + hw_offset; + + for (int ic_blk_idx = 0; ic_blk_idx < oc_block; ++ic_blk_idx) { + const int oc_offset = ic_blk_idx * pack_ic * c_stride; + char4 temp = *(char4*)(src + src_offset_base + oc_offset); + dst[dst_offset_base + oc_offset + 0 * c_stride] = temp.x; + dst[dst_offset_base + oc_offset + 1 * c_stride] = temp.y; + dst[dst_offset_base + oc_offset + 2 * c_stride] = temp.z; + dst[dst_offset_base + oc_offset + 3 * c_stride] = temp.w; + } + + if (remain_oc > 0) { + const int oc_offset = oc_block * pack_ic * c_stride; + char4 temp = *(char4*)(src + src_offset_base + oc_offset); + dst[dst_offset_base + oc_offset + 0 * c_stride] = temp.x; + if (remain_oc > 1) { + dst[dst_offset_base + oc_offset + 1 * c_stride] = temp.y; + } + if (remain_oc > 2) { + dst[dst_offset_base + oc_offset + 2 * c_stride] = temp.z; + } + } } + } +} - if (remain_ic > 0) { - const int ic_offset = ic_block * pack_c * ic_stride; - RelayoutKern::core_relayout_kern(src, dst, - src_offset_base, - dst_offset_base, - ic_offset, ic_stride, - remain_ic, - post_process, - zero_point); +__global__ void kern_nchw_nchw4_weight( + const char* src, char* dst, int in_oc, int ic, int ihw, + int oc_stride_src, int ic_stride, int oc_stride_dst, + int group_stride_src, int group_stride_dst, const char zero_point, + CudaPostProcess + post_process) { + typedef char SrcType; + typedef char DstType; + typedef dtype::QuantizedS8 DnnSrcType; + typedef dtype::QuantizedS8 DnnDstType; + constexpr int pack_c = 4; + constexpr int pack_w = 1; + constexpr bool same_scale = true; + + const int group_idx = blockIdx.z; + const int oc_idx = blockIdx.y; + const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int ihw_offset = ihw_block_idx * pack_w; + + if (ihw_offset < ihw) { + const int ic_block = ic / pack_c; + const int remain_ic = ic % pack_c; + const int src_offset_base = group_idx * group_stride_src + + oc_idx * oc_stride_src + ihw_offset; + const int dst_offset_base = group_idx * group_stride_dst + + oc_idx * oc_stride_dst + + ihw_offset * pack_c; + if (oc_idx < in_oc) { + for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { + const int ic_offset = ic_blk_idx * pack_c * ic_stride; + RelayoutKern::core_relayout_kern(src, dst, + src_offset_base, + dst_offset_base, + ic_offset, + ic_stride, + remain_ic, + post_process, + zero_point); + } + + if (remain_ic > 0) { + const int ic_offset = ic_block * pack_c * ic_stride; + RelayoutKern::core_relayout_kern(src, dst, + src_offset_base, + dst_offset_base, + ic_offset, + ic_stride, + remain_ic, + post_process, + zero_point); + } + } else { + //! pad oc per group + const int ic_full_block = (ic + pack_c - 1) / pack_c; + for (int ic_blk_idx = 0; ic_blk_idx < ic_full_block; ++ic_blk_idx) { + const int ic_offset = ic_blk_idx * pack_c * ic_stride; + RelayoutKern::core_relayout_kern(src, dst, + src_offset_base, + dst_offset_base, + ic_offset, + ic_stride, + remain_ic, + post_process, + zero_point); + } } } } @@ -284,20 +490,23 @@ __global__ void kern_nchw_nchw4( } // namespace template -void relayout_format::relayout_format_cuda_exec( +void relayout_format::relayout_format_cuda_nchw_nchw4( const TensorND& src, const TensorND& dst, const cudaStream_t& stream, const float src_scale, const float dst_scale, - const uint8_t src_zero_point, const uint8_t dst_zero_point) { + const uint8_t src_zero_point, const uint8_t dst_zero_point, + const int group) { constexpr int pack_oc = 4; - const int n = src.layout[0]; - const int c = src.layout[1]; + const int in_n = src.layout[0]; + const int out_n = dst.layout[0]; + const int ic = src.layout[1]; const int h = src.layout[2]; const int w = src.layout[3]; + const int oc = dst.layout[1] * pack_oc; const int hw = h * w; - const int oc_block = DIVUP(c, pack_oc); - const int n_stride_src = c * hw; + const int ocpg = oc / group; + const int n_stride_src = ic * hw; const int ic_stride = hw; - const int n_stride_dst = oc_block * pack_oc * h * w; + const int n_stride_dst = oc * hw; auto& src_layout = src.layout; auto& dst_layout = dst.layout; @@ -307,26 +516,26 @@ void relayout_format::relayout_format_cuda_exec( int nr_threads = query_blocksize_for_kernel( \ kern_nchw_nchw4); \ - const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), n); \ + const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), out_n); \ const dim3 thread_dim(nr_threads); \ kern_nchw_nchw4<<>>( \ - (SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, c, hw, \ - n_stride_src, ic_stride, n_stride_dst, \ + (SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, in_n, ic, \ + hw, n_stride_src, ic_stride, n_stride_dst, \ CudaPostProcess( \ src_scale, src_zero_point, dst_scale, dst_zero_point), \ - src_zero_point); \ + src_zero_point, group, ocpg); \ } else { \ int nr_threads = query_blocksize_for_kernel( \ kern_nchw_nchw4); \ - const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), n); \ + const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), out_n); \ const dim3 thread_dim(nr_threads); \ kern_nchw_nchw4<<>>( \ - (SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, c, hw, \ - n_stride_src, ic_stride, n_stride_dst, \ + (SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, in_n, ic, \ + hw, n_stride_src, ic_stride, n_stride_dst, \ CudaPostProcess( \ src_scale, src_zero_point, dst_scale, dst_zero_point), \ - src_zero_point); \ + src_zero_point, group, ocpg); \ } if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Uint8 && @@ -340,6 +549,10 @@ void relayout_format::relayout_format_cuda_exec( dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) { RUN_KERNEL(same_scale, dtype::QuantizedS8, dtype::QuantizedS8, char, char); + } else if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32 && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32) { + RUN_KERNEL(same_scale, dtype::QuantizedS32, dtype::QuantizedS32, int, + int); } else { megdnn_assert(0, "not support dtype %s %s", src_layout.dtype.name(), dst_layout.dtype.name()); @@ -356,16 +569,65 @@ bool relayout_format::relayout_format_cuda_usable( (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized8Asymm && dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) || (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8 && - dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8); + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) || + (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32 && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32); return is_all_continue && is_all_int8; } -template void relayout_format::relayout_format_cuda_exec<1>( +void relayout_format::relayout_format_cuda_nchw4_nchw( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream, + const int group) { + constexpr int pack_w = 1; + const int n = src.layout[0]; + const int ic = src.layout[1] * 4; + const int h = src.layout[2]; + const int w = src.layout[3]; + const int oc = dst.layout[1]; + const int hw = h * w; + int nr_threads = query_blocksize_for_kernel(kern_nchw4_nchw); + const dim3 block_dim(DIVUP(hw, nr_threads * pack_w), n); + const dim3 thread_dim(nr_threads); + kern_nchw4_nchw<<>>( + (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, n, ic, oc, h, w, group); +} + +void relayout_format::relayout_format_cuda_nchw_nchw4_weight( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream) { + constexpr int pack_c = 4; + const bool is_group = src.layout.ndim == 5; + const int group = is_group ? src.layout[0] : 1; + const int oc = is_group ? src.layout[1] : src.layout[0]; + const int ic = is_group ? src.layout[2] : src.layout[1]; + const int kh = is_group ? src.layout[3] : src.layout[2]; + const int kw = is_group ? src.layout[4] : src.layout[3]; + const int hw = kh * kw; + const int oc_round = ROUNDUP(oc, pack_c); + const int ic_round = ROUNDUP(ic, pack_c); + const int ic_stride = hw; + const int oc_stride_src = ic * ic_stride; + const int oc_stride_dst = ic_round * ic_stride; + const int group_stride_src = oc * oc_stride_src; + const int group_stride_dst = oc_round * oc_stride_dst; + + int nr_threads = 32; + const dim3 block_dim(DIVUP(hw, nr_threads), oc_round, group); + const dim3 thread_dim(nr_threads); + + kern_nchw_nchw4_weight<<>>( + (char*)src.raw_ptr, (char*)dst.raw_ptr, oc, ic, hw, oc_stride_src, + ic_stride, oc_stride_dst, group_stride_src, group_stride_dst, 0, + {}); +} + +template void relayout_format::relayout_format_cuda_nchw_nchw4<1>( const TensorND& src, const TensorND& dst, const cudaStream_t& stream, const float src_scale, const float dst_scale, - const uint8_t src_zero_point, const uint8_t dst_zero_point); + const uint8_t src_zero_point, const uint8_t dst_zero_point, + const int group); -template void relayout_format::relayout_format_cuda_exec<4>( +template void relayout_format::relayout_format_cuda_nchw_nchw4<4>( const TensorND& src, const TensorND& dst, const cudaStream_t& stream, const float src_scale, const float dst_scale, - const uint8_t src_zero_point, const uint8_t dst_zero_point); + const uint8_t src_zero_point, const uint8_t dst_zero_point, + const int group); diff --git a/dnn/src/cuda/relayout_format/relayout_format.cuh b/dnn/src/cuda/relayout_format/relayout_format.cuh index 4f7d614f9..3dd2058fb 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cuh +++ b/dnn/src/cuda/relayout_format/relayout_format.cuh @@ -20,16 +20,25 @@ namespace cuda { namespace relayout_format { template -void relayout_format_cuda_exec(const TensorND& src, const TensorND& dst, - const cudaStream_t& stream, - const float src_scale = 1.f, - const float dst_scale = 1.f, - const uint8_t src_zero_point = 0, - const uint8_t dst_zero_point = 0); +void relayout_format_cuda_nchw_nchw4(const TensorND& src, const TensorND& dst, + const cudaStream_t& stream, + const float src_scale = 1.f, + const float dst_scale = 1.f, + const uint8_t src_zero_point = 0, + const uint8_t dst_zero_point = 0, + const int group = 1); bool relayout_format_cuda_usable(const TensorLayout& src_layout, const TensorLayout& dst_layout); +void relayout_format_cuda_nchw4_nchw(const TensorND& src, const TensorND& dst, + const cudaStream_t& stream, + const int group); + +void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, + const TensorND& dst, + const cudaStream_t& stream); + } // namespace relayout_format } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/relayout_format/relayout_format.h b/dnn/src/cuda/relayout_format/relayout_format.h index 460fb6a70..ba3905b77 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.h +++ b/dnn/src/cuda/relayout_format/relayout_format.h @@ -13,6 +13,7 @@ #pragma once #include "megdnn/basic_types.h" +#include "megdnn/oprs.h" #include "src/cuda/utils.cuh" namespace megdnn { @@ -23,7 +24,8 @@ struct RelayoutFormatFast { static bool usable(const TensorLayout& src_layout, const TensorLayout& dst_layout); static void exec(const TensorND& src, const TensorND& dst, - cudaStream_t stream); + cudaStream_t stream, RelayoutFormat::Param::Mode mode, + int group); }; } // namespace relayout_format diff --git a/dnn/src/cuda/utils.cuh b/dnn/src/cuda/utils.cuh index 6c28bf6f2..9cac97ec1 100644 --- a/dnn/src/cuda/utils.cuh +++ b/dnn/src/cuda/utils.cuh @@ -86,6 +86,7 @@ #endif #define DIVUP(x, y) (((x) + (y)-1) / (y)) +#define ROUNDUP(x, y) (DIVUP(x, y) * (y)) #define KERN_FOR(i, n) \ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ diff --git a/dnn/src/naive/relayout_format/opr_impl.cpp b/dnn/src/naive/relayout_format/opr_impl.cpp index d970696f3..f226ea582 100644 --- a/dnn/src/naive/relayout_format/opr_impl.cpp +++ b/dnn/src/naive/relayout_format/opr_impl.cpp @@ -23,6 +23,71 @@ using namespace megdnn; using namespace naive; namespace { + +template +void recursive_cp(const TensorND& dst, const TensorND& src, size_t idx = 0, + size_t src_offset = 0, size_t dst_offset = 0) { + if (idx < (src.layout.ndim - 1)) { + for (size_t i = 0; i < src.layout[idx]; ++i) { + recursive_cp(dst, src, idx + 1, + src_offset + i * src.layout.stride[idx], + dst_offset + i * dst.layout.stride[idx]); + } + } else { + for (size_t i = 0; i < src.layout[idx]; ++i) { + ((ctype*)dst.raw_ptr)[dst_offset + i * dst.layout.stride[idx]] = + ((ctype*)src + .raw_ptr)[src_offset + i * src.layout.stride[idx]]; + } + } +} + +void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { + switch (src.layout.dtype.enumv()) { +#define cb(name, ctype) \ + case (DTypeEnum::name): { \ + recursive_cp(dst, src); \ + break; \ + } + + cb(Float32, dt_float32); + cb(Int32, dt_int32); + cb(QuantizedS32, dt_int32); + cb(QuantizedS8, dt_qint8); + + default: + megdnn_assert(0, "not support dtype %s", src.layout.dtype.name()); +#undef cb + } +} + +void extract_from_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, + size_t group) { + megdnn_assert(dst.layout.is_contiguous() && src.layout.is_contiguous(), + "dst %s, src %s", dst.layout.to_string().c_str(), + src.layout.to_string().c_str()); + const size_t type_size = dst.layout.dtype.size(); + const size_t n = dst.layout[0]; + const size_t n_stride_dst = dst.layout.stride[0]; + const size_t n_stride_src = src.layout.stride[0]; + const size_t ocpg = dst.layout[1] / group; + const size_t icpg = src.layout[1] / group; + const size_t dst_hw = dst.layout[2] * dst.layout[3]; + const size_t src_hw = src.layout[2] * src.layout[3]; + megdnn_assert(dst_hw == src_hw); + for (size_t nid = 0; nid < n; ++nid) { + const size_t n_offset_dst = nid * n_stride_dst * type_size; + const size_t n_offset_src = nid * n_stride_src * type_size; + for (size_t gid = 0; gid < group; ++gid) { + memcpy((char*)dst.raw_ptr + n_offset_dst + + gid * ocpg * dst_hw * type_size, + (char*)src.raw_ptr + n_offset_src + + gid * icpg * src_hw * type_size, + ocpg * dst_hw * type_size); + } + } +}; + template void padding_src_to_workspace(dtype* dptr, const dtype* sptr, size_t N, size_t IC, size_t IH, size_t IW) { @@ -86,6 +151,8 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, } cb(Float32, dt_float32); + cb(Int32, dt_int32); + cb(QuantizedS32, dt_int32); cb(QuantizedS8, dt_qint8); case (DTypeEnum::Quantized8Asymm): { @@ -159,6 +226,19 @@ void do_copy_diff_q8_q8(const TensorND& dst, const TensorND& src) { ++isrc; } } +void do_copy_diff_q32_q32(const TensorND& dst, const TensorND& src) { + auto isrc = tensor_iter_valonly::ctype>(src) + .begin(); + auto idst = tensor_iter_valonly::ctype>(dst) + .begin(); + auto src_dt_parm = src.layout.dtype.param(); + auto dst_dt_parm = dst.layout.dtype.param(); + for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) { + *idst = dst_dt_parm.quantize(src_dt_parm.dequantize(*isrc)); + ++idst; + ++isrc; + } +} void do_copy_diff_u8_q8(const TensorND& dst, const TensorND& src) { auto isrc = @@ -263,12 +343,34 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, return n * c * h * w * src.dtype.size(); } case Param::Mode::NCHW_NCHW4: { + size_t group = param().group; size_t n = src[0]; - size_t c = round_up(src[1], 4_z); + size_t c = group * round_up(src[1] / group, 4_z); size_t h = src[2]; size_t w = src[3]; return n * c * h * w * src.dtype.size(); } + case Param::Mode::NCHW4_NCHW: { + return src.total_nr_elems() * src.dtype.size(); + } + case Param::Mode::NCHW_NCHW4_WEIGHT: { + if (src.ndim == 4) { + size_t oc = round_up(src[0], 4_z); + size_t ic = round_up(src[1], 4_z); + size_t h = src[2]; + size_t w = src[3]; + return oc * ic * h * w * src.dtype.size(); + } else if (src.ndim == 5) { + size_t group = src[0]; + size_t oc = round_up(src[1], 4_z); + size_t ic = round_up(src[2], 4_z); + size_t h = src[3]; + size_t w = src[4]; + return group * oc * ic * h * w * src.dtype.size(); + } else { + megdnn_throw("no support nchw_nchw4_weight"); + } + } case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: { megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); if (src[1] % 4 == 0) @@ -288,13 +390,15 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { megdnn_assert(src.layout.dtype.category() == DTypeCategory::FLOAT || + src.layout.dtype.enumv() == DTypeEnum::Int32 || (src.layout.dtype.enumv() == DTypeEnum::Uint8 && dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) || src.layout.dtype.category() == DTypeCategory::QUANTIZED); check_exec(src.layout, dst.layout, workspace.size); HandleImpl* m_handle = static_cast(handle()); - TensorLayout exec_src, exec_dst; - deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); + TensorLayout exec_src, exec_dst, exec_workspace; + deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, + exec_dst); TensorND exec_src_nd{src.raw_ptr, exec_src}; TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; // clean dst @@ -371,6 +475,19 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, } \ } \ MIDOUT_END(); + +#define cb2(_idx, _pack_size, _mode, _src_layout, _workspace_layout) \ + MIDOUT_BEGIN(megdnn_naive_relayout_format, \ + midout_iv(Param::Mode::_mode)) { \ + size_t val = _src_layout[_idx]; \ + if (val % _pack_size != 0) { \ + memset(workspace.raw_ptr, 0, exec_src.span().dist_byte()); \ + padding_to_workspace({workspace.raw_ptr, _workspace_layout}, \ + {src.raw_ptr, _src_layout}); \ + exec_src_nd.raw_ptr = workspace.raw_ptr; \ + } \ + } \ + MIDOUT_END(); cb(1, 8, NCHW_NCHW88); } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { @@ -384,10 +501,62 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { cb(1, 4, NCHW_NCHW4_IC_SMALL); } else if (param().mode == Param::Mode::NCHW_NCHW4) { - cb(1, 4, NCHW_NCHW4); + if (param().group == 1) { + cb(1, 4, NCHW_NCHW4); + } else { + TensorLayout group_src_layout{{src.layout[0], param().group, + src.layout[1] / param().group, + src.layout[2], src.layout[3]}, + src.layout.dtype, + src.layout.format}; + TensorLayout workspace_layout{ + {src.layout[0], param().group, + div_ceil(src.layout[1] / param().group, 4_z) * 4_z, + src.layout[2], src.layout[3]}, + src.layout.dtype, + src.layout.format}; + cb2(2, 4, NCHW_NCHW4, group_src_layout, workspace_layout); + + } } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); + } else if (param().mode == Param::Mode::NCHW_NCHW4_WEIGHT) { +#undef cb +#define cb(_idx0, _idx1, _pack_size, _mode) \ + MIDOUT_BEGIN(megdnn_naive_relayout_format, \ + midout_iv(Param::Mode::_mode)) { \ + size_t val0 = src.layout[_idx0]; \ + size_t val1 = src.layout[_idx1]; \ + if (val0 % _pack_size != 0 || val1 % _pack_size != 0) { \ + memset(workspace.raw_ptr, 0, exec_src.span().dist_byte()); \ + padding_to_workspace({workspace.raw_ptr, exec_workspace}, src); \ + exec_src_nd.raw_ptr = workspace.raw_ptr; \ + } \ + } \ + MIDOUT_END(); + if (src.layout.ndim == 4) { + cb(0, 1, 4, NCHW_NCHW4_WEIGHT); + } else if (src.layout.ndim == 5) { + cb(1, 2, 4, NCHW_NCHW4_WEIGHT); + } + } else if (param().mode == Param::Mode::NCHW4_NCHW) { + if (exec_workspace.total_nr_elems() == dst.layout.total_nr_elems()) { + m_handle->relayout_opr()->exec( + exec_src_nd, {dst.raw_ptr, exec_workspace}, handle()); + return; + } else { + m_handle->relayout_opr()->exec( + exec_src_nd, {workspace.raw_ptr, exec_workspace}, handle()); + TensorLayout workspace_layout{{src.layout[0], src.layout[1] * 4, + src.layout[2], src.layout[3]}, + src.layout.dtype, + src.layout.format}; + extract_from_workspace(exec_dst_nd, + {workspace.raw_ptr, workspace_layout}, + param().group); + return; + } } if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && @@ -417,6 +586,15 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, }; MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); return; + } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { + TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; + check_layout_and_canonize(src0.layout, src0.layout); + auto func = [](const TensorND& dst, const TensorND& src) { + do_copy_diff_q32_q32(dst, src); + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); + return; } else { m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); } diff --git a/dnn/test/cuda/conv_bias.cpp b/dnn/test/cuda/conv_bias.cpp index 495041056..95be9b716 100644 --- a/dnn/test/cuda/conv_bias.cpp +++ b/dnn/test/cuda/conv_bias.cpp @@ -215,8 +215,7 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_QS8) { .execs({src_shape, filter_shape, bias_shape, {}, {}}); } } -//! close for cu111 ci, reopen it when bug fixed -#if CUDA_VERSION < 11000 + TEST_F(CUDA, CONV_BIAS_NCHW_QS8) { //! not support NonlineMode::SIGMOID and NonlineMode::H_SWISH require_compute_capability(6, 1); @@ -274,8 +273,97 @@ TEST_F(CUDA, CONV_BIAS_NCHW_QS8) { } } } + + for (NonlineMode mode : {NonlineMode::RELU, + NonlineMode::IDENTITY, NonlineMode::H_SWISH}) { + for (size_t g : {13}) { + for (size_t b : {1, 2}) { + for (size_t ic : {13}) { + for (size_t oc : {13}) { + for (size_t fh : {1, 3}) { + for (int ph : {static_cast(fh / 2)}) { + for (int sh : {1, 2}) { + size_t ih = 16, iw = 16; + param.nonlineMode = mode; + param.stride_h = param.stride_w = sh; + param.pad_h = param.pad_w = ph; + param.sparse = + ConvBias::Param::Sparse::GROUP; + checker.set_param(param) + .execs({{b, ic, ih, iw}, + {g, oc/g, ic/g, fh, fh}, + {1, oc, 1, 1}, + {}, + {}}); + } + } + } + } + } + } + } + } + { + size_t ih = 16, iw = 16, b = 1, oc = 14, ic = 14; + size_t fh = 3, sh = 1, ph = 1; + param.nonlineMode = NonlineMode::IDENTITY; + param.stride_h = param.stride_w = sh; + param.pad_h = param.pad_w = ph; + param.sparse = ConvBias::Param::Sparse::DENSE; + checker.set_param(param).execs( + {{b, ic, ih, iw}, {oc, ic, fh, fh}, {}, {}, {}}); + } } -#endif + +TEST_F(CUDA, CONV_BIAS_NCHW_QS8_FUSE_Z) { + require_compute_capability(6, 1); + Checker checker(handle_cuda()); + UniformIntRNG int_rng{-128, 127}; + using NonlineMode = ConvBias::Param::NonlineMode; + + ConvBias::Param param; + param.format = ConvBias::Param::Format::NCHW; + + checker.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(3, dtype::QuantizedS8(0.25f)) + .set_dtype(4, dtype::QuantizedS8(0.25f)) + .set_rng(0, &int_rng) + .set_rng(1, &int_rng) + .set_rng(2, &int_rng) + .set_rng(3, &int_rng); + + for (NonlineMode mode : + {NonlineMode::RELU, NonlineMode::IDENTITY, NonlineMode::H_SWISH}) { + for (size_t b : {2}) { + for (size_t ic : {6, 16}) { + for (size_t oc : {4}) { + for (size_t fh : {1, 3}) { + for (int ph : {static_cast(fh / 2)}) { + for (int sh : {1, 2}) { + size_t ih = 16, iw = 16; + param.nonlineMode = mode; + param.stride_h = param.stride_w = sh; + param.pad_h = param.pad_w = ph; + param.sparse = ConvBias::Param::Sparse::DENSE; + const size_t oh = (ih - fh + 2 * ph) / sh + 1; + const size_t ow = (iw - fh + 2 * ph) / sh + 1; + checker.set_param(param).execs( + {{b, ic, ih, iw}, + {oc, ic, fh, fh}, + {1, oc, 1, 1}, + {b, oc, oh, ow}, + {}}); + } + } + } + } + } + } + } +} + #if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, BENCHMARK_CONV_BIAS_NCHW4_INT8) { require_compute_capability(6, 1); diff --git a/dnn/test/cuda/relayout_format.cpp b/dnn/test/cuda/relayout_format.cpp index 6af7cf04f..a2752becb 100644 --- a/dnn/test/cuda/relayout_format.cpp +++ b/dnn/test/cuda/relayout_format.cpp @@ -34,9 +34,43 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { checker.execs({{22, 23, 24, 25, 4}, {}}); } +TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4_NCHW) { + Checker checker(handle_cuda()); + UniformIntRNG rng{-50, 50}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NCHW4_NCHW; + + checker.set_dtype(0, dtype::QuantizedS8{0.1f}) + .set_dtype(1, dtype::QuantizedS8{0.1f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{1, 1, 2, 2, 4}, {}}); + + checker.set_dtype(0, dtype::QuantizedS8{0.1f}) + .set_dtype(1, dtype::QuantizedS8{0.1f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{22, 23, 24, 25, 4}, {}}); + + param.oc = 90; + checker.set_dtype(0, dtype::QuantizedS8{0.1f}) + .set_dtype(1, dtype::QuantizedS8{0.1f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{22, 23, 24, 25, 4}, {}}); + + param.oc = 16; + param.group = 8; + checker.set_dtype(0, dtype::QuantizedS8{0.1f}) + .set_dtype(1, dtype::QuantizedS8{0.1f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{11, 16, 22, 33, 4}, {}}); +} + TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4) { Checker checker(handle_cuda()); - UniformIntRNG rng{0, 50}; + UniformIntRNG rng{-50, 50}; param::RelayoutFormat param; param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4; @@ -55,6 +89,12 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4) { .set_rng(0, &rng) .set_param(param) .execs({{n, c, h, w}, {}}); + + checker.set_dtype(0, dtype::QuantizedS32{1.f}) + .set_dtype(1, dtype::QuantizedS32{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{n, c, h, w}, {}}); } } } @@ -77,6 +117,59 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4) { .set_rng(0, &rng) .set_param(param) .execs({{1, 6, 768, 1280}, {}}); + + param.group = 2; + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{8, 6, 300, 300}, {}}); + + param.group = 3; + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{8, 6, 300, 300}, {}}); +} + +TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_WEIGHT) { + Checker checker(handle_cuda()); + UniformIntRNG rng{-50, 50}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_WEIGHT; + + for (size_t oc : {1, 3, 4, 16, 33}) { + for (size_t ic : {1, 2, 3, 4, 8, 9, 11, 16, 33}) { + for (size_t h : {3, 5, 7}) { + for (size_t w : {3, 5, 7}) { + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{oc, ic, h, w}, {}}); + } + } + } + } + + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{13, 13, 5, 5}, {}}); + + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{4, 16, 16, 3, 3}, {}}); + + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{4, 13, 11, 3, 3}, {}}); } TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_DEFAULT) { diff --git a/dnn/test/naive/relayout_format.cpp b/dnn/test/naive/relayout_format.cpp index 8165d273d..1c6bc3cc3 100644 --- a/dnn/test/naive/relayout_format.cpp +++ b/dnn/test/naive/relayout_format.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "test/naive/fixture.h" @@ -17,6 +18,136 @@ using namespace megdnn; using namespace test; +TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW4_NCHW) { + Checker checker(handle(), /* check_dispatch */ false); + + { + auto tensor_nchw4 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {1, 3, 5, 7, 2, 4, 6, 8, 9, 11, 13, 15, 10, 12, 14, 16}); + auto tensor_nchw = TensorValue( + {1, 8, 1, 2}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + + RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW4_NCHW}; + + checker.set_param(param).exect(Testcase{tensor_nchw4, {}}, + Testcase{{}, tensor_nchw}); + } + { + auto tensor_nchw4 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {1, 3, 5, 7, 2, 4, 6, 8, 9, 11, 13, 15, 10, 12, 14, 16}); + auto tensor_nchw = + TensorValue({1, 7, 1, 2}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + + RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW4_NCHW}; + param.oc = 7; + + checker.set_param(param).exect(Testcase{tensor_nchw4, {}}, + Testcase{{}, tensor_nchw}); + } + { + auto tensor_nchw4 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {1, 3, 5, 7, 2, 4, 6, 8, 9, 11, 13, 15, 10, 12, 14, 16}); + auto tensor_nchw = + TensorValue({1, 6, 1, 2}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14}); + + RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW4_NCHW}; + param.oc = 6; + param.group = 2; + + checker.set_param(param).exect(Testcase{tensor_nchw4, {}}, + Testcase{{}, tensor_nchw}); + } +} + +TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW_NCHW4_WEIGHT) { + Checker checker(handle(), /* check_dispatch */ false); + + { + auto tensor_nchw = TensorValue({2, 2, 2, 2}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 7, 8, + + 9, 10, 11, 12, 13, 14, 15, 16}); + auto tensor_nchw4 = TensorValue( + {4, 1, 2, 2, 4}, dtype::Float32(), + {1, 5, 0, 0, 2, 6, 0, 0, 3, 7, 0, 0, 4, 8, 0, 0, + 9, 13, 0, 0, 10, 14, 0, 0, 11, 15, 0, 0, 12, 16, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + RelayoutFormat::Param param{ + RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT}; + + checker.set_param(param).exect(Testcase{tensor_nchw, {}}, + Testcase{{}, tensor_nchw4}); + } + { + auto tensor_nchw = TensorValue({2, 2, 1, 2, 2}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 7, 8, + + 9, 10, 11, 12, 13, 14, 15, 16}); + auto tensor_nchw4 = TensorValue( + {2, 4, 1, 2, 2, 4}, dtype::Float32(), + {1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, + 0, 6, 0, 0, 0, 7, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, + 12, 0, 0, 0, 13, 0, 0, 0, 14, 0, 0, 0, 15, 0, 0, 0, 16, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + RelayoutFormat::Param param{ + RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT}; + + checker.set_param(param).exect(Testcase{tensor_nchw, {}}, + Testcase{{}, tensor_nchw4}); + } +} + +TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW_NCHW4) { + Checker checker(handle(), /* check_dispatch */ false); + + { + auto tensor_nchw = TensorValue( + {1, 8, 1, 2}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto tensor_nchw4 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {1, 3, 5, 7, 2, 4, 6, 8, 9, 11, 13, 15, 10, 12, 14, 16}); + RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW_NCHW4}; + + checker.set_param(param).exect(Testcase{tensor_nchw, {}}, + Testcase{{}, tensor_nchw4}); + } + { + auto tensor_nchw = TensorValue( + {1, 8, 1, 2}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto tensor_nchw4 = TensorValue( + {1, 4, 1, 2, 4}, dtype::Float32(), + {1, 3, 0, 0, 2, 4, 0, 0, 5, 7, 0, 0, 6, 8, 0, 0, + 9, 11, 0, 0, 10, 12, 0, 0, 13, 15, 0, 0, 14, 16, 0, 0}); + RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW_NCHW4}; + param.group = 4; + checker.set_param(param).exect(Testcase{tensor_nchw, {}}, + Testcase{{}, tensor_nchw4}); + } + { + auto tensor_nchw = TensorValue({1, 6, 1, 2}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto tensor_nchw4 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {1, 3, 5, 0, 2, 4, 6, 0, 7, 9, 11, 0, 8, 10, 12, 0}); + RelayoutFormat::Param param{RelayoutFormat::Param::Mode::NCHW_NCHW4}; + param.group = 2; + checker.set_param(param).exect(Testcase{tensor_nchw, {}}, + Testcase{{}, tensor_nchw4}); + } +} + TEST_F(NAIVE, RELAYOUT_FORMAT_NCHW88) { Checker checker(handle(), /* check_dispatch */ false); diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 50823e9fc..4ef7808e7 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1914,8 +1914,17 @@ TEST(TestEnableTensorCore, Nchw4Nchw) { unpack_vector(gopt::optimize_for_inference({y}, options), y_no_tc); } auto nr_dimshuffle = find_opr_num(y_opt); + if (format == opr::ConvBias::Param::Format::NCHW4) { +#if CUDA_VERSION >= 10020 +//! try_conv_reformat_nchw322nchw4 used when cuda_version >= 10020 + ASSERT_EQ(1u, nr_dimshuffle); +#else + ASSERT_EQ(2u, nr_dimshuffle); +#endif + } else { + ASSERT_EQ(2u, nr_dimshuffle); + } std::string json_name; - ASSERT_EQ(2u, nr_dimshuffle); if (format == opr::ConvBias::Param::Format::NCHW4) { json_name = "TestGoptInference.Nchw4Nchw.NCHW4.json"; } else { @@ -2856,8 +2865,6 @@ TEST(TestGoptInference, EnableCHWN4ShuffleRemove) { } #endif -//! close for cu111 ci, reopen it when bug fixed -#if CUDA_VERSION < 11000 TEST(TestGoptInference, ConvertFormatNCHW4GPU) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0"); @@ -2936,7 +2943,6 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { func->execute(); MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); } -#endif #endif @@ -3076,8 +3082,6 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } -//! close for cu111 ci, reopen it when bug fixed -#if CUDA_VERSION < 11000 TEST(TestGoptInference, ConvertFormatNCHW4Ic3) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0"); @@ -3139,7 +3143,6 @@ TEST(TestGoptInference, ConvertFormatNCHW4Ic3) { func->execute(); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } -#endif TEST(TestGoptInference, ConvertFormatNCHW88) { HostTensorGenerator<> gen; diff --git a/src/opr/impl/tensor_manip.sereg.h b/src/opr/impl/tensor_manip.sereg.h index 1926136d4..f64c1b571 100644 --- a/src/opr/impl/tensor_manip.sereg.h +++ b/src/opr/impl/tensor_manip.sereg.h @@ -183,7 +183,9 @@ namespace opr { } MGB_REG_OPR_SHALLOW_COPY(ParamPackConcat, opr_shallow_copy_param_pack_concat); - MGB_SEREG_OPR(RelayoutFormat, 1); + + using RelayoutFormatV1 = opr::RelayoutFormat; + MGB_SEREG_OPR(RelayoutFormatV1, 1); } // namespace opr } // namespace mgb diff --git a/src/tensorrt/test/opr_replace.cpp b/src/tensorrt/test/opr_replace.cpp index 2635a4827..e38183757 100644 --- a/src/tensorrt/test/opr_replace.cpp +++ b/src/tensorrt/test/opr_replace.cpp @@ -1977,8 +1977,7 @@ TEST(TestTensorRTReplace, FuseConvAdd) { MGB_ASSERT_TENSOR_NEAR(outputs[0], outputs[2], 1e-3); MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3); } -//! close for cu111 ci, reopen it when bug fixed -#if CUDA_VERSION < 11000 + TEST(TestTensorRTReplace, FuseConvAddNchw2nchw4) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0"); @@ -2044,7 +2043,7 @@ TEST(TestTensorRTReplace, FuseConvAddNchw2nchw4) { MGB_ASSERT_TENSOR_NEAR(outputs[0], outputs[1], 1e-3); } -#endif + #endif // MGB_ENABLE_TENSOR_RT -- GitLab