Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c6eb2e8d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c6eb2e8d
编写于
3月 23, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn): optimize winograd input transpose
GitOrigin-RevId: a43077550c0e729063b0214be7f71b39ec89d710
上级
f077a529
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
360 addition
and
230 deletion
+360
-230
dnn/src/common/winograd/winograd_helper.cpp
dnn/src/common/winograd/winograd_helper.cpp
+66
-68
dnn/src/common/winograd/winograd_helper.h
dnn/src/common/winograd/winograd_helper.h
+2
-2
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+1
-1
dnn/src/fallback/conv_bias/winograd/strategy.cpp
dnn/src/fallback/conv_bias/winograd/strategy.cpp
+211
-88
dnn/src/fallback/conv_bias/winograd/winograd.h
dnn/src/fallback/conv_bias/winograd/winograd.h
+19
-33
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
+31
-19
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
+30
-19
未找到文件。
dnn/src/common/winograd/winograd_helper.cpp
浏览文件 @
c6eb2e8d
...
...
@@ -247,33 +247,31 @@ void StrategyHelper<
Getter
<
ctype
,
input_filter_compute_type
>
getter
(
dtype
);
InputVisitor
<
layout
,
format
>
intput_visitor
(
IC
);
rep
(
ic
,
IC
)
{
memset
(
mid_buf1
,
0
,
alpha
*
alpha
*
sizeof
(
input_filter_compute_type
));
rep
(
i
,
alpha
)
rep
(
j
,
alpha
)
{
int
ih
=
ih_start
+
i
;
int
iw
=
iw_start
+
j
;
if
(
ih
>=
0
&&
ih
<
(
int
)
IH
&&
iw
>=
0
&&
iw
<
(
int
)
IW
)
{
mid_buf1
[
i
*
alpha
+
j
]
=
getter
(
input
[
intput_visitor
.
get
(
alpha
,
ic
,
IH
,
IW
,
ih
,
iw
)]);
}
memset
(
mid_buf1
,
0
,
alpha
*
alpha
*
sizeof
(
input_filter_compute_type
));
rep
(
i
,
alpha
)
rep
(
j
,
alpha
)
{
int
ih
=
ih_start
+
i
;
int
iw
=
iw_start
+
j
;
if
(
ih
>=
0
&&
ih
<
(
int
)
IH
&&
iw
>=
0
&&
iw
<
(
int
)
IW
)
{
mid_buf1
[
i
*
alpha
+
j
]
=
getter
(
input
[
intput_visitor
.
get
(
alpha
,
ic
,
IH
,
IW
,
ih
,
iw
)]);
}
}
megdnn
::
naive
::
run_matrix_mul_tpl
<
input_filter_compute_type
,
input_filter_compute_type
,
true
,
false
>
(
winograd_coeff
.
B
(
rescale
).
data
(),
mid_buf1
,
mid_buf2
,
alpha
,
alpha
,
alpha
,
alpha
,
alpha
,
alpha
,
dtype
,
dtype
);
megdnn
::
naive
::
run_matrix_mul_tpl
<
input_filter_compute_type
,
input_filter_compute_type
,
false
,
false
>
(
mid_buf2
,
winograd_coeff
.
B
(
rescale
).
data
(),
mid_buf1
,
alpha
,
alpha
,
alpha
,
alpha
,
alpha
,
alpha
,
dtype
,
dtype
);
rep
(
i
,
alpha
)
rep
(
j
,
alpha
)
{
input_transform_buf
[
intput_visitor
.
put
(
alpha
,
ic
,
nr_units_in_tile
,
unit_idx
,
i
,
j
)]
=
mid_buf1
[
i
*
alpha
+
j
];
}
megdnn
::
naive
::
run_matrix_mul_tpl
<
input_filter_compute_type
,
input_filter_compute_type
,
true
,
false
>
(
winograd_coeff
.
B
(
rescale
).
data
(),
mid_buf1
,
mid_buf2
,
alpha
,
alpha
,
alpha
,
alpha
,
alpha
,
alpha
,
dtype
,
dtype
);
megdnn
::
naive
::
run_matrix_mul_tpl
<
input_filter_compute_type
,
input_filter_compute_type
,
false
,
false
>
(
mid_buf2
,
winograd_coeff
.
B
(
rescale
).
data
(),
mid_buf1
,
alpha
,
alpha
,
alpha
,
alpha
,
alpha
,
alpha
,
dtype
,
dtype
);
rep
(
i
,
alpha
)
rep
(
j
,
alpha
)
{
input_transform_buf
[
intput_visitor
.
put
(
alpha
,
ic
,
nr_units_in_tile
,
unit_idx
,
i
,
j
)]
=
mid_buf1
[
i
*
alpha
+
j
];
}
}
...
...
@@ -287,7 +285,7 @@ void StrategyHelper<
output_compute_type
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_
end
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
oc_
index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
input_filter_scale
,
float
input_filter_rescale
,
...
...
@@ -300,49 +298,49 @@ void StrategyHelper<
OutputGetter
<
output_compute_type
,
dst_type
>
getter
(
dtype
);
OutputVisitor
<
layout
,
format
>
output_visitor
(
oc_end
-
oc_start
);
for
(
size_t
oc
=
oc_start
;
oc
<
oc_end
;
oc
++
)
{
/* gather */
rep
(
i
,
alpha
)
rep
(
j
,
alpha
)
{
mid_buf1
[
i
*
alpha
+
j
]
=
output_transform_buf
[
output_visitor
.
get
(
alpha
,
oc
-
oc_start
,
oc
,
nr_units_in_tile
,
unit_idx
,
i
,
j
)];
}
/* A[alpha*m] M[alpha*alpha] */
megdnn
::
naive
::
run_matrix_mul_tpl
<
output_compute_type
,
output_compute_type
,
true
,
false
>
(
winograd_coeff
.
A
(
rescale
).
data
(),
mid_buf1
,
mid_buf2
,
m
,
alpha
,
alpha
,
m
,
alpha
,
alpha
,
dtype
,
dtype
);
megdnn
::
naive
::
run_matrix_mul_tpl
<
output_compute_type
,
output_compute_type
,
false
,
false
>
(
mid_buf2
,
winograd_coeff
.
A
(
rescale
).
data
(),
mid_buf1
,
m
,
m
,
alpha
,
alpha
,
m
,
m
,
dtype
,
dtype
);
rep
(
i
,
m
)
rep
(
j
,
m
)
{
auto
oh
=
oh_start
+
i
;
auto
ow
=
ow_start
+
j
;
if
(
oh
<
OH
&&
ow
<
OW
)
{
float
val
=
mid_buf1
[
i
*
m
+
j
];
if
(
bmode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
val
+=
bias
[
oc
]
*
input_filter_rescale
*
input_filter_rescale
;
}
else
if
(
bmode
==
BiasMode
::
BIAS
)
{
val
+=
bias
[
output_visitor
.
put
(
oc
,
OH
,
OW
,
oh
,
ow
)]
*
input_filter_rescale
*
input_filter_rescale
;
}
val
=
val
*
input_filter_scale
/
(
input_filter_rescale
*
input_filter_rescale
*
rescale
*
rescale
);
if
(
nonline_mode
==
NonlineMode
::
RELU
)
{
val
=
val
>
0
?
val
:
0
;
}
else
if
(
nonline_mode
==
NonlineMode
::
SIGMOID
)
{
val
=
1.
f
/
(
expf
(
-
val
)
+
1.
f
);
}
else
if
(
nonline_mode
==
NonlineMode
::
H_SWISH
)
{
val
=
val
*
std
::
min
(
std
::
max
(
val
+
3
,
0.
f
),
6.
f
)
/
6.
f
;
}
else
{
megdnn_assert
(
nonline_mode
==
NonlineMode
::
IDENTITY
);
}
output
[
output_visitor
.
put
(
oc
,
OH
,
OW
,
oh
,
ow
)]
=
getter
(
val
);
size_t
oc
=
oc_start
+
oc_index
;
/* gather */
rep
(
i
,
alpha
)
rep
(
j
,
alpha
)
{
mid_buf1
[
i
*
alpha
+
j
]
=
output_transform_buf
[
output_visitor
.
get
(
alpha
,
oc_index
,
oc
,
nr_units_in_tile
,
unit_idx
,
i
,
j
)];
}
/* A[alpha*m] M[alpha*alpha] */
megdnn
::
naive
::
run_matrix_mul_tpl
<
output_compute_type
,
output_compute_type
,
true
,
false
>
(
winograd_coeff
.
A
(
rescale
).
data
(),
mid_buf1
,
mid_buf2
,
m
,
alpha
,
alpha
,
m
,
alpha
,
alpha
,
dtype
,
dtype
);
megdnn
::
naive
::
run_matrix_mul_tpl
<
output_compute_type
,
output_compute_type
,
false
,
false
>
(
mid_buf2
,
winograd_coeff
.
A
(
rescale
).
data
(),
mid_buf1
,
m
,
m
,
alpha
,
alpha
,
m
,
m
,
dtype
,
dtype
);
rep
(
i
,
m
)
rep
(
j
,
m
)
{
auto
oh
=
oh_start
+
i
;
auto
ow
=
ow_start
+
j
;
if
(
oh
<
OH
&&
ow
<
OW
)
{
float
val
=
mid_buf1
[
i
*
m
+
j
];
if
(
bmode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
val
+=
bias
[
oc
]
*
input_filter_rescale
*
input_filter_rescale
;
}
else
if
(
bmode
==
BiasMode
::
BIAS
)
{
val
+=
bias
[
output_visitor
.
put
(
oc
,
OH
,
OW
,
oh
,
ow
)]
*
input_filter_rescale
*
input_filter_rescale
;
}
val
=
val
*
input_filter_scale
/
(
input_filter_rescale
*
input_filter_rescale
*
rescale
*
rescale
);
if
(
nonline_mode
==
NonlineMode
::
RELU
)
{
val
=
val
>
0
?
val
:
0
;
}
else
if
(
nonline_mode
==
NonlineMode
::
SIGMOID
)
{
val
=
1.
f
/
(
expf
(
-
val
)
+
1.
f
);
}
else
if
(
nonline_mode
==
NonlineMode
::
H_SWISH
)
{
val
=
val
*
std
::
min
(
std
::
max
(
val
+
3
,
0.
f
),
6.
f
)
/
6.
f
;
}
else
{
megdnn_assert
(
nonline_mode
==
NonlineMode
::
IDENTITY
);
}
output
[
output_visitor
.
put
(
oc
,
OH
,
OW
,
oh
,
ow
)]
=
getter
(
val
);
}
}
};
...
...
dnn/src/common/winograd/winograd_helper.h
浏览文件 @
c6eb2e8d
...
...
@@ -44,7 +44,7 @@ public:
input_filter_compute_type
*
input_transform_buf
,
input_filter_compute_type
*
transform_mid_buf
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
IC
,
size_t
ic
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
rescale
=
1.0
f
);
...
...
@@ -54,7 +54,7 @@ public:
const
output_compute_type
*
bias
,
dst_type
*
output
,
output_compute_type
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_
end
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_
index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
m
,
size_t
r
,
const
std
::
vector
<
float
>&
interp_points
,
DType
dtype
,
float
input_filter_scale
=
1.0
f
,
// input_scale * filter_scale
...
...
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
c6eb2e8d
...
...
@@ -55,7 +55,7 @@ public:
ohw_tile_size
));
all_algos
.
emplace_back
(
refhold
.
back
().
get
());
}
#if
0
#if
1
//! As these algos maybe very slow, it will make fastrun search slow, so
//! we disable it, but for the test of strategyhelper, we just keep it.
//! FIXME: I do not know a better way to do it.
...
...
dnn/src/fallback/conv_bias/winograd/strategy.cpp
浏览文件 @
c6eb2e8d
此差异已折叠。
点击以展开。
dnn/src/fallback/conv_bias/winograd/winograd.h
浏览文件 @
c6eb2e8d
...
...
@@ -321,17 +321,10 @@ public:
"nr_tiles_in_unit: %zu TILE_SIZE:%zu"
,
nr_tiles_in_unit
,
unit_tile_size
);
}
rep
(
unit_idx
,
nr_tiles_in_unit
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
size_t
nh
=
index
/
units_w
;
size_t
nw
=
index
%
units_w
;
int
ih_start
=
nh
*
Strategy
::
OUTPUT_BLOCK_SIZE
-
PH
;
int
iw_start
=
nw
*
Strategy
::
OUTPUT_BLOCK_SIZE
-
PW
;
strategy
.
input
(
src_ptr
,
input_transform_buf
,
transform_mid_buf
,
ih_start
,
iw_start
,
IH
,
IW
,
IC
,
unit_idx
,
nr_tiles_in_unit
);
}
//! BTdB
strategy
.
input
(
src_ptr
,
input_transform_buf
,
transform_mid_buf
,
IH
,
IW
,
IC
,
PH
,
PW
,
unit_start_idx
,
nr_tiles_in_unit
);
rep
(
i
,
Strategy
::
ALPHA
)
rep
(
j
,
Strategy
::
ALPHA
)
{
if
(
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
)
{
matmul_param
.
A_ptr
=
...
...
@@ -368,22 +361,14 @@ public:
}
matmul_kern
(
matmul_param
);
}
/* Y = ATmA */
rep
(
unit_idx
,
nr_tiles_in_unit
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
auto
nh
=
index
/
units_w
;
auto
nw
=
index
%
units_w
;
size_t
oh_start
=
nh
*
Strategy
::
OUTPUT_BLOCK_SIZE
;
size_t
ow_start
=
nw
*
Strategy
::
OUTPUT_BLOCK_SIZE
;
size_t
oc_end_idx
=
oc_start_idx
+
nr_oc_in_unit
;
strategy
.
output
(
output_transform_buf
,
bias_ptr
,
dst_ptr
,
reinterpret_cast
<
output_compute_type
*>
(
transform_mid_buf
),
ncb_param
.
bias_mode
,
ncb_param
.
nonlineMode
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start_idx
,
oc_end_idx
,
unit_idx
,
nr_tiles_in_unit
);
}
//! Y = ATmA
size_t
oc_end_idx
=
oc_start_idx
+
nr_oc_in_unit
;
strategy
.
output
(
output_transform_buf
,
bias_ptr
,
dst_ptr
,
reinterpret_cast
<
output_compute_type
*>
(
transform_mid_buf
),
ncb_param
.
bias_mode
,
ncb_param
.
nonlineMode
,
OH
,
OW
,
oc_start_idx
,
oc_end_idx
,
unit_start_idx
,
nr_tiles_in_unit
);
};
SmallVector
<
NCBKern
>
get_kerns
(
...
...
@@ -542,15 +527,16 @@ public:
size_t IC, size_t oc_start, size_t oc_end); \
void input(const stype* input, \
input_filter_compute_type* input_transform_buf, \
input_filter_compute_type* transform_mid_buf,
int ih_start,
\
int iw_start, size_t IH, size_t IW, size_t IC,
\
size_t unit_
idx, size_t nr_tiles_in_unit);
\
input_filter_compute_type* transform_mid_buf,
\
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW,
\
size_t unit_
start_idx, size_t nr_tiles_in_unit);
\
void output(const output_compute_type* output_transform_buf, \
const output_compute_type* bias, dst_type* output, \
output_compute_type* transform_mid_buf, BiasMode bmode, \
NonlineMode nonline_mode, size_t oh_start, \
size_t ow_start, size_t OH, size_t OW, size_t oc_start, \
size_t oc_end, size_t unit_idx, size_t nr_tiles_in_unit); \
NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \
}
;
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
...
...
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
浏览文件 @
c6eb2e8d
...
...
@@ -274,31 +274,43 @@ void winograd_nchw88_2x3_8x8_f::filter(const float* filter,
transform_mid_buf
,
OC
,
IC
,
oc_start
,
oc_end
);
}
void
winograd_nchw88_2x3_8x8_f
::
input
(
const
float
*
input
,
float
*
input_transform_buf
,
float
*
transform_mid_buf
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
uni
t_idx
,
float
*
transform_mid_buf
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
PH
,
size_t
PW
,
size_t
unit_star
t_idx
,
size_t
nr_units_in_tile
)
{
megdnn_assert
(
IC
%
8
==
0
);
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto
units_w
=
div_ceil
<
size_t
>
(
IW
+
2
*
PW
-
KERNEL_SIZE
+
1
,
OUTPUT_BLOCK_SIZE
);
float
*
patch
=
transform_mid_buf
;
float
*
patchT
=
transform_mid_buf
+
8
*
alpha
*
alpha
;
if
(
ih_start
>=
0
&&
ih_start
+
alpha
<=
static_cast
<
size_t
>
(
IH
)
&&
iw_start
>=
0
&&
iw_start
+
alpha
<=
static_cast
<
size_t
>
(
IW
))
{
for
(
size_t
ic
=
0
;
ic
<
IC
;
ic
+=
8
)
{
InputTransform2X3_NCHW88
::
prepare
<
true
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransform2X3_NCHW88
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
}
else
{
for
(
size_t
ic
=
0
;
ic
<
IC
;
ic
+=
8
)
{
InputTransform2X3_NCHW88
::
prepare
<
false
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransform2X3_NCHW88
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
for
(
size_t
ic
=
0
;
ic
<
IC
;
ic
+=
8
)
{
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
size_t
nh
=
index
/
units_w
;
size_t
nw
=
index
%
units_w
;
int
ih_start
=
nh
*
OUTPUT_BLOCK_SIZE
-
PH
;
int
iw_start
=
nw
*
OUTPUT_BLOCK_SIZE
-
PW
;
if
(
ih_start
>=
0
&&
ih_start
+
alpha
<=
static_cast
<
size_t
>
(
IH
)
&&
iw_start
>=
0
&&
iw_start
+
alpha
<=
static_cast
<
size_t
>
(
IW
))
{
InputTransform2X3_NCHW88
::
prepare
<
true
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransform2X3_NCHW88
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
else
{
InputTransform2X3_NCHW88
::
prepare
<
false
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransform2X3_NCHW88
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
}
}
}
...
...
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp
浏览文件 @
c6eb2e8d
...
...
@@ -338,32 +338,43 @@ void winograd_nchw88_6x3_8x8_f::filter(const float* filter,
transform_mid_buf
,
OC
,
IC
,
oc_start
,
oc_end
);
}
void
winograd_nchw88_6x3_8x8_f
::
input
(
const
float
*
input
,
float
*
input_transform_buf
,
float
*
transform_mid_buf
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
uni
t_idx
,
float
*
transform_mid_buf
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
PH
,
size_t
PW
,
size_t
unit_star
t_idx
,
size_t
nr_units_in_tile
)
{
megdnn_assert
(
IC
%
8
==
0
);
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto
units_w
=
div_ceil
<
size_t
>
(
IW
+
2
*
PW
-
KERNEL_SIZE
+
1
,
OUTPUT_BLOCK_SIZE
);
float
*
patch
=
transform_mid_buf
;
float
*
patchT
=
transform_mid_buf
+
8
*
alpha
*
alpha
;
if
(
ih_start
>=
0
&&
ih_start
+
alpha
<=
static_cast
<
size_t
>
(
IH
)
&&
iw_start
>=
0
&&
iw_start
+
alpha
<=
static_cast
<
size_t
>
(
IW
))
{
for
(
size_t
ic
=
0
;
ic
<
IC
;
ic
+=
8
)
{
InputTransform6X3_NCHW88
::
prepare
<
true
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransform6X3_NCHW88
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
}
else
{
for
(
size_t
ic
=
0
;
ic
<
IC
;
ic
+=
8
)
{
InputTransform6X3_NCHW88
::
prepare
<
false
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransform6X3_NCHW88
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
for
(
size_t
ic
=
0
;
ic
<
IC
;
ic
+=
8
)
{
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
size_t
nh
=
index
/
units_w
;
size_t
nw
=
index
%
units_w
;
int
ih_start
=
nh
*
OUTPUT_BLOCK_SIZE
-
PH
;
int
iw_start
=
nw
*
OUTPUT_BLOCK_SIZE
-
PW
;
if
(
ih_start
>=
0
&&
ih_start
+
alpha
<=
static_cast
<
size_t
>
(
IH
)
&&
iw_start
>=
0
&&
iw_start
+
alpha
<=
static_cast
<
size_t
>
(
IW
))
{
InputTransform6X3_NCHW88
::
prepare
<
true
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransform6X3_NCHW88
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
else
{
InputTransform6X3_NCHW88
::
prepare
<
false
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransform6X3_NCHW88
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录