Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5e8aa333
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5e8aa333
编写于
3月 24, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn): refactor winograd output transpose
GitOrigin-RevId: 6d4b225ea54a14c6c5479788b1d2b42a5b9d3cf5
上级
c6eb2e8d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
131 addition
and
98 deletion
+131
-98
dnn/src/common/winograd/winograd_helper.cpp
dnn/src/common/winograd/winograd_helper.cpp
+3
-3
dnn/src/common/winograd/winograd_helper.h
dnn/src/common/winograd/winograd_helper.h
+4
-5
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+0
-1
dnn/src/fallback/conv_bias/winograd/winograd.h
dnn/src/fallback/conv_bias/winograd/winograd.h
+0
-1
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
+52
-35
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
+72
-53
未找到文件。
dnn/src/common/winograd/winograd_helper.cpp
浏览文件 @
5e8aa333
...
...
@@ -235,7 +235,7 @@ void StrategyHelper<
input_filter_compute_type
*
input_transform_buf
,
input_filter_compute_type
*
transform_mid_buf
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
IC
,
size_t
ic
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
rescale
)
{
...
...
@@ -284,7 +284,7 @@ void StrategyHelper<
const
output_compute_type
*
bias
,
dst_type
*
output
,
output_compute_type
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
oc_start
,
size_t
oc_index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
...
...
@@ -296,7 +296,7 @@ void StrategyHelper<
output_compute_type
*
mid_buf1
=
transform_mid_buf
;
output_compute_type
*
mid_buf2
=
transform_mid_buf
+
alpha
*
alpha
;
OutputGetter
<
output_compute_type
,
dst_type
>
getter
(
dtype
);
OutputVisitor
<
layout
,
format
>
output_visitor
(
oc_end
-
oc_start
);
OutputVisitor
<
layout
,
format
>
output_visitor
(
OC
);
size_t
oc
=
oc_start
+
oc_index
;
...
...
dnn/src/common/winograd/winograd_helper.h
浏览文件 @
5e8aa333
...
...
@@ -6,8 +6,7 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
...
...
@@ -44,8 +43,8 @@ public:
input_filter_compute_type
*
input_transform_buf
,
input_filter_compute_type
*
transform_mid_buf
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
ic
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
size_t
IC
,
size_t
ic
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
rescale
=
1.0
f
);
...
...
@@ -54,7 +53,7 @@ public:
const
output_compute_type
*
bias
,
dst_type
*
output
,
output_compute_type
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_index
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
oc_start
,
size_t
oc_index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
input_filter_scale
=
1.0
f
,
// input_scale * filter_scale
...
...
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
5e8aa333
...
...
@@ -45,7 +45,6 @@ public:
static_cast
<
fallback
::
MatrixMulImpl
*>
(
matmul_opr
)
->
algo_pack
();
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
algoset
()
==
//! TODO: threre should filter MK matmul
MatrixMulImpl
::
AlgoBase
::
AlgoSet
::
ALGO_TYPE_GEMV
)
{
continue
;
}
...
...
dnn/src/fallback/conv_bias/winograd/winograd.h
浏览文件 @
5e8aa333
...
...
@@ -536,7 +536,6 @@ public:
NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \
};
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
...
...
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
浏览文件 @
5e8aa333
...
...
@@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 {
float
*
output
,
float
*
transform_mid_buf
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
size_t
oc_index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
transform_mid_buf
);
megdnn_assert
(
(
oc_end
-
oc_start
)
%
8
==
0
&&
oc_start
%
8
==
0
&&
oc_end
%
8
==
0
,
"Winograd output transform input param is not times of 8!"
);
Op
op
(
src_dtype
,
dst_dtype
);
//! AT * m * A
size_t
OCB
=
(
oc_end
-
oc_start
)
/
8
;
for
(
size_t
oc
=
oc_start
;
oc
+
8
<=
oc_end
;
oc
+=
8
)
{
size_t
ocb
=
(
oc
-
oc_start
)
/
8
;
size_t
oc
=
oc_start
+
oc_index
;
size_t
ocb
=
oc_index
/
8
;
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \
ocb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2
(
4
,
4
,
cb
);
UNROLL_CALL_NOWRAPPER_D2
(
4
,
4
,
cb
);
#undef cb
//! 1 1 1 0 v00 v01 v02 v03 1 0
//! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1
//! 1 1 1 0 v00 v01 v02 v03 1 0
//! 0 1 -1 1 v10 v11 v12 v13 1 1
//! v20 v21 v22 v23 1 -1
//! v30 v31 v32 v33 0 1
#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;
UNROLL_CALL_NOWRAPPER
(
4
,
cb
);
UNROLL_CALL_NOWRAPPER
(
4
,
cb
);
#undef cb
#define cb(m) \
v##m##0 = t##m##0 + t##m##1 + t##m##2; \
v##m##1 = t##m##1 - t##m##2 + t##m##3;
UNROLL_CALL_NOWRAPPER
(
2
,
cb
);
UNROLL_CALL_NOWRAPPER
(
2
,
cb
);
#undef cb
Vector
<
float
,
8
>
vbias
;
if
(
bmode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
vbias
=
Vector
<
float
,
8
>::
load
(
bias
+
oc
);
Vector
<
float
,
8
>
vbias
;
if
(
bmode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
vbias
=
Vector
<
float
,
8
>::
load
(
bias
+
oc
);
#define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2
(
2
,
2
,
cb
);
UNROLL_CALL_RAW_D2
(
2
,
2
,
cb
);
#undef cb
}
if
(
bmode
!=
BiasMode
::
BIAS
)
{
}
if
(
bmode
!=
BiasMode
::
BIAS
)
{
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2
(
2
,
2
,
cb
);
UNROLL_CALL_RAW_D2
(
2
,
2
,
cb
);
#undef cb
}
}
#define out_save(oho, owo) \
do { \
size_t oh = oh_start + oho; \
...
...
@@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 {
ow * 8); \
} \
} while (0);
UNROLL_CALL_RAW_D2
(
2
,
2
,
out_save
);
}
UNROLL_CALL_RAW_D2
(
2
,
2
,
out_save
);
}
};
#undef CONCAT
...
...
@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input,
}
}
void
winograd_nchw88_2x3_8x8_f
::
output
(
const
float
*
output_transform_buf
,
const
float
*
bias
,
float
*
output
,
float
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
)
{
void
winograd_nchw88_2x3_8x8_f
::
output
(
const
float
*
output_transform_buf
,
const
float
*
bias
,
float
*
output
,
float
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_start_idx
,
size_t
nr_units_in_tile
)
{
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_x86_winograd_nchw88_fp32_F23_8x8
,
cb
,
SIMDType
::
AVX2
,
float
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
bias
,
output
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start
,
oc_end
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
auto
units_w
=
div_ceil
<
size_t
>
(
OW
,
OUTPUT_BLOCK_SIZE
);
size_t
OC
=
oc_end
-
oc_start
;
megdnn_assert
(
OC
%
8
==
0
&&
oc_start
%
8
==
0
&&
oc_end
%
8
==
0
,
"Winograd output transform input param is not times of 8!"
);
for
(
size_t
oc
=
oc_start
;
oc
+
8
<=
oc_end
;
oc
+=
8
)
{
size_t
oc_index
=
oc
-
oc_start
;
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
auto
nh
=
index
/
units_w
;
auto
nw
=
index
%
units_w
;
size_t
oh_start
=
nh
*
OUTPUT_BLOCK_SIZE
;
size_t
ow_start
=
nw
*
OUTPUT_BLOCK_SIZE
;
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_x86_winograd_nchw88_fp32_F23_8x8
,
cb
,
SIMDType
::
AVX2
,
float
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
bias
,
output
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start
,
oc_end
,
oc_index
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
}
}
#undef cb
}
...
...
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
浏览文件 @
5e8aa333
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/unroll_macro.h"
...
...
@@ -19,10 +20,10 @@
#include <x86intrin.h>
#ifdef WIN32CMAKE
#include <avxintrin.h>
#include <smmintrin.h>
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#include "midout.h"
...
...
@@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 {
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
ic
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
patch
);
size_t
IW8
=
IW
*
8
;
//! For nchw88 mode
size_t
IW8
=
IW
*
8
;
//! For nchw88 mode
size_t
iw8_start
=
iw_start
*
8
;
//! For nchw88 mode
size_t
icb
=
ic
/
8
;
if
(
!
(
inner
&&
ic
+
8
<
IC
))
{
...
...
@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 {
for
(
size_t
ocb
=
oc_start
/
8
;
ocb
<
oc_end
/
8
;
ocb
++
)
{
for
(
size_t
icb
=
0
;
icb
<
ICB
;
icb
++
)
{
for
(
size_t
ic_inner
=
0
;
ic_inner
<
8
;
ic_inner
++
){
for
(
size_t
ic_inner
=
0
;
ic_inner
<
8
;
ic_inner
++
)
{
const
float
*
fptr
=
filter
+
(
ocb
*
ICB
+
icb
)
*
3
*
3
*
8
*
8
+
ic_inner
*
8
;
...
...
@@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 {
float
*
output
,
float
*
transform_mid_buf
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
size_t
oc_index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
transform_mid_buf
);
megdnn_assert
(
(
oc_end
-
oc_start
)
%
8
==
0
&&
oc_start
%
8
==
0
&&
oc_end
%
8
==
0
,
"Winograd output transform input param is not times of 8!"
);
Op
op
(
src_dtype
,
dst_dtype
);
//! AT * m * A
size_t
OCB
=
(
oc_end
-
oc_start
)
/
8
;
for
(
size_t
oc
=
oc_start
;
oc
+
8
<=
oc_end
;
oc
+=
8
)
{
size_t
ocb
=
(
oc
-
oc_start
)
/
8
;
size_t
oc
=
oc_start
+
oc_index
;
size_t
ocb
=
oc_index
/
8
;
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \
ocb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2
(
8
,
8
,
cb
);
UNROLL_CALL_NOWRAPPER_D2
(
8
,
8
,
cb
);
#undef cb
/**
* A
*
* 1 0 0 0 0 0
* 1 1 1 1 1 1
* 1 -1 1 -1 1 -1
* 1 2 4 8 16 32
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0.0 0 0 0 1
*/
Vector
<
float
,
8
>
v1addv2
,
v1subv2
,
v3addv4
,
v3subv4
,
v5addv6
,
v5subv6
;
/**
* A
*
* 1 0 0 0 0 0
* 1 1 1 1 1 1
* 1 -1 1 -1 1 -1
* 1 2 4 8 16 32
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0.0 0 0 0 1
*/
Vector
<
float
,
8
>
v1addv2
,
v1subv2
,
v3addv4
,
v3subv4
,
v5addv6
,
v5subv6
;
#define cb(m) \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
...
...
@@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 {
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
UNROLL_CALL_NOWRAPPER
(
8
,
cb
);
UNROLL_CALL_NOWRAPPER
(
8
,
cb
);
#undef cb
#define cb(m) \
...
...
@@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 {
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
UNROLL_CALL_NOWRAPPER
(
6
,
cb
);
UNROLL_CALL_NOWRAPPER
(
6
,
cb
);
#undef cb
Vector
<
float
,
8
>
vbias
;
if
(
bmode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
vbias
=
Vector
<
float
,
8
>::
load
(
bias
+
oc
);
Vector
<
float
,
8
>
vbias
;
if
(
bmode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
vbias
=
Vector
<
float
,
8
>::
load
(
bias
+
oc
);
#define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2
(
6
,
6
,
cb
);
UNROLL_CALL_RAW_D2
(
6
,
6
,
cb
);
#undef cb
}
if
(
bmode
!=
BiasMode
::
BIAS
)
{
}
if
(
bmode
!=
BiasMode
::
BIAS
)
{
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2
(
6
,
6
,
cb
);
UNROLL_CALL_RAW_D2
(
6
,
6
,
cb
);
#undef cb
}
}
#define out_save(oho, owo) \
do { \
size_t oh = oh_start + oho; \
...
...
@@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 {
ow * 8); \
} \
} while (0);
UNROLL_CALL_RAW_D2
(
6
,
6
,
out_save
);
}
UNROLL_CALL_RAW_D2
(
6
,
6
,
out_save
);
}
};
#undef CONCAT
...
...
@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
megdnn_assert
(
IC
%
8
==
0
);
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto
units_w
=
div_ceil
<
size_t
>
(
IW
+
2
*
PW
-
KERNEL_SIZE
+
1
,
OUTPUT_BLOCK_SIZE
);
auto
units_w
=
div_ceil
<
size_t
>
(
IW
+
2
*
PW
-
KERNEL_SIZE
+
1
,
OUTPUT_BLOCK_SIZE
);
float
*
patch
=
transform_mid_buf
;
float
*
patchT
=
transform_mid_buf
+
8
*
alpha
*
alpha
;
...
...
@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
}
}
void
winograd_nchw88_6x3_8x8_f
::
output
(
const
float
*
output_transform_buf
,
const
float
*
bias
,
float
*
output
,
float
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
)
{
void
winograd_nchw88_6x3_8x8_f
::
output
(
const
float
*
output_transform_buf
,
const
float
*
bias
,
float
*
output
,
float
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_start_idx
,
size_t
nr_units_in_tile
)
{
#define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_x86_winograd_nchw88_fp32_F63_8x8
,
cb
,
SIMDType
::
AVX2
,
float
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
bias
,
output
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start
,
oc_end
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
auto
units_w
=
div_ceil
<
size_t
>
(
OW
,
OUTPUT_BLOCK_SIZE
);
size_t
OC
=
oc_end
-
oc_start
;
megdnn_assert
(
OC
%
8
==
0
&&
oc_start
%
8
==
0
&&
oc_end
%
8
==
0
,
"Winograd output transform input param is not times of 8!"
);
for
(
size_t
oc
=
oc_start
;
oc
+
8
<=
oc_end
;
oc
+=
8
)
{
size_t
oc_index
=
oc
-
oc_start
;
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
auto
nh
=
index
/
units_w
;
auto
nw
=
index
%
units_w
;
size_t
oh_start
=
nh
*
OUTPUT_BLOCK_SIZE
;
size_t
ow_start
=
nw
*
OUTPUT_BLOCK_SIZE
;
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_x86_winograd_nchw88_fp32_F63_8x8
,
cb
,
SIMDType
::
AVX2
,
float
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
bias
,
output
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start
,
oc_end
,
oc_index
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
}
}
#undef cb
}
}
// namespace winograd
}
// namespace
arm_common
}
// namespace
x86
}
// namespace megdnn
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录