Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
532cff71
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
532cff71
编写于
3月 07, 2019
作者:
H
hjchen2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refator depthwise conv3x3 and fix it's bugs for armv8
上级
cb5e15b9
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
1110 addition
and
2532 deletion
+1110
-2532
src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp
...rators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp
+6
-16
src/operators/kernel/arm/convolution/conv_add_kernel.cpp
src/operators/kernel/arm/convolution/conv_add_kernel.cpp
+34
-1
src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp
...operators/kernel/arm/convolution/conv_add_relu_kernel.cpp
+7
-14
src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp
...rators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp
+36
-2
src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp
src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp
+6
-16
src/operators/kernel/arm/convolution/conv_common.cpp
src/operators/kernel/arm/convolution/conv_common.cpp
+12
-16
src/operators/kernel/arm/convolution/conv_kernel.cpp
src/operators/kernel/arm/convolution/conv_kernel.cpp
+8
-14
src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
...perators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
+6
-16
src/operators/kernel/central-arm-func/conv_add_arm_func.h
src/operators/kernel/central-arm-func/conv_add_arm_func.h
+0
-29
src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h
...ators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h
+0
-25
src/operators/kernel/central-arm-func/conv_transpose_arm_func.h
...erators/kernel/central-arm-func/conv_transpose_arm_func.h
+0
-1
src/operators/math/depthwise_conv3x3.cpp
src/operators/math/depthwise_conv3x3.cpp
+972
-1996
src/operators/math/depthwise_conv3x3.h
src/operators/math/depthwise_conv3x3.h
+0
-42
src/operators/math/gemm/cblas.cc
src/operators/math/gemm/cblas.cc
+2
-4
src/operators/math/gemm/executor.h
src/operators/math/gemm/executor.h
+7
-16
src/operators/math/im2col.cpp
src/operators/math/im2col.cpp
+12
-320
src/operators/op_param.h
src/operators/op_param.h
+2
-4
未找到文件。
src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp
浏览文件 @
532cff71
...
@@ -61,25 +61,15 @@ template <>
...
@@ -61,25 +61,15 @@ template <>
void
ConvAddBNReluKernel
<
CPU
,
float
>::
Compute
(
void
ConvAddBNReluKernel
<
CPU
,
float
>::
Compute
(
const
FusionConvAddBNReluParam
<
CPU
>
&
param
)
{
const
FusionConvAddBNReluParam
<
CPU
>
&
param
)
{
switch
(
param
.
ExecMode
())
{
switch
(
param
.
ExecMode
())
{
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1P1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1_FLOAT
:
math
::
DepthwiseConvAddBNRelu3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
math
::
DepthwiseConv3x3S1
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Output
(),
param
.
NewScale
(),
param
.
Paddings
(),
param
.
Output
());
param
.
NewBias
(),
true
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P1_FLOAT
:
math
::
DepthwiseConvAddBNRelu3x3s2p1v2
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
true
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P0_FLOAT
:
math
::
DepthwiseConv3x3s2p0
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
nullptr
,
false
,
false
);
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
param
.
NewBias
(),
param
.
Output
());
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3
S2
_FLOAT
:
math
::
DepthwiseConv3x3
(
param
.
Input
(),
param
.
Strides
(),
param
.
Paddings
(),
math
::
DepthwiseConv3x3
S2
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Filter
(),
nullptr
,
param
.
Output
(),
false
);
param
.
Paddings
(),
param
.
Output
()
);
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
param
.
NewBias
(),
param
.
Output
());
break
;
break
;
...
...
src/operators/kernel/arm/convolution/conv_add_kernel.cpp
浏览文件 @
532cff71
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP
#ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h"
#include "operators/kernel/conv_add_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_add_arm_func.h"
#include "operators/kernel/central-arm-func/conv_add_arm_func.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
...
@@ -21,12 +22,44 @@ namespace operators {
...
@@ -21,12 +22,44 @@ namespace operators {
template
<
>
template
<
>
bool
ConvAddKernel
<
CPU
,
float
>::
Init
(
FusionConvAddParam
<
CPU
>
*
param
)
{
bool
ConvAddKernel
<
CPU
,
float
>::
Init
(
FusionConvAddParam
<
CPU
>
*
param
)
{
InitBaseConvKernel
(
param
);
return
true
;
return
true
;
}
}
template
<
>
template
<
>
void
ConvAddKernel
<
CPU
,
float
>::
Compute
(
const
FusionConvAddParam
<
CPU
>
&
param
)
{
void
ConvAddKernel
<
CPU
,
float
>::
Compute
(
const
FusionConvAddParam
<
CPU
>
&
param
)
{
ConvAddCompute
<
float
>
(
param
);
switch
(
param
.
ExecMode
())
{
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1_FLOAT
:
math
::
DepthwiseConv3x3S1
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Paddings
(),
param
.
Output
());
math
::
AddChannelWise
<
IDENTITY
>
(
param
.
Output
(),
param
.
Bias
(),
param
.
Output
());
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2_FLOAT
:
math
::
DepthwiseConv3x3S2
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Paddings
(),
param
.
Output
());
math
::
AddChannelWise
<
IDENTITY
>
(
param
.
Output
(),
param
.
Bias
(),
param
.
Output
());
break
;
#ifndef __aarch64__
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE5x5_FLOAT
:
DepthwiseConv5x5
<
float
,
float
>
(
param
);
math
::
AddChannelWise
<
IDENTITY
>
(
param
.
Output
(),
param
.
Bias
(),
param
.
Output
());
break
;
case
ConvParam
<
CPU
>::
EXEC_WINOGRAD3X3_FLOAT
:
WinogradConv3x3
<
8
,
3
>
(
param
);
math
::
AddChannelWise
<
IDENTITY
>
(
param
.
Output
(),
param
.
Bias
(),
param
.
Output
());
break
;
#endif // __aarch64__
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
ConvAddBasic
(
param
);
break
;
default:
PADDLE_MOBILE_THROW_EXCEPTION
(
"Invalid convolution execute mode %d"
,
param
.
ExecMode
());
}
}
}
template
class
ConvAddKernel
<
CPU
,
float
>;
template
class
ConvAddKernel
<
CPU
,
float
>;
...
...
src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp
浏览文件 @
532cff71
...
@@ -31,21 +31,14 @@ template <>
...
@@ -31,21 +31,14 @@ template <>
void
ConvAddReluKernel
<
CPU
,
float
>::
Compute
(
void
ConvAddReluKernel
<
CPU
,
float
>::
Compute
(
const
FusionConvAddReluParam
<
CPU
>
&
param
)
{
const
FusionConvAddReluParam
<
CPU
>
&
param
)
{
switch
(
param
.
ExecMode
())
{
switch
(
param
.
ExecMode
())
{
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1P1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1_FLOAT
:
math
::
DepthwiseConv3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
math
::
DepthwiseConv3x3S1
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Bias
(),
true
,
true
);
param
.
Paddings
(),
param
.
Output
());
break
;
math
::
AddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
Bias
(),
param
.
Output
());
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P1_FLOAT
:
math
::
DepthwiseConv3x3s2p1v2
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
Bias
(),
true
,
true
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P0_FLOAT
:
math
::
DepthwiseConv3x3s2p0
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
Bias
(),
true
,
true
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3
S2
_FLOAT
:
math
::
DepthwiseConv3x3
(
param
.
Input
(),
param
.
Strides
(),
param
.
Paddings
(),
math
::
DepthwiseConv3x3
S2
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Filter
(),
nullptr
,
param
.
Output
(),
false
);
param
.
Paddings
(),
param
.
Output
()
);
math
::
AddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
Bias
(),
param
.
Output
());
math
::
AddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
Bias
(),
param
.
Output
());
break
;
break
;
#ifndef __aarch64__
#ifndef __aarch64__
...
...
src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp
浏览文件 @
532cff71
...
@@ -16,7 +16,8 @@ limitations under the License. */
...
@@ -16,7 +16,8 @@ limitations under the License. */
#include "operators/kernel/conv_bn_add_relu_kernel.h"
#include "operators/kernel/conv_bn_add_relu_kernel.h"
#include <cmath>
#include <cmath>
#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
...
@@ -51,13 +52,46 @@ bool ConvBNAddReluKernel<CPU, float>::Init(
...
@@ -51,13 +52,46 @@ bool ConvBNAddReluKernel<CPU, float>::Init(
}
}
param
->
SetNewScale
(
new_scale
);
param
->
SetNewScale
(
new_scale
);
param
->
SetNewBias
(
new_bias
);
param
->
SetNewBias
(
new_bias
);
InitBaseConvKernel
(
param
);
return
true
;
return
true
;
}
}
template
<
>
template
<
>
void
ConvBNAddReluKernel
<
CPU
,
float
>::
Compute
(
void
ConvBNAddReluKernel
<
CPU
,
float
>::
Compute
(
const
FusionConvBNAddReluParam
<
CPU
>
&
param
)
{
const
FusionConvBNAddReluParam
<
CPU
>
&
param
)
{
ConvBNAddReluCompute
<
float
>
(
param
);
switch
(
param
.
ExecMode
())
{
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1_FLOAT
:
math
::
DepthwiseConv3x3S1
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Paddings
(),
param
.
Output
());
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2_FLOAT
:
math
::
DepthwiseConv3x3S2
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Paddings
(),
param
.
Output
());
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
break
;
#ifndef __aarch64__
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE5x5_FLOAT
:
DepthwiseConv5x5
<
float
,
float
>
(
param
);
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
break
;
case
ConvParam
<
CPU
>::
EXEC_WINOGRAD3X3_FLOAT
:
WinogradConv3x3
<
8
,
3
>
(
param
);
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
break
;
#endif // __aarch64__
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
ConvBNReluBasic
<
FusionConvBNAddReluParam
<
CPU
>>
(
param
);
break
;
default:
PADDLE_MOBILE_THROW_EXCEPTION
(
"Invalid convolution execute mode %d"
,
param
.
ExecMode
());
}
}
}
template
class
ConvBNAddReluKernel
<
CPU
,
float
>;
template
class
ConvBNAddReluKernel
<
CPU
,
float
>;
...
...
src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp
浏览文件 @
532cff71
...
@@ -60,25 +60,15 @@ template <>
...
@@ -60,25 +60,15 @@ template <>
void
ConvBNReluKernel
<
CPU
,
float
>::
Compute
(
void
ConvBNReluKernel
<
CPU
,
float
>::
Compute
(
const
FusionConvBNReluParam
<
CPU
>
&
param
)
{
const
FusionConvBNReluParam
<
CPU
>
&
param
)
{
switch
(
param
.
ExecMode
())
{
switch
(
param
.
ExecMode
())
{
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1P1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1_FLOAT
:
math
::
DepthwiseConvAddBNRelu3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
math
::
DepthwiseConv3x3S1
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Output
(),
param
.
NewScale
(),
param
.
Paddings
(),
param
.
Output
());
param
.
NewBias
(),
true
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P1_FLOAT
:
math
::
DepthwiseConvAddBNRelu3x3s2p1v2
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
true
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P0_FLOAT
:
math
::
DepthwiseConv3x3s2p0
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
nullptr
,
false
,
false
);
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
param
.
NewBias
(),
param
.
Output
());
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3
S2
_FLOAT
:
math
::
DepthwiseConv3x3
(
param
.
Input
(),
param
.
Strides
(),
param
.
Paddings
(),
math
::
DepthwiseConv3x3
S1
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Filter
(),
nullptr
,
param
.
Output
(),
false
);
param
.
Paddings
(),
param
.
Output
()
);
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
param
.
NewBias
(),
param
.
Output
());
break
;
break
;
...
...
src/operators/kernel/arm/convolution/conv_common.cpp
浏览文件 @
532cff71
...
@@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
...
@@ -44,29 +44,25 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
#endif // __aarch64__
#endif // __aarch64__
}
else
{
}
else
{
if
(
depth3x3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
if
(
depth3x3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
param
->
Strides
()[
0
]
==
1
&&
param
->
Paddings
()[
0
]
==
1
&&
param
->
Strides
()[
0
]
==
1
)
{
param
->
Paddings
()[
0
]
==
param
->
Paddings
()[
1
])
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1_FLOAT
;
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1P1_FLOAT
;
}
else
if
(
depth3x3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
}
else
if
(
depth3x3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
param
->
Strides
()[
0
]
==
2
&&
param
->
Paddings
()[
0
]
==
0
&&
param
->
Strides
()[
0
]
==
2
)
{
param
->
Paddings
()[
0
]
==
param
->
Paddings
()[
1
])
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2_FLOAT
;
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P0_FLOAT
;
}
else
if
(
depth3x3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
param
->
Strides
()[
0
]
==
2
&&
param
->
Paddings
()[
0
]
==
1
&&
param
->
Paddings
()[
0
]
==
param
->
Paddings
()[
1
])
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P1_FLOAT
;
}
else
if
(
depth3x3
)
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_FLOAT
;
#ifndef __aarch64__
#ifndef __aarch64__
}
else
if
(
depth5x5
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
}
else
if
(
depth5x5
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
param
->
Strides
()[
0
]
==
1
)
{
param
->
Strides
()[
0
]
==
1
)
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE5x5_FLOAT
;
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE5x5_FLOAT
;
}
else
if
(
conv3x3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
}
else
if
(
conv3x3
&&
!
depth3x3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
param
->
Dilations
()[
0
]
==
param
->
Dilations
()[
1
]
&&
param
->
Dilations
()[
0
]
==
param
->
Dilations
()[
1
]
&&
param
->
Strides
()[
0
]
==
1
&&
param
->
Dilations
()[
0
]
==
1
/* &&
param
->
Strides
()[
0
]
==
1
&&
param
->
Dilations
()[
0
]
==
1
param->Output()->dims()[1] >= 16 &&
#if 0
&& param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 */
/* refered from ncnn */
)
{
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */
#endif
)
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_WINOGRAD3X3_FLOAT
;
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_WINOGRAD3X3_FLOAT
;
// transform weight
// transform weight
param
->
transformed_filter_
=
new
framework
::
LoDTensor
;
param
->
transformed_filter_
=
new
framework
::
LoDTensor
;
...
...
src/operators/kernel/arm/convolution/conv_kernel.cpp
浏览文件 @
532cff71
...
@@ -18,6 +18,8 @@ limitations under the License. */
...
@@ -18,6 +18,8 @@ limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <iostream>
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
...
@@ -41,21 +43,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> ¶m) {
...
@@ -41,21 +43,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> ¶m) {
DepthwiseConv5x5
<
int8_t
,
int32_t
>
(
param
);
DepthwiseConv5x5
<
int8_t
,
int32_t
>
(
param
);
break
;
break
;
#endif // __aarch64__
#endif // __aarch64__
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1P1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1_FLOAT
:
math
::
DepthwiseConv3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
math
::
DepthwiseConv3x3S1
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
nullptr
,
false
,
false
);
param
.
Paddings
(),
param
.
Output
());
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P1_FLOAT
:
math
::
DepthwiseConv3x3s2p1v2
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
nullptr
,
false
,
false
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P0_FLOAT
:
math
::
DepthwiseConv3x3s2p0
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
nullptr
,
false
,
false
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3
S2
_FLOAT
:
math
::
DepthwiseConv3x3
(
param
.
Input
(),
param
.
Strides
(),
param
.
Paddings
(),
math
::
DepthwiseConv3x3
S2
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Filter
(),
nullptr
,
param
.
Output
(),
false
);
param
.
Paddings
(),
param
.
Output
()
);
break
;
break
;
#ifndef __aarch64__
#ifndef __aarch64__
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE5x5_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE5x5_FLOAT
:
...
...
src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
浏览文件 @
532cff71
...
@@ -60,25 +60,15 @@ template <>
...
@@ -60,25 +60,15 @@ template <>
void
DWConvBNReluKernel
<
CPU
,
float
>::
Compute
(
void
DWConvBNReluKernel
<
CPU
,
float
>::
Compute
(
const
FusionDWConvBNReluParam
<
CPU
>
&
param
)
{
const
FusionDWConvBNReluParam
<
CPU
>
&
param
)
{
switch
(
param
.
ExecMode
())
{
switch
(
param
.
ExecMode
())
{
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1P1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1_FLOAT
:
math
::
DepthwiseConvAddBNRelu3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
math
::
DepthwiseConv3x3S1
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Output
(),
param
.
NewScale
(),
param
.
Paddings
(),
param
.
Output
());
param
.
NewBias
(),
true
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P1_FLOAT
:
math
::
DepthwiseConvAddBNRelu3x3s2p1v2
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
true
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2P0_FLOAT
:
math
::
DepthwiseConv3x3s2p0
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
nullptr
,
false
,
false
);
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
param
.
NewBias
(),
param
.
Output
());
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3
S2
_FLOAT
:
math
::
DepthwiseConv3x3
(
param
.
Input
(),
param
.
Strides
(),
param
.
Paddings
(),
math
::
DepthwiseConv3x3
S2
<
float
,
float
>
(
*
param
.
Input
(),
*
param
.
Filter
(),
param
.
Filter
(),
nullptr
,
param
.
Output
(),
false
);
param
.
Paddings
(),
param
.
Output
()
);
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
math
::
ScaleAddChannelWise
<
RELU
>
(
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
param
.
Output
());
param
.
NewBias
(),
param
.
Output
());
break
;
break
;
...
...
src/operators/kernel/central-arm-func/conv_add_arm_func.h
浏览文件 @
532cff71
...
@@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam<CPU> ¶m) {
...
@@ -115,35 +115,6 @@ void ConvAddBasic(const FusionConvAddParam<CPU> ¶m) {
}
}
}
}
template
<
typename
P
>
void
ConvAddCompute
(
const
FusionConvAddParam
<
CPU
>
&
param
)
{
param
.
Output
()
->
mutable_data
<
float
>
();
if
(
param
.
Groups
()
==
param
.
Input
()
->
dims
()[
1
]
&&
param
.
Input
()
->
dims
()[
1
]
==
param
.
Output
()
->
dims
()[
1
]
&&
param
.
Filter
()
->
dims
()[
2
]
==
param
.
Filter
()
->
dims
()[
3
]
&&
param
.
Filter
()
->
dims
()[
2
]
==
3
&&
param
.
Strides
()[
0
]
==
1
)
{
math
::
DepthwiseConv3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
Bias
(),
true
,
false
);
}
else
if
(
param
.
Groups
()
==
param
.
Input
()
->
dims
()[
1
]
&&
param
.
Input
()
->
dims
()[
1
]
==
param
.
Output
()
->
dims
()[
1
]
&&
param
.
Filter
()
->
dims
()[
2
]
==
param
.
Filter
()
->
dims
()[
3
]
&&
param
.
Filter
()
->
dims
()[
2
]
==
3
&&
param
.
Strides
()[
0
]
==
2
)
{
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
if
(
param
.
Paddings
()[
0
]
==
0
)
{
math
::
DepthwiseConv3x3s2p0
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
Bias
(),
true
,
false
);
}
else
{
math
::
DepthwiseConv3x3s2p1v2
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
Bias
(),
true
,
false
);
}
}
else
{
ConvAddBasic
(
param
);
}
}
}
// namespace operators
}
// namespace operators
}
// namespace paddle_mobile
}
// namespace paddle_mobile
...
...
src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h
浏览文件 @
532cff71
...
@@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> ¶m) {
...
@@ -115,31 +115,6 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> ¶m) {
}
}
}
}
}
}
template
<
typename
P
>
void
ConvBNAddReluCompute
(
const
FusionConvBNAddReluParam
<
CPU
>
&
param
)
{
Tensor
Bias
;
Bias
.
mutable_data
<
float
>
({
param
.
Groups
()});
if
(
param
.
Groups
()
==
param
.
Input
()
->
dims
()[
1
]
&&
param
.
Input
()
->
dims
()[
1
]
==
param
.
Output
()
->
dims
()[
1
]
&&
param
.
Filter
()
->
dims
()[
2
]
==
param
.
Filter
()
->
dims
()[
3
]
&&
param
.
Filter
()
->
dims
()[
2
]
==
3
&&
param
.
Strides
()[
0
]
==
1
)
{
math
::
DepthwiseConvAddBNRelu3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
true
);
}
else
if
(
param
.
Groups
()
==
param
.
Input
()
->
dims
()[
1
]
&&
param
.
Input
()
->
dims
()[
1
]
==
param
.
Output
()
->
dims
()[
1
]
&&
param
.
Filter
()
->
dims
()[
2
]
==
param
.
Filter
()
->
dims
()[
3
]
&&
param
.
Filter
()
->
dims
()[
2
]
==
3
&&
param
.
Strides
()[
0
]
==
2
)
{
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math
::
DepthwiseConvAddBNRelu3x3s2p1v2
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
param
.
NewScale
(),
param
.
NewBias
(),
true
);
}
else
{
ConvBNAddReluBasic
(
param
);
}
}
}
// namespace operators
}
// namespace operators
}
// namespace paddle_mobile
}
// namespace paddle_mobile
...
...
src/operators/kernel/central-arm-func/conv_transpose_arm_func.h
浏览文件 @
532cff71
...
@@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> ¶m) {
...
@@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> ¶m) {
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
paddings
[
1
]},
&
out_slice
);
&
out_slice
);
}
else
if
(
data_dim
==
3U
)
{
}
else
if
(
data_dim
==
3U
)
{
col2vol
(
col
,
dilations
,
strides
,
paddings
,
&
out_slice
);
col2vol
(
col
,
dilations
,
strides
,
paddings
,
&
out_slice
);
}
}
...
...
src/operators/math/depthwise_conv3x3.cpp
浏览文件 @
532cff71
...
@@ -12,2066 +12,1042 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,2066 +12,1042 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON
#include <arm_neon.h>
#include <arm_neon.h>
#endif
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
void
DepthwiseConv3x3
(
const
framework
::
Tensor
*
input
,
#ifndef __aarch64__
const
std
::
vector
<
int
>
&
strides
,
inline
float32x4_t
vpaddq_f32
(
float32x4_t
r0
,
float32x4_t
r1
)
{
const
std
::
vector
<
int
>
&
paddings
,
float32x2_t
sum0
=
vpadd_f32
(
vget_low_f32
(
r0
),
vget_high_f32
(
r0
));
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
bias
,
float32x2_t
sum1
=
vpadd_f32
(
vget_low_f32
(
r1
),
vget_high_f32
(
r1
));
framework
::
Tensor
*
output
,
bool
if_bias
)
{
return
vcombine_f32
(
sum0
,
sum1
);
const
int
batch_size
=
input
->
dims
()[
0
];
const
int
input_height
=
input
->
dims
()[
2
];
const
int
input_width
=
input
->
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
_kernel_size
=
3
;
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
const
float
zero
=
0
;
const
int
input_channel_stride
=
input_height
*
input_width
;
const
int
output_channel_stride
=
output_height
*
output_width
;
const
int
filter_channel_stride
=
9
;
const
float
*
input_ptr
=
input
->
data
<
float
>
();
const
float
*
filter_ptr
=
filter
->
data
<
float
>
();
if
(
if_bias
)
{
math
::
expand_bias
(
*
bias
,
1
,
output
->
dims
());
output
->
ShareDataWith
(
*
bias
);
}
float
*
output_ptr
=
output
->
mutable_data
<
float
>
();
const
float
*
pos1
,
*
pos2
,
*
pos3
,
*
filter1
,
*
filter2
,
*
filter3
,
*
output_ptr2
;
int
hstart
,
wstart
,
hend
,
wend
;
float
result
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
const
float
*
input_data
=
input_ptr
+
(
i
*
output_channels
+
c
)
*
input_channel_stride
;
float
*
output_data
=
output_ptr
+
(
i
*
output_channels
+
c
)
*
output_channel_stride
;
filter1
=
filter_ptr
+
c
*
filter_channel_stride
;
filter2
=
filter1
+
3
;
filter3
=
filter2
+
3
;
for
(
int
ph
=
0
;
ph
<
output_height
;
ph
++
)
{
for
(
int
pw
=
0
;
pw
<
output_width
;
pw
++
)
{
hstart
=
ph
*
stride_height
-
padding_height
;
wstart
=
pw
*
stride_width
-
padding_width
;
hend
=
std
::
min
(
hstart
+
_kernel_size
,
input_height
+
padding_height
);
wend
=
std
::
min
(
wstart
+
_kernel_size
,
input_width
+
padding_width
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
std
::
min
(
hend
,
input_height
);
wend
=
std
::
min
(
wend
,
input_width
);
pos1
=
input_data
+
hstart
*
input_width
+
wstart
;
pos2
=
input_data
+
(
hstart
+
1
)
*
input_width
+
wstart
;
pos3
=
input_data
+
(
hstart
+
2
)
*
input_width
+
wstart
;
output_ptr2
=
output_data
+
ph
*
output_width
+
pw
;
if
(
hend
-
hstart
!=
3
||
wend
-
wstart
!=
3
)
{
result
=
0
;
float
fake_input
[
9
]
=
{
0
};
if
(
hstart
==
0
&&
wstart
==
0
)
{
// 左上角
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
>=
3
-
hend
&&
k
>=
3
-
wend
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
-
(
3
-
hend
))
*
input_width
+
k
-
(
3
-
wend
)];
}
}
}
}
else
if
(
hstart
==
0
&&
wend
==
input_width
)
{
// 右上角
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
>=
3
-
hend
&&
k
<=
input_width
-
wstart
-
1
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
-
(
3
-
hend
))
*
input_width
+
k
+
wstart
];
}
}
}
}
else
if
(
hend
==
input_height
&&
wstart
==
0
)
{
// 左下角
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
<=
input_height
-
1
-
hstart
&&
k
>=
3
-
wend
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
k
-
(
3
-
wend
)];
}
}
}
}
else
if
(
hend
==
input_height
&&
wend
==
input_width
)
{
// 右下角
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
<=
input_height
-
hstart
-
1
&&
k
<=
input_width
-
wstart
-
1
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
k
+
wstart
];
}
}
}
}
else
if
(
hstart
==
0
)
{
// 顶部
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
>=
3
-
hend
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
-
(
3
-
hend
))
*
input_width
+
k
+
wstart
];
}
}
}
}
else
if
(
hend
==
input_height
)
{
// 底部
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
<=
input_height
-
hstart
-
1
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
k
+
wstart
];
}
}
}
}
else
if
(
wstart
==
0
)
{
// 左侧
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
k
>=
3
-
wend
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
(
k
-
(
3
-
wend
))];
}
}
}
}
else
if
(
wend
==
input_width
)
{
// 右侧
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
k
<=
input_width
-
wstart
-
1
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
k
+
wstart
];
}
}
}
}
for
(
int
l
=
0
;
l
<
9
;
++
l
)
{
result
+=
fake_input
[
l
]
*
filter1
[
l
];
}
if
(
if_bias
)
{
output_data
[
ph
*
output_width
+
pw
]
+=
result
;
}
else
{
output_data
[
ph
*
output_width
+
pw
]
=
result
;
}
}
else
{
#if __ARM_NEON
#if __aarch64__
const
float32x4_t
data1
=
vld1q_f32
(
pos1
);
const
float32x4_t
data2
=
vld1q_f32
(
pos2
);
const
float32x4_t
data3
=
vld1q_f32
(
pos3
);
const
float32x4_t
v_filter1
=
vld1q_f32
(
filter1
);
const
float32x4_t
v_filter2
=
vld1q_f32
(
filter2
);
const
float32x4_t
v_filter3
=
vld1q_f32
(
filter3
);
float32x4_t
mula
=
vmulq_f32
(
data1
,
v_filter1
);
mula
=
vmlaq_f32
(
mula
,
data2
,
v_filter2
);
mula
=
vmlaq_f32
(
mula
,
data3
,
v_filter3
);
float32x2_t
res
=
vpadd_f32
(
vget_high_f32
(
vsetq_lane_f32
(
0
,
mula
,
3
)),
vget_low_f32
(
mula
));
res
=
vpadd_f32
(
res
,
res
);
if
(
if_bias
)
{
output_data
[
ph
*
output_width
+
pw
]
+=
vget_lane_f32
(
res
,
0
);
}
else
{
output_data
[
ph
*
output_width
+
pw
]
=
vget_lane_f32
(
res
,
0
);
}
#else
asm
volatile
(
"vld1.32 {q1}, [%[pos1]]
\n\t
"
"vld1.32 {q4}, [%[filter1]]
\n\t
"
"vmov.f32 q0, #0.0
\n\t
"
"vld1.32 {q2}, [%[pos2]]
\n\t
"
"vld1.32 {q5}, [%[filter2]]
\n\t
"
"vmla.f32 q0, q1, q4
\n\t
"
"vld1.32 {q3}, [%[pos3]]
\n\t
"
"vld1.32 {q6}, [%[filter3]]
\n\t
"
"vmla.f32 q0, q2, q5
\n\t
"
"vmla.f32 q0, q3, q6
\n\t
"
"vmov.f32 d1[1], %[zero]
\n\t
"
"vadd.f32 d4, d0, d1
\n\t
"
"vadd.f32 s10, s8, s9
\n\t
"
"vst1.32 {d5[0]},[%[output_ptr]]
\n\t
"
:
:
[
input_data
]
"r"
(
input_data
),
[
pos1
]
"r"
(
pos1
),
[
pos2
]
"r"
(
pos2
),
[
pos3
]
"r"
(
pos3
),
[
filter1
]
"r"
(
filter1
),
[
filter2
]
"r"
(
filter2
),
[
filter3
]
"r"
(
filter3
),
[
output_ptr
]
"r"
(
output_ptr2
),
[
zero
]
"r"
(
zero
)
:
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
#endif // __aarch64__
#else
#endif // __ARM_NEON
}
}
}
}
}
}
}
#endif
void
DepthwiseConv3x3s1p1
(
const
framework
::
Tensor
*
input
,
template
<
int
Stride
=
1
>
const
framework
::
Tensor
*
filter
,
inline
void
Depth3x3NormalRowLoadInput
(
const
float
*
input
,
float32x4_t
*
y
)
{
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias
,
y
[
0
]
=
vld1q_f32
(
input
);
bool
if_bias
,
bool
if_relu
)
{
y
[
2
]
=
vld1q_f32
(
input
+
4
);
#if __ARM_NEON
y
[
1
]
=
vextq_f32
(
y
[
0
],
y
[
2
],
1
);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
y
[
2
]
=
vextq_f32
(
y
[
0
],
y
[
2
],
2
);
const
int
c
=
static_cast
<
int
>
(
input
->
dims
()[
1
]);
}
const
int
h
=
static_cast
<
int
>
(
input
->
dims
()[
2
]);
const
int
w
=
static_cast
<
int
>
(
input
->
dims
()[
3
]);
const
int
hxw
=
h
*
w
;
// const int l = h;
// leftTop, rightTop, leftBottom, rightBottom
const
int
lt
=
0
;
const
int
rt
=
w
-
1
;
const
int
lb
=
(
h
-
1
)
*
w
;
const
int
rb
=
h
*
w
-
1
;
const
float
*
bias_data
;
if
(
if_bias
)
{
bias_data
=
bias
->
data
<
float
>
();
}
float32x4_t
zero
=
vdupq_n_f32
(
0.0
);
for
(
int
b
=
0
;
b
<
batch_size
;
++
b
)
{
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
c
;
++
j
)
{
const
float
*
filter_data_tmp
=
filter
->
data
<
float
>
()
+
j
*
9
;
const
float
*
input_data
=
input
->
data
<
float
>
()
+
j
*
hxw
;
float
*
output_data
=
output
->
mutable_data
<
float
>
()
+
j
*
hxw
;
float32x4_t
vbias
;
if
(
if_bias
)
{
vbias
=
vdupq_n_f32
(
bias_data
[
j
]);
}
int
w_mid
=
w
-
2
;
// l=1->l_mid=-1,l=2->l_mid=0
float
w00
=
filter_data_tmp
[
0
];
float
w01
=
filter_data_tmp
[
1
];
float
w02
=
filter_data_tmp
[
2
];
float
w10
=
filter_data_tmp
[
3
];
float
w11
=
filter_data_tmp
[
4
];
float
w12
=
filter_data_tmp
[
5
];
float
w20
=
filter_data_tmp
[
6
];
float
w21
=
filter_data_tmp
[
7
];
float
w22
=
filter_data_tmp
[
8
];
output_data
[
lt
]
=
w11
*
input_data
[
0
]
+
w12
*
input_data
[
1
]
+
w21
*
input_data
[
w
]
+
w22
*
input_data
[
w
+
1
];
output_data
[
rt
]
=
w10
*
input_data
[
w
-
2
]
+
w11
*
input_data
[
w
-
1
]
+
w20
*
input_data
[
2
*
w
-
2
]
+
w21
*
input_data
[
2
*
w
-
1
];
output_data
[
lb
]
=
w01
*
input_data
[(
h
-
2
)
*
w
]
+
w02
*
input_data
[(
h
-
2
)
*
w
+
1
]
+
w11
*
input_data
[(
h
-
1
)
*
w
]
+
w12
*
input_data
[(
h
-
1
)
*
w
+
1
];
output_data
[
rb
]
=
w00
*
input_data
[
h
*
w
-
w
-
2
]
+
w01
*
input_data
[
h
*
w
-
w
-
1
]
+
w10
*
input_data
[
h
*
w
-
2
]
+
w11
*
input_data
[
h
*
w
-
1
];
if
(
if_bias
)
{
output_data
[
lt
]
+=
bias_data
[
j
];
output_data
[
rt
]
+=
bias_data
[
j
];
output_data
[
lb
]
+=
bias_data
[
j
];
output_data
[
rb
]
+=
bias_data
[
j
];
}
if
(
if_relu
)
{
output_data
[
lt
]
=
output_data
[
lt
]
<
0
?
0
:
output_data
[
lt
];
output_data
[
rt
]
=
output_data
[
rt
]
<
0
?
0
:
output_data
[
rt
];
output_data
[
lb
]
=
output_data
[
lb
]
<
0
?
0
:
output_data
[
lb
];
output_data
[
rb
]
=
output_data
[
rb
]
<
0
?
0
:
output_data
[
rb
];
}
for
(
int
i
=
1
;
i
<
h
-
1
;
++
i
)
{
int
left
=
i
*
w
;
int
right
=
i
*
w
+
w
-
1
;
output_data
[
left
]
=
w01
*
input_data
[
i
*
w
-
w
]
+
w02
*
input_data
[
i
*
w
-
w
+
1
]
+
w11
*
input_data
[
i
*
w
]
+
w12
*
input_data
[
i
*
w
+
1
]
+
w21
*
input_data
[
i
*
w
+
w
]
+
w22
*
input_data
[
i
*
w
+
w
+
1
];
output_data
[
right
]
=
w00
*
input_data
[
i
*
w
+
w
-
1
-
w
-
1
]
+
w01
*
input_data
[
i
*
w
+
w
-
1
-
w
]
+
w10
*
input_data
[
i
*
w
+
w
-
1
-
1
]
+
w11
*
input_data
[
i
*
w
+
w
-
1
]
+
w20
*
input_data
[
i
*
w
+
w
-
1
+
w
-
1
]
+
w21
*
input_data
[
i
*
w
+
w
-
1
+
w
];
if
(
if_bias
)
{
output_data
[
left
]
+=
bias_data
[
j
];
output_data
[
right
]
+=
bias_data
[
j
];
}
if
(
if_relu
)
{
output_data
[
left
]
=
output_data
[
left
]
<
0
?
0
:
output_data
[
left
];
output_data
[
right
]
=
output_data
[
right
]
<
0
?
0
:
output_data
[
right
];
}
}
// top 1 row and bottom 1 row
const
float
*
input_tmp
=
input_data
;
float32x4_t
in0
,
in1
,
in2
,
in3
,
in4
,
in5
,
in6
,
in7
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
tmp4
,
tmp5
,
out0
;
in0
=
vld1q_f32
(
input_tmp
);
in2
=
vld1q_f32
(
input_tmp
+
w
);
const
float
*
input_tmp_end
=
input_tmp
+
(
h
-
2
)
*
w
;
in4
=
vld1q_f32
(
input_tmp_end
);
in6
=
vld1q_f32
(
input_tmp_end
+
w
);
int
c_mid
=
w_mid
;
auto
output_ptr
=
output_data
+
1
;
for
(;
c_mid
>
3
;
c_mid
-=
4
)
{
in1
=
vld1q_f32
(
input_tmp
+
4
);
in3
=
vld1q_f32
(
input_tmp
+
w
+
4
);
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
out0
=
vmulq_n_f32
(
in0
,
w10
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w11
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w12
);
out0
=
vmlaq_n_f32
(
out0
,
in2
,
w20
);
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w21
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w22
);
out0
=
vaddq_f32
(
out0
,
vbias
);
if
(
if_relu
)
{
out0
=
vmaxq_f32
(
out0
,
zero
);
}
vst1q_f32
(
output_ptr
,
out0
);
in5
=
vld1q_f32
(
input_tmp_end
+
4
);
in7
=
vld1q_f32
(
input_tmp_end
+
w
+
4
);
tmp0
=
vextq_f32
(
in4
,
in5
,
1
);
tmp1
=
vextq_f32
(
in4
,
in5
,
2
);
tmp2
=
vextq_f32
(
in6
,
in7
,
1
);
tmp3
=
vextq_f32
(
in6
,
in7
,
2
);
out0
=
vmulq_n_f32
(
in4
,
w00
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w01
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w02
);
out0
=
vmlaq_n_f32
(
out0
,
in6
,
w10
);
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w11
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w12
);
out0
=
vaddq_f32
(
out0
,
vbias
);
if
(
if_relu
)
{
out0
=
vmaxq_f32
(
out0
,
zero
);
}
vst1q_f32
(
output_ptr
+
(
h
-
1
)
*
w
,
out0
);
// can optimize to each 8 stride.
input_tmp
+=
4
;
input_tmp_end
+=
4
;
output_ptr
+=
4
;
in0
=
in1
;
in2
=
in3
;
in4
=
in5
;
in6
=
in7
;
}
// top right pad
float32x4_t
pad0
=
vdupq_n_f32
(
input_data
[
w
-
1
]);
float32x4_t
pad1
=
vdupq_n_f32
(
input_data
[
2
*
w
-
1
]);
tmp0
=
vextq_f32
(
in0
,
pad0
,
1
);
tmp1
=
vextq_f32
(
in0
,
pad0
,
2
);
tmp2
=
vextq_f32
(
in2
,
pad1
,
1
);
tmp3
=
vextq_f32
(
in2
,
pad1
,
2
);
out0
=
vmulq_n_f32
(
in0
,
w10
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w11
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w12
);
out0
=
vmlaq_n_f32
(
out0
,
in2
,
w20
);
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w21
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w22
);
out0
=
vaddq_f32
(
out0
,
vbias
);
if
(
if_relu
)
{
out0
=
vmaxq_f32
(
out0
,
zero
);
}
for
(
int
i
=
0
;
i
<
c_mid
;
++
i
)
{
if
(
i
==
0
)
{
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
0
);
}
if
(
i
==
1
)
{
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
1
);
}
if
(
i
==
2
)
{
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
2
);
}
}
// bottom right pad
float32x4_t
pad2
=
vdupq_n_f32
(
input_data
[
h
*
w
-
1
-
w
]);
float32x4_t
pad3
=
vdupq_n_f32
(
input_data
[
h
*
w
-
1
]);
tmp0
=
vextq_f32
(
in4
,
pad2
,
1
);
tmp1
=
vextq_f32
(
in4
,
pad2
,
2
);
tmp2
=
vextq_f32
(
in6
,
pad3
,
1
);
tmp3
=
vextq_f32
(
in6
,
pad3
,
2
);
out0
=
vmulq_n_f32
(
in4
,
w00
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w01
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w02
);
out0
=
vmlaq_n_f32
(
out0
,
in6
,
w10
);
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w11
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w12
);
out0
=
vaddq_f32
(
out0
,
vbias
);
if
(
if_relu
)
{
out0
=
vmaxq_f32
(
out0
,
zero
);
}
for
(
int
i
=
0
;
i
<
c_mid
;
++
i
)
{
if
(
i
==
0
)
{
vst1q_lane_f32
(
output_ptr
+
(
h
-
1
)
*
w
+
i
,
out0
,
0
);
}
if
(
i
==
1
)
{
vst1q_lane_f32
(
output_ptr
+
(
h
-
1
)
*
w
+
i
,
out0
,
1
);
}
if
(
i
==
2
)
{
vst1q_lane_f32
(
output_ptr
+
(
h
-
1
)
*
w
+
i
,
out0
,
2
);
}
}
// mid
for
(
int
i
=
0
;
i
<
h
-
2
;
++
i
)
{
auto
output_ptr
=
output_data
+
(
i
+
1
)
*
w
+
1
;
input_tmp
=
input_data
+
i
*
w
;
auto
in0_tmp
=
vld1q_f32
(
input_tmp
);
auto
in2_tmp
=
vld1q_f32
(
input_tmp
+
w
);
auto
in4_tmp
=
vld1q_f32
(
input_tmp
+
w
+
w
);
c_mid
=
w_mid
;
for
(;
c_mid
>
3
;
c_mid
-=
4
)
{
auto
in1_tmp
=
vld1q_f32
(
input_tmp
+
4
);
auto
in3_tmp
=
vld1q_f32
(
input_tmp
+
w
+
4
);
auto
in5_tmp
=
vld1q_f32
(
input_tmp
+
w
+
w
+
4
);
tmp0
=
vextq_f32
(
in0_tmp
,
in1_tmp
,
1
);
tmp1
=
vextq_f32
(
in0_tmp
,
in1_tmp
,
2
);
tmp2
=
vextq_f32
(
in2_tmp
,
in3_tmp
,
1
);
tmp3
=
vextq_f32
(
in2_tmp
,
in3_tmp
,
2
);
tmp4
=
vextq_f32
(
in4_tmp
,
in5_tmp
,
1
);
tmp5
=
vextq_f32
(
in4_tmp
,
in5_tmp
,
2
);
out0
=
vmulq_n_f32
(
in0_tmp
,
w00
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w01
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w02
);
out0
=
vmlaq_n_f32
(
out0
,
in2_tmp
,
w10
);
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w11
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w12
);
out0
=
vmlaq_n_f32
(
out0
,
in4_tmp
,
w20
);
out0
=
vmlaq_n_f32
(
out0
,
tmp4
,
w21
);
out0
=
vmlaq_n_f32
(
out0
,
tmp5
,
w22
);
out0
=
vaddq_f32
(
out0
,
vbias
);
if
(
if_relu
)
{
out0
=
vmaxq_f32
(
out0
,
zero
);
}
vst1q_f32
(
output_ptr
,
out0
);
output_ptr
+=
4
;
template
<
>
input_tmp
+=
4
;
inline
void
Depth3x3NormalRowLoadInput
<
2
>
(
const
float
*
input
,
float32x4_t
*
y
)
{
in0_tmp
=
in1_tmp
;
float32x4x2_t
x
=
vld2q_f32
(
input
);
in2_tmp
=
in3_tmp
;
y
[
0
]
=
x
.
val
[
0
];
in4_tmp
=
in5_tmp
;
y
[
1
]
=
x
.
val
[
1
];
}
y
[
2
]
=
vextq_f32
(
y
[
0
],
y
[
0
],
1
);
y
[
2
]
=
vsetq_lane_f32
(
input
[
8
],
y
[
2
],
3
);
}
float32x4_t
pad0
=
vdupq_n_f32
(
input_data
[
i
*
w
+
w
-
1
]);
#define DEPTHWISE_CONV3X3_NORMAL_BORDER(start, end) \
float32x4_t
pad1
=
vdupq_n_f32
(
input_data
[
i
*
w
+
w
-
1
+
w
]);
for (int w = start; w < end; ++w) { \
float32x4_t
pad2
=
vdupq_n_f32
(
input_data
[
i
*
w
+
w
-
1
+
w
+
w
]);
const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 3; \
tmp0
=
vextq_f32
(
in0_tmp
,
pad0
,
1
);
const int w_start = w_in_start > 0 ? w_in_start : 0; \
tmp1
=
vextq_f32
(
in0_tmp
,
pad0
,
2
);
const int w_end = w_in_end < input_w ? w_in_end : input_w; \
tmp2
=
vextq_f32
(
in2_tmp
,
pad1
,
1
);
float value = 0; \
tmp3
=
vextq_f32
(
in2_tmp
,
pad1
,
2
);
for (int h_in = h_start; h_in < h_end; ++h_in) { \
tmp4
=
vextq_f32
(
in4_tmp
,
pad2
,
1
);
for (int w_in = w_start; w_in < w_end; ++w_in) { \
tmp5
=
vextq_f32
(
in4_tmp
,
pad2
,
2
);
value += filter[(h_in - h_in_start) * 3 + (w_in - w_in_start)] * \
input[h_in * input_w + w_in]; \
out0
=
vmulq_n_f32
(
in0_tmp
,
w00
);
} \
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w01
);
} \
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w02
);
output_ptr[w] = value; \
out0
=
vmlaq_n_f32
(
out0
,
in2_tmp
,
w10
);
}
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w11
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w12
);
out0
=
vmlaq_n_f32
(
out0
,
in4_tmp
,
w20
);
out0
=
vmlaq_n_f32
(
out0
,
tmp4
,
w21
);
out0
=
vmlaq_n_f32
(
out0
,
tmp5
,
w22
);
out0
=
vaddq_f32
(
out0
,
vbias
);
if
(
if_relu
)
{
out0
=
vmaxq_f32
(
out0
,
zero
);
}
for
(
int
i
=
0
;
i
<
c_mid
;
++
i
)
{
template
<
int
Stride_h
,
int
Stride_w
>
if
(
i
==
0
)
{
inline
void
DepthwiseConv3x3NormalRow
(
const
float
*
input
,
const
float
*
filter
,
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
0
);
const
int
h_output
,
const
int
input_h
,
}
const
int
input_w
,
const
int
padding_h
,
if
(
i
==
1
)
{
const
int
padding_w
,
const
int
output_w
,
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
1
);
float
*
output
,
float32x4_t
*
ker
)
{
}
const
int
h_in_start
=
-
padding_h
+
h_output
*
Stride_h
;
if
(
i
==
2
)
{
const
int
h_in_end
=
h_in_start
+
3
;
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
2
);
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
}
const
int
h_end
=
h_in_end
<
input_h
?
h_in_end
:
input_h
;
}
}
const
int
valid_w_start
=
(
padding_w
+
Stride_w
-
1
)
/
Stride_w
;
const
int
valid_w_end
=
(
input_w
+
padding_w
-
3
)
/
Stride_w
+
1
;
// const int valid_w_end = output_w - valid_w_start;
float
*
output_ptr
=
output
+
h_output
*
output_w
;
// border left
DEPTHWISE_CONV3X3_NORMAL_BORDER
(
0
,
valid_w_start
)
// middle
int
output_tiles
=
(
valid_w_end
-
valid_w_start
)
>>
2
;
float32x4_t
_sum
,
_x
[
3
];
// valid w
for
(
int
w
=
0
;
w
<
output_tiles
*
4
;
w
+=
4
)
{
_sum
=
vdupq_n_f32
(
0.
f
);
int
output_offset
=
valid_w_start
+
w
;
int
input_w_offset
=
output_offset
*
Stride_w
-
padding_w
;
for
(
int
h_in
=
h_start
;
h_in
<
h_end
;
++
h_in
)
{
int
index
=
h_in
-
h_in_start
;
Depth3x3NormalRowLoadInput
<
Stride_w
>
(
input
+
h_in
*
input_w
+
input_w_offset
,
_x
);
_sum
=
vmlaq_lane_f32
(
_sum
,
_x
[
0
],
vget_low_f32
(
ker
[
index
]),
0
);
_sum
=
vmlaq_lane_f32
(
_sum
,
_x
[
1
],
vget_low_f32
(
ker
[
index
]),
1
);
_sum
=
vmlaq_lane_f32
(
_sum
,
_x
[
2
],
vget_high_f32
(
ker
[
index
]),
0
);
}
}
vst1q_f32
(
output_ptr
+
output_offset
,
_sum
);
}
}
#endif
// remain valid w
int
remain
=
(
valid_w_end
-
valid_w_start
)
&
0x3
;
if
(
remain
>
0
)
{
_sum
=
vdupq_n_f32
(
0.
f
);
int
remain_start
=
valid_w_start
+
(
output_tiles
<<
2
);
int
input_w_offset
=
remain_start
*
Stride_w
-
padding_w
;
float
*
output_ptr0
=
output_ptr
+
remain_start
;
for
(
int
h_in
=
h_start
;
h_in
<
h_end
;
++
h_in
)
{
int
index
=
h_in
-
h_in_start
;
Depth3x3NormalRowLoadInput
<
Stride_w
>
(
input
+
h_in
*
input_w
+
input_w_offset
,
_x
);
_sum
=
vmlaq_lane_f32
(
_sum
,
_x
[
0
],
vget_low_f32
(
ker
[
index
]),
0
);
_sum
=
vmlaq_lane_f32
(
_sum
,
_x
[
1
],
vget_low_f32
(
ker
[
index
]),
1
);
_sum
=
vmlaq_lane_f32
(
_sum
,
_x
[
2
],
vget_high_f32
(
ker
[
index
]),
0
);
}
switch
(
remain
)
{
case
3
:
vst1q_lane_f32
(
output_ptr0
+
2
,
_sum
,
2
);
case
2
:
vst1_f32
(
output_ptr0
,
vget_low_f32
(
_sum
));
break
;
case
1
:
vst1_lane_f32
(
output_ptr0
,
vget_low_f32
(
_sum
),
0
);
break
;
}
}
// border right
DEPTHWISE_CONV3X3_NORMAL_BORDER
(
valid_w_end
,
output_w
)
}
}
void
DepthwiseConvAddBNRelu3x3s1p1
(
const
framework
::
Tensor
*
input
,
template
<
>
const
framework
::
Tensor
*
filter
,
void
DepthwiseConv3x3S1
<
float
,
float
>
(
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
output
,
const
framework
::
Tensor
&
filter
,
const
framework
::
Tensor
*
new_scale
,
const
std
::
vector
<
int
>
&
paddings
,
const
framework
::
Tensor
*
new_bias
,
framework
::
Tensor
*
output
)
{
bool
if_relu
)
{
const
float
*
input_data
=
input
.
data
<
float
>
();
#if __ARM_NEON
const
float
*
filter_data
=
filter
.
data
<
float
>
();
const
float
*
input_data
=
input
->
data
<
float
>
();
float
*
out_data
=
output
->
mutable_data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
int
input_h
=
input
.
dims
()[
2
];
float
*
output_data
=
output
->
mutable_data
<
float
>
();
int
input_w
=
input
.
dims
()[
3
];
const
float
*
newscale_data
=
new_scale
->
data
<
float
>
();
int
output_h
=
output
->
dims
()[
2
];
const
float
*
newbias_data
=
new_bias
->
data
<
float
>
();
int
output_w
=
output
->
dims
()[
3
];
int
padding_h
=
paddings
[
0
];
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
int
padding_w
=
paddings
[
1
];
const
int
input_channel
=
static_cast
<
int
>
(
input
->
dims
()[
1
]);
int
image_size
=
input_h
*
input_w
;
int
out_image_size
=
output_h
*
output_w
;
const
int
input_height
=
static_cast
<
int
>
(
input
->
dims
()[
2
]);
int
valid_h_start
=
padding_h
;
const
int
input_width
=
static_cast
<
int
>
(
input
->
dims
()[
3
]);
int
valid_h_end
=
output_h
-
valid_h_start
;
const
int
output_height
=
static_cast
<
int
>
(
output
->
dims
()[
2
]);
int
valid_h
=
valid_h_end
-
valid_h_start
;
const
int
output_width
=
static_cast
<
int
>
(
output
->
dims
()[
3
]);
int
valid_w_start
=
padding_w
;
int
valid_w_end
=
output_w
-
valid_w_start
;
const
int
hxw
=
input_height
*
input_width
;
int
valid_w
=
valid_w_end
-
valid_w_start
;
// const int l = input_height;
#pragma omp parallel for
const
int
h
=
input_height
;
for
(
int
g
=
0
;
g
<
input
.
dims
()[
1
];
++
g
)
{
const
int
w
=
input_width
;
const
float
*
input_ptr
=
input_data
+
g
*
image_size
;
float32x4_t
vzero
=
vdupq_n_f32
(
0
);
const
float
*
filter_ptr
=
filter_data
+
g
*
9
;
float
*
output_ptr
=
out_data
+
g
*
out_image_size
;
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
#pragma omp parallel for
const
float
*
filter_ptr0
=
filter_ptr
;
for
(
int
c
=
0
;
c
<
input_channel
;
c
++
)
{
const
float
*
filter_ptr1
=
filter_ptr0
+
3
;
const
float
*
filter_data
=
filter
->
data
<
float
>
()
+
c
*
9
;
const
float
*
filter_ptr2
=
filter_ptr1
+
3
;
const
float
*
input_data
=
input
->
data
<
float
>
()
+
c
*
hxw
;
float32x4_t
_ker
[
3
];
float
*
output_data
=
output
->
data
<
float
>
()
+
c
*
hxw
;
_ker
[
0
]
=
vld1q_f32
(
filter_ptr0
);
float32x4_t
vnewbias
=
vdupq_n_f32
(
newbias_data
[
c
]);
_ker
[
1
]
=
vld1q_f32
(
filter_ptr1
);
float32x4_t
vnewscale
=
vdupq_n_f32
(
newscale_data
[
c
]);
_ker
[
2
]
=
vld1q_f32
(
filter_ptr2
);
float
w00
=
filter_data
[
0
];
// pad top
float
w01
=
filter_data
[
1
];
for
(
int
h
=
0
;
h
<
valid_h_start
;
++
h
)
{
float
w02
=
filter_data
[
2
];
DepthwiseConv3x3NormalRow
<
1
,
1
>
(
input_ptr
,
filter_ptr
,
h
,
input_h
,
float
w10
=
filter_data
[
3
];
input_w
,
padding_h
,
padding_w
,
output_w
,
float
w11
=
filter_data
[
4
];
output_ptr
,
_ker
);
float
w12
=
filter_data
[
5
];
}
float
w20
=
filter_data
[
6
];
float
w21
=
filter_data
[
7
];
float
w22
=
filter_data
[
8
];
for
(
int
i
=
1
;
i
<
output_height
-
1
;
i
++
)
{
float
*
output_ptr
;
float32x4_t
in0
,
in1
,
in2
,
in3
,
in4
,
in5
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
tmp4
,
tmp5
,
out0
;
for
(
int
m
=
1
;
m
<
output_width
-
4
;
m
+=
4
)
{
output_ptr
=
output_data
+
i
*
output_width
+
m
;
in0
=
vld1q_f32
(
input_data
+
(
i
-
1
)
*
input_width
+
m
-
1
);
in1
=
vld1q_f32
(
input_data
+
(
i
-
1
)
*
input_width
+
m
+
3
);
in2
=
vld1q_f32
(
input_data
+
i
*
input_width
+
m
-
1
);
in3
=
vld1q_f32
(
input_data
+
i
*
input_width
+
m
+
3
);
in4
=
vld1q_f32
(
input_data
+
(
i
+
1
)
*
input_width
+
m
-
1
);
in5
=
vld1q_f32
(
input_data
+
(
i
+
1
)
*
input_width
+
m
+
3
);
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
tmp4
=
vextq_f32
(
in4
,
in5
,
1
);
tmp5
=
vextq_f32
(
in4
,
in5
,
2
);
out0
=
vmulq_n_f32
(
in0
,
w00
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w01
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w02
);
out0
=
vmlaq_n_f32
(
out0
,
in2
,
w10
);
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w11
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w12
);
out0
=
vmlaq_n_f32
(
out0
,
in4
,
w20
);
out0
=
vmlaq_n_f32
(
out0
,
tmp4
,
w21
);
out0
=
vmlaq_n_f32
(
out0
,
tmp5
,
w22
);
out0
=
vmlaq_f32
(
vnewbias
,
vnewscale
,
out0
);
if
(
if_relu
)
{
out0
=
vmaxq_f32
(
out0
,
vzero
);
}
vst1q_f32
(
output_ptr
,
out0
);
}
int
m
;
for
(
m
=
1
;
(
m
+
3
)
<
output_width
-
1
;
m
=
m
+
4
)
{
}
for
(
int
j
=
m
;
j
<
output_width
-
1
;
j
++
)
{
// output 2x6
output_data
[
i
*
output_width
+
j
]
=
int
output_w_tiles
=
valid_w
/
6
;
input_data
[(
i
-
1
)
*
input_width
+
j
-
1
]
*
w00
+
int
output_w_remain
=
valid_w
-
output_w_tiles
*
6
;
input_data
[(
i
-
1
)
*
input_width
+
j
]
*
w01
+
for
(
int
h
=
valid_h_start
;
h
<
valid_h_end
-
1
;
h
+=
2
)
{
input_data
[(
i
-
1
)
*
input_width
+
j
+
1
]
*
w02
+
const
float
*
input_ptr0
=
input_ptr
+
(
h
-
padding_h
)
*
input_w
;
input_data
[(
i
)
*
input_width
+
j
-
1
]
*
w10
+
const
float
*
input_ptr1
=
input_ptr0
+
input_w
;
input_data
[(
i
)
*
input_width
+
j
]
*
w11
+
const
float
*
input_ptr2
=
input_ptr1
+
input_w
;
input_data
[(
i
)
*
input_width
+
j
+
1
]
*
w12
+
const
float
*
input_ptr3
=
input_ptr2
+
input_w
;
input_data
[(
i
+
1
)
*
input_width
+
j
-
1
]
*
w20
+
float
*
output_ptr0
=
output_ptr
+
h
*
output_w
;
input_data
[(
i
+
1
)
*
input_width
+
j
]
*
w21
+
float
*
output_ptr1
=
output_ptr0
+
output_w
;
input_data
[(
i
+
1
)
*
input_width
+
j
+
1
]
*
w22
;
// pad left
output_data
[
i
*
output_width
+
j
]
=
if
(
padding_w
)
{
newscale_data
[
c
]
*
output_data
[
i
*
output_width
+
j
]
+
float32x4_t
row0
=
vld1q_f32
(
input_ptr0
);
newbias_data
[
c
];
float32x4_t
row1
=
vld1q_f32
(
input_ptr1
);
if
(
if_relu
)
{
float32x4_t
row2
=
vld1q_f32
(
input_ptr2
);
output_data
[
i
*
output_width
+
j
]
=
float32x4_t
row3
=
vld1q_f32
(
input_ptr3
);
output_data
[
i
*
output_width
+
j
]
<
0
float32x4_t
zero
=
vdupq_n_f32
(
0.
f
);
?
0
row0
=
vextq_f32
(
zero
,
row0
,
3
);
:
output_data
[
i
*
output_width
+
j
];
row1
=
vextq_f32
(
zero
,
row1
,
3
);
row2
=
vextq_f32
(
zero
,
row2
,
3
);
row3
=
vextq_f32
(
zero
,
row3
,
3
);
float32x4_t
acc0
,
acc1
;
for
(
int
w
=
valid_w_start
-
1
;
w
>=
0
;
--
w
)
{
int
padding
=
padding_w
-
w
;
if
(
padding
>=
3
)
{
output_ptr0
[
w
]
=
0.
f
;
output_ptr1
[
w
]
=
0.
f
;
}
else
{
acc0
=
vmulq_f32
(
row0
,
_ker
[
0
]);
acc0
=
vmlaq_f32
(
acc0
,
row1
,
_ker
[
1
]);
acc0
=
vmlaq_f32
(
acc0
,
row2
,
_ker
[
2
]);
acc0
=
vextq_f32
(
acc0
,
acc0
,
1
);
acc1
=
vmulq_f32
(
row1
,
_ker
[
0
]);
acc1
=
vmlaq_f32
(
acc1
,
row2
,
_ker
[
1
]);
acc1
=
vmlaq_f32
(
acc1
,
row3
,
_ker
[
2
]);
acc1
=
vextq_f32
(
acc1
,
acc1
,
1
);
float32x2_t
sum
=
vpadd_f32
(
vget_low_f32
(
acc0
),
vget_low_f32
(
acc1
));
vst1_lane_f32
(
output_ptr0
+
w
,
sum
,
0
);
vst1_lane_f32
(
output_ptr1
+
w
,
sum
,
1
);
row0
=
vextq_f32
(
zero
,
row0
,
3
);
row1
=
vextq_f32
(
zero
,
row1
,
3
);
row2
=
vextq_f32
(
zero
,
row2
,
3
);
row3
=
vextq_f32
(
zero
,
row3
,
3
);
}
}
}
}
output_ptr0
+=
valid_w_start
;
output_ptr1
+=
valid_w_start
;
}
}
// valid
output_data
[
0
]
=
w11
*
input_data
[
0
]
+
w12
*
input_data
[
1
]
+
float32x4_t
_result0
,
_result1
,
_result2
,
_result3
;
w21
*
input_data
[
w
]
+
w22
*
input_data
[
w
+
1
];
for
(
int
loop
=
0
;
loop
<
output_w_tiles
;
++
loop
)
{
output_data
[
w
-
1
]
=
w10
*
input_data
[
w
-
2
]
+
w11
*
input_data
[
w
-
1
]
+
float32x4_t
_row00
=
vld1q_f32
(
input_ptr0
);
w20
*
input_data
[
2
*
w
-
2
]
+
float32x4_t
_row01
=
vld1q_f32
(
input_ptr0
+
4
);
w21
*
input_data
[
2
*
w
-
1
];
float32x4_t
_row10
=
vld1q_f32
(
input_ptr1
);
output_data
[(
h
-
1
)
*
w
]
=
float32x4_t
_row11
=
vld1q_f32
(
input_ptr1
+
4
);
w01
*
input_data
[(
h
-
2
)
*
w
]
+
w02
*
input_data
[(
h
-
2
)
*
w
+
1
]
+
w11
*
input_data
[(
h
-
1
)
*
w
]
+
w12
*
input_data
[(
h
-
1
)
*
w
+
1
];
float32x4_t
_ext01
=
vextq_f32
(
_row00
,
_row01
,
1
);
output_data
[
h
*
w
-
1
]
=
float32x4_t
_ext02
=
vextq_f32
(
_row00
,
_row01
,
2
);
w00
*
input_data
[
h
*
w
-
w
-
2
]
+
w01
*
input_data
[
h
*
w
-
w
-
1
]
+
float32x4_t
_ext03
=
vextq_f32
(
_row01
,
_row01
,
1
);
w10
*
input_data
[
h
*
w
-
2
]
+
w11
*
input_data
[
h
*
w
-
1
];
float32x4_t
_ext04
=
vextq_f32
(
_row01
,
_row01
,
2
);
output_data
[
0
]
=
output_data
[
0
]
*
newscale_data
[
c
]
+
newbias_data
[
c
];
output_data
[
w
-
1
]
=
_result0
=
vmulq_lane_f32
(
_row00
,
vget_low_f32
(
_ker
[
0
]),
0
);
output_data
[
w
-
1
]
*
newscale_data
[
c
]
+
newbias_data
[
c
];
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
0
]),
1
);
output_data
[(
h
-
1
)
*
w
]
=
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
0
]),
0
);
output_data
[(
h
-
1
)
*
w
]
*
newscale_data
[
c
]
+
newbias_data
[
c
];
_result1
=
vmulq_lane_f32
(
_row01
,
vget_low_f32
(
_ker
[
0
]),
0
);
output_data
[
h
*
w
-
1
]
=
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
0
]),
1
);
output_data
[
h
*
w
-
1
]
*
newscale_data
[
c
]
+
newbias_data
[
c
];
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
0
]),
0
);
if
(
if_relu
)
{
_ext01
=
vextq_f32
(
_row10
,
_row11
,
1
);
output_data
[
0
]
=
output_data
[
0
]
<
0
?
0
:
output_data
[
0
];
_ext02
=
vextq_f32
(
_row10
,
_row11
,
2
);
output_data
[
w
-
1
]
=
output_data
[
w
-
1
]
<
0
?
0
:
output_data
[
w
-
1
];
_ext03
=
vextq_f32
(
_row11
,
_row11
,
1
);
output_data
[(
h
-
1
)
*
w
]
=
_ext04
=
vextq_f32
(
_row11
,
_row11
,
2
);
output_data
[(
h
-
1
)
*
w
]
<
0
?
0
:
output_data
[(
h
-
1
)
*
w
];
output_data
[
h
*
w
-
1
]
=
_result0
=
vmlaq_lane_f32
(
_result0
,
_row10
,
vget_low_f32
(
_ker
[
1
]),
0
);
output_data
[
h
*
w
-
1
]
<
0
?
0
:
output_data
[
h
*
w
-
1
];
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
1
]),
1
);
}
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
1
]),
0
);
for
(
int
i
=
1
;
i
<
h
-
1
;
++
i
)
{
_result1
=
vmlaq_lane_f32
(
_result1
,
_row11
,
vget_low_f32
(
_ker
[
1
]),
0
);
output_data
[
i
*
w
]
=
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
1
]),
1
);
w01
*
input_data
[
i
*
w
-
w
]
+
w02
*
input_data
[
i
*
w
-
w
+
1
]
+
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
1
]),
0
);
w11
*
input_data
[
i
*
w
]
+
w12
*
input_data
[
i
*
w
+
1
]
+
w21
*
input_data
[
i
*
w
+
w
]
+
w22
*
input_data
[
i
*
w
+
w
+
1
];
_result2
=
vmulq_lane_f32
(
_row10
,
vget_low_f32
(
_ker
[
0
]),
0
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext01
,
vget_low_f32
(
_ker
[
0
]),
1
);
output_data
[
i
*
w
+
w
-
1
]
=
w00
*
input_data
[
i
*
w
+
w
-
1
-
w
-
1
]
+
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext02
,
vget_high_f32
(
_ker
[
0
]),
0
);
w01
*
input_data
[
i
*
w
+
w
-
1
-
w
]
+
_result3
=
vmulq_lane_f32
(
_row11
,
vget_low_f32
(
_ker
[
0
]),
0
);
w10
*
input_data
[
i
*
w
+
w
-
1
-
1
]
+
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext03
,
vget_low_f32
(
_ker
[
0
]),
1
);
w11
*
input_data
[
i
*
w
+
w
-
1
]
+
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext04
,
vget_high_f32
(
_ker
[
0
]),
0
);
w20
*
input_data
[
i
*
w
+
w
-
1
+
w
-
1
]
+
w21
*
input_data
[
i
*
w
+
w
-
1
+
w
];
_row00
=
vld1q_f32
(
input_ptr2
);
output_data
[
i
*
w
]
=
_row01
=
vld1q_f32
(
input_ptr2
+
4
);
output_data
[
i
*
w
]
*
newscale_data
[
c
]
+
newbias_data
[
c
];
_row10
=
vld1q_f32
(
input_ptr3
);
output_data
[
i
*
w
+
w
-
1
]
=
_row11
=
vld1q_f32
(
input_ptr3
+
4
);
output_data
[
i
*
w
+
w
-
1
]
*
newscale_data
[
c
]
+
newbias_data
[
c
];
_ext01
=
vextq_f32
(
_row00
,
_row01
,
1
);
if
(
if_relu
)
{
_ext02
=
vextq_f32
(
_row00
,
_row01
,
2
);
output_data
[
i
*
w
]
=
output_data
[
i
*
w
]
<
0
?
0
:
output_data
[
i
*
w
];
_ext03
=
vextq_f32
(
_row01
,
_row01
,
1
);
output_data
[
i
*
w
+
w
-
1
]
=
_ext04
=
vextq_f32
(
_row01
,
_row01
,
2
);
output_data
[
i
*
w
+
w
-
1
]
<
0
?
0
:
output_data
[
i
*
w
+
w
-
1
];
}
_result0
=
vmlaq_lane_f32
(
_result0
,
_row00
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
2
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row01
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
2
]),
0
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_row00
,
vget_low_f32
(
_ker
[
1
]),
0
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext01
,
vget_low_f32
(
_ker
[
1
]),
1
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext02
,
vget_high_f32
(
_ker
[
1
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_row01
,
vget_low_f32
(
_ker
[
1
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext03
,
vget_low_f32
(
_ker
[
1
]),
1
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext04
,
vget_high_f32
(
_ker
[
1
]),
0
);
_ext01
=
vextq_f32
(
_row10
,
_row11
,
1
);
_ext02
=
vextq_f32
(
_row10
,
_row11
,
2
);
_ext03
=
vextq_f32
(
_row11
,
_row11
,
1
);
_ext04
=
vextq_f32
(
_row11
,
_row11
,
2
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_row10
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext01
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext02
,
vget_high_f32
(
_ker
[
2
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_row11
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext03
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext04
,
vget_high_f32
(
_ker
[
2
]),
0
);
vst1q_f32
(
output_ptr0
,
_result0
);
vst1_f32
(
output_ptr0
+
4
,
vget_low_f32
(
_result1
));
vst1q_f32
(
output_ptr1
,
_result2
);
vst1_f32
(
output_ptr1
+
4
,
vget_low_f32
(
_result3
));
input_ptr0
+=
6
;
input_ptr1
+=
6
;
input_ptr2
+=
6
;
input_ptr3
+=
6
;
output_ptr0
+=
6
;
output_ptr1
+=
6
;
}
}
// remain w
int
m
;
if
(
output_w_remain
>
0
)
{
for
(
m
=
1
;
m
<
output_width
-
4
;
m
+=
4
)
{
float32x4_t
_row00
=
vld1q_f32
(
input_ptr0
);
float
*
output_ptr
=
output_data
+
m
;
float32x4_t
_row01
=
vld1q_f32
(
input_ptr0
+
4
);
float32x4_t
in0
,
in1
,
in2
,
in3
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
out0
;
float32x4_t
_row10
=
vld1q_f32
(
input_ptr1
);
in0
=
vld1q_f32
(
input_data
+
m
-
1
);
float32x4_t
_row11
=
vld1q_f32
(
input_ptr1
+
4
);
in1
=
vld1q_f32
(
input_data
+
m
+
3
);
in2
=
vld1q_f32
(
input_data
+
input_width
+
m
-
1
);
float32x4_t
_ext01
=
vextq_f32
(
_row00
,
_row01
,
1
);
in3
=
vld1q_f32
(
input_data
+
input_width
+
m
+
3
);
float32x4_t
_ext02
=
vextq_f32
(
_row00
,
_row01
,
2
);
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
float32x4_t
_ext03
=
vextq_f32
(
_row01
,
_row01
,
1
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
float32x4_t
_ext04
=
vextq_f32
(
_row01
,
_row01
,
2
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
_result0
=
vmulq_lane_f32
(
_row00
,
vget_low_f32
(
_ker
[
0
]),
0
);
out0
=
vmulq_n_f32
(
in0
,
w10
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
0
]),
1
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w11
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
0
]),
0
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w12
);
_result1
=
vmulq_lane_f32
(
_row01
,
vget_low_f32
(
_ker
[
0
]),
0
);
out0
=
vmlaq_n_f32
(
out0
,
in2
,
w20
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
0
]),
1
);
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w21
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
0
]),
0
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w22
);
out0
=
vmlaq_f32
(
vnewbias
,
vnewscale
,
out0
);
_ext01
=
vextq_f32
(
_row10
,
_row11
,
1
);
if
(
if_relu
)
{
_ext02
=
vextq_f32
(
_row10
,
_row11
,
2
);
out0
=
vmaxq_f32
(
out0
,
vzero
);
_ext03
=
vextq_f32
(
_row11
,
_row11
,
1
);
_ext04
=
vextq_f32
(
_row11
,
_row11
,
2
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_row10
,
vget_low_f32
(
_ker
[
1
]),
0
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
1
]),
1
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
1
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row11
,
vget_low_f32
(
_ker
[
1
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
1
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
1
]),
0
);
_result2
=
vmulq_lane_f32
(
_row10
,
vget_low_f32
(
_ker
[
0
]),
0
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext01
,
vget_low_f32
(
_ker
[
0
]),
1
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext02
,
vget_high_f32
(
_ker
[
0
]),
0
);
_result3
=
vmulq_lane_f32
(
_row11
,
vget_low_f32
(
_ker
[
0
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext03
,
vget_low_f32
(
_ker
[
0
]),
1
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext04
,
vget_high_f32
(
_ker
[
0
]),
0
);
_row00
=
vld1q_f32
(
input_ptr2
);
_row01
=
vld1q_f32
(
input_ptr2
+
4
);
_row10
=
vld1q_f32
(
input_ptr3
);
_row11
=
vld1q_f32
(
input_ptr3
+
4
);
_ext01
=
vextq_f32
(
_row00
,
_row01
,
1
);
_ext02
=
vextq_f32
(
_row00
,
_row01
,
2
);
_ext03
=
vextq_f32
(
_row01
,
_row01
,
1
);
_ext04
=
vextq_f32
(
_row01
,
_row01
,
2
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_row00
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
2
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row01
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
2
]),
0
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_row00
,
vget_low_f32
(
_ker
[
1
]),
0
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext01
,
vget_low_f32
(
_ker
[
1
]),
1
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext02
,
vget_high_f32
(
_ker
[
1
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_row01
,
vget_low_f32
(
_ker
[
1
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext03
,
vget_low_f32
(
_ker
[
1
]),
1
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext04
,
vget_high_f32
(
_ker
[
1
]),
0
);
_ext01
=
vextq_f32
(
_row10
,
_row11
,
1
);
_ext02
=
vextq_f32
(
_row10
,
_row11
,
2
);
_ext03
=
vextq_f32
(
_row11
,
_row11
,
1
);
_ext04
=
vextq_f32
(
_row11
,
_row11
,
2
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_row10
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext01
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result2
=
vmlaq_lane_f32
(
_result2
,
_ext02
,
vget_high_f32
(
_ker
[
2
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_row11
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext03
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result3
=
vmlaq_lane_f32
(
_result3
,
_ext04
,
vget_high_f32
(
_ker
[
2
]),
0
);
switch
(
output_w_remain
)
{
case
5
:
vst1q_lane_f32
(
output_ptr0
+
4
,
_result1
,
0
);
vst1q_lane_f32
(
output_ptr1
+
4
,
_result3
,
0
);
case
4
:
vst1q_f32
(
output_ptr0
,
_result0
);
vst1q_f32
(
output_ptr1
,
_result2
);
break
;
case
3
:
vst1q_lane_f32
(
output_ptr0
+
2
,
_result0
,
2
);
vst1q_lane_f32
(
output_ptr1
+
2
,
_result2
,
2
);
case
2
:
vst1_f32
(
output_ptr0
,
vget_low_f32
(
_result0
));
vst1_f32
(
output_ptr1
,
vget_low_f32
(
_result2
));
break
;
case
1
:
vst1q_lane_f32
(
output_ptr0
,
_result0
,
0
);
vst1q_lane_f32
(
output_ptr1
,
_result2
,
0
);
break
;
}
}
vst1q_f32
(
output_ptr
,
out0
);
}
for
(
m
=
1
;
(
m
+
3
)
<
output_width
-
1
;
m
+=
4
)
{
input_ptr0
+=
output_w_remain
;
input_ptr1
+=
output_w_remain
;
input_ptr2
+=
output_w_remain
;
input_ptr3
+=
output_w_remain
;
output_ptr0
+=
output_w_remain
;
output_ptr1
+=
output_w_remain
;
}
}
for
(
int
j
=
m
;
j
<
output_width
-
1
;
j
++
)
{
// pad right
output_data
[
j
]
=
input_data
[
j
-
1
]
*
w10
+
input_data
[
j
]
*
w11
+
if
(
padding_w
)
{
input_data
[
j
+
1
]
*
w12
+
float32x2_t
row0
=
vld1_f32
(
input_ptr0
);
input_data
[
input_width
+
j
-
1
]
*
w20
+
float32x2_t
row1
=
vld1_f32
(
input_ptr1
);
input_data
[
input_width
+
j
]
*
w21
+
float32x2_t
row2
=
vld1_f32
(
input_ptr2
);
input_data
[
input_width
+
j
+
1
]
*
w22
;
float32x2_t
row3
=
vld1_f32
(
input_ptr3
);
output_data
[
j
]
=
output_data
[
j
]
*
newscale_data
[
c
]
+
newbias_data
[
c
];
float32x2_t
zero
=
vdup_n_f32
(
0.
f
);
float32x2_t
acc0
,
acc1
;
if
(
if_relu
)
{
for
(
int
w
=
valid_w_end
;
w
<
output_w
;
++
w
)
{
output_data
[
j
]
=
output_data
[
j
]
<
0
?
0
:
output_data
[
j
];
int
padding
=
w
+
3
-
(
padding_w
+
input_w
);
if
(
padding
>=
3
)
{
*
output_ptr0
=
0.
f
;
*
output_ptr1
=
0.
f
;
}
else
{
acc0
=
vmul_f32
(
row0
,
vget_low_f32
(
_ker
[
0
]));
acc0
=
vmla_f32
(
acc0
,
row1
,
vget_low_f32
(
_ker
[
1
]));
acc0
=
vmla_f32
(
acc0
,
row2
,
vget_low_f32
(
_ker
[
2
]));
acc1
=
vmul_f32
(
row1
,
vget_low_f32
(
_ker
[
0
]));
acc1
=
vmla_f32
(
acc1
,
row2
,
vget_low_f32
(
_ker
[
1
]));
acc1
=
vmla_f32
(
acc1
,
row3
,
vget_low_f32
(
_ker
[
2
]));
float32x2_t
sum
=
vpadd_f32
(
acc0
,
acc1
);
vst1_lane_f32
(
output_ptr0
,
sum
,
0
);
vst1_lane_f32
(
output_ptr1
,
sum
,
1
);
row0
=
vext_f32
(
row0
,
zero
,
1
);
row1
=
vext_f32
(
row1
,
zero
,
1
);
row2
=
vext_f32
(
row2
,
zero
,
1
);
row3
=
vext_f32
(
row3
,
zero
,
1
);
}
output_ptr0
++
;
output_ptr1
++
;
}
}
}
}
}
for
(
m
=
1
;
m
<
output_width
-
4
;
m
+=
4
)
{
// remain height
float
*
output_ptr
=
int
start_h
=
valid_h_start
+
(
valid_h
&
0xfffe
);
output_data
+
(
output_height
-
1
)
*
output_width
+
m
;
if
(
start_h
<
valid_h_end
)
{
const
float
*
input_ptr0
=
input_ptr
+
(
start_h
-
padding_h
)
*
input_w
;
float32x4_t
in0
,
in1
,
in2
,
in3
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
out0
;
const
float
*
input_ptr1
=
input_ptr0
+
input_w
;
in0
=
vld1q_f32
(
input_data
+
(
output_height
-
2
)
*
input_width
+
m
-
1
);
const
float
*
input_ptr2
=
input_ptr1
+
input_w
;
in1
=
vld1q_f32
(
input_data
+
(
output_height
-
2
)
*
input_width
+
m
+
3
);
float
*
output_ptr0
=
output_ptr
+
start_h
*
output_w
;
in2
=
vld1q_f32
(
input_data
+
(
output_height
-
1
)
*
input_width
+
m
-
1
);
// pad left
in3
=
vld1q_f32
(
input_data
+
(
output_height
-
1
)
*
input_width
+
m
+
3
);
if
(
padding_w
)
{
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
float32x4_t
row0
=
vld1q_f32
(
input_ptr0
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
float32x4_t
row1
=
vld1q_f32
(
input_ptr1
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
float32x4_t
row2
=
vld1q_f32
(
input_ptr2
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
float32x4_t
zero
=
vdupq_n_f32
(
0.
f
);
out0
=
vmulq_n_f32
(
in0
,
w00
);
row0
=
vextq_f32
(
zero
,
row0
,
3
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w01
);
row1
=
vextq_f32
(
zero
,
row1
,
3
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w02
);
row2
=
vextq_f32
(
zero
,
row2
,
3
);
out0
=
vmlaq_n_f32
(
out0
,
in2
,
w10
);
float32x4_t
acc
;
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w11
);
for
(
int
w
=
valid_w_start
-
1
;
w
>=
0
;
--
w
)
{
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w12
);
int
padding
=
padding_w
-
w
;
out0
=
vmlaq_f32
(
vnewbias
,
vnewscale
,
out0
);
if
(
padding
>=
3
)
{
if
(
if_relu
)
{
output_ptr0
[
w
]
=
0.
f
;
out0
=
vmaxq_f32
(
out0
,
vzero
);
}
else
{
acc
=
vmulq_f32
(
row0
,
_ker
[
0
]);
acc
=
vmlaq_f32
(
acc
,
row1
,
_ker
[
1
]);
acc
=
vmlaq_f32
(
acc
,
row2
,
_ker
[
2
]);
acc
=
vextq_f32
(
acc
,
acc
,
1
);
float32x2_t
sum
=
vpadd_f32
(
vget_low_f32
(
acc
),
vget_low_f32
(
acc
));
vst1_lane_f32
(
output_ptr0
+
w
,
sum
,
0
);
row0
=
vextq_f32
(
zero
,
row0
,
3
);
row1
=
vextq_f32
(
zero
,
row1
,
3
);
row2
=
vextq_f32
(
zero
,
row2
,
3
);
}
}
}
vst1q_f32
(
output_ptr
,
out0
)
;
output_ptr0
+=
valid_w_start
;
}
}
for
(
m
=
1
;
(
m
+
3
)
<
output_width
-
1
;
m
=
m
+
4
)
{
// valid
float32x4_t
_result0
,
_result1
;
for
(
int
loop
=
0
;
loop
<
output_w_tiles
;
++
loop
)
{
float32x4_t
_row00
=
vld1q_f32
(
input_ptr0
);
float32x4_t
_row01
=
vld1q_f32
(
input_ptr0
+
4
);
float32x4_t
_row10
=
vld1q_f32
(
input_ptr1
);
float32x4_t
_row11
=
vld1q_f32
(
input_ptr1
+
4
);
float32x4_t
_ext01
=
vextq_f32
(
_row00
,
_row01
,
1
);
float32x4_t
_ext02
=
vextq_f32
(
_row00
,
_row01
,
2
);
float32x4_t
_ext03
=
vextq_f32
(
_row01
,
_row01
,
1
);
float32x4_t
_ext04
=
vextq_f32
(
_row01
,
_row01
,
2
);
_result0
=
vmulq_lane_f32
(
_row00
,
vget_low_f32
(
_ker
[
0
]),
0
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
0
]),
1
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
0
]),
0
);
_result1
=
vmulq_lane_f32
(
_row01
,
vget_low_f32
(
_ker
[
0
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
0
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
0
]),
0
);
_ext01
=
vextq_f32
(
_row10
,
_row11
,
1
);
_ext02
=
vextq_f32
(
_row10
,
_row11
,
2
);
_ext03
=
vextq_f32
(
_row11
,
_row11
,
1
);
_ext04
=
vextq_f32
(
_row11
,
_row11
,
2
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_row10
,
vget_low_f32
(
_ker
[
1
]),
0
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
1
]),
1
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
1
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row11
,
vget_low_f32
(
_ker
[
1
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
1
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
1
]),
0
);
_row00
=
vld1q_f32
(
input_ptr2
);
_row01
=
vld1q_f32
(
input_ptr2
+
4
);
_ext01
=
vextq_f32
(
_row00
,
_row01
,
1
);
_ext02
=
vextq_f32
(
_row00
,
_row01
,
2
);
_ext03
=
vextq_f32
(
_row01
,
_row01
,
1
);
_ext04
=
vextq_f32
(
_row01
,
_row01
,
2
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_row00
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
2
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row01
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
2
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
2
]),
0
);
vst1q_f32
(
output_ptr0
,
_result0
);
vst1_f32
(
output_ptr0
+
4
,
vget_low_f32
(
_result1
));
input_ptr0
+=
6
;
input_ptr1
+=
6
;
input_ptr2
+=
6
;
output_ptr0
+=
6
;
}
}
for
(
int
j
=
m
;
j
<
output_width
-
1
;
j
++
)
{
output_data
[(
output_height
-
1
)
*
input_width
+
j
]
=
input_data
[(
output_height
-
2
)
*
input_width
+
j
-
1
]
*
w00
+
input_data
[(
output_height
-
2
)
*
input_width
+
j
]
*
w01
+
input_data
[(
output_height
-
2
)
*
input_width
+
j
+
1
]
*
w02
+
input_data
[(
output_height
-
1
)
*
input_width
+
j
-
1
]
*
w10
+
input_data
[(
output_height
-
1
)
*
input_width
+
j
]
*
w11
+
input_data
[(
output_height
-
1
)
*
input_width
+
j
+
1
]
*
w12
;
output_data
[(
output_height
-
1
)
*
output_width
+
j
]
=
output_data
[(
output_height
-
1
)
*
output_width
+
j
]
*
newscale_data
[
c
]
+
newbias_data
[
c
];
if
(
if_relu
)
{
output_data
[(
output_height
-
1
)
*
output_width
+
j
]
=
output_data
[(
output_height
-
1
)
*
output_width
+
j
]
<
0
?
0
:
output_data
[(
output_height
-
1
)
*
output_width
+
j
];
}
}
}
}
/*
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
// const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[w] + w22 * input_data[w + 1];
output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w -
1] + w20 * input_data[2 * w - 2] + w21 * input_data[2 * w - 1];
output_data[(h - 1) * w] =
w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w +
1] + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1];
output_data[h * w - 1] = w00 * input_data[h*w-w-2] +
w01 * input_data[h*w-w-1] +
w10 * input_data[h * w - 2] +
w11 * input_data[h * w - 1];
output_data[0] = output_data[0] * newscale_data[j] +
newbias_data[j]; output_data[w - 1] = output_data[w - 1] *
newscale_data[j] + newbias_data[j]; output_data[(h - 1) * w] =
output_data[(h - 1) * w] * newscale_data[j] + newbias_data[j];
output_data[h * w - 1] =
output_data[h * w - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w -
1]; output_data[(h - 1) * w] = output_data[(h - 1) * w] < 0 ? 0 :
output_data[(h - 1) * w]; output_data[h * w - 1] = output_data[h * w - 1]
< 0 ? 0 : output_data[h * w - 1];
}
for (int i = 1; i < h - 1; ++i) {
output_data[i * w] =
w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1]
+ w11 * input_data[i * w] + w12 * input_data[i * w + 1] + w21 *
input_data[i * w + w] + w22 * input_data[i * w + w + 1]; output_data[i *
w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + w01 * input_data[i
* w + w - 1 - w] + w10 * input_data[i * w + w - 1 - 1] + w11 *
input_data[i * w + w - 1] + w20 * input_data[i * w + w - 1 + w - 1] + w21
* input_data[i * w + w - 1 + w]; output_data[i * w] = output_data[i * w]
* newscale_data[j] + newbias_data[j]; output_data[i * w + w - 1] =
output_data[i * w + w - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i
* w]; output_data[i * w + w - 1] = output_data[i * w + w - 1] < 0 ? 0 :
output_data[i * w + w - 1];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1,
tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); in2 =
vld1q_f32(input_tmp + w); const float *input_tmp_end = input_tmp + (h -
2) * w; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end +
w); int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid >
3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); in3 =
vld1q_f32(input_tmp + w + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (h - 1) * w, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
if
(
output_w_remain
>
0
)
{
float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]);
float32x4_t
_row00
=
vld1q_f32
(
input_ptr0
);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]);
float32x4_t
_row01
=
vld1q_f32
(
input_ptr0
+
4
);
float32x4_t
_row10
=
vld1q_f32
(
input_ptr1
);
tmp0 = vextq_f32(in0, pad0, 1);
float32x4_t
_row11
=
vld1q_f32
(
input_ptr1
+
4
);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
float32x4_t
_ext01
=
vextq_f32
(
_row00
,
_row01
,
1
);
tmp3 = vextq_f32(in2, pad1, 2);
float32x4_t
_ext02
=
vextq_f32
(
_row00
,
_row01
,
2
);
float32x4_t
_ext03
=
vextq_f32
(
_row01
,
_row01
,
1
);
out0 = vmulq_n_f32(in0, w10);
float32x4_t
_ext04
=
vextq_f32
(
_row01
,
_row01
,
2
);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
_result0
=
vmulq_lane_f32
(
_row00
,
vget_low_f32
(
_ker
[
0
]),
0
);
out0 = vmlaq_n_f32(out0, in2, w20);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
0
]),
1
);
out0 = vmlaq_n_f32(out0, tmp2, w21);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
0
]),
0
);
out0 = vmlaq_n_f32(out0, tmp3, w22);
_result1
=
vmulq_lane_f32
(
_row01
,
vget_low_f32
(
_ker
[
0
]),
0
);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
0
]),
1
);
if (if_relu) {
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
0
]),
0
);
out0 = vmaxq_f32(out0, vzero);
}
_ext01
=
vextq_f32
(
_row10
,
_row11
,
1
);
for (int i = 0; i < c_mid; ++i) {
_ext02
=
vextq_f32
(
_row10
,
_row11
,
2
);
if (i == 0) {
_ext03
=
vextq_f32
(
_row11
,
_row11
,
1
);
vst1q_lane_f32(output_ptr + i, out0, 0);
_ext04
=
vextq_f32
(
_row11
,
_row11
,
2
);
}
if (i == 1) {
_result0
=
vmlaq_lane_f32
(
_result0
,
_row10
,
vget_low_f32
(
_ker
[
1
]),
0
);
vst1q_lane_f32(output_ptr + i, out0, 1);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
1
]),
1
);
}
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
1
]),
0
);
if (i == 2) {
_result1
=
vmlaq_lane_f32
(
_result1
,
_row11
,
vget_low_f32
(
_ker
[
1
]),
0
);
vst1q_lane_f32(output_ptr + i, out0, 2);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
1
]),
1
);
}
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
1
]),
0
);
}
_row00
=
vld1q_f32
(
input_ptr2
);
// bottom right pad
_row01
=
vld1q_f32
(
input_ptr2
+
4
);
float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]);
float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]);
_ext01
=
vextq_f32
(
_row00
,
_row01
,
1
);
_ext02
=
vextq_f32
(
_row00
,
_row01
,
2
);
tmp0 = vextq_f32(in4, pad2, 1);
_ext03
=
vextq_f32
(
_row01
,
_row01
,
1
);
tmp1 = vextq_f32(in4, pad2, 2);
_ext04
=
vextq_f32
(
_row01
,
_row01
,
2
);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
_result0
=
vmlaq_lane_f32
(
_result0
,
_row00
,
vget_low_f32
(
_ker
[
2
]),
0
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext01
,
vget_low_f32
(
_ker
[
2
]),
1
);
out0 = vmulq_n_f32(in4, w00);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext02
,
vget_high_f32
(
_ker
[
2
]),
0
);
out0 = vmlaq_n_f32(out0, tmp0, w01);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row01
,
vget_low_f32
(
_ker
[
2
]),
0
);
out0 = vmlaq_n_f32(out0, tmp1, w02);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext03
,
vget_low_f32
(
_ker
[
2
]),
1
);
out0 = vmlaq_n_f32(out0, in6, w10);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext04
,
vget_high_f32
(
_ker
[
2
]),
0
);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
switch
(
output_w_remain
)
{
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
case
5
:
if (if_relu) {
vst1q_lane_f32
(
output_ptr0
+
4
,
_result1
,
0
);
out0 = vmaxq_f32(out0, vzero);
case
4
:
}
vst1q_f32
(
output_ptr0
,
_result0
);
for (int i = 0; i < c_mid; ++i) {
break
;
if (i == 0) {
case
3
:
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0);
vst1q_lane_f32
(
output_ptr0
+
2
,
_result0
,
2
);
}
case
2
:
if (i == 1) {
vst1_f32
(
output_ptr0
,
vget_low_f32
(
_result0
));
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1);
break
;
}
case
1
:
if (i == 2) {
vst1q_lane_f32
(
output_ptr0
,
_result0
,
0
);
vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2);
break
;
}
}
// mid
for (int i = 0; i < h - 2; ++i) {
auto output_ptr = output_data + (i + 1) * w + 1;
input_tmp = input_data + i * w;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + w);
auto in4_tmp = vld1q_f32(input_tmp + w + w);
c_mid = w_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + w + 4);
auto in5_tmp = vld1q_f32(input_tmp + w + w + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
}
}
}
*/
#endif
input_ptr0
+=
output_w_remain
;
}
input_ptr1
+=
output_w_remain
;
input_ptr2
+=
output_w_remain
;
/// w!=h not fix
output_ptr0
+=
output_w_remain
;
void
DepthwiseConvAddBNRelu3x3s2p1
(
const
framework
::
Tensor
*
input
,
}
const
framework
::
Tensor
*
filter
,
// pad right
framework
::
Tensor
*
output
,
if
(
padding_w
)
{
const
framework
::
Tensor
*
new_scale
,
float32x2_t
row0
=
vld1_f32
(
input_ptr0
);
const
framework
::
Tensor
*
new_bias
,
float32x2_t
row1
=
vld1_f32
(
input_ptr1
);
bool
if_relu
)
{
float32x2_t
row2
=
vld1_f32
(
input_ptr2
);
#if __ARM_NEON
float32x2_t
zero
=
vdup_n_f32
(
0.
f
);
float32x2_t
acc
;
const
int
batch_size
=
input
->
dims
()[
0
];
for
(
int
w
=
valid_w_end
;
w
<
output_w
;
++
w
)
{
int
padding
=
w
+
3
-
(
padding_w
+
input_w
);
const
int
input_height
=
input
->
dims
()[
2
];
if
(
padding
>=
3
)
{
*
output_ptr0
=
0.
f
;
const
int
input_width
=
input
->
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
_kernel_size
=
3
;
const
int
stride_height
=
2
;
const
int
stride_width
=
2
;
const
int
padding_height
=
1
;
const
int
padding_width
=
1
;
const
float
zero
=
0
;
const
int
input_channel_stride
=
input_height
*
input_width
;
const
int
output_channel_stride
=
output_height
*
output_width
;
const
int
filter_channel_stride
=
9
;
const
float
*
newscale_data
=
new_scale
->
data
<
float
>
();
const
float
*
newbias_data
=
new_bias
->
data
<
float
>
();
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
float
*
output_data
=
output
->
mutable_data
<
float
>
();
const
int
input_batch_stride
=
output_channels
*
input_channel_stride
;
const
int
output_batch_stride
=
output_channels
*
output_channel_stride
;
const
int
filter_batch_stride
=
output_channels
*
output_channel_stride
;
const
float
*
pos1
,
*
pos2
,
*
pos3
,
*
filter1
,
*
filter2
,
*
filter3
,
*
output_ptr
;
int
hstart
,
wstart
,
hend
,
wend
;
float
result
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
filter1
=
filter_data
;
filter2
=
filter1
+
3
;
filter3
=
filter2
+
3
;
for
(
int
ph
=
0
;
ph
<
output_height
;
ph
++
)
{
for
(
int
pw
=
0
;
pw
<
output_width
;
pw
++
)
{
hstart
=
ph
*
stride_height
-
padding_height
;
wstart
=
pw
*
stride_width
-
padding_width
;
hend
=
std
::
min
(
hstart
+
_kernel_size
,
input_height
+
padding_height
);
wend
=
std
::
min
(
wstart
+
_kernel_size
,
input_width
+
padding_width
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
std
::
min
(
hend
,
input_height
);
wend
=
std
::
min
(
wend
,
input_width
);
pos1
=
input_data
+
hstart
*
input_width
+
wstart
;
pos2
=
input_data
+
(
hstart
+
1
)
*
input_width
+
wstart
;
pos3
=
input_data
+
(
hstart
+
2
)
*
input_width
+
wstart
;
output_ptr
=
output_data
+
ph
*
output_width
+
pw
;
if
(
hend
-
hstart
!=
3
||
wend
-
wstart
!=
3
)
{
result
=
0
;
float
fake_input
[
9
]
=
{
0
};
if
(
hstart
==
0
&&
wstart
==
0
)
{
// 左上角
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
>=
3
-
hend
&&
k
>=
3
-
wend
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
-
(
3
-
hend
))
*
input_width
+
k
-
(
3
-
wend
)];
}
}
}
}
else
if
(
hstart
==
0
&&
wend
==
input_width
)
{
// 右上角
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
>=
3
-
hend
&&
k
<=
input_width
-
wstart
-
1
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
-
(
3
-
hend
))
*
input_width
+
k
+
wstart
];
}
}
}
}
else
if
(
hend
==
input_height
&&
wstart
==
0
)
{
// 左下角
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
<=
input_height
-
1
-
hstart
&&
k
>=
3
-
wend
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
k
-
(
3
-
wend
)];
}
}
}
}
else
if
(
hend
==
input_height
&&
wend
==
input_width
)
{
// 右下角
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
<=
input_height
-
hstart
-
1
&&
k
<=
input_width
-
wstart
-
1
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
k
+
wstart
];
}
}
}
}
else
if
(
hstart
==
0
)
{
// 顶部
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
>=
3
-
hend
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
-
(
3
-
hend
))
*
input_width
+
k
+
wstart
];
}
}
}
}
else
if
(
hend
==
input_height
)
{
// 底部
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
j
<=
input_height
-
hstart
-
1
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
k
+
wstart
];
}
}
}
}
else
if
(
wstart
==
0
)
{
// 左侧
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
k
>=
3
-
wend
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
(
k
-
(
3
-
wend
))];
}
}
}
}
else
if
(
wend
==
input_width
)
{
// 右侧
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
for
(
int
k
=
0
;
k
<
3
;
++
k
)
{
if
(
k
<=
input_width
-
wstart
-
1
)
{
fake_input
[
3
*
j
+
k
]
=
input_data
[(
j
+
hstart
)
*
input_width
+
k
+
wstart
];
}
}
}
}
for
(
int
l
=
0
;
l
<
9
;
++
l
)
{
result
+=
fake_input
[
l
]
*
filter1
[
l
];
}
output_data
[
ph
*
output_width
+
pw
]
=
newscale_data
[
c
]
*
result
+
newbias_data
[
c
];
if
(
if_relu
)
{
output_data
[
ph
*
output_width
+
pw
]
=
output_data
[
ph
*
output_width
+
pw
]
<
0
?
0
:
output_data
[
ph
*
output_width
+
pw
];
}
}
else
{
}
else
{
const
float32x4_t
data1
=
vld1q_f32
(
pos1
);
acc
=
vmul_f32
(
row0
,
vget_low_f32
(
_ker
[
0
]));
const
float32x4_t
data2
=
vld1q_f32
(
pos2
);
acc
=
vmla_f32
(
acc
,
row1
,
vget_low_f32
(
_ker
[
1
]));
const
float32x4_t
data3
=
vld1q_f32
(
pos3
);
acc
=
vmla_f32
(
acc
,
row2
,
vget_low_f32
(
_ker
[
2
]));
float32x2_t
sum
=
vpadd_f32
(
acc
,
acc
);
const
float32x4_t
v_filter1
=
vld1q_f32
(
filter1
);
vst1_lane_f32
(
output_ptr0
,
sum
,
0
);
const
float32x4_t
v_filter2
=
vld1q_f32
(
filter2
);
row0
=
vext_f32
(
row0
,
zero
,
1
);
const
float32x4_t
v_filter3
=
vld1q_f32
(
filter3
);
row1
=
vext_f32
(
row1
,
zero
,
1
);
float32x4_t
mula
=
vmulq_f32
(
data1
,
v_filter1
);
row2
=
vext_f32
(
row2
,
zero
,
1
);
mula
=
vmlaq_f32
(
mula
,
data2
,
v_filter2
);
mula
=
vmlaq_f32
(
mula
,
data3
,
v_filter3
);
float32x2_t
res
=
vpadd_f32
(
vget_high_f32
(
vsetq_lane_f32
(
0
,
mula
,
3
)),
vget_low_f32
(
mula
));
res
=
vpadd_f32
(
res
,
res
);
output_data
[
ph
*
output_width
+
pw
]
=
vget_lane_f32
(
res
,
0
)
*
newscale_data
[
c
]
+
newbias_data
[
c
];
if
(
if_relu
)
{
output_data
[
ph
*
output_width
+
pw
]
=
output_data
[
ph
*
output_width
+
pw
]
<
0
?
0
:
output_data
[
ph
*
output_width
+
pw
];
}
}
}
output_ptr0
++
;
}
}
}
}
input_data
+=
input_channel_stride
;
output_data
+=
output_channel_stride
;
filter_data
+=
filter_channel_stride
;
}
}
input_data
+=
input_batch_stride
;
// pad bottom
output_data
+=
output_batch_stride
;
for
(
int
h
=
valid_h_end
;
h
<
output_h
;
++
h
)
{
DepthwiseConv3x3NormalRow
<
1
,
1
>
(
input_ptr
,
filter_ptr
,
h
,
input_h
,
input_w
,
padding_h
,
padding_w
,
output_w
,
output_ptr
,
_ker
);
}
}
}
#endif
}
}
void
DepthwiseConv3x3s2p1v2
(
const
framework
::
Tensor
*
input
,
template
<
>
const
framework
::
Tensor
*
filter
,
void
DepthwiseConv3x3S2
<
float
,
float
>
(
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias
,
const
framework
::
Tensor
&
filter
,
bool
if_bias
,
bool
if_relu
)
{
const
std
::
vector
<
int
>
&
paddings
,
#if __ARM_NEON
framework
::
Tensor
*
output
)
{
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
input_data
=
input
.
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
const
float
*
filter_data
=
filter
.
data
<
float
>
();
float
*
output_data
=
output
->
mutable_data
<
float
>
();
float
*
out_data
=
output
->
mutable_data
<
float
>
();
const
float
*
bias_data
;
int
input_h
=
input
.
dims
()[
2
];
if
(
if_bias
)
{
int
input_w
=
input
.
dims
()[
3
];
bias_data
=
bias
->
data
<
float
>
();
int
output_h
=
output
->
dims
()[
2
];
}
int
output_w
=
output
->
dims
()[
3
];
int
padding_h
=
paddings
[
0
];
const
int
in_h
=
static_cast
<
int
>
(
input
->
dims
()[
2
]);
int
padding_w
=
paddings
[
1
];
const
int
in_w
=
static_cast
<
int
>
(
input
->
dims
()[
3
]);
int
image_size
=
input_h
*
input_w
;
const
int
out_h
=
static_cast
<
int
>
(
output
->
dims
()[
2
]);
int
out_image_size
=
output_h
*
output_w
;
const
int
out_w
=
static_cast
<
int
>
(
output
->
dims
()[
3
]);
int
valid_h_start
=
(
padding_h
+
1
)
/
2
;
const
int
out_l
=
out_h
;
int
valid_h_end
=
(
input_h
+
padding_h
-
1
)
/
2
;
const
int
in_l
=
in_h
;
int
valid_h
=
valid_h_end
-
valid_h_start
;
const
int
inhxw
=
in_h
*
in_w
;
int
valid_w_start
=
(
padding_w
+
1
)
/
2
;
const
int
outhxw
=
out_h
*
out_w
;
int
valid_w_end
=
(
input_w
+
padding_w
-
1
)
/
2
;
/// todo : fix if_pad when w != h
int
valid_w
=
valid_w_end
-
valid_w_start
;
const
int
if_pad_r
=
in_w
-
1
==
(
out_w
-
1
)
*
2
?
1
:
0
;
int
input_w_start
=
2
*
valid_w_start
-
padding_w
;
const
int
if_pad_b
=
in_h
-
1
==
(
out_h
-
1
)
*
2
?
1
:
0
;
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
#pragma omp parallel for
const
int
c
=
static_cast
<
int
>
(
input
->
dims
()[
1
]);
for
(
int
g
=
0
;
g
<
input
.
dims
()[
1
];
++
g
)
{
const
float
*
input_row_ptr
;
const
float
*
input_ptr
=
input_data
+
g
*
image_size
;
float
*
output_row_ptr
;
const
float
*
filter_ptr
=
filter_data
+
g
*
9
;
float
*
output_ptr
=
out_data
+
g
*
out_image_size
;
const
int
w_times
=
(
out_w
-
2
)
/
3
;
const
float
*
filter_ptr0
=
filter_ptr
;
float32x4_t
vbias
=
vdupq_n_f32
(
0.0
);
const
float
*
filter_ptr1
=
filter_ptr0
+
3
;
const
float
*
filter_ptr2
=
filter_ptr1
+
3
;
float32x4x2_t
input_buff_mid
{},
input_buff_bottom
[
w_times
+
1
];
float32x4_t
_ker
[
3
];
float32x4_t
elewise_res0
,
elewise_res1
,
elewise_res2
,
res3
;
_ker
[
0
]
=
vld1q_f32
(
filter_ptr0
);
int
out2in_mid
;
_ker
[
1
]
=
vld1q_f32
(
filter_ptr1
);
float32x4_t
zero
=
vdupq_n_f32
(
0.0
);
_ker
[
2
]
=
vld1q_f32
(
filter_ptr2
);
for
(
int
b
=
batch_size
;
b
>
0
;
--
b
)
{
const
float
*
filter_data_tmp
=
filter_data
;
// pad top
for
(
int
j
=
0
;
j
<
c
;
++
j
)
{
for
(
int
h
=
0
;
h
<
valid_h_start
;
++
h
)
{
auto
output_data_tmp
=
output_data
+
j
*
out_h
*
out_w
;
DepthwiseConv3x3NormalRow
<
2
,
2
>
(
input_ptr
,
filter_ptr
,
h
,
input_h
,
auto
input_data_tmp
=
input_data
+
j
*
in_h
*
in_w
;
input_w
,
padding_h
,
padding_w
,
output_w
,
auto
input_const
=
input_data_tmp
;
output_ptr
,
_ker
);
}
if
(
if_bias
)
{
// valid 2x4
vbias
=
vdupq_n_f32
(
bias_data
[
j
]);
int
output_w_tiles
=
valid_w
/
4
;
}
int
output_w_remain
=
valid_w
-
output_w_tiles
*
4
;
for
(
int
h
=
valid_h_start
;
h
<
valid_h_end
-
1
;
h
+=
2
)
{
float
w00
=
filter_data_tmp
[
0
];
const
float
*
input_ptr0
=
input_ptr
+
(
2
*
h
-
padding_h
)
*
input_w
;
float
w01
=
filter_data_tmp
[
1
];
const
float
*
input_ptr1
=
input_ptr0
+
input_w
;
float
w02
=
filter_data_tmp
[
2
];
const
float
*
input_ptr2
=
input_ptr1
+
input_w
;
float
w10
=
filter_data_tmp
[
3
];
const
float
*
input_ptr3
=
input_ptr2
+
input_w
;
float
w11
=
filter_data_tmp
[
4
];
const
float
*
input_ptr4
=
input_ptr3
+
input_w
;
float
w12
=
filter_data_tmp
[
5
];
float
*
output_ptr0
=
output_ptr
+
h
*
output_w
;
float
w20
=
filter_data_tmp
[
6
];
float
*
output_ptr1
=
output_ptr0
+
output_w
;
float
w21
=
filter_data_tmp
[
7
];
// pad left
float
w22
=
filter_data_tmp
[
8
];
if
(
padding_w
)
{
for
(
int
w
=
valid_w_start
-
1
;
w
>=
0
;
--
w
)
{
int
h_mid
=
0
;
int
padding
=
padding_w
-
(
w
<<
1
);
if
(
padding
>=
3
)
{
for
(;
h_mid
<
out_h
-
1
;
h_mid
++
)
{
output_ptr0
[
w
]
=
0
;
input_row_ptr
=
input_data_tmp
+
1
+
h_mid
*
2
*
in_w
;
output_ptr1
[
w
]
=
0
;
output_row_ptr
=
output_data_tmp
+
1
+
h_mid
*
out_w
;
for
(
int
w4
=
0
;
w4
<
w_times
+
1
;
w4
++
)
{
if
(
h_mid
==
0
)
{
elewise_res1
=
zero
;
elewise_res0
=
zero
;
elewise_res2
=
zero
;
}
else
{
}
else
{
elewise_res1
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
1
],
w01
);
float32x4_t
row0
=
vld1q_f32
(
input_ptr0
-
padding
);
elewise_res0
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
0
],
w00
);
float32x4_t
row1
=
vld1q_f32
(
input_ptr1
-
padding
);
elewise_res2
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
0
],
w02
);
float32x4_t
row2
=
vld1q_f32
(
input_ptr2
-
padding
);
}
float32x4_t
row3
=
vld1q_f32
(
input_ptr3
-
padding
);
input_buff_mid
=
vld2q_f32
(
input_row_ptr
);
float32x4_t
row4
=
vld1q_f32
(
input_ptr4
-
padding
);
input_buff_bottom
[
w4
]
=
vld2q_f32
(
input_row_ptr
+
in_w
);
float32x4_t
acc0
=
vmulq_f32
(
row0
,
_ker
[
0
]);
float32x4_t
acc1
=
vmulq_f32
(
row2
,
_ker
[
0
]);
elewise_res1
=
vmlaq_n_f32
(
elewise_res1
,
input_buff_mid
.
val
[
1
],
w11
);
acc0
=
vmlaq_f32
(
acc0
,
row1
,
_ker
[
1
]);
elewise_res0
=
vmlaq_n_f32
(
elewise_res0
,
input_buff_mid
.
val
[
0
],
w10
);
acc1
=
vmlaq_f32
(
acc1
,
row3
,
_ker
[
1
]);
elewise_res2
=
vmlaq_n_f32
(
elewise_res2
,
input_buff_mid
.
val
[
0
],
w12
);
acc0
=
vmlaq_f32
(
acc0
,
row2
,
_ker
[
2
]);
acc1
=
vmlaq_f32
(
acc1
,
row4
,
_ker
[
2
]);
elewise_res1
=
float
sum0
=
vgetq_lane_f32
(
acc0
,
2
);
vmlaq_n_f32
(
elewise_res1
,
input_buff_bottom
[
w4
].
val
[
1
],
w21
);
float
sum1
=
vgetq_lane_f32
(
acc1
,
2
);
elewise_res0
=
if
(
padding
==
1
)
{
vmlaq_n_f32
(
elewise_res0
,
input_buff_bottom
[
w4
].
val
[
0
],
w20
);
sum0
+=
vgetq_lane_f32
(
acc0
,
1
);
elewise_res2
=
sum1
+=
vgetq_lane_f32
(
acc1
,
1
);
vmlaq_n_f32
(
elewise_res2
,
input_buff_bottom
[
w4
].
val
[
0
],
w22
);
}
output_ptr0
[
w
]
=
sum0
;
res3
=
vaddq_f32
(
vextq_f32
(
elewise_res2
,
zero
,
1
),
output_ptr1
[
w
]
=
sum1
;
vaddq_f32
(
elewise_res0
,
elewise_res1
));
res3
=
vaddq_f32
(
res3
,
vbias
);
if
(
if_relu
)
{
res3
=
vmaxq_f32
(
res3
,
zero
);
}
vst1q_f32
(
output_row_ptr
,
res3
);
input_row_ptr
+=
6
;
output_row_ptr
+=
3
;
}
}
clock
();
input_row_ptr
=
input_data_tmp
+
1
+
h_mid
*
2
*
in_w
;
output_row_ptr
=
output_data_tmp
+
1
+
h_mid
*
out_w
;
for
(
int
w4
=
0
;
w4
<
w_times
+
1
;
w4
++
)
{
elewise_res1
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
1
],
w01
);
elewise_res0
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
0
],
w00
);
elewise_res2
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
0
],
w02
);
input_buff_mid
=
vld2q_f32
(
input_row_ptr
);
input_buff_bottom
[
w4
]
=
vld2q_f32
(
input_row_ptr
+
in_w
);
elewise_res1
=
vmlaq_n_f32
(
elewise_res1
,
input_buff_mid
.
val
[
1
],
w11
);
elewise_res0
=
vmlaq_n_f32
(
elewise_res0
,
input_buff_mid
.
val
[
0
],
w10
);
elewise_res2
=
vmlaq_n_f32
(
elewise_res2
,
input_buff_mid
.
val
[
0
],
w12
);
if
(
!
if_pad_b
)
{
elewise_res1
=
vmlaq_n_f32
(
elewise_res1
,
input_buff_bottom
[
w4
].
val
[
1
],
w21
);
elewise_res0
=
vmlaq_n_f32
(
elewise_res0
,
input_buff_bottom
[
w4
].
val
[
0
],
w20
);
elewise_res2
=
vmlaq_n_f32
(
elewise_res2
,
input_buff_bottom
[
w4
].
val
[
0
],
w22
);
}
res3
=
vaddq_f32
(
vextq_f32
(
elewise_res2
,
zero
,
1
),
vaddq_f32
(
elewise_res0
,
elewise_res1
));
res3
=
vaddq_f32
(
res3
,
vbias
);
if
(
if_relu
)
{
res3
=
vmaxq_f32
(
res3
,
zero
);
}
if
((
w4
!=
w_times
))
{
vst1q_f32
(
output_row_ptr
,
res3
);
}
else
{
if
(
out_w
-
2
-
w_times
*
3
==
1
)
{
vst1q_lane_f32
(
output_row_ptr
,
res3
,
0
);
}
else
if
(
out_w
-
2
-
w_times
*
3
==
2
)
{
vst1q_lane_f32
(
output_row_ptr
,
res3
,
0
);
vst1q_lane_f32
(
output_row_ptr
+
1
,
res3
,
1
);
}
}
}
}
input_row_ptr
+=
6
;
input_ptr0
+=
input_w_start
;
output_row_ptr
+=
3
;
input_ptr1
+=
input_w_start
;
input_ptr2
+=
input_w_start
;
input_ptr3
+=
input_w_start
;
input_ptr4
+=
input_w_start
;
output_ptr0
+=
valid_w_start
;
output_ptr1
+=
valid_w_start
;
}
}
// valid
// leftTop, rightTop, leftBottom, rightBottom
float32x4_t
_result0
,
_result1
,
_ext
;
int
lt
=
0
;
for
(
int
loop
=
0
;
loop
<
output_w_tiles
;
++
loop
)
{
int
rt
=
out_w
-
1
;
float32x4x2_t
_row0
=
vld2q_f32
(
input_ptr0
);
int
lb
=
out_w
*
(
out_h
-
1
);
float32x4x2_t
_row1
=
vld2q_f32
(
input_ptr1
);
int
rb
=
out_h
*
out_w
-
1
;
_ext
=
vextq_f32
(
_row0
.
val
[
0
],
_ext
,
1
);
output_data_tmp
[
lt
]
=
input_const
[
0
]
*
w11
+
input_const
[
1
]
*
w12
+
_ext
=
vsetq_lane_f32
(
input_ptr0
[
8
],
_ext
,
3
);
input_const
[
in_w
]
*
w21
+
_result0
=
vmulq_lane_f32
(
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
0
]),
0
);
input_const
[
in_w
+
1
]
*
w22
;
_result0
=
vmlaq_lane_f32
(
_result0
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
0
]),
1
);
out2in_mid
=
(
out_w
-
1
)
*
2
;
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
0
]),
0
);
output_data_tmp
[
rt
]
=
w10
*
input_const
[
out2in_mid
-
1
]
+
w11
*
input_const
[
out2in_mid
]
+
_ext
=
vextq_f32
(
_row1
.
val
[
0
],
_ext
,
1
);
w20
*
input_const
[
out2in_mid
+
in_w
-
1
]
+
_ext
=
vsetq_lane_f32
(
input_ptr1
[
8
],
_ext
,
3
);
w21
*
input_const
[
out2in_mid
+
in_w
]
+
_result0
=
(
1
-
if_pad_r
)
*
(
w12
*
input_const
[
out2in_mid
+
1
]
+
vmlaq_lane_f32
(
_result0
,
_row1
.
val
[
0
],
vget_low_f32
(
_ker
[
1
]),
0
);
w22
*
input_const
[
out2in_mid
+
in_w
+
1
]);
_result0
=
vmlaq_lane_f32
(
_result0
,
_row1
.
val
[
1
],
vget_low_f32
(
_ker
[
1
]),
1
);
out2in_mid
=
(
out_h
-
1
)
*
2
*
in_w
;
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
1
]),
0
);
output_data_tmp
[
lb
]
=
_row0
=
vld2q_f32
(
input_ptr2
);
w01
*
input_const
[
out2in_mid
-
in_w
]
+
_row1
=
vld2q_f32
(
input_ptr3
);
w02
*
input_const
[
out2in_mid
-
in_w
+
1
]
+
w11
*
input_const
[
out2in_mid
]
+
w12
*
input_const
[
out2in_mid
+
1
]
+
_ext
=
vextq_f32
(
_row0
.
val
[
0
],
_ext
,
1
);
(
1
-
if_pad_b
)
*
(
w21
*
input_const
[
out2in_mid
+
in_w
]
+
_ext
=
vsetq_lane_f32
(
input_ptr2
[
8
],
_ext
,
3
);
w22
*
input_const
[
out2in_mid
+
in_w
+
1
]);
_result0
=
out2in_mid
=
(
out_h
-
1
)
*
2
*
in_w
+
(
out_w
-
1
)
*
2
;
vmlaq_lane_f32
(
_result0
,
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
2
]),
0
);
_result0
=
output_data_tmp
[
rb
]
=
vmlaq_lane_f32
(
_result0
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
2
]),
1
);
w00
*
input_const
[
out2in_mid
-
in_w
-
1
]
+
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
2
]),
0
);
w01
*
input_const
[
out2in_mid
-
in_w
]
+
_result1
=
vmulq_lane_f32
(
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
0
]),
0
);
w10
*
input_const
[
out2in_mid
-
1
]
+
w11
*
input_const
[
out2in_mid
]
+
_result1
=
(
1
-
if_pad_r
)
*
(
w20
*
input_const
[
out2in_mid
+
in_w
-
1
]
+
vmlaq_lane_f32
(
_result1
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
0
]),
1
);
w21
*
input_const
[
out2in_mid
+
in_w
])
+
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext
,
vget_high_f32
(
_ker
[
0
]),
0
);
(
1
-
if_pad_b
)
*
(
w02
*
input_const
[
out2in_mid
-
in_w
+
1
]
+
w12
*
input_const
[
out2in_mid
+
1
])
+
_ext
=
vextq_f32
(
_row1
.
val
[
0
],
_ext
,
1
);
(
1
-
if_pad_r
)
*
(
1
-
if_pad_b
)
*
w22
*
_ext
=
vsetq_lane_f32
(
input_ptr3
[
8
],
_ext
,
3
);
input_const
[
out2in_mid
+
in_w
+
1
];
_result1
=
if
(
if_bias
)
{
vmlaq_lane_f32
(
_result1
,
_row1
.
val
[
0
],
vget_low_f32
(
_ker
[
1
]),
0
);
output_data_tmp
[
lt
]
+=
bias_data
[
j
];
_result1
=
output_data_tmp
[
rt
]
+=
bias_data
[
j
];
vmlaq_lane_f32
(
_result1
,
_row1
.
val
[
1
],
vget_low_f32
(
_ker
[
1
]),
1
);
output_data_tmp
[
lb
]
+=
bias_data
[
j
];
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext
,
vget_high_f32
(
_ker
[
1
]),
0
);
output_data_tmp
[
rb
]
+=
bias_data
[
j
];
_row0
=
vld2q_f32
(
input_ptr4
);
_ext
=
vextq_f32
(
_row0
.
val
[
0
],
_ext
,
1
);
_ext
=
vsetq_lane_f32
(
input_ptr4
[
8
],
_ext
,
3
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
2
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
2
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext
,
vget_high_f32
(
_ker
[
2
]),
0
);
vst1q_f32
(
output_ptr0
,
_result0
);
vst1q_f32
(
output_ptr1
,
_result1
);
input_ptr0
+=
8
;
input_ptr1
+=
8
;
input_ptr2
+=
8
;
input_ptr3
+=
8
;
input_ptr4
+=
8
;
output_ptr0
+=
4
;
output_ptr1
+=
4
;
}
}
if
(
if_relu
)
{
// remain w
output_data_tmp
[
lt
]
=
output_data_tmp
[
lt
]
<
0
?
0
:
output_data_tmp
[
lt
];
if
(
output_w_remain
>
0
)
{
output_data_tmp
[
rt
]
=
output_data_tmp
[
rt
]
<
0
?
0
:
output_data_tmp
[
rt
];
float32x4x2_t
_row0
=
vld2q_f32
(
input_ptr0
);
output_data_tmp
[
lb
]
=
output_data_tmp
[
lb
]
<
0
?
0
:
output_data_tmp
[
lb
];
float32x4x2_t
_row1
=
vld2q_f32
(
input_ptr1
);
output_data_tmp
[
rb
]
=
output_data_tmp
[
rb
]
<
0
?
0
:
output_data_tmp
[
rb
];
}
_ext
=
vextq_f32
(
_row0
.
val
[
0
],
_ext
,
1
);
for
(
int
i
=
1
;
i
<
out_h
-
1
;
i
++
)
{
_ext
=
vsetq_lane_f32
(
input_ptr0
[
8
],
_ext
,
3
);
out2in_mid
=
i
*
2
*
in_w
;
_result0
=
vmulq_lane_f32
(
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
0
]),
0
);
int
left
=
i
*
out_w
;
_result0
=
output_data_tmp
[
left
]
=
w01
*
input_const
[
out2in_mid
-
in_w
]
+
vmlaq_lane_f32
(
_result0
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
0
]),
1
);
w02
*
input_const
[
out2in_mid
-
in_w
+
1
]
+
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
0
]),
0
);
w11
*
input_const
[
out2in_mid
]
+
w12
*
input_const
[
out2in_mid
+
1
]
+
_ext
=
vextq_f32
(
_row1
.
val
[
0
],
_ext
,
1
);
w21
*
input_const
[
out2in_mid
+
in_w
]
+
_ext
=
vsetq_lane_f32
(
input_ptr1
[
8
],
_ext
,
3
);
w22
*
input_const
[
out2in_mid
+
in_w
+
1
];
_result0
=
vmlaq_lane_f32
(
_result0
,
_row1
.
val
[
0
],
vget_low_f32
(
_ker
[
1
]),
0
);
out2in_mid
=
i
*
2
*
in_w
+
(
out_w
-
1
)
*
2
;
_result0
=
int
right
=
i
*
out_w
+
out_w
-
1
;
vmlaq_lane_f32
(
_result0
,
_row1
.
val
[
1
],
vget_low_f32
(
_ker
[
1
]),
1
);
output_data_tmp
[
right
]
=
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
1
]),
0
);
w00
*
input_const
[
out2in_mid
-
in_w
-
1
]
+
w01
*
input_const
[
out2in_mid
-
in_w
]
+
_row0
=
vld2q_f32
(
input_ptr2
);
w10
*
input_const
[
out2in_mid
-
1
]
+
w11
*
input_const
[
out2in_mid
]
+
_row1
=
vld2q_f32
(
input_ptr3
);
w20
*
input_const
[
out2in_mid
+
in_w
-
1
]
+
w21
*
input_const
[
out2in_mid
+
in_w
]
+
_ext
=
vextq_f32
(
_row0
.
val
[
0
],
_ext
,
1
);
(
1
-
if_pad_r
)
*
(
w02
*
input_const
[
out2in_mid
-
in_w
+
1
]
+
_ext
=
vsetq_lane_f32
(
input_ptr2
[
8
],
_ext
,
3
);
w12
*
input_const
[
out2in_mid
+
1
]
+
_result0
=
w22
*
input_const
[
out2in_mid
+
in_w
+
1
]);
vmlaq_lane_f32
(
_result0
,
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
2
]),
0
);
if
(
if_bias
)
{
_result0
=
output_data_tmp
[
left
]
+=
bias_data
[
j
];
vmlaq_lane_f32
(
_result0
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
2
]),
1
);
output_data_tmp
[
right
]
+=
bias_data
[
j
];
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
2
]),
0
);
}
_result1
=
vmulq_lane_f32
(
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
0
]),
0
);
if
(
if_relu
)
{
_result1
=
output_data_tmp
[
left
]
=
vmlaq_lane_f32
(
_result1
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
0
]),
1
);
output_data_tmp
[
left
]
<
0
?
0
:
output_data_tmp
[
left
];
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext
,
vget_high_f32
(
_ker
[
0
]),
0
);
output_data_tmp
[
right
]
=
output_data_tmp
[
right
]
<
0
?
0
:
output_data_tmp
[
right
];
_ext
=
vextq_f32
(
_row1
.
val
[
0
],
_ext
,
1
);
_ext
=
vsetq_lane_f32
(
input_ptr3
[
8
],
_ext
,
3
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row1
.
val
[
0
],
vget_low_f32
(
_ker
[
1
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row1
.
val
[
1
],
vget_low_f32
(
_ker
[
1
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext
,
vget_high_f32
(
_ker
[
1
]),
0
);
_row0
=
vld2q_f32
(
input_ptr4
);
_ext
=
vextq_f32
(
_row0
.
val
[
0
],
_ext
,
1
);
_ext
=
vsetq_lane_f32
(
input_ptr4
[
8
],
_ext
,
3
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
2
]),
0
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
2
]),
1
);
_result1
=
vmlaq_lane_f32
(
_result1
,
_ext
,
vget_high_f32
(
_ker
[
2
]),
0
);
switch
(
output_w_remain
)
{
case
3
:
vst1q_lane_f32
(
output_ptr0
+
2
,
_result0
,
2
);
vst1q_lane_f32
(
output_ptr1
+
2
,
_result1
,
2
);
case
2
:
vst1_f32
(
output_ptr0
,
vget_low_f32
(
_result0
));
vst1_f32
(
output_ptr1
,
vget_low_f32
(
_result1
));
break
;
case
1
:
vst1q_lane_f32
(
output_ptr0
,
_result0
,
0
);
vst1q_lane_f32
(
output_ptr1
,
_result1
,
0
);
break
;
}
}
input_ptr0
+=
output_w_remain
*
2
;
input_ptr1
+=
output_w_remain
*
2
;
input_ptr2
+=
output_w_remain
*
2
;
input_ptr3
+=
output_w_remain
*
2
;
input_ptr4
+=
output_w_remain
*
2
;
output_ptr0
+=
output_w_remain
;
output_ptr1
+=
output_w_remain
;
}
}
filter_data_tmp
+=
9
;
// pad right
}
if
(
padding_w
>
0
)
{
input_data
+=
inhxw
*
c
;
float32x4_t
row0
=
vld1q_f32
(
input_ptr0
);
output_data
+=
outhxw
*
c
;
float32x4_t
row1
=
vld1q_f32
(
input_ptr1
);
}
float32x4_t
row2
=
vld1q_f32
(
input_ptr2
);
#endif
float32x4_t
row3
=
vld1q_f32
(
input_ptr3
);
}
float32x4_t
row4
=
vld1q_f32
(
input_ptr4
);
float32x4_t
acc0
,
acc1
;
void
DepthwiseConvAddBNRelu3x3s2p1v2
(
const
framework
::
Tensor
*
input
,
for
(
int
w
=
valid_w_end
;
w
<
output_w
;
++
w
)
{
const
framework
::
Tensor
*
filter
,
int
padding
=
2
*
w
+
3
-
(
padding_w
+
input_w
);
framework
::
Tensor
*
output
,
if
(
padding
>=
3
)
{
const
framework
::
Tensor
*
new_scale
,
*
output_ptr0
=
0
;
const
framework
::
Tensor
*
new_bias
,
*
output_ptr1
=
0
;
bool
if_relu
)
{
#if __ARM_NEON
// #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>();
// const float *newbias_data = new_bias->data<float>();
//
// const int batch_size = static_cast<int>(input->dims()[0]);
// const int input_channel = static_cast<int>(input->dims()[1]);
//
// const int input_height = static_cast<int>(input->dims()[2]);
// const int input_width = static_cast<int>(input->dims()[3]);
// const int output_height = static_cast<int>(output->dims()[2]);
// const int output_width = static_cast<int>(output->dims()[3]);
// const int inhxw = input_height * input_width;
// const int outhxw = output_height * output_width;
//
// float32x4_t zero = vdupq_n_f32(0.0);
// for (int b = 0; b < batch_size; b++) {
// #pragma omp parallel for
// for (int c = 0; c < input_channel; c++) {
// const float *filter_data = filter->data<float>() + c * 9;
// const float *input_data = input->data<float>() + c * inhxw;
// float *output_data = output->data<float>() + c * outhxw;
// float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]);
// float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]);
//
// float w00 = filter_data[0];
// float w01 = filter_data[1];
// float w02 = filter_data[2];
// float w10 = filter_data[3];
// float w11 = filter_data[4];
// float w12 = filter_data[5];
// float w20 = filter_data[6];
// float w21 = filter_data[7];
// float w22 = filter_data[8];
//
// int m;
// for (m = 1; m < output_width - 2; m = m + 3) {
// float *output_ptr = output_data + m;
// float32x4x2_t input_buff_mid{}, input_buff_bottom{};
// float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
// input_buff_mid = vld2q_f32(input_data + (2 * m - 1));
// input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m -
// 1));
//
// in0 = input_buff_mid.val[0];
// tmp0 = input_buff_mid.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
//
// in2 = input_buff_bottom.val[0];
// tmp2 = input_buff_bottom.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
//
// out0 = vmulq_n_f32(in0, w10);
// out0 = vmlaq_n_f32(out0, tmp0, w11);
// out0 = vmlaq_n_f32(out0, tmp1, w12);
// out0 = vmlaq_n_f32(out0, in2, w20);
// out0 = vmlaq_n_f32(out0, tmp2, w21);
// out0 = vmlaq_n_f32(out0, tmp3, w22);
// out0 = vmlaq_f32(vnewbias, vnewscale, out0);
// if (if_relu) {
// out0 = vmaxq_f32(out0, zero);
// }
// vst1q_lane_f32(output_ptr, out0, 0);
// vst1q_lane_f32(output_ptr + 1, out0, 1);
// vst1q_lane_f32(output_ptr + 2, out0, 2);
// }
// for (m = 1; m < output_width - 2; m += 3) {
// }
// for (int j = m; j < output_width; j++) {
// output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] *
// w11 +
// input_data[2 * j + 1] * w12 +
// input_data[2 * j - 1 + input_width] * w20 +
// input_data[2 * j + input_width] * w21 +
// input_data[2 * j + 1 + input_width] * w22;
// output_data[j] = newscale_data[c] * output_data[j] +
// newbias_data[c]; if (if_relu) {
// output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
// }
// }
//
// for (int i = 1; i < output_height; i += 1) {
// for (int m = 1; m < output_width - 2; m += 3) {
// float *output_ptr = output_data + i * output_width + m;
// float32x4x2_t input_buff_top{}, input_buff_mid{},
// input_buff_bottom{}; float32x4_t in0, in1, in2, in3, in4, in5,
// tmp0, tmp1, tmp2, tmp3,
// tmp4, tmp5, out0;
// input_buff_top =
// vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m -
// 1));
// input_buff_mid =
// vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1));
// input_buff_bottom =
// vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m -
// 1));
//
// in0 = input_buff_top.val[0];
// tmp0 = input_buff_top.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
//
// in2 = input_buff_mid.val[0];
// tmp2 = input_buff_mid.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
//
// in4 = input_buff_bottom.val[0];
// tmp4 = input_buff_bottom.val[1];
// tmp5 = vextq_f32(in4, zero, 1);
//
// out0 = vmulq_n_f32(in0, w00);
// out0 = vmlaq_n_f32(out0, tmp0, w01);
// out0 = vmlaq_n_f32(out0, tmp1, w02);
// out0 = vmlaq_n_f32(out0, in2, w10);
// out0 = vmlaq_n_f32(out0, tmp2, w11);
// out0 = vmlaq_n_f32(out0, tmp3, w12);
// out0 = vmlaq_n_f32(out0, in4, w20);
// out0 = vmlaq_n_f32(out0, tmp4, w21);
// out0 = vmlaq_n_f32(out0, tmp5, w22);
// out0 = vmlaq_f32(vnewbias, vnewscale, out0);
// if (if_relu) {
// out0 = vmaxq_f32(out0, zero);
// }
// vst1q_lane_f32(output_ptr, out0, 0);
// vst1q_lane_f32(output_ptr + 1, out0, 1);
// vst1q_lane_f32(output_ptr + 2, out0, 2);
// }
// int m;
// for (m = 1; m < output_width - 2; m += 3) {
// }
// for (int j = m; j < output_width; j++) {
// output_data[i * output_width + j] =
// input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 +
// input_data[(2 * i - 1) * input_width + 2 * j] * w01 +
// input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 +
// input_data[(2 * i) * input_width + 2 * j - 1] * w10 +
// input_data[(2 * i) * input_width + 2 * j] * w11 +
// input_data[(2 * i) * input_width + 2 * j + 1] * w12 +
// input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 +
// input_data[(2 * i + 1) * input_width + 2 * j] * w21 +
// input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22;
// output_data[i * output_width + j] =
// newscale_data[c] * output_data[i * output_width + j] +
// newbias_data[c];
// if (if_relu) {
// output_data[i * output_width + j] =
// output_data[i * output_width + j] < 0
// ? 0
// : output_data[i * output_width + j];
// }
// }
// }
// output_data[0] = input_data[0] * w11 + input_data[1] * w12 +
// input_data[input_height] * w21 +
// input_data[input_height + 1] * w22;
//
// output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c];
// if (if_relu) {
// output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
// }
// for (int i = 1; i < output_height; i++) {
// output_data[i * output_width] =
// input_data[(2 * i - 1) * input_width] * w01 +
// input_data[(2 * i - 1) * input_width + 1] * w02 +
// input_data[(2 * i) * input_width] * w11 +
// input_data[(2 * i) * input_width + 1] * w12 +
// input_data[(2 * i + 1) * input_width] * w21 +
// input_data[(2 * i + 1) * input_width + 1] * w22;
//
// output_data[i * output_width] =
// newscale_data[c] * output_data[i * output_width] +
// newbias_data[c];
// if (if_relu) {
// output_data[i * output_width] = output_data[i * output_width] < 0
// ? 0
// : output_data[i *
// output_width];
// }
// }
// }
// }
//
// #else
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
float
*
output_data
=
output
->
mutable_data
<
float
>
();
const
float
*
newscale_data
=
new_scale
->
data
<
float
>
();
const
float
*
newbias_data
=
new_bias
->
data
<
float
>
();
const
int
in_h
=
static_cast
<
int
>
(
input
->
dims
()[
2
]);
const
int
in_w
=
static_cast
<
int
>
(
input
->
dims
()[
3
]);
const
int
out_h
=
static_cast
<
int
>
(
output
->
dims
()[
2
]);
const
int
out_w
=
static_cast
<
int
>
(
output
->
dims
()[
3
]);
// const int out_l = out_h;
// const int in_l = in_h;
const
int
inhxw
=
in_h
*
in_w
;
const
int
outhxw
=
out_h
*
out_w
;
/// todo : fix if_pad when w != h
const
int
if_pad_r
=
in_w
-
1
==
(
out_w
-
1
)
*
2
?
1
:
0
;
const
int
if_pad_b
=
in_h
-
1
==
(
out_h
-
1
)
*
2
?
1
:
0
;
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
c
=
static_cast
<
int
>
(
input
->
dims
()[
1
]);
const
int
w_times
=
(
out_w
-
2
)
/
3
;
float32x4_t
zero
=
vdupq_n_f32
(
0.0
);
for
(
int
b
=
batch_size
;
b
>
0
;
--
b
)
{
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
const
float
*
input_row_ptr
;
float
*
output_row_ptr
;
float32x4x2_t
input_buff_mid
{},
input_buff_bottom
[
w_times
+
1
];
float32x4_t
elewise_res0
,
elewise_res1
,
elewise_res2
,
res3
;
int
out2in_mid
;
float32x4_t
vnewbias
=
vdupq_n_f32
(
0.0
);
float32x4_t
vnewscale
=
vdupq_n_f32
(
1.0
);
auto
output_data_tmp
=
output_data
+
j
*
out_h
*
out_w
;
auto
input_data_tmp
=
input_data
+
j
*
in_h
*
in_w
;
auto
input_const
=
input_data_tmp
;
const
float
*
filter_data_tmp
=
filter_data
+
9
*
j
;
vnewbias
=
vdupq_n_f32
(
newbias_data
[
j
]);
vnewscale
=
vdupq_n_f32
(
newscale_data
[
j
]);
float
w00
=
filter_data_tmp
[
0
];
float
w01
=
filter_data_tmp
[
1
];
float
w02
=
filter_data_tmp
[
2
];
float
w10
=
filter_data_tmp
[
3
];
float
w11
=
filter_data_tmp
[
4
];
float
w12
=
filter_data_tmp
[
5
];
float
w20
=
filter_data_tmp
[
6
];
float
w21
=
filter_data_tmp
[
7
];
float
w22
=
filter_data_tmp
[
8
];
int
h_mid
=
0
;
for
(;
h_mid
<
out_h
-
1
;
h_mid
++
)
{
input_row_ptr
=
input_data_tmp
+
1
+
h_mid
*
2
*
in_w
;
output_row_ptr
=
output_data_tmp
+
1
+
h_mid
*
out_w
;
for
(
int
w4
=
0
;
w4
<
w_times
+
1
;
w4
++
)
{
if
(
h_mid
==
0
)
{
elewise_res1
=
zero
;
elewise_res0
=
zero
;
elewise_res2
=
zero
;
}
else
{
}
else
{
elewise_res1
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
1
],
w01
);
acc0
=
vmulq_f32
(
row0
,
_ker
[
0
]);
elewise_res0
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
0
],
w00
);
acc1
=
vmulq_f32
(
row2
,
_ker
[
0
]);
elewise_res2
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
0
],
w02
);
acc0
=
vmlaq_f32
(
acc0
,
row1
,
_ker
[
1
]);
}
acc1
=
vmlaq_f32
(
acc1
,
row3
,
_ker
[
1
]);
input_buff_mid
=
vld2q_f32
(
input_row_ptr
);
acc0
=
vmlaq_f32
(
acc0
,
row2
,
_ker
[
2
]);
input_buff_bottom
[
w4
]
=
vld2q_f32
(
input_row_ptr
+
in_w
);
acc1
=
vmlaq_f32
(
acc1
,
row4
,
_ker
[
2
]);
float
sum0
=
vgetq_lane_f32
(
acc0
,
0
);
elewise_res1
=
vmlaq_n_f32
(
elewise_res1
,
input_buff_mid
.
val
[
1
],
w11
);
float
sum1
=
vgetq_lane_f32
(
acc1
,
0
);
elewise_res0
=
vmlaq_n_f32
(
elewise_res0
,
input_buff_mid
.
val
[
0
],
w10
);
if
(
padding
==
1
)
{
elewise_res2
=
vmlaq_n_f32
(
elewise_res2
,
input_buff_mid
.
val
[
0
],
w12
);
sum0
+=
vgetq_lane_f32
(
acc0
,
1
);
sum1
+=
vgetq_lane_f32
(
acc1
,
1
);
elewise_res1
=
}
vmlaq_n_f32
(
elewise_res1
,
input_buff_bottom
[
w4
].
val
[
1
],
w21
);
*
output_ptr0
=
sum0
;
elewise_res0
=
*
output_ptr1
=
sum1
;
vmlaq_n_f32
(
elewise_res0
,
input_buff_bottom
[
w4
].
val
[
0
],
w20
);
elewise_res2
=
vmlaq_n_f32
(
elewise_res2
,
input_buff_bottom
[
w4
].
val
[
0
],
w22
);
res3
=
vaddq_f32
(
vextq_f32
(
elewise_res2
,
zero
,
1
),
vaddq_f32
(
elewise_res0
,
elewise_res1
));
res3
=
vmlaq_f32
(
vnewbias
,
vnewscale
,
res3
);
if
(
if_relu
)
{
res3
=
vmaxq_f32
(
res3
,
zero
);
}
}
vst1q_lane_f32
(
output_row_ptr
,
res3
,
0
);
output_ptr0
++
;
vst1q_lane_f32
(
output_row_ptr
+
1
,
res3
,
1
);
output_ptr1
++
;
vst1q_lane_f32
(
output_row_ptr
+
2
,
res3
,
2
);
input_row_ptr
+=
6
;
output_row_ptr
+=
3
;
}
}
}
}
clock
();
}
// remain height
input_row_ptr
=
input_data_tmp
+
1
+
h_mid
*
2
*
in_w
;
int
start_h
=
valid_h_start
+
(
valid_h
&
0xfffe
);
output_row_ptr
=
output_data_tmp
+
1
+
h_mid
*
out_w
;
if
(
start_h
<
valid_h_end
)
{
const
float
*
input_ptr0
=
input_ptr
+
(
2
*
start_h
-
padding_h
)
*
input_w
;
for
(
int
w4
=
0
;
w4
<
w_times
+
1
;
w4
++
)
{
const
float
*
input_ptr1
=
input_ptr0
+
input_w
;
elewise_res1
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
1
],
w01
);
const
float
*
input_ptr2
=
input_ptr1
+
input_w
;
elewise_res0
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
0
],
w00
);
float
*
output_ptr0
=
output_ptr
+
start_h
*
output_w
;
elewise_res2
=
vmulq_n_f32
(
input_buff_bottom
[
w4
].
val
[
0
],
w02
);
// pad left
if
(
padding_w
)
{
input_buff_mid
=
vld2q_f32
(
input_row_ptr
);
for
(
int
w
=
valid_w_start
-
1
;
w
>=
0
;
--
w
)
{
input_buff_bottom
[
w4
]
=
vld2q_f32
(
input_row_ptr
+
in_w
);
int
padding
=
padding_w
-
(
w
<<
1
);
if
(
padding
>=
3
)
{
elewise_res1
=
vmlaq_n_f32
(
elewise_res1
,
input_buff_mid
.
val
[
1
],
w11
);
output_ptr0
[
w
]
=
0
;
elewise_res0
=
vmlaq_n_f32
(
elewise_res0
,
input_buff_mid
.
val
[
0
],
w10
);
}
else
{
elewise_res2
=
vmlaq_n_f32
(
elewise_res2
,
input_buff_mid
.
val
[
0
],
w12
);
float32x4_t
row0
=
vld1q_f32
(
input_ptr0
-
padding
);
float32x4_t
row1
=
vld1q_f32
(
input_ptr1
-
padding
);
if
(
!
if_pad_b
)
{
float32x4_t
row2
=
vld1q_f32
(
input_ptr2
-
padding
);
elewise_res1
=
float32x4_t
acc0
=
vmulq_f32
(
row0
,
_ker
[
0
]);
vmlaq_n_f32
(
elewise_res1
,
input_buff_bottom
[
w4
].
val
[
1
],
w21
);
acc0
=
vmlaq_f32
(
acc0
,
row1
,
_ker
[
1
]);
elewise_res0
=
acc0
=
vmlaq_f32
(
acc0
,
row2
,
_ker
[
2
]);
vmlaq_n_f32
(
elewise_res0
,
input_buff_bottom
[
w4
].
val
[
0
],
w20
);
float
sum0
=
vgetq_lane_f32
(
acc0
,
2
);
elewise_res2
=
if
(
padding
==
1
)
{
vmlaq_n_f32
(
elewise_res2
,
input_buff_bottom
[
w4
].
val
[
0
],
w22
);
sum0
+=
vgetq_lane_f32
(
acc0
,
1
);
}
}
res3
=
vaddq_f32
(
vextq_f32
(
elewise_res2
,
zero
,
1
),
output_ptr0
[
w
]
=
sum0
;
vaddq_f32
(
elewise_res0
,
elewise_res1
));
res3
=
vmlaq_f32
(
vnewbias
,
vnewscale
,
res3
);
if
(
if_relu
)
{
res3
=
vmaxq_f32
(
res3
,
zero
);
}
if
((
w4
!=
w_times
))
{
vst1q_lane_f32
(
output_row_ptr
,
res3
,
0
);
vst1q_lane_f32
(
output_row_ptr
+
1
,
res3
,
1
);
vst1q_lane_f32
(
output_row_ptr
+
2
,
res3
,
2
);
}
else
{
if
(
out_w
-
2
-
w_times
*
3
==
1
)
{
vst1q_lane_f32
(
output_row_ptr
,
res3
,
0
);
}
else
if
(
out_w
-
2
-
w_times
*
3
==
2
)
{
vst1q_lane_f32
(
output_row_ptr
,
res3
,
0
);
vst1q_lane_f32
(
output_row_ptr
+
1
,
res3
,
1
);
}
}
}
}
input_row_ptr
+=
6
;
input_ptr0
+=
input_w_start
;
output_row_ptr
+=
3
;
input_ptr1
+=
input_w_start
;
input_ptr2
+=
input_w_start
;
output_ptr0
+=
valid_w_start
;
}
}
// valid
output_data_tmp
[
0
]
=
input_const
[
0
]
*
w11
+
input_const
[
1
]
*
w12
+
float32x4_t
_result0
,
_ext
;
input_const
[
in_w
]
*
w21
+
for
(
int
loop
=
0
;
loop
<
output_w_tiles
;
++
loop
)
{
input_const
[
in_w
+
1
]
*
w22
;
float32x4x2_t
_row0
=
vld2q_f32
(
input_ptr0
);
float32x4x2_t
_row1
=
vld2q_f32
(
input_ptr1
);
out2in_mid
=
(
out_w
-
1
)
*
2
;
float32x4x2_t
_row2
=
vld2q_f32
(
input_ptr2
);
output_data_tmp
[
out_w
-
1
]
=
w10
*
input_const
[
out2in_mid
-
1
]
+
w11
*
input_const
[
out2in_mid
]
+
_ext
=
vextq_f32
(
_row0
.
val
[
0
],
_ext
,
1
);
w20
*
input_const
[
out2in_mid
+
in_w
-
1
]
+
_ext
=
vsetq_lane_f32
(
input_ptr0
[
8
],
_ext
,
3
);
w21
*
input_const
[
out2in_mid
+
in_w
]
+
_result0
=
vmulq_lane_f32
(
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
0
]),
0
);
(
1
-
if_pad_r
)
*
(
w12
*
input_const
[
out2in_mid
+
1
]
+
_result0
=
w22
*
input_const
[
out2in_mid
+
in_w
+
1
]);
vmlaq_lane_f32
(
_result0
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
0
]),
1
);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
0
]),
0
);
out2in_mid
=
(
out_h
-
1
)
*
2
*
in_w
;
_ext
=
vextq_f32
(
_row1
.
val
[
0
],
_ext
,
1
);
output_data_tmp
[
out_w
*
(
out_h
-
1
)]
=
_ext
=
vsetq_lane_f32
(
input_ptr1
[
8
],
_ext
,
3
);
w01
*
input_const
[
out2in_mid
-
in_w
]
+
_result0
=
w02
*
input_const
[
out2in_mid
-
in_w
+
1
]
+
vmlaq_lane_f32
(
_result0
,
_row1
.
val
[
0
],
vget_low_f32
(
_ker
[
1
]),
0
);
w11
*
input_const
[
out2in_mid
]
+
w12
*
input_const
[
out2in_mid
+
1
]
+
_result0
=
(
1
-
if_pad_b
)
*
(
w21
*
input_const
[
out2in_mid
+
in_w
]
+
vmlaq_lane_f32
(
_result0
,
_row1
.
val
[
1
],
vget_low_f32
(
_ker
[
1
]),
1
);
w22
*
input_const
[
out2in_mid
+
in_w
+
1
]);
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
1
]),
0
);
out2in_mid
=
(
out_h
-
1
)
*
2
*
in_w
+
(
out_w
-
1
)
*
2
;
_ext
=
vextq_f32
(
_row2
.
val
[
0
],
_ext
,
1
);
output_data_tmp
[
out_h
*
out_w
-
1
]
=
_ext
=
vsetq_lane_f32
(
input_ptr2
[
8
],
_ext
,
3
);
w00
*
input_const
[
out2in_mid
-
in_w
-
1
]
+
_result0
=
w01
*
input_const
[
out2in_mid
-
in_w
]
+
vmlaq_lane_f32
(
_result0
,
_row2
.
val
[
0
],
vget_low_f32
(
_ker
[
2
]),
0
);
w10
*
input_const
[
out2in_mid
-
1
]
+
w11
*
input_const
[
out2in_mid
]
+
_result0
=
(
1
-
if_pad_r
)
*
(
w20
*
input_const
[
out2in_mid
+
in_w
-
1
]
+
vmlaq_lane_f32
(
_result0
,
_row2
.
val
[
1
],
vget_low_f32
(
_ker
[
2
]),
1
);
w21
*
input_const
[
out2in_mid
+
in_w
])
+
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
2
]),
0
);
(
1
-
if_pad_b
)
*
(
w02
*
input_const
[
out2in_mid
-
in_w
+
1
]
+
w12
*
input_const
[
out2in_mid
+
1
])
+
vst1q_f32
(
output_ptr0
,
_result0
);
(
1
-
if_pad_r
)
*
(
1
-
if_pad_b
)
*
w22
*
input_const
[
out2in_mid
+
in_w
+
1
];
input_ptr0
+=
8
;
output_data_tmp
[
0
]
=
input_ptr1
+=
8
;
output_data_tmp
[
0
]
*
newscale_data
[
j
]
+
newbias_data
[
j
];
input_ptr2
+=
8
;
output_data_tmp
[
out_w
-
1
]
=
output_ptr0
+=
4
;
output_data_tmp
[
out_w
-
1
]
*
newscale_data
[
j
]
+
newbias_data
[
j
];
output_data_tmp
[
out_w
*
(
out_h
-
1
)]
=
output_data_tmp
[
out_w
*
(
out_h
-
1
)]
*
newscale_data
[
j
]
+
newbias_data
[
j
];
output_data_tmp
[
out_h
*
out_w
-
1
]
=
output_data_tmp
[
out_h
*
out_w
-
1
]
*
newscale_data
[
j
]
+
newbias_data
[
j
];
if
(
if_relu
)
{
output_data_tmp
[
0
]
=
output_data_tmp
[
0
]
<
0
?
0
:
output_data_tmp
[
0
];
output_data_tmp
[
out_w
-
1
]
=
output_data_tmp
[
out_w
-
1
]
<
0
?
0
:
output_data_tmp
[
out_w
-
1
];
output_data_tmp
[
out_w
*
(
out_h
-
1
)]
=
output_data_tmp
[
out_w
*
(
out_h
-
1
)]
<
0
?
0
:
output_data_tmp
[
out_w
*
(
out_h
-
1
)];
output_data_tmp
[
out_h
*
out_w
-
1
]
=
output_data_tmp
[
out_h
*
out_w
-
1
]
<
0
?
0
:
output_data_tmp
[
out_h
*
out_w
-
1
];
}
}
for
(
int
i
=
1
;
i
<
out_h
-
1
;
i
++
)
{
// remain w
out2in_mid
=
i
*
2
*
in_w
;
if
(
output_w_remain
>
0
)
{
output_data_tmp
[
i
*
out_w
]
=
w01
*
input_const
[
out2in_mid
-
in_w
]
+
float32x4x2_t
_row0
=
vld2q_f32
(
input_ptr0
);
w02
*
input_const
[
out2in_mid
-
in_w
+
1
]
+
float32x4x2_t
_row1
=
vld2q_f32
(
input_ptr1
);
w11
*
input_const
[
out2in_mid
]
+
float32x4x2_t
_row2
=
vld2q_f32
(
input_ptr2
);
w12
*
input_const
[
out2in_mid
+
1
]
+
w21
*
input_const
[
out2in_mid
+
in_w
]
+
_ext
=
vextq_f32
(
_row0
.
val
[
0
],
_ext
,
1
);
w22
*
input_const
[
out2in_mid
+
in_w
+
1
];
_ext
=
vsetq_lane_f32
(
input_ptr0
[
8
],
_ext
,
3
);
_result0
=
vmulq_lane_f32
(
_row0
.
val
[
0
],
vget_low_f32
(
_ker
[
0
]),
0
);
out2in_mid
=
i
*
2
*
in_w
+
(
out_w
-
1
)
*
2
;
_result0
=
output_data_tmp
[
i
*
out_w
+
out_w
-
1
]
=
vmlaq_lane_f32
(
_result0
,
_row0
.
val
[
1
],
vget_low_f32
(
_ker
[
0
]),
1
);
w00
*
input_const
[
out2in_mid
-
in_w
-
1
]
+
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
0
]),
0
);
w01
*
input_const
[
out2in_mid
-
in_w
]
+
w10
*
input_const
[
out2in_mid
-
1
]
+
w11
*
input_const
[
out2in_mid
]
+
_ext
=
vextq_f32
(
_row1
.
val
[
0
],
_ext
,
1
);
w20
*
input_const
[
out2in_mid
+
in_w
-
1
]
+
_ext
=
vsetq_lane_f32
(
input_ptr1
[
8
],
_ext
,
3
);
w21
*
input_const
[
out2in_mid
+
in_w
]
+
_result0
=
(
1
-
if_pad_r
)
*
(
w02
*
input_const
[
out2in_mid
-
in_w
+
1
]
+
vmlaq_lane_f32
(
_result0
,
_row1
.
val
[
0
],
vget_low_f32
(
_ker
[
1
]),
0
);
w12
*
input_const
[
out2in_mid
+
1
]
+
_result0
=
w22
*
input_const
[
out2in_mid
+
in_w
+
1
]);
vmlaq_lane_f32
(
_result0
,
_row1
.
val
[
1
],
vget_low_f32
(
_ker
[
1
]),
1
);
output_data_tmp
[
i
*
out_w
]
=
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
1
]),
0
);
output_data_tmp
[
i
*
out_w
]
*
newscale_data
[
j
]
+
newbias_data
[
j
];
output_data_tmp
[
i
*
out_w
+
out_w
-
1
]
=
_ext
=
vextq_f32
(
_row2
.
val
[
0
],
_ext
,
1
);
output_data_tmp
[
i
*
out_w
+
out_w
-
1
]
*
newscale_data
[
j
]
+
_ext
=
vsetq_lane_f32
(
input_ptr2
[
8
],
_ext
,
3
);
newbias_data
[
j
];
_result0
=
if
(
if_relu
)
{
vmlaq_lane_f32
(
_result0
,
_row2
.
val
[
0
],
vget_low_f32
(
_ker
[
2
]),
0
);
output_data_tmp
[
i
*
out_w
]
=
_result0
=
output_data_tmp
[
i
*
out_w
]
<
0
?
0
:
output_data_tmp
[
i
*
out_w
];
vmlaq_lane_f32
(
_result0
,
_row2
.
val
[
1
],
vget_low_f32
(
_ker
[
2
]),
1
);
output_data_tmp
[
i
*
out_w
+
out_w
-
1
]
=
_result0
=
vmlaq_lane_f32
(
_result0
,
_ext
,
vget_high_f32
(
_ker
[
2
]),
0
);
output_data_tmp
[
i
*
out_w
+
out_w
-
1
]
<
0
?
0
switch
(
output_w_remain
)
{
:
output_data_tmp
[
i
*
out_w
+
out_w
-
1
];
case
3
:
vst1q_lane_f32
(
output_ptr0
+
2
,
_result0
,
2
);
case
2
:
vst1_f32
(
output_ptr0
,
vget_low_f32
(
_result0
));
break
;
case
1
:
vst1q_lane_f32
(
output_ptr0
,
_result0
,
0
);
break
;
}
}
input_ptr0
+=
output_w_remain
*
2
;
input_ptr1
+=
output_w_remain
*
2
;
input_ptr2
+=
output_w_remain
*
2
;
output_ptr0
+=
output_w_remain
;
}
}
}
// pad right
input_data
+=
inhxw
*
c
;
if
(
padding_w
)
{
output_data
+=
outhxw
*
c
;
float32x4_t
row0
=
vld1q_f32
(
input_ptr0
);
}
float32x4_t
row1
=
vld1q_f32
(
input_ptr1
);
// #endif
float32x4_t
row2
=
vld1q_f32
(
input_ptr2
);
#endif
float32x4_t
acc0
;
}
for
(
int
w
=
valid_w_end
;
w
<
output_w
;
++
w
)
{
int
padding
=
2
*
w
+
3
-
(
padding_w
+
input_w
);
void
DepthwiseConv3x3s2p0
(
const
framework
::
Tensor
*
input
,
if
(
padding
>=
3
)
{
const
framework
::
Tensor
*
filter
,
*
output_ptr0
=
0
;
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias
,
}
else
{
bool
if_bias
,
bool
if_relu
)
{
acc0
=
vmulq_f32
(
row0
,
_ker
[
0
]);
#if __ARM_NEON
acc0
=
vmlaq_f32
(
acc0
,
row1
,
_ker
[
1
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
acc0
=
vmlaq_f32
(
acc0
,
row2
,
_ker
[
2
]);
const
int
input_channel
=
static_cast
<
int
>
(
input
->
dims
()[
1
]);
float
sum0
=
vgetq_lane_f32
(
acc0
,
0
);
if
(
padding
==
1
)
{
const
int
input_height
=
static_cast
<
int
>
(
input
->
dims
()[
2
]);
sum0
+=
vgetq_lane_f32
(
acc0
,
1
);
const
int
input_width
=
static_cast
<
int
>
(
input
->
dims
()[
3
]);
}
const
int
output_height
=
static_cast
<
int
>
(
output
->
dims
()[
2
]);
*
output_ptr0
=
sum0
;
const
int
output_width
=
static_cast
<
int
>
(
output
->
dims
()[
3
]);
const
int
inhxw
=
input_height
*
input_width
;
const
int
outhxw
=
output_height
*
output_width
;
output
->
mutable_data
<
float
>
();
float32x4_t
zero
=
vdupq_n_f32
(
0.0
);
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
input_channel
;
c
++
)
{
const
float
*
filter_data
=
filter
->
data
<
float
>
()
+
c
*
9
;
const
float
*
input_data
=
input
->
data
<
float
>
()
+
c
*
inhxw
;
const
float
*
bias_data
;
float32x4_t
biasv
;
if
(
if_bias
)
{
bias_data
=
bias
->
data
<
float
>
()
+
c
;
biasv
=
vld1q_dup_f32
(
bias_data
);
}
float
*
output_data
=
output
->
data
<
float
>
()
+
c
*
outhxw
;
float
w00
=
filter_data
[
0
];
float
w01
=
filter_data
[
1
];
float
w02
=
filter_data
[
2
];
float
w10
=
filter_data
[
3
];
float
w11
=
filter_data
[
4
];
float
w12
=
filter_data
[
5
];
float
w20
=
filter_data
[
6
];
float
w21
=
filter_data
[
7
];
float
w22
=
filter_data
[
8
];
for
(
int
i
=
0
;
i
<
output_height
;
i
+=
1
)
{
for
(
int
m
=
0
;
m
<
output_width
-
2
;
m
+=
3
)
{
float
*
output_ptr
=
output_data
+
i
*
output_width
+
m
;
float32x4x2_t
input_buff_top
{},
input_buff_mid
{},
input_buff_bottom
{};
float32x4_t
in0
,
in1
,
in2
,
in3
,
in4
,
in5
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
tmp4
,
tmp5
,
out0
;
input_buff_top
=
vld2q_f32
(
input_data
+
(
2
*
i
)
*
input_width
+
(
2
*
m
));
input_buff_mid
=
vld2q_f32
(
input_data
+
(
2
*
i
+
1
)
*
input_width
+
(
2
*
m
));
input_buff_bottom
=
vld2q_f32
(
input_data
+
(
2
*
i
+
2
)
*
input_width
+
(
2
*
m
));
in0
=
input_buff_top
.
val
[
0
];
tmp0
=
input_buff_top
.
val
[
1
];
tmp1
=
vextq_f32
(
in0
,
zero
,
1
);
in2
=
input_buff_mid
.
val
[
0
];
tmp2
=
input_buff_mid
.
val
[
1
];
tmp3
=
vextq_f32
(
in2
,
zero
,
1
);
in4
=
input_buff_bottom
.
val
[
0
];
tmp4
=
input_buff_bottom
.
val
[
1
];
tmp5
=
vextq_f32
(
in4
,
zero
,
1
);
out0
=
vmulq_n_f32
(
in0
,
w00
);
out0
=
vmlaq_n_f32
(
out0
,
tmp0
,
w01
);
out0
=
vmlaq_n_f32
(
out0
,
tmp1
,
w02
);
out0
=
vmlaq_n_f32
(
out0
,
in2
,
w10
);
out0
=
vmlaq_n_f32
(
out0
,
tmp2
,
w11
);
out0
=
vmlaq_n_f32
(
out0
,
tmp3
,
w12
);
out0
=
vmlaq_n_f32
(
out0
,
in4
,
w20
);
out0
=
vmlaq_n_f32
(
out0
,
tmp4
,
w21
);
out0
=
vmlaq_n_f32
(
out0
,
tmp5
,
w22
);
if
(
if_bias
)
{
out0
=
vaddq_f32
(
out0
,
biasv
);
}
if
(
if_relu
)
{
out0
=
vmaxq_f32
(
out0
,
zero
);
}
vst1q_lane_f32
(
output_ptr
,
out0
,
0
);
vst1q_lane_f32
(
output_ptr
+
1
,
out0
,
1
);
vst1q_lane_f32
(
output_ptr
+
2
,
out0
,
2
);
}
int
m
;
for
(
m
=
0
;
m
<
output_width
-
2
;
m
+=
3
)
{
}
for
(
int
j
=
m
;
j
<
output_width
;
j
++
)
{
int
index
=
i
*
output_width
+
j
;
output_data
[
index
]
=
input_data
[(
2
*
i
)
*
input_width
+
2
*
j
]
*
w00
+
input_data
[(
2
*
i
)
*
input_width
+
2
*
j
+
1
]
*
w01
+
input_data
[(
2
*
i
)
*
input_width
+
2
*
j
+
2
]
*
w02
+
input_data
[(
2
*
i
+
1
)
*
input_width
+
2
*
j
]
*
w10
+
input_data
[(
2
*
i
+
1
)
*
input_width
+
2
*
j
+
1
]
*
w11
+
input_data
[(
2
*
i
+
1
)
*
input_width
+
2
*
j
+
2
]
*
w12
+
input_data
[(
2
*
i
+
2
)
*
input_width
+
2
*
j
]
*
w20
+
input_data
[(
2
*
i
+
2
)
*
input_width
+
2
*
j
+
1
]
*
w21
+
input_data
[(
2
*
i
+
2
)
*
input_width
+
2
*
j
+
2
]
*
w22
;
if
(
if_bias
)
{
output_data
[
index
]
+=
*
bias_data
;
}
if
(
if_relu
)
{
output_data
[
index
]
=
output_data
[
index
]
<
0
?
0
:
output_data
[
index
];
}
}
output_ptr0
++
;
}
}
}
}
}
}
// pad bottom
for
(
int
h
=
valid_h_end
;
h
<
output_h
;
++
h
)
{
DepthwiseConv3x3NormalRow
<
2
,
2
>
(
input_ptr
,
filter_ptr
,
h
,
input_h
,
input_w
,
padding_h
,
padding_w
,
output_w
,
output_ptr
,
_ker
);
}
}
}
#endif
}
}
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle_mobile
}
// namespace paddle_mobile
#endif // __ARM_NEON__
src/operators/math/depthwise_conv3x3.h
浏览文件 @
532cff71
...
@@ -23,48 +23,6 @@ namespace paddle_mobile {
...
@@ -23,48 +23,6 @@ namespace paddle_mobile {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
void
DepthwiseConv3x3
(
const
framework
::
Tensor
*
input
,
const
std
::
vector
<
int
>
&
strides
,
const
std
::
vector
<
int
>
&
paddings
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
bias
,
framework
::
Tensor
*
output
,
bool
if_bias
);
void
DepthwiseConv3x3s1p1
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias
,
bool
if_bias
,
bool
if_relu
);
void
DepthwiseConvAddBNRelu3x3s1p1
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
);
void
DepthwiseConvAddBNRelu3x3s2p1
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
);
void
DepthwiseConv3x3s2p1v2
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias
,
bool
if_bias
,
bool
if_relu
);
void
DepthwiseConvAddBNRelu3x3s2p1v2
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
);
void
DepthwiseConv3x3s2p0
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias
,
bool
if_bias
,
bool
if_relu
);
// TODO(hjchen2) need to be implemented
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
// void DepthwiseConv3x3(const framework::Tensor *input,
...
...
src/operators/math/gemm/cblas.cc
浏览文件 @
532cff71
...
@@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
...
@@ -31,16 +31,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// }
// }
CPUInfo
*
info
=
CPUInfo
::
Info
();
GemmExecutor
<
SgemmStrategy
>
exec
(
transA
,
transB
,
M
,
N
,
K
);
GemmExecutor
<
SgemmStrategy
>
exec
(
info
,
transA
,
transB
,
M
,
N
,
K
);
exec
(
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
exec
(
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
}
void
cblas_sgemv
(
const
bool
trans
,
const
int
M
,
const
int
N
,
const
float
alpha
,
void
cblas_sgemv
(
const
bool
trans
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
const
float
beta
,
float
*
C
)
{
CPUInfo
*
info
=
CPUInfo
::
Info
();
GemvExecutor
<
SgemvStrategy
>
exec
(
trans
,
M
,
N
);
GemvExecutor
<
SgemvStrategy
>
exec
(
info
,
trans
,
M
,
N
);
exec
(
alpha
,
A
,
lda
,
B
,
beta
,
C
);
exec
(
alpha
,
A
,
lda
,
B
,
beta
,
C
);
}
}
...
...
src/operators/math/gemm/executor.h
浏览文件 @
532cff71
...
@@ -19,7 +19,6 @@ limitations under the License. */
...
@@ -19,7 +19,6 @@ limitations under the License. */
#include <omp.h>
#include <omp.h>
#endif
#endif
#include <sys/time.h>
#include <sys/time.h>
#include <iostream>
#include "common/log.h"
#include "common/log.h"
#include "memory/t_malloc.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm/cpu_info.h"
#include "operators/math/gemm/cpu_info.h"
...
@@ -29,6 +28,8 @@ namespace paddle_mobile {
...
@@ -29,6 +28,8 @@ namespace paddle_mobile {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
static
CPUInfo
*
info
=
CPUInfo
::
Info
();
int
CeilDiv
(
const
int
&
x
,
const
int
&
y
)
{
return
(
x
+
y
-
1
)
/
y
;
}
int
CeilDiv
(
const
int
&
x
,
const
int
&
y
)
{
return
(
x
+
y
-
1
)
/
y
;
}
unsigned
int
ResetL1Cache
(
const
unsigned
int
L1_size
,
const
int
thread_num
,
unsigned
int
ResetL1Cache
(
const
unsigned
int
L1_size
,
const
int
thread_num
,
const
int
N
,
const
int
K
)
{
const
int
N
,
const
int
K
)
{
...
@@ -62,15 +63,9 @@ class GemmExecutor : public Executor {
...
@@ -62,15 +63,9 @@ class GemmExecutor : public Executor {
typedef
typename
Strategy
::
Otype
Otype
;
typedef
typename
Strategy
::
Otype
Otype
;
public:
public:
GemmExecutor
(
const
CPUInfo
*
info
,
const
bool
transA
,
const
bool
transB
,
GemmExecutor
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
M
,
const
int
N
,
const
int
K
)
const
int
K
)
:
Executor
(),
:
Executor
(),
transA_
(
transA
),
transB_
(
transB
),
M_
(
M
),
N_
(
N
),
K_
(
K
)
{
info_
(
info
),
transA_
(
transA
),
transB_
(
transB
),
M_
(
M
),
N_
(
N
),
K_
(
K
)
{
unsigned
int
L1_size
=
0
;
unsigned
int
L1_size
=
0
;
unsigned
int
L2_size
=
0
;
unsigned
int
L2_size
=
0
;
if
(
M_
>
N_
)
{
if
(
M_
>
N_
)
{
...
@@ -212,8 +207,6 @@ class GemmExecutor : public Executor {
...
@@ -212,8 +207,6 @@ class GemmExecutor : public Executor {
virtual
~
GemmExecutor
()
{}
virtual
~
GemmExecutor
()
{}
private:
private:
const
CPUInfo
*
info_
;
const
unsigned
int
M_
;
const
unsigned
int
M_
;
const
unsigned
int
N_
;
const
unsigned
int
N_
;
const
unsigned
int
K_
;
const
unsigned
int
K_
;
...
@@ -242,8 +235,8 @@ class GemvExecutor : public Executor {
...
@@ -242,8 +235,8 @@ class GemvExecutor : public Executor {
typedef
typename
Strategy
::
Otype
Otype
;
typedef
typename
Strategy
::
Otype
Otype
;
public:
public:
GemvExecutor
(
const
CPUInfo
*
info
,
const
bool
transA
,
const
int
M
,
const
int
N
)
GemvExecutor
(
const
bool
transA
,
const
int
M
,
const
int
N
)
:
Executor
(),
info_
(
info
),
M_
(
M
),
N_
(
N
)
{}
:
Executor
(),
M_
(
M
),
N_
(
N
)
{}
void
operator
()(
const
float
alpha
,
const
Itype
*
A
,
const
int
lda
,
void
operator
()(
const
float
alpha
,
const
Itype
*
A
,
const
int
lda
,
const
Itype
*
B
,
const
float
beta
,
Otype
*
C
)
{
const
Itype
*
B
,
const
float
beta
,
Otype
*
C
)
{
...
@@ -253,8 +246,6 @@ class GemvExecutor : public Executor {
...
@@ -253,8 +246,6 @@ class GemvExecutor : public Executor {
virtual
~
GemvExecutor
()
{}
virtual
~
GemvExecutor
()
{}
private:
private:
const
CPUInfo
*
const
info_
;
const
unsigned
int
M_
;
const
unsigned
int
M_
;
const
unsigned
int
N_
;
const
unsigned
int
N_
;
...
...
src/operators/math/im2col.cpp
浏览文件 @
532cff71
...
@@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
...
@@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
for
(
int
i
=
start_height
;
i
<
end_height
;
i
+=
stride_h
)
{
for
(
int
i
=
start_height
;
i
<
end_height
;
i
+=
stride_h
)
{
if
(
stride_w
==
1
)
{
if
(
stride_w
==
1
)
{
memcpy
(
col_data
,
im_data
,
extract
*
sizeof
(
float
));
// memcpy(col_data, im_data, extract * sizeof(float));
int
s
=
0
;
#if __ARM_NEON
for
(;
s
<
extract
-
3
;
s
+=
4
)
{
float32x4_t
img
=
vld1q_f32
(
im_data
+
s
);
vst1q_f32
(
col_data
+
s
,
img
);
}
#endif
for
(;
s
<
extract
;
++
s
)
{
col_data
[
s
]
=
im_data
[
s
];
}
}
else
if
(
stride_w
==
2
)
{
}
else
if
(
stride_w
==
2
)
{
int
s
=
0
;
int
s
=
0
;
#if __ARM_NEON
#if __ARM_NEON
...
@@ -109,325 +119,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
...
@@ -109,325 +119,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
const
float
*
im_data
=
im
.
data
<
float
>
();
const
float
*
im_data
=
im
.
data
<
float
>
();
float
*
col_data
=
col
->
data
<
float
>
();
float
*
col_data
=
col
->
data
<
float
>
();
#if __ARM_NEON
#if __ARM_NEON
const
int
osize
=
col_height
;
if
(
stride
[
0
]
<=
4
&&
dilation
[
0
]
==
1
&&
dilation
[
0
]
==
dilation
[
1
])
{
const
int
isize
=
im_height
;
bool
pad1
=
padding
[
0
]
>
0
;
bool
pad2
=
(
pad1
&&
padding
[
1
]
&&
(((
isize
-
2
*
padding
[
0
]
+
filter_height
)
%
stride
[
0
]
==
0
)
?
1
:
0
));
int
fill
=
isize
%
2
;
if
(
stride
[
0
]
==
1
&&
filter_height
==
3
&&
pad1
&&
pad2
&&
dilation
[
0
]
==
1
&&
im_height
>
2
&&
im_height
==
im_width
)
{
for
(
int
c
=
0
;
c
<
im_channels
;
++
c
)
{
int
oosize
=
osize
*
osize
;
int
nk4
=
osize
/
4
;
int
mk4
=
osize
%
4
;
float
*
col0
=
col_data
+
0
*
oosize
+
2
*
osize
+
2
;
float
*
col1
=
col_data
+
1
*
oosize
+
2
*
osize
+
1
;
float
*
col2
=
col_data
+
2
*
oosize
+
2
*
osize
;
float
*
col3
=
col_data
+
3
*
oosize
+
osize
+
2
;
float
*
col4
=
col_data
+
4
*
oosize
+
osize
+
1
;
float
*
col5
=
col_data
+
5
*
oosize
+
osize
;
float
*
col6
=
col_data
+
6
*
oosize
+
2
;
float
*
col7
=
col_data
+
7
*
oosize
+
1
;
float
*
col8
=
col_data
+
8
*
oosize
;
float32x4_t
im1
;
const
float
*
im_tmp_data
=
im_data
+
osize
+
1
;
int
rrsize
=
oosize
-
osize
-
1
;
int
nr4
=
rrsize
/
4
;
int
mr4
=
rrsize
%
4
;
for
(
int
i
=
0
;
i
<
nr4
;
++
i
)
{
im1
=
vld1q_f32
(
im_tmp_data
);
vst1q_f32
(
col0
,
im1
);
vst1q_f32
(
col1
,
im1
);
vst1q_f32
(
col2
,
im1
);
vst1q_f32
(
col3
,
im1
);
vst1q_f32
(
col4
,
im1
);
vst1q_f32
(
col5
,
im1
);
vst1q_f32
(
col6
,
im1
);
vst1q_f32
(
col7
,
im1
);
vst1q_f32
(
col8
,
im1
);
col0
+=
4
;
col1
+=
4
;
col2
+=
4
;
col3
+=
4
;
col4
+=
4
;
col5
+=
4
;
col6
+=
4
;
col7
+=
4
;
col8
+=
4
;
im_tmp_data
+=
4
;
}
for
(
int
i
=
0
;
i
<
mr4
;
++
i
)
{
*
col0
=
*
im_tmp_data
;
*
col1
=
*
im_tmp_data
;
*
col2
=
*
im_tmp_data
;
*
col3
=
*
im_tmp_data
;
*
col4
=
*
im_tmp_data
;
*
col5
=
*
im_tmp_data
;
*
col6
=
*
im_tmp_data
;
*
col7
=
*
im_tmp_data
;
*
col8
=
*
im_tmp_data
;
col0
++
;
col1
++
;
col2
++
;
col3
++
;
col4
++
;
col5
++
;
col6
++
;
col7
++
;
col8
++
;
im_tmp_data
++
;
}
im_tmp_data
=
im_data
+
1
;
col0
=
col_data
+
0
*
oosize
+
osize
+
2
;
col1
=
col_data
+
1
*
oosize
+
osize
+
1
;
col2
=
col_data
+
2
*
oosize
+
osize
;
col3
=
col_data
+
3
*
oosize
+
2
;
col4
=
col_data
+
4
*
oosize
+
1
;
col5
=
col_data
+
5
*
oosize
;
for
(
int
i
=
0
;
i
<
nk4
;
i
++
)
{
im1
=
vld1q_f32
(
im_tmp_data
);
vst1q_f32
(
col0
,
im1
);
vst1q_f32
(
col1
,
im1
);
vst1q_f32
(
col2
,
im1
);
vst1q_f32
(
col3
,
im1
);
vst1q_f32
(
col4
,
im1
);
vst1q_f32
(
col5
,
im1
);
col0
+=
4
;
col1
+=
4
;
col2
+=
4
;
col3
+=
4
;
col4
+=
4
;
col5
+=
4
;
im_tmp_data
+=
4
;
}
for
(
int
i
=
0
;
i
<
mk4
;
i
++
)
{
*
col0
=
*
im_tmp_data
;
*
col1
=
*
im_tmp_data
;
*
col2
=
*
im_tmp_data
;
*
col3
=
*
im_tmp_data
;
*
col4
=
*
im_tmp_data
;
*
col5
=
*
im_tmp_data
;
col0
++
;
col1
++
;
col2
++
;
col3
++
;
col4
++
;
col5
++
;
im_tmp_data
++
;
}
// fill 0 1 11;
for
(
int
i
=
0
;
i
<
osize
;
++
i
)
{
col_data
[
0
*
oosize
+
i
*
osize
]
=
0.0
;
col_data
[
3
*
oosize
+
i
*
osize
]
=
0.0
;
col_data
[
6
*
oosize
+
i
*
osize
]
=
0.0
;
col_data
[
2
*
oosize
+
osize
-
1
+
i
*
osize
]
=
0.0
;
col_data
[
5
*
oosize
+
osize
-
1
+
i
*
osize
]
=
0.0
;
col_data
[
8
*
oosize
+
osize
-
1
+
i
*
osize
]
=
0.0
;
}
col_data
[
0
*
oosize
+
osize
+
1
]
=
im_data
[
0
];
col_data
[
3
*
oosize
+
1
]
=
im_data
[
0
];
col_data
[
6
*
oosize
+
1
]
=
im_data
[
osize
];
col_data
[
1
*
oosize
+
osize
]
=
im_data
[
0
];
col_data
[
4
*
oosize
]
=
im_data
[
0
];
col_data
[
7
*
oosize
]
=
im_data
[
osize
];
float32x4_t
zero4
;
zero4
=
vdupq_n_f32
(
0.0
);
auto
col_z0
=
col_data
;
auto
col_z1
=
col_data
+
oosize
;
auto
col_z2
=
col_data
+
2
*
oosize
;
auto
col_z6
=
col_data
+
6
*
oosize
+
osize
*
(
osize
-
1
);
auto
col_z7
=
col_data
+
7
*
oosize
+
osize
*
(
osize
-
1
);
auto
col_z8
=
col_data
+
8
*
oosize
+
osize
*
(
osize
-
1
);
for
(
int
i
=
0
;
i
<
nk4
;
++
i
)
{
vst1q_f32
(
col_z0
,
zero4
);
vst1q_f32
(
col_z1
,
zero4
);
vst1q_f32
(
col_z2
,
zero4
);
vst1q_f32
(
col_z6
,
zero4
);
vst1q_f32
(
col_z7
,
zero4
);
vst1q_f32
(
col_z8
,
zero4
);
col_z0
+=
4
;
col_z1
+=
4
;
col_z2
+=
4
;
col_z6
+=
4
;
col_z7
+=
4
;
col_z8
+=
4
;
}
for
(
int
i
=
0
;
i
<
mk4
;
++
i
)
{
col_z0
[
i
]
=
0.0
;
col_z1
[
i
]
=
0.0
;
col_z2
[
i
]
=
0.0
;
col_z6
[
i
]
=
0.0
;
col_z7
[
i
]
=
0.0
;
col_z8
[
i
]
=
0.0
;
}
col_data
+=
9
*
oosize
;
im_data
+=
isize
*
isize
;
}
}
else
if
(
stride
[
0
]
==
2
&&
filter_height
==
3
&&
pad1
&&
dilation
[
0
]
==
1
&&
im_height
>
2
&&
im_height
==
im_width
)
{
for
(
int
c
=
0
;
c
<
im_channels
;
++
c
)
{
int
oosize
=
osize
*
osize
;
int
nk4
=
osize
/
4
;
int
mk4
=
osize
%
4
;
// 3 2 3 1 0 1 3 2 3
float
*
col0
=
col_data
+
0
*
oosize
+
osize
+
1
;
float
*
col1
=
col_data
+
1
*
oosize
+
osize
;
float
*
col2
=
col_data
+
2
*
oosize
+
osize
;
float
*
col3
=
col_data
+
3
*
oosize
+
1
;
float
*
col4
=
col_data
+
4
*
oosize
;
float
*
col5
=
col_data
+
5
*
oosize
;
float
*
col6
=
col_data
+
6
*
oosize
+
1
;
float
*
col7
=
col_data
+
7
*
oosize
;
float
*
col8
=
col_data
+
8
*
oosize
;
float32x4x2_t
im01
;
float32x4x2_t
im23
;
const
float
*
im_tmp_data0
=
im_data
;
const
float
*
im_tmp_data2
=
im_data
+
isize
;
for
(
int
j
=
0
;
j
<
osize
;
++
j
)
{
for
(
int
i
=
0
;
i
<
nk4
;
++
i
)
{
im01
=
vld2q_f32
(
im_tmp_data0
);
im23
=
vld2q_f32
(
im_tmp_data2
);
vst1q_f32
(
col0
,
im23
.
val
[
1
]);
vst1q_f32
(
col1
,
im23
.
val
[
0
]);
vst1q_f32
(
col2
,
im23
.
val
[
1
]);
vst1q_f32
(
col3
,
im01
.
val
[
1
]);
vst1q_f32
(
col4
,
im01
.
val
[
0
]);
vst1q_f32
(
col5
,
im01
.
val
[
1
]);
vst1q_f32
(
col6
,
im23
.
val
[
1
]);
vst1q_f32
(
col7
,
im23
.
val
[
0
]);
vst1q_f32
(
col8
,
im23
.
val
[
1
]);
col0
+=
4
;
col1
+=
4
;
col2
+=
4
;
col3
+=
4
;
col4
+=
4
;
col5
+=
4
;
col6
+=
4
;
col7
+=
4
;
col8
+=
4
;
im_tmp_data0
+=
8
;
im_tmp_data2
+=
8
;
}
const
float
*
im_tmp_data1
=
im_tmp_data0
+
1
;
const
float
*
im_tmp_data3
=
im_tmp_data2
+
1
;
for
(
int
i
=
0
;
i
<
mk4
;
++
i
)
{
*
col0
=
*
im_tmp_data3
;
*
col1
=
*
im_tmp_data2
;
*
col2
=
*
im_tmp_data3
;
*
col3
=
*
im_tmp_data1
;
*
col4
=
*
im_tmp_data0
;
*
col5
=
*
im_tmp_data1
;
*
col6
=
*
im_tmp_data3
;
*
col7
=
*
im_tmp_data2
;
*
col8
=
*
im_tmp_data3
;
col0
++
;
col1
++
;
col2
++
;
col3
++
;
col4
++
;
col5
++
;
col6
++
;
col7
++
;
col8
++
;
im_tmp_data0
+=
2
;
im_tmp_data1
+=
2
;
im_tmp_data2
+=
2
;
im_tmp_data3
+=
2
;
}
im_tmp_data0
+=
(
isize
-
fill
);
im_tmp_data2
+=
(
isize
-
fill
);
}
for
(
int
i
=
0
;
i
<
osize
;
++
i
)
{
col_data
[
0
*
oosize
+
i
*
osize
]
=
0.0
;
col_data
[
3
*
oosize
+
i
*
osize
]
=
0.0
;
col_data
[
6
*
oosize
+
i
*
osize
]
=
0.0
;
if
(
pad2
)
{
col_data
[
2
*
oosize
+
osize
-
1
+
i
*
osize
]
=
0.0
;
col_data
[
5
*
oosize
+
osize
-
1
+
i
*
osize
]
=
0.0
;
col_data
[
8
*
oosize
+
osize
-
1
+
i
*
osize
]
=
0.0
;
}
}
float32x4_t
zero4
;
zero4
=
vdupq_n_f32
(
0.0
);
auto
col_z0
=
col_data
;
auto
col_z1
=
col_data
+
oosize
;
auto
col_z2
=
col_data
+
2
*
oosize
;
auto
col_z6
=
col_data
+
6
*
oosize
+
osize
*
(
osize
-
1
);
auto
col_z7
=
col_data
+
7
*
oosize
+
osize
*
(
osize
-
1
);
auto
col_z8
=
col_data
+
8
*
oosize
+
osize
*
(
osize
-
1
);
for
(
int
i
=
0
;
i
<
nk4
;
++
i
)
{
vst1q_f32
(
col_z0
,
zero4
);
vst1q_f32
(
col_z1
,
zero4
);
vst1q_f32
(
col_z2
,
zero4
);
if
(
pad2
)
{
vst1q_f32
(
col_z6
,
zero4
);
vst1q_f32
(
col_z7
,
zero4
);
vst1q_f32
(
col_z8
,
zero4
);
}
col_z0
+=
4
;
col_z1
+=
4
;
col_z2
+=
4
;
col_z6
+=
4
;
col_z7
+=
4
;
col_z8
+=
4
;
}
for
(
int
i
=
0
;
i
<
mk4
;
++
i
)
{
col_z0
[
i
]
=
0.0
;
col_z1
[
i
]
=
0.0
;
col_z2
[
i
]
=
0.0
;
if
(
pad2
)
{
col_z6
[
i
]
=
0.0
;
col_z7
[
i
]
=
0.0
;
col_z8
[
i
]
=
0.0
;
}
}
col_data
[
1
*
oosize
+
osize
]
=
im_data
[
isize
];
for
(
int
i
=
1
;
i
<
osize
;
++
i
)
{
col_data
[
3
*
oosize
+
i
]
=
im_data
[(
i
-
1
)
*
stride
[
0
]
+
1
];
}
col_data
[
4
*
oosize
]
=
im_data
[
0
];
col_data
[
7
*
oosize
]
=
im_data
[
isize
];
col_data
+=
9
*
oosize
;
im_data
+=
isize
*
isize
;
}
}
else
if
(
stride
[
0
]
<=
4
&&
dilation
[
0
]
==
1
&&
dilation
[
0
]
==
dilation
[
1
])
{
int
im_spatial_size
=
im_height
*
im_width
;
int
im_spatial_size
=
im_height
*
im_width
;
int
col_spatial_size
=
col_height
*
col_width
;
int
col_spatial_size
=
col_height
*
col_width
;
// pad 0
// pad 0
...
...
src/operators/op_param.h
浏览文件 @
532cff71
...
@@ -441,10 +441,8 @@ class ConvParam : public OpParam {
...
@@ -441,10 +441,8 @@ class ConvParam : public OpParam {
enum
ExecMode
{
enum
ExecMode
{
EXEC_INVALID
=
0
,
EXEC_INVALID
=
0
,
EXEC_GEMM_FLOAT
,
EXEC_GEMM_FLOAT
,
EXEC_DEPTHWISE3x3S1P1_FLOAT
,
EXEC_DEPTHWISE3x3S1_FLOAT
,
EXEC_DEPTHWISE3x3S2P0_FLOAT
,
EXEC_DEPTHWISE3x3S2_FLOAT
,
EXEC_DEPTHWISE3x3S2P1_FLOAT
,
EXEC_DEPTHWISE3x3_FLOAT
,
EXEC_WINOGRAD3X3_FLOAT
,
EXEC_WINOGRAD3X3_FLOAT
,
EXEC_WINOGRAD5X5_FLOAT
,
EXEC_WINOGRAD5X5_FLOAT
,
EXEC_DEPTHWISE5x5_FLOAT
,
EXEC_DEPTHWISE5x5_FLOAT
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录