From ea93927d199a80070953dc8ef40947b21ec9f963 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 20 Mar 2020 17:41:18 +0800 Subject: [PATCH] fix(dnn/fallbackls): fix fallback convolution and all conv_bias algos GitOrigin-RevId: e37fbe0ffe9f8ab6325f41d9ee6857115630b17a --- dnn/src/fallback/conv_bias/algos.cpp | 6 + dnn/src/fallback/conv_bias/im2col/algos.cpp | 3 +- dnn/src/fallback/conv_bias/opr_impl.cpp | 98 ++++++----- dnn/src/fallback/conv_bias/opr_impl.h | 37 ++-- dnn/src/fallback/convolution/algos.cpp | 7 +- dnn/src/fallback/convolution/algos.h | 27 ++- dnn/src/fallback/convolution/opr_impl.cpp | 34 +--- dnn/src/fallback/convolution/opr_impl.h | 37 ++++ dnn/src/x86/conv_bias/f32/algos.cpp | 165 ++++++++++-------- dnn/src/x86/conv_bias/f32/algos.h | 12 +- dnn/src/x86/conv_bias/int8/algos.cpp | 23 +-- .../int8/avx2_direct_conv_stride1.cpp | 50 +++--- .../int8/avx2_direct_conv_stride2.cpp | 51 +++--- 13 files changed, 312 insertions(+), 238 deletions(-) diff --git a/dnn/src/fallback/conv_bias/algos.cpp b/dnn/src/fallback/conv_bias/algos.cpp index bb4eceb5..70cb6a1c 100644 --- a/dnn/src/fallback/conv_bias/algos.cpp +++ b/dnn/src/fallback/conv_bias/algos.cpp @@ -213,11 +213,17 @@ SmallVector ConvBiasImpl::AlgoNaive::dispatch_kerns( const NCBKernParam& param, const NCBKernIndex& ncb_index) { MIDOUT_BEGIN(megdnn_fallback_naive, 2) { + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; size_t thread_id = ncb_index.thread_id; auto thread_param = param; thread_param.workspace_ptr = reinterpret_cast( reinterpret_cast(param.workspace_ptr) + thread_id * workspace_per_thread); + thread_param.filter_ptr = param.filter(group_id); + thread_param.dst_ptr = param.dst(batch_id, group_id); + thread_param.src_ptr = param.src(batch_id, group_id); + thread_param.bias_ptr = param.bias(batch_id, group_id); kern_default(opr_param, thread_param); } MIDOUT_END(); diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index b2712c8c..dc1c785d 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -111,7 +111,6 @@ static void copy_padding_kern(WorkspaceBundle bundle, size_t channel_id = ncb_index.ndrange_id[2]; size_t padding_group_size = IH2 * IW2 * IC; - size_t input_channel_offset = IH * IW * channel_id; size_t workspace_channel_offset = IH2 * IW2 * channel_id; size_t workspace_group_offset = group_id * padding_group_size; size_t workspace_batch_offset = @@ -123,7 +122,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, src_zp = param.src_type.param().zero_point; } src_ctype* src = const_cast( - param.src(batch_id, group_id) + input_channel_offset); + param.src(batch_id, group_id, channel_id)); src_ctype* src2; src2 = static_cast( bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 7985770b..f9f1a2ad 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -246,10 +246,9 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) { auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param); for (auto&& kernel : ncb_kerns) { - auto run = [=](size_t index, size_t thread_id) { - auto copy_param = param; + auto run = [kernel, param](size_t index, size_t thread_id) { CpuNDRange ndrange_id(kernel.global_size, index); - kernel.kern(copy_param, {thread_id, ndrange_id}); + kernel.kern(param, {thread_id, ndrange_id}); }; static_cast(handle())->dispatch_kern( run, kernel.global_size.total_size()); @@ -328,28 +327,29 @@ const char* ConvBiasImpl::get_algorithm_set_name() const { namespace megdnn{ namespace fallback { -//! when format is nchwxx and channel wise mode, multi group will pack - //! together, so pack_group_size is the number of packed group + template -const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_id, - size_t group_pack_size) const { - src_type.assert_is_compatible_ctype(); +const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id, + size_t channel_pack_id, + size_t group_pack_size, + size_t channel_pack_size) const { size_t batch_offset = batch_id * inp_bs * src_type.size(); - size_t group_offset = group_pack_size * group_id * filter_meta.icpg * + size_t group_offset = group_pack_size * group_pack_id * filter_meta.icpg * isz[0] * isz[1] * src_type.size(); + size_t channel_offset = channel_pack_size * channel_pack_id * isz[0] * + isz[1] * src_type.size(); return reinterpret_cast(reinterpret_cast(src_ptr) + - batch_offset + group_offset); + batch_offset + group_offset + channel_offset); } -//! when format is nchwxx and channel wise mode, multi group will pack -//! together, so pack_group_size is the number of packed group + template -const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, +const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, size_t pack_group_size) const { size_t group_offset = 0_z; switch (filter_meta.format) { case Param::Format::NCHW: { - group_offset = pack_group_size * group_id * filter_meta.icpg * + group_offset = pack_group_size * group_pack_id * filter_meta.icpg * filter_meta.ocpg * filter_meta.spatial[0] * filter_meta.spatial[1] * filter_type.size(); break; @@ -359,15 +359,15 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, size_t icpg = filter_meta.icpg; size_t ocpg = filter_meta.ocpg; //! four format of weight layout - //! 1. {oc/8, ic/8, fh, fw, 8, 8}, 2. {g, oc/8, ic/8, fh, - //! fw, 8, 8} - //! 3. {g/8, 1, 1, fh, fw, 8, 8}, 3. {oc/8 ,fh, fw, ic, 8} + //! 1. {oc/8, ic/8, fh, fw, 8, 8}, + //! 2. {g, oc/8, ic/8, fh, fw, 8, 8}, + //! 3. {g/8, fh, fw, 1, 1, 8}, 4. {oc/8, fh, fw, ic, 8} megdnn_assert((icpg % 8 == 0 && ocpg % 8 == 0) || (group % 8 == 0 && icpg == 1 && ocpg == 1 && pack_group_size > 1) || (group == 1 && ocpg % 8 == 0), "The filter shepe is not right of nchw88"); - group_offset = pack_group_size * group_id * filter_meta.icpg * + group_offset = pack_group_size * group_pack_id * filter_meta.icpg * filter_meta.ocpg * filter_meta.spatial[0] * filter_meta.spatial[1] * filter_type.size(); @@ -380,7 +380,7 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, //! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8} //! 3. {g, alpha, alpha, oc, ic, 8, 8} //! 4. {alpha, alpha, oc, ic} - group_offset = pack_group_size * group_id * filter_meta.icpg * + group_offset = pack_group_size * group_pack_id * filter_meta.icpg * filter_meta.ocpg * (filter_meta.spatial[0] + output_block_size - 1) * (filter_meta.spatial[1] + output_block_size - 1) * @@ -388,58 +388,66 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, break; } default: - megdnn_assert("other filter format is not support yet"); + megdnn_assert(0, "other filter format is not support yet"); } return reinterpret_cast(reinterpret_cast(filter_ptr) + group_offset); } -//! when format is nchwxx and channel wise mode, multi group will pack -//! together, so pack_group_size is the number of packed group template -const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_id, - size_t group_pack_size) const { - bias_type.assert_is_compatible_ctype(); +const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_pack_id, + size_t channel_pack_id, + size_t group_pack_size, + size_t channel_pack_size) const { size_t batch_offset = 0_z; size_t group_offset = 0_z; + size_t channel_offset = 0_z; if (bias_mode == BiasMode::BIAS) { batch_offset = batch_id * bias_bs * bias_type.size(); - group_offset = group_pack_size * group_id * filter_meta.ocpg * osz[0] * - osz[1] * bias_type.size(); + group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * + osz[0] * osz[1] * bias_type.size(); + channel_offset = channel_pack_size * channel_pack_id * osz[0] * osz[1] * + bias_type.size(); } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - group_offset = group_pack_size * group_id * filter_meta.ocpg * + group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * bias_type.size(); + channel_offset = channel_pack_size * channel_pack_id * bias_type.size(); } return reinterpret_cast(reinterpret_cast(bias_ptr) + - batch_offset + group_offset); + batch_offset + group_offset + channel_offset); } -//! when format is nchwxx and channel wise mode, multi group will pack -//! together, so pack_group_size is the number of packed group template -T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_id, - size_t group_pack_size) const { - dst_type.assert_is_compatible_ctype(); +T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_pack_id, + size_t channel_pack_id, + size_t group_pack_size, + size_t channel_pack_size) const { size_t batch_offset = batch_id * out_bs * dst_type.size(); - size_t group_offset = group_pack_size * group_id * filter_meta.ocpg * + size_t group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * osz[0] * osz[1] * dst_type.size(); + size_t channel_offset = channel_pack_size * channel_pack_id * osz[0] * + osz[1] * dst_type.size(); return reinterpret_cast(reinterpret_cast(dst_ptr) + - batch_offset + group_offset); + batch_offset + group_offset + channel_offset); } -#define INST(T) \ - template const T* ConvBiasImpl::NCBKernParam::src( \ - size_t batch_id, size_t group_id, size_t group_pack_size) const; \ - template const T* ConvBiasImpl::NCBKernParam::bias( \ - size_t batch_id, size_t group_id, size_t group_pack_size) const; \ - template const T* ConvBiasImpl::NCBKernParam::filter( \ - size_t group_id, size_t group_pack_size) const; \ - template T* ConvBiasImpl::NCBKernParam::dst( \ - size_t batch_id, size_t group_id, size_t group_pack_size) const; +#define INST(T) \ + template const T* ConvBiasImpl::NCBKernParam::src( \ + size_t batch_id, size_t group_id, size_t channel_id, \ + size_t group_pack_size, size_t channel_pack_size) const; \ + template const T* ConvBiasImpl::NCBKernParam::bias( \ + size_t batch_id, size_t group_id, size_t channel_id, \ + size_t group_pack_size, size_t channel_pack_size) const; \ + template const T* ConvBiasImpl::NCBKernParam::filter( \ + size_t group_id, size_t group_pack_size) const; \ + template T* ConvBiasImpl::NCBKernParam::dst( \ + size_t batch_id, size_t group_id, size_t channel_id, \ + size_t group_pack_size, size_t channel_pack_size) const; #define INST_DT(d) INST(DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT) +INST(void) #undef INST #undef INST_DT } // namespace fallback diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 1b121382..c4d081bc 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -103,10 +103,32 @@ public: src_type.assert_is_compatible_ctype(); return static_cast(src_ptr); } + //! when format is nchwxx, multi channel will pack into one + //! chnannel_pack_id. pack_channel_size is the number of packed channel + //! when format is nchwxx and channel wise, multi group will pack into + //! one group_pack_id. group_pack_size is the number of packed group + //! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} + template + const T* src(size_t batch_id, size_t group_pack_id, + size_t channel_pack_id = 0, size_t group_pack_size = 1, + size_t channel_pack_size = 1) const; + + template + const T* bias(size_t batch_id, size_t group_pack_id, + size_t channel_pack_id = 0, size_t group_pack_size = 1, + size_t channel_pack_size = 1) const; template - const T* src(size_t batch_id, size_t group_id, - size_t group_pack_size = 1_z) const; + T* dst(size_t batch_id, size_t group_pack_id, + size_t channel_pack_id = 0, size_t group_pack_size = 1, + size_t channel_pack_size = 1) const; + + //! when format is nchwxx and channel wise, multi group will pack into + //! one group_pack_id. group_pack_size is the number of packed group + //! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} + template + const T* filter(size_t group_pack_id, + size_t pack_group_size = 1_z) const; template const T* filter() const { @@ -114,29 +136,18 @@ public: return static_cast(filter_ptr); } - template - const T* filter(size_t group_id, size_t pack_group_size = 1_z) const; - template const T* bias() const { bias_type.assert_is_compatible_ctype(); return static_cast(bias_ptr); } - template - const T* bias(size_t batch_id, size_t group_id, - size_t group_pack_size = 1_z) const; - template T* dst() const { dst_type.assert_is_compatible_ctype(); return static_cast(dst_ptr); } - template - T* dst(size_t batch_id, size_t group_id, - size_t group_pack_size = 1_z) const; - template T* workspace() const { return static_cast(workspace_ptr); diff --git a/dnn/src/fallback/convolution/algos.cpp b/dnn/src/fallback/convolution/algos.cpp index 9ceb9489..340270ce 100644 --- a/dnn/src/fallback/convolution/algos.cpp +++ b/dnn/src/fallback/convolution/algos.cpp @@ -197,9 +197,12 @@ ConvolutionImpl::AlgoFallback::dispatch_kern( auto kern_fallback = [workspace_per_thread](const NCBKernParam& p, const NCBKernIndex& ncb_index) { UNPACK_CONV_F32_NCB_KERN_SIZES(p); + size_t batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0]; MEGDNN_MARK_USED_VAR(N); - auto src = p.src(), filter = p.filter(); - auto dst = p.dst(); + auto src = p.src(batch_id, group_id), + filter = p.filter(group_id); + auto dst = p.dst(batch_id, group_id); size_t thread_id = ncb_index.thread_id; void* workspace_ptr = reinterpret_cast( reinterpret_cast(p.workspace_ptr) + diff --git a/dnn/src/fallback/convolution/algos.h b/dnn/src/fallback/convolution/algos.h index b24c7fbe..091be295 100644 --- a/dnn/src/fallback/convolution/algos.h +++ b/dnn/src/fallback/convolution/algos.h @@ -20,18 +20,25 @@ namespace fallback { template void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p, - const ConvolutionImpl::NCBKernIndex& /*index*/) { + const ConvolutionImpl::NCBKernIndex& ncb_index) { + size_t batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0]; auto IC = p.filter_meta.icpg, IH = p.isz[0], IW = p.isz[1], OC = p.filter_meta.ocpg, OH = p.osz[0], OW = p.osz[1]; + ptrdiff_t fstrd = p.filter_meta.icpg * p.filter_meta.ocpg * + p.filter_meta.spatial[0] * p.filter_meta.spatial[1] * + p.filter_type.size(); + ptrdiff_t istrd = p.filter_meta.icpg * p.src_type.size(); + ptrdiff_t ostrd = p.filter_meta.ocpg * p.dst_type.size(); TensorND src, dst; - src.raw_ptr = const_cast(p.src_ptr); - dst.raw_ptr = p.dst_ptr; src.layout.dtype = p.src_type; dst.layout.dtype = p.dst_type; if (p.filter_meta.format == param::Convolution::Format::NCHW) { - src.layout.init_contiguous_stride({1, IC, IH, IW}); - dst.layout.init_contiguous_stride({1, OC, OH, OW}); + istrd *= p.isz[0] * p.isz[1]; + ostrd *= p.osz[0] * p.osz[1]; + src.layout.init_contiguous_stride({1, IC, IH, IW}); + dst.layout.init_contiguous_stride({1, OC, OH, OW}); } else { // Must be NHWC megdnn_assert( @@ -41,9 +48,17 @@ void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p, src.layout.init_contiguous_stride({1, IH, IW, IC}); dst.layout.init_contiguous_stride({1, OH, OW, OC}); } + src.raw_ptr = reinterpret_cast( + reinterpret_cast(p.src_ptr) + + batch_id * p.inp_bs * p.src_type.size() + group_id * istrd); + dst.raw_ptr = reinterpret_cast( + reinterpret_cast(p.dst_ptr) + + batch_id * p.out_bs * p.dst_type.size() + group_id * ostrd); + ST* filter = reinterpret_cast( + reinterpret_cast(p.filter_ptr) + group_id * fstrd); std::copy(p.inp_s, p.inp_s + 4, src.layout.stride); std::copy(p.out_s, p.out_s + 4, dst.layout.stride); - naive::convolution::forward(src, p.filter(), dst, + naive::convolution::forward(src, filter, dst, p.filter_meta); } diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 88a02010..50e809da 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -189,41 +189,15 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param( void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo) { auto kerns = ncb_algo_dispatch_kern(algo, param); - size_t src_batch_stride = param.inp_bs * param.src_type.size(); - size_t dst_batch_stride = param.out_bs * param.dst_type.size(); - auto group = param.filter_meta.group; auto fallback_handle = handle(); for (auto kernel : kerns) { megdnn_assert(param.filter_meta.format == Param::Format::NCHW || - param.filter_meta.format == Param::Format::NHWC || + param.filter_meta.format == Param::Format::NHWC || + param.filter_meta.format == Param::Format::NCHW88, "invalid conv format"); - - ptrdiff_t istrd = 0, fstrd = 0, ostrd = 0; - fstrd = param.filter_meta.icpg * param.filter_meta.ocpg * - param.filter_meta.spatial[0] * param.filter_meta.spatial[1] * - param.filter_type.size(); - istrd = param.filter_meta.icpg * param.src_type.size(); - ostrd = param.filter_meta.ocpg * param.dst_type.size(); - if (param.filter_meta.format == Param::Format::NCHW) { - istrd *= param.isz[0] * param.isz[1]; - ostrd *= param.osz[0] * param.osz[1]; - } else { - // must be NHWC. No action performed. - } - auto run = [=](size_t index, size_t thread_id) { - auto copy_param = param; + auto run = [param, kernel](size_t index, size_t thread_id) { CpuNDRange ndrange_id(kernel.global_size, index); - size_t group_id = ndrange_id[0]; - size_t batch_id = ndrange_id[1]; - megdnn_assert(group_id < group, - "The group id should smaller than gruop"); - //! The kernel ptr point to batch index - incr_ptr(copy_param.src_ptr, - group_id * istrd + batch_id * src_batch_stride); - incr_ptr(copy_param.filter_ptr, group_id * fstrd); - incr_ptr(copy_param.dst_ptr, - group_id * ostrd + batch_id * dst_batch_stride); - kernel.kern(copy_param, {thread_id, ndrange_id}); + kernel.kern(param, {thread_id, ndrange_id}); }; static_cast(fallback_handle) ->dispatch_kern(run, kernel.global_size.total_size()); diff --git a/dnn/src/fallback/convolution/opr_impl.h b/dnn/src/fallback/convolution/opr_impl.h index ebf2254b..467e2879 100644 --- a/dnn/src/fallback/convolution/opr_impl.h +++ b/dnn/src/fallback/convolution/opr_impl.h @@ -100,6 +100,43 @@ public: T* workspace() const { return static_cast(workspace_ptr); } + + //! when format is nchwxx and channel wise, multi group will pack into + //! one group_pack_id. group_pack_size is the number of packed group + //! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} + template + T* dst(size_t batch_id, size_t group_pack_id, + size_t group_pack_size = 1_z) const{ + size_t batch_offset = batch_id * out_bs * dst_type.size(); + size_t group_offset = group_pack_size * group_pack_id * + filter_meta.ocpg * osz[0] * osz[1] * + dst_type.size(); + return reinterpret_cast(reinterpret_cast(dst_ptr) + + batch_offset + group_offset); + } + + template + const T* src(size_t batch_id, size_t group_pack_id, + size_t group_pack_size = 1_z) const { + size_t batch_offset = batch_id * inp_bs * src_type.size(); + size_t group_offset = group_pack_size * group_pack_id * + filter_meta.icpg * isz[0] * isz[1] * + src_type.size(); + return reinterpret_cast(reinterpret_cast(src_ptr) + + batch_offset + group_offset); + + } + + template + const T* filter(size_t group_pack_id, + size_t pack_group_size = 1_z) const { + size_t group_offset = pack_group_size * group_pack_id * + filter_meta.icpg * filter_meta.ocpg * + filter_meta.spatial[0] * + filter_meta.spatial[1] * filter_type.size(); + return reinterpret_cast( + reinterpret_cast(filter_ptr) + group_offset); + } }; static void* const sm_fallback_conv_algo_type; diff --git a/dnn/src/x86/conv_bias/f32/algos.cpp b/dnn/src/x86/conv_bias/f32/algos.cpp index 41a4eee5..01a629bf 100644 --- a/dnn/src/x86/conv_bias/f32/algos.cpp +++ b/dnn/src/x86/conv_bias/f32/algos.cpp @@ -58,43 +58,45 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, } } // namespace -#define GET_KERN \ - auto fm = param.filter_meta; \ - size_t N = param.n; \ - size_t IC = param.filter_meta.icpg; \ - size_t OC = param.filter_meta.ocpg; \ - size_t group = fm.group; \ - WorkspaceBundle wbundle = get_bundle(param); \ - SmallVector ret_kerns; \ - if (m_large_group) { \ - auto exec_one_group = [wbundle](const NCBKernParam& kern_param, \ - const NCBKernIndex& ncb_index) { \ - auto fm = kern_param.filter_meta; \ - size_t IC = fm.icpg; \ - size_t OC = fm.ocpg; \ - WorkspaceBundle bundle = wbundle; \ - for (size_t ic = 0; ic < IC; ic++) { \ - copy_padding_kern( \ - bundle, kern_param, \ - {ncb_index.thread_id, {ncb_index.thread_id, 0, ic}}); \ - } \ - for (size_t oc = 0; oc < OC; oc++) { \ - do_conv_kern( \ - bundle, kern_param, \ - {ncb_index.thread_id, {ncb_index.thread_id, 0, oc}}); \ - } \ - }; \ - ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \ - } else { \ - WorkspaceBundle bundle = wbundle; \ - auto copy_padding = \ - std::bind(copy_padding_kern, bundle, std::placeholders::_1, \ - std::placeholders::_2); \ - ret_kerns.push_back({copy_padding, {group, N, IC}}); \ - auto do_conv = std::bind(do_conv_kern, bundle, std::placeholders::_1, \ - std::placeholders::_2); \ - ret_kerns.push_back({do_conv, {group, N, OC}}); \ - } \ +#define GET_KERN \ + auto fm = param.filter_meta; \ + size_t N = param.n; \ + size_t IC = param.filter_meta.icpg; \ + size_t OC = param.filter_meta.ocpg; \ + size_t group = fm.group; \ + WorkspaceBundle wbundle = get_bundle(param); \ + SmallVector ret_kerns; \ + if (m_large_group) { \ + auto exec_one_group = [wbundle](const NCBKernParam& kern_param, \ + const NCBKernIndex& ncb_index) { \ + auto fm = kern_param.filter_meta; \ + size_t IC = fm.icpg; \ + size_t OC = fm.ocpg; \ + WorkspaceBundle bundle = wbundle; \ + for (size_t ic = 0; ic < IC; ic++) { \ + copy_padding_kern(bundle, kern_param, ncb_index, \ + {ncb_index.thread_id, 0, ic}); \ + } \ + for (size_t oc = 0; oc < OC; oc++) { \ + do_conv_kern(bundle, kern_param, ncb_index, \ + {ncb_index.thread_id, 0, oc}); \ + } \ + }; \ + ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \ + } else { \ + auto copy_padding = [wbundle](const NCBKernParam& kern_param, \ + const NCBKernIndex& ncb_index) { \ + copy_padding_kern(wbundle, kern_param, ncb_index, \ + ncb_index.ndrange_id); \ + }; \ + ret_kerns.push_back({copy_padding, {group, N, IC}}); \ + auto do_conv = [wbundle](const NCBKernParam& kern_param, \ + const NCBKernIndex& ncb_index) { \ + do_conv_kern(wbundle, kern_param, ncb_index, \ + ncb_index.ndrange_id); \ + }; \ + ret_kerns.push_back({do_conv, {group, N, OC}}); \ + } \ return ret_kerns; /* ===================== direct algo ===================== */ @@ -145,7 +147,8 @@ size_t ConvBiasImpl::AlgoDirect::get_workspace( //! Process one input channel copy padding void ConvBiasImpl::AlgoDirect::copy_padding_kern( WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index) { + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -160,14 +163,18 @@ void ConvBiasImpl::AlgoDirect::copy_padding_kern( get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); bool rectify_src = (IH != IH2 || IW != IW2); size_t padding_group_size = IH2 * IW2 * IC; - const float* sptr = static_cast(kern_param.src_ptr) + - ncb_index.ndrange_id[2] * IH * IW; + size_t batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0]; + size_t channel_id = workspace_ids[2]; + const float* sptr = static_cast( + kern_param.src(batch_id, group_id)) + + channel_id * IH * IW; bundle.set(kern_param.workspace_ptr); //! Used for get the workspace offset - size_t workspace_group_id = ncb_index.ndrange_id[0], - workspace_batch_id = ncb_index.ndrange_id[1], - workspace_channel_id = ncb_index.ndrange_id[2]; + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], + workspace_channel_id = workspace_ids[2]; //! If large group, each thread has its own worspace, set group_id with //! thread_id if (rectify_src) { @@ -234,7 +241,8 @@ void ConvBiasImpl::AlgoDirect::copy_padding_kern( //! compute one output channel void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t IH = kern_param.isz[0]; @@ -265,14 +273,16 @@ void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle, megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { bias_offset = 1_z; } + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; //! Used for get the workspace offset - size_t workspace_group_id = ncb_index.ndrange_id[0], - workspace_batch_id = ncb_index.ndrange_id[1], - oc = ncb_index.ndrange_id[2]; - const float* sptr = kern_param.src(); - const float* filter = kern_param.filter() + oc * FH * FW * IC; - const float* bias_ptr = kern_param.bias() + oc * bias_offset; - float* dst = kern_param.dst() + oc * OH * OW; + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + const float* sptr = kern_param.src(batch_id, group_id); + const float* filter = kern_param.filter(group_id) + oc * FH * FW * IC; + const float* bias_ptr = + kern_param.bias(batch_id, group_id) + oc * bias_offset; + float* dst = kern_param.dst(batch_id, group_id) + oc * OH * OW; if (rectify_src) { sptr = static_cast(bundle.get(0)) + workspace_group_id * padding_group_size + @@ -358,7 +368,8 @@ size_t ConvBiasImpl::AlgoDirectStride2::get_workspace( //! Process one input channel copy padding void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index) { + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -373,13 +384,17 @@ void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); bool rectify_src = need_src_copy(kern_param); size_t padding_group_size = IH2 * IW2 * IC; - const float* sptr = static_cast(kern_param.src_ptr) + - ncb_index.ndrange_id[2] * IH * IW; + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; + size_t channel_id = workspace_ids[2]; + const float* sptr = static_cast( + kern_param.src(batch_id, group_id)) + + channel_id * IH * IW; bundle.set(kern_param.workspace_ptr); //! Used for get the workspace offset - size_t workspace_group_id = ncb_index.ndrange_id[0], - workspace_batch_id = ncb_index.ndrange_id[1], - workspace_channel_id = ncb_index.ndrange_id[2]; + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], + workspace_channel_id = workspace_ids[2]; if (rectify_src) { //! copy to sptr_base to eliminate padding effect float* sptr_base = static_cast(bundle.get(0)) + @@ -397,7 +412,7 @@ void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( //! compute one output channel void ConvBiasImpl::AlgoDirectStride2::do_conv_kern( WorkspaceBundle bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t IH = kern_param.isz[0]; @@ -439,14 +454,17 @@ void ConvBiasImpl::AlgoDirectStride2::do_conv_kern( megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { bias_offset = 1_z; } + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; //! Used for get the workspace offset - size_t workspace_group_id = ncb_index.ndrange_id[0], - workspace_batch_id = ncb_index.ndrange_id[1], - oc = ncb_index.ndrange_id[2]; - const float* sptr = kern_param.src(); - const float* filter = kern_param.filter() + oc * FH * FW * IC; - const float* bias_ptr = kern_param.bias() + oc * bias_offset; - float* dst = kern_param.dst() + oc * OH * OW; + size_t workspace_group_id = workspace_ids[0], + workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + const float* sptr = kern_param.src(batch_id, group_id); + const float* filter = + kern_param.filter(group_id) + oc * FH * FW * IC; + const float* bias_ptr = + kern_param.bias(batch_id, group_id) + oc * bias_offset; + float* dst = kern_param.dst(batch_id, group_id) + oc * OH * OW; if (rectify_src) { sptr = static_cast(bundle.get(0)) + workspace_group_id * padding_group_size + @@ -547,23 +565,22 @@ MatrixMul* ConvBiasImpl::AlgoMatrixMul::get_matmul_opr() { } void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param, - const NCBKernIndex&) { + const NCBKernIndex& ncb_index) { UNPACK_CONV_F32_NCB_KERN_SIZES(param); auto IH2 = IH + 2 * PH; auto IW2 = IW + 2 * PW; + size_t group_id = ncb_index.ndrange_id[0]; bool is_xcorr = !param.filter_meta.should_flip; auto bundle = get_bundle(param); bundle.set(param.workspace_ptr); // workspace = tmp..src2 for (size_t n = 0; n < N; ++n) { - float* src = const_cast(param.src()) + n * param.inp_bs; - float* dst = param.dst() + n * param.out_bs; - float* bias_ptr = - static_cast(const_cast(param.bias_ptr)); - if (param.bias_mode == megdnn::BiasMode::BIAS) { - bias_ptr += n * param.out_bs; - } + float* src = const_cast(param.src(n, group_id)); + float* dst = param.dst(n, group_id); + float* bias_ptr = static_cast( + const_cast(param.bias(n, group_id))); + float *B, *src2; if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { // special case: 1x1 @@ -613,7 +630,7 @@ void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param, { TensorND A_, B_, C_; A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Float32()); - A_.raw_ptr = const_cast(param.filter()); + A_.raw_ptr = const_cast(param.filter(group_id)); B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Float32()); B_.raw_ptr = B; C_.layout = TensorLayout({OC, OH * OW}, dtype::Float32()); diff --git a/dnn/src/x86/conv_bias/f32/algos.h b/dnn/src/x86/conv_bias/f32/algos.h index 05a8b57a..e5a63351 100644 --- a/dnn/src/x86/conv_bias/f32/algos.h +++ b/dnn/src/x86/conv_bias/f32/algos.h @@ -22,10 +22,12 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { static void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); static void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); bool m_large_group; public: @@ -57,10 +59,12 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { static void copy_padding_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); static void do_conv_kern(WorkspaceBundle bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); + const NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids); bool m_large_group; public: diff --git a/dnn/src/x86/conv_bias/int8/algos.cpp b/dnn/src/x86/conv_bias/int8/algos.cpp index 0a8c3854..3777d887 100644 --- a/dnn/src/x86/conv_bias/int8/algos.cpp +++ b/dnn/src/x86/conv_bias/int8/algos.cpp @@ -146,9 +146,11 @@ WorkspaceBundle ConvBiasImpl::AlgoMkldnnQint8::get_bundle( } while (0) void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( - const NCBKernParam& param, const NCBKernIndex&) { + const NCBKernParam& param, const NCBKernIndex& ncb_index) { UNPACK_CONV_F32_NCB_KERN_SIZES(param); MEGDNN_MARK_USED_VAR(N); + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; auto x86_handle = static_cast(inplace_cpu_handle().get()); megdnn_assert(x86_handle != nullptr, "x86 handle can not be null"); auto eng_mkldnn = x86_handle->mkldnn_engine(); @@ -167,10 +169,11 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( auto megdnn_dst_md = memory::desc({dst_shape}, memory::data_type::s32, memory::format_tag::nchw); - auto megdnn_weight_memory = memory(megdnn_weight_md, eng_mkldnn, - const_cast(param.filter_ptr)); - int8_t* src = const_cast(param.src()); - int32_t* dst = param.dst(); + auto megdnn_weight_memory = + memory(megdnn_weight_md, eng_mkldnn, + const_cast(param.filter(group_id))); + int8_t* src = const_cast(param.src(batch_id, group_id)); + int32_t* dst = param.dst(batch_id, group_id); auto megdnn_src_memory = memory(megdnn_src_md, eng_mkldnn, static_cast(src)); @@ -353,18 +356,18 @@ MatrixMul* ConvBiasImpl::AlgoMkldnnMatmulQint8::get_matmul_opr() { } void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32( - const NCBKernParam& param, const NCBKernIndex&) { + const NCBKernParam& param, const NCBKernIndex& ncb_index) { UNPACK_CONV_F32_NCB_KERN_SIZES(param); auto IH2 = IH + 2 * PH; auto IW2 = IW + 2 * PW; + size_t group_id = ncb_index.ndrange_id[0]; bool is_xcorr = !param.filter_meta.should_flip; auto bundle = get_bundle(param); bundle.set(param.workspace_ptr); for (size_t n = 0; n < N; ++n) { - int8_t* src = - const_cast(param.src()) + n * param.inp_bs; - int32_t* dst = param.dst() + n * param.out_bs; + int8_t* src = const_cast(param.src(n, group_id)); + int32_t* dst = param.dst(n, group_id); int8_t *B, *src2; if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { // special case: 1x1 @@ -414,7 +417,7 @@ void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32( { TensorND A_, B_, C_; A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Int8()); - A_.raw_ptr = const_cast(param.filter()); + A_.raw_ptr = const_cast(param.filter(group_id)); B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Int8()); B_.raw_ptr = B; C_.layout = TensorLayout({OC, OH * OW}, dtype::Int32()); diff --git a/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride1.cpp b/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride1.cpp index 574a5481..865f9399 100644 --- a/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride1.cpp +++ b/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride1.cpp @@ -47,8 +47,8 @@ void pack_src_conv_avx2_stride1(WorkspaceBundle bundle, batch_id = ncb_index.ndrange_id[1], channel_id = ncb_index.ndrange_id[2]; - const int8_t* src_ptr = - kern_param.src() + ic_step * channel_id * c_stride; + const int8_t* src_ptr = kern_param.src(batch_id, group_id) + + ic_step * channel_id * c_stride; bundle.set(kern_param.workspace_ptr); int8_t* packed_src = static_cast(bundle.get(0)) + batch_id * group * packed_group_size + @@ -129,7 +129,7 @@ static inline void pack_filter_conv_avx2_stride1( size_t group_id = ncb_index.ndrange_id[0], oc_index_id = ncb_index.ndrange_id[1]; - const int8_t* pack_filter_ptr = kern_param.filter(); + const int8_t* pack_filter_ptr = kern_param.filter(group_id); bundle.set(kern_param.workspace_ptr); int16_t* out_ptr = static_cast(bundle.get(1)) + group_id * round_up(oc, oc_step) * oc_out_stride; @@ -632,19 +632,18 @@ void do_conv_kern(WorkspaceBundle bundle, const uint32_t packed_group_size = div_ceil(ic, ic_step) * pack_ih * pack_iw; - size_t workspace_group_id = ncb_index.ndrange_id[0], - workspace_batch_id = ncb_index.ndrange_id[1], - workspace_channel_id = ncb_index.ndrange_id[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1], + channel_id = ncb_index.ndrange_id[2]; bundle.set(kern_param.workspace_ptr); int8_t* src_ptr = static_cast(bundle.get(0)) + - workspace_group_id * packed_group_size + - workspace_batch_id * group * packed_group_size; - int16_t* filter_ptr = - static_cast(bundle.get(1)) + - workspace_group_id * round_up(oc, oc_step) * filter_round_size + - oc_step * workspace_channel_id * filter_round_size; + group_id * packed_group_size + + batch_id * group * packed_group_size; + int16_t* filter_ptr = static_cast(bundle.get(1)) + + group_id * round_up(oc, oc_step) * filter_round_size + + oc_step * channel_id * filter_round_size; bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; @@ -652,12 +651,11 @@ void do_conv_kern(WorkspaceBundle bundle, int32_t* dst_tptr = nullptr; if (need_post_process) { dst_tptr = static_cast(bundle.get(2)) + - workspace_batch_id * group * oc * oc_stride + - workspace_group_id * oc * oc_stride + - oc_step * workspace_channel_id * oh * ow; + batch_id * group * oc * oc_stride + + group_id * oc * oc_stride + oc_step * channel_id * oh * ow; } else { - dst_tptr = kern_param.dst() + - oc_step * workspace_channel_id * oh * ow; + dst_tptr = kern_param.dst(batch_id, group_id) + + oc_step * channel_id * oh * ow; } const uint32_t oc_end = oc / oc_step * oc_step; @@ -666,7 +664,7 @@ void do_conv_kern(WorkspaceBundle bundle, const uint32_t oh_remain = oh - oh_end; const uint32_t ow_end = ow / ow_step * ow_step; const uint32_t ow_remain = ow - ow_end; - const uint32_t oc_index = oc_step * workspace_channel_id; + const uint32_t oc_index = oc_step * channel_id; AlgoAVX2DirectConvStride1S8S8S32_forward( @@ -684,29 +682,29 @@ void do_post_process(WorkspaceBundle bundle, const uint32_t oh = kern_param.osz[0]; const uint32_t ow = kern_param.osz[1]; - size_t workspace_group_id = ncb_index.ndrange_id[0], - workspace_batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; bundle.set(kern_param.workspace_ptr); bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; void* dst_tptr = nullptr; if (need_post_process) { dst_tptr = static_cast(bundle.get(2)) + - workspace_batch_id * group * oc * oh * ow + - workspace_group_id * oc * oh * ow; + batch_id * group * oc * oh * ow + group_id * oc * oh * ow; } else { - dst_tptr = kern_param.dst(); + dst_tptr = kern_param.dst(batch_id, group_id); } + void* dst_ptr = kern_param.dst(batch_id, group_id); #define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \ { \ - const dt_int32* bias_ptr = kern_param.bias(); \ + const dt_int32* bias_ptr = \ + kern_param.bias(batch_id, group_id); \ PostProcess::ctype, \ DTypeTrait<_dst_ctype>::ctype, \ _postprocess_mode>::run(dst_tptr, \ const_cast(bias_ptr), \ - kern_param.dst_ptr, \ - kern_param.bias_mode, \ + dst_ptr, kern_param.bias_mode, \ kern_param.nonlineMode, \ kern_param.bias_type, \ kern_param.dst_type, 1, oc, oh, \ diff --git a/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride2.cpp b/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride2.cpp index a480bbef..28f3a4b7 100644 --- a/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride2.cpp +++ b/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride2.cpp @@ -45,8 +45,8 @@ void pack_src_conv_avx2_stride2(WorkspaceBundle bundle, batch_id = ncb_index.ndrange_id[1], channel_id = ncb_index.ndrange_id[2]; - const int8_t* src_ptr = - kern_param.src() + ic_step * channel_id * c_stride; + const int8_t* src_ptr = kern_param.src(batch_id, group_id) + + ic_step * channel_id * c_stride; bundle.set(kern_param.workspace_ptr); int8_t* packed_src = static_cast(bundle.get(0)) + batch_id * group * packed_group_size + @@ -187,7 +187,7 @@ static inline void pack_filter_conv_avx2_stride2( size_t group_id = ncb_index.ndrange_id[0], oc_index_id = ncb_index.ndrange_id[1]; - const int8_t* pack_filter_ptr = kern_param.filter(); + const int8_t* pack_filter_ptr = kern_param.filter(group_id); bundle.set(kern_param.workspace_ptr); int16_t* out_ptr = static_cast(bundle.get(1)) + group_id * round_up(oc, oc_step) * oc_out_stride; @@ -705,18 +705,17 @@ void kernel_imp(WorkspaceBundle bundle, const uint32_t packed_group_size = div_ceil(ic, ic_step) * pack_ih * pack_iw; - size_t workspace_group_id = ncb_index.ndrange_id[0], - workspace_batch_id = ncb_index.ndrange_id[1], - workspace_channel_id = ncb_index.ndrange_id[2]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1], + channel_id = ncb_index.ndrange_id[2]; bundle.set(kern_param.workspace_ptr); int8_t* src_ptr = static_cast(bundle.get(0)) + - workspace_group_id * packed_group_size + - workspace_batch_id * group * packed_group_size; - int16_t* filter_ptr = - static_cast(bundle.get(1)) + - workspace_group_id * round_up(oc, oc_step) * filter_round_size + - oc_step * workspace_channel_id * filter_round_size; + group_id * packed_group_size + + batch_id * group * packed_group_size; + int16_t* filter_ptr = static_cast(bundle.get(1)) + + group_id * round_up(oc, oc_step) * filter_round_size + + oc_step * channel_id * filter_round_size; bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; @@ -724,12 +723,11 @@ void kernel_imp(WorkspaceBundle bundle, int32_t* dst_tptr = nullptr; if (need_post_process) { dst_tptr = static_cast(bundle.get(2)) + - workspace_batch_id * group * oc * oc_stride + - workspace_group_id * oc * oc_stride + - oc_step * workspace_channel_id * oh * ow; + batch_id * group * oc * oc_stride + + group_id * oc * oc_stride + oc_step * channel_id * oh * ow; } else { - dst_tptr = kern_param.dst() + - oc_step * workspace_channel_id * oh * ow; + dst_tptr = kern_param.dst(batch_id, group_id) + + oc_step * channel_id * oh * ow; } const uint32_t oc_end = oc / oc_step * oc_step; const uint32_t oc_remain = oc - oc_end; @@ -737,7 +735,7 @@ void kernel_imp(WorkspaceBundle bundle, const uint32_t oh_remain = oh - oh_end; const uint32_t ow_end = ow / ow_step * ow_step; const uint32_t ow_remain = ow - ow_end; - const uint32_t oc_index = oc_step * workspace_channel_id; + const uint32_t oc_index = oc_step * channel_id; kernel_handle_oh_remain( oh_remain, oc_remain, ow_remain, filter_ptr, src_ptr, dst_tptr, @@ -754,8 +752,8 @@ void do_post_process(WorkspaceBundle bundle, const uint32_t oh = kern_param.osz[0]; const uint32_t ow = kern_param.osz[1]; - size_t workspace_group_id = ncb_index.ndrange_id[0], - workspace_batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; bundle.set(kern_param.workspace_ptr); bool need_post_process = @@ -763,21 +761,22 @@ void do_post_process(WorkspaceBundle bundle, void* dst_tptr = nullptr; if (need_post_process) { dst_tptr = static_cast(bundle.get(2)) + - workspace_batch_id * group * oc * oh * ow + - workspace_group_id * oc * oh * ow; + batch_id * group * oc * oh * ow + + group_id * oc * oh * ow; } else { - dst_tptr = kern_param.dst(); + dst_tptr = kern_param.dst(batch_id, group_id); } + void* dst_ptr = kern_param.dst(batch_id, group_id); #define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \ { \ - const dt_int32* bias_ptr = kern_param.bias(); \ + const dt_int32* bias_ptr = \ + kern_param.bias(batch_id, group_id); \ PostProcess::ctype, \ DTypeTrait<_dst_ctype>::ctype, \ _postprocess_mode>::run(dst_tptr, \ const_cast(bias_ptr), \ - kern_param.dst_ptr, \ - kern_param.bias_mode, \ + dst_ptr, kern_param.bias_mode, \ kern_param.nonlineMode, \ kern_param.bias_type, \ kern_param.dst_type, 1, oc, oh, \ -- GitLab