Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c9986df5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c9986df5
编写于
5月 11, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): add fp32 nchw_nchw44 conv
GitOrigin-RevId: f19fe892d9f3e4c166d4835804bf5fc0ad31ccbc
上级
ca855d8d
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
1231 addition
and
100 deletion
+1231
-100
dnn/src/arm_common/conv_bias/fp32/algos.h
dnn/src/arm_common/conv_bias/fp32/algos.h
+20
-2
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
...on/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
+317
-0
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp
...on/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp
+430
-0
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
...mmon/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
+38
-0
dnn/src/arm_common/conv_bias/intrinsic_helper.h
dnn/src/arm_common/conv_bias/intrinsic_helper.h
+295
-62
dnn/src/arm_common/conv_bias/neon_struct.h
dnn/src/arm_common/conv_bias/neon_struct.h
+11
-0
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+2
-0
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+1
-0
dnn/src/arm_common/elemwise_helper/kimpl/hswish.h
dnn/src/arm_common/elemwise_helper/kimpl/hswish.h
+6
-1
dnn/src/arm_common/elemwise_helper/kimpl/none.h
dnn/src/arm_common/elemwise_helper/kimpl/none.h
+9
-1
dnn/src/arm_common/elemwise_helper/kimpl/relu.h
dnn/src/arm_common/elemwise_helper/kimpl/relu.h
+5
-0
dnn/src/arm_common/simd_macro/marm_neon.h
dnn/src/arm_common/simd_macro/marm_neon.h
+33
-0
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+53
-31
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+11
-3
未找到文件。
dnn/src/arm_common/conv_bias/fp32/algos.h
浏览文件 @
c9986df5
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -156,7 +157,6 @@ private:
...
@@ -156,7 +157,6 @@ private:
uint32_t
m_tile_size
;
uint32_t
m_tile_size
;
};
};
class
ConvBiasImpl
::
AlgoF32Direct
final
:
public
AlgoBase
{
class
ConvBiasImpl
::
AlgoF32Direct
final
:
public
AlgoBase
{
SmallVector
<
NCBKern
>
get_kimpls
(
const
NCBKernSizeParam
&
param
)
const
;
SmallVector
<
NCBKern
>
get_kimpls
(
const
NCBKernSizeParam
&
param
)
const
;
bool
m_large_group
;
bool
m_large_group
;
...
@@ -217,6 +217,24 @@ public:
...
@@ -217,6 +217,24 @@ public:
fallback
::
ConvBiasImpl
*
opr
,
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
override
;
const
NCBKernSizeParam
&
param
)
const
override
;
};
};
class
ConvBiasImpl
::
AlgoF32DirectStride2NCHWNCHW44
final
:
public
AlgoBase
{
SmallVector
<
NCBKern
>
get_kimpls
(
const
NCBKernSizeParam
&
param
)
const
;
public:
AlgoF32DirectStride2NCHWNCHW44
()
{}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"F32_CONV_NCHW_NCHW44"
;
}
bool
usable
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
override
;
size_t
get_workspace
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
override
;
};
}
// namespace arm_common
}
// namespace arm_common
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
0 → 100644
浏览文件 @
c9986df5
/**
* \file
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
using
conv_fun
=
std
::
function
<
void
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2
)
namespace
{
static
inline
int
block_helper
(
const
int
nthread
,
const
int
amount
,
const
int
per_unit_bytes
)
{
MEGDNN_MARK_USED_VAR
(
per_unit_bytes
);
const
int
block_per_thread
=
div_ceil
(
amount
,
nthread
);
const
int
best_block
=
16
;
const
int
max_block_num
=
div_ceil
(
block_per_thread
,
best_block
);
const
int
min_block_num
=
std
::
max
(
max_block_num
-
1
,
1
);
const
int
max_block
=
div_ceil
(
block_per_thread
,
max_block_num
);
const
int
min_block
=
div_ceil
(
block_per_thread
,
min_block_num
);
const
int
max_loss
=
std
::
abs
(
max_block_num
*
max_block
-
block_per_thread
);
const
int
min_loss
=
std
::
abs
(
min_block_num
*
min_block
-
block_per_thread
);
int
block
=
max_loss
>
min_loss
?
min_block
:
max_block
;
return
block
;
}
static
inline
size_t
get_perthread_cache_bytes
(
const
int
ic
,
const
int
ih2
,
const
int
iw2
)
{
// border_size is used to avoid read illegal memory
int
border_size
=
64
*
2
;
return
ic
*
ih2
*
iw2
*
sizeof
(
float
)
+
border_size
;
}
static
void
get_rectified_size
(
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
int
&
ih2
,
int
&
iw2
,
int
&
oh2
,
int
&
ow2
)
{
int
iw
=
param
.
isz
[
1
];
int
oh
=
param
.
osz
[
0
];
int
ow
=
param
.
osz
[
1
];
oh2
=
oh
;
ow2
=
ow
;
constexpr
int
cacheline
=
64
/
sizeof
(
float
);
int
block_oh
=
block_helper
(
param
.
nr_threads
,
oh
,
0
);
auto
&&
fm
=
param
.
filter_meta
;
const
int
stride_h
=
static_cast
<
int
>
(
fm
.
stride
[
0
]);
const
int
filter_h
=
static_cast
<
int
>
(
fm
.
spatial
[
0
]);
ih2
=
block_oh
*
stride_h
+
filter_h
-
stride_h
;
iw2
=
round_up
(
iw
+
2
*
static_cast
<
int
>
(
fm
.
padding
[
1
]),
cacheline
);
}
static
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
auto
&&
fm
=
param
.
filter_meta
;
int
group
=
fm
.
group
;
int
ic
=
fm
.
icpg
;
int
oc
=
fm
.
ocpg
;
int
fh
=
fm
.
spatial
[
0
];
int
fw
=
fm
.
spatial
[
1
];
int
ih2
,
iw2
,
oh2
,
ow2
;
get_rectified_size
(
param
,
ih2
,
iw2
,
oh2
,
ow2
);
int
oh_block
=
block_helper
(
param
.
nr_threads
,
oh2
,
0
);
megdnn_assert
(
oh_block
!=
0
,
"oh_block!=0"
);
size_t
src_size
=
get_perthread_cache_bytes
(
ic
,
ih2
,
iw2
);
size_t
weight_size
=
group
*
oc
*
ic
*
fh
*
fw
*
sizeof
(
float
);
return
{
nullptr
,
{
src_size
*
param
.
nr_threads
,
weight_size
}};
};
static
inline
void
copy_pad_src
(
float
*
sptr_base
,
const
float
*
sptr_origin
,
int
ph
,
int
pw
,
int
pad_right
,
int
ih
,
int
iw
,
int
iw2
,
int
pad_top
,
int
pad_bottom
,
int
ic
,
int
ic_stride
)
{
MEGDNN_MARK_USED_VAR
(
ph
);
rep
(
ic_idx
,
ic
)
{
const
float
*
sptr
=
sptr_origin
+
ic_idx
*
ic_stride
;
memset
(
sptr_base
,
0
,
sizeof
(
float
)
*
iw2
*
pad_top
);
sptr_base
+=
iw2
*
pad_top
;
rep
(
ih_idx
,
ih
)
{
memset
(
sptr_base
,
0
,
sizeof
(
float
)
*
pw
);
sptr_base
+=
pw
;
memcpy
(
sptr_base
,
sptr
,
sizeof
(
float
)
*
iw
);
sptr_base
+=
iw
;
sptr
+=
iw
;
memset
(
sptr_base
,
0
,
sizeof
(
float
)
*
pad_right
);
sptr_base
+=
pad_right
;
}
memset
(
sptr_base
,
0
,
sizeof
(
float
)
*
iw2
*
pad_bottom
);
sptr_base
+=
iw2
*
pad_bottom
;
}
}
static
void
pack_weight
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
{
bundle
.
set
(
kern_param
.
workspace_ptr
);
const
int
group_id
=
ncb_index
.
ndrange_id
[
0
];
int
fh
=
kern_param
.
filter_meta
.
spatial
[
0
];
int
fw
=
kern_param
.
filter_meta
.
spatial
[
1
];
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
int
ic
=
kern_param
.
filter_meta
.
icpg
;
int
oc_block
=
oc
;
int
oc_idx
=
0
;
const
float
*
fptr
=
kern_param
.
filter
<
dt_float32
>
(
group_id
)
+
oc_idx
*
fh
*
fw
*
ic
;
auto
packed_weight
=
reinterpret_cast
<
float
*>
(
bundle
.
get
(
1
))
+
group_id
*
oc
*
ic
*
fh
*
fw
+
oc_idx
*
ic
*
fh
*
fw
;
conv_bias
::
pack_weight_fp32_nchw_nchw44
(
fptr
,
packed_weight
,
oc_block
,
fh
,
fw
,
ic
);
}
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
>
static
void
do_conv_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
,
const
CpuNDRange
&
)
{
const
int
oh
=
kern_param
.
osz
[
0
];
const
int
ow
=
kern_param
.
osz
[
1
];
const
int
fh
=
kern_param
.
filter_meta
.
spatial
[
0
];
const
int
fw
=
kern_param
.
filter_meta
.
spatial
[
1
];
const
int
ic
=
kern_param
.
filter_meta
.
icpg
;
const
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
const
int
ih
=
kern_param
.
isz
[
0
];
const
int
iw
=
kern_param
.
isz
[
1
];
const
int
stride_h
=
kern_param
.
filter_meta
.
stride
[
0
];
const
int
ph
=
kern_param
.
filter_meta
.
padding
[
0
];
const
int
pw
=
kern_param
.
filter_meta
.
padding
[
1
];
int
ih2
=
0
;
int
iw2
=
0
;
int
oh2
=
0
;
int
ow2
=
0
;
get_rectified_size
(
kern_param
,
ih2
,
iw2
,
oh2
,
ow2
);
bundle
.
set
(
kern_param
.
workspace_ptr
);
constexpr
int
pack_c
=
4
;
const
int
batch_id
=
ncb_index
.
ndrange_id
[
0
];
const
int
group_id
=
ncb_index
.
ndrange_id
[
1
];
int
oc_idx
=
0
;
int
oc_block
=
oc
;
int
oh_block
=
block_helper
(
kern_param
.
nr_threads
,
oh2
,
0
);
const
int
oh_idx
=
ncb_index
.
ndrange_id
[
2
];
const
int
oh_block_real
=
std
::
min
(
oh
-
oh_idx
*
oh_block
,
oh_block
);
const
int
ih_real
=
oh_block_real
*
stride_h
+
fh
-
stride_h
;
const
int
src_top_pad
=
std
::
max
(
ph
-
oh_idx
*
oh_block
*
stride_h
,
0
);
const
int
src_bottom_pad
=
std
::
max
(
(
oh_idx
*
oh_block
+
oh_block_real
-
1
)
*
stride_h
+
fh
-
ih
-
ph
,
0
);
const
int
remain_right_pad
=
std
::
max
(
iw2
-
iw
-
pw
,
0
);
const
int
src_offset
=
std
::
max
(
oh_idx
*
oh_block
*
stride_h
-
ph
,
0
)
*
iw
;
const
float
*
origin_sptr
=
static_cast
<
const
float
*>
(
kern_param
.
src
<
float
>
(
batch_id
,
group_id
,
0
,
1
,
1
))
+
src_offset
;
const
size_t
src_size
=
get_perthread_cache_bytes
(
ic
,
ih2
,
iw2
);
float
*
sptr
=
reinterpret_cast
<
float
*>
((
int8_t
*
)
bundle
.
get
(
0
)
+
ncb_index
.
thread_id
*
src_size
);
copy_pad_src
(
sptr
,
origin_sptr
,
ph
,
pw
,
remain_right_pad
,
ih_real
-
src_top_pad
-
src_bottom_pad
,
iw
,
iw2
,
src_top_pad
,
src_bottom_pad
,
ic
,
ih
*
iw
);
// pack weight
auto
packed_weight
=
reinterpret_cast
<
float
*>
(
bundle
.
get
(
1
))
+
group_id
*
oc
*
ic
*
fh
*
fw
+
oc_idx
*
ic
*
fh
*
fw
;
// get param
float_t
*
dst
=
kern_param
.
dst
<
float_t
>
(
batch_id
,
group_id
)
+
oh_idx
*
oh_block
*
ow
*
pack_c
;
const
float
*
bptr
=
kern_param
.
bias
<
dt_float32
>
(
batch_id
,
group_id
)
+
oc_idx
;
Op
op
;
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw_nchw44< \
\
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, dst, oc_block, \
ic, ih_real, iw2, oh, oh_block_real, ow, op, ph, \
pw)
DISPATCH_FILTER
(
filter
,
KERN1_NCHW44_CONV
);
#undef KERN1_NCHW44_CONV
}
}
// namespace
/* ===================== stride2 algo ===================== */
bool
ConvBiasImpl
::
AlgoF32DirectStride2NCHWNCHW44
::
usable
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
auto
&&
fm
=
param
.
filter_meta
;
auto
fh
=
fm
.
spatial
[
0
];
int
oc
=
fm
.
ocpg
;
bool
ok_type
=
((
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Float32
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
Float32
&&
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Float32
)))
&&
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
);
bool
ok_src_dst
=
fm
.
icpg
<
4
&&
(
oc
%
4
==
0
&&
oc
>=
4
)
&&
fm
.
group
==
1
;
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fh
==
fm
.
spatial
[
1
]
&&
(
fh
==
3
||
fh
==
5
||
fh
==
7
);
bool
ok_slide
=
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
2
&&
fm
.
stride
[
1
]
==
2
;
bool
ok_conv
=
!
fm
.
should_flip
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
bool
avaible
=
ok_type
&&
ok_src_dst
&&
ok_filter
&&
ok_slide
&&
ok_conv
;
return
avaible
;
}
size_t
ConvBiasImpl
::
AlgoF32DirectStride2NCHWNCHW44
::
get_workspace
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
get_bundle
(
param
).
total_size_in_bytes
();
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ConvBiasImpl
::
AlgoF32DirectStride2NCHWNCHW44
::
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
auto
fm
=
param
.
filter_meta
;
const
int
batch
=
param
.
n
;
const
int
group
=
fm
.
group
;
WorkspaceBundle
wbundle
=
get_bundle
(
param
);
conv_fun
do_conv_fun
=
nullptr
;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, NoneOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, ReluOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp<dt_float32>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN
();
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert
(
do_conv_fun
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kerns
;
WorkspaceBundle
bundle
=
wbundle
;
int
oh
=
param
.
osz
[
0
];
int
oh_block
=
block_helper
(
param
.
nr_threads
,
oh
,
0
);
auto
do_pack_weight
=
[
bundle
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
pack_weight
(
bundle
,
kern_param
,
ncb_index
);
};
ret_kerns
.
push_back
({
do_pack_weight
,
{
static_cast
<
size_t
>
(
group
)}});
CpuNDRange
ncb_range
=
{
static_cast
<
size_t
>
(
batch
),
static_cast
<
size_t
>
(
group
),
static_cast
<
size_t
>
(
div_ceil
(
oh
,
oh_block
))};
auto
do_conv
=
[
bundle
,
do_conv_fun
,
ncb_range
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
do_conv_fun
(
bundle
,
kern_param
,
ncb_index
,
ncb_index
.
ndrange_id
,
ncb_range
);
};
ret_kerns
.
push_back
({
do_conv
,
ncb_range
});
return
ret_kerns
;
}
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp
0 → 100644
浏览文件 @
c9986df5
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
namespace
{
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
);
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
2
,
Func
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
constexpr
int
stride
=
2
;
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]); \
c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[1][step], weight[1][weight_idx], \
src[(step * stride + src_idx) / 4]);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
}
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
1
,
Func
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
constexpr
int
stride
=
2
;
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
}
};
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
FUNC
,
typename
T
,
typename
T2
,
typename
T3
>
inline
void
cal_helper
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
ShiftCalHelper
<
src_idx
,
weight_idx
,
c_dim
,
FUNC
,
T
,
T2
,
T3
,
int
>::
impl
(
c
,
src
,
weight
);
};
template
<
int
oc
>
struct
OCHelper
{
public:
static
const
int
val
=
-
1
;
};
template
<
>
struct
OCHelper
<
4
>
{
public:
static
const
int
val
=
1
;
};
template
<
>
struct
OCHelper
<
8
>
{
public:
static
const
int
val
=
2
;
};
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
* */
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44FP32
{
static
void
impl
(
const
float32_t
*
src_ptr
,
const
float32_t
*
weight_ptr
,
const
float32_t
*
bias_ptr
,
float32_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
);
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44FP32
<
bias_mode
,
Op
,
remain_w
,
7
,
oc_block
>
{
static
void
impl
(
const
float32_t
*
src_ptr
,
const
float32_t
*
weight_ptr
,
const
float32_t
*
bias_ptr
,
float32_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
loop_ic_step
=
1
;
constexpr
int
filter_size
=
7
;
constexpr
int
oc_step
=
4
;
constexpr
int
simd_len
=
4
;
constexpr
int
src_reg_size
=
6
;
constexpr
int
ld_weight_fw
=
oc_step
*
filter_size
;
const
int
ld_weight_oc
=
oc_step
*
filter_size
*
filter_size
*
ic
;
const
int
ld_weight_ic
=
oc_step
*
filter_size
*
filter_size
;
const
int
ld_src_ic
=
ih
*
iw
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
float32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
float32x4_t
src
[
src_reg_size
];
float32x4_t
weight
[
c_dim
][
filter_size
];
#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<5, 5, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<6, 6, c_dim, Vfmaq_laneq_f32>(c, src, weight);
UNROLL_CALL_RAW
(
7
,
KERNEL_CB
)
#undef KERNEL_CB
src_ptr
+=
ld_src_ic
;
weight_ptr
+=
ld_weight_ic
;
}
store_ocx_ow8_remain_static
<
c_dim
,
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44FP32
<
bias_mode
,
Op
,
remain_w
,
5
,
oc_block
>
{
static
void
impl
(
const
float32_t
*
src_ptr
,
const
float32_t
*
weight_ptr
,
const
float32_t
*
bias_ptr
,
float32_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
loop_ic_step
=
1
;
constexpr
int
filter_size
=
5
;
constexpr
int
oc_step
=
4
;
constexpr
int
simd_len
=
4
;
constexpr
int
src_reg_size
=
5
;
constexpr
int
ld_weight_fw
=
oc_step
*
filter_size
;
const
int
ld_weight_oc
=
oc_step
*
filter_size
*
filter_size
*
ic
;
const
int
ld_weight_ic
=
oc_step
*
filter_size
*
filter_size
;
const
int
ld_src_ic
=
ih
*
iw
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
float32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
float32x4_t
src
[
src_reg_size
];
float32x4_t
weight
[
c_dim
][
filter_size
];
#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight);
UNROLL_CALL_RAW
(
5
,
KERNEL_CB
)
#undef KERNEL_CB
src_ptr
+=
ld_src_ic
;
weight_ptr
+=
ld_weight_ic
;
}
store_ocx_ow8_remain_static
<
c_dim
,
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44FP32
<
bias_mode
,
Op
,
remain_w
,
3
,
oc_block
>
{
static
void
impl
(
const
float32_t
*
src_ptr
,
const
float32_t
*
weight_ptr
,
const
float32_t
*
bias_ptr
,
float32_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
loop_ic_step
=
1
;
constexpr
int
filter_size
=
3
;
constexpr
int
oc_step
=
4
;
constexpr
int
simd_len
=
4
;
constexpr
int
src_reg_size
=
5
;
constexpr
int
ld_weight_fw
=
oc_step
*
filter_size
;
const
int
ld_weight_oc
=
oc_step
*
filter_size
*
filter_size
*
ic
;
const
int
ld_weight_ic
=
oc_step
*
filter_size
*
filter_size
;
const
int
ld_src_ic
=
ih
*
iw
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
float32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
float32x4_t
src
[
src_reg_size
];
float32x4_t
weight
[
c_dim
][
filter_size
];
// row 0
load_helper
<
5
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
,
0
);
load_helper
<
3
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
2
,
2
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
// row 1
load_helper
<
5
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
+
iw
,
0
);
load_helper
<
3
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
+
1
*
ld_weight_fw
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
2
,
2
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
// row 2
load_helper
<
5
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
+
2
*
iw
,
0
);
load_helper
<
3
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
+
2
*
ld_weight_fw
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
2
,
2
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
src_ptr
+=
ld_src_ic
;
weight_ptr
+=
ld_weight_ic
;
}
store_ocx_ow8_remain_static
<
c_dim
,
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
}
// namespace
void
conv_bias
::
pack_weight_fp32_nchw_nchw44
(
const
float32_t
*
in_ptr
,
float32_t
*
dst_ptr
,
const
int
oc
,
const
int
kh
,
const
int
kw
,
const
int
ic
)
{
constexpr
int
oc_step
=
4
;
const
int
filter_oc_stride
=
kh
*
kw
*
ic
;
const
int
filter_ic_stride
=
kh
*
kw
*
oc_step
;
for
(
int
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
float32_t
*
in_ptr_oc
=
in_ptr
+
oc_idx
*
filter_oc_stride
;
float32_t
*
dst_ptr_oc
=
dst_ptr
+
oc_idx
*
filter_oc_stride
;
for
(
int
kh_idx
=
0
;
kh_idx
<
kh
;
++
kh_idx
)
{
for
(
int
kw_idx
=
0
;
kw_idx
<
kw
;
++
kw_idx
)
{
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
++
ic_idx
)
{
float32x4_t
vsrc
=
vld1q_f32
(
in_ptr_oc
);
vst1q_f32
(
dst_ptr_oc
+
ic_idx
*
filter_ic_stride
,
vsrc
);
in_ptr_oc
+=
oc_step
;
}
dst_ptr_oc
+=
oc_step
;
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
filter_size
>
static
void
conv_direct_stride2_fp32_nchw_nchw44
(
const
float32_t
*
src
,
const
float32_t
*
filter
,
const
float32_t
*
bias
,
float32_t
*
,
float32_t
*
dst
,
const
int
oc
,
const
int
ic
,
const
int
ih
,
const
int
iw
,
const
int
oh
,
const
int
oh_block
,
const
int
ow
,
const
Op
&
op
,
const
int
,
const
int
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
1
;
constexpr
int
big_oc_step
=
8
;
constexpr
int
oc_step
=
4
;
constexpr
int
ih_step
=
1
;
constexpr
int
oh_step
=
1
;
constexpr
int
ow_step
=
8
;
constexpr
int
stride_h
=
2
;
constexpr
int
stride_w
=
2
;
constexpr
int
pack_iw_len
=
1
;
const
int
img_stride
=
oh
*
ow
;
const
int
ow_end
=
ow
/
ow_step
*
ow_step
;
const
int
ow_remain
=
ow
-
ow_end
;
const
int
oc_end
=
oc
/
big_oc_step
*
big_oc_step
;
const
int
oc_remain
=
oc
-
oc_end
;
const
int
ld_dst_oc
=
oc_step
*
img_stride
;
using
remain_fun
=
std
::
function
<
void
(
const
float32_t
*
src_ptr
,
const
float32_t
*
weight_ptr
,
const
float32_t
*
bias_ptr
,
float32_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
>
;
remain_fun
kern_big_oc_remain
=
nullptr
;
remain_fun
kern_small_oc_remain
=
nullptr
;
switch
(
ow_remain
)
{
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
big_oc_step>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
oc_step>::impl; \
break;
UNROLL_CALL_RAW
(
8
,
cb
);
default:
megdnn_assert
(
0
,
"no remain %d for kern"
,
ow_remain
);
}
for
(
int
oc_idx
=
0
;
oc_idx
<
oc_end
;
oc_idx
+=
big_oc_step
)
{
const
int
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
int
oh_idx
=
0
;
oh_idx
<
oh_block
;
oh_idx
+=
oh_step
)
{
for
(
int
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
int
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
int
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonXXs2NchwNchw44FP32
<
bias_mode
,
Op
,
0
,
filter_size
,
big_oc_step
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
int
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
int
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_big_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
if
(
oc_remain
>
0
)
{
int
oc_idx
=
oc_end
;
const
int
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
int
oh_idx
=
0
;
oh_idx
<
oh_block
;
oh_idx
+=
oh_step
)
{
for
(
int
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
int
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
int
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonXXs2NchwNchw44FP32
<
bias_mode
,
Op
,
0
,
filter_size
,
oc_step
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
int
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
int
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_small_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
}
#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw_nchw44( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t* temp, float32_t* dst, \
const int oc, const int ic, const int ih, const int iw, \
const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw) { \
conv_direct_stride2_fp32_nchw_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \
ow, op, ph, pw); \
}
CONSTRUCT_FUNC
(
3
);
CONSTRUCT_FUNC
(
5
);
CONSTRUCT_FUNC
(
7
);
#undef CONSTRUCT_FUNC
template
<
BiasMode
bias_mode
,
typename
Op
>
void
conv_bias
::
conv_direct_stride2_2x2_fp32_nchw_nchw44
(
const
float32_t
*
,
const
float32_t
*
,
const
float32_t
*
,
float32_t
*
,
float32_t
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
Op
&
,
const
int
,
const
int
)
{
megdnn_assert
(
0
,
"not imple nchw_nchw44 2x2s2 conv"
);
}
#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias:: \
conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \
const float32_t*, const float32_t*, const float32_t*, \
float32_t*, float32_t*, const int, const int, const int, \
const int, const int, const int, const int, const Op&, \
const int, const int);
#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER
(
stride2
)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
0 → 100644
浏览文件 @
c9986df5
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace
megdnn
{
namespace
arm_common
{
namespace
conv_bias
{
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_nchw_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);
KERN
(
stride2
,
2
,
nchw44
)
KERN
(
stride2
,
3
,
nchw44
)
KERN
(
stride2
,
5
,
nchw44
)
KERN
(
stride2
,
7
,
nchw44
)
#undef KERN
void
pack_weight_fp32_nchw_nchw44
(
const
float_t
*
in_ptr
,
float_t
*
dst_ptr
,
const
int
oc
,
const
int
kh
,
const
int
kw
,
const
int
ic
);
}
// namespace conv_bias
}
// namespace arm_common
}
// namespace megdnn
\ No newline at end of file
dnn/src/arm_common/conv_bias/intrinsic_helper.h
浏览文件 @
c9986df5
...
@@ -174,7 +174,167 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr,
...
@@ -174,7 +174,167 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr,
int
ld_dst_oc
)
{
int
ld_dst_oc
)
{
StoreOcxOw4Remain
<
c_dim
,
ow_remain
,
Op
,
T
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
StoreOcxOw4Remain
<
c_dim
,
ow_remain
,
Op
,
T
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
}
////////////////////Store_OCX_OW8_Remain/////////////////////////
template
<
int
c_dim
,
int
ow_remain
,
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
);
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
0
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
dst_ptr
+
24
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
dst_ptr
+
ld_dst_oc
+
16
);
op
({{
c
[
1
][
6
],
c
[
1
][
7
]}},
dst_ptr
+
ld_dst_oc
+
24
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
7
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
(
c
[
0
][
6
],
dst_ptr
+
24
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
dst_ptr
+
ld_dst_oc
+
16
);
op
(
c
[
1
][
6
],
dst_ptr
+
ld_dst_oc
+
24
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
6
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
dst_ptr
+
ld_dst_oc
+
16
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
5
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
(
c
[
0
][
4
],
dst_ptr
+
16
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
(
c
[
1
][
4
],
dst_ptr
+
ld_dst_oc
+
16
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
4
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
3
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
(
c
[
0
][
2
],
dst_ptr
+
8
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
(
c
[
1
][
2
],
dst_ptr
+
ld_dst_oc
+
8
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
2
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
1
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
(
c
[
0
][
0
],
dst_ptr
);
op
(
c
[
1
][
0
],
dst_ptr
+
ld_dst_oc
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
0
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
dst_ptr
+
24
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
7
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
(
c
[
0
][
6
],
dst_ptr
+
24
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
6
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
5
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
(
c
[
0
][
4
],
dst_ptr
+
16
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
4
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
3
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
(
c
[
0
][
2
],
dst_ptr
+
8
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
2
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
1
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
(
c
[
0
][
0
],
dst_ptr
);
}
};
template
<
int
c_dim
,
int
ow_remain
,
typename
Op
,
typename
T
>
inline
void
store_ocx_ow8_remain_static
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
StoreOcxOw8Remain
<
c_dim
,
ow_remain
,
Op
,
T
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
////////////////////Store_OC8_OW8_Remain/////////////////////////
////////////////////Store_OC8_OW8_Remain/////////////////////////
template
<
int
ow_remain
,
typename
Op
>
template
<
int
ow_remain
,
typename
Op
>
...
@@ -299,14 +459,15 @@ struct Store_OC8_OW8_Remain<1, Op> {
...
@@ -299,14 +459,15 @@ struct Store_OC8_OW8_Remain<1, Op> {
}
}
};
};
template
<
int
ow_remain
,
typename
Op
>
///////////
inline
void
store_oc8_ow8_remain_static
(
int32x4_t
c
[
2
][
8
],
const
Op
&
op
,
int8_t
*
dst_ptr
,
int
ld_dst_oc
)
{
template
<
int
ow_remain
,
typename
Op
,
typename
T
,
typename
T2
>
inline
void
store_oc8_ow8_remain_static
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
Store_OC8_OW8_Remain
<
ow_remain
,
Op
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
Store_OC8_OW8_Remain
<
ow_remain
,
Op
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
}
///////////////////////////////////////////////////////
//////////////////////////////////////
template
<
BiasMode
bias_mode
>
template
<
BiasMode
bias_mode
>
inline
void
init_oc4_ow8
(
int32x4_t
c
[
8
],
const
int32_t
*
bias_ptr
)
{
inline
void
init_oc4_ow8
(
int32x4_t
c
[
8
],
const
int32_t
*
bias_ptr
)
{
if
(
bias_mode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
if
(
bias_mode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
...
@@ -337,6 +498,49 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
...
@@ -337,6 +498,49 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
#undef BAIS_INIT
#undef BAIS_INIT
}
}
}
}
/////////////////////////init_ocx_ow8////////////////////
template
<
int
c_dim
,
BiasMode
bias_mode
,
typename
T
,
typename
T2
>
struct
InitOcxOw8
{
static
void
impl
(
T
&
c
,
T2
bias_ptr
,
int
oc_step
);
};
template
<
BiasMode
bias_mode
,
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
2
,
bias_mode
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
oc_step
)
{
if
(
bias_mode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
#define BAIS_INIT(step) \
c[0][step] = vld1q_f32(bias_ptr); \
c[1][step] = vld1q_f32(bias_ptr + oc_step);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
else
{
#define BAIS_INIT(step) \
c[0][step] = vdupq_n_f32(0); \
c[1][step] = vdupq_n_f32(0);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
}
};
template
<
BiasMode
bias_mode
,
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
1
,
bias_mode
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
)
{
if
(
bias_mode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
else
{
#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
}
};
template
<
int
c_dim
,
BiasMode
bias_mode
,
typename
T
,
typename
T2
>
inline
void
init_ocx_ow8
(
T
&
c
,
T2
bias_ptr
,
int
oc_step
)
{
InitOcxOw8
<
c_dim
,
bias_mode
,
T
,
T2
>::
impl
(
c
,
bias_ptr
,
oc_step
);
}
/////////////////////init_ocx_ow4/////////////////////
template
<
int
c_dim
,
BiasMode
bias_mode
,
typename
T
>
template
<
int
c_dim
,
BiasMode
bias_mode
,
typename
T
>
struct
InitOcxOw4
{
struct
InitOcxOw4
{
static
void
impl
(
T
&
c
,
const
int32_t
*
bias_ptr
,
int
oc_step
);
static
void
impl
(
T
&
c
,
const
int32_t
*
bias_ptr
,
int
oc_step
);
...
@@ -383,57 +587,54 @@ inline void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) {
...
@@ -383,57 +587,54 @@ inline void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) {
}
}
///////////////////////////////////////
///////////////////////////////////////
template
<
int
weight_number
,
int
base_offset
,
int
ptr_step
,
int
oc_block
,
template
<
int
weight_number
,
int
base_offset
,
int
ptr_step
,
int
oc_block
,
typename
Func
,
typename
T
,
typename
...
XT
>
typename
Func
,
typename
T
,
typename
T2
,
typename
...
XT
>
struct
LoadHelper
{
struct
LoadHelper
{
static
void
impl
(
T
&
weight
,
const
int8_t
*
ptr
,
int
oc_offset
,
XT
...
args
);
static
void
impl
(
T
&
weight
,
T2
ptr
,
int
oc_offset
,
XT
...
args
);
};
};
#define WEIGHT_CB(step) \
#define WEIGHT_CB(step) \
src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...);
src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...);
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
,
typename
...
XT
>
typename
...
XT
>
struct
LoadHelper
<
1
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
XT
...
>
{
struct
LoadHelper
<
1
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
T2
,
XT
...
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offse
t
,
XT
...
args
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
in
t
,
XT
...
args
)
{
UNROLL_CALL_RAW
(
1
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
1
,
WEIGHT_CB
);
}
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
,
typename
...
XT
>
typename
...
XT
>
struct
LoadHelper
<
2
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
XT
...
>
{
struct
LoadHelper
<
2
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
T2
,
XT
...
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offse
t
,
XT
...
args
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
in
t
,
XT
...
args
)
{
UNROLL_CALL_RAW
(
2
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
2
,
WEIGHT_CB
);
}
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
,
typename
...
XT
>
typename
...
XT
>
struct
LoadHelper
<
3
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
XT
...
>
{
struct
LoadHelper
<
3
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
T2
,
XT
...
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offse
t
,
XT
...
args
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
in
t
,
XT
...
args
)
{
UNROLL_CALL_RAW
(
3
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
3
,
WEIGHT_CB
);
}
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
,
typename
...
XT
>
typename
...
XT
>
struct
LoadHelper
<
4
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
XT
...
>
{
struct
LoadHelper
<
4
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
T2
,
XT
...
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
,
XT
...
args
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
,
XT
...
args
)
{
MEGDNN_MARK_USED_VAR
(
oc_offset
);
UNROLL_CALL_RAW
(
4
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
4
,
WEIGHT_CB
);
}
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
,
typename
...
XT
>
typename
...
XT
>
struct
LoadHelper
<
5
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
XT
...
>
{
struct
LoadHelper
<
5
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
T2
,
XT
...
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
,
XT
...
args
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
,
XT
...
args
)
{
MEGDNN_MARK_USED_VAR
(
oc_offset
);
UNROLL_CALL_RAW
(
5
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
5
,
WEIGHT_CB
);
}
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
,
typename
...
XT
>
typename
...
XT
>
struct
LoadHelper
<
6
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
XT
...
>
{
struct
LoadHelper
<
6
,
base_offset
,
ptr_step
,
0
,
Func
,
T
,
T2
,
XT
...
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
,
XT
...
args
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
,
XT
...
args
)
{
MEGDNN_MARK_USED_VAR
(
oc_offset
);
UNROLL_CALL_RAW
(
6
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
6
,
WEIGHT_CB
);
}
}
};
};
...
@@ -441,27 +642,36 @@ struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> {
...
@@ -441,27 +642,36 @@ struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> {
#define WEIGHT_CB(step) \
#define WEIGHT_CB(step) \
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step);
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step);
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
>
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
1
,
base_offset
,
ptr_step
,
1
,
Func
,
T
>
{
struct
LoadHelper
<
1
,
base_offset
,
ptr_step
,
1
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
)
{
UNROLL_CALL_RAW
(
1
,
WEIGHT_CB
);
}
MEGDNN_MARK_USED_VAR
(
oc_offset
);
UNROLL_CALL_RAW
(
1
,
WEIGHT_CB
);
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
>
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
2
,
base_offset
,
ptr_step
,
1
,
Func
,
T
>
{
struct
LoadHelper
<
2
,
base_offset
,
ptr_step
,
1
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
)
{
UNROLL_CALL_RAW
(
2
,
WEIGHT_CB
);
}
MEGDNN_MARK_USED_VAR
(
oc_offset
);
UNROLL_CALL_RAW
(
2
,
WEIGHT_CB
);
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
>
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
3
,
base_offset
,
ptr_step
,
1
,
Func
,
T
>
{
struct
LoadHelper
<
3
,
base_offset
,
ptr_step
,
1
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
)
{
UNROLL_CALL_RAW
(
3
,
WEIGHT_CB
);
}
MEGDNN_MARK_USED_VAR
(
oc_offset
);
};
UNROLL_CALL_RAW
(
3
,
WEIGHT_CB
);
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
}
struct
LoadHelper
<
4
,
base_offset
,
ptr_step
,
1
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
)
{
UNROLL_CALL_RAW
(
4
,
WEIGHT_CB
);
}
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
5
,
base_offset
,
ptr_step
,
1
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
)
{
UNROLL_CALL_RAW
(
5
,
WEIGHT_CB
);
}
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
6
,
base_offset
,
ptr_step
,
1
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
)
{
UNROLL_CALL_RAW
(
6
,
WEIGHT_CB
);
}
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
7
,
base_offset
,
ptr_step
,
1
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
)
{
UNROLL_CALL_RAW
(
7
,
WEIGHT_CB
);
}
};
};
#undef WEIGHT_CB
#undef WEIGHT_CB
...
@@ -470,40 +680,63 @@ struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> {
...
@@ -470,40 +680,63 @@ struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> {
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \
src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset);
src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset);
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
>
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
1
,
base_offset
,
ptr_step
,
2
,
Func
,
T
>
{
struct
LoadHelper
<
1
,
base_offset
,
ptr_step
,
2
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
oc_offset
)
{
UNROLL_CALL_RAW
(
1
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
1
,
WEIGHT_CB
);
}
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
>
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
2
,
base_offset
,
ptr_step
,
2
,
Func
,
T
>
{
struct
LoadHelper
<
2
,
base_offset
,
ptr_step
,
2
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
oc_offset
)
{
UNROLL_CALL_RAW
(
2
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
2
,
WEIGHT_CB
);
}
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
>
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
3
,
base_offset
,
ptr_step
,
2
,
Func
,
T
>
{
struct
LoadHelper
<
3
,
base_offset
,
ptr_step
,
2
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
const
int8_t
*
ptr
,
int
oc_offset
)
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
oc_offset
)
{
UNROLL_CALL_RAW
(
3
,
WEIGHT_CB
);
UNROLL_CALL_RAW
(
3
,
WEIGHT_CB
);
}
}
};
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
4
,
base_offset
,
ptr_step
,
2
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
oc_offset
)
{
UNROLL_CALL_RAW
(
4
,
WEIGHT_CB
);
}
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
5
,
base_offset
,
ptr_step
,
2
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
oc_offset
)
{
UNROLL_CALL_RAW
(
5
,
WEIGHT_CB
);
}
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
6
,
base_offset
,
ptr_step
,
2
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
oc_offset
)
{
UNROLL_CALL_RAW
(
6
,
WEIGHT_CB
);
}
};
template
<
int
base_offset
,
int
ptr_step
,
typename
Func
,
typename
T
,
typename
T2
>
struct
LoadHelper
<
7
,
base_offset
,
ptr_step
,
2
,
Func
,
T
,
T2
>
{
static
void
impl
(
T
&
src
,
T2
ptr
,
int
oc_offset
)
{
UNROLL_CALL_RAW
(
7
,
WEIGHT_CB
);
}
};
#undef WEIGHT_CB
#undef WEIGHT_CB
template
<
int
weight_number
,
int
base_offset
,
int
ptr_step
,
int
c_dim
,
template
<
int
weight_number
,
int
base_offset
,
int
ptr_step
,
int
c_dim
,
typename
Func
,
typename
T
>
typename
Func
,
typename
T
,
typename
T2
>
inline
void
load_helper
(
T
&
weight
,
const
int8_t
*
ptr
,
int
oc_offset
)
{
inline
void
load_helper
(
T
&
weight
,
T2
ptr
,
int
oc_offset
)
{
LoadHelper
<
weight_number
,
base_offset
,
ptr_step
,
c_dim
,
Func
,
T
>::
impl
(
LoadHelper
<
weight_number
,
base_offset
,
ptr_step
,
c_dim
,
Func
,
T
,
T2
>::
impl
(
weight
,
ptr
,
oc_offset
);
weight
,
ptr
,
oc_offset
);
}
}
template
<
int
weight_number
,
int
base_offset
,
int
ptr_step
,
int
c_dim
,
template
<
int
weight_number
,
int
base_offset
,
int
ptr_step
,
int
c_dim
,
typename
Func
,
typename
T
,
typename
...
XT
>
typename
Func
,
typename
T
,
typename
T2
,
typename
...
XT
>
inline
void
load_helper_x
(
T
&
weight
,
const
int8_t
*
ptr
,
int
oc_offset
,
inline
void
load_helper_x
(
T
&
weight
,
T2
ptr
,
int
oc_offset
,
XT
...
args
)
{
XT
...
args
)
{
LoadHelper
<
weight_number
,
base_offset
,
ptr_step
,
c_dim
,
Func
,
T
,
T2
,
LoadHelper
<
weight_number
,
base_offset
,
ptr_step
,
c_dim
,
Func
,
T
,
XT
...
>::
impl
(
weight
,
ptr
,
oc_offset
,
args
...);
XT
...
>::
impl
(
weight
,
ptr
,
oc_offset
,
args
...);
}
}
...
...
dnn/src/arm_common/conv_bias/neon_struct.h
浏览文件 @
c9986df5
...
@@ -34,6 +34,9 @@ struct Vmlal_s16 {
...
@@ -34,6 +34,9 @@ struct Vmlal_s16 {
struct
Vld1q_s8
{
struct
Vld1q_s8
{
static
int8x16_t
impl
(
const
int8_t
*
ptr
)
{
return
vld1q_s8
(
ptr
);
}
static
int8x16_t
impl
(
const
int8_t
*
ptr
)
{
return
vld1q_s8
(
ptr
);
}
};
};
struct
Vld1q_f32
{
static
float32x4_t
impl
(
const
float32_t
*
ptr
)
{
return
vld1q_f32
(
ptr
);
}
};
struct
Vld1_s8
{
struct
Vld1_s8
{
static
int8x8_t
impl
(
const
int8_t
*
ptr
)
{
return
vld1_s8
(
ptr
);
}
static
int8x8_t
impl
(
const
int8_t
*
ptr
)
{
return
vld1_s8
(
ptr
);
}
};
};
...
@@ -50,5 +53,13 @@ struct Vldq_tbl_low_s8 {
...
@@ -50,5 +53,13 @@ struct Vldq_tbl_low_s8 {
struct
Vld1_dup_s8_s16
{
struct
Vld1_dup_s8_s16
{
static
int16x8_t
impl
(
const
int8_t
*
ptr
)
{
return
vld1_dup_s8_s16
(
ptr
);
}
static
int16x8_t
impl
(
const
int8_t
*
ptr
)
{
return
vld1_dup_s8_s16
(
ptr
);
}
};
};
struct
Vfmaq_laneq_f32
{
template
<
const
int
lane
>
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vfmaq_laneq_f32
(
a
,
b
,
v
,
lane
);
}
};
}
// namespace
}
// namespace
}
// namespace megdnn
}
// namespace megdnn
\ No newline at end of file
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
c9986df5
...
@@ -71,6 +71,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
...
@@ -71,6 +71,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF32DirectStride2
f32_direct_stride2_small_group
{
false
};
AlgoF32DirectStride2
f32_direct_stride2_small_group
{
false
};
AlgoF32DirectStride1
f32_direct_stride1_large_group
{
true
};
AlgoF32DirectStride1
f32_direct_stride1_large_group
{
true
};
AlgoF32DirectStride1
f32_direct_stride1_small_group
{
false
};
AlgoF32DirectStride1
f32_direct_stride1_small_group
{
false
};
AlgoF32DirectStride2NCHWNCHW44
f32_direct_stride2_nchw_nchw44
;
AlgoI8x8x16Direct
i8x8x16_direct_large_group
{
true
};
AlgoI8x8x16Direct
i8x8x16_direct_large_group
{
true
};
AlgoI8x8x16Direct
i8x8x16_direct_small_group
{
false
};
AlgoI8x8x16Direct
i8x8x16_direct_small_group
{
false
};
AlgoI8x8x16Stride2
i8x8x16_stride2_large_group
{
true
};
AlgoI8x8x16Stride2
i8x8x16_stride2_large_group
{
true
};
...
@@ -123,6 +124,7 @@ public:
...
@@ -123,6 +124,7 @@ public:
direct_algos
.
emplace_back
(
&
i8x8x16_stride2_filter2
);
direct_algos
.
emplace_back
(
&
i8x8x16_stride2_filter2
);
direct_algos
.
emplace_back
(
&
i8x8x16_stride2_large_group
);
direct_algos
.
emplace_back
(
&
i8x8x16_stride2_large_group
);
direct_algos
.
emplace_back
(
&
i8x8x16_stride2_small_group
);
direct_algos
.
emplace_back
(
&
i8x8x16_stride2_small_group
);
direct_algos
.
emplace_back
(
&
f32_direct_stride2_nchw_nchw44
);
direct_algos
.
emplace_back
(
&
f32_direct_stride1_large_group
);
direct_algos
.
emplace_back
(
&
f32_direct_stride1_large_group
);
direct_algos
.
emplace_back
(
&
f32_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
f32_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
f32_direct_stride2_large_group
);
direct_algos
.
emplace_back
(
&
f32_direct_stride2_large_group
);
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
c9986df5
...
@@ -67,6 +67,7 @@ private:
...
@@ -67,6 +67,7 @@ private:
class
AlgoF32Direct
;
class
AlgoF32Direct
;
class
AlgoF32DirectStride1
;
class
AlgoF32DirectStride1
;
class
AlgoF32DirectStride2
;
class
AlgoF32DirectStride2
;
class
AlgoF32DirectStride2NCHWNCHW44
;
class
AlgoI8x8x16Direct
;
class
AlgoI8x8x16Direct
;
class
AlgoI8x8x16Stride2
;
class
AlgoI8x8x16Stride2
;
class
AlgoI8x8x16Stride2Filter2
;
class
AlgoI8x8x16Stride2Filter2
;
...
...
dnn/src/arm_common/elemwise_helper/kimpl/hswish.h
浏览文件 @
c9986df5
...
@@ -45,13 +45,17 @@ struct HSwishOp;
...
@@ -45,13 +45,17 @@ struct HSwishOp;
vst1q_##_func_suffix(dst, vitem.val[0]); \
vst1q_##_func_suffix(dst, vitem.val[0]); \
vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
} \
void operator()(const _neon_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
vst1q_##_func_suffix(dst, vitem); \
} \
_neon_type2 operator()(const _neon_type2& src) const { \
_neon_type2 operator()(const _neon_type2& src) const { \
auto val1 = src.val[0]; \
auto val1 = src.val[0]; \
auto val2 = src.val[1]; \
auto val2 = src.val[1]; \
H_SWISH_KERN(_func_suffix, val1, val2); \
H_SWISH_KERN(_func_suffix, val1, val2); \
return {{val1, val2}}; \
return {{val1, val2}}; \
} \
} \
_neon_type operator()(const _neon_type& src)
{
\
_neon_type operator()(const _neon_type& src)
const {
\
auto val_zero = vdupq_n_##_func_suffix(0.f); \
auto val_zero = vdupq_n_##_func_suffix(0.f); \
auto val_six = vdupq_n_##_func_suffix(6.f); \
auto val_six = vdupq_n_##_func_suffix(6.f); \
auto val_three = vdupq_n_##_func_suffix(3.f); \
auto val_three = vdupq_n_##_func_suffix(3.f); \
...
@@ -64,6 +68,7 @@ struct HSwishOp;
...
@@ -64,6 +68,7 @@ struct HSwishOp;
val_rec_six); \
val_rec_six); \
} \
} \
};
};
OP
(
dt_float32
,
float32x4_t
,
float32x4x2_t
,
f32
,
4
)
OP
(
dt_float32
,
float32x4_t
,
float32x4x2_t
,
f32
,
4
)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
OP
(
__fp16
,
float16x8_t
,
float16x8x2_t
,
f16
,
8
)
OP
(
__fp16
,
float16x8_t
,
float16x8x2_t
,
f16
,
8
)
...
...
dnn/src/arm_common/elemwise_helper/kimpl/none.h
浏览文件 @
c9986df5
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -30,6 +31,13 @@ struct NoneOp;
...
@@ -30,6 +31,13 @@ struct NoneOp;
using NoneOpBase::operator(); \
using NoneOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
constexpr static size_t SIMD_WIDTH = _simd_width; \
_neon_type2 operator()(const _neon_type2& src) const { return src; } \
_neon_type2 operator()(const _neon_type2& src) const { return src; } \
void operator()(const _neon_type2& src, _ctype* dst) const { \
vst1q_##_func_suffix(dst, src.val[0]); \
vst1q_##_func_suffix(dst + SIMD_WIDTH, src.val[1]); \
} \
void operator()(const _neon_type& src, _ctype* dst) const { \
vst1q_##_func_suffix(dst, src); \
} \
_neon_type operator()(const _neon_type& src) const { return src; } \
_neon_type operator()(const _neon_type& src) const { return src; } \
};
};
...
...
dnn/src/arm_common/elemwise_helper/kimpl/relu.h
浏览文件 @
c9986df5
...
@@ -47,11 +47,16 @@ struct ReluOp;
...
@@ -47,11 +47,16 @@ struct ReluOp;
auto vitem1 = vmaxq_##_func_suffix(src.val[1], vzero); \
auto vitem1 = vmaxq_##_func_suffix(src.val[1], vzero); \
return {{vitem0, vitem1}}; \
return {{vitem0, vitem1}}; \
} \
} \
void operator()(const _neon_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
vst1q_##_func_suffix(dst, vitem); \
} \
_neon_type operator()(const _neon_type& src) const { \
_neon_type operator()(const _neon_type& src) const { \
auto vzero = vdupq_n_##_func_suffix(0); \
auto vzero = vdupq_n_##_func_suffix(0); \
return vmaxq_##_func_suffix(src, vzero); \
return vmaxq_##_func_suffix(src, vzero); \
} \
} \
};
};
OP
(
dt_float32
,
float32x4_t
,
float32x4x2_t
,
f32
,
4
)
OP
(
dt_float32
,
float32x4_t
,
float32x4x2_t
,
f32
,
4
)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
OP
(
__fp16
,
float16x8_t
,
float16x8x2_t
,
f16
,
8
)
OP
(
__fp16
,
float16x8_t
,
float16x8x2_t
,
f16
,
8
)
...
...
dnn/src/arm_common/simd_macro/marm_neon.h
浏览文件 @
c9986df5
...
@@ -479,6 +479,39 @@ UNROLL_CALL_RAW(4, cb);
...
@@ -479,6 +479,39 @@ UNROLL_CALL_RAW(4, cb);
#undef cb
#undef cb
}
// namespace
}
// namespace
#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec)
#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec)
namespace
{
template
<
int
lane
>
struct
Vfmap_laneq_f32_armv7
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
);
};
template
<
>
struct
Vfmap_laneq_f32_armv7
<
0
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
0
);
}
};
template
<
>
struct
Vfmap_laneq_f32_armv7
<
1
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
1
);
}
};
template
<
>
struct
Vfmap_laneq_f32_armv7
<
2
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
0
);
}
};
template
<
>
struct
Vfmap_laneq_f32_armv7
<
3
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
}
};
}
// namespace
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmap_laneq_f32_armv7<lane>::impl(a, b, v)
#endif
#endif
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
c9986df5
...
@@ -85,7 +85,7 @@ TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QU8) {
...
@@ -85,7 +85,7 @@ TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QU8) {
#if MEGDNN_WITH_BENCHMARK
#if MEGDNN_WITH_BENCHMARK
static
void
benchmark_convbias
(
Handle
*
handle
)
{
static
void
benchmark_convbias
(
Handle
*
handle
,
bool
is_fp32
=
false
)
{
constexpr
size_t
RUNS
=
30
;
constexpr
size_t
RUNS
=
30
;
Benchmarker
<
ConvBias
>
benchmarker_int
(
handle
);
Benchmarker
<
ConvBias
>
benchmarker_int
(
handle
);
...
@@ -102,15 +102,25 @@ static void benchmark_convbias(Handle* handle) {
...
@@ -102,15 +102,25 @@ static void benchmark_convbias(Handle* handle) {
Benchmarker
<
ConvBias
>
benchmarker_float
(
handle
);
Benchmarker
<
ConvBias
>
benchmarker_float
(
handle
);
benchmarker_float
.
set_display
(
false
).
set_times
(
RUNS
);
benchmarker_float
.
set_display
(
false
).
set_times
(
RUNS
);
benchmarker_float
.
set_before_exec_callback
(
benchmarker_float
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
".+"
));
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"
));
Benchmarker
<
ConvBias
>
benchmarker_int_nchw44
(
handle
);
Benchmarker
<
ConvBias
>
benchmarker_int_nchw44
(
handle
);
if
(
is_fp32
)
{
benchmarker_int_nchw44
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
4
,
dtype
::
Float32
())
.
set_display
(
false
);
}
else
{
benchmarker_int_nchw44
.
set_times
(
RUNS
)
benchmarker_int_nchw44
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
))
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
2.5
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
2.5
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.25
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.25
))
.
set_dtype
(
4
,
dtype
::
QuantizedS8
(
60.25
))
.
set_dtype
(
4
,
dtype
::
QuantizedS8
(
60.25
))
.
set_display
(
false
);
.
set_display
(
false
);
}
benchmarker_int_nchw44
.
set_before_exec_callback
(
benchmarker_int_nchw44
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
".+"
));
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
".+"
));
...
@@ -151,7 +161,6 @@ static void benchmark_convbias(Handle* handle) {
...
@@ -151,7 +161,6 @@ static void benchmark_convbias(Handle* handle) {
auto
int_nchw44_used
=
benchmarker_int_nchw44
.
set_param
(
param
).
exec
(
auto
int_nchw44_used
=
benchmarker_int_nchw44
.
set_param
(
param
).
exec
(
{
src
,
filter
,
bias
,
{},
dst
})
/
{
src
,
filter
,
bias
,
{},
dst
})
/
RUNS
;
RUNS
;
float
computations
=
IC
*
(
FS
*
FS
)
*
dst
.
total_nr_elems
()
*
2
*
1e-6
;
float
computations
=
IC
*
(
FS
*
FS
)
*
dst
.
total_nr_elems
()
*
2
*
1e-6
;
printf
(
"run: %s %s %s->%s
\n
"
,
src
.
to_string
().
c_str
(),
printf
(
"run: %s %s %s->%s
\n
"
,
src
.
to_string
().
c_str
(),
filter
.
to_string
().
c_str
(),
bias
.
to_string
().
c_str
(),
filter
.
to_string
().
c_str
(),
bias
.
to_string
().
c_str
(),
...
@@ -160,32 +169,42 @@ static void benchmark_convbias(Handle* handle) {
...
@@ -160,32 +169,42 @@ static void benchmark_convbias(Handle* handle) {
computations
/
float_used
);
computations
/
float_used
);
printf
(
"int_nchw: %f ms %f Gflops, "
,
int_used
,
printf
(
"int_nchw: %f ms %f Gflops, "
,
int_used
,
computations
/
int_used
);
computations
/
int_used
);
auto
speed_up
=
int_used
/
int_nchw44_used
;
if
(
is_fp32
)
{
speed_up
=
float_used
/
int_nchw44_used
;
printf
(
"fp32_nchw44: %f ms %f Gflops %f speedup, "
,
int_nchw44_used
,
computations
/
int_nchw44_used
,
speed_up
);
}
else
{
printf
(
"int_nchw44: %f ms %f Gflops %f speedup, "
,
int_nchw44_used
,
printf
(
"int_nchw44: %f ms %f Gflops %f speedup, "
,
int_nchw44_used
,
computations
/
int_nchw44_used
,
int_used
/
int_nchw44_used
);
computations
/
int_nchw44_used
,
speed_up
);
}
printf
(
"
\n
"
);
printf
(
"
\n
"
);
};
};
if
(
is_fp32
)
{
run
(
1
,
3
,
32
,
224
,
224
,
3
,
2
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
3
,
2
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
5
,
2
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
2
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
2
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
7
,
2
,
true
);
}
else
{
for
(
size_t
stride
:
{
1
,
2
})
{
for
(
size_t
stride
:
{
1
,
2
})
{
printf
(
"stride %zu
\n
"
,
stride
);
printf
(
"stride %zu
\n
"
,
stride
);
for
(
size_t
filter_size
:
{
2
,
3
,
5
,
7
})
{
for
(
size_t
filter_size
:
{
2
,
3
,
5
,
7
})
{
for
(
size_t
img_size
:
{
32
})
{
for
(
size_t
img_size
:
{
32
})
{
for
(
size_t
channel
:
{
8
,
16
,
32
,
64
,
128
,
256
})
{
for
(
size_t
channel
:
{
8
,
16
,
32
,
64
,
128
,
256
})
{
run
(
1
,
channel
,
channel
,
img_size
,
img_size
,
filter_size
,
run
(
1
,
channel
,
channel
,
img_size
,
img_size
,
stride
,
false
);
filter_size
,
stride
,
false
);
}
}
}
}
}
}
}
}
}
}
}
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONVBIAS_NCHW44
)
{
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONVBIAS_NCHW44
)
{
benchmark_convbias
(
handle
());
benchmark_convbias
(
handle
()
,
true
);
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
BENCHMARK_CONVBIAS_NCHW44
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
BENCHMARK_CONVBIAS_NCHW44
)
{
benchmark_convbias
(
handle
());
benchmark_convbias
(
handle
()
,
true
);
}
}
#endif
#endif
TEST_F
(
ARM_COMMON
,
CONV_BIAS_MATMUL_QS8
)
{
TEST_F
(
ARM_COMMON
,
CONV_BIAS_MATMUL_QS8
)
{
using
namespace
conv_bias
;
using
namespace
conv_bias
;
...
@@ -1464,7 +1483,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
...
@@ -1464,7 +1483,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
#if MEGDNN_WITH_BENCHMARK
#if MEGDNN_WITH_BENCHMARK
namespace
{
namespace
{
std
::
vector
<
conv_bias
::
TestArg
>
get_conv_bias_1x1_benchmark_args
(
size_t
pack_size
=
1
)
{
std
::
vector
<
conv_bias
::
TestArg
>
get_conv_bias_1x1_benchmark_args
(
size_t
pack_size
=
1
)
{
using
namespace
conv_bias
;
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
;
std
::
vector
<
TestArg
>
args
;
param
::
ConvBias
param
;
param
::
ConvBias
param
;
...
@@ -1474,14 +1494,16 @@ std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(size_t pack_siz
...
@@ -1474,14 +1494,16 @@ std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(size_t pack_siz
param
.
pad_w
=
0
;
param
.
pad_w
=
0
;
param
.
nonlineMode
=
param
::
ConvBias
::
NonlineMode
::
IDENTITY
;
param
.
nonlineMode
=
param
::
ConvBias
::
NonlineMode
::
IDENTITY
;
auto
bench_case
=
[
&
](
size_t
OC
,
size_t
IC
,
size_t
H
,
size_t
W
)
{
auto
bench_case
=
[
&
](
size_t
OC
,
size_t
IC
,
size_t
H
,
size_t
W
)
{
if
(
pack_size
==
1
)
if
(
pack_size
==
1
)
args
.
emplace_back
(
param
,
TensorShape
{
1
,
IC
,
H
,
W
},
args
.
emplace_back
(
param
,
TensorShape
{
1
,
IC
,
H
,
W
},
TensorShape
{
OC
,
IC
,
1
,
1
},
TensorShape
{});
TensorShape
{
OC
,
IC
,
1
,
1
},
TensorShape
{});
else
{
else
{
if
(
pack_size
==
4
)
if
(
pack_size
==
4
)
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44
;
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44
;
args
.
emplace_back
(
param
,
TensorShape
{
1
,
IC
/
pack_size
,
H
,
W
,
pack_size
},
args
.
emplace_back
(
param
,
TensorShape
{
OC
/
pack_size
,
IC
/
pack_size
,
1
,
1
,
pack_size
,
pack_size
},
TensorShape
{
1
,
IC
/
pack_size
,
H
,
W
,
pack_size
},
TensorShape
{
OC
/
pack_size
,
IC
/
pack_size
,
1
,
1
,
pack_size
,
pack_size
},
TensorShape
{});
TensorShape
{});
}
}
};
};
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
c9986df5
...
@@ -78,9 +78,10 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
...
@@ -78,9 +78,10 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
std
::
vector
<
TestArg
>
args
;
std
::
vector
<
TestArg
>
args
;
auto
pack
=
[
&
](
size_t
n
,
size_t
oc
,
size_t
ic
,
size_t
h
,
size_t
w
,
auto
pack
=
[
&
](
size_t
n
,
size_t
oc
,
size_t
ic
,
size_t
h
,
size_t
w
,
size_t
kernel
,
size_t
stride
,
size_t
group
,
NLMode
nlmode
)
{
size_t
kernel
,
size_t
stride
,
size_t
group
,
NLMode
nlmode
,
int
any_pad
=
-
1
)
{
constexpr
int
pack_c
=
4
;
constexpr
int
pack_c
=
4
;
const
size_t
pad
=
no_pad
?
0
:
kernel
/
2
;
const
size_t
pad
=
any_pad
>=
0
?
any_pad
:
kernel
/
2
;
auto
bias_mode
=
no_bias
?
megdnn
::
BiasMode
::
NO_BIAS
auto
bias_mode
=
no_bias
?
megdnn
::
BiasMode
::
NO_BIAS
:
megdnn
::
BiasMode
::
BROADCAST_CHANNEL_BIAS
;
:
megdnn
::
BiasMode
::
BROADCAST_CHANNEL_BIAS
;
auto
oc_per_group
=
oc
/
group
;
auto
oc_per_group
=
oc
/
group
;
...
@@ -90,7 +91,8 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
...
@@ -90,7 +91,8 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
ic_per_group
>
0
;
ic_per_group
>
0
;
bool
nchw_disable
=
group
>
1
||
ic_per_group
>=
4
;
bool
nchw_disable
=
group
>
1
||
ic_per_group
>=
4
;
bool
nchw44_disable
=
ic_per_group
%
pack_c
!=
0
;
bool
nchw44_disable
=
ic_per_group
%
pack_c
!=
0
;
if
(
!
(
ok_group
))
{
bool
invalid_pad
=
(
w
+
2
*
pad
<
kernel
)
||
(
h
+
2
*
pad
<
kernel
);
if
(
!
(
ok_group
)
||
invalid_pad
)
{
return
;
return
;
}
}
if
((
is_input_nchw
&&
nchw_disable
)
||
if
((
is_input_nchw
&&
nchw_disable
)
||
...
@@ -107,6 +109,7 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
...
@@ -107,6 +109,7 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
param
.
pad_h
=
pad
;
param
.
pad_h
=
pad
;
param
.
pad_w
=
pad
;
param
.
pad_w
=
pad
;
param
.
nonlineMode
=
nlmode
;
param
.
nonlineMode
=
nlmode
;
auto
src_tensor_shape
=
TensorShape
{
n
,
ic
/
pack_c
,
h
,
w
,
pack_c
};
auto
src_tensor_shape
=
TensorShape
{
n
,
ic
/
pack_c
,
h
,
w
,
pack_c
};
auto
weight_tensor_shape
=
TensorShape
{
auto
weight_tensor_shape
=
TensorShape
{
oc
/
pack_c
,
ic
/
pack_c
,
kernel_h
,
kernel_w
,
pack_c
,
pack_c
};
oc
/
pack_c
,
ic
/
pack_c
,
kernel_h
,
kernel_w
,
pack_c
,
pack_c
};
...
@@ -338,6 +341,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
...
@@ -338,6 +341,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
check_conv_bias
(
get_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
),
check_conv_bias
(
get_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
),
handle
(),
"F32STRD2_SMALL_GROUP"
);
handle
(),
"F32STRD2_SMALL_GROUP"
);
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_NCHW_NCHW44_F32
)
{
check_conv_bias
(
get_nchw44_conv_bias_args
({
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
),
handle
(),
"F32_CONV_NCHW_NCHW44"
);
}
/**********************************F16 direct************************/
/**********************************F16 direct************************/
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_DIRECT_FP16_LARGE_GROUP
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_DIRECT_FP16_LARGE_GROUP
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录