diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index bc484c2f0d10b8f2d41083ec47aab7f2c1390832..0c0ae8e3dd84f38218d03a761c58a664b927f161 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -130,8 +130,9 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, } // 分块矩阵乘法 -void InnerKernel(int m, int n, int k, const float *A, int lda, const float *B, - int ldb, float *C, int ldc, int first_time) { +void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + int first_time) { int Buff_A_M = m; int Buff_B_N = n; @@ -162,15 +163,15 @@ void InnerKernel(int m, int n, int k, const float *A, int lda, const float *B, // A 取 4 行,打包预热 for (i = 0; i < Buff_A_M; i += MR) { mc = (m - i) < MR ? _mc : MR; - AddDot4x4(k, &packedA[i * k], 4, &packedB[j * k], k, &C(i, j), ldc, mc, - nc); + AddDot4x4(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta, + &C(i, j), ldc, mc, nc); } } } // 计算一个更小的 4 * 4 的 C 矩阵分块 -void AddDot4x4(int k, const float *a, int lda, const float *b, int ldb, - float *C, int ldc, int mc, int nc) { +void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b, + int ldb, float beta, float *C, int ldc, int mc, int nc) { float c[16] = {0}; float reg_a0, reg_a1, reg_a2, reg_a3, reg_b0, reg_b1, reg_b2, reg_b3; @@ -218,7 +219,16 @@ void AddDot4x4(int k, const float *a, int lda, const float *b, int ldb, int i, j; for (i = 0; i < mc; ++i) { for (j = 0; j < nc; ++j) { - C(i, j) += c[i * 4 + j]; + if (beta == 0.0) { + C(i, j) = 0.0; + } else if (beta != 1.0) { + C(i, j) *= beta; + } + if (alpha != 1.0) { + C(i, j) += alpha * c[i * MR + j]; + } else { + C(i, j) += c[i * MR + j]; + } } } } @@ -227,15 +237,20 @@ void AddDot4x4(int k, const float *a, int lda, const float *b, int ldb, void sgemm(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc) { int i, j, p, mc, nc, kc; - + float beta_; for (j = 0; j < n; j += NC) { nc = s_min(n - j, NC); for (p = 0; p < k; p += KC) { kc = s_min(k - p, KC); for (i = 0; i < m; i += MC) { mc = s_min(m - i, MC); - InnerKernel(mc, nc, kc, &A(i, p), lda, &B(p, j), ldb, &C(i, j), ldc, - i == 0); + if (p != 0) { + beta_ = 1.0; + } else { + beta_ = beta; + } + InnerKernel(mc, nc, kc, alpha, &A(i, p), lda, &B(p, j), ldb, beta_, + &C(i, j), ldc, i == 0); } } } diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 2eea23a3b104c372305ba0b96fd020bf6c5e47dd..87d65bdd28a42c4510668345ad7ce7058eb2cdf8 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -49,12 +49,13 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, float *buffer); // 分块矩阵乘法 -void InnerKernel(int m, int n, int k, const float *A, int lda, const float *B, - int ldb, float *C, int ldc, int first_time); +void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + int first_time); // 计算一个更小的 4 * 4 的 C 矩阵分块 -void AddDot4x4(int k, const float *A, int lda, const float *B, int ldb, - float *C, int ldc, int mc, int nc); +void AddDot4x4(int k, float alpha, const float *A, int lda, const float *B, + int ldb, float beta, float *C, int ldc, int mc, int nc); // 32位 float 矩阵乘法 void sgemm(int m, int n, int k, float alpha, const float *A, int lda, diff --git a/test/common/test_gemm.cpp.cpp b/test/common/test_gemm.cpp.cpp index 0e32a87c72151dfd105ab50145c9eacc9f70f8a2..f385bf960e266df1ddfd317c3281904fea1a21ee 100644 --- a/test/common/test_gemm.cpp.cpp +++ b/test/common/test_gemm.cpp.cpp @@ -20,27 +20,32 @@ limitations under the License. */ #define b(i, j) b[(i)*ldb + (j)] #define c1(i, j) c1[(i)*ldc + (j)] -#define m 7 -#define n 7 -#define k 7 +#define m 62 +#define n 63 +#define k 74 int main() { int lda = k; int ldb = n; int ldc = n; - float a[7 * 7]; - float b[7 * 7]; - float c[7 * 7] = {0}; - float c1[7 * 7] = {0}; + float a[62 * 74]; + float b[74 * 63]; + float c[62 * 63] = {0}; + float c1[62 * 63] = {0}; for (int i = 0; i < m * k; ++i) { a[i] = 2; } for (int i = 0; i < k * n; ++i) { b[i] = 2; } + for (int i = 0; i < m * n; ++i) { + c[i] = 2; + c1[i] = 2; + } - paddle_mobile::operators::math::sgemm(m, n, k, 1, a, lda, b, ldb, 0, c, ldc); + paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, + ldc); for (int i = 0; i < m * n; ++i) { std::cout << c[i] << " | "; if (i % n == (n - 1)) { @@ -49,8 +54,9 @@ int main() { } for (int j = 0; j < n; ++j) { for (int i = 0; i < m; ++i) { + c1(i, j) *= 0.3; for (int p = 0; p < k; ++p) { - c1(i, j) += a(i, p) * b(p, j); + c1(i, j) += 0.9 * a(i, p) * b(p, j); } } }