提交 481a6cbb 编写于 作者: M Megvii Engine Team

feat(x86): make nchw44 happly on x86

GitOrigin-RevId: f10f51d3a2ddab296ea42a08d8f3799f1a6b748f
上级 5873d5f5
......@@ -32,7 +32,7 @@ namespace x86 {
thin_function<void(const ctype*, ctype*, DType, DType, size_t)> run = \
OpCallerUnary<_op<_simd_type, ctype, ctype>, _simd_type>::run; \
run(static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW);
bias_type, dst_type, N* OC* OH* OW* pack_oc_size);
#define CALL_BINARY_BROADCAST(_op, _simd_type) \
thin_function<void( \
......@@ -45,6 +45,17 @@ namespace x86 {
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW);
#define CALL_BINARY_BROADCAST_NCHWXX(_op, _simd_type) \
thin_function<void( \
const ctype*, const ctype*, ctype*, DType, DType, DType, size_t, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_simd_type, ctype, ctype>, _simd_type, \
megdnn::x86::BcastType::VEC_BCAST101xX>::run; \
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);
#define CALL_BINARY(_op, _simd_type) \
thin_function<void( \
const ctype*, const ctype*, ctype*, DType, DType, DType, size_t)> \
......@@ -53,7 +64,7 @@ namespace x86 {
megdnn::x86::BcastType::VEC_VEC>::run; \
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW);
N* OC* OH* OW* pack_oc_size);
#define cb_unary(_simd_type) \
if (elem_mode == megdnn::param::Elemwise::Mode::RELU) { \
......@@ -93,19 +104,24 @@ namespace x86 {
cb_binary(CALLER, SIMDType::NONE) \
}
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::NO_BIAS: \
FOR_NONLINEAR_NOBIAS(); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
FOR_NONLINEAR(CALL_BINARY_BROADCAST); \
break; \
case BiasMode::BIAS: \
FOR_NONLINEAR(CALL_BINARY); \
break; \
default: \
break; \
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::NO_BIAS: \
FOR_NONLINEAR_NOBIAS(); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
if (pack_oc_size == 1) { \
FOR_NONLINEAR(CALL_BINARY_BROADCAST); \
} else { \
megdnn_assert(pack_oc_size == 4, "Only support nchw44 in x86"); \
FOR_NONLINEAR(CALL_BINARY_BROADCAST_NCHWXX); \
} \
break; \
case BiasMode::BIAS: \
FOR_NONLINEAR(CALL_BINARY); \
break; \
default: \
break; \
}
template <
......@@ -119,7 +135,9 @@ struct PostProcess {
DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn_assert(pack_oc_size == 1, "PostProcess only support nchw in x86");
megdnn_assert(
pack_oc_size == 1 || pack_oc_size == 4,
"PostProcess only support nchw/44 in x86");
megdnn::param::Elemwise::Mode elem_mode = megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
switch (nonlineMode) {
......@@ -320,16 +338,21 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> {
CALLER(AddOp, SIMDType::NONE) \
}
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::BIAS: \
FOR_SIMD(CALL_BINARY); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
FOR_SIMD(CALL_BINARY_BROADCAST); \
break; \
default: \
break; \
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::BIAS: \
FOR_SIMD(CALL_BINARY); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
if (pack_oc_size == 1) { \
FOR_SIMD(CALL_BINARY_BROADCAST); \
} else { \
megdnn_assert(pack_oc_size == 4, "Only support nchw44 in x86"); \
FOR_SIMD(CALL_BINARY_BROADCAST_NCHWXX); \
} \
break; \
default: \
break; \
}
template <typename ctype, typename dtype>
......
......@@ -53,6 +53,33 @@ cb(dt_int8, __m256i, "avx2", int8_t, __m256i, mm256, si256, epi8, SIMDType::AVX2
cb(dt_uint8, __m256i, "avx2", uint8_t, __m256i, mm256, si256, epi8, SIMDType::AVX2);
cb(dt_float32, float, "avx2", float, __m256, mm256, ps, ps, SIMDType::AVX2);
#undef cb
//! visitor for handle BCAST101xX(4) at AVX2, load 128, broadcast to 256
template <typename ctype, SIMDType simd_type = SIMDType::AVX2>
struct ParamElemVisitorHalfBoardCast;
#define cb( \
_ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \
template <> \
struct ParamElemVisitorHalfBoardCast<_ctype, SIMDType::AVX2> { \
MEGDNN_ATTRIBUTE_TARGET("avx2") \
_simd_type operator()(const _ctype* src) const { \
half_type tmp = \
load_half_fuc(reinterpret_cast<_simd_ptr_type const*>(src)); \
return board_cast_func(tmp, tmp); \
} \
}
cb(dt_qint32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_qint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_quint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_int32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_int16, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_int8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_uint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128);
#undef cb
/*!
* \brief broadcast type
......@@ -71,7 +98,8 @@ enum BcastType {
BCAST101_VEC_BCAST101,
VEC_BCAST101_VEC,
VEC_SCALAR_VEC,
VEC_SCALAR_SCALAR
VEC_SCALAR_SCALAR,
VEC_BCAST101xX
};
///////////////////////////////// OpCaller /////////////////////////////
......@@ -227,6 +255,106 @@ struct OpCallerBinary<Op, SIMDType::NONE, VEC_BCAST101> {
};
#undef OP_CALLER
template <typename Op>
struct OpCallerBinary<Op, SIMDType::SSE4_2, VEC_BCAST101xX> {
MEGDNN_ATTRIBUTE_TARGET("sse4.2")
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride,
size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4, "only imp for nchw44");
Op op(src0_dtype, src1_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype, SIMDType::SSE4_2> vis0;
ParamElemVisitor<typename Op::src_ctype, SIMDType::SSE4_2> vis1;
for (size_t b = 0; b < batch; b++) {
const typename Op::src_ctype* src1_ptr = src1;
for (size_t c = 0; c < channel; c++) {
auto src1_block_ptr = src1_ptr + c * channel_block_dim;
auto channel_block_vec = vis1(src1_block_ptr);
size_t img_index = 0;
auto src0_offset = Op::SIMD_WIDTH / channel_block_dim;
for (; img_index + 2 * src0_offset <= channel_stride;
img_index += 2 * src0_offset) {
op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}},
{{channel_block_vec, channel_block_vec}}, dst);
src0 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
for (; img_index < channel_stride; img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
op(*src0, *(src1_block_ptr + c_iter), dst);
src0++;
dst++;
}
}
}
}
}
};
template <typename Op>
struct OpCallerBinary<Op, SIMDType::AVX2, VEC_BCAST101xX> {
MEGDNN_ATTRIBUTE_TARGET("avx2")
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride,
size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4, "only imp for nchw44");
Op op(src0_dtype, src1_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype, SIMDType::AVX2> vis0;
ParamElemVisitorHalfBoardCast<typename Op::src_ctype, SIMDType::AVX2> vis1;
for (size_t b = 0; b < batch; b++) {
const typename Op::src_ctype* src1_ptr = src1;
for (size_t c = 0; c < channel; c++) {
auto src1_block_ptr = src1_ptr + c * channel_block_dim;
auto channel_block_vec = vis1(src1_block_ptr);
size_t img_index = 0;
auto src0_offset = Op::SIMD_WIDTH / channel_block_dim;
for (; img_index + 2 * src0_offset <= channel_stride;
img_index += 2 * src0_offset) {
op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}},
{{channel_block_vec, channel_block_vec}}, dst);
src0 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
for (; img_index < channel_stride; img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
op(*src0, *(src1_block_ptr + c_iter), dst);
src0++;
dst++;
}
}
}
}
}
};
template <typename Op>
struct OpCallerBinary<Op, SIMDType::NONE, VEC_BCAST101xX> {
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride,
size_t channel_block_dim) {
Op op(src0_dtype, src1_dtype, dst_dtype);
for (size_t b = 0; b < batch; b++) {
auto src1_ptr = src1;
for (size_t cb = 0; cb < channel; cb++) {
auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
for (size_t img_index = 0; img_index < channel_stride; img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) {
op(*src0, *(src1_block_ptr + c_iter), dst);
src0++;
dst++;
}
}
}
}
}
};
#define OP_CALLER(simd_type, target_simd) \
template <typename Op> \
struct OpCallerBinary<Op, simd_type, VEC_SCALAR> { \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册