Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
1cff3bfe
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1cff3bfe
编写于
7月 10, 2018
作者:
W
WangLiu
提交者:
GitHub
7月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #542 from cocodark/develop
accelerate with openmp
上级
866ab5fc
f312f389
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
371 addition
and
288 deletion
+371
-288
CMakeLists.txt
CMakeLists.txt
+1
-1
src/io/executor.cpp
src/io/executor.cpp
+15
-0
src/io/executor.h
src/io/executor.h
+2
-0
src/operators/kernel/central-arm-func/conv_add_arm_func.h
src/operators/kernel/central-arm-func/conv_add_arm_func.h
+31
-3
src/operators/kernel/lrn_kernel.h
src/operators/kernel/lrn_kernel.h
+4
-1
src/operators/math/gemm.cpp
src/operators/math/gemm.cpp
+55
-52
src/operators/math/gemm.h
src/operators/math/gemm.h
+105
-81
src/operators/math/math_function.cpp
src/operators/math/math_function.cpp
+7
-25
src/operators/math/pool_3x3.cpp
src/operators/math/pool_3x3.cpp
+140
-121
src/operators/math/pool_3x3.h
src/operators/math/pool_3x3.h
+3
-0
src/operators/math/pooling.cpp
src/operators/math/pooling.cpp
+4
-1
test/net/test_googlenet.cpp
test/net/test_googlenet.cpp
+4
-3
未找到文件。
CMakeLists.txt
浏览文件 @
1cff3bfe
...
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.0)
...
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.0)
project
(
paddle-mobile
)
project
(
paddle-mobile
)
option
(
DEBUGING
"enable debug mode"
ON
)
option
(
DEBUGING
"enable debug mode"
ON
)
option
(
USE_OPENMP
"openmp support"
O
FF
)
option
(
USE_OPENMP
"openmp support"
O
N
)
option
(
USE_EXCEPTION
"use std exception"
ON
)
option
(
USE_EXCEPTION
"use std exception"
ON
)
option
(
LOG_PROFILE
"log profile"
ON
)
option
(
LOG_PROFILE
"log profile"
ON
)
# select the platform to build
# select the platform to build
...
...
src/io/executor.cpp
浏览文件 @
1cff3bfe
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "io/executor.h"
#include "io/executor.h"
#include <operators/math/gemm.h>
#include <algorithm>
#include <algorithm>
#include <vector>
#include <vector>
#include "common/enforce.h"
#include "common/enforce.h"
...
@@ -25,6 +26,9 @@ limitations under the License. */
...
@@ -25,6 +26,9 @@ limitations under the License. */
#include "framework/program/var_desc.h"
#include "framework/program/var_desc.h"
#include "framework/scope.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/tensor.h"
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue>
#include <queue>
#include <utility>
#include <utility>
...
@@ -403,6 +407,17 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
...
@@ -403,6 +407,17 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
return
result_vector
;
return
result_vector
;
}
}
template
<
typename
Dtype
,
Precision
P
>
void
Executor
<
Dtype
,
P
>::
SetThreadNum
(
int
num
)
{
for
(
int
k
=
0
;
k
<
std
::
max
(
num
,
3
);
++
k
)
{
operators
::
math
::
Gemmer
::
gemmers
.
push_back
(
new
operators
::
math
::
Gemmer
());
}
#ifdef _OPENMP
// omp_set_dynamic(0);
omp_set_num_threads
(
num
);
#endif
}
template
class
Executor
<
CPU
,
Precision
::
FP32
>;
template
class
Executor
<
CPU
,
Precision
::
FP32
>;
template
class
Executor
<
FPGA
,
Precision
::
FP32
>;
template
class
Executor
<
FPGA
,
Precision
::
FP32
>;
template
class
Executor
<
GPU_MALI
,
Precision
::
FP32
>;
template
class
Executor
<
GPU_MALI
,
Precision
::
FP32
>;
...
...
src/io/executor.h
浏览文件 @
1cff3bfe
...
@@ -58,6 +58,8 @@ class Executor {
...
@@ -58,6 +58,8 @@ class Executor {
std
::
vector
<
Ptype
>
Predict
(
const
std
::
vector
<
Ptype
>
&
input
,
std
::
vector
<
Ptype
>
Predict
(
const
std
::
vector
<
Ptype
>
&
input
,
const
std
::
vector
<
int64_t
>
&
dims
);
const
std
::
vector
<
int64_t
>
&
dims
);
void
SetThreadNum
(
int
num
);
protected:
protected:
Executor
()
=
default
;
Executor
()
=
default
;
void
InitMemory
();
void
InitMemory
();
...
...
src/operators/kernel/central-arm-func/conv_add_arm_func.h
浏览文件 @
1cff3bfe
...
@@ -14,10 +14,14 @@ limitations under the License. */
...
@@ -14,10 +14,14 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP
#ifdef FUSION_CONVADD_OP
#pragma once
#pragma once
#if _OPENMP
#include <omp.h>
#endif
#include <vector>
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/gemm.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/math/vol2col.h"
...
@@ -106,9 +110,33 @@ void ConvAddBasic(const FusionConvAddParam ¶m) {
...
@@ -106,9 +110,33 @@ void ConvAddBasic(const FusionConvAddParam ¶m) {
// gemm
// gemm
Tensor
out_slice
=
out_batch
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
out_slice
=
out_batch
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
filter_slice
=
filter
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
filter_slice
=
filter
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
math
::
matmul
<
float
>
(
filter_slice
,
false
,
col_matrix
,
false
,
static_cast
<
float
>
(
1
),
&
out_slice
,
auto
dim_a
=
filter_slice
.
dims
();
static_cast
<
float
>
(
1
));
auto
dim_b
=
col_matrix
.
dims
();
auto
dim_out
=
out_slice
.
dims
();
int
m
=
dim_out
[
0
];
int
n
=
dim_out
[
1
];
int
k
=
dim_a
[
1
];
float
*
output_data
=
out_slice
.
data
<
float
>
();
int
thread_num
=
4
;
int
m1
=
m
/
thread_num
;
int
m2
=
m
%
thread_num
;
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
thread_num
;
++
j
)
{
int
row_count
=
m1
;
if
(
j
==
thread_num
-
1
)
{
row_count
=
m1
+
m2
;
}
math
::
Gemmer
::
gemmers
[
j
]
->
Sgemm
(
row_count
,
n
,
k
,
1
,
filter_slice
.
data
<
float
>
()
+
j
*
m1
*
k
,
k
,
col_matrix
.
data
<
float
>
(),
n
,
1
,
output_data
+
j
*
m1
*
n
,
n
,
false
);
}
// math::matmul<float>(filter_slice, false, col_matrix, false,
// static_cast<float>(1), &out_slice,
// static_cast<float>(1));
}
}
}
}
}
}
...
...
src/operators/kernel/lrn_kernel.h
浏览文件 @
1cff3bfe
...
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#ifdef LRN_OP
#ifdef LRN_OP
#ifdef _OPENMP
#include <omp.h>
#endif
#include "framework/operator.h"
#include "framework/operator.h"
#include "operators/op_param.h"
#include "operators/op_param.h"
...
@@ -47,6 +49,7 @@ struct LRNFunctor {
...
@@ -47,6 +49,7 @@ struct LRNFunctor {
std
::
fill
(
sqr_buffer_ptr
,
sqr_buffer_ptr
+
sqr_buffer
.
numel
(),
0.0
);
std
::
fill
(
sqr_buffer_ptr
,
sqr_buffer_ptr
+
sqr_buffer
.
numel
(),
0.0
);
for
(
int
a
=
0
;
a
<
N
;
a
++
)
{
for
(
int
a
=
0
;
a
<
N
;
a
++
)
{
#pragma parallel for
for
(
int
b
=
0
;
b
<
C
;
b
++
)
{
for
(
int
b
=
0
;
b
<
C
;
b
++
)
{
for
(
int
index
=
start
;
index
<
end
;
index
++
)
{
for
(
int
index
=
start
;
index
<
end
;
index
++
)
{
int
channel
=
b
+
index
;
int
channel
=
b
+
index
;
...
...
src/operators/math/gemm.cpp
浏览文件 @
1cff3bfe
...
@@ -22,17 +22,11 @@ limitations under the License. */
...
@@ -22,17 +22,11 @@ limitations under the License. */
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
int
MC
=
0
;
int
KC
=
0
;
std
::
vector
<
Gemmer
*>
Gemmer
::
gemmers
;
int
NC
=
0
;
float
*
packedA
;
float
*
packedB
;
float
*
packedC
;
float
*
zero
;
// 将A矩阵分块复制到连续内存(ColMajor)
// 将A矩阵分块复制到连续内存(ColMajor)
void
PackMatrixA
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
void
Gemmer
::
PackMatrixA
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
)
{
float
*
buffer
)
{
int
i
,
j
;
int
i
,
j
;
const
float
*
Aij
;
const
float
*
Aij
;
for
(
i
=
0
;
i
<
m
-
m_tail
;
i
+=
MR
)
{
for
(
i
=
0
;
i
<
m
-
m_tail
;
i
+=
MR
)
{
...
@@ -58,8 +52,8 @@ void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
...
@@ -58,8 +52,8 @@ void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
}
}
// 将A矩阵分块复制到连续内存(RowMajor)
// 将A矩阵分块复制到连续内存(RowMajor)
void
PackMatrixA_
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
void
Gemmer
::
PackMatrixA_
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
)
{
float
*
buffer
)
{
const
float
*
a0
,
*
a1
,
*
a2
,
*
a3
;
const
float
*
a0
,
*
a1
,
*
a2
,
*
a3
;
for
(
int
i
=
0
;
i
<
m
-
m_tail
;
i
+=
MR
)
{
for
(
int
i
=
0
;
i
<
m
-
m_tail
;
i
+=
MR
)
{
a0
=
A
+
i
*
lda
;
a0
=
A
+
i
*
lda
;
...
@@ -98,8 +92,8 @@ void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
...
@@ -98,8 +92,8 @@ void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
}
}
// 将B矩阵分块复制到连续内存(ColMajor)
// 将B矩阵分块复制到连续内存(ColMajor)
void
PackMatrixB
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
Gemmer
::
PackMatrixB
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
)
{
float
*
buffer
)
{
int
i
,
j
;
int
i
,
j
;
const
float
*
Bj
,
*
Bj1
,
*
Bj2
,
*
Bj3
;
const
float
*
Bj
,
*
Bj1
,
*
Bj2
,
*
Bj3
;
for
(
j
=
0
;
j
<
n
-
n_tail
;
j
+=
NR
)
{
for
(
j
=
0
;
j
<
n
-
n_tail
;
j
+=
NR
)
{
...
@@ -127,8 +121,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
...
@@ -127,8 +121,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
}
}
// 将B矩阵分块复制到连续内存(RowMajor)
// 将B矩阵分块复制到连续内存(RowMajor)
void
PackMatrixB_
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
Gemmer
::
PackMatrixB_
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
)
{
float
*
buffer
)
{
const
float
*
b0
;
const
float
*
b0
;
for
(
int
j
=
0
;
j
<
n
-
n_tail
;
j
+=
NR
)
{
for
(
int
j
=
0
;
j
<
n
-
n_tail
;
j
+=
NR
)
{
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
...
@@ -156,8 +150,9 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
...
@@ -156,8 +150,9 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
}
}
// 分块矩阵乘法
// 分块矩阵乘法
void
InnerKernel
(
int
mc
,
int
nc
,
float
alpha
,
const
float
*
a
,
const
float
*
b
,
void
Gemmer
::
InnerKernel
(
int
mc
,
int
nc
,
float
alpha
,
const
float
*
a
,
float
beta
,
float
*
c
,
float
*
C
,
int
ldc
,
bool
relu
)
{
const
float
*
b
,
float
beta
,
float
*
c
,
float
*
C
,
int
ldc
,
bool
relu
)
{
for
(
int
j
=
0
;
j
<
nc
;
j
+=
NR
)
{
for
(
int
j
=
0
;
j
<
nc
;
j
+=
NR
)
{
for
(
int
i
=
0
;
i
<
mc
;
i
+=
MR
)
{
for
(
int
i
=
0
;
i
<
mc
;
i
+=
MR
)
{
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
...
@@ -184,9 +179,10 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
...
@@ -184,9 +179,10 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
}
}
// 分块矩阵乘法
// 分块矩阵乘法
void
InnerKernelWithBn
(
int
mc
,
int
nc
,
float
alpha
,
const
float
*
a
,
void
Gemmer
::
InnerKernelWithBn
(
int
mc
,
int
nc
,
float
alpha
,
const
float
*
a
,
const
float
*
b
,
float
beta
,
float
*
c
,
float
*
C
,
int
ldc
,
const
float
*
b
,
float
beta
,
float
*
c
,
float
*
C
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
)
{
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
)
{
for
(
int
j
=
0
;
j
<
nc
;
j
+=
NR
)
{
for
(
int
j
=
0
;
j
<
nc
;
j
+=
NR
)
{
for
(
int
i
=
0
;
i
<
mc
;
i
+=
MR
)
{
for
(
int
i
=
0
;
i
<
mc
;
i
+=
MR
)
{
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
...
@@ -202,7 +198,8 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
...
@@ -202,7 +198,8 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
}
}
#if defined(IOS)
#if defined(IOS)
void
AddDot4x4
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
C
,
int
ldc
)
{
void
Gemmer
::
AddDot4x4
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
C
,
int
ldc
)
{
// init C
// init C
float32x4_t
cv0
=
vdupq_n_f32
(
0.0
);
float32x4_t
cv0
=
vdupq_n_f32
(
0.0
);
float32x4_t
cv1
=
vdupq_n_f32
(
0.0
);
float32x4_t
cv1
=
vdupq_n_f32
(
0.0
);
...
@@ -253,7 +250,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) {
...
@@ -253,7 +250,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) {
}
// namespace math
}
// namespace math
#elif defined(ARMV7)
#elif defined(ARMV7)
void
AddDot4x4
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
)
{
void
Gemmer
::
AddDot4x4
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
)
{
const
float
*
a_ptr
,
*
b_ptr
;
const
float
*
a_ptr
,
*
b_ptr
;
a_ptr
=
a
;
a_ptr
=
a
;
b_ptr
=
b
;
b_ptr
=
b
;
...
@@ -324,7 +322,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
...
@@ -324,7 +322,8 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
}
}
#else
#else
void
AddDot4x4
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
)
{
void
Gemmer
::
AddDot4x4
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
)
{
float
*
c0
,
*
c1
,
*
c2
,
*
c3
;
float
*
c0
,
*
c1
,
*
c2
,
*
c3
;
c0
=
c
;
c0
=
c
;
c1
=
c
+
ldc
;
c1
=
c
+
ldc
;
...
@@ -363,8 +362,9 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
...
@@ -363,8 +362,9 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
#endif
#endif
// 32位 float 矩阵乘法
// 32位 float 矩阵乘法
void
Sgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
void
Gemmer
::
Sgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
bool
relu
)
{
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
bool
relu
)
{
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int
L1
=
30
*
1024
;
int
L1
=
30
*
1024
;
...
@@ -415,9 +415,10 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
...
@@ -415,9 +415,10 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile
::
memory
::
Free
(
zero
);
paddle_mobile
::
memory
::
Free
(
zero
);
}
}
void
SgemmWithBn
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
void
Gemmer
::
SgemmWithBn
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
)
{
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
)
{
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int
L1
=
30
*
1024
;
int
L1
=
30
*
1024
;
...
@@ -468,9 +469,9 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
...
@@ -468,9 +469,9 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile
::
memory
::
Free
(
zero
);
paddle_mobile
::
memory
::
Free
(
zero
);
}
}
void
VectorKernel
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
void
Gemmer
::
VectorKernel
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
bool
relu
)
{
float
*
C
,
int
ldc
,
bool
relu
)
{
float
*
bufferC
=
static_cast
<
float
*>
(
memory
::
Alloc
(
sizeof
(
float
)
*
n
));
float
*
bufferC
=
static_cast
<
float
*>
(
memory
::
Alloc
(
sizeof
(
float
)
*
n
));
const
float
*
a0
,
*
b0
,
*
b1
,
*
b2
,
*
b3
;
const
float
*
a0
,
*
b0
,
*
b1
,
*
b2
,
*
b3
;
...
@@ -690,9 +691,10 @@ void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
...
@@ -690,9 +691,10 @@ void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
}
}
}
}
void
VectorKernelWithBn
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
void
Gemmer
::
VectorKernelWithBn
(
int
m
,
int
n
,
int
k
,
float
alpha
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
)
{
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
)
{
float
*
bufferC
=
static_cast
<
float
*>
(
memory
::
Alloc
(
sizeof
(
float
)
*
n
));
float
*
bufferC
=
static_cast
<
float
*>
(
memory
::
Alloc
(
sizeof
(
float
)
*
n
));
const
float
*
a0
,
*
b0
,
*
b1
,
*
b2
,
*
b3
;
const
float
*
a0
,
*
b0
,
*
b1
,
*
b2
,
*
b3
;
...
@@ -901,7 +903,8 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
...
@@ -901,7 +903,8 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
}
}
}
}
void
AddDot4x8
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
)
{
void
Gemmer
::
AddDot4x8
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
)
{
const
float
*
a_ptr
,
*
b_ptr
;
const
float
*
a_ptr
,
*
b_ptr
;
a_ptr
=
a
;
a_ptr
=
a
;
b_ptr
=
b
;
b_ptr
=
b
;
...
@@ -1009,7 +1012,7 @@ void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
...
@@ -1009,7 +1012,7 @@ void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
}
}
// C = A * B
// C = A * B
void
WriteBasic
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
)
{
void
Gemmer
::
WriteBasic
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
)
{
int
nc1
=
nc
/
16
;
int
nc1
=
nc
/
16
;
int
_nc1
=
nc
%
16
;
int
_nc1
=
nc
%
16
;
int
step
=
4
*
ldc
;
int
step
=
4
*
ldc
;
...
@@ -1066,10 +1069,10 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
...
@@ -1066,10 +1069,10 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
}
}
// C = alpha * A * B + beta * C
// C = alpha * A * B + beta * C
void
WriteWithAlphaBeta
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
)
{}
void
Gemmer
::
WriteWithAlphaBeta
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
)
{}
// C = A * B + C
// C = A * B + C
void
WriteWithAdd
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
)
{
void
Gemmer
::
WriteWithAdd
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
)
{
int
nc1
=
nc
/
16
;
int
nc1
=
nc
/
16
;
int
_nc1
=
nc
%
16
;
int
_nc1
=
nc
%
16
;
int
step
=
4
*
ldc
;
int
step
=
4
*
ldc
;
...
@@ -1133,7 +1136,7 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
...
@@ -1133,7 +1136,7 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
}
}
// C = A * B + C, relu(C)
// C = A * B + C, relu(C)
void
WriteWithAddRelu
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
)
{
void
Gemmer
::
WriteWithAddRelu
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
)
{
int
nc1
=
nc
/
16
;
int
nc1
=
nc
/
16
;
int
_nc1
=
nc
%
16
;
int
_nc1
=
nc
%
16
;
int
step
=
4
*
ldc
;
int
step
=
4
*
ldc
;
...
@@ -1207,8 +1210,8 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
...
@@ -1207,8 +1210,8 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
}
}
// C = A * B, batchnorm(C)
// C = A * B, batchnorm(C)
void
WriteWithBn
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
scale
,
void
Gemmer
::
WriteWithBn
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
bias
)
{
float
*
scale
,
float
*
bias
)
{
int
nc1
=
nc
/
16
;
int
nc1
=
nc
/
16
;
int
_nc1
=
nc
%
16
;
int
_nc1
=
nc
%
16
;
int
nc2
=
_nc1
/
4
;
int
nc2
=
_nc1
/
4
;
...
@@ -1293,8 +1296,8 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
...
@@ -1293,8 +1296,8 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
}
}
// C = A * B, batchnorm(C), relu(C)
// C = A * B, batchnorm(C), relu(C)
void
WriteWithBnRelu
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
scale
,
void
Gemmer
::
WriteWithBnRelu
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
bias
)
{
float
*
scale
,
float
*
bias
)
{
int
nc1
=
nc
/
16
;
int
nc1
=
nc
/
16
;
int
_nc1
=
nc
%
16
;
int
_nc1
=
nc
%
16
;
int
nc2
=
_nc1
/
4
;
int
nc2
=
_nc1
/
4
;
...
@@ -1386,7 +1389,7 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
...
@@ -1386,7 +1389,7 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
}
}
// C = A * B
// C = A * B
void
VecWriteBasic
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{
void
Gemmer
::
VecWriteBasic
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{
int
nc1
=
n
/
16
;
int
nc1
=
n
/
16
;
int
_nc1
=
n
%
16
;
int
_nc1
=
n
%
16
;
int
nc2
=
_nc1
/
4
;
int
nc2
=
_nc1
/
4
;
...
@@ -1432,10 +1435,10 @@ void VecWriteBasic(int n, float *c, float *C, int ldc) {
...
@@ -1432,10 +1435,10 @@ void VecWriteBasic(int n, float *c, float *C, int ldc) {
}
}
// C = alpha * A * B + beta * C
// C = alpha * A * B + beta * C
void
VecWriteWithAlphaBeta
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{}
void
Gemmer
::
VecWriteWithAlphaBeta
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{}
// C = A * B + C
// C = A * B + C
void
VecWriteWithAdd
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{
void
Gemmer
::
VecWriteWithAdd
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{
int
nc1
=
n
/
16
;
int
nc1
=
n
/
16
;
int
_nc1
=
n
%
16
;
int
_nc1
=
n
%
16
;
...
@@ -1473,7 +1476,7 @@ void VecWriteWithAdd(int n, float *c, float *C, int ldc) {
...
@@ -1473,7 +1476,7 @@ void VecWriteWithAdd(int n, float *c, float *C, int ldc) {
}
}
// C = A * B + C, relu(C)
// C = A * B + C, relu(C)
void
VecWriteWithAddRelu
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{
void
Gemmer
::
VecWriteWithAddRelu
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{
int
nc1
=
n
/
16
;
int
nc1
=
n
/
16
;
int
_nc1
=
n
%
16
;
int
_nc1
=
n
%
16
;
...
@@ -1521,8 +1524,8 @@ void VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
...
@@ -1521,8 +1524,8 @@ void VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
}
}
// C = A * B, batchnorm(C)
// C = A * B, batchnorm(C)
void
VecWriteWithBn
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
scale
,
void
Gemmer
::
VecWriteWithBn
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
scale
,
float
*
bias
)
{
float
*
bias
)
{
int
nc1
=
n
/
16
;
int
nc1
=
n
/
16
;
int
_nc1
=
n
%
16
;
int
_nc1
=
n
%
16
;
int
nc2
=
_nc1
/
4
;
int
nc2
=
_nc1
/
4
;
...
@@ -1588,8 +1591,8 @@ void VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
...
@@ -1588,8 +1591,8 @@ void VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
}
}
// C = A * B, batchnorm(C), relu(C)
// C = A * B, batchnorm(C), relu(C)
void
VecWriteWithBnRelu
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
scale
,
void
Gemmer
::
VecWriteWithBnRelu
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
bias
)
{
float
*
scale
,
float
*
bias
)
{
int
nc1
=
n
/
16
;
int
nc1
=
n
/
16
;
int
_nc1
=
n
%
16
;
int
_nc1
=
n
%
16
;
int
nc2
=
_nc1
/
4
;
int
nc2
=
_nc1
/
4
;
...
...
src/operators/math/gemm.h
浏览文件 @
1cff3bfe
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <vector>
// 矩阵取值运算宏,假设矩阵按行存储
// 矩阵取值运算宏,假设矩阵按行存储
#define A(i, j) A[(i)*lda + (j)]
#define A(i, j) A[(i)*lda + (j)]
...
@@ -27,88 +28,111 @@ limitations under the License. */
...
@@ -27,88 +28,111 @@ limitations under the License. */
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
struct
Gemmer
{
int
MC
=
0
;
int
KC
=
0
;
int
NC
=
0
;
// 将 A 矩阵分块复制到连续内存(ColMajor)
float
*
packedA
;
void
PackMatrixA
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
packedB
;
float
*
buffer
);
float
*
packedC
;
float
*
zero
;
// 将 B 矩阵分块复制到连续内存(ColMajor)
static
std
::
vector
<
Gemmer
*>
gemmers
;
void
PackMatrixB
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
);
// 将 A 矩阵分块复制到连续内存(ColMajor)
void
PackMatrixA
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
// 将 A 矩阵分块复制到连续内存(RowMajor)
float
*
buffer
);
void
PackMatrixA_
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void
PackMatrixB
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
// 将 B 矩阵分块复制到连续内存(RowMajor)
float
*
buffer
);
void
PackMatrixB_
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
);
// 将 A 矩阵分块复制到连续内存(RowMajor)
void
PackMatrixA_
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
// 分块矩阵乘法
float
*
buffer
);
void
InnerKernel
(
int
mc
,
int
nc
,
float
alpha
,
const
float
*
a
,
const
float
*
b
,
float
beta
,
float
*
c
,
float
*
C
,
int
ldc
,
bool
relu
);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void
PackMatrixB_
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
InnerKernelWithBn
(
int
mc
,
int
nc
,
float
alpha
,
const
float
*
a
,
float
*
buffer
);
const
float
*
b
,
float
beta
,
float
*
c
,
float
*
C
,
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
);
// 分块矩阵乘法
void
InnerKernel
(
int
mc
,
int
nc
,
float
alpha
,
const
float
*
a
,
const
float
*
b
,
// 向量矩阵乘法 (M = 1)
float
beta
,
float
*
c
,
float
*
C
,
int
ldc
,
bool
relu
);
void
VectorKernel
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
void
InnerKernelWithBn
(
int
mc
,
int
nc
,
float
alpha
,
const
float
*
a
,
bool
relu
);
const
float
*
b
,
float
beta
,
float
*
c
,
float
*
C
,
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
);
void
VectorKernelWithBn
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
// 向量矩阵乘法 (M = 1)
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
);
void
VectorKernel
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
// 计算一个更小的 C 矩阵分块
bool
relu
);
void
AddDot4x4
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
);
void
AddDot4x8
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
);
void
VectorKernelWithBn
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
// 分块矩阵乘法结果回写
float
*
C
,
int
ldc
,
bool
relu
,
float
*
new_scale
,
// C = A * B
float
*
new_bias
);
void
WriteBasic
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
);
// C = alpha * A * B + beta * C
// 计算一个更小的 C 矩阵分块
void
WriteWithAlphaBeta
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
);
void
AddDot4x4
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
);
// C = A * B + C
void
WriteWithAdd
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
);
void
AddDot4x8
(
int
k
,
const
float
*
a
,
const
float
*
b
,
float
*
c
,
int
ldc
);
// C = A * B + C, relu(C)
void
WriteWithAddRelu
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
);
// 分块矩阵乘法结果回写
// C = A * B, batchnorm(C)
// C = A * B
void
WriteWithBn
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
new_scale
,
void
WriteBasic
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
);
float
*
new_bias
);
// C = A * B, batchnorm(C), relu(C)
// C = alpha * A * B + beta * C
void
WriteWithBnRelu
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
,
void
WriteWithAlphaBeta
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
);
float
*
new_scale
,
float
*
new_bias
);
// C = A * B + C
// 向量矩阵乘法结果回写
void
WriteWithAdd
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
);
// C = A * B
void
VecWriteBasic
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
);
// C = A * B + C, relu(C)
// C = alpha * A * B + beta * C
void
WriteWithAddRelu
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
);
void
VecWriteWithAlphaBeta
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
);
// C = A * B + C
// C = A * B, batchnorm(C)
void
VecWriteWithAdd
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
);
void
WriteWithBn
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
,
// C = A * B + C, relu(C)
float
*
new_scale
,
float
*
new_bias
);
void
VecWriteWithAddRelu
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
);
// C = A * B, batchnorm(C)
// C = A * B, batchnorm(C), relu(C)
void
VecWriteWithBn
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
new_scale
,
void
WriteWithBnRelu
(
int
mc
,
int
nc
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
new_bias
);
float
*
new_scale
,
float
*
new_bias
);
// C = A * B, batchnorm(C), relu(C)
void
VecWriteWithBnRelu
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
new_scale
,
// 向量矩阵乘法结果回写
float
*
new_bias
);
// C = A * B
void
VecWriteBasic
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
);
// 32位 float 矩阵乘法
void
Sgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
// C = alpha * A * B + beta * C
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
bool
relu
);
void
VecWriteWithAlphaBeta
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
// C = A * B + C
void
SgemmWithBn
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
void
VecWriteWithAdd
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
);
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
);
// C = A * B + C, relu(C)
void
VecWriteWithAddRelu
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
);
// 64位 double 矩阵乘法
void
dgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
double
*
A
,
int
lda
,
// C = A * B, batchnorm(C)
const
double
*
B
,
int
ldb
,
float
beta
,
double
*
C
,
int
ldc
);
void
VecWriteWithBn
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
new_scale
,
float
*
new_bias
);
// C = A * B, batchnorm(C), relu(C)
void
VecWriteWithBnRelu
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
,
float
*
new_scale
,
float
*
new_bias
);
// 32位 float 矩阵乘法
void
Sgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
bool
relu
);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void
SgemmWithBn
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
bool
relu
,
float
*
new_scale
,
float
*
new_bias
);
// 64位 double 矩阵乘法
void
dgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
double
*
A
,
int
lda
,
const
double
*
B
,
int
ldb
,
float
beta
,
double
*
C
,
int
ldc
);
};
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
src/operators/math/math_function.cpp
浏览文件 @
1cff3bfe
...
@@ -26,23 +26,14 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
...
@@ -26,23 +26,14 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
auto
dim_out
=
matrix_out
->
dims
();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int
M
=
dim_out
[
0
];
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
N
=
dim_out
[
1
];
int
K
=
(
!
trans_a
)
?
dim_a
[
1
]
:
dim_a
[
0
];
int
K
=
(
!
trans_a
)
?
dim_a
[
1
]
:
dim_a
[
0
];
Sgemm
(
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
K
,
matrix_b
.
data
<
float
>
(),
N
,
Gemmer
::
gemmers
[
0
]
->
Sgemm
(
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
K
,
beta
,
matrix_out
->
data
<
float
>
(),
N
,
relu
);
matrix_b
.
data
<
float
>
(),
N
,
beta
,
matrix_out
->
data
<
float
>
(),
N
,
relu
);
}
}
template
<
>
template
<
>
...
@@ -54,24 +45,15 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
...
@@ -54,24 +45,15 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
auto
dim_out
=
matrix_out
->
dims
();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int
M
=
dim_out
[
0
];
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
N
=
dim_out
[
1
];
int
K
=
(
!
trans_a
)
?
dim_a
[
1
]
:
dim_a
[
0
];
int
K
=
(
!
trans_a
)
?
dim_a
[
1
]
:
dim_a
[
0
];
SgemmWithBn
(
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
K
,
matrix_b
.
data
<
float
>
(),
Gemmer
::
gemmers
[
0
]
->
SgemmWithBn
(
N
,
beta
,
matrix_out
->
data
<
float
>
(),
N
,
relu
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
K
,
matrix_b
.
data
<
float
>
(),
N
,
new_scale
->
data
<
float
>
(),
new_bias
->
data
<
float
>
());
beta
,
matrix_out
->
data
<
float
>
(),
N
,
relu
,
new_scale
->
data
<
float
>
(),
new_bias
->
data
<
float
>
());
}
}
}
// namespace math
}
// namespace math
...
...
src/operators/math/pool_3x3.cpp
浏览文件 @
1cff3bfe
此差异已折叠。
点击以展开。
src/operators/math/pool_3x3.h
浏览文件 @
1cff3bfe
...
@@ -15,6 +15,9 @@ limitations under the License. */
...
@@ -15,6 +15,9 @@ limitations under the License. */
#ifdef POOL_OP
#ifdef POOL_OP
#pragma once
#pragma once
#ifdef _OPENMP
#include <omp.h>
#endif
#include <algorithm>
#include <algorithm>
#include <vector>
#include <vector>
#include "framework/tensor.h"
#include "framework/tensor.h"
...
...
src/operators/math/pooling.cpp
浏览文件 @
1cff3bfe
...
@@ -16,6 +16,9 @@ limitations under the License. */
...
@@ -16,6 +16,9 @@ limitations under the License. */
#include "pooling.h"
#include "pooling.h"
#include "common/types.h"
#include "common/types.h"
#ifdef _OPENMP
#include <omp.h>
#endif
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
...
@@ -57,8 +60,8 @@ class PoolFunctor<CPU, PoolProcess, T> {
...
@@ -57,8 +60,8 @@ class PoolFunctor<CPU, PoolProcess, T> {
T
*
output_data
=
output
->
mutable_data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// #pragma omp parallel for
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
#pragma omp parallel for
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hend
=
std
::
min
(
hstart
+
ksize_height
,
input_height
);
int
hend
=
std
::
min
(
hstart
+
ksize_height
,
input_height
);
...
...
test/net/test_googlenet.cpp
浏览文件 @
1cff3bfe
...
@@ -26,16 +26,17 @@ int main() {
...
@@ -26,16 +26,17 @@ int main() {
auto
time2
=
time
();
auto
time2
=
time
();
DLOG
<<
"load cost :"
<<
time_diff
(
time1
,
time2
)
<<
"ms
\n
"
;
DLOG
<<
"load cost :"
<<
time_diff
(
time1
,
time2
)
<<
"ms
\n
"
;
paddle_mobile
::
Executor
<
paddle_mobile
::
CPU
>
executor
(
program
,
1
,
optimize
);
paddle_mobile
::
Executor
<
paddle_mobile
::
CPU
>
executor
(
program
,
1
,
optimize
);
executor
.
SetThreadNum
(
4
);
std
::
vector
<
float
>
input
;
std
::
vector
<
float
>
input
;
std
::
vector
<
int64_t
>
dims
{
1
,
3
,
224
,
224
};
std
::
vector
<
int64_t
>
dims
{
1
,
3
,
224
,
224
};
GetInput
<
float
>
(
g_test_image_1x3x224x224
,
&
input
,
dims
);
GetInput
<
float
>
(
g_test_image_1x3x224x224
,
&
input
,
dims
);
auto
time3
=
time
();
auto
time3
=
time
();
int
count
=
1
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
executor
.
Predict
(
input
,
dims
);
executor
.
Predict
(
input
,
dims
);
}
}
auto
time4
=
time
();
auto
time4
=
time
();
DLOG
<<
"predict cost :"
<<
time_diff
(
time3
,
time4
)
<<
"ms
\n
"
;
DLOG
<<
"predict cost :"
<<
time_diff
(
time3
,
time4
)
/
count
<<
"ms
\n
"
;
return
0
;
return
0
;
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录