提交 0f7e0f6e 编写于 作者: Z zhaojiaying01

add vector matrix multiplication in Gemm

上级 d868a0a1
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace memory {
const int MALLOC_ALIGN = 16;
const int MALLOC_ALIGN = 64;
void Copy(void *dst, const void *src, size_t num) {
std::memcpy(dst, src, num);
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/gemm.h"
#include "common/log.h"
#include "memory/t_malloc.h"
#ifndef X86
#include <arm_neon.h>
#endif
......@@ -757,6 +759,10 @@ void sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc) {
int i, j, p, mc, nc, kc;
float beta_;
if (m == 1) {
VectorKernel(1, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
return;
}
for (j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
for (p = 0; p < k; p += KC) {
......@@ -803,6 +809,223 @@ void sgemm_relu(int m, int n, int k, float alpha, const float *A, int lda,
}
}
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc) {
float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n));
const float *a0, *b0, *b1, *b2, *b3;
float *c0, *C0;
int volatile kc1 = k / 4;
int volatile kc2 = k % 4;
int volatile nc1 = n / 16;
int _nc1 = n % 16;
int volatile nc2 = _nc1 / 4;
int volatile nc3 = _nc1 % 4;
// DLOG << "GEMM VECTOR kc1 = " << kc1 << ", kc2 = " << kc2;
// DLOG << "GEMM VECTOR nc1 = " << nc1 << ", nc2 = " << nc2 << ", nc3 = " <<
// nc3;
for (int i = 0; i < kc1; i++) {
a0 = A + i * 4;
b0 = B + i * 4 * ldb;
b1 = b0 + ldb;
b2 = b1 + ldb;
b3 = b2 + ldb;
c0 = bufferC;
asm volatile(
"pld [%[a0], #16] \n\t"
"vld1.32 {q0}, [%[a0]] \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"cmp %[i], #0 \n\t"
"beq i_eq0_%= \n\t"
"bne i_ne0_%= \n\t"
"i_eq0_%=: \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"b gemm_nc1_%= \n\t"
"i_ne0_%=: \n\t"
"pld [%[c0], #64] \n\t"
"vld1.32 {q10, q11}, [%[c0]]! \n\t"
"vld1.32 {q12, q13}, [%[c0]] \n\t"
"sub %[c0], %[c0], #32 \n\t"
"gemm_nc1_%=: \n\t"
"pld [%[b0], #64] \n\t"
"vld1.32 {q2, q3}, [%[b0]]! \n\t"
"vld1.32 {q4, q5}, [%[b0]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q3, d0[0] \n\t"
"vmla.f32 q12, q4, d0[0] \n\t"
"vmla.f32 q13, q5, d0[0] \n\t"
"pld [%[b1], #64] \n\t"
"vld1.32 {q2, q3}, [%[b1]]! \n\t"
"vld1.32 {q4, q5}, [%[b1]]! \n\t"
"vmla.f32 q10, q2, d0[1] \n\t"
"vmla.f32 q11, q3, d0[1] \n\t"
"vmla.f32 q12, q4, d0[1] \n\t"
"vmla.f32 q13, q5, d0[1] \n\t"
"pld [%[b2], #64] \n\t"
"vld1.32 {q2, q3}, [%[b2]]! \n\t"
"vld1.32 {q4, q5}, [%[b2]]! \n\t"
"vmla.f32 q10, q2, d1[0] \n\t"
"vmla.f32 q11, q3, d1[0] \n\t"
"vmla.f32 q12, q4, d1[0] \n\t"
"vmla.f32 q13, q5, d1[0] \n\t"
"pld [%[b3], #64] \n\t"
"vld1.32 {q2, q3}, [%[b3]]! \n\t"
"vld1.32 {q4, q5}, [%[b3]]! \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q4, d1[1] \n\t"
"vmla.f32 q13, q5, d1[1] \n\t"
"vst1.32 {q10, q11}, [%[c0]]! \n\t"
"vst1.32 {q12, q13}, [%[c0]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"cmp %[i], #0 \n\t"
"beq ii_eq0_%= \n\t"
"bne ii_ne0_%= \n\t"
"ii_eq0_%=: \n\t"
"vmov.f32 q10, #0.0 \n\t"
"b gemm_nc2_%= \n\t"
"ii_ne0_%=: \n\t"
"pld [%[c0], #16] \n\t"
"vld1.32 {q10}, [%[c0]] \n\t"
"gemm_nc2_%=: \n\t"
"pld [%[b0], #16] \n\t"
"vld1.32 {q2}, [%[b0]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"pld [%[b1], #16] \n\t"
"vld1.32 {q3}, [%[b1]]! \n\t"
"vmla.f32 q10, q3, d0[1] \n\t"
"pld [%[b2], #16] \n\t"
"vld1.32 {q4}, [%[b2]]! \n\t"
"vmla.f32 q10, q4, d1[0] \n\t"
"pld [%[b3], #16] \n\t"
"vld1.32 {q5}, [%[b3]]! \n\t"
"vmla.f32 q10, q5, d1[1] \n\t"
"vst1.32 {q10}, [%[c0]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t"
"end_nc2_%=: \n\t"
: [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3),
[c0] "+r"(c0)
: [a0] "r"(a0), [i] "r"(i), [nc1] "r"(nc1), [nc2] "r"(nc2)
: "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13");
for (int j = 0; j < nc3; j++) {
if (i == 0) {
*c0 = (*a0) * (*b0++);
} else {
*c0 += (*a0) * (*b0++);
}
*c0 += (*(a0 + 1)) * (*b1++);
*c0 += (*(a0 + 2)) * (*b2++);
*c0 += (*(a0 + 3)) * (*b3++);
c0++;
}
}
for (int i = 0; i < kc2; ++i) {
a0 = A + 4 * kc1 + i;
b0 = B + (4 * kc1 + i) * ldb;
c0 = bufferC;
asm volatile(
"pld [%[a0], #16] \n\t"
"vld1.32 {d0}, [%[a0]] \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"pld [%[c0], #64] \n\t"
"vld1.32 {q10, q11}, [%[c0]]! \n\t"
"vld1.32 {q12, q13}, [%[c0]] \n\t"
"sub %[c0], %[c0], #32 \n\t"
"gemm_nc1_%=: \n\t"
"pld [%[b0], #64] \n\t"
"vld1.32 {q2, q3}, [%[b0]]! \n\t"
"vld1.32 {q4, q5}, [%[b0]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q3, d0[0] \n\t"
"vmla.f32 q12, q4, d0[0] \n\t"
"vmla.f32 q13, q5, d0[0] \n\t"
"vst1.32 {q10, q11}, [%[c0]]! \n\t"
"vst1.32 {q12, q13}, [%[c0]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"pld [%[c0], #16] \n\t"
"vld1.32 {q10}, [%[c0]] \n\t"
"gemm_nc2_%=: \n\t"
"vld1.32 {q2}, [%[b0]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vst1.32 {q10}, [%[c0]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t"
"end_nc2_%=: \n\t"
: [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3),
[c0] "+r"(c0)
: [a0] "r"(a0), [nc1] "r"(nc1), [nc2] "r"(nc2)
: "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13");
for (int j = 0; j < nc3; j++) {
*c0 += (*a0) * (*b0++);
c0++;
}
}
c0 = bufferC;
C0 = C;
for (int i = 0; i < n; i++) {
if (beta == 1.0) {
*C0++ += *c0++;
} else {
*C0++ = *c0++;
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -53,6 +53,10 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
int first_time);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc);
// 计算一个更小的 4 * 4 的 C 矩阵分块
void AddDot4x4(int k, float alpha, const float *A, int lda, const float *B,
int ldb, float beta, float *C, int ldc, int mc, int nc);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册