Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
481a6cbb
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
481a6cbb
编写于
6月 21, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(x86): make nchw44 happly on x86
GitOrigin-RevId: f10f51d3a2ddab296ea42a08d8f3799f1a6b748f
上级
5873d5f5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
178 addition
and
27 deletion
+178
-27
dnn/src/x86/conv_bias/postprocess_helper.h
dnn/src/x86/conv_bias/postprocess_helper.h
+49
-26
dnn/src/x86/elemwise_op.h
dnn/src/x86/elemwise_op.h
+129
-1
未找到文件。
dnn/src/x86/conv_bias/postprocess_helper.h
浏览文件 @
481a6cbb
...
...
@@ -32,7 +32,7 @@ namespace x86 {
thin_function<void(const ctype*, ctype*, DType, DType, size_t)> run = \
OpCallerUnary<_op<_simd_type, ctype, ctype>, _simd_type>::run; \
run(static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW);
bias_type, dst_type, N* OC* OH* OW
* pack_oc_size
);
#define CALL_BINARY_BROADCAST(_op, _simd_type) \
thin_function<void( \
...
...
@@ -45,6 +45,17 @@ namespace x86 {
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW);
#define CALL_BINARY_BROADCAST_NCHWXX(_op, _simd_type) \
thin_function<void( \
const ctype*, const ctype*, ctype*, DType, DType, DType, size_t, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_simd_type, ctype, ctype>, _simd_type, \
megdnn::x86::BcastType::VEC_BCAST101xX>::run; \
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);
#define CALL_BINARY(_op, _simd_type) \
thin_function<void( \
const ctype*, const ctype*, ctype*, DType, DType, DType, size_t)> \
...
...
@@ -53,7 +64,7 @@ namespace x86 {
megdnn::x86::BcastType::VEC_VEC>::run; \
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW);
N* OC* OH* OW
* pack_oc_size
);
#define cb_unary(_simd_type) \
if (elem_mode == megdnn::param::Elemwise::Mode::RELU) { \
...
...
@@ -93,19 +104,24 @@ namespace x86 {
cb_binary(CALLER, SIMDType::NONE) \
}
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::NO_BIAS: \
FOR_NONLINEAR_NOBIAS(); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
FOR_NONLINEAR(CALL_BINARY_BROADCAST); \
break; \
case BiasMode::BIAS: \
FOR_NONLINEAR(CALL_BINARY); \
break; \
default: \
break; \
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::NO_BIAS: \
FOR_NONLINEAR_NOBIAS(); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
if (pack_oc_size == 1) { \
FOR_NONLINEAR(CALL_BINARY_BROADCAST); \
} else { \
megdnn_assert(pack_oc_size == 4, "Only support nchw44 in x86"); \
FOR_NONLINEAR(CALL_BINARY_BROADCAST_NCHWXX); \
} \
break; \
case BiasMode::BIAS: \
FOR_NONLINEAR(CALL_BINARY); \
break; \
default: \
break; \
}
template
<
...
...
@@ -119,7 +135,9 @@ struct PostProcess {
DType
dst_type
,
size_t
N
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
pack_oc_size
=
1
)
{
MEGDNN_MARK_USED_VAR
(
pack_oc_size
);
megdnn_assert
(
pack_oc_size
==
1
,
"PostProcess only support nchw in x86"
);
megdnn_assert
(
pack_oc_size
==
1
||
pack_oc_size
==
4
,
"PostProcess only support nchw/44 in x86"
);
megdnn
::
param
::
Elemwise
::
Mode
elem_mode
=
megdnn
::
param
::
Elemwise
::
Mode
::
ADD
;
if
(
bias_mode
!=
megdnn
::
ConvBiasForward
::
BiasMode
::
NO_BIAS
)
{
switch
(
nonlineMode
)
{
...
...
@@ -320,16 +338,21 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> {
CALLER(AddOp, SIMDType::NONE) \
}
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::BIAS: \
FOR_SIMD(CALL_BINARY); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
FOR_SIMD(CALL_BINARY_BROADCAST); \
break; \
default: \
break; \
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::BIAS: \
FOR_SIMD(CALL_BINARY); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
if (pack_oc_size == 1) { \
FOR_SIMD(CALL_BINARY_BROADCAST); \
} else { \
megdnn_assert(pack_oc_size == 4, "Only support nchw44 in x86"); \
FOR_SIMD(CALL_BINARY_BROADCAST_NCHWXX); \
} \
break; \
default: \
break; \
}
template
<
typename
ctype
,
typename
dtype
>
...
...
dnn/src/x86/elemwise_op.h
浏览文件 @
481a6cbb
...
...
@@ -53,6 +53,33 @@ cb(dt_int8, __m256i, "avx2", int8_t, __m256i, mm256, si256, epi8, SIMDType::AVX2
cb
(
dt_uint8
,
__m256i
,
"avx2"
,
uint8_t
,
__m256i
,
mm256
,
si256
,
epi8
,
SIMDType
::
AVX2
);
cb
(
dt_float32
,
float
,
"avx2"
,
float
,
__m256
,
mm256
,
ps
,
ps
,
SIMDType
::
AVX2
);
#undef cb
//! visitor for handle BCAST101xX(4) at AVX2, load 128, broadcast to 256
template
<
typename
ctype
,
SIMDType
simd_type
=
SIMDType
::
AVX2
>
struct
ParamElemVisitorHalfBoardCast
;
#define cb( \
_ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \
template <> \
struct ParamElemVisitorHalfBoardCast<_ctype, SIMDType::AVX2> { \
MEGDNN_ATTRIBUTE_TARGET("avx2") \
_simd_type operator()(const _ctype* src) const { \
half_type tmp = \
load_half_fuc(reinterpret_cast<_simd_ptr_type const*>(src)); \
return board_cast_func(tmp, tmp); \
} \
}
cb
(
dt_qint32
,
__m128i
,
_mm_loadu_si128
,
__m128i
,
__m256i
,
_mm256_set_m128i
);
cb
(
dt_qint8
,
__m128i
,
_mm_loadu_si128
,
__m128i
,
__m256i
,
_mm256_set_m128i
);
cb
(
dt_quint8
,
__m128i
,
_mm_loadu_si128
,
__m128i
,
__m256i
,
_mm256_set_m128i
);
cb
(
dt_int32
,
__m128i
,
_mm_loadu_si128
,
__m128i
,
__m256i
,
_mm256_set_m128i
);
cb
(
dt_int16
,
__m128i
,
_mm_loadu_si128
,
__m128i
,
__m256i
,
_mm256_set_m128i
);
cb
(
dt_int8
,
__m128i
,
_mm_loadu_si128
,
__m128i
,
__m256i
,
_mm256_set_m128i
);
cb
(
dt_uint8
,
__m128i
,
_mm_loadu_si128
,
__m128i
,
__m256i
,
_mm256_set_m128i
);
cb
(
dt_float32
,
float
,
_mm_load_ps
,
__m128
,
__m256
,
_mm256_set_m128
);
#undef cb
/*!
* \brief broadcast type
...
...
@@ -71,7 +98,8 @@ enum BcastType {
BCAST101_VEC_BCAST101
,
VEC_BCAST101_VEC
,
VEC_SCALAR_VEC
,
VEC_SCALAR_SCALAR
VEC_SCALAR_SCALAR
,
VEC_BCAST101xX
};
///////////////////////////////// OpCaller /////////////////////////////
...
...
@@ -227,6 +255,106 @@ struct OpCallerBinary<Op, SIMDType::NONE, VEC_BCAST101> {
};
#undef OP_CALLER
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
SIMDType
::
SSE4_2
,
VEC_BCAST101xX
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"sse4.2"
)
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
,
size_t
channel_block_dim
)
{
megdnn_assert
(
channel_block_dim
==
4
,
"only imp for nchw44"
);
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
ParamElemVisitor
<
typename
Op
::
src_ctype
,
SIMDType
::
SSE4_2
>
vis0
;
ParamElemVisitor
<
typename
Op
::
src_ctype
,
SIMDType
::
SSE4_2
>
vis1
;
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
const
typename
Op
::
src_ctype
*
src1_ptr
=
src1
;
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
auto
src1_block_ptr
=
src1_ptr
+
c
*
channel_block_dim
;
auto
channel_block_vec
=
vis1
(
src1_block_ptr
);
size_t
img_index
=
0
;
auto
src0_offset
=
Op
::
SIMD_WIDTH
/
channel_block_dim
;
for
(;
img_index
+
2
*
src0_offset
<=
channel_stride
;
img_index
+=
2
*
src0_offset
)
{
op
({{
vis0
(
src0
),
vis0
(
src0
+
Op
::
SIMD_WIDTH
)}},
{{
channel_block_vec
,
channel_block_vec
}},
dst
);
src0
+=
Op
::
SIMD_WIDTH
*
2
;
dst
+=
Op
::
SIMD_WIDTH
*
2
;
}
for
(;
img_index
<
channel_stride
;
img_index
++
)
{
for
(
size_t
c_iter
=
0
;
c_iter
<
channel_block_dim
;
c_iter
++
)
{
op
(
*
src0
,
*
(
src1_block_ptr
+
c_iter
),
dst
);
src0
++
;
dst
++
;
}
}
}
}
}
};
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
SIMDType
::
AVX2
,
VEC_BCAST101xX
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
,
size_t
channel_block_dim
)
{
megdnn_assert
(
channel_block_dim
==
4
,
"only imp for nchw44"
);
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
ParamElemVisitor
<
typename
Op
::
src_ctype
,
SIMDType
::
AVX2
>
vis0
;
ParamElemVisitorHalfBoardCast
<
typename
Op
::
src_ctype
,
SIMDType
::
AVX2
>
vis1
;
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
const
typename
Op
::
src_ctype
*
src1_ptr
=
src1
;
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
auto
src1_block_ptr
=
src1_ptr
+
c
*
channel_block_dim
;
auto
channel_block_vec
=
vis1
(
src1_block_ptr
);
size_t
img_index
=
0
;
auto
src0_offset
=
Op
::
SIMD_WIDTH
/
channel_block_dim
;
for
(;
img_index
+
2
*
src0_offset
<=
channel_stride
;
img_index
+=
2
*
src0_offset
)
{
op
({{
vis0
(
src0
),
vis0
(
src0
+
Op
::
SIMD_WIDTH
)}},
{{
channel_block_vec
,
channel_block_vec
}},
dst
);
src0
+=
Op
::
SIMD_WIDTH
*
2
;
dst
+=
Op
::
SIMD_WIDTH
*
2
;
}
for
(;
img_index
<
channel_stride
;
img_index
++
)
{
for
(
size_t
c_iter
=
0
;
c_iter
<
channel_block_dim
;
c_iter
++
)
{
op
(
*
src0
,
*
(
src1_block_ptr
+
c_iter
),
dst
);
src0
++
;
dst
++
;
}
}
}
}
}
};
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
SIMDType
::
NONE
,
VEC_BCAST101xX
>
{
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
,
size_t
channel_block_dim
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
auto
src1_ptr
=
src1
;
for
(
size_t
cb
=
0
;
cb
<
channel
;
cb
++
)
{
auto
src1_block_ptr
=
src1_ptr
+
cb
*
channel_block_dim
;
for
(
size_t
img_index
=
0
;
img_index
<
channel_stride
;
img_index
++
)
{
for
(
size_t
c_iter
=
0
;
c_iter
<
channel_block_dim
;
c_iter
++
)
{
op
(
*
src0
,
*
(
src1_block_ptr
+
c_iter
),
dst
);
src0
++
;
dst
++
;
}
}
}
}
}
};
#define OP_CALLER(simd_type, target_simd) \
template <typename Op> \
struct OpCallerBinary<Op, simd_type, VEC_SCALAR> { \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录