提交 221fec0b 编写于 作者: 吴承辉

Merge branch 'gemm' into 'master'

Implement Gemm transpose and more-than-1-dimension batch

See merge request !559
......@@ -12,18 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <algorithm>
#include <cstring>
#include "mace/core/tensor.h"
#include "mace/kernels/gemm.h"
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/core/macros.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/logging.h"
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
......@@ -37,13 +35,14 @@ inline void GemmBlock(const float *A,
const index_t height,
const index_t K,
const index_t width,
const index_t stride_k,
const index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *C) {
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
for (int k = 0; k < K; ++k) {
C[i * stride_w + j] += A[i * stride_k + k] * B[k * stride_w + j];
C[i * stride_c + j] += A[i * stride_a + k] * B[k * stride_b + j];
}
}
}
......@@ -75,8 +74,9 @@ inline void GemmBlock(const float *A,
inline void Gemm884(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14,
......@@ -86,38 +86,38 @@ inline void Gemm884(const float *a_ptr,
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_k);
a11 = vld1q_f32(a_ptr + 5 * stride_k + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_k);
a13 = vld1q_f32(a_ptr + 6 * stride_k + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_k);
a15 = vld1q_f32(a_ptr + 7 * stride_k + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_a);
a13 = vld1q_f32(a_ptr + 6 * stride_a + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_a);
a15 = vld1q_f32(a_ptr + 7 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c4 = vld1q_f32(c_ptr + 4 * stride_w);
c5 = vld1q_f32(c_ptr + 5 * stride_w);
c6 = vld1q_f32(c_ptr + 6 * stride_w);
c7 = vld1q_f32(c_ptr + 7 * stride_w);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
c6 = vld1q_f32(c_ptr + 6 * stride_c);
c7 = vld1q_f32(c_ptr + 7 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
......@@ -140,25 +140,28 @@ inline void Gemm884(const float *a_ptr,
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 4 * stride_w, c4);
vst1q_f32(c_ptr + 5 * stride_w, c5);
vst1q_f32(c_ptr + 6 * stride_w, c6);
vst1q_f32(c_ptr + 7 * stride_w, c7);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
vst1q_f32(c_ptr + 6 * stride_c, c6);
vst1q_f32(c_ptr + 7 * stride_c, c7);
#else
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm184(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
MACE_UNUSED(stride_k);
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED(stride_a);
MACE_UNUSED(stride_c);
float32x4_t a0, a1;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0;
......@@ -167,13 +170,13 @@ inline void Gemm184(const float *a_ptr,
a1 = vld1q_f32(a_ptr + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
......@@ -185,14 +188,15 @@ inline void Gemm184(const float *a_ptr,
vst1q_f32(c_ptr, c0);
#else
GemmBlock(a_ptr, b_ptr, 1, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 1, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm284(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3;
......@@ -201,20 +205,20 @@ inline void Gemm284(const float *a_ptr,
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
......@@ -225,16 +229,17 @@ inline void Gemm284(const float *a_ptr,
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 1 * stride_c, c1);
#else
GemmBlock(a_ptr, b_ptr, 2, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 2, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm384(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5;
......@@ -243,23 +248,23 @@ inline void Gemm384(const float *a_ptr,
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
......@@ -272,17 +277,18 @@ inline void Gemm384(const float *a_ptr,
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
#else
GemmBlock(a_ptr, b_ptr, 3, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 3, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm484(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7;
......@@ -291,26 +297,26 @@ inline void Gemm484(const float *a_ptr,
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
......@@ -325,18 +331,19 @@ inline void Gemm484(const float *a_ptr,
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
#else
GemmBlock(a_ptr, b_ptr, 4, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 4, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm584(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9;
......@@ -345,29 +352,29 @@ inline void Gemm584(const float *a_ptr,
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c4 = vld1q_f32(c_ptr + 4 * stride_w);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
......@@ -384,19 +391,20 @@ inline void Gemm584(const float *a_ptr,
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 4 * stride_w, c4);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
#else
GemmBlock(a_ptr, b_ptr, 5, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 5, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm684(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11;
......@@ -405,32 +413,32 @@ inline void Gemm684(const float *a_ptr,
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_k);
a11 = vld1q_f32(a_ptr + 5 * stride_k + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c4 = vld1q_f32(c_ptr + 4 * stride_w);
c5 = vld1q_f32(c_ptr + 5 * stride_w);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
......@@ -449,21 +457,22 @@ inline void Gemm684(const float *a_ptr,
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 4 * stride_w, c4);
vst1q_f32(c_ptr + 5 * stride_w, c5);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
#else
GemmBlock(a_ptr, b_ptr, 6, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 6, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm784(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13;
......@@ -472,35 +481,35 @@ inline void Gemm784(const float *a_ptr,
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_k);
a11 = vld1q_f32(a_ptr + 5 * stride_k + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_k);
a13 = vld1q_f32(a_ptr + 6 * stride_k + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_a);
a13 = vld1q_f32(a_ptr + 6 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c4 = vld1q_f32(c_ptr + 4 * stride_w);
c5 = vld1q_f32(c_ptr + 5 * stride_w);
c6 = vld1q_f32(c_ptr + 6 * stride_w);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
c6 = vld1q_f32(c_ptr + 6 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
......@@ -521,48 +530,49 @@ inline void Gemm784(const float *a_ptr,
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 4 * stride_w, c4);
vst1q_f32(c_ptr + 5 * stride_w, c5);
vst1q_f32(c_ptr + 6 * stride_w, c6);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
vst1q_f32(c_ptr + 6 * stride_c, c6);
#else
GemmBlock(a_ptr, b_ptr, 7, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 7, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void GemmX84(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr,
int row) {
switch (row) {
case 1:
Gemm184(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
Gemm184(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 2:
Gemm284(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
Gemm284(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 3:
Gemm384(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
Gemm384(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 4:
Gemm484(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
Gemm484(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 5:
Gemm584(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
Gemm584(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 6:
Gemm684(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
Gemm684(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 7:
Gemm784(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
Gemm784(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 8:
Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
default:
MACE_NOT_IMPLEMENTED;
......@@ -574,14 +584,15 @@ inline void GemmTile(const float *A,
const index_t height,
const index_t K,
const index_t width,
const index_t stride_k,
const index_t stride_w,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *C) {
#if defined(MACE_ENABLE_NEON)
index_t h, w, k;
for (h = 0; h < height - 7; h += 8) {
for (k = 0; k < K - 7; k += 8) {
const float *a_ptr = A + (h * stride_k + k);
const float *a_ptr = A + (h * stride_a + k);
#if defined(__aarch64__) && defined(__clang__)
int nw = width >> 2;
if (nw > 0) {
......@@ -590,38 +601,38 @@ inline void GemmTile(const float *A,
a14, a15;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_k);
a11 = vld1q_f32(a_ptr + 5 * stride_k + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_k);
a13 = vld1q_f32(a_ptr + 6 * stride_k + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_k);
a15 = vld1q_f32(a_ptr + 7 * stride_k + 4);
const float *b_ptr0 = B + k * stride_w;
const float *b_ptr1 = B + (k + 1) * stride_w;
const float *b_ptr2 = B + (k + 2) * stride_w;
const float *b_ptr3 = B + (k + 3) * stride_w;
const float *b_ptr4 = B + (k + 4) * stride_w;
const float *b_ptr5 = B + (k + 5) * stride_w;
const float *b_ptr6 = B + (k + 6) * stride_w;
const float *b_ptr7 = B + (k + 7) * stride_w;
float *c_ptr0 = C + h * stride_w;
float *c_ptr1 = C + (h + 1) * stride_w;
float *c_ptr2 = C + (h + 2) * stride_w;
float *c_ptr3 = C + (h + 3) * stride_w;
float *c_ptr4 = C + (h + 4) * stride_w;
float *c_ptr5 = C + (h + 5) * stride_w;
float *c_ptr6 = C + (h + 6) * stride_w;
float *c_ptr7 = C + (h + 7) * stride_w;
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_a);
a13 = vld1q_f32(a_ptr + 6 * stride_a + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_a);
a15 = vld1q_f32(a_ptr + 7 * stride_a + 4);
const float *b_ptr0 = B + k * stride_b;
const float *b_ptr1 = B + (k + 1) * stride_b;
const float *b_ptr2 = B + (k + 2) * stride_b;
const float *b_ptr3 = B + (k + 3) * stride_b;
const float *b_ptr4 = B + (k + 4) * stride_b;
const float *b_ptr5 = B + (k + 5) * stride_b;
const float *b_ptr6 = B + (k + 6) * stride_b;
const float *b_ptr7 = B + (k + 7) * stride_b;
float *c_ptr0 = C + h * stride_c;
float *c_ptr1 = C + (h + 1) * stride_c;
float *c_ptr2 = C + (h + 2) * stride_c;
float *c_ptr3 = C + (h + 3) * stride_c;
float *c_ptr4 = C + (h + 4) * stride_c;
float *c_ptr5 = C + (h + 5) * stride_c;
float *c_ptr6 = C + (h + 6) * stride_c;
float *c_ptr7 = C + (h + 7) * stride_c;
asm volatile(
"prfm pldl1keep, [%9, #128] \n"
......@@ -824,53 +835,68 @@ inline void GemmTile(const float *A,
}
#else // gcc || armv7a
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
}
#endif // clang && armv8a
if (w < width) {
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr);
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_a, stride_b, stride_c,
c_ptr);
}
}
if (k < K) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + k * stride_w;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_k, stride_w, c_ptr);
const float *a_ptr = A + (h * stride_a + k);
const float *b_ptr = B + k * stride_b;
float *c_ptr = C + h * stride_c;
GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_a, stride_b, stride_c,
c_ptr);
}
}
if (h < height) {
index_t remain_h = height - h;
for (k = 0; k < K - 7; k += 8) {
const float *a_ptr = A + (h * stride_k + k);
const float *a_ptr = A + (h * stride_a + k);
index_t w;
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
GemmX84(a_ptr, b_ptr, stride_k, stride_w, c_ptr, remain_h);
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmX84(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h);
}
if (w < width) {
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_k, stride_w,
c_ptr);
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_a, stride_b,
stride_c, c_ptr);
}
}
if (k < K) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + k * stride_w;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_k, stride_w,
c_ptr);
const float *a_ptr = A + (h * stride_a + k);
const float *b_ptr = B + k * stride_b;
float *c_ptr = C + h * stride_c;
GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_a, stride_b,
stride_c, c_ptr);
}
}
#else
GemmBlock(A, B, height, K, width, stride_k, stride_w, C);
GemmBlock(A, B, height, K, width, stride_a, stride_b, stride_c, C);
#endif // MACE_ENABLE_NEON
}
void Transpose(const float *src,
index_t height,
index_t width,
index_t stride_w,
float *dst) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
dst[w * height + h] = src[h * stride_w + w];
}
}
}
} // namespace
// A: height x K, B: K x width, C: height x width
......@@ -880,7 +906,9 @@ void Gemm(const float *A,
const index_t height,
const index_t K,
const index_t width,
float *C) {
float *C,
const bool transpose_a,
const bool transpose_b) {
if (width == 1) {
for (index_t b = 0; b < batch; ++b) {
Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height);
......@@ -898,41 +926,77 @@ void Gemm(const float *A,
const index_t block_tile_height = RoundUpDiv(height, block_size);
const index_t block_tile_width = RoundUpDiv(width, block_size);
const index_t block_tile_k = RoundUpDiv(K, block_size);
const index_t block_tile[3] = {block_tile_height, block_tile_width,
block_tile_k};
const index_t remain_height = height % block_size;
const index_t remain_width = width % block_size;
const index_t remain_k = K % block_size;
const index_t remain[3] = {remain_height, remain_width, remain_k};
#pragma omp parallel for collapse(3)
for (index_t n = 0; n < batch; ++n) {
for (index_t bh = 0; bh < block_tile_height; ++bh) {
for (index_t bw = 0; bw < block_tile_width; ++bw) {
for (index_t bh = 0; bh < block_tile[0]; ++bh) {
for (index_t bw = 0; bw < block_tile[1]; ++bw) {
const float *a_base = A + n * height * K;
const float *b_base = B + n * K * width;
float *c_base = C + n * height * width;
const index_t ih_begin = bh * block_size;
const index_t ih_end =
bh * block_size + (bh == block_tile_height - 1 && remain_height > 0
? remain_height
: block_size);
bh * block_size +
(bh == block_tile[0] - 1 && remain[0] > 0 ? remain[0] : block_size);
const index_t iw_begin = bw * block_size;
const index_t iw_end =
bw * block_size + (bw == block_tile_width - 1 && remain_width > 0
? remain_width
: block_size);
bw * block_size +
(bw == block_tile[1] - 1 && remain[1] > 0 ? remain[1] : block_size);
for (index_t bk = 0; bk < block_tile_k; ++bk) {
for (index_t bk = 0; bk < block_tile[2]; ++bk) {
const index_t ik_begin = bk * block_size;
const index_t ik_end =
bk * block_size +
(bk == block_tile_k - 1 && remain_k > 0 ? remain_k : block_size);
bk * block_size + (bk == block_tile[2] - 1 && remain[2] > 0
? remain[2]
: block_size);
Tensor trans_a;
Tensor trans_b;
const float *real_a = nullptr;
const float *real_b = nullptr;
float *real_c = c_base + (ih_begin * width + iw_begin);
index_t stride_a;
index_t stride_b;
index_t stride_c = width;
if (transpose_a) {
trans_a.Resize({block_size, block_size});
float *trans_a_data = trans_a.mutable_data<float>();
// A[K, H] -> A[H, K]
Transpose(a_base + (ik_begin * height + ih_begin),
ik_end - ik_begin, ih_end - ih_begin, height,
trans_a_data);
real_a = trans_a_data;
stride_a = ik_end - ik_begin;
} else {
real_a = a_base + (ih_begin * K + ik_begin);
stride_a = K;
}
if (transpose_b) {
trans_b.Resize({block_size, block_size});
float *trans_b_data = trans_b.mutable_data<float>();
// B[W, K] -> B[K, W]
Transpose(b_base + (iw_begin * K + ik_begin), iw_end - iw_begin,
ik_end - ik_begin, K, trans_b_data);
real_b = trans_b_data;
stride_b = iw_end - iw_begin;
} else {
real_b = b_base + (ik_begin * width + iw_begin);
stride_b = width;
}
// inside block:
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
GemmTile(a_base + (ih_begin * K + ik_begin),
b_base + (ik_begin * width + iw_begin), ih_end - ih_begin,
ik_end - ik_begin, iw_end - iw_begin, K, width,
c_base + (ih_begin * width + iw_begin));
GemmTile(real_a, real_b, ih_end - ih_begin, ik_end - ik_begin,
iw_end - iw_begin, stride_a, stride_b, stride_c, real_c);
} // bk
} // bw
} // bh
......@@ -946,14 +1010,47 @@ void GemmRef(const float *A,
const index_t height,
const index_t K,
const index_t width,
float *C) {
float *C,
const bool transpose_a,
const bool transpose_b) {
memset(C, 0, sizeof(float) * batch * height * width);
Tensor trans_a;
Tensor trans_b;
float *trans_a_data = nullptr;
float *trans_b_data = nullptr;
if (transpose_a) {
trans_a.Resize({height, K});
trans_a_data = trans_a.mutable_data<float>();
}
if (transpose_b) {
trans_b.Resize({K, width});
trans_b_data = trans_b.mutable_data<float>();
}
for (index_t b = 0; b < batch; ++b) {
const float *real_a = nullptr;
const float *real_b = nullptr;
float *real_c = C + b * height * width;
if (transpose_a) {
// A[K, H] -> A[H, K]
Transpose(A + b * height * K, K, height, height, trans_a_data);
real_a = trans_a_data;
} else {
real_a = A + b * height * K;
}
if (transpose_b) {
// B[W, K] -> B[K, W]
Transpose(B + b * width * K, width, K, K, trans_b_data);
real_b = trans_b_data;
} else {
real_b = B + b * width * K;
}
for (index_t i = 0; i < height; ++i) {
for (index_t j = 0; j < width; ++j) {
for (index_t k = 0; k < K; ++k) {
C[(b * height + i) * width + j] +=
A[(b * height + i) * K + k] * B[(b * K + k) * width + j];
real_c[i * width + j] += real_a[i * K + k] * real_b[k * width + j];
}
}
}
......
......@@ -30,7 +30,9 @@ void Gemm(const float *A,
const index_t height,
const index_t K,
const index_t width,
float *C);
float *C,
const bool transpose_a = false,
const bool transpose_b = false);
void GemmRef(const float *A,
const float *B,
......@@ -38,7 +40,9 @@ void GemmRef(const float *A,
const index_t height,
const index_t K,
const index_t width,
float *C);
float *C,
const bool transpose_a = false,
const bool transpose_b = false);
void Gemv(const float *m_ptr,
const float *v_ptr,
......
......@@ -13,17 +13,22 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include <memory>
#include <random>
#include "mace/kernels/gemm.h"
#include "mace/core/types.h"
#include "mace/kernels/gemm.h"
namespace mace {
namespace {
void GemmTest(index_t batch, index_t N, index_t K, index_t M) {
void GemmTest(index_t batch,
index_t N,
index_t K,
index_t M,
bool transpose_a,
bool transpose_b) {
std::unique_ptr<float[]> A(new float[batch * N * K]);
std::unique_ptr<float[]> B(new float[batch * K * M]);
std::unique_ptr<float[]> C(new float[batch * N * M]);
......@@ -34,15 +39,13 @@ void GemmTest(index_t batch, index_t N, index_t K, index_t M) {
std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + batch * N * K,
[&gen, &nd] {
return nd(gen);
});
[&gen, &nd] { return nd(gen); });
std::generate(B.get(), B.get() + batch * K * M,
[&gen, &nd] {
return nd(gen);
});
kernels::Gemm(A.get(), B.get(), batch, N, K, M, C.get());
kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get());
[&gen, &nd] { return nd(gen); });
kernels::Gemm(A.get(), B.get(), batch, N, K, M, C.get(), transpose_a,
transpose_b);
kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a,
transpose_b);
for (int i = 0; i < batch * N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
......@@ -59,14 +62,8 @@ void GemvTest(index_t batch, index_t N, index_t M) {
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * M,
[&gen, &nd] {
return nd(gen);
});
std::generate(B.get(), B.get() + batch * M,
[&gen, &nd] {
return nd(gen);
});
std::generate(A.get(), A.get() + N * M, [&gen, &nd] { return nd(gen); });
std::generate(B.get(), B.get() + batch * M, [&gen, &nd] { return nd(gen); });
kernels::Gemv(A.get(), B.get(), batch, M, N, C.get());
kernels::GemvRef(A.get(), B.get(), batch, M, N, C_ref.get());
......@@ -78,36 +75,36 @@ void GemvTest(index_t batch, index_t N, index_t M) {
} // namespace
TEST(GEMMTest, AlignedWithoutBatch) {
GemmTest(1, 1, 64, 128);
GemmTest(1, 2, 64, 128);
GemmTest(1, 3, 64, 128);
GemmTest(1, 4, 64, 128);
GemmTest(1, 5, 64, 128);
GemmTest(1, 6, 64, 128);
GemmTest(1, 7, 64, 128);
GemmTest(1, 17, 64, 128);
GemmTest(1, 1, 64, 128, false, false);
GemmTest(1, 2, 64, 128, false, true);
GemmTest(1, 3, 64, 128, true, false);
GemmTest(1, 4, 64, 128, true, true);
GemmTest(1, 5, 64, 128, false, false);
GemmTest(1, 6, 64, 128, false, true);
GemmTest(1, 7, 64, 128, true, false);
GemmTest(1, 17, 64, 128, true, true);
}
TEST(GEMMTest, UnalignedWithoutBatch) {
GemmTest(1, 1, 63, 127);
GemmTest(1, 2, 63, 127);
GemmTest(1, 3, 63, 127);
GemmTest(1, 4, 63, 127);
GemmTest(1, 5, 63, 127);
GemmTest(1, 6, 63, 127);
GemmTest(1, 7, 63, 127);
GemmTest(1, 17, 63, 127);
GemmTest(1, 1, 63, 127, false, false);
GemmTest(1, 2, 63, 127, false, true);
GemmTest(1, 3, 63, 127, true, false);
GemmTest(1, 4, 63, 127, true, true);
GemmTest(1, 5, 63, 127, false, false);
GemmTest(1, 6, 63, 127, false, true);
GemmTest(1, 7, 63, 127, true, false);
GemmTest(1, 17, 63, 127, true, true);
}
TEST(GEMMTest, UnalignedWithBatch) {
GemmTest(3, 1, 63, 127);
GemmTest(3, 2, 63, 127);
GemmTest(3, 3, 63, 127);
GemmTest(3, 4, 63, 127);
GemmTest(3, 5, 63, 127);
GemmTest(3, 6, 63, 127);
GemmTest(3, 7, 63, 127);
GemmTest(3, 17, 63, 127);
GemmTest(3, 1, 63, 127, false, false);
GemmTest(3, 2, 63, 127, false, true);
GemmTest(3, 3, 63, 127, true, false);
GemmTest(3, 4, 63, 127, true, true);
GemmTest(3, 5, 63, 127, false, false);
GemmTest(3, 6, 63, 127, false, true);
GemmTest(3, 7, 63, 127, true, false);
GemmTest(3, 17, 63, 127, true, true);
}
TEST(GEMMTest, gemv) {
......
......@@ -20,6 +20,8 @@
#endif
#include <algorithm>
#include <utility>
#include <functional>
#include <memory>
#include <string>
#include <vector>
......@@ -36,14 +38,39 @@
namespace mace {
namespace kernels {
template<DeviceType D, typename T>
template <DeviceType D, typename T>
struct MatMulFunctor {
MaceStatus operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
bool transpose_a,
bool transpose_b,
StatsFuture *future) {
MACE_UNUSED(future);
std::vector<index_t> c_shape = {A->dim(0), A->dim(1), B->dim(2), 1};
index_t batch;
index_t height;
index_t K;
index_t width;
index_t rank = A->dim_size();
height = A->dim(rank - 2);
K = A->dim(rank - 1);
if (transpose_a) {
std::swap(height, K);
}
if (transpose_b) {
width = B->dim(rank - 2);
} else {
width = B->dim(rank - 1);
}
batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
std::vector<index_t> c_shape = A->shape();
c_shape[rank - 2] = height;
c_shape[rank - 1] = width;
MACE_RETURN_IF_ERROR(C->Resize(c_shape));
Tensor::MappingGuard guarda(A);
......@@ -53,27 +80,26 @@ struct MatMulFunctor {
const T *b_ptr_base = B->data<T>();
T *c_ptr_base = C->mutable_data<T>();
const index_t batch = C->dim(0);
const index_t height = C->dim(1);
const index_t width = C->dim(2);
const index_t K = A->dim(2);
// It is better to use large block size if it fits for fast cache.
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// the block size should be sqrt(32k / sizeof(T) / 3).
memset(c_ptr_base, 0, batch * height * width * sizeof(T));
Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base);
Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base,
transpose_a, transpose_b);
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template<typename T>
template <typename T>
struct MatMulFunctor<DeviceType::GPU, T> {
MaceStatus operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
bool transpose_a,
bool transpose_b,
StatsFuture *future);
cl::Kernel kernel_;
......
......@@ -134,8 +134,12 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()(
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
if (buffer->dim_size() < 4) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
}
b2f_kernel.setArg(idx++, *(image->opencl_image()));
const uint32_t kwg_size =
......
......@@ -76,19 +76,27 @@ void CalWinogradFilterImageShape(
// [W * C, N * RoundUp<4>(H)]
void CalInOutHeightImageShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 4);
std::vector<index_t> padded_shape = shape;
while (padded_shape.size() < 4) {
padded_shape.push_back(1);
}
MACE_CHECK(padded_shape.size() == 4);
image_shape->resize(2);
(*image_shape)[0] = shape[2] * shape[3];
(*image_shape)[1] = shape[0] * RoundUpDiv4(shape[1]);
(*image_shape)[0] = padded_shape[2] * padded_shape[3];
(*image_shape)[1] = padded_shape[0] * RoundUpDiv4(padded_shape[1]);
}
// [RoundUp<4>(W) * C, N * H]
void CalInOutWidthImageShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 4);
std::vector<index_t> padded_shape = shape;
while (padded_shape.size() < 4) {
padded_shape.push_back(1);
}
MACE_CHECK(padded_shape.size() == 4);
image_shape->resize(2);
(*image_shape)[0] = RoundUpDiv4(shape[2]) * shape[3];
(*image_shape)[1] = shape[0] * shape[1];
(*image_shape)[0] = RoundUpDiv4(padded_shape[2]) * padded_shape[3];
(*image_shape)[1] = padded_shape[0] * padded_shape[1];
}
// [Ic * H * W, (Oc + 3) / 4]
......@@ -150,10 +158,10 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape,
const BufferType type) {
if (type == WINOGRAD_FILTER) {
return {16, shape[0], shape[1], 1};
return {16, shape[0], shape[1]};
} else if (type == IN_OUT_HEIGHT) {
index_t out_width = shape[0] * ((shape[1] - 1) / 2) * ((shape[2] - 1) / 2);
return {16, shape[3], out_width, 1};
return {16, shape[3], out_width};
} else {
LOG(FATAL) << "Mace not supported yet.";
return std::vector<index_t>();
......
......@@ -122,8 +122,12 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
if (buffer->dim_size() < 4) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
}
b2f_kernel.setArg(idx++, *(image->opencl_image()));
const uint32_t kwg_size =
......
......@@ -24,17 +24,27 @@ template <typename T>
MaceStatus MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
bool transpose_a,
bool transpose_b,
StatsFuture *future) {
MACE_UNUSED(future);
std::vector<index_t> c_shape = {A->dim(0), A->dim(1), B->dim(2), 1};
MACE_CHECK(!transpose_a && !transpose_b,
"GPU does not support transpose matmul");
index_t rank = A->dim_size();
index_t height = A->dim(rank - 2);
index_t K = A->dim(rank - 1);
index_t width = B->dim(rank - 1);
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
std::vector<index_t> c_shape = A->shape();
c_shape[rank - 2] = height;
c_shape[rank - 1] = width;
std::vector<size_t> c_image_shape;
CalImage2DShape(c_shape, BufferType::IN_OUT_HEIGHT, &c_image_shape);
MACE_RETURN_IF_ERROR(C->ResizeImage(c_shape, c_image_shape));
const index_t batch = C->dim(0);
const index_t height = C->dim(1);
const index_t width = C->dim(2);
const index_t height_blocks = RoundUpDiv4(height);
const index_t width_blocks = RoundUpDiv4(width);
const uint32_t gws[2] = {
......@@ -82,13 +92,12 @@ MaceStatus MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
kernel_.setArg(idx++, *(C->opencl_image()));
kernel_.setArg(idx++, static_cast<int>(height));
kernel_.setArg(idx++, static_cast<int>(width));
kernel_.setArg(idx++, static_cast<int>(A->dim(2)));
kernel_.setArg(idx++, static_cast<int>(K));
kernel_.setArg(idx++, static_cast<int>(height_blocks));
kernel_.setArg(idx++, static_cast<int>(RoundUpDiv4(A->dim(2))));
kernel_.setArg(idx++, static_cast<int>(RoundUpDiv4(K)));
const std::vector<uint32_t> lws = {kwg_size_ / 64, 64, 0};
std::string tuning_key = Concat("matmul_opencl_kernel", C->dim(0), C->dim(1),
C->dim(2), C->dim(3));
std::string tuning_key = Concat("matmul_opencl_kernel", batch, height, width);
TuningOrRun2DKernel(kernel_, tuning_key, gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
......
......@@ -74,7 +74,7 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
static_cast<uint32_t>(RoundUpDiv4(input_tensor->dim(3)))};
if (!IsVecEqual(input_shape_, input_tensor->shape())) {
output_shape = {16, input_tensor->dim(3), out_width, 1};
output_shape = {16, input_tensor->dim(3), out_width};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, &image_shape);
MACE_RETURN_IF_ERROR(output_tensor->ResizeImage(output_shape, image_shape));
......@@ -104,7 +104,7 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
const std::vector<uint32_t> lws = {kwg_size_ / 8, 8, 0};
std::string tuning_key = Concat("winograd_transform_kernel",
output_tensor->dim(0), output_tensor->dim(1),
output_tensor->dim(2), output_tensor->dim(3));
output_tensor->dim(2));
TuningOrRun2DKernel(kernel_, tuning_key, gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
......
......@@ -25,24 +25,37 @@ template <DeviceType D, class T>
class MatMulOp : public Operator<D, T> {
public:
MatMulOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
: Operator<D, T>(operator_def, ws),
transpose_a_(OperatorBase::GetOptionalArg<bool>("transpose_a", false)),
transpose_b_(OperatorBase::GetOptionalArg<bool>("transpose_b", false)) {
}
MaceStatus Run(StatsFuture *future) override {
const Tensor *A = this->Input(0);
const Tensor *B = this->Input(1);
Tensor *C = this->Output(0);
MACE_CHECK(A->dim_size() == 4 && 4 == B->dim_size())
<< "The dimension of A and B should be 4";
MACE_CHECK(A->dim(0) == B->dim(0)) << "A and B must have same batch size";
MACE_CHECK(A->dim(2) == B->dim(1))
<< "the number of A's column " << A->dim(2)
<< " must be equal to B's row " << B->dim(1);
return functor_(A, B, C, future);
const Tensor *A = this->Input(INPUT_A);
const Tensor *B = this->Input(INPUT_B);
Tensor *C = this->Output(OUTPUT);
MACE_CHECK(A->dim_size() == B->dim_size() && A->dim_size() >= 2,
"rank(A) should be equal to rank(B), rank should be greater "
"than or equal to 2");
index_t rank = A->dim_size();
for (index_t i = 0; i < rank - 2; ++i) {
MACE_CHECK(A->dim(i) == B->dim(i), "batch dimensions are not equal");
}
index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1);
index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2);
MACE_CHECK(ak == bk, "the number of A's column ", ak,
" must be equal to B's row ", bk);
return functor_(A, B, C, transpose_a_, transpose_b_, future);
}
private:
MACE_OP_INPUT_TAGS(INPUT_A, INPUT_B);
MACE_OP_OUTPUT_TAGS(OUTPUT);
kernels::MatMulFunctor<D, T> functor_;
bool transpose_a_;
bool transpose_b_;
};
} // namespace ops
......
......@@ -31,8 +31,8 @@ void MatMulBenchmark(
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("A", {batch, height, channels, 1});
net.AddRandomInput<D, float>("B", {batch, channels, out_width, 1});
net.AddRandomInput<D, float>("A", {batch, height, channels});
net.AddRandomInput<D, float>("B", {batch, channels, out_width});
if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH);
......@@ -65,6 +65,41 @@ void MatMulBenchmark(
}
net.Sync();
}
template <DeviceType D, typename T>
void MatMulTransposeBenchmark(
int iters, int batch, int height, int channels, int out_width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("A", {batch, height, channels});
net.AddRandomInput<D, float>("B", {batch, out_width, channels});
if (D == DeviceType::CPU) {
OpDefBuilder("MatMul", "MatMulBM")
.Input("A")
.Input("B")
.AddIntArg("transpose_b", 1)
.Output("Output")
.Finalize(net.NewOperatorDef());
} else {
MACE_NOT_IMPLEMENTED;
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_MATMUL_MACRO(N, H, C, W, TYPE, DEVICE) \
......@@ -83,6 +118,20 @@ void MatMulBenchmark(
MACE_BM_MATMUL_MACRO(N, H, C, W, float, GPU); \
MACE_BM_MATMUL_MACRO(N, H, C, W, half, GPU);
#define MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, TYPE, DEVICE) \
static void MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t macc = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * (C * H + H * W); \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
MatMulTransposeBenchmark<DEVICE, TYPE>(iters, N, H, C, W); \
} \
MACE_BENCHMARK(MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_MATMUL_TRANPOSE(N, H, C, W) \
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU);
MACE_BM_MATMUL(16, 32, 128, 49);
MACE_BM_MATMUL(16, 32, 128, 961);
MACE_BM_MATMUL(16, 32, 128, 3969);
......@@ -90,6 +139,13 @@ MACE_BM_MATMUL(16, 128, 128, 49);
MACE_BM_MATMUL(16, 128, 128, 961);
MACE_BM_MATMUL(16, 128, 128, 3969);
MACE_BM_MATMUL_TRANPOSE(16, 32, 128, 49);
MACE_BM_MATMUL_TRANPOSE(16, 32, 128, 961);
MACE_BM_MATMUL_TRANPOSE(16, 32, 128, 3969);
MACE_BM_MATMUL_TRANPOSE(16, 128, 128, 49);
MACE_BM_MATMUL_TRANPOSE(16, 128, 128, 961);
MACE_BM_MATMUL_TRANPOSE(16, 128, 128, 3969);
} // namespace test
} // namespace ops
} // namespace mace
......@@ -72,46 +72,46 @@ void Simple(const std::vector<index_t> &A_shape,
} // namespace
TEST_F(MatMulOpTest, SimpleCPU) {
Simple<DeviceType::CPU>({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6}, {1, 3, 2, 1},
{1, 2, 3, 4, 5, 6}, {1, 2, 2, 1}, {22, 28, 49, 64});
Simple<DeviceType::CPU>({1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 3, 2},
{1, 2, 3, 4, 5, 6}, {1, 2, 2}, {22, 28, 49, 64});
Simple<DeviceType::CPU>(
{1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
{1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
{1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 5, 1}, {215, 230, 245, 260, 275, 490, 530, 570, 610,
{1, 5, 5}, {215, 230, 245, 260, 275, 490, 530, 570, 610,
650, 765, 830, 895, 960, 1025, 1040, 1130, 1220,
1310, 1400, 1315, 1430, 1545, 1660, 1775});
}
TEST_F(MatMulOpTest, SimpleCPUWithBatch) {
Simple<DeviceType::CPU>({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64});
Simple<DeviceType::CPU>({2, 2, 3}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 3, 2}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 2, 2}, {22, 28, 49, 64, 22, 28, 49, 64});
}
TEST_F(MatMulOpTest, SimpleOPENCL) {
Simple<DeviceType::GPU>({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6}, {1, 3, 2, 1},
{1, 2, 3, 4, 5, 6}, {1, 2, 2, 1}, {22, 28, 49, 64});
Simple<DeviceType::GPU>({1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 3, 2},
{1, 2, 3, 4, 5, 6}, {1, 2, 2}, {22, 28, 49, 64});
Simple<DeviceType::GPU>(
{1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
{1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
{1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 5, 1}, {215, 230, 245, 260, 275, 490, 530, 570, 610,
{1, 5, 5}, {215, 230, 245, 260, 275, 490, 530, 570, 610,
650, 765, 830, 895, 960, 1025, 1040, 1130, 1220,
1310, 1400, 1315, 1430, 1545, 1660, 1775});
}
TEST_F(MatMulOpTest, SimpleGPUWithBatch) {
Simple<DeviceType::CPU>({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64});
Simple<DeviceType::CPU>({2, 2, 3}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 3, 2}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6},
{2, 2, 2}, {22, 28, 49, 64, 22, 28, 49, 64});
}
namespace {
template <typename T>
void Complex(const index_t batch,
void Complex(const std::vector<index_t> &batch,
const index_t height,
const index_t channels,
const index_t out_width) {
......@@ -119,23 +119,14 @@ void Complex(const index_t batch,
// Construct graph
OpsTestNet net;
OpDefBuilder("MatMul", "MatMulTest")
.Input("A")
.Input("B")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::GPU, float>("A", {batch, height, channels, 1});
net.AddRandomInput<DeviceType::GPU, float>("B",
{batch, channels, out_width, 1});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1,
std::multiplies<index_t>());
net.AddRandomInput<DeviceType::GPU, float>("A",
{batch_count, height, channels});
net.AddRandomInput<DeviceType::GPU, float>(
"B", {batch_count, channels, out_width});
// Run on opencl
BufferToImage<DeviceType::GPU, T>(&net, "A", "AImage",
......@@ -150,11 +141,40 @@ void Complex(const index_t batch,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::GPU);
ImageToBuffer<DeviceType::GPU, float>(&net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_HEIGHT);
// run cpu
std::vector<index_t> shape_a = batch;
shape_a.push_back(height);
shape_a.push_back(channels);
std::vector<index_t> shape_b = batch;
shape_b.push_back(channels);
shape_b.push_back(out_width);
std::vector<index_t> expected_output_shape = batch;
expected_output_shape.push_back(height);
expected_output_shape.push_back(out_width);
net.GetTensor("A")->Reshape(shape_a);
net.GetTensor("B")->Reshape(shape_b);
OpDefBuilder("MatMul", "MatMulTest")
.Input("A")
.Input("B")
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
// Check
EXPECT_EQ(expected_output_shape, net.GetOutput("Output")->shape());
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
expected.Reshape({batch_count, height, out_width});
if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-2,
1e-1);
......@@ -166,28 +186,36 @@ void Complex(const index_t batch,
} // namespace
TEST_F(MatMulOpTest, OPENCLAlignedWithoutBatch) {
Complex<float>(1, 64, 128, 32);
Complex<float>(1, 64, 32, 128);
Complex<float>({1}, 64, 128, 32);
Complex<float>({1}, 64, 32, 128);
Complex<float>({2, 3}, 64, 32, 128);
}
TEST_F(MatMulOpTest, OPENCLUnAlignedWithoutBatch) {
Complex<float>(1, 31, 113, 61);
Complex<float>(1, 113, 31, 73);
Complex<float>({1}, 31, 113, 61);
Complex<float>({1}, 113, 31, 73);
Complex<float>({2, 3}, 113, 31, 73);
}
TEST_F(MatMulOpTest, OPENCLUnAlignedWithBatch) {
Complex<float>(2, 3, 3, 3);
Complex<float>(16, 31, 61, 67);
Complex<float>(31, 31, 61, 67);
Complex<float>({2}, 3, 3, 3);
Complex<float>({16}, 31, 61, 67);
Complex<float>({31}, 31, 61, 67);
Complex<float>({2, 3}, 31, 61, 67);
}
TEST_F(MatMulOpTest, OPENCLHalfAlignedWithoutBatch) {
Complex<half>(1, 64, 128, 32);
Complex<half>(1, 64, 32, 128);
Complex<half>({1}, 64, 128, 32);
Complex<half>({1}, 64, 32, 128);
Complex<half>({2, 3}, 64, 32, 128);
}
TEST_F(MatMulOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>(2, 31, 113, 61);
Complex<half>(16, 32, 64, 64);
Complex<half>(31, 31, 61, 67);
Complex<half>({2}, 31, 113, 61);
Complex<half>({16}, 32, 64, 64);
Complex<half>({31}, 31, 61, 67);
Complex<half>({2, 3}, 31, 61, 67);
}
// TODO(liyin): test transpose after implementing gpu runtime
// now transpose test is in kernels_test
} // namespace test
} // namespace ops
} // namespace mace
......@@ -518,7 +518,7 @@ class Transformer(base_converter.ConverterInterface):
wt_output_width = batch * (
(out_height + 1) / 2) * ((out_width + 1) / 2)
wt_output_shape.dims.extend(
[16, in_channels, wt_output_width, 1])
[16, in_channels, wt_output_width])
if ConverterUtil.get_arg(op,
MaceKeyword.mace_padding_str) \
......@@ -543,7 +543,7 @@ class Transformer(base_converter.ConverterInterface):
matmul_op.output.extend([matmul_op.name])
matmul_output_shape = matmul_op.output_shape.add()
matmul_output_shape.dims.extend(
[16, out_channels, wt_output_width, 1])
[16, out_channels, wt_output_width])
arg = matmul_op.arg.add()
arg.name = MaceKeyword.mace_winograd_filter_transformed
......
......@@ -167,7 +167,7 @@ class GPUMemoryOptimizer(MemoryOptimizer):
def get_op_mem_block(self, op_type, output_shape):
mem_block = [0, 0]
if op_type == 'WinogradTransform' or op_type == 'MatMul':
mem_block[0] = output_shape[2] * output_shape[3]
mem_block[0] = output_shape[2]
mem_block[1] = output_shape[0] * int((output_shape[1] + 3) / 4)
else:
mem_block[0] = output_shape[2] * int((output_shape[3] + 3) / 4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册