Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d346c878
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d346c878
编写于
3月 20, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/fallbackls): delete the conv_bias fallback offset
GitOrigin-RevId: c91aee2c7cfc95d1f31cc7f7eb7a05ece40ba002
上级
a7e28712
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
189 addition
and
93 deletion
+189
-93
dnn/src/fallback/conv_bias/im2col/algos.cpp
dnn/src/fallback/conv_bias/im2col/algos.cpp
+45
-33
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+120
-55
dnn/src/fallback/conv_bias/opr_impl.h
dnn/src/fallback/conv_bias/opr_impl.h
+15
-0
dnn/src/fallback/conv_bias/winograd/winograd.h
dnn/src/fallback/conv_bias/winograd/winograd.h
+9
-5
未找到文件。
dnn/src/fallback/conv_bias/im2col/algos.cpp
浏览文件 @
d346c878
...
@@ -57,11 +57,12 @@ public:
...
@@ -57,11 +57,12 @@ public:
const
ConvBiasImpl
::
NCBKernParam
&
param
,
const
ConvBiasImpl
::
NCBKernParam
&
param
,
const
WorkspaceBundle
&
bundle_thread
,
size_t
bundle_id
,
const
WorkspaceBundle
&
bundle_thread
,
size_t
bundle_id
,
size_t
oc_cur_index
,
size_t
OHW
,
bool
is_dst_8bit
,
size_t
oc_cur_index
,
size_t
OHW
,
bool
is_dst_8bit
,
bool
ohw_bigger_ohwblock
)
{
bool
ohw_bigger_ohwblock
,
size_t
batch_id
,
size_t
group_id
)
{
if
(
is_dst_8bit
||
!
ohw_bigger_ohwblock
)
{
if
(
is_dst_8bit
||
!
ohw_bigger_ohwblock
)
{
return
static_cast
<
dtype
*>
(
bundle_thread
.
get
(
bundle_id
));
return
static_cast
<
dtype
*>
(
bundle_thread
.
get
(
bundle_id
));
}
else
{
}
else
{
dtype
*
dst
=
param
.
dst
<
dtype
>
()
+
oc_cur_index
*
OHW
;
dtype
*
dst
=
param
.
dst
<
dtype
>
(
batch_id
,
group_id
)
+
oc_cur_index
*
OHW
;
return
static_cast
<
dtype
*>
(
dst
);
return
static_cast
<
dtype
*>
(
dst
);
}
}
}
}
...
@@ -105,23 +106,24 @@ static void copy_padding_kern(WorkspaceBundle bundle,
...
@@ -105,23 +106,24 @@ static void copy_padding_kern(WorkspaceBundle bundle,
size_t
IW2
=
IW
+
2
*
PW
;
size_t
IW2
=
IW
+
2
*
PW
;
size_t
IH2
=
IH
+
2
*
PH
;
size_t
IH2
=
IH
+
2
*
PH
;
size_t
group_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
batch_id
=
ncb_index
.
ndrange_id
[
1
];
size_t
channel_id
=
ncb_index
.
ndrange_id
[
2
];
size_t
padding_group_size
=
IH2
*
IW2
*
IC
;
size_t
padding_group_size
=
IH2
*
IW2
*
IC
;
size_t
input_channel_offset
=
IH
*
IW
*
ncb_index
.
ndrange_id
[
2
];
size_t
input_channel_offset
=
IH
*
IW
*
channel_id
;
size_t
workspace_channel_offset
=
IH2
*
IW2
*
ncb_index
.
ndrange_id
[
2
];
size_t
workspace_channel_offset
=
IH2
*
IW2
*
channel_id
;
size_t
workspace_group_offset
=
size_t
workspace_group_offset
=
group_id
*
padding_group_size
;
ncb_index
.
ndrange_id
[
0
]
*
padding_group_size
;
size_t
workspace_batch_offset
=
size_t
workspace_batch_offset
=
param
.
filter_meta
.
group
*
param
.
filter_meta
.
group
*
batch_id
*
padding_group_size
;
ncb_index
.
ndrange_id
[
1
]
*
padding_group_size
;
bundle
.
set
(
param
.
workspace_ptr
);
bundle
.
set
(
param
.
workspace_ptr
);
src_ctype
src_zp
=
static_cast
<
src_ctype
>
(
0
);
src_ctype
src_zp
=
static_cast
<
src_ctype
>
(
0
);
if
(
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
if
(
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
src_zp
=
param
.
src_type
.
param
<
dtype
::
Quantized8Asymm
>
().
zero_point
;
src_zp
=
param
.
src_type
.
param
<
dtype
::
Quantized8Asymm
>
().
zero_point
;
}
}
src_ctype
*
src
=
const_cast
<
src_ctype
*>
(
param
.
src
<
src_ctype
>
()
+
src_ctype
*
src
=
const_cast
<
src_ctype
*>
(
input_channel_offset
);
param
.
src
<
src_ctype
>
(
batch_id
,
group_id
)
+
input_channel_offset
);
src_ctype
*
src2
;
src_ctype
*
src2
;
src2
=
static_cast
<
src_ctype
*>
(
src2
=
static_cast
<
src_ctype
*>
(
bundle
.
get
(
Im2colBundelIndex
::
BUNDLE_PADDING_INDEX
))
+
bundle
.
get
(
Im2colBundelIndex
::
BUNDLE_PADDING_INDEX
))
+
...
@@ -153,8 +155,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
...
@@ -153,8 +155,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
*/
*/
#define COPY_BIAS() \
#define COPY_BIAS() \
const bias_ctype* bias_ptr =
\
const bias_ctype* bias_ptr =
static_cast<const bias_ctype*>(
\
static_cast<const bias_ctype*>(param.bias_ptr);
\
param.bias<bias_ctype>(batch_id, group_id));
\
bias_ctype* bias_temp_ptr = \
bias_ctype* bias_temp_ptr = \
PtrGetter::get_bias_temp_ptr<bias_ctype>(param, bundle_thread); \
PtrGetter::get_bias_temp_ptr<bias_ctype>(param, bundle_thread); \
if (param.bias_mode == megdnn::BiasMode::BIAS) { \
if (param.bias_mode == megdnn::BiasMode::BIAS) { \
...
@@ -172,7 +174,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
...
@@ -172,7 +174,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
#define IM2COL() \
#define IM2COL() \
src_ctype* im2col_dst = nullptr; \
src_ctype* im2col_dst = nullptr; \
src_ctype* no_padding_src = \
src_ctype* no_padding_src = \
const_cast<src_ctype*>(param.src<src_ctype>()) + ohw_cur_index; \
const_cast<src_ctype*>(param.src<src_ctype>(batch_id, group_id)) + \
ohw_cur_index; \
if (!special_1x1) { \
if (!special_1x1) { \
size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \
size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \
src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
...
@@ -181,7 +184,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
...
@@ -181,7 +184,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
param.filter_meta.group * ncb_index.ndrange_id[1]) * \
param.filter_meta.group * ncb_index.ndrange_id[1]) * \
padding_group_size); \
padding_group_size); \
if (PH == 0 && PW == 0) { \
if (PH == 0 && PW == 0) { \
src2 = const_cast<src_ctype*>(param.src<src_ctype>()); \
src2 = const_cast<src_ctype*>( \
param.src<src_ctype>(batch_id, group_id)); \
} \
} \
im2col_dst = static_cast<src_ctype*>(bundle_thread.get( \
im2col_dst = static_cast<src_ctype*>(bundle_thread.get( \
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \
...
@@ -217,8 +221,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
...
@@ -217,8 +221,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
output_block_size); \
output_block_size); \
if (!skip_copy_dst) { \
if (!skip_copy_dst) { \
dst_ctype* dst_tmp_ptr = reinterpret_cast<dst_ctype*>(matmul_dst); \
dst_ctype* dst_tmp_ptr = reinterpret_cast<dst_ctype*>(matmul_dst); \
dst_ctype* dst =
\
dst_ctype* dst =
param.dst<dst_ctype>(batch_id, group_id) +
\
param.dst<dst_ctype>() + oc_cur_index * OHW + ohw_cur_index;
\
oc_cur_index * OHW + ohw_cur_index;
\
for (size_t oc = 0; oc < output_block_oc_size; oc++) { \
for (size_t oc = 0; oc < output_block_oc_size; oc++) { \
std::memcpy(dst, dst_tmp_ptr, \
std::memcpy(dst, dst_tmp_ptr, \
sizeof(dst_ctype) * output_block_size); \
sizeof(dst_ctype) * output_block_size); \
...
@@ -243,7 +247,7 @@ static void copy_padding_kern(WorkspaceBundle bundle,
...
@@ -243,7 +247,7 @@ static void copy_padding_kern(WorkspaceBundle bundle,
bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
param, bundle_thread, \
param, bundle_thread, \
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \
is_dst_8bit, is_ohw_size_bigger);
is_dst_8bit, is_ohw_size_bigger
, batch_id, group_id
);
#define MATMUL_COMPUTE() \
#define MATMUL_COMPUTE() \
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \
...
@@ -272,6 +276,7 @@ public:
...
@@ -272,6 +276,7 @@ public:
ConvBiasImpl
::
NCBKernIndex
ncb_index
)
{
ConvBiasImpl
::
NCBKernIndex
ncb_index
)
{
bundle
.
set
(
param
.
workspace_ptr
);
bundle
.
set
(
param
.
workspace_ptr
);
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
;
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
;
size_t
group_id
=
ncb_index
.
ndrange_id
[
0
];
static_cast
<
fallback
::
MatrixMulImpl
::
KernSizeParam
&>
(
matmul_param
)
=
static_cast
<
fallback
::
MatrixMulImpl
::
KernSizeParam
&>
(
matmul_param
)
=
matmulparam
;
matmulparam
;
size_t
packA_group_size
=
size_t
packA_group_size
=
...
@@ -283,11 +288,11 @@ public:
...
@@ -283,11 +288,11 @@ public:
matmul_algo
->
get_packA_type_size
();
matmul_algo
->
get_packA_type_size
();
size_t
a_panel_offset
=
size_t
a_panel_offset
=
ncb_index
.
ndrange_id
[
2
]
*
packed_per_oc_block_size
;
ncb_index
.
ndrange_id
[
2
]
*
packed_per_oc_block_size
;
int8_t
*
a_panel
=
int8_t
*
a_panel
=
static_cast
<
int8_t
*>
(
bundle
.
get
(
static_cast
<
int8_t
*>
(
Im2colBundelIndex
::
BUNDLE_PACKA_INDEX
))
+
bundle
.
get
(
Im2colBundelIndex
::
BUNDLE_PACKA_INDEX
))
+
group_id
*
packA_group_size
+
a_panel_offset
;
ncb_index
.
ndrange_id
[
0
]
*
packA_group_size
+
a_panel_offset
;
matmul_param
.
A_ptr
=
matmul_param
.
A_ptr
=
const_cast
<
src_ctype
*>
(
param
.
filter
<
src_ctype
>
(
));
const_cast
<
src_ctype
*>
(
param
.
filter
<
src_ctype
>
(
group_id
));
matmul_algo
->
pack_A
(
matmul_param
,
a_panel
,
ncb_index
.
ndrange_id
[
2
],
matmul_algo
->
pack_A
(
matmul_param
,
a_panel
,
ncb_index
.
ndrange_id
[
2
],
matmul_algo
->
get_inner_block_size
().
m
);
matmul_algo
->
get_inner_block_size
().
m
);
};
};
...
@@ -309,6 +314,8 @@ public:
...
@@ -309,6 +314,8 @@ public:
auto
IH2
=
IH
+
2
*
PH
;
auto
IH2
=
IH
+
2
*
PH
;
auto
IW2
=
IW
+
2
*
PW
;
auto
IW2
=
IW
+
2
*
PW
;
size_t
OHW
=
OH
*
OW
;
size_t
OHW
=
OH
*
OW
;
size_t
group_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
batch_id
=
ncb_index
.
ndrange_id
[
1
];
size_t
output_block_size
=
std
::
min
(
size_t
output_block_size
=
std
::
min
(
ohw_tile_size
,
OHW
-
ncb_index
.
ndrange_id
[
2
]
*
ohw_tile_size
);
ohw_tile_size
,
OHW
-
ncb_index
.
ndrange_id
[
2
]
*
ohw_tile_size
);
size_t
output_block_oc_size
=
std
::
min
(
size_t
output_block_oc_size
=
std
::
min
(
...
@@ -369,11 +376,11 @@ public:
...
@@ -369,11 +376,11 @@ public:
\
\
src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \
bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \
ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset);
\
group_id * packA_group_size + a_panel_offset);
\
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
param, bundle_thread, \
param, bundle_thread, \
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \
OHW, is_dst_8bit, is_ohw_size_bigger);
OHW, is_dst_8bit, is_ohw_size_bigger
, batch_id, group_id
);
#define MATMUL_COMPUTE() \
#define MATMUL_COMPUTE() \
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \
...
@@ -402,6 +409,7 @@ public:
...
@@ -402,6 +409,7 @@ public:
matmulparam
;
matmulparam
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
oc_tile_size
=
matmul_param
.
M
;
size_t
oc_tile_size
=
matmul_param
.
M
;
size_t
group_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
output_block_oc_size
=
std
::
min
(
size_t
output_block_oc_size
=
std
::
min
(
oc_tile_size
,
OC
-
ncb_index
.
ndrange_id
[
2
]
*
oc_tile_size
);
oc_tile_size
,
OC
-
ncb_index
.
ndrange_id
[
2
]
*
oc_tile_size
);
size_t
oc_cur_index
=
ncb_index
.
ndrange_id
[
2
]
*
oc_tile_size
;
size_t
oc_cur_index
=
ncb_index
.
ndrange_id
[
2
]
*
oc_tile_size
;
...
@@ -411,11 +419,11 @@ public:
...
@@ -411,11 +419,11 @@ public:
size_t
a_panel_offset
=
size_t
a_panel_offset
=
ncb_index
.
ndrange_id
[
2
]
*
ncb_index
.
ndrange_id
[
2
]
*
matmul_algo
->
get_bundle
(
matmul_param
).
get_size
(
0
);
matmul_algo
->
get_bundle
(
matmul_param
).
get_size
(
0
);
int8_t
*
a_panel
=
int8_t
*
a_panel
=
static_cast
<
int8_t
*>
(
bundle
.
get
(
static_cast
<
int8_t
*>
(
Im2colBundelIndex
::
BUNDLE_PACKA_INDEX
))
+
bundle
.
get
(
Im2colBundelIndex
::
BUNDLE_PACKA_INDEX
))
+
group_id
*
packA_group_size
+
a_panel_offset
;
ncb_index
.
ndrange_id
[
0
]
*
packA_group_size
+
a_panel_offset
;
matmul_param
.
A_ptr
=
matmul_param
.
A_ptr
=
const_cast
<
src_ctype
*>
(
param
.
filter
<
src_ctype
>
(
))
+
const_cast
<
src_ctype
*>
(
param
.
filter
<
src_ctype
>
(
group_id
))
+
oc_cur_index
*
matmul_param
.
K
;
oc_cur_index
*
matmul_param
.
K
;
matmul_param
.
M
=
output_block_oc_size
;
matmul_param
.
M
=
output_block_oc_size
;
matmul_algo
->
pack_A
(
matmul_param
,
a_panel
,
0
_z
,
0
_z
);
matmul_algo
->
pack_A
(
matmul_param
,
a_panel
,
0
_z
,
0
_z
);
...
@@ -437,6 +445,8 @@ public:
...
@@ -437,6 +445,8 @@ public:
MEGDNN_MARK_USED_VAR
(
N
);
MEGDNN_MARK_USED_VAR
(
N
);
auto
IH2
=
IH
+
2
*
PH
;
auto
IH2
=
IH
+
2
*
PH
;
auto
IW2
=
IW
+
2
*
PW
;
auto
IW2
=
IW
+
2
*
PW
;
size_t
group_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
batch_id
=
ncb_index
.
ndrange_id
[
1
];
size_t
OHW
=
OH
*
OW
;
size_t
OHW
=
OH
*
OW
;
size_t
output_block_size
=
std
::
min
(
size_t
output_block_size
=
std
::
min
(
ohw_tile_size
,
OHW
-
ncb_index
.
ndrange_id
[
2
]
*
ohw_tile_size
);
ohw_tile_size
,
OHW
-
ncb_index
.
ndrange_id
[
2
]
*
ohw_tile_size
);
...
@@ -490,11 +500,11 @@ public:
...
@@ -490,11 +500,11 @@ public:
#define PREPAR_MATMUL_DATA() \
#define PREPAR_MATMUL_DATA() \
bias_ctype* matmul_dst = nullptr; \
bias_ctype* matmul_dst = nullptr; \
const src_ctype* filter = \
const src_ctype* filter = \
param.filter<src_ctype>(
) + oc_cur_index * IC * FH * FW;
\
param.filter<src_ctype>(
group_id) + oc_cur_index * IC * FH * FW;
\
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
param, bundle_thread, \
param, bundle_thread, \
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \
OHW, is_dst_8bit, is_ohw_size_bigger);
OHW, is_dst_8bit, is_ohw_size_bigger
, batch_id, group_id
);
#define MATMUL_COMPUTE() \
#define MATMUL_COMPUTE() \
matmul_param.M = output_block_oc_size; \
matmul_param.M = output_block_oc_size; \
...
@@ -526,6 +536,8 @@ public:
...
@@ -526,6 +536,8 @@ public:
MEGDNN_MARK_USED_VAR
(
N
);
MEGDNN_MARK_USED_VAR
(
N
);
auto
IH2
=
IH
+
2
*
PH
;
auto
IH2
=
IH
+
2
*
PH
;
auto
IW2
=
IW
+
2
*
PW
;
auto
IW2
=
IW
+
2
*
PW
;
size_t
group_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
batch_id
=
ncb_index
.
ndrange_id
[
1
];
size_t
OHW
=
OH
*
OW
;
size_t
OHW
=
OH
*
OW
;
size_t
output_block_size
=
std
::
min
(
size_t
output_block_size
=
std
::
min
(
ohw_tile_size
,
OHW
-
ncb_index
.
ndrange_id
[
2
]
*
ohw_tile_size
);
ohw_tile_size
,
OHW
-
ncb_index
.
ndrange_id
[
2
]
*
ohw_tile_size
);
...
...
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
d346c878
...
@@ -245,65 +245,10 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(
...
@@ -245,65 +245,10 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(
void
ConvBiasImpl
::
exec_with_ncb_kern
(
const
NCBKernParam
&
param
,
void
ConvBiasImpl
::
exec_with_ncb_kern
(
const
NCBKernParam
&
param
,
ConvBiasImpl
::
Algorithm
*
algo
)
{
ConvBiasImpl
::
Algorithm
*
algo
)
{
auto
ncb_kerns
=
ncb_algo_dispatch_kerns
(
algo
,
param
);
auto
ncb_kerns
=
ncb_algo_dispatch_kerns
(
algo
,
param
);
size_t
src_batch_stride
=
param
.
inp_bs
*
param
.
src_type
.
size
();
size_t
dst_batch_stride
=
param
.
out_bs
*
param
.
dst_type
.
size
();
size_t
bias_batch_stride
=
0
;
if
(
param
.
bias_mode
==
BiasMode
::
BIAS
)
{
bias_batch_stride
=
param
.
bias_bs
*
param
.
bias_type
.
size
();
}
for
(
auto
&&
kernel
:
ncb_kerns
)
{
for
(
auto
&&
kernel
:
ncb_kerns
)
{
megdnn_assert
(
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NHWC
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW_WINOGRAD
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW88
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
,
"invalid conv format"
);
ptrdiff_t
istrd
=
0
,
fstrd
=
0
,
bstrd
=
0
,
ostrd
=
0
;
if
(
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW_WINOGRAD
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
)
{
fstrd
=
param
.
filter_meta
.
icpg
*
param
.
filter_meta
.
ocpg
*
(
param
.
filter_meta
.
spatial
[
0
]
+
param
.
output_block_size
-
1
)
*
(
param
.
filter_meta
.
spatial
[
1
]
+
param
.
output_block_size
-
1
)
*
param
.
filter_type
.
size
();
}
else
{
fstrd
=
param
.
filter_meta
.
icpg
*
param
.
filter_meta
.
ocpg
*
param
.
filter_meta
.
spatial
[
0
]
*
param
.
filter_meta
.
spatial
[
1
]
*
param
.
filter_type
.
size
();
}
istrd
=
param
.
filter_meta
.
icpg
*
param
.
src_type
.
size
();
ostrd
=
param
.
filter_meta
.
ocpg
*
param
.
dst_type
.
size
();
if
(
param
.
bias_mode
!=
BiasMode
::
NO_BIAS
)
{
bstrd
=
param
.
filter_meta
.
ocpg
*
param
.
bias_type
.
size
();
}
if
(
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW_WINOGRAD
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
)
{
istrd
*=
param
.
isz
[
0
]
*
param
.
isz
[
1
];
ostrd
*=
param
.
osz
[
0
]
*
param
.
osz
[
1
];
if
(
param
.
bias_mode
==
BiasMode
::
BIAS
)
{
bstrd
*=
param
.
osz
[
0
]
*
param
.
osz
[
1
];
}
}
else
{
// must be NHWC. No action performed.
}
auto
run
=
[
=
](
size_t
index
,
size_t
thread_id
)
{
auto
run
=
[
=
](
size_t
index
,
size_t
thread_id
)
{
auto
copy_param
=
param
;
auto
copy_param
=
param
;
CpuNDRange
ndrange_id
(
kernel
.
global_size
,
index
);
CpuNDRange
ndrange_id
(
kernel
.
global_size
,
index
);
size_t
group_id
=
ndrange_id
[
0
];
size_t
batch_id
=
ndrange_id
[
1
];
//! The kernel ptr point to batch index
incr_ptr
(
copy_param
.
src_ptr
,
group_id
*
istrd
+
batch_id
*
src_batch_stride
);
incr_ptr
(
copy_param
.
filter_ptr
,
group_id
*
fstrd
);
incr_ptr
(
copy_param
.
bias_ptr
,
group_id
*
bstrd
+
batch_id
*
bias_batch_stride
);
incr_ptr
(
copy_param
.
dst_ptr
,
group_id
*
ostrd
+
batch_id
*
dst_batch_stride
);
kernel
.
kern
(
copy_param
,
{
thread_id
,
ndrange_id
});
kernel
.
kern
(
copy_param
,
{
thread_id
,
ndrange_id
});
};
};
static_cast
<
naive
::
HandleImpl
*>
(
handle
())
->
dispatch_kern
(
static_cast
<
naive
::
HandleImpl
*>
(
handle
())
->
dispatch_kern
(
...
@@ -381,4 +326,124 @@ const char* ConvBiasImpl::get_algorithm_set_name() const {
...
@@ -381,4 +326,124 @@ const char* ConvBiasImpl::get_algorithm_set_name() const {
return
"F0"
;
return
"F0"
;
}
}
namespace
megdnn
{
namespace
fallback
{
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template
<
typename
T
>
const
T
*
ConvBiasImpl
::
NCBKernParam
::
src
(
size_t
batch_id
,
size_t
group_id
,
size_t
group_pack_size
)
const
{
src_type
.
assert_is_compatible_ctype
<
T
>
();
size_t
batch_offset
=
batch_id
*
inp_bs
*
src_type
.
size
();
size_t
group_offset
=
group_pack_size
*
group_id
*
filter_meta
.
icpg
*
isz
[
0
]
*
isz
[
1
]
*
src_type
.
size
();
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
ptrdiff_t
>
(
src_ptr
)
+
batch_offset
+
group_offset
);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template
<
typename
T
>
const
T
*
ConvBiasImpl
::
NCBKernParam
::
filter
(
size_t
group_id
,
size_t
pack_group_size
)
const
{
size_t
group_offset
=
0
_z
;
switch
(
filter_meta
.
format
)
{
case
Param
::
Format
::
NCHW
:
{
group_offset
=
pack_group_size
*
group_id
*
filter_meta
.
icpg
*
filter_meta
.
ocpg
*
filter_meta
.
spatial
[
0
]
*
filter_meta
.
spatial
[
1
]
*
filter_type
.
size
();
break
;
}
case
Param
::
Format
::
NCHW88
:
{
size_t
group
=
filter_meta
.
group
;
size_t
icpg
=
filter_meta
.
icpg
;
size_t
ocpg
=
filter_meta
.
ocpg
;
//! four format of weight layout
//! 1. {oc/8, ic/8, fh, fw, 8, 8}, 2. {g, oc/8, ic/8, fh,
//! fw, 8, 8}
//! 3. {g/8, 1, 1, fh, fw, 8, 8}, 3. {oc/8 ,fh, fw, ic, 8}
megdnn_assert
((
icpg
%
8
==
0
&&
ocpg
%
8
==
0
)
||
(
group
%
8
==
0
&&
icpg
==
1
&&
ocpg
==
1
&&
pack_group_size
>
1
)
||
(
group
==
1
&&
ocpg
%
8
==
0
),
"The filter shepe is not right of nchw88"
);
group_offset
=
pack_group_size
*
group_id
*
filter_meta
.
icpg
*
filter_meta
.
ocpg
*
filter_meta
.
spatial
[
0
]
*
filter_meta
.
spatial
[
1
]
*
filter_type
.
size
();
break
;
}
case
ConvBiasImpl
::
Param
::
Format
::
NCHW_WINOGRAD
:
case
ConvBiasImpl
::
Param
::
Format
::
NCHW88_WINOGRAD
:
{
//! four format of weight layout
//! 1. {g, alpha, alpha, ocpg/8, icpg/8, 8, 8}
//! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8}
//! 3. {g, alpha, alpha, oc, ic, 8, 8}
//! 4. {alpha, alpha, oc, ic}
group_offset
=
pack_group_size
*
group_id
*
filter_meta
.
icpg
*
filter_meta
.
ocpg
*
(
filter_meta
.
spatial
[
0
]
+
output_block_size
-
1
)
*
(
filter_meta
.
spatial
[
1
]
+
output_block_size
-
1
)
*
filter_type
.
size
();
break
;
}
default:
megdnn_assert
(
"other filter format is not support yet"
);
}
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
ptrdiff_t
>
(
filter_ptr
)
+
group_offset
);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template
<
typename
T
>
const
T
*
ConvBiasImpl
::
NCBKernParam
::
bias
(
size_t
batch_id
,
size_t
group_id
,
size_t
group_pack_size
)
const
{
bias_type
.
assert_is_compatible_ctype
<
T
>
();
size_t
batch_offset
=
0
_z
;
size_t
group_offset
=
0
_z
;
if
(
bias_mode
==
BiasMode
::
BIAS
)
{
batch_offset
=
batch_id
*
bias_bs
*
bias_type
.
size
();
group_offset
=
group_pack_size
*
group_id
*
filter_meta
.
ocpg
*
osz
[
0
]
*
osz
[
1
]
*
bias_type
.
size
();
}
else
if
(
bias_mode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
group_offset
=
group_pack_size
*
group_id
*
filter_meta
.
ocpg
*
bias_type
.
size
();
}
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
ptrdiff_t
>
(
bias_ptr
)
+
batch_offset
+
group_offset
);
}
//! when format is nchwxx and channel wise mode, multi group will pack
//! together, so pack_group_size is the number of packed group
template
<
typename
T
>
T
*
ConvBiasImpl
::
NCBKernParam
::
dst
(
size_t
batch_id
,
size_t
group_id
,
size_t
group_pack_size
)
const
{
dst_type
.
assert_is_compatible_ctype
<
T
>
();
size_t
batch_offset
=
batch_id
*
out_bs
*
dst_type
.
size
();
size_t
group_offset
=
group_pack_size
*
group_id
*
filter_meta
.
ocpg
*
osz
[
0
]
*
osz
[
1
]
*
dst_type
.
size
();
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
ptrdiff_t
>
(
dst_ptr
)
+
batch_offset
+
group_offset
);
}
#define INST(T) \
template const T* ConvBiasImpl::NCBKernParam::src<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const; \
template const T* ConvBiasImpl::NCBKernParam::bias<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const; \
template const T* ConvBiasImpl::NCBKernParam::filter<T>( \
size_t group_id, size_t group_pack_size) const; \
template T* ConvBiasImpl::NCBKernParam::dst<T>( \
size_t batch_id, size_t group_id, size_t group_pack_size) const;
#define INST_DT(d) INST(DTypeTrait<d>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE
(
INST_DT
)
#undef INST
#undef INST_DT
}
// namespace fallback
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/opr_impl.h
浏览文件 @
d346c878
...
@@ -104,24 +104,39 @@ public:
...
@@ -104,24 +104,39 @@ public:
return
static_cast
<
const
T
*>
(
src_ptr
);
return
static_cast
<
const
T
*>
(
src_ptr
);
}
}
template
<
typename
T
>
const
T
*
src
(
size_t
batch_id
,
size_t
group_id
,
size_t
group_pack_size
=
1
_z
)
const
;
template
<
typename
T
>
template
<
typename
T
>
const
T
*
filter
()
const
{
const
T
*
filter
()
const
{
filter_type
.
assert_is_compatible_ctype
<
T
>
();
filter_type
.
assert_is_compatible_ctype
<
T
>
();
return
static_cast
<
const
T
*>
(
filter_ptr
);
return
static_cast
<
const
T
*>
(
filter_ptr
);
}
}
template
<
typename
T
>
const
T
*
filter
(
size_t
group_id
,
size_t
pack_group_size
=
1
_z
)
const
;
template
<
typename
T
>
template
<
typename
T
>
const
T
*
bias
()
const
{
const
T
*
bias
()
const
{
bias_type
.
assert_is_compatible_ctype
<
T
>
();
bias_type
.
assert_is_compatible_ctype
<
T
>
();
return
static_cast
<
const
T
*>
(
bias_ptr
);
return
static_cast
<
const
T
*>
(
bias_ptr
);
}
}
template
<
typename
T
>
const
T
*
bias
(
size_t
batch_id
,
size_t
group_id
,
size_t
group_pack_size
=
1
_z
)
const
;
template
<
typename
T
>
template
<
typename
T
>
T
*
dst
()
const
{
T
*
dst
()
const
{
dst_type
.
assert_is_compatible_ctype
<
T
>
();
dst_type
.
assert_is_compatible_ctype
<
T
>
();
return
static_cast
<
T
*>
(
dst_ptr
);
return
static_cast
<
T
*>
(
dst_ptr
);
}
}
template
<
typename
T
>
T
*
dst
(
size_t
batch_id
,
size_t
group_id
,
size_t
group_pack_size
=
1
_z
)
const
;
template
<
typename
T
>
template
<
typename
T
>
T
*
workspace
()
const
{
T
*
workspace
()
const
{
return
static_cast
<
T
*>
(
workspace_ptr
);
return
static_cast
<
T
*>
(
workspace_ptr
);
...
...
dnn/src/fallback/conv_bias/winograd/winograd.h
浏览文件 @
d346c878
...
@@ -210,7 +210,7 @@ public:
...
@@ -210,7 +210,7 @@ public:
reinterpret_cast
<
input_filter_compute_type
*>
(
reinterpret_cast
<
input_filter_compute_type
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle_compute
.
get
(
2
))
+
reinterpret_cast
<
uintptr_t
>
(
bundle_compute
.
get
(
2
))
+
compute_workspace_size_per_thread
*
thread_id
);
compute_workspace_size_per_thread
*
thread_id
);
const
stype
*
filter_ptr
=
kern_param
.
filter
<
stype
>
();
const
stype
*
filter_ptr
=
kern_param
.
filter
<
stype
>
(
group_id
);
size_t
oc_start
=
oc_id
,
oc_end
=
oc_id
+
1
;
size_t
oc_start
=
oc_id
,
oc_end
=
oc_id
+
1
;
if
(
kern_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
if
(
kern_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
oc_start
=
8
*
oc_id
;
oc_start
=
8
*
oc_id
;
...
@@ -246,16 +246,19 @@ public:
...
@@ -246,16 +246,19 @@ public:
size_t
oc_block_id
=
ncb_index
.
ndrange_id
[
3
];
size_t
oc_block_id
=
ncb_index
.
ndrange_id
[
3
];
size_t
tile_id
=
ncb_index
.
ndrange_id
[
2
];
size_t
tile_id
=
ncb_index
.
ndrange_id
[
2
];
size_t
batch_id
=
ncb_index
.
ndrange_id
[
1
];
size_t
group_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
group_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
thread_id
=
ncb_index
.
thread_id
;
size_t
thread_id
=
ncb_index
.
thread_id
;
bundle_top
.
set
(
ncb_param
.
workspace_ptr
);
bundle_top
.
set
(
ncb_param
.
workspace_ptr
);
bundle_compute
.
set
(
bundle_top
.
get
(
0
));
bundle_compute
.
set
(
bundle_top
.
get
(
0
));
const
stype
*
src_ptr
=
ncb_param
.
src
<
stype
>
();
const
stype
*
src_ptr
=
ncb_param
.
src
<
stype
>
(
batch_id
,
group_id
);
dst_type
*
dst_ptr
=
ncb_param
.
dst
<
dst_type
>
();
dst_type
*
dst_ptr
=
ncb_param
.
dst
<
dst_type
>
(
batch_id
,
group_id
);
const
output_compute_type
*
bias_ptr
=
const
output_compute_type
*
bias_ptr
=
static_cast
<
const
output_compute_type
*>
(
ncb_param
.
bias_ptr
);
static_cast
<
const
output_compute_type
*>
(
ncb_param
.
bias
<
output_compute_type
>
(
batch_id
,
group_id
));
input_filter_compute_type
*
input_transform_buf
=
input_filter_compute_type
*
input_transform_buf
=
reinterpret_cast
<
input_filter_compute_type
*>
(
reinterpret_cast
<
input_filter_compute_type
*>
(
...
@@ -271,9 +274,10 @@ public:
...
@@ -271,9 +274,10 @@ public:
reinterpret_cast
<
uintptr_t
>
(
bundle_compute
.
get
(
2
))
+
reinterpret_cast
<
uintptr_t
>
(
bundle_compute
.
get
(
2
))
+
compute_workspace_size_per_thread
*
thread_id
);
compute_workspace_size_per_thread
*
thread_id
);
//! NCHW88_WINOGRAD and NCHW_WINOGRAD is the same offset
const
input_filter_compute_type
*
filter_transform_buf
=
const
input_filter_compute_type
*
filter_transform_buf
=
static_cast
<
const
input_filter_compute_type
*>
(
static_cast
<
const
input_filter_compute_type
*>
(
ncb_param
.
filter
_ptr
);
ncb_param
.
filter
<
input_filter_compute_type
>
(
group_id
)
);
if
(
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
if
(
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
filter_transform_buf
=
reinterpret_cast
<
input_filter_compute_type
*>
(
filter_transform_buf
=
reinterpret_cast
<
input_filter_compute_type
*>
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录