提交 ea93927d 编写于 作者: M Megvii Engine Team

fix(dnn/fallbackls): fix fallback convolution and all conv_bias algos

GitOrigin-RevId: e37fbe0ffe9f8ab6325f41d9ee6857115630b17a
上级 d346c878
......@@ -213,11 +213,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoNaive::dispatch_kerns(
const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
MIDOUT_BEGIN(megdnn_fallback_naive, 2) {
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t thread_id = ncb_index.thread_id;
auto thread_param = param;
thread_param.workspace_ptr = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(param.workspace_ptr) +
thread_id * workspace_per_thread);
thread_param.filter_ptr = param.filter<void>(group_id);
thread_param.dst_ptr = param.dst<void>(batch_id, group_id);
thread_param.src_ptr = param.src<void>(batch_id, group_id);
thread_param.bias_ptr = param.bias<void>(batch_id, group_id);
kern_default(opr_param, thread_param);
}
MIDOUT_END();
......
......@@ -111,7 +111,6 @@ static void copy_padding_kern(WorkspaceBundle bundle,
size_t channel_id = ncb_index.ndrange_id[2];
size_t padding_group_size = IH2 * IW2 * IC;
size_t input_channel_offset = IH * IW * channel_id;
size_t workspace_channel_offset = IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
......@@ -123,7 +122,7 @@ static void copy_padding_kern(WorkspaceBundle bundle,
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(
param.src<src_ctype>(batch_id, group_id) + input_channel_offset);
param.src<src_ctype>(batch_id, group_id, channel_id));
src_ctype* src2;
src2 = static_cast<src_ctype*>(
bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) +
......
......@@ -246,10 +246,9 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvBiasImpl::Algorithm* algo) {
auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param);
for (auto&& kernel : ncb_kerns) {
auto run = [=](size_t index, size_t thread_id) {
auto copy_param = param;
auto run = [kernel, param](size_t index, size_t thread_id) {
CpuNDRange ndrange_id(kernel.global_size, index);
kernel.kern(copy_param, {thread_id, ndrange_id});
kernel.kern(param, {thread_id, ndrange_id});
};
static_cast<naive::HandleImpl*>(handle())->dispatch_kern(
run, kernel.global_size.total_size());
......@@ -328,28 +327,29 @@ const char* ConvBiasImpl::get_algorithm_set_name() const {
namespace megdnn{
namespace fallback {
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template <typename T>
const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_id,
size_t group_pack_size) const {
src_type.assert_is_compatible_ctype<T>();
const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id,
size_t channel_pack_id,
size_t group_pack_size,
size_t channel_pack_size) const {
size_t batch_offset = batch_id * inp_bs * src_type.size();
size_t group_offset = group_pack_size * group_id * filter_meta.icpg *
size_t group_offset = group_pack_size * group_pack_id * filter_meta.icpg *
isz[0] * isz[1] * src_type.size();
size_t channel_offset = channel_pack_size * channel_pack_id * isz[0] *
isz[1] * src_type.size();
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) +
batch_offset + group_offset);
batch_offset + group_offset + channel_offset);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template <typename T>
const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id,
const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id,
size_t pack_group_size) const {
size_t group_offset = 0_z;
switch (filter_meta.format) {
case Param::Format::NCHW: {
group_offset = pack_group_size * group_id * filter_meta.icpg *
group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
filter_meta.ocpg * filter_meta.spatial[0] *
filter_meta.spatial[1] * filter_type.size();
break;
......@@ -359,15 +359,15 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id,
size_t icpg = filter_meta.icpg;
size_t ocpg = filter_meta.ocpg;
//! four format of weight layout
//! 1. {oc/8, ic/8, fh, fw, 8, 8}, 2. {g, oc/8, ic/8, fh,
//! fw, 8, 8}
//! 3. {g/8, 1, 1, fh, fw, 8, 8}, 3. {oc/8 ,fh, fw, ic, 8}
//! 1. {oc/8, ic/8, fh, fw, 8, 8},
//! 2. {g, oc/8, ic/8, fh, fw, 8, 8},
//! 3. {g/8, fh, fw, 1, 1, 8}, 4. {oc/8, fh, fw, ic, 8}
megdnn_assert((icpg % 8 == 0 && ocpg % 8 == 0) ||
(group % 8 == 0 && icpg == 1 && ocpg == 1 &&
pack_group_size > 1) ||
(group == 1 && ocpg % 8 == 0),
"The filter shepe is not right of nchw88");
group_offset = pack_group_size * group_id * filter_meta.icpg *
group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
filter_meta.ocpg * filter_meta.spatial[0] *
filter_meta.spatial[1] * filter_type.size();
......@@ -380,7 +380,7 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id,
//! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8}
//! 3. {g, alpha, alpha, oc, ic, 8, 8}
//! 4. {alpha, alpha, oc, ic}
group_offset = pack_group_size * group_id * filter_meta.icpg *
group_offset = pack_group_size * group_pack_id * filter_meta.icpg *
filter_meta.ocpg *
(filter_meta.spatial[0] + output_block_size - 1) *
(filter_meta.spatial[1] + output_block_size - 1) *
......@@ -388,58 +388,66 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id,
break;
}
default:
megdnn_assert("other filter format is not support yet");
megdnn_assert(0, "other filter format is not support yet");
}
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(filter_ptr) +
group_offset);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template <typename T>
const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_id,
size_t group_pack_size) const {
bias_type.assert_is_compatible_ctype<T>();
const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_pack_id,
size_t channel_pack_id,
size_t group_pack_size,
size_t channel_pack_size) const {
size_t batch_offset = 0_z;
size_t group_offset = 0_z;
size_t channel_offset = 0_z;
if (bias_mode == BiasMode::BIAS) {
batch_offset = batch_id * bias_bs * bias_type.size();
group_offset = group_pack_size * group_id * filter_meta.ocpg * osz[0] *
osz[1] * bias_type.size();
group_offset = group_pack_size * group_pack_id * filter_meta.ocpg *
osz[0] * osz[1] * bias_type.size();
channel_offset = channel_pack_size * channel_pack_id * osz[0] * osz[1] *
bias_type.size();
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
group_offset = group_pack_size * group_id * filter_meta.ocpg *
group_offset = group_pack_size * group_pack_id * filter_meta.ocpg *
bias_type.size();
channel_offset = channel_pack_size * channel_pack_id * bias_type.size();
}
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(bias_ptr) +
batch_offset + group_offset);
batch_offset + group_offset + channel_offset);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template <typename T>
T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_id,
size_t group_pack_size) const {
dst_type.assert_is_compatible_ctype<T>();
T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_pack_id,
size_t channel_pack_id,
size_t group_pack_size,
size_t channel_pack_size) const {
size_t batch_offset = batch_id * out_bs * dst_type.size();
size_t group_offset = group_pack_size * group_id * filter_meta.ocpg *
size_t group_offset = group_pack_size * group_pack_id * filter_meta.ocpg *
osz[0] * osz[1] * dst_type.size();
size_t channel_offset = channel_pack_size * channel_pack_id * osz[0] *
osz[1] * dst_type.size();
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) +
batch_offset + group_offset);
batch_offset + group_offset + channel_offset);
}
#define INST(T) \
template const T* ConvBiasImpl::NCBKernParam::src<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const; \
template const T* ConvBiasImpl::NCBKernParam::bias<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const; \
template const T* ConvBiasImpl::NCBKernParam::filter<T>( \
size_t group_id, size_t group_pack_size) const; \
template T* ConvBiasImpl::NCBKernParam::dst<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const;
#define INST(T) \
template const T* ConvBiasImpl::NCBKernParam::src<T>( \
size_t batch_id, size_t group_id, size_t channel_id, \
size_t group_pack_size, size_t channel_pack_size) const; \
template const T* ConvBiasImpl::NCBKernParam::bias<T>( \
size_t batch_id, size_t group_id, size_t channel_id, \
size_t group_pack_size, size_t channel_pack_size) const; \
template const T* ConvBiasImpl::NCBKernParam::filter<T>( \
size_t group_id, size_t group_pack_size) const; \
template T* ConvBiasImpl::NCBKernParam::dst<T>( \
size_t batch_id, size_t group_id, size_t channel_id, \
size_t group_pack_size, size_t channel_pack_size) const;
#define INST_DT(d) INST(DTypeTrait<d>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT)
INST(void)
#undef INST
#undef INST_DT
} // namespace fallback
......
......@@ -103,10 +103,32 @@ public:
src_type.assert_is_compatible_ctype<T>();
return static_cast<const T*>(src_ptr);
}
//! when format is nchwxx, multi channel will pack into one
//! chnannel_pack_id. pack_channel_size is the number of packed channel
//! when format is nchwxx and channel wise, multi group will pack into
//! one group_pack_id. group_pack_size is the number of packed group
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8}
template <typename T>
const T* src(size_t batch_id, size_t group_pack_id,
size_t channel_pack_id = 0, size_t group_pack_size = 1,
size_t channel_pack_size = 1) const;
template <typename T>
const T* bias(size_t batch_id, size_t group_pack_id,
size_t channel_pack_id = 0, size_t group_pack_size = 1,
size_t channel_pack_size = 1) const;
template <typename T>
const T* src(size_t batch_id, size_t group_id,
size_t group_pack_size = 1_z) const;
T* dst(size_t batch_id, size_t group_pack_id,
size_t channel_pack_id = 0, size_t group_pack_size = 1,
size_t channel_pack_size = 1) const;
//! when format is nchwxx and channel wise, multi group will pack into
//! one group_pack_id. group_pack_size is the number of packed group
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8}
template <typename T>
const T* filter(size_t group_pack_id,
size_t pack_group_size = 1_z) const;
template <typename T>
const T* filter() const {
......@@ -114,29 +136,18 @@ public:
return static_cast<const T*>(filter_ptr);
}
template <typename T>
const T* filter(size_t group_id, size_t pack_group_size = 1_z) const;
template <typename T>
const T* bias() const {
bias_type.assert_is_compatible_ctype<T>();
return static_cast<const T*>(bias_ptr);
}
template <typename T>
const T* bias(size_t batch_id, size_t group_id,
size_t group_pack_size = 1_z) const;
template <typename T>
T* dst() const {
dst_type.assert_is_compatible_ctype<T>();
return static_cast<T*>(dst_ptr);
}
template <typename T>
T* dst(size_t batch_id, size_t group_id,
size_t group_pack_size = 1_z) const;
template <typename T>
T* workspace() const {
return static_cast<T*>(workspace_ptr);
......
......@@ -197,9 +197,12 @@ ConvolutionImpl::AlgoFallback::dispatch_kern(
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p,
const NCBKernIndex& ncb_index) {
UNPACK_CONV_F32_NCB_KERN_SIZES(p);
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0];
MEGDNN_MARK_USED_VAR(N);
auto src = p.src<float>(), filter = p.filter<float>();
auto dst = p.dst<float>();
auto src = p.src<float>(batch_id, group_id),
filter = p.filter<float>(group_id);
auto dst = p.dst<float>(batch_id, group_id);
size_t thread_id = ncb_index.thread_id;
void* workspace_ptr = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) +
......
......@@ -20,18 +20,25 @@ namespace fallback {
template <typename ST, typename DT, typename CT>
void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p,
const ConvolutionImpl::NCBKernIndex& /*index*/) {
const ConvolutionImpl::NCBKernIndex& ncb_index) {
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0];
auto IC = p.filter_meta.icpg, IH = p.isz[0], IW = p.isz[1],
OC = p.filter_meta.ocpg, OH = p.osz[0], OW = p.osz[1];
ptrdiff_t fstrd = p.filter_meta.icpg * p.filter_meta.ocpg *
p.filter_meta.spatial[0] * p.filter_meta.spatial[1] *
p.filter_type.size();
ptrdiff_t istrd = p.filter_meta.icpg * p.src_type.size();
ptrdiff_t ostrd = p.filter_meta.ocpg * p.dst_type.size();
TensorND src, dst;
src.raw_ptr = const_cast<void*>(p.src_ptr);
dst.raw_ptr = p.dst_ptr;
src.layout.dtype = p.src_type;
dst.layout.dtype = p.dst_type;
if (p.filter_meta.format == param::Convolution::Format::NCHW) {
src.layout.init_contiguous_stride({1, IC, IH, IW});
dst.layout.init_contiguous_stride({1, OC, OH, OW});
istrd *= p.isz[0] * p.isz[1];
ostrd *= p.osz[0] * p.osz[1];
src.layout.init_contiguous_stride({1, IC, IH, IW});
dst.layout.init_contiguous_stride({1, OC, OH, OW});
} else {
// Must be NHWC
megdnn_assert(
......@@ -41,9 +48,17 @@ void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p,
src.layout.init_contiguous_stride({1, IH, IW, IC});
dst.layout.init_contiguous_stride({1, OH, OW, OC});
}
src.raw_ptr = reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(p.src_ptr) +
batch_id * p.inp_bs * p.src_type.size() + group_id * istrd);
dst.raw_ptr = reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(p.dst_ptr) +
batch_id * p.out_bs * p.dst_type.size() + group_id * ostrd);
ST* filter = reinterpret_cast<ST*>(
reinterpret_cast<uintptr_t>(p.filter_ptr) + group_id * fstrd);
std::copy(p.inp_s, p.inp_s + 4, src.layout.stride);
std::copy(p.out_s, p.out_s + 4, dst.layout.stride);
naive::convolution::forward<ST, ST, DT, CT>(src, p.filter<ST>(), dst,
naive::convolution::forward<ST, ST, DT, CT>(src, filter, dst,
p.filter_meta);
}
......
......@@ -189,41 +189,15 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
Algorithm* algo) {
auto kerns = ncb_algo_dispatch_kern(algo, param);
size_t src_batch_stride = param.inp_bs * param.src_type.size();
size_t dst_batch_stride = param.out_bs * param.dst_type.size();
auto group = param.filter_meta.group;
auto fallback_handle = handle();
for (auto kernel : kerns) {
megdnn_assert(param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NHWC ||
param.filter_meta.format == Param::Format::NHWC ||
param.filter_meta.format == Param::Format::NCHW88,
"invalid conv format");
ptrdiff_t istrd = 0, fstrd = 0, ostrd = 0;
fstrd = param.filter_meta.icpg * param.filter_meta.ocpg *
param.filter_meta.spatial[0] * param.filter_meta.spatial[1] *
param.filter_type.size();
istrd = param.filter_meta.icpg * param.src_type.size();
ostrd = param.filter_meta.ocpg * param.dst_type.size();
if (param.filter_meta.format == Param::Format::NCHW) {
istrd *= param.isz[0] * param.isz[1];
ostrd *= param.osz[0] * param.osz[1];
} else {
// must be NHWC. No action performed.
}
auto run = [=](size_t index, size_t thread_id) {
auto copy_param = param;
auto run = [param, kernel](size_t index, size_t thread_id) {
CpuNDRange ndrange_id(kernel.global_size, index);
size_t group_id = ndrange_id[0];
size_t batch_id = ndrange_id[1];
megdnn_assert(group_id < group,
"The group id should smaller than gruop");
//! The kernel ptr point to batch index
incr_ptr(copy_param.src_ptr,
group_id * istrd + batch_id * src_batch_stride);
incr_ptr(copy_param.filter_ptr, group_id * fstrd);
incr_ptr(copy_param.dst_ptr,
group_id * ostrd + batch_id * dst_batch_stride);
kernel.kern(copy_param, {thread_id, ndrange_id});
kernel.kern(param, {thread_id, ndrange_id});
};
static_cast<naive::HandleImpl*>(fallback_handle)
->dispatch_kern(run, kernel.global_size.total_size());
......
......@@ -100,6 +100,43 @@ public:
T* workspace() const {
return static_cast<T*>(workspace_ptr);
}
//! when format is nchwxx and channel wise, multi group will pack into
//! one group_pack_id. group_pack_size is the number of packed group
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8}
template <typename T>
T* dst(size_t batch_id, size_t group_pack_id,
size_t group_pack_size = 1_z) const{
size_t batch_offset = batch_id * out_bs * dst_type.size();
size_t group_offset = group_pack_size * group_pack_id *
filter_meta.ocpg * osz[0] * osz[1] *
dst_type.size();
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) +
batch_offset + group_offset);
}
template <typename T>
const T* src(size_t batch_id, size_t group_pack_id,
size_t group_pack_size = 1_z) const {
size_t batch_offset = batch_id * inp_bs * src_type.size();
size_t group_offset = group_pack_size * group_pack_id *
filter_meta.icpg * isz[0] * isz[1] *
src_type.size();
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) +
batch_offset + group_offset);
}
template <typename T>
const T* filter(size_t group_pack_id,
size_t pack_group_size = 1_z) const {
size_t group_offset = pack_group_size * group_pack_id *
filter_meta.icpg * filter_meta.ocpg *
filter_meta.spatial[0] *
filter_meta.spatial[1] * filter_type.size();
return reinterpret_cast<T*>(
reinterpret_cast<ptrdiff_t>(filter_ptr) + group_offset);
}
};
static void* const sm_fallback_conv_algo_type;
......
......@@ -58,43 +58,45 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH,
}
} // namespace
#define GET_KERN \
auto fm = param.filter_meta; \
size_t N = param.n; \
size_t IC = param.filter_meta.icpg; \
size_t OC = param.filter_meta.ocpg; \
size_t group = fm.group; \
WorkspaceBundle wbundle = get_bundle(param); \
SmallVector<NCBKern> ret_kerns; \
if (m_large_group) { \
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, \
const NCBKernIndex& ncb_index) { \
auto fm = kern_param.filter_meta; \
size_t IC = fm.icpg; \
size_t OC = fm.ocpg; \
WorkspaceBundle bundle = wbundle; \
for (size_t ic = 0; ic < IC; ic++) { \
copy_padding_kern( \
bundle, kern_param, \
{ncb_index.thread_id, {ncb_index.thread_id, 0, ic}}); \
} \
for (size_t oc = 0; oc < OC; oc++) { \
do_conv_kern( \
bundle, kern_param, \
{ncb_index.thread_id, {ncb_index.thread_id, 0, oc}}); \
} \
}; \
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \
} else { \
WorkspaceBundle bundle = wbundle; \
auto copy_padding = \
std::bind(copy_padding_kern, bundle, std::placeholders::_1, \
std::placeholders::_2); \
ret_kerns.push_back({copy_padding, {group, N, IC}}); \
auto do_conv = std::bind(do_conv_kern, bundle, std::placeholders::_1, \
std::placeholders::_2); \
ret_kerns.push_back({do_conv, {group, N, OC}}); \
} \
#define GET_KERN \
auto fm = param.filter_meta; \
size_t N = param.n; \
size_t IC = param.filter_meta.icpg; \
size_t OC = param.filter_meta.ocpg; \
size_t group = fm.group; \
WorkspaceBundle wbundle = get_bundle(param); \
SmallVector<NCBKern> ret_kerns; \
if (m_large_group) { \
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, \
const NCBKernIndex& ncb_index) { \
auto fm = kern_param.filter_meta; \
size_t IC = fm.icpg; \
size_t OC = fm.ocpg; \
WorkspaceBundle bundle = wbundle; \
for (size_t ic = 0; ic < IC; ic++) { \
copy_padding_kern(bundle, kern_param, ncb_index, \
{ncb_index.thread_id, 0, ic}); \
} \
for (size_t oc = 0; oc < OC; oc++) { \
do_conv_kern(bundle, kern_param, ncb_index, \
{ncb_index.thread_id, 0, oc}); \
} \
}; \
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \
} else { \
auto copy_padding = [wbundle](const NCBKernParam& kern_param, \
const NCBKernIndex& ncb_index) { \
copy_padding_kern(wbundle, kern_param, ncb_index, \
ncb_index.ndrange_id); \
}; \
ret_kerns.push_back({copy_padding, {group, N, IC}}); \
auto do_conv = [wbundle](const NCBKernParam& kern_param, \
const NCBKernIndex& ncb_index) { \
do_conv_kern(wbundle, kern_param, ncb_index, \
ncb_index.ndrange_id); \
}; \
ret_kerns.push_back({do_conv, {group, N, OC}}); \
} \
return ret_kerns;
/* ===================== direct algo ===================== */
......@@ -145,7 +147,8 @@ size_t ConvBiasImpl::AlgoDirect::get_workspace(
//! Process one input channel copy padding
void ConvBiasImpl::AlgoDirect::copy_padding_kern(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
......@@ -160,14 +163,18 @@ void ConvBiasImpl::AlgoDirect::copy_padding_kern(
get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
bool rectify_src = (IH != IH2 || IW != IW2);
size_t padding_group_size = IH2 * IW2 * IC;
const float* sptr = static_cast<const float*>(kern_param.src_ptr) +
ncb_index.ndrange_id[2] * IH * IW;
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0];
size_t channel_id = workspace_ids[2];
const float* sptr = static_cast<const float*>(
kern_param.src<float>(batch_id, group_id)) +
channel_id * IH * IW;
bundle.set(kern_param.workspace_ptr);
//! Used for get the workspace offset
size_t workspace_group_id = ncb_index.ndrange_id[0],
workspace_batch_id = ncb_index.ndrange_id[1],
workspace_channel_id = ncb_index.ndrange_id[2];
size_t workspace_group_id = workspace_ids[0],
workspace_batch_id = workspace_ids[1],
workspace_channel_id = workspace_ids[2];
//! If large group, each thread has its own worspace, set group_id with
//! thread_id
if (rectify_src) {
......@@ -234,7 +241,8 @@ void ConvBiasImpl::AlgoDirect::copy_padding_kern(
//! compute one output channel
void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
......@@ -265,14 +273,16 @@ void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle,
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
bias_offset = 1_z;
}
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
//! Used for get the workspace offset
size_t workspace_group_id = ncb_index.ndrange_id[0],
workspace_batch_id = ncb_index.ndrange_id[1],
oc = ncb_index.ndrange_id[2];
const float* sptr = kern_param.src<float>();
const float* filter = kern_param.filter<float>() + oc * FH * FW * IC;
const float* bias_ptr = kern_param.bias<float>() + oc * bias_offset;
float* dst = kern_param.dst<float>() + oc * OH * OW;
size_t workspace_group_id = workspace_ids[0],
workspace_batch_id = workspace_ids[1], oc = workspace_ids[2];
const float* sptr = kern_param.src<float>(batch_id, group_id);
const float* filter = kern_param.filter<float>(group_id) + oc * FH * FW * IC;
const float* bias_ptr =
kern_param.bias<float>(batch_id, group_id) + oc * bias_offset;
float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW;
if (rectify_src) {
sptr = static_cast<float*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
......@@ -358,7 +368,8 @@ size_t ConvBiasImpl::AlgoDirectStride2::get_workspace(
//! Process one input channel copy padding
void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
......@@ -373,13 +384,17 @@ void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern(
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
bool rectify_src = need_src_copy(kern_param);
size_t padding_group_size = IH2 * IW2 * IC;
const float* sptr = static_cast<const float*>(kern_param.src_ptr) +
ncb_index.ndrange_id[2] * IH * IW;
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t channel_id = workspace_ids[2];
const float* sptr = static_cast<const float*>(
kern_param.src<float>(batch_id, group_id)) +
channel_id * IH * IW;
bundle.set(kern_param.workspace_ptr);
//! Used for get the workspace offset
size_t workspace_group_id = ncb_index.ndrange_id[0],
workspace_batch_id = ncb_index.ndrange_id[1],
workspace_channel_id = ncb_index.ndrange_id[2];
size_t workspace_group_id = workspace_ids[0],
workspace_batch_id = workspace_ids[1],
workspace_channel_id = workspace_ids[2];
if (rectify_src) {
//! copy to sptr_base to eliminate padding effect
float* sptr_base = static_cast<float*>(bundle.get(0)) +
......@@ -397,7 +412,7 @@ void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern(
//! compute one output channel
void ConvBiasImpl::AlgoDirectStride2::do_conv_kern(
WorkspaceBundle bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
......@@ -439,14 +454,17 @@ void ConvBiasImpl::AlgoDirectStride2::do_conv_kern(
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
bias_offset = 1_z;
}
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
//! Used for get the workspace offset
size_t workspace_group_id = ncb_index.ndrange_id[0],
workspace_batch_id = ncb_index.ndrange_id[1],
oc = ncb_index.ndrange_id[2];
const float* sptr = kern_param.src<float>();
const float* filter = kern_param.filter<float>() + oc * FH * FW * IC;
const float* bias_ptr = kern_param.bias<float>() + oc * bias_offset;
float* dst = kern_param.dst<float>() + oc * OH * OW;
size_t workspace_group_id = workspace_ids[0],
workspace_batch_id = workspace_ids[1], oc = workspace_ids[2];
const float* sptr = kern_param.src<float>(batch_id, group_id);
const float* filter =
kern_param.filter<float>(group_id) + oc * FH * FW * IC;
const float* bias_ptr =
kern_param.bias<float>(batch_id, group_id) + oc * bias_offset;
float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW;
if (rectify_src) {
sptr = static_cast<float*>(bundle.get(0)) +
workspace_group_id * padding_group_size +
......@@ -547,23 +565,22 @@ MatrixMul* ConvBiasImpl::AlgoMatrixMul::get_matmul_opr() {
}
void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex&) {
const NCBKernIndex& ncb_index) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
bool is_xcorr = !param.filter_meta.should_flip;
auto bundle = get_bundle(param);
bundle.set(param.workspace_ptr);
// workspace = tmp..src2
for (size_t n = 0; n < N; ++n) {
float* src = const_cast<float*>(param.src<float>()) + n * param.inp_bs;
float* dst = param.dst<float>() + n * param.out_bs;
float* bias_ptr =
static_cast<float*>(const_cast<void*>(param.bias_ptr));
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ptr += n * param.out_bs;
}
float* src = const_cast<float*>(param.src<float>(n, group_id));
float* dst = param.dst<float>(n, group_id);
float* bias_ptr = static_cast<float*>(
const_cast<void*>(param.bias<void>(n, group_id)));
float *B, *src2;
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
// special case: 1x1
......@@ -613,7 +630,7 @@ void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param,
{
TensorND A_, B_, C_;
A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Float32());
A_.raw_ptr = const_cast<float*>(param.filter<float>());
A_.raw_ptr = const_cast<float*>(param.filter<float>(group_id));
B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Float32());
B_.raw_ptr = B;
C_.layout = TensorLayout({OC, OH * OW}, dtype::Float32());
......
......@@ -22,10 +22,12 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase {
static void copy_padding_kern(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index);
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
static void do_conv_kern(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index);
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
bool m_large_group;
public:
......@@ -57,10 +59,12 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase {
static void copy_padding_kern(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index);
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
static void do_conv_kern(WorkspaceBundle bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index);
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
bool m_large_group;
public:
......
......@@ -146,9 +146,11 @@ WorkspaceBundle ConvBiasImpl::AlgoMkldnnQint8::get_bundle(
} while (0)
void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32(
const NCBKernParam& param, const NCBKernIndex&) {
const NCBKernParam& param, const NCBKernIndex& ncb_index) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
auto x86_handle = static_cast<HandleImpl*>(inplace_cpu_handle().get());
megdnn_assert(x86_handle != nullptr, "x86 handle can not be null");
auto eng_mkldnn = x86_handle->mkldnn_engine();
......@@ -167,10 +169,11 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32(
auto megdnn_dst_md = memory::desc({dst_shape}, memory::data_type::s32,
memory::format_tag::nchw);
auto megdnn_weight_memory = memory(megdnn_weight_md, eng_mkldnn,
const_cast<void*>(param.filter_ptr));
int8_t* src = const_cast<int8_t*>(param.src<int8_t>());
int32_t* dst = param.dst<int32_t>();
auto megdnn_weight_memory =
memory(megdnn_weight_md, eng_mkldnn,
const_cast<void*>(param.filter<void>(group_id)));
int8_t* src = const_cast<int8_t*>(param.src<int8_t>(batch_id, group_id));
int32_t* dst = param.dst<int32_t>(batch_id, group_id);
auto megdnn_src_memory =
memory(megdnn_src_md, eng_mkldnn, static_cast<void*>(src));
......@@ -353,18 +356,18 @@ MatrixMul* ConvBiasImpl::AlgoMkldnnMatmulQint8::get_matmul_opr() {
}
void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32(
const NCBKernParam& param, const NCBKernIndex&) {
const NCBKernParam& param, const NCBKernIndex& ncb_index) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
bool is_xcorr = !param.filter_meta.should_flip;
auto bundle = get_bundle(param);
bundle.set(param.workspace_ptr);
for (size_t n = 0; n < N; ++n) {
int8_t* src =
const_cast<int8_t*>(param.src<int8_t>()) + n * param.inp_bs;
int32_t* dst = param.dst<int32_t>() + n * param.out_bs;
int8_t* src = const_cast<int8_t*>(param.src<int8_t>(n, group_id));
int32_t* dst = param.dst<int32_t>(n, group_id);
int8_t *B, *src2;
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
// special case: 1x1
......@@ -414,7 +417,7 @@ void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32(
{
TensorND A_, B_, C_;
A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Int8());
A_.raw_ptr = const_cast<int8_t*>(param.filter<int8_t>());
A_.raw_ptr = const_cast<int8_t*>(param.filter<int8_t>(group_id));
B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Int8());
B_.raw_ptr = B;
C_.layout = TensorLayout({OC, OH * OW}, dtype::Int32());
......
......@@ -47,8 +47,8 @@ void pack_src_conv_avx2_stride1(WorkspaceBundle bundle,
batch_id = ncb_index.ndrange_id[1],
channel_id = ncb_index.ndrange_id[2];
const int8_t* src_ptr =
kern_param.src<int8_t>() + ic_step * channel_id * c_stride;
const int8_t* src_ptr = kern_param.src<int8_t>(batch_id, group_id) +
ic_step * channel_id * c_stride;
bundle.set(kern_param.workspace_ptr);
int8_t* packed_src = static_cast<int8_t*>(bundle.get(0)) +
batch_id * group * packed_group_size +
......@@ -129,7 +129,7 @@ static inline void pack_filter_conv_avx2_stride1(
size_t group_id = ncb_index.ndrange_id[0],
oc_index_id = ncb_index.ndrange_id[1];
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>();
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(group_id);
bundle.set(kern_param.workspace_ptr);
int16_t* out_ptr = static_cast<int16_t*>(bundle.get(1)) +
group_id * round_up(oc, oc_step) * oc_out_stride;
......@@ -632,19 +632,18 @@ void do_conv_kern(WorkspaceBundle bundle,
const uint32_t packed_group_size =
div_ceil(ic, ic_step) * pack_ih * pack_iw;
size_t workspace_group_id = ncb_index.ndrange_id[0],
workspace_batch_id = ncb_index.ndrange_id[1],
workspace_channel_id = ncb_index.ndrange_id[2];
size_t group_id = ncb_index.ndrange_id[0],
batch_id = ncb_index.ndrange_id[1],
channel_id = ncb_index.ndrange_id[2];
bundle.set(kern_param.workspace_ptr);
int8_t* src_ptr = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * packed_group_size +
workspace_batch_id * group * packed_group_size;
int16_t* filter_ptr =
static_cast<int16_t*>(bundle.get(1)) +
workspace_group_id * round_up(oc, oc_step) * filter_round_size +
oc_step * workspace_channel_id * filter_round_size;
group_id * packed_group_size +
batch_id * group * packed_group_size;
int16_t* filter_ptr = static_cast<int16_t*>(bundle.get(1)) +
group_id * round_up(oc, oc_step) * filter_round_size +
oc_step * channel_id * filter_round_size;
bool need_post_process =
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
......@@ -652,12 +651,11 @@ void do_conv_kern(WorkspaceBundle bundle,
int32_t* dst_tptr = nullptr;
if (need_post_process) {
dst_tptr = static_cast<int32_t*>(bundle.get(2)) +
workspace_batch_id * group * oc * oc_stride +
workspace_group_id * oc * oc_stride +
oc_step * workspace_channel_id * oh * ow;
batch_id * group * oc * oc_stride +
group_id * oc * oc_stride + oc_step * channel_id * oh * ow;
} else {
dst_tptr = kern_param.dst<int32_t>() +
oc_step * workspace_channel_id * oh * ow;
dst_tptr = kern_param.dst<int32_t>(batch_id, group_id) +
oc_step * channel_id * oh * ow;
}
const uint32_t oc_end = oc / oc_step * oc_step;
......@@ -666,7 +664,7 @@ void do_conv_kern(WorkspaceBundle bundle,
const uint32_t oh_remain = oh - oh_end;
const uint32_t ow_end = ow / ow_step * ow_step;
const uint32_t ow_remain = ow - ow_end;
const uint32_t oc_index = oc_step * workspace_channel_id;
const uint32_t oc_index = oc_step * channel_id;
AlgoAVX2DirectConvStride1S8S8S32_forward<oc_step, ic_step, oh_step,
ow_step>(
......@@ -684,29 +682,29 @@ void do_post_process(WorkspaceBundle bundle,
const uint32_t oh = kern_param.osz[0];
const uint32_t ow = kern_param.osz[1];
size_t workspace_group_id = ncb_index.ndrange_id[0],
workspace_batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0],
batch_id = ncb_index.ndrange_id[1];
bundle.set(kern_param.workspace_ptr);
bool need_post_process =
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
void* dst_tptr = nullptr;
if (need_post_process) {
dst_tptr = static_cast<int32_t*>(bundle.get(2)) +
workspace_batch_id * group * oc * oh * ow +
workspace_group_id * oc * oh * ow;
batch_id * group * oc * oh * ow + group_id * oc * oh * ow;
} else {
dst_tptr = kern_param.dst<dt_int32>();
dst_tptr = kern_param.dst<dt_int32>(batch_id, group_id);
}
void* dst_ptr = kern_param.dst<void>(batch_id, group_id);
#define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \
{ \
const dt_int32* bias_ptr = kern_param.bias<dt_int32>(); \
const dt_int32* bias_ptr = \
kern_param.bias<dt_int32>(batch_id, group_id); \
PostProcess<DTypeTrait<_bias_ctype>::ctype, \
DTypeTrait<_dst_ctype>::ctype, \
_postprocess_mode>::run(dst_tptr, \
const_cast<dt_int32*>(bias_ptr), \
kern_param.dst_ptr, \
kern_param.bias_mode, \
dst_ptr, kern_param.bias_mode, \
kern_param.nonlineMode, \
kern_param.bias_type, \
kern_param.dst_type, 1, oc, oh, \
......
......@@ -45,8 +45,8 @@ void pack_src_conv_avx2_stride2(WorkspaceBundle bundle,
batch_id = ncb_index.ndrange_id[1],
channel_id = ncb_index.ndrange_id[2];
const int8_t* src_ptr =
kern_param.src<int8_t>() + ic_step * channel_id * c_stride;
const int8_t* src_ptr = kern_param.src<int8_t>(batch_id, group_id) +
ic_step * channel_id * c_stride;
bundle.set(kern_param.workspace_ptr);
int8_t* packed_src = static_cast<int8_t*>(bundle.get(0)) +
batch_id * group * packed_group_size +
......@@ -187,7 +187,7 @@ static inline void pack_filter_conv_avx2_stride2(
size_t group_id = ncb_index.ndrange_id[0],
oc_index_id = ncb_index.ndrange_id[1];
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>();
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(group_id);
bundle.set(kern_param.workspace_ptr);
int16_t* out_ptr = static_cast<int16_t*>(bundle.get(1)) +
group_id * round_up(oc, oc_step) * oc_out_stride;
......@@ -705,18 +705,17 @@ void kernel_imp(WorkspaceBundle bundle,
const uint32_t packed_group_size =
div_ceil(ic, ic_step) * pack_ih * pack_iw;
size_t workspace_group_id = ncb_index.ndrange_id[0],
workspace_batch_id = ncb_index.ndrange_id[1],
workspace_channel_id = ncb_index.ndrange_id[2];
size_t group_id = ncb_index.ndrange_id[0],
batch_id = ncb_index.ndrange_id[1],
channel_id = ncb_index.ndrange_id[2];
bundle.set(kern_param.workspace_ptr);
int8_t* src_ptr = static_cast<int8_t*>(bundle.get(0)) +
workspace_group_id * packed_group_size +
workspace_batch_id * group * packed_group_size;
int16_t* filter_ptr =
static_cast<int16_t*>(bundle.get(1)) +
workspace_group_id * round_up(oc, oc_step) * filter_round_size +
oc_step * workspace_channel_id * filter_round_size;
group_id * packed_group_size +
batch_id * group * packed_group_size;
int16_t* filter_ptr = static_cast<int16_t*>(bundle.get(1)) +
group_id * round_up(oc, oc_step) * filter_round_size +
oc_step * channel_id * filter_round_size;
bool need_post_process =
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
......@@ -724,12 +723,11 @@ void kernel_imp(WorkspaceBundle bundle,
int32_t* dst_tptr = nullptr;
if (need_post_process) {
dst_tptr = static_cast<int32_t*>(bundle.get(2)) +
workspace_batch_id * group * oc * oc_stride +
workspace_group_id * oc * oc_stride +
oc_step * workspace_channel_id * oh * ow;
batch_id * group * oc * oc_stride +
group_id * oc * oc_stride + oc_step * channel_id * oh * ow;
} else {
dst_tptr = kern_param.dst<int32_t>() +
oc_step * workspace_channel_id * oh * ow;
dst_tptr = kern_param.dst<int32_t>(batch_id, group_id) +
oc_step * channel_id * oh * ow;
}
const uint32_t oc_end = oc / oc_step * oc_step;
const uint32_t oc_remain = oc - oc_end;
......@@ -737,7 +735,7 @@ void kernel_imp(WorkspaceBundle bundle,
const uint32_t oh_remain = oh - oh_end;
const uint32_t ow_end = ow / ow_step * ow_step;
const uint32_t ow_remain = ow - ow_end;
const uint32_t oc_index = oc_step * workspace_channel_id;
const uint32_t oc_index = oc_step * channel_id;
kernel_handle_oh_remain<oc_step, ic_step, oh_step, ow_step>(
oh_remain, oc_remain, ow_remain, filter_ptr, src_ptr, dst_tptr,
......@@ -754,8 +752,8 @@ void do_post_process(WorkspaceBundle bundle,
const uint32_t oh = kern_param.osz[0];
const uint32_t ow = kern_param.osz[1];
size_t workspace_group_id = ncb_index.ndrange_id[0],
workspace_batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0],
batch_id = ncb_index.ndrange_id[1];
bundle.set(kern_param.workspace_ptr);
bool need_post_process =
......@@ -763,21 +761,22 @@ void do_post_process(WorkspaceBundle bundle,
void* dst_tptr = nullptr;
if (need_post_process) {
dst_tptr = static_cast<int32_t*>(bundle.get(2)) +
workspace_batch_id * group * oc * oh * ow +
workspace_group_id * oc * oh * ow;
batch_id * group * oc * oh * ow +
group_id * oc * oh * ow;
} else {
dst_tptr = kern_param.dst<dt_int32>();
dst_tptr = kern_param.dst<dt_int32>(batch_id, group_id);
}
void* dst_ptr = kern_param.dst<void>(batch_id, group_id);
#define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \
{ \
const dt_int32* bias_ptr = kern_param.bias<dt_int32>(); \
const dt_int32* bias_ptr = \
kern_param.bias<dt_int32>(batch_id, group_id); \
PostProcess<DTypeTrait<_bias_ctype>::ctype, \
DTypeTrait<_dst_ctype>::ctype, \
_postprocess_mode>::run(dst_tptr, \
const_cast<dt_int32*>(bias_ptr), \
kern_param.dst_ptr, \
kern_param.bias_mode, \
dst_ptr, kern_param.bias_mode, \
kern_param.nonlineMode, \
kern_param.bias_type, \
kern_param.dst_type, 1, oc, oh, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册