Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
e3ba5be3
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e3ba5be3
编写于
5月 28, 2018
作者:
S
smilejames
提交者:
GitHub
5月 28, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #302 from smilejames/develop
update Gemm with implementation of 'C = alpha * A * B + beta * C'
上级
bfbca2c6
ebadfc3b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
45 addition
and
23 deletion
+45
-23
src/operators/math/gemm.cpp
src/operators/math/gemm.cpp
+25
-10
src/operators/math/gemm.h
src/operators/math/gemm.h
+5
-4
test/common/test_gemm.cpp.cpp
test/common/test_gemm.cpp.cpp
+15
-9
未找到文件。
src/operators/math/gemm.cpp
浏览文件 @
e3ba5be3
...
...
@@ -130,8 +130,9 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
}
// 分块矩阵乘法
void
InnerKernel
(
int
m
,
int
n
,
int
k
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
*
C
,
int
ldc
,
int
first_time
)
{
void
InnerKernel
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
int
first_time
)
{
int
Buff_A_M
=
m
;
int
Buff_B_N
=
n
;
...
...
@@ -162,15 +163,15 @@ void InnerKernel(int m, int n, int k, const float *A, int lda, const float *B,
// A 取 4 行,打包预热
for
(
i
=
0
;
i
<
Buff_A_M
;
i
+=
MR
)
{
mc
=
(
m
-
i
)
<
MR
?
_mc
:
MR
;
AddDot4x4
(
k
,
&
packedA
[
i
*
k
],
4
,
&
packedB
[
j
*
k
],
k
,
&
C
(
i
,
j
),
ldc
,
mc
,
nc
);
AddDot4x4
(
k
,
alpha
,
&
packedA
[
i
*
k
],
4
,
&
packedB
[
j
*
k
],
k
,
beta
,
&
C
(
i
,
j
),
ldc
,
mc
,
nc
);
}
}
}
// 计算一个更小的 4 * 4 的 C 矩阵分块
void
AddDot4x4
(
int
k
,
const
float
*
a
,
int
lda
,
const
float
*
b
,
int
ld
b
,
float
*
C
,
int
ldc
,
int
mc
,
int
nc
)
{
void
AddDot4x4
(
int
k
,
float
alpha
,
const
float
*
a
,
int
lda
,
const
float
*
b
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
int
mc
,
int
nc
)
{
float
c
[
16
]
=
{
0
};
float
reg_a0
,
reg_a1
,
reg_a2
,
reg_a3
,
reg_b0
,
reg_b1
,
reg_b2
,
reg_b3
;
...
...
@@ -218,7 +219,16 @@ void AddDot4x4(int k, const float *a, int lda, const float *b, int ldb,
int
i
,
j
;
for
(
i
=
0
;
i
<
mc
;
++
i
)
{
for
(
j
=
0
;
j
<
nc
;
++
j
)
{
C
(
i
,
j
)
+=
c
[
i
*
4
+
j
];
if
(
beta
==
0.0
)
{
C
(
i
,
j
)
=
0.0
;
}
else
if
(
beta
!=
1.0
)
{
C
(
i
,
j
)
*=
beta
;
}
if
(
alpha
!=
1.0
)
{
C
(
i
,
j
)
+=
alpha
*
c
[
i
*
MR
+
j
];
}
else
{
C
(
i
,
j
)
+=
c
[
i
*
MR
+
j
];
}
}
}
}
...
...
@@ -227,15 +237,20 @@ void AddDot4x4(int k, const float *a, int lda, const float *b, int ldb,
void
sgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
)
{
int
i
,
j
,
p
,
mc
,
nc
,
kc
;
float
beta_
;
for
(
j
=
0
;
j
<
n
;
j
+=
NC
)
{
nc
=
s_min
(
n
-
j
,
NC
);
for
(
p
=
0
;
p
<
k
;
p
+=
KC
)
{
kc
=
s_min
(
k
-
p
,
KC
);
for
(
i
=
0
;
i
<
m
;
i
+=
MC
)
{
mc
=
s_min
(
m
-
i
,
MC
);
InnerKernel
(
mc
,
nc
,
kc
,
&
A
(
i
,
p
),
lda
,
&
B
(
p
,
j
),
ldb
,
&
C
(
i
,
j
),
ldc
,
i
==
0
);
if
(
p
!=
0
)
{
beta_
=
1.0
;
}
else
{
beta_
=
beta
;
}
InnerKernel
(
mc
,
nc
,
kc
,
alpha
,
&
A
(
i
,
p
),
lda
,
&
B
(
p
,
j
),
ldb
,
beta_
,
&
C
(
i
,
j
),
ldc
,
i
==
0
);
}
}
}
...
...
src/operators/math/gemm.h
浏览文件 @
e3ba5be3
...
...
@@ -49,12 +49,13 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
float
*
buffer
);
// 分块矩阵乘法
void
InnerKernel
(
int
m
,
int
n
,
int
k
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
*
C
,
int
ldc
,
int
first_time
);
void
InnerKernel
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
int
first_time
);
// 计算一个更小的 4 * 4 的 C 矩阵分块
void
AddDot4x4
(
int
k
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
*
C
,
int
ldc
,
int
mc
,
int
nc
);
void
AddDot4x4
(
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
int
mc
,
int
nc
);
// 32位 float 矩阵乘法
void
sgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
const
float
*
A
,
int
lda
,
...
...
test/common/test_gemm.cpp.cpp
浏览文件 @
e3ba5be3
...
...
@@ -20,27 +20,32 @@ limitations under the License. */
#define b(i, j) b[(i)*ldb + (j)]
#define c1(i, j) c1[(i)*ldc + (j)]
#define m
7
#define n
7
#define k 7
#define m
62
#define n
63
#define k 7
4
int
main
()
{
int
lda
=
k
;
int
ldb
=
n
;
int
ldc
=
n
;
float
a
[
7
*
7
];
float
b
[
7
*
7
];
float
c
[
7
*
7
]
=
{
0
};
float
c1
[
7
*
7
]
=
{
0
};
float
a
[
62
*
74
];
float
b
[
7
4
*
63
];
float
c
[
62
*
63
]
=
{
0
};
float
c1
[
62
*
63
]
=
{
0
};
for
(
int
i
=
0
;
i
<
m
*
k
;
++
i
)
{
a
[
i
]
=
2
;
}
for
(
int
i
=
0
;
i
<
k
*
n
;
++
i
)
{
b
[
i
]
=
2
;
}
for
(
int
i
=
0
;
i
<
m
*
n
;
++
i
)
{
c
[
i
]
=
2
;
c1
[
i
]
=
2
;
}
paddle_mobile
::
operators
::
math
::
sgemm
(
m
,
n
,
k
,
1
,
a
,
lda
,
b
,
ldb
,
0
,
c
,
ldc
);
paddle_mobile
::
operators
::
math
::
sgemm
(
m
,
n
,
k
,
0.9
,
a
,
lda
,
b
,
ldb
,
0.3
,
c
,
ldc
);
for
(
int
i
=
0
;
i
<
m
*
n
;
++
i
)
{
std
::
cout
<<
c
[
i
]
<<
" | "
;
if
(
i
%
n
==
(
n
-
1
))
{
...
...
@@ -49,8 +54,9 @@ int main() {
}
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
c1
(
i
,
j
)
*=
0.3
;
for
(
int
p
=
0
;
p
<
k
;
++
p
)
{
c1
(
i
,
j
)
+=
a
(
i
,
p
)
*
b
(
p
,
j
);
c1
(
i
,
j
)
+=
0.9
*
a
(
i
,
p
)
*
b
(
p
,
j
);
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录