未验证 提交 e3ba5be3 编写于 作者: S smilejames 提交者: GitHub

Merge pull request #302 from smilejames/develop

update Gemm with implementation of 'C = alpha * A * B + beta * C'
......@@ -130,8 +130,9 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
}
// 分块矩阵乘法
void InnerKernel(int m, int n, int k, const float *A, int lda, const float *B,
int ldb, float *C, int ldc, int first_time) {
void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
int first_time) {
int Buff_A_M = m;
int Buff_B_N = n;
......@@ -162,15 +163,15 @@ void InnerKernel(int m, int n, int k, const float *A, int lda, const float *B,
// A 取 4 行,打包预热
for (i = 0; i < Buff_A_M; i += MR) {
mc = (m - i) < MR ? _mc : MR;
AddDot4x4(k, &packedA[i * k], 4, &packedB[j * k], k, &C(i, j), ldc, mc,
nc);
AddDot4x4(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta,
&C(i, j), ldc, mc, nc);
}
}
}
// 计算一个更小的 4 * 4 的 C 矩阵分块
void AddDot4x4(int k, const float *a, int lda, const float *b, int ldb,
float *C, int ldc, int mc, int nc) {
void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
int ldb, float beta, float *C, int ldc, int mc, int nc) {
float c[16] = {0};
float reg_a0, reg_a1, reg_a2, reg_a3, reg_b0, reg_b1, reg_b2, reg_b3;
......@@ -218,7 +219,16 @@ void AddDot4x4(int k, const float *a, int lda, const float *b, int ldb,
int i, j;
for (i = 0; i < mc; ++i) {
for (j = 0; j < nc; ++j) {
C(i, j) += c[i * 4 + j];
if (beta == 0.0) {
C(i, j) = 0.0;
} else if (beta != 1.0) {
C(i, j) *= beta;
}
if (alpha != 1.0) {
C(i, j) += alpha * c[i * MR + j];
} else {
C(i, j) += c[i * MR + j];
}
}
}
}
......@@ -227,15 +237,20 @@ void AddDot4x4(int k, const float *a, int lda, const float *b, int ldb,
void sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc) {
int i, j, p, mc, nc, kc;
float beta_;
for (j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
for (p = 0; p < k; p += KC) {
kc = s_min(k - p, KC);
for (i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
InnerKernel(mc, nc, kc, &A(i, p), lda, &B(p, j), ldb, &C(i, j), ldc,
i == 0);
if (p != 0) {
beta_ = 1.0;
} else {
beta_ = beta;
}
InnerKernel(mc, nc, kc, alpha, &A(i, p), lda, &B(p, j), ldb, beta_,
&C(i, j), ldc, i == 0);
}
}
}
......
......@@ -49,12 +49,13 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int m, int n, int k, const float *A, int lda, const float *B,
int ldb, float *C, int ldc, int first_time);
void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
int first_time);
// 计算一个更小的 4 * 4 的 C 矩阵分块
void AddDot4x4(int k, const float *A, int lda, const float *B, int ldb,
float *C, int ldc, int mc, int nc);
void AddDot4x4(int k, float alpha, const float *A, int lda, const float *B,
int ldb, float beta, float *C, int ldc, int mc, int nc);
// 32位 float 矩阵乘法
void sgemm(int m, int n, int k, float alpha, const float *A, int lda,
......
......@@ -20,27 +20,32 @@ limitations under the License. */
#define b(i, j) b[(i)*ldb + (j)]
#define c1(i, j) c1[(i)*ldc + (j)]
#define m 7
#define n 7
#define k 7
#define m 62
#define n 63
#define k 74
int main() {
int lda = k;
int ldb = n;
int ldc = n;
float a[7 * 7];
float b[7 * 7];
float c[7 * 7] = {0};
float c1[7 * 7] = {0};
float a[62 * 74];
float b[74 * 63];
float c[62 * 63] = {0};
float c1[62 * 63] = {0};
for (int i = 0; i < m * k; ++i) {
a[i] = 2;
}
for (int i = 0; i < k * n; ++i) {
b[i] = 2;
}
for (int i = 0; i < m * n; ++i) {
c[i] = 2;
c1[i] = 2;
}
paddle_mobile::operators::math::sgemm(m, n, k, 1, a, lda, b, ldb, 0, c, ldc);
paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
ldc);
for (int i = 0; i < m * n; ++i) {
std::cout << c[i] << " | ";
if (i % n == (n - 1)) {
......@@ -49,8 +54,9 @@ int main() {
}
for (int j = 0; j < n; ++j) {
for (int i = 0; i < m; ++i) {
c1(i, j) *= 0.3;
for (int p = 0; p < k; ++p) {
c1(i, j) += a(i, p) * b(p, j);
c1(i, j) += 0.9 * a(i, p) * b(p, j);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册