提交 eed54081 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/arm): add armv7 mk4 i8i8i16 gemm, optimized for A7

GitOrigin-RevId: d2f8290a8d6577b99adad16e42d57a6ca55a119e
上级 9c475fff
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/aarch64/matrix_mul/algos.h"
......@@ -733,7 +734,9 @@ void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x16(kern_size_param);
return can_be_treated_as_int8x8x16(kern_size_param) &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
}
bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred(
......@@ -796,7 +799,9 @@ void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x16(kern_size_param);
return can_be_treated_as_int8x8x16(kern_size_param) &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
}
bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred(
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/armv7/matrix_mul/algos.h"
......@@ -526,6 +527,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8,
"AlgoInt8x8x16K4x8x8"_hash,
armv7::matmul::gemm_s8x8x16_4x8, int8_t,
int16_t);
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/
namespace {
void kern_int8x8x16_mk4_k8x8x4(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("kern_int8x8x16_mk4_k8x8x4"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto trA = kern_param.trA, trB = kern_param.trB;
armv7::matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, kern_param.A_type,
kern_param.B_type,
kern_param.C_type);
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_s8x8x16_mk4_8x8>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::usable(
const KernSizeParam& kern_size_param) const {
bool type_ok = can_be_treated_as_int8x8x16(kern_size_param);
return type_ok && kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}
size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_mk4_8x8>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern(
const KernSizeParam&) const {
return kern_int8x8x16_mk4_k8x8x4;
}
bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::preferred(
const KernSizeParam& kern_size_param) const {
return kern_size_param.K >= 4;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x16MK4_8x8x4"_hash,
armv7::matmul::gemm_s8x8x16_mk4_8x8,
int8_t, int16_t, int16_t);
/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */
namespace {
......@@ -937,11 +1006,9 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern(
Bptr = kern_param.B<dt_float16>();
auto Cptr = kern_param.C<dt_float16>();
armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<
armv7::matmul::gemm_nopack_f16_4x8, false>(M, N, K, trA,
trB, strategy)
armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_nopack_f16_4x8,
false>(M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -171,6 +172,18 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
......
......@@ -6,13 +6,15 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <arm_neon.h>
#include <cmath>
#include <cstdint>
#include <type_traits>
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
......@@ -172,7 +174,6 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1,
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr)
:
: "q0", "q1", "q2", "q3", "memory");
}
template <typename T>
......@@ -183,12 +184,12 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1,
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"interleave_4x4_4_b only support uint8_t and int8_t");
asm volatile(
"vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3
"vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3
"vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3
"vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3
"vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3
"vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3
"vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3
"vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3
"vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3
"vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3
"vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3
"vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3
"vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2
"vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3
"vst1.32 {d0-d1},[%[outptr]]!\n"
......@@ -323,10 +324,10 @@ static inline void interleave_6x4_8_b(const T*& inptr0, const T*& inptr1,
"vtrn.32 q1, q3 \n" // q1=r02,r12,r03,r13 q3=r06,r16,r07,r17
"vtrn.32 q5, q7 \n" // q5=r22,r32,r23,r33 q7=r26,r36,r27,r37
"vtrn.32 q9, q11 \n" // q9=r42,r52,r43,r53 q11=r46,r56,r47,r57
"vst1.32 {d0-d1}, [%[outptr]]! \n"
"vst1.32 {d16}, [%[outptr]]! \n"
"vst1.32 {d0-d1}, [%[outptr]]! \n"
"vst1.32 {d16}, [%[outptr]]! \n"
"vswp d3, d10 \n" // q1=r02,r12,r22,r32 q5=r03,r13,r23,r33
"vst1.32 {d8-d9}, [%[outptr]]! \n"
"vst1.32 {d8-d9}, [%[outptr]]! \n"
"vst1.32 {d17}, [%[outptr]]! \n"
"vst1.32 {d2-d3}, [%[outptr]]!\n"
"vst1.32 {d18}, [%[outptr]]!\n"
......@@ -810,15 +811,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1,
"interleave_12x4_1_h only support uint16_t and int16_t");
auto ldin_asm = ldin << 1;
asm volatile(
"vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3
"vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3
"vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3
"vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3
"vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3
"vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3
"vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3
"vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3
"vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3
"vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3
"vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3
"vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3
"vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3
"vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3
"vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3
"vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3
"vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3
"vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3
"vld1.16 {d9}, [%[inptr9]]\n" // J0J1J2J3
"add %[inptr9], %[inptr9], %[ldin_asm]\n"
"vld1.16 {d10}, [%[inptr9]]\n" // K0K1K2K3
......@@ -854,17 +855,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1,
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8),
[inptr9] "+r"(inptr9), [outptr] "+r"(outptr)
:[ldin_asm] "r"(ldin_asm)
: [ldin_asm] "r"(ldin_asm)
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "memory");
inptr9 -= ldin_asm;
inptr9 += 4;
inptr9 -= ldin_asm;
inptr9 += 4;
inptr10 += 4;
inptr11 += 4;
}
template <typename T>
static inline void transpose_2x16_1_b_helper(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......@@ -1038,7 +1037,7 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1,
"vst1.32 {d7}, [%[outptr]], %[stride]\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
[outptr] "+r"(outptr), [stride] "+r" (stride)
[outptr] "+r"(outptr), [stride] "+r"(stride)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory");
}
......@@ -1069,7 +1068,6 @@ static inline void transpose_4x2_1_s(const T*& inptr0, const T*& inptr1,
: "d0", "d1", "d2", "d3", "memory");
}
template <typename T>
static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......@@ -1082,9 +1080,9 @@ static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1,
"vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7
"vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7
"vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7
"add %[inptr0],%[inptr0],#6 \n"
"add %[inptr1],%[inptr1],#6 \n"
"add %[inptr2],%[inptr2],#6 \n"
......@@ -1121,9 +1119,9 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1,
"vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7
"vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7
"vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7
"add %[inptr0],%[inptr0],#4 \n"
"add %[inptr1],%[inptr1],#4 \n"
"add %[inptr2],%[inptr2],#4 \n"
......@@ -1176,7 +1174,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"vst1.32 {d14-d15}, [%[outptr]]! \n"
"vst1.32 {d22-d23}, [%[outptr]]! \n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "memory");
......@@ -1195,12 +1193,11 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) {
"vst1.32 {d4-d5}, [%[outptr]]! \n"
"vst1.32 {d2-d3}, [%[outptr]]! \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr)
:
: "q0", "q1", "q2", "q3", "memory");
}
template <typename T>
static inline void transpose_4(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3, T* outptr,
......@@ -1251,7 +1248,6 @@ static inline void transpose_8(const T*& inptr0, const T*& inptr1,
}
}
template <typename T>
static inline void transpose_4x1(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......@@ -1375,7 +1371,68 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr,
: "q0", "q1", "q2", "q3", "memory");
}
} // armv7
static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0,
const int8_t* inptr1,
int16_t* outptr) {
int8x16_t row0 = vld1q_s8(inptr0);
int16x8_t row0_01 = vmovl_low_s8(row0);
int16x8_t row0_23 = vmovl_high_s8(row0);
int16x4_t row0_0 = vget_low_s16(row0_01);
int16x4_t row0_1 = vget_high_s16(row0_01);
int16x4_t row0_2 = vget_low_s16(row0_23);
int16x4_t row0_3 = vget_high_s16(row0_23);
int8x16_t row1 = vld1q_s8(inptr1);
int16x8_t row1_01 = vmovl_low_s8(row1);
int16x8_t row1_23 = vmovl_high_s8(row1);
int16x4_t row1_0 = vget_low_s16(row1_01);
int16x4_t row1_1 = vget_high_s16(row1_01);
int16x4_t row1_2 = vget_low_s16(row1_23);
int16x4_t row1_3 = vget_high_s16(row1_23);
vst1_s16(outptr, row0_0);
vst1_s16(outptr + 1 * 4, row1_0);
vst1_s16(outptr + 2 * 4, row0_1);
vst1_s16(outptr + 3 * 4, row1_1);
vst1_s16(outptr + 4 * 4, row0_2);
vst1_s16(outptr + 5 * 4, row1_2);
vst1_s16(outptr + 6 * 4, row0_3);
vst1_s16(outptr + 7 * 4, row1_3);
};
static inline void transpos_8x4_int8(const int8_t* inptr0, int8_t* outptr) {
int8x8x4_t input = vld4_s8(inptr0);
vst1_s8(outptr, input.val[0]);
vst1_s8(outptr + 1 * 8, input.val[1]);
vst1_s8(outptr + 2 * 8, input.val[2]);
vst1_s8(outptr + 3 * 8, input.val[3]);
}
static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr,
int count) {
for (; count >= 32; count -= 32) {
int8x8_t in0 = vld1_s8(inptr);
int8x8_t in1 = vld1_s8(inptr + 1 * 8);
int8x8_t in2 = vld1_s8(inptr + 2 * 8);
int8x8_t in3 = vld1_s8(inptr + 3 * 8);
vst1q_s16(outptr, vmovl_s8(in0));
vst1q_s16(outptr + 1 * 8, vmovl_s8(in1));
vst1q_s16(outptr + 2 * 8, vmovl_s8(in2));
vst1q_s16(outptr + 3 * 8, vmovl_s8(in3));
inptr += 32;
outptr += 32;
}
for (; count >= 8; count -= 8) {
int8x8_t in0 = vld1_s8(inptr);
vst1q_s16(outptr, vmovl_s8(in0));
inptr += 8;
outptr += 8;
}
for (; count > 0; --count) {
*outptr++ = (int16_t)(*inptr++);
}
}
} // namespace armv7
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -102,60 +102,60 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
"vld1.8 {d2}, [%[a_ptr]]!\n"
"vld1.8 {d4}, [%[a_ptr]]!\n"
"vld1.8 {d6}, [%[a_ptr]]!\n"
"vld1.8 {d18}, [%[b_ptr]]!\n"
"vmovl.s8 q8, d16\n"
"vmovl.s8 q0, d0\n"
"vmovl.s8 q1, d2\n"
"vmovl.s8 q2, d4\n"
"vmovl.s8 q3, d6\n"
"vld1.8 {d18}, [%[b_ptr]]!\n"
"vmovl.s8 q9, d18\n"
"vld1.8 {d20}, [%[b_ptr]]!\n"
"vmla.s16 q4, q8, d0[0]\n"
"vmla.s16 q5, q8, d2[0]\n"
"vmla.s16 q6, q8, d4[0]\n"
"vmla.s16 q7, q8, d6[0]\n"
"vmovl.s8 q9, d18\n"
"vld1.8 {d20}, [%[b_ptr]]!\n"
"vmovl.s8 q10, d20\n"
"vld1.8 {d22}, [%[b_ptr]]!\n"
"vmla.s16 q4, q9, d0[1]\n"
"vmla.s16 q5, q9, d2[1]\n"
"vmla.s16 q6, q9, d4[1]\n"
"vmla.s16 q7, q9, d6[1]\n"
"vmovl.s8 q10, d20\n"
"vld1.8 {d22}, [%[b_ptr]]!\n"
"vmovl.s8 q11, d22\n"
"vld1.8 {d24}, [%[b_ptr]]!\n"
"vmla.s16 q4, q10, d0[2]\n"
"vmla.s16 q5, q10, d2[2]\n"
"vmla.s16 q6, q10, d4[2]\n"
"vmla.s16 q7, q10, d6[2]\n"
"vmovl.s8 q11, d22\n"
"vld1.8 {d24}, [%[b_ptr]]!\n"
"vmovl.s8 q12, d24\n"
"vld1.8 {d26}, [%[b_ptr]]!\n"
"vmla.s16 q4, q11, d0[3]\n"
"vmla.s16 q5, q11, d2[3]\n"
"vmla.s16 q6, q11, d4[3]\n"
"vmla.s16 q7, q11, d6[3]\n"
"vmovl.s8 q12, d24\n"
"vld1.8 {d26}, [%[b_ptr]]!\n"
"vmovl.s8 q13, d26\n"
"vld1.8 {d28}, [%[b_ptr]]!\n"
"vmla.s16 q4, q12, d1[0]\n"
"vmla.s16 q5, q12, d3[0]\n"
"vmla.s16 q6, q12, d5[0]\n"
"vmla.s16 q7, q12, d7[0]\n"
"vmovl.s8 q13, d26\n"
"vld1.8 {d28}, [%[b_ptr]]!\n"
"vmovl.s8 q14, d28\n"
"vld1.8 {d30}, [%[b_ptr]]!\n"
"vmla.s16 q4, q13, d1[1]\n"
"vmla.s16 q5, q13, d3[1]\n"
"vmla.s16 q6, q13, d5[1]\n"
"vmla.s16 q7, q13, d7[1]\n"
"vmovl.s8 q14, d28\n"
"vld1.8 {d30}, [%[b_ptr]]!\n"
"vmovl.s8 q15, d30\n"
"vmla.s16 q4, q14, d1[2]\n"
"vmla.s16 q5, q14, d3[2]\n"
"vmla.s16 q6, q14, d5[2]\n"
"vmla.s16 q7, q14, d7[2]\n"
"vmovl.s8 q15, d30\n"
"vmla.s16 q4, q15, d1[3]\n"
"vmla.s16 q5, q15, d3[3]\n"
......
/**
* \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
namespace megdnn {
namespace armv7 {
namespace matmul_mk4_8x8x4 {
//! optimize for A7
/**
* Overview of register layout:
*
* A 8x8x8 cell of Lhs is stored in 16bit in q0, q1
* A 8x8x8 cell of Rhs is stored in 8bit in q2, q3
* A 8x8 block of accumulators is stored in 16bit in q8-q15
*
* +--------+
* | q4[0-8]|
* Rhs +--------+
* Lhs | |
*
* +--------+ - - - - +---------
* |q0[0]| | q8 [0-8]|
* |q0[1]| | q9 [0-8]|
* |q0[2]| | q10[0-8]|
* |q0[3]| | q11[0-8]|
* |q0[4]| | q12[0-8]|
* |q0[5]| | q13[0-8]|
* |q0[6]| | q14[0-8]|
* |q0[7]| | q15[0-8]|
* +--------+ - - - - +---------
*
* Accumulator
*/
static void kern_8x8(const int16_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int remain_n) {
K /= 4;
const int16_t* a_ptr = packA;
const int8_t* b_ptr = packB;
LDC = LDC * sizeof(int16_t);
int x0 = 0;
// clang-format off
#define STORE_LINE(reg_index1, reg_index2) \
"cmp %[x0], #0 \n" \
"beq 101f\n" \
"vst1.16 {d" reg_index1 "}, [r0]!\n" \
"vst1.16 {d" reg_index2 "}, [r1]!\n" \
"subs %[x0], %[x0], #1\n"
#define STORE_C \
"mov %[x0], %[remain_n]\n" \
STORE_LINE("16", "17") \
STORE_LINE("18", "19") \
STORE_LINE("20", "21") \
STORE_LINE("22", "23") \
STORE_LINE("24", "25") \
STORE_LINE("26", "27") \
STORE_LINE("28", "29") \
STORE_LINE("30", "31") \
"101:\n"
// clang-format on
register int16_t* outptr asm("r0") = output;
asm volatile(
// load accumulator C
"add r1, r0, %[LDC]\n"
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"b 2f\n"
"1:\n"
"veor.s32 q8 , q8 , q8 \n"
"veor.s32 q9 , q9 , q9 \n"
"veor.s32 q10, q10, q10\n"
"veor.s32 q11, q11, q11\n"
"veor.s32 q12, q12, q12\n"
"veor.s32 q13, q13, q13\n"
"veor.s32 q14, q14, q14\n"
"veor.s32 q15, q15, q15\n"
"2: \n"
"vld1.8 {d4}, [%[b_ptr]]!\n"
"vld1.16 {d0, d1}, [%[a_ptr]]!\n"
"vmovl.s8 q2, d4\n"
"vld1.16 {d2, d3}, [%[a_ptr]]!\n"
"vld1.8 {d6}, [%[b_ptr]]!\n"
//! k0
"vmla.s16 q8, q0, d4[0]\n"
"vmla.s16 q9, q0, d4[1]\n"
"vmla.s16 q10, q0, d4[2]\n"
"vmla.s16 q11, q0, d4[3]\n"
"vmovl.s8 q3, d6\n"
"vmla.s16 q12, q0, d5[0]\n"
"vmla.s16 q13, q0, d5[1]\n"
"vmla.s16 q14, q0, d5[2]\n"
"vmla.s16 q15, q0, d5[3]\n"
//! k1
"vld1.16 {d0, d1}, [%[a_ptr]]!\n"
"vld1.8 {d4}, [%[b_ptr]]!\n"
"vmla.s16 q8, q1, d6[0]\n"
"vmla.s16 q9, q1, d6[1]\n"
"vmla.s16 q10, q1, d6[2]\n"
"vmla.s16 q11, q1, d6[3]\n"
"vmovl.s8 q2, d4\n"
"vmla.s16 q12, q1, d7[0]\n"
"vmla.s16 q13, q1, d7[1]\n"
"vmla.s16 q14, q1, d7[2]\n"
"vmla.s16 q15, q1, d7[3]\n"
//! k2
"vld1.16 {d2, d3}, [%[a_ptr]]!\n"
"vld1.8 {d6}, [%[b_ptr]]!\n"
"vmla.s16 q8, q0, d4[0]\n"
"vmla.s16 q9, q0, d4[1]\n"
"vmla.s16 q10, q0, d4[2]\n"
"vmla.s16 q11, q0, d4[3]\n"
"vmovl.s8 q3, d6\n"
"vmla.s16 q12, q0, d5[0]\n"
"vmla.s16 q13, q0, d5[1]\n"
"vmla.s16 q14, q0, d5[2]\n"
"vmla.s16 q15, q0, d5[3]\n"
//! k3
"vmla.s16 q8, q1, d6[0]\n"
"vmla.s16 q9, q1, d6[1]\n"
"vmla.s16 q10, q1, d6[2]\n"
"vmla.s16 q11, q1, d6[3]\n"
"vmla.s16 q12, q1, d7[0]\n"
"vmla.s16 q13, q1, d7[1]\n"
"vmla.s16 q14, q1, d7[2]\n"
"vmla.s16 q15, q1, d7[3]\n"
"subs %[K], %[K], #1\n"
"bne 2b\n"
"3:\n"
"cmp %[remain_n], #8\n"
"bne 4f\n"
"vstr d16, [r0]\n"
"vstr d18, [r0, #8]\n"
"vstr d20, [r0, #16]\n"
"vstr d22, [r0, #24]\n"
"vstr d24, [r0, #32]\n"
"vstr d26, [r0, #40]\n"
"vstr d28, [r0, #48]\n"
"vstr d30, [r0, #56]\n"
"vstr d17, [r1]\n"
"vstr d19, [r1, #8]\n"
"vstr d21, [r1, #16]\n"
"vstr d23, [r1, #24]\n"
"vstr d25, [r1, #32]\n"
"vstr d27, [r1, #40]\n"
"vstr d29, [r1, #48]\n"
"vstr d31, [r1, #56]\n"
"b 101f\n"
"4:\n " STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28",
"d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory");
#undef STORE_C
#undef STORE_LINE
}
/**
* Overview of register layout:
*
* A 8x8x8 cell of Lhs is stored in 16bit in d0, d2
* A 8x8x8 cell of Rhs is stored in 8bit in q2, q3
* A 8x8 block of accumulators is stored in 16bit in q8-11
*
* +--------+
* | q4[0-8]|
* Rhs +--------+
* Lhs | |
*
* +--------+ - - - - +---------
* |d0[0]| | q8 [0-8]|
* |d0[1]| | q9 [0-8]|
* |d0[2]| | q10[0-8]|
* |d0[3]| | q11[0-8]|
* +--------+ - - - - +---------
*
* Accumulator
*/
static void kern_4x8(const int16_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int remain_n) {
K /= 4;
const int16_t* a_ptr = packA;
const int8_t* b_ptr = packB;
LDC = LDC * sizeof(int16_t);
int x0 = 0;
// clang-format off
#define STORE_LINE(reg_index1) \
"cmp %[x0], #0 \n" \
"beq 101f\n" \
"vst1.16 {d" reg_index1 "}, [r0]!\n" \
"subs %[x0], %[x0], #1\n"
#define STORE_C \
"mov %[x0], %[remain_n]\n" \
STORE_LINE("16") \
STORE_LINE("18") \
STORE_LINE("20") \
STORE_LINE("22") \
STORE_LINE("24") \
STORE_LINE("26") \
STORE_LINE("28") \
STORE_LINE("30") \
"101:\n"
// clang-format on
register int16_t* outptr asm("r0") = output;
asm volatile(
//! load accumulator C
"add r1, r0, %[LDC]\n"
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"b 2f\n"
"1:\n"
"veor.s32 q8 , q8 , q8 \n"
"veor.s32 q9 , q9 , q9 \n"
"veor.s32 q10, q10, q10\n"
"veor.s32 q11, q11, q11\n"
"veor.s32 q12, q12, q12\n"
"veor.s32 q13, q13, q13\n"
"veor.s32 q14, q14, q14\n"
"veor.s32 q15, q15, q15\n"
"2: \n"
"vld1.8 {d4}, [%[b_ptr]]!\n"
"vld1.16 {d0}, [%[a_ptr]]!\n"
"vmovl.s8 q2, d4\n"
"vld1.16 {d2}, [%[a_ptr]]!\n"
"vld1.8 {d6}, [%[b_ptr]]!\n"
//! k0
"vmla.s16 d16, d0, d4[0]\n"
"vmla.s16 d18, d0, d4[1]\n"
"vmla.s16 d20, d0, d4[2]\n"
"vmla.s16 d22, d0, d4[3]\n"
"vmovl.s8 q3, d6\n"
"vmla.s16 d24, d0, d5[0]\n"
"vmla.s16 d26, d0, d5[1]\n"
"vmla.s16 d28, d0, d5[2]\n"
"vmla.s16 d30, d0, d5[3]\n"
//! k1
"vld1.16 {d0}, [%[a_ptr]]!\n"
"vld1.8 {d4}, [%[b_ptr]]!\n"
"vmla.s16 d16, d2, d6[0]\n"
"vmla.s16 d18, d2, d6[1]\n"
"vmla.s16 d20, d2, d6[2]\n"
"vmla.s16 d22, d2, d6[3]\n"
"vmovl.s8 q2, d4\n"
"vmla.s16 d24, d2, d7[0]\n"
"vmla.s16 d26, d2, d7[1]\n"
"vmla.s16 d28, d2, d7[2]\n"
"vmla.s16 d30, d2, d7[3]\n"
//! k2
"vld1.16 {d2}, [%[a_ptr]]!\n"
"vld1.8 {d6}, [%[b_ptr]]!\n"
"vmla.s16 d16, d0, d4[0]\n"
"vmla.s16 d18, d0, d4[1]\n"
"vmla.s16 d20, d0, d4[2]\n"
"vmla.s16 d22, d0, d4[3]\n"
"vmovl.s8 q3, d6\n"
"vmla.s16 d24, d0, d5[0]\n"
"vmla.s16 d26, d0, d5[1]\n"
"vmla.s16 d28, d0, d5[2]\n"
"vmla.s16 d30, d0, d5[3]\n"
//! k3
"vmla.s16 d16, d2, d6[0]\n"
"vmla.s16 d18, d2, d6[1]\n"
"vmla.s16 d20, d2, d6[2]\n"
"vmla.s16 d22, d2, d6[3]\n"
"vmla.s16 d24, d2, d7[0]\n"
"vmla.s16 d26, d2, d7[1]\n"
"vmla.s16 d28, d2, d7[2]\n"
"vmla.s16 d30, d2, d7[3]\n"
"subs %[K], %[K], #1\n"
"bne 2b\n"
"3:\n"
"cmp %[remain_n], #8\n"
"bne 4f\n"
"vstr d16, [r0]\n"
"vstr d18, [r0, #8]\n"
"vstr d20, [r0, #16]\n"
"vstr d22, [r0, #24]\n"
"vstr d24, [r0, #32]\n"
"vstr d26, [r0, #40]\n"
"vstr d28, [r0, #48]\n"
"vstr d30, [r0, #56]\n"
"b 101f\n"
"4:\n " STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
:
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
"d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28",
"d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory");
#undef STORE_C
#undef STORE_LINE
}
static void gemm_s8x8x16_mk4_8x8_pack_A_n(dt_int16* outptr,
const dt_int8* inptr, int ldin,
int m0, int mmax, int k0, int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 8;
constexpr int pack_k = 4;
constexpr int pack_size = 4;
const int m_size = mmax - m0;
const int m_end = m_size / pack_m * pack_m + m0;
const int remain_m = mmax - m_end;
for (int m_idx = m0; m_idx < m_end; m_idx += pack_m) {
const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) {
interleave_4x4_8x4_s8_s16(inptr0, inptr1, outptr);
inptr0 += pack_size * pack_size;
inptr1 += pack_size * pack_size;
outptr += pack_m * pack_k;
}
}
if (remain_m > 0) {
const int8_t* inptr0 = inptr + m_end / pack_size * ldin + k0;
const int k_size = kmax - k0;
memcpy_s8_s16(inptr0, outptr, k_size * pack_size);
}
}
static void gemm_s8x8x16_mk4_8x8_pack_B_n(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
int8_t tmpbuff[32] = {0};
constexpr int pack_n = 8;
constexpr int pack_size = 4;
const int ksize = kmax - k0;
const int nsize = nmax - n0;
const int n_end = nsize / pack_n * pack_n + n0;
const int remain_n = nsize % pack_n;
int output_stride = ksize * pack_n;
int8_t* outptr_base = out;
for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) {
const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size;
prefetch_3x(inptr);
auto outptr = outptr_base;
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) {
transpos_8x4_int8(inptr, outptr);
inptr += pack_n * pack_size;
outptr += output_stride;
}
if (remain_n > 0) {
memcpy(tmpbuff, inptr, sizeof(int8_t) * remain_n * pack_size);
transpos_8x4_int8(tmpbuff, outptr);
outptr += output_stride;
}
outptr_base += pack_n * pack_size;
}
}
} // namespace matmul_mk4_8x8x4
} // namespace armv7
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,14 +6,16 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/armv7/matrix_mul/int8x8x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h"
#include "src/armv7/matrix_mul/int8x8x16/strategy.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
......@@ -108,7 +110,7 @@ void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_4x4==================================
// ===========================gemm_s8x8x16_4x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x8);
void gemm_s8x8x16_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
......@@ -179,4 +181,79 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_mk4_8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8);
void gemm_s8x8x16_mk4_8x8::pack_A(dt_int16* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
void gemm_s8x8x16_mk4_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
void gemm_s8x8x16_mk4_8x8::kern(const dt_int16* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
constexpr size_t pack_size = 4;
constexpr size_t pack_m = 8;
constexpr size_t pack_n = 8;
const size_t remain_n = N % pack_n;
const size_t remain_m = M % pack_m;
size_t m_idx = 0;
for (; m_idx + pack_m <= M; m_idx += pack_m) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
packA += pack_m * K;
}
if (remain_m > 0) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
}
}
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
......@@ -21,6 +22,10 @@ MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true,
MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true,
gemm_s8x8x16_4x8);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(int8_t, int16_t, int16_t, int16_t, 8,
8, 4, false, false,
gemm_s8x8x16_mk4_8x8);
} // namespace matmul
} // namespace armv7
} // namespace megdnn
......
......@@ -6,10 +6,11 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/armv7/matrix_mul/opr_impl.h"
#include "src/armv7/matrix_mul/algos.h"
#include "src/armv7/matrix_mul/opr_impl.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
......@@ -21,7 +22,7 @@ using namespace armv7;
class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32 f32;
AlgoF32MK4Pack4x12 f32_mk4_pack_4x12;
AlgoF32MK4_4x8 f32_mk4_4x8;
AlgoF32MK4_4x8 f32_mk4_4x8;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16K4x16x1 f16_k4x16x1;
AlgoF16MK8_4x8 f16_mk8_4x8;
......@@ -38,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoQuint8K4x8x8 quint8_k4x8x8;
AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16;
AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8;
AlgoInt8x8x16MK4_8x8x4 int8x8x16_mk4_8x8x4;
AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1;
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8;
......@@ -62,8 +64,10 @@ public:
all_algos.emplace_back(&int8x8x32_k4x2x16);
all_algos.emplace_back(&int8x8x32_k4x8x8);
all_algos.emplace_back(&quint8_k4x8x8);
all_algos.emplace_back(&int8x8x16_mk4_8x8x4);
all_algos.emplace_back(&int8x8x16_k4x2x16);
all_algos.emplace_back(&int8x8x16_k4x8x8);
all_algos.emplace_back(&int16x16x32_k12x4x1);
all_algos.emplace_back(&int16x16x32_mk8_4x8);
}
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/matrix_mul/opr_impl.h"
......@@ -19,26 +20,28 @@ public:
using arm_common::MatrixMulImpl::MatrixMulImpl;
SmallVector<AlgoBase*> algo_pack() override;
private:
class AlgoF32; // Armv7 F32
class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack
class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack
class AlgoF32Gemv; // Armv7 F32 Gemv
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16
class AlgoF32; // Armv7 F32
class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack
class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack
class AlgoF32Gemv; // Armv7 F32 Gemv
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16
class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8
class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1
class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8
class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel 8x8x8
class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1
class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1
class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8
#endif
#if __ARM_FEATURE_DOTPROD
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4
class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4
// DotProduct
#endif
......
......@@ -10,9 +10,9 @@
* implied.
*/
#include "src/fallback/conv_bias/conv1x1/algos.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/conv1x1/algos.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h"
#include "src/fallback/conv_bias/opr_impl.h"
......@@ -194,10 +194,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
PW = param.filter_meta.padding[1];
size_t SH = param.filter_meta.stride[0],
SW = param.filter_meta.stride[1];
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1)
return false;
if (param.src_type.enumv() != param.filter_type.enumv()) {
return false;
}
......@@ -216,6 +214,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
//! is identity otherwise return false mean that 8x8x32 and 8x8x16
//! not support PostProcess
if (param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <unordered_map>
......@@ -226,10 +227,10 @@ public:
PostprocessMode::FLOAT,
"DefaultStrategyType::FLOAT"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
auto matmul_block = matmul_algo->get_inner_block_size();
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1
//! im2col+pack fuse
if ((matmul_block.m == 8 || matmul_block.m == 4) &&
matmul_block.n == 12 && matmul_block.k == 1 &&
param.filter_meta.spatial[0] == 3 &&
......@@ -297,9 +298,21 @@ public:
break;
case StrategyType::INT8x8x16:
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x16"_hash);
if (format == param::ConvBias::Format::NCHW) {
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x16"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x16"_hash);
} else {
megdnn_throw(
ssprintf("Current only support layout "
"NCHW44/NCHW for im2col "
"algo, but got %d\n",
uint32_t(format)));
}
break;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case StrategyType::QUINT8x8x32:
......@@ -421,10 +434,11 @@ public:
dt_int32, dt_int8, PostprocessMode::QUANTIZED,
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash);
} else {
megdnn_throw(ssprintf("Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d\n",
uint32_t(format)));
megdnn_throw(
ssprintf("Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d\n",
uint32_t(format)));
}
break;
}
......
......@@ -6,11 +6,12 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/naive/matrix_mul/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/matrix_mul/opr_impl.h"
namespace megdnn {
namespace fallback {
......@@ -66,7 +67,8 @@ public:
};
typedef void (*kern_t)(const KernParam&);
typedef void (*kern_naked_t)(const KernParam& , const void* a_panel, const void *b_panel);
typedef void (*kern_naked_t)(const KernParam&, const void* a_panel,
const void* b_panel);
class AlgoBase : public Algorithm {
protected:
virtual ~AlgoBase() = default;
......@@ -83,18 +85,19 @@ public:
bool can_be_treated_as_int8x8x16(const KernSizeParam& param) const {
return param.A_type.enumv() == param.B_type.enumv() &&
param.A_type.enumv() == DTypeEnum::Int8 &&
param.C_type.enumv() == DTypeEnum::Int16 &&
param.format == param::MatrixMul::Format::DEFAULT &&
param.compute_mode == Param::ComputeMode::DEFAULT;
(param.A_type.enumv() == DTypeEnum::Int8 ||
param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(param.C_type.enumv() == DTypeEnum::Int16 ||
param.C_type.enumv() == DTypeEnum::QuantizedS16);
}
public:
enum class AlgoSet:uint32_t {
enum class AlgoSet : uint32_t {
ALGO_TYPE_GEMM = 0,
ALGO_TYPE_GEMV = 1,
};
enum class PackMode:uint32_t {
enum class PackMode : uint32_t {
DEFAULT = 0,
NO_PACK = 1,
ONLY_PACKA = 2,
......
......@@ -489,25 +489,26 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle,
void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name,
const char* im2col_name, Handle* handle,
size_t kernel, size_t pack_size = 1) {
auto&& args = get_winograd_benchmark_args(kernel, pack_size);
size_t kernel, DType src_type,
DType dst_type) {
auto&& args = get_winograd_benchmark_args(kernel, 4);
using namespace conv_bias;
constexpr size_t RUN = 10;
Benchmarker<ConvBias> benchmark(handle);
benchmark.set_display(false);
benchmark.set_times(RUN);
benchmark.set_dtype(0, dtype::Int8());
benchmark.set_dtype(1, dtype::Int8());
benchmark.set_dtype(2, dtype::Int32());
benchmark.set_dtype(4, dtype::Int32());
benchmark.set_dtype(0, src_type);
benchmark.set_dtype(1, src_type);
benchmark.set_dtype(2, dst_type);
benchmark.set_dtype(4, dst_type);
Benchmarker<ConvBias> benchmark_im2col(handle);
benchmark_im2col.set_display(false);
benchmark_im2col.set_times(RUN);
benchmark_im2col.set_dtype(0, dtype::Int8());
benchmark_im2col.set_dtype(1, dtype::Int8());
benchmark_im2col.set_dtype(2, dtype::Int32());
benchmark_im2col.set_dtype(4, dtype::Int32());
benchmark_im2col.set_dtype(0, src_type);
benchmark_im2col.set_dtype(1, src_type);
benchmark_im2col.set_dtype(2, dst_type);
benchmark_im2col.set_dtype(4, dst_type);
for (auto&& arg : args) {
TensorLayout dst_layout;
......@@ -556,6 +557,7 @@ void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name,
computations / used_im2col, used / used_im2col);
}
}
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) {
printf("=========================compare "
......@@ -563,7 +565,17 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) {
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16 \n");
BENCHMARK_IM2COL_NCHW44_VS_NCHW("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16",
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16",
handle(), 3, 4);
handle(), 3, dtype::Int8(), dtype::Int32());
}
#endif
#if MEGDNN_ARMV7
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) {
const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8";
const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4";
printf("compare %s vs %s \n", default_algo, mk4_algo);
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3,
dtype::Int8(), dtype::Int16());
}
#endif
......@@ -1860,15 +1872,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) {
param.format = param::ConvBias::Format::NCHW44_DOT;
//! channel bias
args.emplace_back(param, TensorShape{1, ic/4, h, w, 4},
TensorShape{oc/4, ic/4, kernel, kernel, 4, 4},
TensorShape{1, oc/4, 1, 1, 4});
args.emplace_back(param, TensorShape{1, ic / 4, h, w, 4},
TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4},
TensorShape{1, oc / 4, 1, 1, 4});
};
for (size_t stride : {1, 2})
for (size_t kernel : {2, 3, 5, 7})
for(size_t oc : {64})
for (size_t oc : {64})
for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) {
run(oc, oc, 56, 56, kernel, kernel / 2, stride, nonline_mode);
run(oc, oc, 56, 56, kernel, kernel / 2, stride,
nonline_mode);
}
constexpr size_t RUN = 50;
......@@ -1880,7 +1893,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) {
benchmark0.set_display(false);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("ARMDOTS8DIRECT_NCHW44"));
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"ARMDOTS8DIRECT_NCHW44"));
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f))
......@@ -2002,15 +2016,20 @@ std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(
void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle,
DType stype, DType matmul_dtype, DType bias_type,
DType conv_dtype) {
DType conv_dtype, bool is_mk4 = false) {
using namespace conv_bias;
int pack_size = is_mk4 ? 4 : 1;
std::vector<TestArg> conv_bias_1x1_args =
get_conv_bias_1x1_benchmark_args();
get_conv_bias_1x1_benchmark_args(pack_size);
constexpr size_t RUNS = 50;
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
if (is_mk4) {
param.format = MatrixMul::Param::Format::MK4;
}
Benchmarker<MatrixMul> benchmark_matmul(handle);
benchmark_matmul.set_before_exec_callback(
AlgoChecker<MatrixMul>(matmul_algo_name));
......@@ -2038,8 +2057,8 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle,
size_t OH = arg.src[2];
size_t OW = arg.src[3];
size_t OC = arg.filter[0];
size_t M = OC;
size_t K = IC;
size_t M = OC * pack_size;
size_t K = IC * pack_size;
size_t N = OH * OW;
float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3;
......@@ -2047,6 +2066,10 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle,
TensorShape A, B;
A = TensorShape{M, K};
B = TensorShape{K, N};
if (is_mk4) {
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, N, 4};
}
auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}}) /
......@@ -2133,6 +2156,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) {
dtype::Int16{}, dtype::Int16{}, dtype::Int16{});
benchmark_conv1x1("ARMV7_INT8X8X16_K4X2X16", handle(), dtype::Int8{},
dtype::Int16{}, dtype::Int16{}, dtype::Int16{});
benchmark_conv1x1("ARMV7_INT8X8X16_MK4_K8X8X4", handle(), dtype::Int8{},
dtype::Int16{}, dtype::Int16{}, dtype::Int16{}, true);
#endif
}
......@@ -2145,13 +2170,13 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) {
conv_param.pad_h = 0;
conv_param.pad_w = 0;
conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY;
auto run = [&](size_t M, size_t K){
auto run = [&](size_t M, size_t K) {
args.emplace_back(conv_param, TensorShape{1, K, 1, 1},
TensorShape{M, K, 1, 1}, TensorShape{});
};
for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 256, 1024, 4096})
run(M, K);
run(M, K);
constexpr size_t RUNS = 50;
param::MatrixMul param;
......
......@@ -850,7 +850,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
param::ConvBias::Format::NCHW44);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
......@@ -1131,7 +1132,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
1e-3f);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS) {
using namespace conv_bias;
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
......@@ -2089,6 +2091,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args_nchw44 =
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true,
false, false, false, false, true);
std::vector<conv_bias::TestArg> args_nchw44_1x1s2 =
get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false,
false, false, true);
#define cb(name) \
checker_conv_bias( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
......@@ -2098,6 +2106,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
&rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name);
#define cb_nchw44(name) \
checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); \
checker_conv_bias(args_nchw44_1x1s2, handle(), &rng, epsilon, \
dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \
dtype::Int16{}, name);
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8");
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16");
......@@ -2106,8 +2121,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8");
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16");
cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4");
#endif
#undef cb
#undef cb_nchw44
}
#endif
......@@ -2516,19 +2534,28 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args(
{1}, 1, true, true, true, false, false, false, false, true);
#define cb(name) \
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
#define cb_nchw44(name) \
checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
#elif MEGDNN_ARMV7
cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
cb_nchw44("CONV1x1:ARMV7_INT8X8X16_MK4_K8X8X4:48");
#endif
cb("CONV1x1:ARM_COMMON_INT8X8X16:48");
#undef cb
#undef cb_nchw44
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/armv7/fixture.h"
#include "test/common/benchmarker.h"
......@@ -51,9 +52,15 @@ TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) {
handle(), "ARMV7_INT8X8X16_K4X8X8");
}
TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_MK4_K8x8x4) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "ARMV7_INT8X8X16_MK4_K8X8X4",
param::MatrixMul::Format::MK4, 1);
}
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) {
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
handle(),"ARMV7_INT16X16X32_K12X4X1");
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
handle(), "ARMV7_INT16X16X32_K12X4X1");
}
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) {
......@@ -83,7 +90,8 @@ TEST_F(ARMV7, MATRIX_MUL_SDOT) {
TEST_F(ARMV7, MATRIX_MUL_UDOT) {
matrix_mul::check_matrix_mul(
dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)), dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)),
dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)),
dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)),
dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4");
}
......@@ -103,7 +111,9 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) {
#if MEGDNN_WITH_BENCHMARK
namespace {
void run_8x8x16_benchmark(const char* algo, Handle* handle) {
void run_8x8x16_benchmark(
const char* algo, Handle* handle,
MatrixMul::Param::Format format = MatrixMul::Param::Format::DEFAULT) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
Benchmarker<MatrixMul> benchmarker_int(handle);
......@@ -116,21 +126,31 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) {
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
param::MatrixMul target_param;
target_param.format = format;
benchmarker_int_kern_4x2x16.set_before_exec_callback(
AlgoChecker<MatrixMul>(algo));
benchmarker_int_kern_4x2x16.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_param(target_param)
.set_display(false);
Benchmarker<MatrixMul> benchmarker_float(handle);
benchmarker_float.set_display(false).set_times(RUNS);
auto run = [&](size_t M, size_t N, size_t K) {
auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS;
auto int_kern_used =
benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / RUNS;
auto int_kern_used = 1e10;
if (format == MatrixMul::Param::Format::MK4) {
int_kern_used = benchmarker_int_kern_4x2x16.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
} else {
int_kern_used =
benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) /
RUNS;
}
auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f "
......@@ -145,6 +165,7 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) {
};
run(256, 12 * 24, 256);
run(256, 256, 256);
//////////////////////// gemv //////////////////////////
for (size_t M : {8, 64, 112, 256}) {
......@@ -185,7 +206,8 @@ void run_16x16x32_benchmark(const char* algo, Handle* handle) {
"int: %f ms %f Gflops %s: \n"
"speedup(%s/arm_common, %s/float): %f\n",
M, K, N, float_used, computations / float_used, int_used,
computations / int_used,algo,algo,algo,float_used / int_used);
computations / int_used, algo, algo, algo,
float_used / int_used);
};
run(256, 12 * 24, 256);
......@@ -231,7 +253,8 @@ void run_8x8x32_benchmark(const char* algo, Handle* handle) {
"int: %f ms %f Gflops %s: \n"
"speedup(%s/arm_common, %s/float): %f\n",
M, K, N, float_used, computations / float_used, int_used,
computations / int_used,algo,algo,algo,float_used / int_used);
computations / int_used, algo, algo, algo,
float_used / int_used);
};
run(256, 12 * 24, 256);
......@@ -252,9 +275,11 @@ void run_8x8x32_quint_benchmark(Handle* handle) {
benchmarker_quint8_dot.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH32_QUINT8_K4X8X4"));
benchmarker_quint8_dot.set_times(RUNS)
.set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20)))
.set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30)))
.set_dtype(2, dtype::QuantizedS32(2.3f*3.1f))
.set_dtype(0,
dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20)))
.set_dtype(1,
dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30)))
.set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f))
.set_param(param)
.set_display(false);
......@@ -262,14 +287,17 @@ void run_8x8x32_quint_benchmark(Handle* handle) {
benchmarker_quint8.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARMV7_QUINT8_K4X8X8"));
benchmarker_quint8.set_times(RUNS)
.set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20)))
.set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30)))
.set_dtype(2, dtype::QuantizedS32(2.3f*3.1f))
.set_dtype(0,
dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20)))
.set_dtype(1,
dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30)))
.set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f))
.set_param(param)
.set_display(false);
auto run = [&](size_t M, size_t N, size_t K) {
auto dot_used = benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS;
auto dot_used =
benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS;
auto normal_used = benchmarker_quint8.exec({{M, K}, {K, N}, {}}) / RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops \n"
......@@ -351,11 +379,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X2X16", handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_MK4_K8X8X4", handle(),
MatrixMul::Param::Format::MK4);
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) {
run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle());
}
......
......@@ -6,12 +6,13 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/common/matrix_mul.h"
#include "src/common/utils.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/matrix_mul.h"
using namespace megdnn;
using namespace test;
......@@ -39,9 +40,9 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_no_mask() {
std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args(
size_t nbase) {
std::vector<TestArg> args;
for (size_t m : {1, 2, 3, 4, 5})
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11})
for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24})
for (size_t k : {1, 2, 3, 4, 5, 9, 10})
for (size_t k : {1, 2, 3, 4, 5, 9, 10, 11})
args.emplace_back(m, n * nbase, k, 0);
return args;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册