Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5e8aa333
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5e8aa333
编写于
3月 24, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn): refactor winograd output transpose
GitOrigin-RevId: 6d4b225ea54a14c6c5479788b1d2b42a5b9d3cf5
上级
c6eb2e8d
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
131 addition
and
98 deletion
+131
-98
dnn/src/common/winograd/winograd_helper.cpp
dnn/src/common/winograd/winograd_helper.cpp
+3
-3
dnn/src/common/winograd/winograd_helper.h
dnn/src/common/winograd/winograd_helper.h
+4
-5
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+0
-1
dnn/src/fallback/conv_bias/winograd/winograd.h
dnn/src/fallback/conv_bias/winograd/winograd.h
+0
-1
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
+52
-35
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
+72
-53
未找到文件。
dnn/src/common/winograd/winograd_helper.cpp
浏览文件 @
5e8aa333
...
@@ -235,7 +235,7 @@ void StrategyHelper<
...
@@ -235,7 +235,7 @@ void StrategyHelper<
input_filter_compute_type
*
input_transform_buf
,
input_filter_compute_type
*
input_transform_buf
,
input_filter_compute_type
*
transform_mid_buf
,
input_filter_compute_type
*
transform_mid_buf
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
IC
,
size_t
ic
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
rescale
)
{
float
rescale
)
{
...
@@ -284,7 +284,7 @@ void StrategyHelper<
...
@@ -284,7 +284,7 @@ void StrategyHelper<
const
output_compute_type
*
bias
,
dst_type
*
output
,
const
output_compute_type
*
bias
,
dst_type
*
output
,
output_compute_type
*
transform_mid_buf
,
BiasMode
bmode
,
output_compute_type
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
oc_start
,
size_t
oc_index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
oc_index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
...
@@ -296,7 +296,7 @@ void StrategyHelper<
...
@@ -296,7 +296,7 @@ void StrategyHelper<
output_compute_type
*
mid_buf1
=
transform_mid_buf
;
output_compute_type
*
mid_buf1
=
transform_mid_buf
;
output_compute_type
*
mid_buf2
=
transform_mid_buf
+
alpha
*
alpha
;
output_compute_type
*
mid_buf2
=
transform_mid_buf
+
alpha
*
alpha
;
OutputGetter
<
output_compute_type
,
dst_type
>
getter
(
dtype
);
OutputGetter
<
output_compute_type
,
dst_type
>
getter
(
dtype
);
OutputVisitor
<
layout
,
format
>
output_visitor
(
oc_end
-
oc_start
);
OutputVisitor
<
layout
,
format
>
output_visitor
(
OC
);
size_t
oc
=
oc_start
+
oc_index
;
size_t
oc
=
oc_start
+
oc_index
;
...
...
dnn/src/common/winograd/winograd_helper.h
浏览文件 @
5e8aa333
...
@@ -6,8 +6,7 @@
...
@@ -6,8 +6,7 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -44,8 +43,8 @@ public:
...
@@ -44,8 +43,8 @@ public:
input_filter_compute_type
*
input_transform_buf
,
input_filter_compute_type
*
input_transform_buf
,
input_filter_compute_type
*
transform_mid_buf
,
input_filter_compute_type
*
transform_mid_buf
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
ic
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
IC
,
size_t
ic
,
size_t
unit_idx
,
size_t
m
,
size_t
r
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
rescale
=
1.0
f
);
float
rescale
=
1.0
f
);
...
@@ -54,7 +53,7 @@ public:
...
@@ -54,7 +53,7 @@ public:
const
output_compute_type
*
bias
,
dst_type
*
output
,
const
output_compute_type
*
bias
,
dst_type
*
output
,
output_compute_type
*
transform_mid_buf
,
BiasMode
bmode
,
output_compute_type
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_index
,
size_t
OH
,
size_t
OW
,
size_t
OC
,
size_t
oc_start
,
size_t
oc_index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
input_filter_scale
=
1.0
f
,
// input_scale * filter_scale
float
input_filter_scale
=
1.0
f
,
// input_scale * filter_scale
...
...
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
5e8aa333
...
@@ -45,7 +45,6 @@ public:
...
@@ -45,7 +45,6 @@ public:
static_cast
<
fallback
::
MatrixMulImpl
*>
(
matmul_opr
)
->
algo_pack
();
static_cast
<
fallback
::
MatrixMulImpl
*>
(
matmul_opr
)
->
algo_pack
();
for
(
auto
&&
algo
:
matmul_algos
)
{
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
algoset
()
==
if
(
algo
->
algoset
()
==
//! TODO: threre should filter MK matmul
MatrixMulImpl
::
AlgoBase
::
AlgoSet
::
ALGO_TYPE_GEMV
)
{
MatrixMulImpl
::
AlgoBase
::
AlgoSet
::
ALGO_TYPE_GEMV
)
{
continue
;
continue
;
}
}
...
...
dnn/src/fallback/conv_bias/winograd/winograd.h
浏览文件 @
5e8aa333
...
@@ -536,7 +536,6 @@ public:
...
@@ -536,7 +536,6 @@ public:
NonlineMode nonline_mode, size_t OH, size_t OW, \
NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \
size_t nr_tiles_in_unit); \
};
};
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
...
...
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
浏览文件 @
5e8aa333
...
@@ -186,18 +186,16 @@ struct OutputTransform2X3_NCHW88 {
...
@@ -186,18 +186,16 @@ struct OutputTransform2X3_NCHW88 {
float
*
output
,
float
*
transform_mid_buf
,
float
*
output
,
float
*
transform_mid_buf
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
oc_index
,
size_t
unit_idx
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
size_t
nr_units_in_tile
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
transform_mid_buf
);
MEGDNN_MARK_USED_VAR
(
transform_mid_buf
);
megdnn_assert
(
(
oc_end
-
oc_start
)
%
8
==
0
&&
oc_start
%
8
==
0
&&
oc_end
%
8
==
0
,
"Winograd output transform input param is not times of 8!"
);
Op
op
(
src_dtype
,
dst_dtype
);
Op
op
(
src_dtype
,
dst_dtype
);
//! AT * m * A
//! AT * m * A
size_t
OCB
=
(
oc_end
-
oc_start
)
/
8
;
size_t
OCB
=
(
oc_end
-
oc_start
)
/
8
;
for
(
size_t
oc
=
oc_start
;
oc
+
8
<=
oc_end
;
oc
+=
8
)
{
size_t
oc
=
oc_start
+
oc_index
;
size_t
ocb
=
(
oc
-
oc_start
)
/
8
;
size_t
ocb
=
oc_index
/
8
;
#define cb(m, n) \
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \
output_transform_buf + \
...
@@ -254,7 +252,6 @@ struct OutputTransform2X3_NCHW88 {
...
@@ -254,7 +252,6 @@ struct OutputTransform2X3_NCHW88 {
} while (0);
} while (0);
UNROLL_CALL_RAW_D2
(
2
,
2
,
out_save
);
UNROLL_CALL_RAW_D2
(
2
,
2
,
out_save
);
}
}
}
};
};
#undef CONCAT
#undef CONCAT
}
// namespace
}
// namespace
...
@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input,
...
@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input,
}
}
}
}
void
winograd_nchw88_2x3_8x8_f
::
output
(
void
winograd_nchw88_2x3_8x8_f
::
output
(
const
float
*
output_transform_buf
,
const
float
*
output_transform_buf
,
const
float
*
bias
,
float
*
output
,
const
float
*
bias
,
float
*
output
,
float
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
float
*
transform_mid_buf
,
BiasMode
bmode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
NonlineMode
nonline_mode
,
size_t
OH
,
size_t
oc_end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
)
{
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_start_idx
,
size_t
nr_units_in_tile
)
{
#define cb(_bmode, _nonline_op, ...) \
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
__VA_ARGS__);
auto
units_w
=
div_ceil
<
size_t
>
(
OW
,
OUTPUT_BLOCK_SIZE
);
size_t
OC
=
oc_end
-
oc_start
;
megdnn_assert
(
OC
%
8
==
0
&&
oc_start
%
8
==
0
&&
oc_end
%
8
==
0
,
"Winograd output transform input param is not times of 8!"
);
for
(
size_t
oc
=
oc_start
;
oc
+
8
<=
oc_end
;
oc
+=
8
)
{
size_t
oc_index
=
oc
-
oc_start
;
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
auto
nh
=
index
/
units_w
;
auto
nw
=
index
%
units_w
;
size_t
oh_start
=
nh
*
OUTPUT_BLOCK_SIZE
;
size_t
ow_start
=
nw
*
OUTPUT_BLOCK_SIZE
;
DISPATCH_CONV_WINOGRAD_BIAS
(
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_x86_winograd_nchw88_fp32_F23_8x8
,
cb
,
SIMDType
::
AVX2
,
float
,
megdnn_x86_winograd_nchw88_fp32_F23_8x8
,
cb
,
SIMDType
::
AVX2
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
bias
,
output
,
float
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start
,
oc_end
,
bias
,
output
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
oc_start
,
oc_end
,
oc_index
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
}
}
#undef cb
#undef cb
}
}
...
...
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
浏览文件 @
5e8aa333
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/common/unroll_macro.h"
#include "src/common/unroll_macro.h"
...
@@ -19,10 +20,10 @@
...
@@ -19,10 +20,10 @@
#include <x86intrin.h>
#include <x86intrin.h>
#ifdef WIN32CMAKE
#ifdef WIN32CMAKE
#include <avxintrin.h>
#include <smmintrin.h>
#include <avx2intrin.h>
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#endif
#include "midout.h"
#include "midout.h"
...
@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 {
...
@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 {
for
(
size_t
ocb
=
oc_start
/
8
;
ocb
<
oc_end
/
8
;
ocb
++
)
{
for
(
size_t
ocb
=
oc_start
/
8
;
ocb
<
oc_end
/
8
;
ocb
++
)
{
for
(
size_t
icb
=
0
;
icb
<
ICB
;
icb
++
)
{
for
(
size_t
icb
=
0
;
icb
<
ICB
;
icb
++
)
{
for
(
size_t
ic_inner
=
0
;
ic_inner
<
8
;
ic_inner
++
){
for
(
size_t
ic_inner
=
0
;
ic_inner
<
8
;
ic_inner
++
)
{
const
float
*
fptr
=
filter
+
const
float
*
fptr
=
filter
+
(
ocb
*
ICB
+
icb
)
*
3
*
3
*
8
*
8
+
(
ocb
*
ICB
+
icb
)
*
3
*
3
*
8
*
8
+
ic_inner
*
8
;
ic_inner
*
8
;
...
@@ -220,18 +221,17 @@ struct OutputTransform6X3_NCHW88 {
...
@@ -220,18 +221,17 @@ struct OutputTransform6X3_NCHW88 {
float
*
output
,
float
*
transform_mid_buf
,
float
*
output
,
float
*
transform_mid_buf
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
oc_index
,
size_t
unit_idx
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
size_t
nr_units_in_tile
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
transform_mid_buf
);
MEGDNN_MARK_USED_VAR
(
transform_mid_buf
);
megdnn_assert
(
(
oc_end
-
oc_start
)
%
8
==
0
&&
oc_start
%
8
==
0
&&
oc_end
%
8
==
0
,
"Winograd output transform input param is not times of 8!"
);
Op
op
(
src_dtype
,
dst_dtype
);
Op
op
(
src_dtype
,
dst_dtype
);
//! AT * m * A
//! AT * m * A
size_t
OCB
=
(
oc_end
-
oc_start
)
/
8
;
size_t
OCB
=
(
oc_end
-
oc_start
)
/
8
;
for
(
size_t
oc
=
oc_start
;
oc
+
8
<=
oc_end
;
oc
+=
8
)
{
size_t
oc
=
oc_start
+
oc_index
;
size_t
ocb
=
(
oc
-
oc_start
)
/
8
;
size_t
ocb
=
oc_index
/
8
;
#define cb(m, n) \
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + \
output_transform_buf + \
...
@@ -253,8 +253,7 @@ struct OutputTransform6X3_NCHW88 {
...
@@ -253,8 +253,7 @@ struct OutputTransform6X3_NCHW88 {
* 0 0.0 0 0 0 1
* 0 0.0 0 0 0 1
*/
*/
Vector
<
float
,
8
>
v1addv2
,
v1subv2
,
v3addv4
,
v3subv4
,
v5addv6
,
Vector
<
float
,
8
>
v1addv2
,
v1subv2
,
v3addv4
,
v3subv4
,
v5addv6
,
v5subv6
;
v5subv6
;
#define cb(m) \
#define cb(m) \
v1addv2 = v1##m + v2##m; \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
v1subv2 = v1##m - v2##m; \
...
@@ -318,7 +317,6 @@ struct OutputTransform6X3_NCHW88 {
...
@@ -318,7 +317,6 @@ struct OutputTransform6X3_NCHW88 {
} while (0);
} while (0);
UNROLL_CALL_RAW_D2
(
6
,
6
,
out_save
);
UNROLL_CALL_RAW_D2
(
6
,
6
,
out_save
);
}
}
}
};
};
#undef CONCAT
#undef CONCAT
}
// namespace
}
// namespace
...
@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
...
@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
megdnn_assert
(
IC
%
8
==
0
);
megdnn_assert
(
IC
%
8
==
0
);
// OW = IW + 2 * PW - KERNEL_SIZE + 1
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto
units_w
=
div_ceil
<
size_t
>
(
IW
+
2
*
PW
-
KERNEL_SIZE
+
1
,
OUTPUT_BLOCK_SIZE
);
auto
units_w
=
div_ceil
<
size_t
>
(
IW
+
2
*
PW
-
KERNEL_SIZE
+
1
,
OUTPUT_BLOCK_SIZE
);
float
*
patch
=
transform_mid_buf
;
float
*
patch
=
transform_mid_buf
;
float
*
patchT
=
transform_mid_buf
+
8
*
alpha
*
alpha
;
float
*
patchT
=
transform_mid_buf
+
8
*
alpha
*
alpha
;
...
@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
...
@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input,
}
}
}
}
void
winograd_nchw88_6x3_8x8_f
::
output
(
void
winograd_nchw88_6x3_8x8_f
::
output
(
const
float
*
output_transform_buf
,
const
float
*
output_transform_buf
,
const
float
*
bias
,
float
*
output
,
const
float
*
bias
,
float
*
output
,
float
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
float
*
transform_mid_buf
,
BiasMode
bmode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
NonlineMode
nonline_mode
,
size_t
OH
,
size_t
oc_end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
)
{
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_start_idx
,
size_t
nr_units_in_tile
)
{
#define cb(_bmode, _nonline_op, ...) \
#define cb(_bmode, _nonline_op, ...) \
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
__VA_ARGS__);
auto
units_w
=
div_ceil
<
size_t
>
(
OW
,
OUTPUT_BLOCK_SIZE
);
size_t
OC
=
oc_end
-
oc_start
;
megdnn_assert
(
OC
%
8
==
0
&&
oc_start
%
8
==
0
&&
oc_end
%
8
==
0
,
"Winograd output transform input param is not times of 8!"
);
for
(
size_t
oc
=
oc_start
;
oc
+
8
<=
oc_end
;
oc
+=
8
)
{
size_t
oc_index
=
oc
-
oc_start
;
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
auto
nh
=
index
/
units_w
;
auto
nw
=
index
%
units_w
;
size_t
oh_start
=
nh
*
OUTPUT_BLOCK_SIZE
;
size_t
ow_start
=
nw
*
OUTPUT_BLOCK_SIZE
;
DISPATCH_CONV_WINOGRAD_BIAS
(
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_x86_winograd_nchw88_fp32_F63_8x8
,
cb
,
SIMDType
::
AVX2
,
float
,
megdnn_x86_winograd_nchw88_fp32_F63_8x8
,
cb
,
SIMDType
::
AVX2
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
bias
,
output
,
float
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start
,
oc_end
,
bias
,
output
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
oc_start
,
oc_end
,
oc_index
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
}
}
#undef cb
#undef cb
}
}
}
// namespace winograd
}
// namespace winograd
}
// namespace
arm_common
}
// namespace
x86
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录