Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
cd1b6c08
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cd1b6c08
编写于
3月 12, 2019
作者:
H
hjchen2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize vector-matrix and matrix-vector multiply
上级
1d475a2c
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
370 addition
and
149 deletion
+370
-149
src/operators/kernel/arm/convolution/conv_kernel.cpp
src/operators/kernel/arm/convolution/conv_kernel.cpp
+0
-2
src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
...perators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
+0
-2
src/operators/kernel/central-arm-func/conv_arm_func.cpp
src/operators/kernel/central-arm-func/conv_arm_func.cpp
+21
-17
src/operators/math/depthwise_conv3x3.cpp
src/operators/math/depthwise_conv3x3.cpp
+37
-32
src/operators/math/depthwise_conv3x3_int8.cpp
src/operators/math/depthwise_conv3x3_int8.cpp
+1
-2
src/operators/math/depthwise_conv5x5.cpp
src/operators/math/depthwise_conv5x5.cpp
+14
-11
src/operators/math/gemm/cblas.cc
src/operators/math/gemm/cblas.cc
+8
-6
src/operators/math/gemm/executor.h
src/operators/math/gemm/executor.h
+3
-2
src/operators/math/gemm/gemm_kernel.h
src/operators/math/gemm/gemm_kernel.h
+194
-0
src/operators/math/gemm/pack_kernel.h
src/operators/math/gemm/pack_kernel.h
+27
-10
src/operators/math/gemm/strategy.h
src/operators/math/gemm/strategy.h
+5
-12
src/operators/math/math.h
src/operators/math/math.h
+12
-0
test/common/test_gemm_accuracy.cpp
test/common/test_gemm_accuracy.cpp
+48
-53
未找到文件。
src/operators/kernel/arm/convolution/conv_kernel.cpp
浏览文件 @
cd1b6c08
...
...
@@ -18,8 +18,6 @@ limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <iostream>
namespace
paddle_mobile
{
namespace
operators
{
...
...
src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
浏览文件 @
cd1b6c08
...
...
@@ -65,14 +65,12 @@ void DWConvBNReluKernel<CPU, float>::Compute(
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S2_FLOAT
:
DepthwiseConv3x3
<
float
,
float
>
(
param
);
break
;
#ifndef __aarch64__
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE5x5_FLOAT
:
DepthwiseConv5x5
<
float
,
float
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_WINOGRAD3X3_FLOAT
:
WinogradConv3x3
<
8
,
3
>
(
param
);
break
;
#endif // __aarch64__
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
break
;
...
...
src/operators/kernel/central-arm-func/conv_arm_func.cpp
浏览文件 @
cd1b6c08
...
...
@@ -190,19 +190,23 @@ void DepthwiseConv3x3(const ConvParam<CPU> ¶m) {
Tensor
*
output
=
param
.
Output
();
output
->
mutable_data
<
Otype
>
();
if
(
strides
[
0
]
==
1
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
);
if
(
strides
[
0
]
==
1
)
{
math
::
DepthwiseConv3x3S1
<
Itype
,
Otype
>
(
in_batch
,
*
filter
,
paddings
,
&
out_batch
);
}
}
else
if
(
strides
[
0
]
==
2
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
);
math
::
DepthwiseConv3x3S2
<
Itype
,
Otype
>
(
in_batch
,
*
filter
,
paddings
,
&
out_batch
);
}
}
else
{
GemmConv
<
Itype
,
Otype
>
(
param
);
}
}
}
template
<
typename
Itype
,
typename
Otype
>
...
...
@@ -215,16 +219,16 @@ void DepthwiseConv5x5(const ConvParam<CPU> ¶m) {
Tensor
*
output
=
param
.
Output
();
output
->
mutable_data
<
Otype
>
();
//
if (strides[0] == 1) {
//
for (int i = 0; i < batch_size; i++) {
//
Tensor in_batch = input->Slice(i, i + 1);
//
Tensor out_batch = output->Slice(i, i + 1);
//
math::DepthwiseConv5x5S1<Itype, Otype>(in_batch, *filter, paddings,
//
&out_batch);
//
}
//
} else {
if
(
strides
[
0
]
==
1
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
);
math
::
DepthwiseConv5x5S1
<
Itype
,
Otype
>
(
in_batch
,
*
filter
,
paddings
,
&
out_batch
);
}
}
else
{
GemmConv
<
Itype
,
Otype
>
(
param
);
//
}
}
}
template
void
GemmConv
<
float
,
float
>(
const
ConvParam
<
CPU
>
&
param
);
...
...
src/operators/math/depthwise_conv3x3.cpp
浏览文件 @
cd1b6c08
...
...
@@ -73,8 +73,11 @@ inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter,
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
const
int
h_end
=
h_in_end
<
input_h
?
h_in_end
:
input_h
;
const
int
valid_w_start
=
(
padding_w
+
Stride_w
-
1
)
/
Stride_w
;
const
int
valid_w_end
=
(
input_w
+
padding_w
-
3
)
/
Stride_w
+
1
;
int
valid_w_start
=
(
padding_w
+
Stride_w
-
1
)
/
Stride_w
;
int
valid_w_end
=
(
input_w
+
padding_w
-
3
)
/
Stride_w
+
1
;
if
(
valid_w_end
<
valid_w_start
)
{
valid_w_end
=
valid_w_start
;
}
// const int valid_w_end = output_w - valid_w_start;
float
*
output_ptr
=
output
+
h_output
*
output_w
;
// border left
...
...
@@ -120,7 +123,7 @@ inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter,
vst1_f32
(
output_ptr0
,
vget_low_f32
(
_sum
));
break
;
case
1
:
vst1
_lane_f32
(
output_ptr0
,
vget_low_f32
(
_sum
)
,
0
);
vst1
q_lane_f32
(
output_ptr0
,
_sum
,
0
);
break
;
}
}
...
...
@@ -136,20 +139,21 @@ void DepthwiseConv3x3S1<float, float>(const framework::Tensor &input,
const
float
*
input_data
=
input
.
data
<
float
>
();
const
float
*
filter_data
=
filter
.
data
<
float
>
();
float
*
out_data
=
output
->
mutable_data
<
float
>
();
int
input_h
=
input
.
dims
()[
2
];
int
input_w
=
input
.
dims
()[
3
];
int
output_h
=
output
->
dims
()[
2
];
int
output_w
=
output
->
dims
()[
3
];
int
padding_h
=
paddings
[
0
];
int
padding_w
=
paddings
[
1
];
int
image_size
=
input_h
*
input_w
;
int
out_image_size
=
output_h
*
output_w
;
int
valid_h_start
=
padding_h
;
int
valid_h_end
=
output_h
-
valid_h_start
;
int
valid_h
=
valid_h_end
-
valid_h_start
;
int
valid_w_start
=
padding_w
;
int
valid_w_end
=
output_w
-
valid_w_start
;
int
valid_w
=
valid_w_end
-
valid_w_start
;
const
int
input_h
=
input
.
dims
()[
2
];
const
int
input_w
=
input
.
dims
()[
3
];
const
int
output_h
=
output
->
dims
()[
2
];
const
int
output_w
=
output
->
dims
()[
3
];
const
int
padding_h
=
paddings
[
0
];
const
int
padding_w
=
paddings
[
1
];
const
int
image_size
=
input_h
*
input_w
;
const
int
out_image_size
=
output_h
*
output_w
;
const
int
valid_h_start
=
padding_h
;
const
int
valid_h_end
=
output_h
-
valid_h_start
;
const
int
valid_h
=
valid_h_end
-
valid_h_start
;
const
int
valid_w_start
=
padding_w
;
const
int
valid_w_end
=
output_w
-
valid_w_start
;
const
int
valid_w
=
valid_w_end
-
valid_w_start
;
#pragma omp parallel for
for
(
int
g
=
0
;
g
<
input
.
dims
()[
1
];
++
g
)
{
...
...
@@ -643,21 +647,22 @@ void DepthwiseConv3x3S2<float, float>(const framework::Tensor &input,
const
float
*
input_data
=
input
.
data
<
float
>
();
const
float
*
filter_data
=
filter
.
data
<
float
>
();
float
*
out_data
=
output
->
mutable_data
<
float
>
();
int
input_h
=
input
.
dims
()[
2
];
int
input_w
=
input
.
dims
()[
3
];
int
output_h
=
output
->
dims
()[
2
];
int
output_w
=
output
->
dims
()[
3
];
int
padding_h
=
paddings
[
0
];
int
padding_w
=
paddings
[
1
];
int
image_size
=
input_h
*
input_w
;
int
out_image_size
=
output_h
*
output_w
;
int
valid_h_start
=
(
padding_h
+
1
)
/
2
;
int
valid_h_end
=
(
input_h
+
padding_h
-
1
)
/
2
;
int
valid_h
=
valid_h_end
-
valid_h_start
;
int
valid_w_start
=
(
padding_w
+
1
)
/
2
;
int
valid_w_end
=
(
input_w
+
padding_w
-
1
)
/
2
;
int
valid_w
=
valid_w_end
-
valid_w_start
;
int
input_w_start
=
2
*
valid_w_start
-
padding_w
;
const
int
input_h
=
input
.
dims
()[
2
];
const
int
input_w
=
input
.
dims
()[
3
];
const
int
output_h
=
output
->
dims
()[
2
];
const
int
output_w
=
output
->
dims
()[
3
];
const
int
padding_h
=
paddings
[
0
];
const
int
padding_w
=
paddings
[
1
];
const
int
image_size
=
input_h
*
input_w
;
const
int
out_image_size
=
output_h
*
output_w
;
const
int
valid_h_start
=
(
padding_h
+
1
)
/
2
;
const
int
valid_h_end
=
(
input_h
+
padding_h
-
1
)
/
2
;
const
int
valid_h
=
valid_h_end
-
valid_h_start
;
const
int
valid_w_start
=
(
padding_w
+
1
)
/
2
;
const
int
valid_w_end
=
(
input_w
+
padding_w
-
1
)
/
2
;
const
int
valid_w
=
valid_w_end
-
valid_w_start
;
const
int
input_w_start
=
2
*
valid_w_start
-
padding_w
;
#pragma omp parallel for
for
(
int
g
=
0
;
g
<
input
.
dims
()[
1
];
++
g
)
{
...
...
src/operators/math/depthwise_conv3x3_int8.cpp
浏览文件 @
cd1b6c08
...
...
@@ -69,9 +69,8 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
// border left
DEPTHWISE_CONV_NORMAL_BORDER
(
0
,
valid_w_start
)
// middle
int
remain_start
=
valid_w_start
;
int
output_tiles
=
(
valid_w_end
-
valid_w_start
)
/
6
;
remain_start
=
valid_w_start
+
output_tiles
*
6
;
int
remain_start
=
valid_w_start
+
output_tiles
*
6
;
int32x4_t
_sum0
,
_sum1
;
int16x8_t
_y
[
3
];
for
(
int
w
=
0
;
w
<
output_tiles
*
6
;
w
+=
6
)
{
...
...
src/operators/math/depthwise_conv5x5.cpp
浏览文件 @
cd1b6c08
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "operators/math/depthwise_conv5x5.h"
#include <arm_neon.h>
#include <iostream>
namespace
paddle_mobile
{
namespace
operators
{
...
...
@@ -48,7 +49,7 @@ inline void Depth5x5NormalRowLoadInput<2>(const float *input, float32x4_t *y) {
y
[
4
]
=
vextq_f32
(
y
[
0
],
y
[
0
],
2
);
}
#define DEPTHWISE_CONV
_NORMAL_BORDER(start, end)
\
#define DEPTHWISE_CONV
5X5_NORMAL_BORDER(start, end)
\
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 5; \
...
...
@@ -77,10 +78,14 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
const
int
h_end
=
h_in_end
<
input_h
?
h_in_end
:
input_h
;
int
valid_w_start
=
(
padding_w
+
Stride_w
-
1
)
/
Stride_w
;
int
valid_w_end
=
output_w
-
valid_w_start
;
int
valid_w_end
=
(
input_w
+
padding_w
-
5
)
/
Stride_w
+
1
;
if
(
valid_w_end
<
valid_w_start
)
{
valid_w_end
=
valid_w_start
;
}
float
*
output_ptr
=
output
+
h_output
*
output_w
;
// border left
DEPTHWISE_CONV_NORMAL_BORDER
(
0
,
valid_w_start
)
DEPTHWISE_CONV
5X5
_NORMAL_BORDER
(
0
,
valid_w_start
)
// middle
int
output_tiles
=
(
valid_w_end
-
valid_w_start
)
>>
2
;
float32x4_t
_sum
,
_x
[
5
];
...
...
@@ -120,20 +125,18 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
_sum
=
vmlaq_lane_f32
(
_sum
,
_x
[
4
],
vget_high_f32
(
ker
[
index
]),
1
);
}
switch
(
remain
)
{
case
1
:
vst1_lane_f32
(
output_ptr0
,
vget_low_f32
(
_sum
),
0
);
break
;
case
3
:
vst1q_lane_f32
(
output_ptr0
+
2
,
_sum
,
2
);
case
2
:
vst1_f32
(
output_ptr0
,
vget_low_f32
(
_sum
));
break
;
case
3
:
vst1_f32
(
output_ptr0
,
vget_low_f32
(
_sum
));
vst1_lane_f32
(
output_ptr0
+
2
,
vget_high_f32
(
_sum
),
0
);
case
1
:
vst1q_lane_f32
(
output_ptr0
,
_sum
,
0
);
break
;
}
}
// border right
DEPTHWISE_CONV_NORMAL_BORDER
(
valid_w_end
,
output_w
)
DEPTHWISE_CONV
5X5
_NORMAL_BORDER
(
valid_w_end
,
output_w
)
}
template
<
>
...
...
@@ -161,7 +164,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
const
int
valid_w
=
valid_w_end
-
valid_w_start
;
#pragma omp parallel for
for
(
int
g
=
0
;
g
<
input
.
dims
()[
1
];
++
g
)
{
for
(
int
g
=
0
;
g
<
output
->
dims
()[
1
];
++
g
)
{
const
float
*
input_ptr
=
input_data
+
g
*
image_size
;
const
float
*
filter_ptr
=
filter_data
+
g
*
25
;
float
*
output_ptr
=
out_data
+
g
*
out_image_size
;
...
...
src/operators/math/gemm/cblas.cc
浏览文件 @
cd1b6c08
...
...
@@ -27,12 +27,14 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
// if (N == 1) {
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// }
if
(
N
==
1
)
{
return
cblas_sgemv
(
transA
,
M
,
K
,
alpha
,
A
,
lda
,
B
,
beta
,
C
);
}
else
if
(
M
==
1
)
{
return
cblas_sgemv
(
!
transB
,
N
,
K
,
alpha
,
B
,
ldb
,
A
,
beta
,
C
);
}
else
{
GemmExecutor
<
SgemmStrategy
>
exec
(
transA
,
transB
,
M
,
N
,
K
);
exec
(
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
}
void
cblas_sgemv
(
const
bool
trans
,
const
int
M
,
const
int
N
,
const
float
alpha
,
...
...
src/operators/math/gemm/executor.h
浏览文件 @
cd1b6c08
...
...
@@ -239,11 +239,11 @@ class GemvExecutor : public Executor {
public:
GemvExecutor
(
const
bool
transA
,
const
int
M
,
const
int
N
)
:
Executor
(),
M_
(
M
),
N_
(
N
)
{}
:
Executor
(),
M_
(
M
),
N_
(
N
)
,
trans_
(
transA
)
{}
void
operator
()(
const
float
alpha
,
const
Itype
*
A
,
const
int
lda
,
const
Itype
*
B
,
const
float
beta
,
Otype
*
C
)
{
// strategy_.kernel(
);
strategy_
.
kernel
(
trans_
,
M_
,
N_
,
alpha
,
A
,
lda
,
B
,
beta
,
C
);
}
virtual
~
GemvExecutor
()
{}
...
...
@@ -251,6 +251,7 @@ class GemvExecutor : public Executor {
private:
const
unsigned
int
M_
;
const
unsigned
int
N_
;
const
bool
trans_
;
Strategy
strategy_
;
};
...
...
src/operators/math/gemm/gemm_kernel.h
浏览文件 @
cd1b6c08
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#include "operators/math/math.h"
namespace
paddle_mobile
{
namespace
operators
{
...
...
@@ -325,6 +326,199 @@ void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output,
}
#endif // __aarch64__
void
sgemv_notrans_mx1
(
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
uint32_t
mask
[
4
]
=
{
0
,
1
,
2
,
3
};
int
remain_n
=
N
&
0x3
;
uint32x4_t
vmask
=
vcltq_u32
(
vld1q_u32
(
mask
),
vdupq_n_u32
(
remain_n
));
float32x4_t
_sum0
,
_sum1
,
_sum2
,
_sum3
;
float32x4_t
_valpha
=
vdupq_n_f32
(
alpha
);
#pragma omp parallel for
for
(
int
m
=
0
;
m
<
M
-
3
;
m
+=
4
)
{
const
float
*
in0
=
A
+
m
*
lda
;
const
float
*
in1
=
in0
+
lda
;
const
float
*
in2
=
in1
+
lda
;
const
float
*
in3
=
in2
+
lda
;
float
*
output
=
C
+
m
;
_sum0
=
vdupq_n_f32
(
0.
f
);
_sum1
=
vdupq_n_f32
(
0.
f
);
_sum2
=
vdupq_n_f32
(
0.
f
);
_sum3
=
vdupq_n_f32
(
0.
f
);
int
n
=
0
;
for
(;
n
<
N
-
3
;
n
+=
4
)
{
float32x4_t
_r0
=
vld1q_f32
(
in0
+
n
);
float32x4_t
_r1
=
vld1q_f32
(
in1
+
n
);
float32x4_t
_r2
=
vld1q_f32
(
in2
+
n
);
float32x4_t
_r3
=
vld1q_f32
(
in3
+
n
);
float32x4_t
_b
=
vld1q_f32
(
B
+
n
);
_sum0
=
vmlaq_f32
(
_sum0
,
_r0
,
_b
);
_sum1
=
vmlaq_f32
(
_sum1
,
_r1
,
_b
);
_sum2
=
vmlaq_f32
(
_sum2
,
_r2
,
_b
);
_sum3
=
vmlaq_f32
(
_sum3
,
_r3
,
_b
);
}
if
(
n
<
N
)
{
float32x4_t
_r0
=
vld1q_f32
(
in0
+
n
);
float32x4_t
_r1
=
vld1q_f32
(
in1
+
n
);
float32x4_t
_r2
=
vld1q_f32
(
in2
+
n
);
float32x4_t
_r3
=
vld1q_f32
(
in3
+
n
);
float32x4_t
_b
=
vld1q_f32
(
B
+
n
);
_r0
=
vandq_f32_u32
(
_r0
,
vmask
);
_r1
=
vandq_f32_u32
(
_r1
,
vmask
);
_r2
=
vandq_f32_u32
(
_r2
,
vmask
);
_r3
=
vandq_f32_u32
(
_r3
,
vmask
);
_b
=
vandq_f32_u32
(
_b
,
vmask
);
_sum0
=
vmlaq_f32
(
_sum0
,
_r0
,
_b
);
_sum1
=
vmlaq_f32
(
_sum1
,
_r1
,
_b
);
_sum2
=
vmlaq_f32
(
_sum2
,
_r2
,
_b
);
_sum3
=
vmlaq_f32
(
_sum3
,
_r3
,
_b
);
}
_sum0
=
vpaddq_f32
(
_sum0
,
_sum1
);
_sum2
=
vpaddq_f32
(
_sum2
,
_sum3
);
_sum0
=
vpaddq_f32
(
_sum0
,
_sum2
);
_sum0
=
vmulq_f32
(
_sum0
,
_valpha
);
if
(
beta
!=
0.
f
)
{
_sum2
=
vmulq_n_f32
(
vld1q_f32
(
output
),
beta
);
_sum0
=
vaddq_f32
(
_sum0
,
_sum2
);
}
// restore
vst1q_f32
(
output
,
_sum0
);
}
// remain m
for
(
int
m
=
(
M
&
0xfffc
);
m
<
M
;
++
m
)
{
const
float
*
in0
=
A
+
m
*
lda
;
float
*
output
=
C
+
m
;
_sum0
=
vdupq_n_f32
(
0.
f
);
int
n
=
0
;
for
(;
n
<
N
-
3
;
n
+=
4
)
{
float32x4_t
_r0
=
vld1q_f32
(
in0
+
n
);
float32x4_t
_b
=
vld1q_f32
(
B
+
n
);
_sum0
=
vmlaq_f32
(
_sum0
,
_r0
,
_b
);
}
if
(
n
<
N
)
{
float32x4_t
_r0
=
vld1q_f32
(
in0
+
n
);
float32x4_t
_b
=
vld1q_f32
(
B
+
n
);
_r0
=
vandq_f32_u32
(
_r0
,
vmask
);
_b
=
vandq_f32_u32
(
_b
,
vmask
);
_sum0
=
vmlaq_f32
(
_sum0
,
_r0
,
_b
);
}
_sum0
=
vpaddq_f32
(
_sum0
,
_sum0
);
_sum0
=
vmulq_f32
(
_sum0
,
_valpha
);
if
(
beta
!=
0.
f
)
{
_sum2
=
vmulq_n_f32
(
vld1q_f32
(
output
),
beta
);
_sum0
=
vpaddq_f32
(
_sum0
,
_sum2
);
}
// restore
*
output
=
vgetq_lane_f32
(
_sum0
,
0
)
+
vgetq_lane_f32
(
_sum0
,
1
);
}
}
void
sgemv_trans_mx1
(
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
float32x4_t
_valpha
=
vdupq_n_f32
(
alpha
);
if
(
beta
==
0.
f
)
{
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
for
(
int
m
=
0
;
m
<
M
-
3
;
m
+=
4
)
{
vst1q_f32
(
C
+
m
,
vzero
);
}
for
(
int
m
=
(
M
&
0xfffc
);
m
<
M
;
++
m
)
{
C
[
m
]
=
0.
f
;
}
}
else
{
float32x4_t
vbeta
=
vdupq_n_f32
(
beta
);
for
(
int
m
=
0
;
m
<
M
-
3
;
m
+=
4
)
{
float32x4_t
_vc
=
vld1q_f32
(
C
+
m
);
_vc
=
vmulq_f32
(
_vc
,
vbeta
);
vst1q_f32
(
C
+
m
,
_vc
);
}
for
(
int
m
=
(
M
&
0xfffc
);
m
<
M
;
++
m
)
{
C
[
m
]
*=
beta
;
}
}
#pragma omp parallel for
for
(
int
n
=
0
;
n
<
N
-
3
;
n
+=
4
)
{
const
float
*
in0
=
A
+
n
*
lda
;
const
float
*
in1
=
in0
+
lda
;
const
float
*
in2
=
in1
+
lda
;
const
float
*
in3
=
in2
+
lda
;
float32x4_t
_b
=
vld1q_f32
(
B
+
n
);
float32x4_t
_sum0
;
int
m
=
0
;
for
(;
m
<
M
-
3
;
m
+=
4
)
{
float32x4_t
_r0
=
vld1q_f32
(
in0
+
m
);
float32x4_t
_r1
=
vld1q_f32
(
in1
+
m
);
float32x4_t
_r2
=
vld1q_f32
(
in2
+
m
);
float32x4_t
_r3
=
vld1q_f32
(
in3
+
m
);
float32x4_t
_vc
=
vld1q_f32
(
C
+
m
);
_sum0
=
vmulq_lane_f32
(
_r0
,
vget_low_f32
(
_b
),
0
);
_sum0
=
vmlaq_lane_f32
(
_sum0
,
_r1
,
vget_low_f32
(
_b
),
1
);
_sum0
=
vmlaq_lane_f32
(
_sum0
,
_r2
,
vget_high_f32
(
_b
),
0
);
_sum0
=
vmlaq_lane_f32
(
_sum0
,
_r3
,
vget_high_f32
(
_b
),
1
);
_sum0
=
vmulq_f32
(
_sum0
,
_valpha
);
_sum0
=
vaddq_f32
(
_sum0
,
_vc
);
vst1q_f32
(
C
+
m
,
_sum0
);
}
if
(
m
<
M
)
{
float32x4_t
_r0
=
vld1q_f32
(
in0
+
m
);
float32x4_t
_r1
=
vld1q_f32
(
in1
+
m
);
float32x4_t
_r2
=
vld1q_f32
(
in2
+
m
);
float32x4_t
_r3
=
vld1q_f32
(
in3
+
m
);
float32x4_t
_vc
=
vld1q_f32
(
C
+
m
);
_sum0
=
vmulq_lane_f32
(
_r0
,
vget_low_f32
(
_b
),
0
);
_sum0
=
vmlaq_lane_f32
(
_sum0
,
_r1
,
vget_low_f32
(
_b
),
1
);
_sum0
=
vmlaq_lane_f32
(
_sum0
,
_r2
,
vget_high_f32
(
_b
),
0
);
_sum0
=
vmlaq_lane_f32
(
_sum0
,
_r3
,
vget_high_f32
(
_b
),
1
);
_sum0
=
vmulq_f32
(
_sum0
,
_valpha
);
_sum0
=
vaddq_f32
(
_sum0
,
_vc
);
switch
(
M
-
m
)
{
case
3
:
vst1q_lane_f32
(
C
+
m
+
2
,
_sum0
,
2
);
case
2
:
vst1_f32
(
C
+
m
,
vget_low_f32
(
_sum0
));
break
;
case
1
:
vst1q_lane_f32
(
C
+
m
,
_sum0
,
0
);
break
;
}
}
}
// remain n
for
(
int
n
=
(
N
&
0xfffc
);
n
<
N
;
++
n
)
{
const
float
*
in0
=
A
+
n
*
lda
;
float32x4_t
_b
=
vld1q_dup_f32
(
B
+
n
);
float32x4_t
_sum0
;
int
m
=
0
;
for
(;
m
<
M
-
3
;
m
+=
4
)
{
float32x4_t
_r0
=
vld1q_f32
(
in0
+
m
);
_sum0
=
vld1q_f32
(
C
+
m
);
_r0
=
vmulq_f32
(
_r0
,
_b
);
_r0
=
vmulq_f32
(
_valpha
,
_r0
);
_sum0
=
vaddq_f32
(
_sum0
,
_r0
);
vst1q_f32
(
C
+
m
,
_sum0
);
}
for
(;
m
<
M
;
++
m
)
{
C
[
m
]
+=
alpha
*
(
in0
[
m
]
*
B
[
n
]);
}
}
}
void
sgemv_mx1
(
const
bool
trans
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
if
(
trans
)
{
sgemv_trans_mx1
(
M
,
N
,
alpha
,
A
,
lda
,
B
,
beta
,
C
);
}
else
{
sgemv_notrans_mx1
(
M
,
N
,
alpha
,
A
,
lda
,
B
,
beta
,
C
);
}
}
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
...
...
src/operators/math/gemm/pack_kernel.h
浏览文件 @
cd1b6c08
...
...
@@ -20,15 +20,12 @@ limitations under the License. */
#ifdef _OPENMP
#include <omp.h>
#endif
#include "operators/math/math.h"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
inline
float32x4_t
vandq_f32_u32
(
float32x4_t
x
,
uint32x4_t
mask
)
{
return
vreinterpretq_f32_u32
(
vandq_u32
(
vreinterpretq_u32_f32
(
x
),
mask
));
}
void
pack_lhs_6r
(
const
int
m
,
const
int
k
,
const
float
*
A
,
const
int
lda
,
float
*
output
,
const
bool
unroll
)
{
uint32_t
mask
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
4
,
5
};
...
...
@@ -218,15 +215,21 @@ void pack_lhs_6r(const int m, const int k, const float *A, const int lda,
vst1q_f32
(
out_ptr
+
18
,
_d3
);
vst1_f32
(
out_ptr
+
22
,
vget_high_f32
(
_d5
));
a0
+=
4
;
a1
+=
4
;
a2
+=
4
;
a3
+=
4
;
a4
+=
4
;
a5
+=
4
;
out_ptr
+=
24
;
#else
asm
volatile
(
"vld1.32 {d0-d1}, [%[a0]]
\n
"
"vld1.32 {d2-d3}, [%[a1]]
\n
"
"vld1.32 {d4-d5}, [%[a2]]
\n
"
"vld1.32 {d6-d7}, [%[a3]]
\n
"
"vld1.32 {d8-d9}, [%[a4]]
\n
"
"vld1.32 {d10-d11}, [%[a5]]
\n
"
"vld1.32 {d0-d1}, [%[a0]]
!
\n
"
"vld1.32 {d2-d3}, [%[a1]]
!
\n
"
"vld1.32 {d4-d5}, [%[a2]]
!
\n
"
"vld1.32 {d6-d7}, [%[a3]]
!
\n
"
"vld1.32 {d8-d9}, [%[a4]]
!
\n
"
"vld1.32 {d10-d11}, [%[a5]]
!
\n
"
"vtrn.32 q0, q1
\n
"
"vtrn.32 q2, q3
\n
"
"vtrn.32 q4, q5
\n
"
...
...
@@ -255,6 +258,20 @@ void pack_lhs_6r(const int m, const int k, const float *A, const int lda,
#endif
}
// remain k
switch
(
remain_m
)
{
case
1
:
a1
=
zerobuff
;
case
2
:
a2
=
zerobuff
;
case
3
:
a3
=
zerobuff
;
case
4
:
a4
=
zerobuff
;
case
5
:
a5
=
zerobuff
;
default:
break
;
}
for
(;
lk
<
k
;
++
lk
)
{
*
out_ptr
++
=
*
a0
++
;
*
out_ptr
++
=
*
a1
++
;
...
...
src/operators/math/gemm/strategy.h
浏览文件 @
cd1b6c08
...
...
@@ -88,19 +88,12 @@ struct SgemvStrategy {
typedef
float
Itype
;
typedef
float
Otype
;
typedef
void
(
*
kern_type
)(
const
Itype
*
,
const
Itype
*
,
const
int
,
Otype
*
,
const
int
);
kern_type
kernel
;
static
int
out_width
()
{
return
1
;
}
typedef
void
(
*
kernelFunc
)(
const
bool
,
const
int
,
const
int
,
const
float
,
const
Itype
*
,
const
int
,
const
Itype
*
,
const
float
,
Otype
*
);
kernelFunc
kernel
;
static
int
out_height
()
{
#if __aarch64__
return
12
;
#else
return
6
;
#endif
}
SgemvStrategy
()
{
kernel
=
sgemv_mx1
;
}
};
struct
I8o32gemvStrategy
{
...
...
src/operators/math/math.h
浏览文件 @
cd1b6c08
...
...
@@ -327,4 +327,16 @@ static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
return
exp_ps
(
vmulq_f32
(
b
,
log_ps
(
a
)));
}
#ifndef __aarch64__
inline
float32x4_t
vpaddq_f32
(
float32x4_t
r0
,
float32x4_t
r1
)
{
float32x2_t
sum0
=
vpadd_f32
(
vget_low_f32
(
r0
),
vget_high_f32
(
r0
));
float32x2_t
sum1
=
vpadd_f32
(
vget_low_f32
(
r1
),
vget_high_f32
(
r1
));
return
vcombine_f32
(
sum0
,
sum1
);
}
#endif
inline
float32x4_t
vandq_f32_u32
(
float32x4_t
x
,
uint32x4_t
mask
)
{
return
vreinterpretq_f32_u32
(
vandq_u32
(
vreinterpretq_u32_f32
(
x
),
mask
));
}
#endif // __ARM_NEON__
test/common/test_gemm_accuracy.cpp
浏览文件 @
cd1b6c08
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "../test_helper.h"
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm.h"
#include "operators/math/gemm
/cblas
.h"
#define a(i, j) a[(i)*lda + (j)]
#define b(i, j) b[(i)*ldb + (j)]
...
...
@@ -36,10 +36,12 @@ void print_matrix(int m, int n, int ldc, float *c) {
std
::
cout
<<
std
::
endl
;
}
int
do_sgemm
(
int
m
,
int
n
,
int
k
,
bool
relu
,
int
t1
,
int
t2
,
int
pr
)
{
int
lda
=
k
;
int
ldb
=
n
;
int
ldc
=
n
;
int
do_sgemm
(
int
m
,
int
n
,
int
k
,
int
pr
)
{
const
float
alpha
=
1.
f
;
const
float
beta
=
0.
f
;
const
int
lda
=
k
;
const
int
ldb
=
n
;
const
int
ldc
=
n
;
float
*
a
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
m
*
k
));
...
...
@@ -49,24 +51,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
m
*
n
));
float
*
c1
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
m
*
n
));
float
*
scale
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
m
));
float
*
bias
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
m
));
srand
(
unsigned
(
time
(
0
)));
std
::
mt19937
rng
(
111
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
const
float
lower
=
-
10.
f
;
const
float
upper
=
10.
f
;
for
(
int
i
=
0
;
i
<
m
*
k
;
++
i
)
{
a
[
i
]
=
t1
+
rand
()
%
t2
;
a
[
i
]
=
static_cast
<
float
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
)
;
}
for
(
int
i
=
0
;
i
<
k
*
n
;
++
i
)
{
b
[
i
]
=
t1
+
rand
()
%
t2
;
}
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
scale
[
i
]
=
t1
+
rand
()
%
t2
;
}
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
bias
[
i
]
=
t1
+
rand
()
%
t2
;
b
[
i
]
=
static_cast
<
float
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
}
memcpy
(
c
,
c1
,
sizeof
(
float
)
*
m
*
n
);
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
...
...
@@ -74,25 +71,20 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
for
(
int
p
=
0
;
p
<
k
;
p
++
)
{
r
+=
a
(
i
,
p
)
*
b
(
p
,
j
);
}
r
*=
scale
[
i
];
r
+=
bias
[
i
];
if
(
relu
&&
(
r
<
0
))
{
r
=
0
;
}
c1
(
i
,
j
)
=
r
;
c1
(
i
,
j
)
=
alpha
*
r
;
}
}
paddle_mobile
::
operators
::
math
::
Gemm
gemm
;
gemm
.
SgemmWithBn
(
m
,
n
,
k
,
1
,
a
,
lda
,
b
,
ldb
,
0.3
,
c
,
ldc
,
relu
,
scale
,
bias
,
nullptr
);
int
eq
=
0
;
int
neq
=
0
;
std
::
cout
<<
"run cblas_sgemm..."
<<
std
::
endl
;
paddle_mobile
::
operators
::
math
::
cblas_sgemm
(
false
,
false
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
0.
f
,
c
,
ldc
);
std
::
cout
<<
"compare results..."
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
m
*
n
;
++
i
)
{
if
(
static_cast
<
int
>
(
c
[
i
])
==
static_cast
<
int
>
(
c1
[
i
])
)
{
++
eq
;
}
else
{
++
neq
;
if
(
abs
(
c
[
i
]
-
c1
[
i
])
>=
1e-2
)
{
std
::
cout
<<
"c["
<<
i
<<
"] != c1["
<<
i
<<
"]: "
<<
c
[
i
]
<<
" vs "
<<
c1
[
i
]
<<
std
::
endl
;
exit
(
1
)
;
}
}
...
...
@@ -107,33 +99,36 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
print_matrix
(
m
,
n
,
ldc
,
c1
);
}
std
::
cout
<<
"mnk="
<<
m
<<
" "
<<
n
<<
" "
<<
k
<<
" relu="
<<
relu
<<
" eq="
<<
eq
<<
" neq="
<<
neq
<<
std
::
endl
;
PADDLE_MOBILE_ENFORCE
(
neq
==
0
,
"The execution of do_sgemm is failed!"
);
paddle_mobile
::
memory
::
Free
(
a
);
paddle_mobile
::
memory
::
Free
(
b
);
paddle_mobile
::
memory
::
Free
(
c
);
paddle_mobile
::
memory
::
Free
(
c1
);
paddle_mobile
::
memory
::
Free
(
scale
);
paddle_mobile
::
memory
::
Free
(
bias
);
return
0
;
}
int
main
()
{
do_sgemm
(
9
,
9
,
9
,
true
,
10
,
10
,
10
);
do_sgemm
(
10
,
6
,
12
,
false
,
10
,
10
,
0
);
do_sgemm
(
512
,
256
,
384
,
false
,
10
,
10
,
0
);
do_sgemm
(
1366
,
768
,
256
,
false
,
10
,
10
,
0
);
do_sgemm
(
1255
,
755
,
333
,
false
,
10
,
10
,
0
);
do_sgemm
(
555
,
777
,
999
,
false
,
10
,
10
,
0
);
do_sgemm
(
10
,
6
,
12
,
true
,
-
4
,
10
,
0
);
do_sgemm
(
512
,
256
,
384
,
true
,
-
4
,
10
,
0
);
do_sgemm
(
1366
,
768
,
256
,
true
,
-
4
,
10
,
0
);
do_sgemm
(
1255
,
755
,
333
,
true
,
-
4
,
10
,
0
);
do_sgemm
(
555
,
777
,
999
,
true
,
-
4
,
10
,
0
);
int
main
(
int
argc
,
char
*
argv
[])
{
do_sgemm
(
1
,
1
,
1
,
1
);
do_sgemm
(
9
,
9
,
1
,
1
);
do_sgemm
(
999
,
99
,
1
,
0
);
do_sgemm
(
999
,
1
,
1
,
0
);
do_sgemm
(
1
,
9
,
9
,
1
);
do_sgemm
(
1
,
99
,
999
,
0
);
do_sgemm
(
1
,
1
,
999
,
0
);
do_sgemm
(
9
,
9
,
9
,
1
);
do_sgemm
(
10
,
6
,
12
,
1
);
do_sgemm
(
512
,
256
,
384
,
0
);
do_sgemm
(
1366
,
768
,
256
,
0
);
do_sgemm
(
1255
,
755
,
333
,
0
);
do_sgemm
(
555
,
777
,
999
,
0
);
do_sgemm
(
10
,
6
,
12
,
1
);
do_sgemm
(
512
,
256
,
384
,
0
);
do_sgemm
(
1366
,
768
,
256
,
0
);
do_sgemm
(
1255
,
755
,
333
,
0
);
do_sgemm
(
555
,
777
,
999
,
0
);
return
0
;
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录