Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6c29548d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6c29548d
编写于
6月 02, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/arm): fix nchw_nchw44 dot stride1 support
GitOrigin-RevId: c8d3d55b258e2a43c27b903808566f2ea1857842
上级
02cbb13b
变更
12
展开全部
显示空白变更内容
内联
并排
Showing
12 changed file
with
1432 addition
and
200 deletion
+1432
-200
dnn/src/arm_common/conv_bias/block_helper.h
dnn/src/arm_common/conv_bias/block_helper.h
+36
-0
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
+7
-21
dnn/src/arm_common/conv_bias/int8/algos.h
dnn/src/arm_common/conv_bias/int8/algos.h
+15
-0
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
...arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
+321
-0
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
...c/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
+779
-0
dnn/src/arm_common/conv_bias/intrinsic_helper.h
dnn/src/arm_common/conv_bias/intrinsic_helper.h
+200
-169
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+2
-0
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+1
-0
dnn/src/arm_common/neon_struct.h
dnn/src/arm_common/neon_struct.h
+8
-0
dnn/src/arm_common/simd_macro/marm_neon.h
dnn/src/arm_common/simd_macro/marm_neon.h
+40
-6
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+13
-4
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+10
-0
未找到文件。
dnn/src/arm_common/conv_bias/block_helper.h
0 → 100644
浏览文件 @
6c29548d
/**
* \file dnn/src/arm_common/conv_bias/block_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
namespace
megdnn
{
namespace
{
// block_helper is used to calculate oh block size
static
inline
int
l2_block_helper
(
const
int
nthread
,
const
int
amount
,
const
int
size_per_unit
)
{
constexpr
int
l2_cache_size
=
256
*
1024
;
const
int
block_per_thread
=
div_ceil
(
amount
,
nthread
);
const
int
best_block
=
std
::
min
(
amount
,
(
l2_cache_size
+
size_per_unit
/
2
)
/
size_per_unit
);
const
int
max_block_num
=
div_ceil
(
block_per_thread
,
best_block
);
const
int
min_block_num
=
std
::
max
(
max_block_num
-
1
,
1
);
const
int
max_block
=
div_ceil
(
block_per_thread
,
max_block_num
);
const
int
min_block
=
div_ceil
(
block_per_thread
,
min_block_num
);
const
int
max_loss
=
std
::
abs
(
max_block_num
*
max_block
-
block_per_thread
);
const
int
min_loss
=
std
::
abs
(
min_block_num
*
min_block
-
block_per_thread
);
int
block
=
max_loss
>
min_loss
?
min_block
:
max_block
;
return
block
;
}
}
// namespace
}
// namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
浏览文件 @
6c29548d
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
*/
*/
#include "megdnn/oprs.h"
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
...
@@ -26,22 +27,7 @@ using conv_fun = std::function<void(
...
@@ -26,22 +27,7 @@ using conv_fun = std::function<void(
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_fp32_nchw44_stride1
)
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_fp32_nchw44_stride1
)
namespace
{
namespace
{
// block_helper is used to calculate oh block size
static
inline
int
block_helper
(
const
int
nthread
,
const
int
amount
,
const
int
size_per_unit
)
{
constexpr
int
l2_cache_size
=
256
*
1024
;
const
int
block_per_thread
=
div_ceil
(
amount
,
nthread
);
const
int
best_block
=
std
::
min
(
amount
,
(
l2_cache_size
+
size_per_unit
/
2
)
/
size_per_unit
);
const
int
max_block_num
=
div_ceil
(
block_per_thread
,
best_block
);
const
int
min_block_num
=
std
::
max
(
max_block_num
-
1
,
1
);
const
int
max_block
=
div_ceil
(
block_per_thread
,
max_block_num
);
const
int
min_block
=
div_ceil
(
block_per_thread
,
min_block_num
);
const
int
max_loss
=
std
::
abs
(
max_block_num
*
max_block
-
block_per_thread
);
const
int
min_loss
=
std
::
abs
(
min_block_num
*
min_block
-
block_per_thread
);
int
block
=
max_loss
>
min_loss
?
min_block
:
max_block
;
return
block
;
}
static
inline
size_t
get_perthread_cache_bytes
(
const
int
ic
,
const
int
ih2
,
static
inline
size_t
get_perthread_cache_bytes
(
const
int
ic
,
const
int
ih2
,
const
int
iw2
)
{
const
int
iw2
)
{
// border_size is used to avoid read illegal memory
// border_size is used to avoid read illegal memory
...
@@ -60,7 +46,7 @@ static void get_rectified_size(
...
@@ -60,7 +46,7 @@ static void get_rectified_size(
ow2
=
ow
;
ow2
=
ow
;
constexpr
int
cacheline
=
64
/
sizeof
(
float
);
constexpr
int
cacheline
=
64
/
sizeof
(
float
);
int
block_oh
=
int
block_oh
=
block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
float
)
*
2
);
l2_
block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
float
)
*
2
);
auto
&&
fm
=
param
.
filter_meta
;
auto
&&
fm
=
param
.
filter_meta
;
const
int
stride_h
=
static_cast
<
int
>
(
fm
.
stride
[
0
]);
const
int
stride_h
=
static_cast
<
int
>
(
fm
.
stride
[
0
]);
const
int
filter_h
=
static_cast
<
int
>
(
fm
.
spatial
[
0
]);
const
int
filter_h
=
static_cast
<
int
>
(
fm
.
spatial
[
0
]);
...
@@ -106,7 +92,7 @@ static void do_conv_kern(WorkspaceBundle bundle,
...
@@ -106,7 +92,7 @@ static void do_conv_kern(WorkspaceBundle bundle,
const
int
group_id
=
ncb_index
.
ndrange_id
[
1
];
const
int
group_id
=
ncb_index
.
ndrange_id
[
1
];
constexpr
int
oc_idx
=
0
;
constexpr
int
oc_idx
=
0
;
int
oc_block
=
oc
;
int
oc_block
=
oc
;
int
oh_block
=
block_helper
(
kern_param
.
nr_threads
,
oh2
,
int
oh_block
=
l2_
block_helper
(
kern_param
.
nr_threads
,
oh2
,
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
const
int
oh_idx
=
ncb_index
.
ndrange_id
[
2
];
const
int
oh_idx
=
ncb_index
.
ndrange_id
[
2
];
const
int
oh_block_real
=
std
::
min
(
oh
-
oh_idx
*
oh_block
,
oh_block
);
const
int
oh_block_real
=
std
::
min
(
oh
-
oh_idx
*
oh_block
,
oh_block
);
...
@@ -298,7 +284,7 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns(
...
@@ -298,7 +284,7 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns(
int
ic
=
param
.
filter_meta
.
icpg
;
int
ic
=
param
.
filter_meta
.
icpg
;
int
iw
=
param
.
isz
[
1
];
int
iw
=
param
.
isz
[
1
];
int
stride_h
=
param
.
filter_meta
.
stride
[
0
];
int
stride_h
=
param
.
filter_meta
.
stride
[
0
];
int
oh_block
=
block_helper
(
param
.
nr_threads
,
oh
,
int
oh_block
=
l2_
block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
CpuNDRange
ncb_range
=
{
static_cast
<
size_t
>
(
batch
),
CpuNDRange
ncb_range
=
{
static_cast
<
size_t
>
(
batch
),
static_cast
<
size_t
>
(
group
),
static_cast
<
size_t
>
(
group
),
...
...
dnn/src/arm_common/conv_bias/int8/algos.h
浏览文件 @
6c29548d
...
@@ -133,6 +133,21 @@ public:
...
@@ -133,6 +133,21 @@ public:
};
};
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
class
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARMDOTS8_NCHW_NCHW44"
;
}
bool
usable
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
override
;
size_t
get_workspace
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
override
;
};
class
ConvBiasImpl
::
AlgoDotS8DirectStride1
final
:
public
AlgoBase
{
class
ConvBiasImpl
::
AlgoDotS8DirectStride1
final
:
public
AlgoBase
{
bool
m_large_group
;
bool
m_large_group
;
...
...
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
0 → 100644
浏览文件 @
6c29548d
/**
* \file
* dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
using
conv_fun
=
std
::
function
<
void
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_int8_nchw44_dot
)
namespace
{
static
inline
size_t
get_perthread_cache_bytes
(
const
int
ic
,
const
int
ih2
,
const
int
iw2
,
const
int
stride
)
{
//! border_size is used to avoid read illegal memory
constexpr
int
cacheline_size
=
64
;
constexpr
int
border_size
=
2
*
cacheline_size
;
const
int
pack_iw_len
=
stride
==
1
?
4
:
1
;
return
round_up
(
ic
*
ih2
*
iw2
*
pack_iw_len
*
(
int
)
sizeof
(
int8_t
)
+
border_size
,
cacheline_size
);
}
static
inline
size_t
get_temp_bytes
(
const
int
iw
,
const
int
pw
)
{
//! border_size is used to avoid read illegal memory
constexpr
int
cacheline_size
=
64
;
constexpr
int
border_size
=
1
*
cacheline_size
;
return
round_up
(
iw
+
pw
*
2
,
cacheline_size
)
+
border_size
;
}
static
void
get_rectified_size
(
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
int
&
ih2
,
int
&
iw2
)
{
auto
&&
fm
=
param
.
filter_meta
;
const
int
stride_h
=
static_cast
<
int
>
(
fm
.
stride
[
0
]);
const
int
filter_h
=
static_cast
<
int
>
(
fm
.
spatial
[
0
]);
int
ic
=
param
.
filter_meta
.
icpg
;
int
iw
=
param
.
isz
[
1
];
int
oh
=
param
.
osz
[
0
];
int
block_oh
=
l2_block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
int8_t
)
*
stride_h
);
ih2
=
block_oh
*
stride_h
+
filter_h
-
stride_h
;
iw2
=
iw
+
2
*
static_cast
<
int
>
(
fm
.
padding
[
1
]);
}
static
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
auto
&&
fm
=
param
.
filter_meta
;
int
ic
=
fm
.
icpg
;
int
fh
=
fm
.
spatial
[
0
];
int
fw
=
fm
.
spatial
[
1
];
int
iw
=
param
.
isz
[
1
];
int
pw
=
param
.
filter_meta
.
padding
[
1
];
int
stride_w
=
param
.
filter_meta
.
stride
[
1
];
int
ih2
,
iw2
;
get_rectified_size
(
param
,
ih2
,
iw2
);
size_t
src_size
=
get_perthread_cache_bytes
(
ic
,
ih2
,
iw2
,
stride_w
);
size_t
weight_size
=
fm
.
group
*
fm
.
icpg
*
fm
.
ocpg
*
fh
*
round_up
(
fw
,
4
);
size_t
temp_size
=
0
;
if
(
fm
.
stride
[
0
]
==
1
)
{
temp_size
=
get_temp_bytes
(
iw
,
pw
);
}
return
{
nullptr
,
{
src_size
*
param
.
nr_threads
,
weight_size
,
temp_size
*
param
.
nr_threads
}};
};
void
do_weight_trans
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
,
const
CpuNDRange
&
)
{
const
int
ic
=
kern_param
.
filter_meta
.
icpg
;
const
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
const
int
fh
=
kern_param
.
filter_meta
.
spatial
[
0
];
const
int
fw
=
kern_param
.
filter_meta
.
spatial
[
1
];
const
int
fw2
=
round_up
(
fw
,
4
);
bundle
.
set
(
kern_param
.
workspace_ptr
);
auto
packed_weight
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
));
auto
origin_weight
=
kern_param
.
filter
<
dt_int8
>
();
pack_weight_int8_nchw_nchw44_dot
(
packed_weight
,
origin_weight
,
oc
,
ic
,
fh
,
fw
,
fw2
);
}
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
,
int
stride
>
static
void
do_conv_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
,
const
CpuNDRange
&
)
{
const
int
oh
=
kern_param
.
osz
[
0
];
const
int
ow
=
kern_param
.
osz
[
1
];
const
int
fh
=
kern_param
.
filter_meta
.
spatial
[
0
];
const
int
fw
=
kern_param
.
filter_meta
.
spatial
[
1
];
const
int
ic
=
kern_param
.
filter_meta
.
icpg
;
const
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
const
int
ih
=
kern_param
.
isz
[
0
];
const
int
iw
=
kern_param
.
isz
[
1
];
const
int
stride_h
=
kern_param
.
filter_meta
.
stride
[
0
];
const
int
stride_w
=
kern_param
.
filter_meta
.
stride
[
1
];
const
int
ph
=
kern_param
.
filter_meta
.
padding
[
0
];
const
int
pw
=
kern_param
.
filter_meta
.
padding
[
1
];
int
ih2
=
0
;
int
iw2
=
0
;
get_rectified_size
(
kern_param
,
ih2
,
iw2
);
bundle
.
set
(
kern_param
.
workspace_ptr
);
constexpr
int
pack_c
=
4
;
const
int
batch_id
=
ncb_index
.
ndrange_id
[
0
];
const
int
group_id
=
ncb_index
.
ndrange_id
[
1
];
constexpr
int
oc_idx
=
0
;
int
oc_block
=
oc
;
int
oh_block
=
l2_block_helper
(
kern_param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
int8_t
)
*
stride_h
);
const
int
oh_idx
=
ncb_index
.
ndrange_id
[
2
];
const
int
oh_block_real
=
std
::
min
(
oh
-
oh_idx
*
oh_block
,
oh_block
);
const
int
ih_real
=
oh_block_real
*
stride_h
+
fh
-
stride_h
;
const
int
src_top_pad
=
std
::
max
(
ph
-
oh_idx
*
oh_block
*
stride_h
,
0
);
const
int
src_bottom_pad
=
std
::
max
(
(
oh_idx
*
oh_block
+
oh_block_real
-
1
)
*
stride_h
+
fh
-
ih
-
ph
,
0
);
const
int
remain_right_pad
=
std
::
max
(
iw2
-
iw
-
pw
,
0
);
const
int
src_offset
=
std
::
max
(
oh_idx
*
oh_block
*
stride_h
-
ph
,
0
)
*
iw
;
const
int8_t
*
origin_sptr
=
static_cast
<
const
int8_t
*>
(
kern_param
.
src
<
int8_t
>
(
batch_id
,
group_id
,
0
,
1
,
1
))
+
src_offset
;
const
size_t
src_size
=
get_perthread_cache_bytes
(
ic
,
ih2
,
iw2
,
stride_w
);
int8_t
*
sptr
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
0
))
+
ncb_index
.
thread_id
*
src_size
;
int8_t
*
tmp_ptr
=
nullptr
;
if
(
stride
==
1
)
{
const
size_t
tmp_size
=
get_temp_bytes
(
iw
,
pw
);
tmp_ptr
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
2
))
+
ncb_index
.
thread_id
*
tmp_size
;
}
pack_src_int8_nchw_nchw44_dot
<
stride
>
(
sptr
,
origin_sptr
,
ph
,
pw
,
remain_right_pad
,
ih_real
-
src_top_pad
-
src_bottom_pad
,
iw
,
iw2
,
src_top_pad
,
src_bottom_pad
,
ic
,
ih
*
iw
,
tmp_ptr
);
const
int8_t
*
fptr
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
))
+
oc_idx
*
fh
*
fw
*
ic
;
int8_t
*
dst
=
kern_param
.
dst
<
int8_t
>
(
batch_id
,
group_id
)
+
oh_idx
*
oh_block
*
ow
*
pack_c
;
const
int
bias_offset
=
oc_idx
;
const
int32_t
*
bptr
=
kern_param
.
bias
<
dt_int32
>
(
batch_id
,
group_id
)
+
bias_offset
;
float
scale_bias
=
kern_param
.
bias_type
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
float
scale_dst
=
kern_param
.
dst_type
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
Op
op
(
scale_bias
,
scale_dst
);
conv_direct_int8_nchw_nchw44_dot
<
bias_mode
,
Op
,
filter
,
stride
>
(
sptr
,
fptr
,
bptr
,
nullptr
,
dst
,
oc_block
,
ic
,
ih_real
,
iw2
,
oh
,
oh_block_real
,
ow
,
op
);
}
}
// namespace
bool
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
::
usable
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
auto
&&
fm
=
param
.
filter_meta
;
auto
fh
=
fm
.
spatial
[
0
];
int
oc
=
fm
.
ocpg
;
int
ic
=
fm
.
icpg
;
bool
ok_type
=
((
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)))
&&
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
);
bool
ok_src_dst
=
(
oc
%
4
==
0
&&
oc
>=
4
&&
ic
<
4
);
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fh
==
fm
.
spatial
[
1
]
&&
(
fh
==
2
||
fh
==
3
||
fh
==
5
||
fh
==
7
);
bool
ok_slide
=
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
fm
.
stride
[
1
]
&&
(
fm
.
stride
[
0
]
==
1
||
fm
.
stride
[
0
]
==
2
);
bool
ok_conv
=
!
fm
.
should_flip
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
bool
avaible
=
ok_type
&&
ok_src_dst
&&
ok_filter
&&
ok_slide
&&
ok_conv
;
return
avaible
;
}
size_t
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
::
get_workspace
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
get_bundle
(
param
).
total_size_in_bytes
();
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
::
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
auto
fm
=
param
.
filter_meta
;
const
int
batch
=
param
.
n
;
const
int
group
=
fm
.
group
;
WorkspaceBundle
wbundle
=
get_bundle
(
param
);
conv_fun
do_conv_fun
=
nullptr
;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
switch
(
param
.
filter_meta
.
stride
[
0
])
{
case
1
:
DISPATCH_CONV_KERN
(
1
);
break
;
case
2
:
DISPATCH_CONV_KERN
(
2
);
break
;
default:
megdnn_assert
(
0
);
break
;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert
(
do_conv_fun
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kerns
;
WorkspaceBundle
bundle
=
wbundle
;
int
oh
=
param
.
osz
[
0
];
int
ic
=
param
.
filter_meta
.
icpg
;
int
iw
=
param
.
isz
[
1
];
int
stride_h
=
param
.
filter_meta
.
stride
[
0
];
int
oh_block
=
l2_block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
int8_t
)
*
stride_h
);
CpuNDRange
ncb_range
=
{
static_cast
<
size_t
>
(
batch
),
static_cast
<
size_t
>
(
group
),
static_cast
<
size_t
>
(
div_ceil
(
oh
,
oh_block
))};
auto
do_trans_weight
=
[
bundle
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
do_weight_trans
(
bundle
,
kern_param
,
ncb_index
,
ncb_index
.
ndrange_id
);
};
ret_kerns
.
push_back
({
do_trans_weight
,
{
1
}});
auto
do_conv
=
[
bundle
,
do_conv_fun
,
ncb_range
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
do_conv_fun
(
bundle
,
kern_param
,
ncb_index
,
ncb_index
.
ndrange_id
,
ncb_range
);
};
ret_kerns
.
push_back
({
do_conv
,
ncb_range
});
return
ret_kerns
;
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
0 → 100644
浏览文件 @
6c29548d
此差异已折叠。
点击以展开。
dnn/src/arm_common/conv_bias/intrinsic_helper.h
浏览文件 @
6c29548d
此差异已折叠。
点击以展开。
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
6c29548d
...
@@ -55,6 +55,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
...
@@ -55,6 +55,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8ChanWiseStride2NCHW44
s8_channel_wise_stride2_nchw44
;
AlgoS8ChanWiseStride2NCHW44
s8_channel_wise_stride2_nchw44
;
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectNCHWNCHW44
ds8_direct_stride2_nchw_nchw44
;
AlgoDotS8DirectStride1
ds8_direct_stride1_large_group
{
true
};
AlgoDotS8DirectStride1
ds8_direct_stride1_large_group
{
true
};
AlgoDotS8DirectStride1
ds8_direct_stride1_small_group
{
false
};
AlgoDotS8DirectStride1
ds8_direct_stride1_small_group
{
false
};
AlgoDotS8DirectStride2
ds8_direct_stride2_large_group
{
true
};
AlgoDotS8DirectStride2
ds8_direct_stride2_large_group
{
true
};
...
@@ -93,6 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
...
@@ -93,6 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
public:
AlgoPack
()
{
AlgoPack
()
{
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
direct_algos
.
emplace_back
(
&
ds8_direct_stride2_nchw_nchw44
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride1_large_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride1_large_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride2_large_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride2_large_group
);
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
6c29548d
...
@@ -62,6 +62,7 @@ private:
...
@@ -62,6 +62,7 @@ private:
class
AlgoFP16WinogradF23_8x8
;
class
AlgoFP16WinogradF23_8x8
;
#endif
#endif
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
class
AlgoDotS8DirectNCHWNCHW44
;
class
AlgoDotS8DirectStride1
;
class
AlgoDotS8DirectStride1
;
class
AlgoDotS8DirectStride2
;
class
AlgoDotS8DirectStride2
;
class
AlgoDotU8DirectStride1
;
class
AlgoDotU8DirectStride1
;
...
...
dnn/src/arm_common/neon_struct.h
浏览文件 @
6c29548d
...
@@ -60,6 +60,14 @@ struct Vfmaq_laneq_f32 {
...
@@ -60,6 +60,14 @@ struct Vfmaq_laneq_f32 {
return
vfmaq_laneq_f32
(
a
,
b
,
v
,
lane
);
return
vfmaq_laneq_f32
(
a
,
b
,
v
,
lane
);
}
}
};
};
#if __ARM_FEATURE_DOTPROD
struct
Vdotq_laneq_s32
{
template
<
const
int
lane
>
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_laneq_s32
(
a
,
b
,
v
,
lane
);
}
};
#endif
}
// namespace
}
// namespace
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/arm_common/simd_macro/marm_neon.h
浏览文件 @
6c29548d
...
@@ -481,37 +481,71 @@ UNROLL_CALL_RAW(4, cb);
...
@@ -481,37 +481,71 @@ UNROLL_CALL_RAW(4, cb);
#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec)
#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec)
namespace
{
namespace
{
template
<
int
lane
>
template
<
int
lane
>
struct
Vfma
p
_laneq_f32_armv7
{
struct
Vfma
q
_laneq_f32_armv7
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
);
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
);
};
};
template
<
>
template
<
>
struct
Vfma
p
_laneq_f32_armv7
<
0
>
{
struct
Vfma
q
_laneq_f32_armv7
<
0
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
0
);
return
vmlaq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
0
);
}
}
};
};
template
<
>
template
<
>
struct
Vfma
p
_laneq_f32_armv7
<
1
>
{
struct
Vfma
q
_laneq_f32_armv7
<
1
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
1
);
return
vmlaq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
1
);
}
}
};
};
template
<
>
template
<
>
struct
Vfma
p
_laneq_f32_armv7
<
2
>
{
struct
Vfma
q
_laneq_f32_armv7
<
2
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
0
);
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
0
);
}
}
};
};
template
<
>
template
<
>
struct
Vfma
p
_laneq_f32_armv7
<
3
>
{
struct
Vfma
q
_laneq_f32_armv7
<
3
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
}
}
};
};
}
// namespace
}
// namespace
#define vfmaq_laneq_f32(a, b, v, lane) \
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmap_laneq_f32_armv7<lane>::impl(a, b, v)
Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v)
#if __ARM_FEATURE_DOTPROD
template
<
int
lane
>
struct
Vdotq_laneq_s32_armv7
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
);
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
0
>
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_low_s32
(
v
),
0
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
1
>
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_low_s32
(
v
),
1
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
2
>
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_high_s32
(
v
),
0
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
3
>
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
}
};
#define vdotq_laneq_s32(a, b, v, lane) \
Vdotq_laneq_s32_armv7<lane>::impl(a, b, v)
#endif
#endif
#endif
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
6c29548d
...
@@ -109,14 +109,12 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
...
@@ -109,14 +109,12 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
.
set_dtype
(
4
,
dtype
::
QuantizedS8
(
60.25
))
.
set_dtype
(
4
,
dtype
::
QuantizedS8
(
60.25
))
.
set_display
(
false
);
.
set_display
(
false
);
benchmarker_int
.
set_before_exec_callback
(
benchmarker_int
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
"IM2COLMATMUL:.+"
));
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"
));
Benchmarker
<
ConvBias
>
benchmarker_float
(
handle
);
Benchmarker
<
ConvBias
>
benchmarker_float
(
handle
);
benchmarker_float
.
set_display
(
false
).
set_times
(
RUNS
);
benchmarker_float
.
set_display
(
false
).
set_times
(
RUNS
);
benchmarker_float
.
set_before_exec_callback
(
benchmarker_float
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
"IM2COLMATMUL:.+"
));
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"
));
Benchmarker
<
ConvBias
>
benchmarker_nchw44
(
handle
);
Benchmarker
<
ConvBias
>
benchmarker_nchw44
(
handle
);
if
(
is_fp32
)
{
if
(
is_fp32
)
{
...
@@ -213,6 +211,15 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
...
@@ -213,6 +211,15 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
run
(
1
,
256
,
256
,
14
,
14
,
3
,
1
,
false
);
run
(
1
,
256
,
256
,
14
,
14
,
3
,
1
,
false
);
run
(
1
,
512
,
512
,
7
,
7
,
3
,
1
,
false
);
run
(
1
,
512
,
512
,
7
,
7
,
3
,
1
,
false
);
}
else
{
}
else
{
run
(
1
,
1
,
4
,
112
,
112
,
2
,
2
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
3
,
2
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
5
,
2
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
2
,
true
);
run
(
1
,
1
,
4
,
112
,
112
,
2
,
1
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
3
,
1
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
5
,
1
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
1
,
true
);
for
(
size_t
stride
:
{
1
,
2
})
{
for
(
size_t
stride
:
{
1
,
2
})
{
printf
(
"stride %zu
\n
"
,
stride
);
printf
(
"stride %zu
\n
"
,
stride
);
for
(
size_t
filter_size
:
{
2
,
3
,
5
,
7
})
{
for
(
size_t
filter_size
:
{
2
,
3
,
5
,
7
})
{
...
@@ -228,9 +235,11 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
...
@@ -228,9 +235,11 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
}
}
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONVBIAS_NCHW44
)
{
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONVBIAS_NCHW44
)
{
benchmark_convbias
(
handle
(),
true
);
benchmark_convbias
(
handle
(),
true
);
benchmark_convbias
(
handle
(),
false
);
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
BENCHMARK_CONVBIAS_NCHW44
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
BENCHMARK_CONVBIAS_NCHW44
)
{
benchmark_convbias
(
handle
(),
true
);
benchmark_convbias
(
handle
(),
true
);
benchmark_convbias
(
handle
(),
false
);
}
}
#endif
#endif
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
6c29548d
...
@@ -557,6 +557,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
...
@@ -557,6 +557,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
/****************************dot qint8 direct*************************/
/****************************dot qint8 direct*************************/
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_DOT_NCHW_NCHW44
)
{
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
),
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
false
,
false
,
true
),
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP
)
{
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP
)
{
checker_conv_bias_qint8x8x8
(
get_int8_quint8_conv_bias_args
(
checker_conv_bias_qint8x8x8
(
get_int8_quint8_conv_bias_args
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录