提交 d346c878 编写于 作者: M Megvii Engine Team

fix(dnn/fallbackls): delete the conv_bias fallback offset

GitOrigin-RevId: c91aee2c7cfc95d1f31cc7f7eb7a05ece40ba002
上级 a7e28712
...@@ -57,11 +57,12 @@ public: ...@@ -57,11 +57,12 @@ public:
const ConvBiasImpl::NCBKernParam& param, const ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, size_t bundle_id, const WorkspaceBundle& bundle_thread, size_t bundle_id,
size_t oc_cur_index, size_t OHW, bool is_dst_8bit, size_t oc_cur_index, size_t OHW, bool is_dst_8bit,
bool ohw_bigger_ohwblock) { bool ohw_bigger_ohwblock, size_t batch_id, size_t group_id) {
if (is_dst_8bit || !ohw_bigger_ohwblock) { if (is_dst_8bit || !ohw_bigger_ohwblock) {
return static_cast<dtype*>(bundle_thread.get(bundle_id)); return static_cast<dtype*>(bundle_thread.get(bundle_id));
} else { } else {
dtype* dst = param.dst<dtype>() + oc_cur_index * OHW; dtype* dst =
param.dst<dtype>(batch_id, group_id) + oc_cur_index * OHW;
return static_cast<dtype*>(dst); return static_cast<dtype*>(dst);
} }
} }
...@@ -105,23 +106,24 @@ static void copy_padding_kern(WorkspaceBundle bundle, ...@@ -105,23 +106,24 @@ static void copy_padding_kern(WorkspaceBundle bundle,
size_t IW2 = IW + 2 * PW; size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH; size_t IH2 = IH + 2 * PH;
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t padding_group_size = IH2 * IW2 * IC; size_t padding_group_size = IH2 * IW2 * IC;
size_t input_channel_offset = IH * IW * ncb_index.ndrange_id[2]; size_t input_channel_offset = IH * IW * channel_id;
size_t workspace_channel_offset = IH2 * IW2 * ncb_index.ndrange_id[2]; size_t workspace_channel_offset = IH2 * IW2 * channel_id;
size_t workspace_group_offset = size_t workspace_group_offset = group_id * padding_group_size;
ncb_index.ndrange_id[0] * padding_group_size; size_t workspace_batch_offset =
size_t workspace_batch_offset = param.filter_meta.group * param.filter_meta.group * batch_id * padding_group_size;
ncb_index.ndrange_id[1] *
padding_group_size;
bundle.set(param.workspace_ptr); bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0); src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
} }
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>() + src_ctype* src = const_cast<src_ctype*>(
input_channel_offset); param.src<src_ctype>(batch_id, group_id) + input_channel_offset);
src_ctype* src2; src_ctype* src2;
src2 = static_cast<src_ctype*>( src2 = static_cast<src_ctype*>(
bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) +
...@@ -153,8 +155,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, ...@@ -153,8 +155,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
*/ */
#define COPY_BIAS() \ #define COPY_BIAS() \
const bias_ctype* bias_ptr = \ const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( \
static_cast<const bias_ctype*>(param.bias_ptr); \ param.bias<bias_ctype>(batch_id, group_id)); \
bias_ctype* bias_temp_ptr = \ bias_ctype* bias_temp_ptr = \
PtrGetter::get_bias_temp_ptr<bias_ctype>(param, bundle_thread); \ PtrGetter::get_bias_temp_ptr<bias_ctype>(param, bundle_thread); \
if (param.bias_mode == megdnn::BiasMode::BIAS) { \ if (param.bias_mode == megdnn::BiasMode::BIAS) { \
...@@ -172,7 +174,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, ...@@ -172,7 +174,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
#define IM2COL() \ #define IM2COL() \
src_ctype* im2col_dst = nullptr; \ src_ctype* im2col_dst = nullptr; \
src_ctype* no_padding_src = \ src_ctype* no_padding_src = \
const_cast<src_ctype*>(param.src<src_ctype>()) + ohw_cur_index; \ const_cast<src_ctype*>(param.src<src_ctype>(batch_id, group_id)) + \
ohw_cur_index; \
if (!special_1x1) { \ if (!special_1x1) { \
size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \ size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \
src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \ src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
...@@ -181,7 +184,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, ...@@ -181,7 +184,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
param.filter_meta.group * ncb_index.ndrange_id[1]) * \ param.filter_meta.group * ncb_index.ndrange_id[1]) * \
padding_group_size); \ padding_group_size); \
if (PH == 0 && PW == 0) { \ if (PH == 0 && PW == 0) { \
src2 = const_cast<src_ctype*>(param.src<src_ctype>()); \ src2 = const_cast<src_ctype*>( \
param.src<src_ctype>(batch_id, group_id)); \
} \ } \
im2col_dst = static_cast<src_ctype*>(bundle_thread.get( \ im2col_dst = static_cast<src_ctype*>(bundle_thread.get( \
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \ Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \
...@@ -217,8 +221,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, ...@@ -217,8 +221,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
output_block_size); \ output_block_size); \
if (!skip_copy_dst) { \ if (!skip_copy_dst) { \
dst_ctype* dst_tmp_ptr = reinterpret_cast<dst_ctype*>(matmul_dst); \ dst_ctype* dst_tmp_ptr = reinterpret_cast<dst_ctype*>(matmul_dst); \
dst_ctype* dst = \ dst_ctype* dst = param.dst<dst_ctype>(batch_id, group_id) + \
param.dst<dst_ctype>() + oc_cur_index * OHW + ohw_cur_index; \ oc_cur_index * OHW + ohw_cur_index; \
for (size_t oc = 0; oc < output_block_oc_size; oc++) { \ for (size_t oc = 0; oc < output_block_oc_size; oc++) { \
std::memcpy(dst, dst_tmp_ptr, \ std::memcpy(dst, dst_tmp_ptr, \
sizeof(dst_ctype) * output_block_size); \ sizeof(dst_ctype) * output_block_size); \
...@@ -243,7 +247,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, ...@@ -243,7 +247,7 @@ static void copy_padding_kern(WorkspaceBundle bundle,
bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
param, bundle_thread, \ param, bundle_thread, \
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \ Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \
is_dst_8bit, is_ohw_size_bigger); is_dst_8bit, is_ohw_size_bigger, batch_id, group_id);
#define MATMUL_COMPUTE() \ #define MATMUL_COMPUTE() \
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \
...@@ -272,6 +276,7 @@ public: ...@@ -272,6 +276,7 @@ public:
ConvBiasImpl::NCBKernIndex ncb_index) { ConvBiasImpl::NCBKernIndex ncb_index) {
bundle.set(param.workspace_ptr); bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param; fallback::MatrixMulImpl::KernParam matmul_param;
size_t group_id = ncb_index.ndrange_id[0];
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
matmulparam; matmulparam;
size_t packA_group_size = size_t packA_group_size =
...@@ -283,11 +288,11 @@ public: ...@@ -283,11 +288,11 @@ public:
matmul_algo->get_packA_type_size(); matmul_algo->get_packA_type_size();
size_t a_panel_offset = size_t a_panel_offset =
ncb_index.ndrange_id[2] * packed_per_oc_block_size; ncb_index.ndrange_id[2] * packed_per_oc_block_size;
int8_t* a_panel = int8_t* a_panel = static_cast<int8_t*>(bundle.get(
static_cast<int8_t*>( Im2colBundelIndex::BUNDLE_PACKA_INDEX)) +
bundle.get(Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + group_id * packA_group_size + a_panel_offset;
ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset; matmul_param.A_ptr =
matmul_param.A_ptr = const_cast<src_ctype*>(param.filter<src_ctype>()); const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[2], matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[2],
matmul_algo->get_inner_block_size().m); matmul_algo->get_inner_block_size().m);
}; };
...@@ -309,6 +314,8 @@ public: ...@@ -309,6 +314,8 @@ public:
auto IH2 = IH + 2 * PH; auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW; auto IW2 = IW + 2 * PW;
size_t OHW = OH * OW; size_t OHW = OH * OW;
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t output_block_size = std::min( size_t output_block_size = std::min(
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size);
size_t output_block_oc_size = std::min( size_t output_block_oc_size = std::min(
...@@ -369,11 +376,11 @@ public: ...@@ -369,11 +376,11 @@ public:
\ \
src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \ src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \ bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \
ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset); \ group_id * packA_group_size + a_panel_offset); \
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
param, bundle_thread, \ param, bundle_thread, \
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \
OHW, is_dst_8bit, is_ohw_size_bigger); OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id);
#define MATMUL_COMPUTE() \ #define MATMUL_COMPUTE() \
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \
...@@ -402,6 +409,7 @@ public: ...@@ -402,6 +409,7 @@ public:
matmulparam; matmulparam;
size_t OC = param.filter_meta.ocpg; size_t OC = param.filter_meta.ocpg;
size_t oc_tile_size = matmul_param.M; size_t oc_tile_size = matmul_param.M;
size_t group_id = ncb_index.ndrange_id[0];
size_t output_block_oc_size = std::min( size_t output_block_oc_size = std::min(
oc_tile_size, OC - ncb_index.ndrange_id[2] * oc_tile_size); oc_tile_size, OC - ncb_index.ndrange_id[2] * oc_tile_size);
size_t oc_cur_index = ncb_index.ndrange_id[2] * oc_tile_size; size_t oc_cur_index = ncb_index.ndrange_id[2] * oc_tile_size;
...@@ -411,11 +419,11 @@ public: ...@@ -411,11 +419,11 @@ public:
size_t a_panel_offset = size_t a_panel_offset =
ncb_index.ndrange_id[2] * ncb_index.ndrange_id[2] *
matmul_algo->get_bundle(matmul_param).get_size(0); matmul_algo->get_bundle(matmul_param).get_size(0);
int8_t* a_panel = int8_t* a_panel = static_cast<int8_t*>(bundle.get(
static_cast<int8_t*>( Im2colBundelIndex::BUNDLE_PACKA_INDEX)) +
bundle.get(Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + group_id * packA_group_size + a_panel_offset;
ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset; matmul_param.A_ptr =
matmul_param.A_ptr = const_cast<src_ctype*>(param.filter<src_ctype>()) + const_cast<src_ctype*>(param.filter<src_ctype>(group_id)) +
oc_cur_index * matmul_param.K; oc_cur_index * matmul_param.K;
matmul_param.M = output_block_oc_size; matmul_param.M = output_block_oc_size;
matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z); matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z);
...@@ -437,6 +445,8 @@ public: ...@@ -437,6 +445,8 @@ public:
MEGDNN_MARK_USED_VAR(N); MEGDNN_MARK_USED_VAR(N);
auto IH2 = IH + 2 * PH; auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW; auto IW2 = IW + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t OHW = OH * OW; size_t OHW = OH * OW;
size_t output_block_size = std::min( size_t output_block_size = std::min(
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size);
...@@ -490,11 +500,11 @@ public: ...@@ -490,11 +500,11 @@ public:
#define PREPAR_MATMUL_DATA() \ #define PREPAR_MATMUL_DATA() \
bias_ctype* matmul_dst = nullptr; \ bias_ctype* matmul_dst = nullptr; \
const src_ctype* filter = \ const src_ctype* filter = \
param.filter<src_ctype>() + oc_cur_index * IC * FH * FW; \ param.filter<src_ctype>(group_id) + oc_cur_index * IC * FH * FW; \
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
param, bundle_thread, \ param, bundle_thread, \
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \
OHW, is_dst_8bit, is_ohw_size_bigger); OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id);
#define MATMUL_COMPUTE() \ #define MATMUL_COMPUTE() \
matmul_param.M = output_block_oc_size; \ matmul_param.M = output_block_oc_size; \
...@@ -526,6 +536,8 @@ public: ...@@ -526,6 +536,8 @@ public:
MEGDNN_MARK_USED_VAR(N); MEGDNN_MARK_USED_VAR(N);
auto IH2 = IH + 2 * PH; auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW; auto IW2 = IW + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t OHW = OH * OW; size_t OHW = OH * OW;
size_t output_block_size = std::min( size_t output_block_size = std::min(
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size);
......
...@@ -245,65 +245,10 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param( ...@@ -245,65 +245,10 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(
void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvBiasImpl::Algorithm* algo) { ConvBiasImpl::Algorithm* algo) {
auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param); auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param);
size_t src_batch_stride = param.inp_bs * param.src_type.size();
size_t dst_batch_stride = param.out_bs * param.dst_type.size();
size_t bias_batch_stride = 0;
if (param.bias_mode == BiasMode::BIAS) {
bias_batch_stride = param.bias_bs * param.bias_type.size();
}
for (auto&& kernel : ncb_kerns) { for (auto&& kernel : ncb_kerns) {
megdnn_assert(
param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NHWC ||
param.filter_meta.format ==
Param::Format::NCHW_WINOGRAD ||
param.filter_meta.format == Param::Format::NCHW88 ||
param.filter_meta.format ==
Param::Format::NCHW88_WINOGRAD,
"invalid conv format");
ptrdiff_t istrd = 0, fstrd = 0, bstrd = 0, ostrd = 0;
if (param.filter_meta.format == Param::Format::NCHW_WINOGRAD ||
param.filter_meta.format == Param::Format::NCHW88_WINOGRAD) {
fstrd = param.filter_meta.icpg * param.filter_meta.ocpg *
(param.filter_meta.spatial[0] + param.output_block_size -
1) *
(param.filter_meta.spatial[1] + param.output_block_size -
1) *
param.filter_type.size();
} else {
fstrd = param.filter_meta.icpg * param.filter_meta.ocpg *
param.filter_meta.spatial[0] *
param.filter_meta.spatial[1] * param.filter_type.size();
}
istrd = param.filter_meta.icpg * param.src_type.size();
ostrd = param.filter_meta.ocpg * param.dst_type.size();
if (param.bias_mode != BiasMode::NO_BIAS) {
bstrd = param.filter_meta.ocpg * param.bias_type.size();
}
if (param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NCHW_WINOGRAD ||
param.filter_meta.format == Param::Format::NCHW88_WINOGRAD) {
istrd *= param.isz[0] * param.isz[1];
ostrd *= param.osz[0] * param.osz[1];
if (param.bias_mode == BiasMode::BIAS) {
bstrd *= param.osz[0] * param.osz[1];
}
} else {
// must be NHWC. No action performed.
}
auto run = [=](size_t index, size_t thread_id) { auto run = [=](size_t index, size_t thread_id) {
auto copy_param = param; auto copy_param = param;
CpuNDRange ndrange_id(kernel.global_size, index); CpuNDRange ndrange_id(kernel.global_size, index);
size_t group_id = ndrange_id[0];
size_t batch_id = ndrange_id[1];
//! The kernel ptr point to batch index
incr_ptr(copy_param.src_ptr,
group_id * istrd + batch_id * src_batch_stride);
incr_ptr(copy_param.filter_ptr, group_id * fstrd);
incr_ptr(copy_param.bias_ptr,
group_id * bstrd + batch_id * bias_batch_stride);
incr_ptr(copy_param.dst_ptr,
group_id * ostrd + batch_id * dst_batch_stride);
kernel.kern(copy_param, {thread_id, ndrange_id}); kernel.kern(copy_param, {thread_id, ndrange_id});
}; };
static_cast<naive::HandleImpl*>(handle())->dispatch_kern( static_cast<naive::HandleImpl*>(handle())->dispatch_kern(
...@@ -381,4 +326,124 @@ const char* ConvBiasImpl::get_algorithm_set_name() const { ...@@ -381,4 +326,124 @@ const char* ConvBiasImpl::get_algorithm_set_name() const {
return "F0"; return "F0";
} }
namespace megdnn{
namespace fallback {
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template <typename T>
const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_id,
size_t group_pack_size) const {
src_type.assert_is_compatible_ctype<T>();
size_t batch_offset = batch_id * inp_bs * src_type.size();
size_t group_offset = group_pack_size * group_id * filter_meta.icpg *
isz[0] * isz[1] * src_type.size();
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) +
batch_offset + group_offset);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template <typename T>
const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id,
size_t pack_group_size) const {
size_t group_offset = 0_z;
switch (filter_meta.format) {
case Param::Format::NCHW: {
group_offset = pack_group_size * group_id * filter_meta.icpg *
filter_meta.ocpg * filter_meta.spatial[0] *
filter_meta.spatial[1] * filter_type.size();
break;
}
case Param::Format::NCHW88: {
size_t group = filter_meta.group;
size_t icpg = filter_meta.icpg;
size_t ocpg = filter_meta.ocpg;
//! four format of weight layout
//! 1. {oc/8, ic/8, fh, fw, 8, 8}, 2. {g, oc/8, ic/8, fh,
//! fw, 8, 8}
//! 3. {g/8, 1, 1, fh, fw, 8, 8}, 3. {oc/8 ,fh, fw, ic, 8}
megdnn_assert((icpg % 8 == 0 && ocpg % 8 == 0) ||
(group % 8 == 0 && icpg == 1 && ocpg == 1 &&
pack_group_size > 1) ||
(group == 1 && ocpg % 8 == 0),
"The filter shepe is not right of nchw88");
group_offset = pack_group_size * group_id * filter_meta.icpg *
filter_meta.ocpg * filter_meta.spatial[0] *
filter_meta.spatial[1] * filter_type.size();
break;
}
case ConvBiasImpl::Param::Format::NCHW_WINOGRAD:
case ConvBiasImpl::Param::Format::NCHW88_WINOGRAD: {
//! four format of weight layout
//! 1. {g, alpha, alpha, ocpg/8, icpg/8, 8, 8}
//! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8}
//! 3. {g, alpha, alpha, oc, ic, 8, 8}
//! 4. {alpha, alpha, oc, ic}
group_offset = pack_group_size * group_id * filter_meta.icpg *
filter_meta.ocpg *
(filter_meta.spatial[0] + output_block_size - 1) *
(filter_meta.spatial[1] + output_block_size - 1) *
filter_type.size();
break;
}
default:
megdnn_assert("other filter format is not support yet");
}
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(filter_ptr) +
group_offset);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template <typename T>
const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_id,
size_t group_pack_size) const {
bias_type.assert_is_compatible_ctype<T>();
size_t batch_offset = 0_z;
size_t group_offset = 0_z;
if (bias_mode == BiasMode::BIAS) {
batch_offset = batch_id * bias_bs * bias_type.size();
group_offset = group_pack_size * group_id * filter_meta.ocpg * osz[0] *
osz[1] * bias_type.size();
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
group_offset = group_pack_size * group_id * filter_meta.ocpg *
bias_type.size();
}
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(bias_ptr) +
batch_offset + group_offset);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template <typename T>
T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_id,
size_t group_pack_size) const {
dst_type.assert_is_compatible_ctype<T>();
size_t batch_offset = batch_id * out_bs * dst_type.size();
size_t group_offset = group_pack_size * group_id * filter_meta.ocpg *
osz[0] * osz[1] * dst_type.size();
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) +
batch_offset + group_offset);
}
#define INST(T) \
template const T* ConvBiasImpl::NCBKernParam::src<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const; \
template const T* ConvBiasImpl::NCBKernParam::bias<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const; \
template const T* ConvBiasImpl::NCBKernParam::filter<T>( \
size_t group_id, size_t group_pack_size) const; \
template T* ConvBiasImpl::NCBKernParam::dst<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const;
#define INST_DT(d) INST(DTypeTrait<d>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT)
#undef INST
#undef INST_DT
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -104,24 +104,39 @@ public: ...@@ -104,24 +104,39 @@ public:
return static_cast<const T*>(src_ptr); return static_cast<const T*>(src_ptr);
} }
template <typename T>
const T* src(size_t batch_id, size_t group_id,
size_t group_pack_size = 1_z) const;
template <typename T> template <typename T>
const T* filter() const { const T* filter() const {
filter_type.assert_is_compatible_ctype<T>(); filter_type.assert_is_compatible_ctype<T>();
return static_cast<const T*>(filter_ptr); return static_cast<const T*>(filter_ptr);
} }
template <typename T>
const T* filter(size_t group_id, size_t pack_group_size = 1_z) const;
template <typename T> template <typename T>
const T* bias() const { const T* bias() const {
bias_type.assert_is_compatible_ctype<T>(); bias_type.assert_is_compatible_ctype<T>();
return static_cast<const T*>(bias_ptr); return static_cast<const T*>(bias_ptr);
} }
template <typename T>
const T* bias(size_t batch_id, size_t group_id,
size_t group_pack_size = 1_z) const;
template <typename T> template <typename T>
T* dst() const { T* dst() const {
dst_type.assert_is_compatible_ctype<T>(); dst_type.assert_is_compatible_ctype<T>();
return static_cast<T*>(dst_ptr); return static_cast<T*>(dst_ptr);
} }
template <typename T>
T* dst(size_t batch_id, size_t group_id,
size_t group_pack_size = 1_z) const;
template <typename T> template <typename T>
T* workspace() const { T* workspace() const {
return static_cast<T*>(workspace_ptr); return static_cast<T*>(workspace_ptr);
......
...@@ -210,7 +210,7 @@ public: ...@@ -210,7 +210,7 @@ public:
reinterpret_cast<input_filter_compute_type*>( reinterpret_cast<input_filter_compute_type*>(
reinterpret_cast<uintptr_t>(bundle_compute.get(2)) + reinterpret_cast<uintptr_t>(bundle_compute.get(2)) +
compute_workspace_size_per_thread * thread_id); compute_workspace_size_per_thread * thread_id);
const stype* filter_ptr = kern_param.filter<stype>(); const stype* filter_ptr = kern_param.filter<stype>(group_id);
size_t oc_start = oc_id, oc_end = oc_id+1; size_t oc_start = oc_id, oc_end = oc_id+1;
if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) { if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) {
oc_start = 8 * oc_id; oc_start = 8 * oc_id;
...@@ -246,16 +246,19 @@ public: ...@@ -246,16 +246,19 @@ public:
size_t oc_block_id = ncb_index.ndrange_id[3]; size_t oc_block_id = ncb_index.ndrange_id[3];
size_t tile_id = ncb_index.ndrange_id[2]; size_t tile_id = ncb_index.ndrange_id[2];
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0]; size_t group_id = ncb_index.ndrange_id[0];
size_t thread_id = ncb_index.thread_id; size_t thread_id = ncb_index.thread_id;
bundle_top.set(ncb_param.workspace_ptr); bundle_top.set(ncb_param.workspace_ptr);
bundle_compute.set(bundle_top.get(0)); bundle_compute.set(bundle_top.get(0));
const stype* src_ptr = ncb_param.src<stype>(); const stype* src_ptr = ncb_param.src<stype>(batch_id, group_id);
dst_type* dst_ptr = ncb_param.dst<dst_type>(); dst_type* dst_ptr = ncb_param.dst<dst_type>(batch_id, group_id);
const output_compute_type* bias_ptr = const output_compute_type* bias_ptr =
static_cast<const output_compute_type*>(ncb_param.bias_ptr); static_cast<const output_compute_type*>(
ncb_param.bias<output_compute_type>(batch_id,
group_id));
input_filter_compute_type* input_transform_buf = input_filter_compute_type* input_transform_buf =
reinterpret_cast<input_filter_compute_type*>( reinterpret_cast<input_filter_compute_type*>(
...@@ -271,9 +274,10 @@ public: ...@@ -271,9 +274,10 @@ public:
reinterpret_cast<uintptr_t>(bundle_compute.get(2)) + reinterpret_cast<uintptr_t>(bundle_compute.get(2)) +
compute_workspace_size_per_thread * thread_id); compute_workspace_size_per_thread * thread_id);
//! NCHW88_WINOGRAD and NCHW_WINOGRAD is the same offset
const input_filter_compute_type* filter_transform_buf = const input_filter_compute_type* filter_transform_buf =
static_cast<const input_filter_compute_type*>( static_cast<const input_filter_compute_type*>(
ncb_param.filter_ptr); ncb_param.filter<input_filter_compute_type>(group_id));
if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW || if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW ||
ncb_param.filter_meta.format == param::ConvBias::Format::NCHW88) { ncb_param.filter_meta.format == param::ConvBias::Format::NCHW88) {
filter_transform_buf = reinterpret_cast<input_filter_compute_type*>( filter_transform_buf = reinterpret_cast<input_filter_compute_type*>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册