提交 6e70fa7a 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/arm): add fp32 asm gemm for a53 a55 and i8i8i16 gemm for a72 a53

GitOrigin-RevId: a049c33f2bf1e737630de263161d7e32be2ba645
上级 dbaf84b0
......@@ -23,6 +23,9 @@
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#include "midout.h"
MIDOUT_DECL(megdnn_aarch64_matmul_kern)
......@@ -80,6 +83,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
}
MIDOUT_END();
};
return f32_kern_8x12;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
......@@ -837,6 +841,159 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16,
aarch64::matmul::gemm_s8x8x16_4x4, int8_t,
int16_t);
/* ===================== Int8x8x16 K16x12x4 algo ===================== */
namespace {
void int8x8x16_mk4_16x12x4_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type,
B_type, C_type);
megdnn::matmul::GemmInterleaved<
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB,
strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x16(kern_size_param) &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}
bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred(
const KernSizeParam&) const {
#if !MGB_ENABLE_CPUINFO
return false;
#else
auto arch = cpuinfo_get_current_core()->uarch;
bool little_core = arch == cpuinfo_uarch_cortex_a53 ||
arch == cpuinfo_uarch_cortex_a55;
return little_core;
#endif
}
size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type,
B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
const KernSizeParam&) const {
return int8x8x16_mk4_16x12x4_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_16x12x4Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t);
/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */
namespace {
void int8x8x16_mk4_4x4x8_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type,
B_type, C_type);
megdnn::matmul::GemmInterleaved<
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB,
strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x16(kern_size_param) &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}
bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred(
const KernSizeParam&) const {
#if !MGB_ENABLE_CPUINFO
return false;
#else
auto arch = cpuinfo_get_current_core()->uarch;
bool little_core = arch == cpuinfo_uarch_cortex_a53 ||
arch == cpuinfo_uarch_cortex_a55;
return !little_core;
#endif
}
size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type,
B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern(
const KernSizeParam&) const {
return int8x8x16_mk4_4x4x8_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_4x4x8_Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72,
int8_t, int16_t);
/* ===================== Int16x16x32 K12x8x1 algo ===================== */
namespace {
void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) {
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -121,12 +122,9 @@ public:
#else
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_4X4X16";
}
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......@@ -188,6 +186,36 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_16X12X4";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <cmath>
......@@ -1140,8 +1141,8 @@ static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1,
"stp q2, q6, [%[outptr], #64]\n"
"stp q3, q7, [%[outptr], #96]\n"
: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1),
[ outptr ] "+r"(outptr)
:
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory");
}
......@@ -1153,7 +1154,7 @@ static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) {
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "memory");
}
......@@ -1550,7 +1551,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
"stp q2, q6, [%[outptr], #96] \n"
"stp q10, q3, [%[outptr], #128] \n"
"stp q7, q11, [%[outptr], #160] \n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "memory");
......@@ -1564,7 +1565,7 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) {
asm volatile(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "memory");
}
......@@ -1681,13 +1682,12 @@ static inline void transpose_12x4_1_s(const T*& inptr0, const T*& inptr1,
"st1 {v3.4s,v4.4s,v5.4s}, [%[outptr]], #48\n"
"st1 {v6.4s,v7.4s,v8.4s}, [%[outptr]], #48\n"
"st1 {v24.4s,v25.4s,v26.4s}, [%[outptr]], #48\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
[inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7),
[inptr8] "+r"(inptr8), [inptr9] "+r"(inptr9),
[inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11),
[outptr] "+r"(outptr)
:
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8),
[inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10),
[inptr11] "+r"(inptr11), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
......@@ -1972,6 +1972,135 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr,
: "v0", "v1", "v2", "v3", "v4", "memory");
}
static inline void interleave_4x4_16x4_s8_s16(const int8_t* inptr0,
const int8_t* inptr1,
const int8_t* inptr2,
const int8_t* inptr3,
int16_t* outptr) {
int8x16_t row0 = vld1q_s8(inptr0);
int16x8_t row0_01 = vmovl_low_s8(row0);
int16x8_t row0_23 = vmovl_high_s8(row0);
int16x4_t row0_0 = vget_low_s16(row0_01);
int16x4_t row0_1 = vget_high_s16(row0_01);
int16x4_t row0_2 = vget_low_s16(row0_23);
int16x4_t row0_3 = vget_high_s16(row0_23);
int8x16_t row1 = vld1q_s8(inptr1);
int16x8_t row1_01 = vmovl_low_s8(row1);
int16x8_t row1_23 = vmovl_high_s8(row1);
int16x4_t row1_0 = vget_low_s16(row1_01);
int16x4_t row1_1 = vget_high_s16(row1_01);
int16x4_t row1_2 = vget_low_s16(row1_23);
int16x4_t row1_3 = vget_high_s16(row1_23);
int8x16_t row2 = vld1q_s8(inptr2);
int16x8_t row2_01 = vmovl_low_s8(row2);
int16x8_t row2_23 = vmovl_high_s8(row2);
int16x4_t row2_0 = vget_low_s16(row2_01);
int16x4_t row2_1 = vget_high_s16(row2_01);
int16x4_t row2_2 = vget_low_s16(row2_23);
int16x4_t row2_3 = vget_high_s16(row2_23);
int8x16_t row3 = vld1q_s8(inptr3);
int16x8_t row3_01 = vmovl_low_s8(row3);
int16x8_t row3_23 = vmovl_high_s8(row3);
int16x4_t row3_0 = vget_low_s16(row3_01);
int16x4_t row3_1 = vget_high_s16(row3_01);
int16x4_t row3_2 = vget_low_s16(row3_23);
int16x4_t row3_3 = vget_high_s16(row3_23);
vst1_s16(outptr, row0_0);
vst1_s16(outptr + 1 * 4, row1_0);
vst1_s16(outptr + 2 * 4, row2_0);
vst1_s16(outptr + 3 * 4, row3_0);
vst1_s16(outptr + 4 * 4, row0_1);
vst1_s16(outptr + 5 * 4, row1_1);
vst1_s16(outptr + 6 * 4, row2_1);
vst1_s16(outptr + 7 * 4, row3_1);
vst1_s16(outptr + 8 * 4, row0_2);
vst1_s16(outptr + 9 * 4, row1_2);
vst1_s16(outptr + 10 * 4, row2_2);
vst1_s16(outptr + 11 * 4, row3_2);
vst1_s16(outptr + 12 * 4, row0_3);
vst1_s16(outptr + 13 * 4, row1_3);
vst1_s16(outptr + 14 * 4, row2_3);
vst1_s16(outptr + 15 * 4, row3_3);
};
static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0,
const int8_t* inptr1,
int16_t* outptr) {
int8x16_t row0 = vld1q_s8(inptr0);
int16x8_t row0_01 = vmovl_low_s8(row0);
int16x8_t row0_23 = vmovl_high_s8(row0);
int16x4_t row0_0 = vget_low_s16(row0_01);
int16x4_t row0_1 = vget_high_s16(row0_01);
int16x4_t row0_2 = vget_low_s16(row0_23);
int16x4_t row0_3 = vget_high_s16(row0_23);
int8x16_t row1 = vld1q_s8(inptr1);
int16x8_t row1_01 = vmovl_low_s8(row1);
int16x8_t row1_23 = vmovl_high_s8(row1);
int16x4_t row1_0 = vget_low_s16(row1_01);
int16x4_t row1_1 = vget_high_s16(row1_01);
int16x4_t row1_2 = vget_low_s16(row1_23);
int16x4_t row1_3 = vget_high_s16(row1_23);
vst1_s16(outptr, row0_0);
vst1_s16(outptr + 1 * 4, row1_0);
vst1_s16(outptr + 2 * 4, row0_1);
vst1_s16(outptr + 3 * 4, row1_1);
vst1_s16(outptr + 4 * 4, row0_2);
vst1_s16(outptr + 5 * 4, row1_2);
vst1_s16(outptr + 6 * 4, row0_3);
vst1_s16(outptr + 7 * 4, row1_3);
};
static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr,
int count) {
for (; count >= 32; count -= 32) {
int8x8_t in0 = vld1_s8(inptr);
int8x8_t in1 = vld1_s8(inptr + 1 * 8);
int8x8_t in2 = vld1_s8(inptr + 2 * 8);
int8x8_t in3 = vld1_s8(inptr + 3 * 8);
vst1q_s16(outptr, vmovl_s8(in0));
vst1q_s16(outptr + 1 * 8, vmovl_s8(in1));
vst1q_s16(outptr + 2 * 8, vmovl_s8(in2));
vst1q_s16(outptr + 3 * 8, vmovl_s8(in3));
inptr += 32;
outptr += 32;
}
for (; count >= 8; count -= 8) {
int8x8_t in0 = vld1_s8(inptr);
vst1q_s16(outptr, vmovl_s8(in0));
inptr += 8;
outptr += 8;
}
for (; count > 0; --count) {
*outptr++ = (int16_t)(*inptr++);
}
}
static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) {
static const uint8_t src_idx_buffer[16] = {0, 4, 8, 12, 1, 5, 9, 13,
2, 6, 10, 14, 3, 7, 11, 15};
static const uint8x16_t vtbl = vld1q_u8(&src_idx_buffer[0]);
int8x8x4_t input = vld4_s8(inptr0);
int8x16_t input2 = vqtbl1q_s8(vld1q_s8(inptr0 + 4 * 8), vtbl);
vst1_s8(outptr, input.val[0]);
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 8),
vreinterpretq_s32_s8(input2), 0);
vst1_s8(outptr + 1 * 12, input.val[1]);
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 1 * 12 + 8),
vreinterpretq_s32_s8(input2), 1);
vst1_s8(outptr + 2 * 12, input.val[2]);
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 2 * 12 + 8),
vreinterpretq_s32_s8(input2), 2);
vst1_s8(outptr + 3 * 12, input.val[3]);
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 3 * 12 + 8),
vreinterpretq_s32_s8(input2), 3);
}
} // namespace aarch64
} // namespace megdnn
......
此差异已折叠。
此差异已折叠。
......@@ -6,42 +6,55 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h"
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h"
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16);
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool transpose_A) const {
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
int k0, int kmax, bool transpose_A) const {
if (transpose_A) {
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
}
void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
if (transpose_B) {
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}
void sgemm_4x16::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
......@@ -61,14 +74,16 @@ void sgemm_4x16::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k,
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K16;
}
for (; n < N; n += 4) {
matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
matmul_general_4x16::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
......@@ -80,8 +95,8 @@ void sgemm_4x16::kern(const float* packA, const float* packB,
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12);
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool transpose_A) const {
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
int k0, int kmax, bool transpose_A) const {
if (transpose_A) {
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
......@@ -102,16 +117,10 @@ void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
}
}
void sgemm_8x12::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
template <typename gemm_class>
static inline void sgemm_8x12_helper(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
......@@ -126,15 +135,13 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
gemm_class::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k,
gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
......@@ -146,17 +153,16 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k,
gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_general_8x12::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
......@@ -164,6 +170,33 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
}
}
void sgemm_8x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
#if !MGB_ENABLE_CPUINFO
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
#else
auto arch = cpuinfo_get_current_core()->uarch;
if (arch == cpuinfo_uarch_cortex_a53) {
sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C,
LDC, is_first_k);
} else if (arch == cpuinfo_uarch_cortex_a55) {
sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C,
LDC, is_first_k);
} else {
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
}
#endif
}
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12);
void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0,
......@@ -180,25 +213,17 @@ void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0,
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}
void sgemm_mk4_8x12::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
template <typename gemm_name>
static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;
constexpr size_t PACK_C_SIZE = 4;
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
float* output = C + (m / PACK_C_SIZE * LDC);
......@@ -206,15 +231,14 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
matmul_mk4_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
gemm_name::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE * PACK_C_SIZE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
......@@ -225,19 +249,45 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k);
gemm_name::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE * PACK_C_SIZE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
packA += K4;
}
}
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
#if !MGB_ENABLE_CPUINFO
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
#else
auto arch = cpuinfo_get_current_core()->uarch;
if (arch == cpuinfo_uarch_cortex_a53) {
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C,
LDC, is_first_k);
} else if (arch == cpuinfo_uarch_cortex_a55) {
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C,
LDC, is_first_k);
} else {
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
}
#endif
}
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
......
/**
* \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <inttypes.h>
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace aarch64 {
namespace matmul_mk4_4x4x8_a72 {
//! optimize for A72
// clang-format off
/**
* Overview of register layout:
*
* A 4x4x8 cell of Lhs is stored in 8bit in q0-q3, q4-q7
* A 4x4x8 cell of Rhs is stored in 8bit in q8-q11, q12-q15
* A 4x4 block of accumulators is stored in 16bit in q16-q31
*
* +------------------------+
* | q8 | q9 | q10 | q11 |
* Rhs +------------------------+
* Lhs | | | | |
* +--------+ - - - - +------------------------+
* | q0 | | q16 | q20 | q24 | q28 |
* | q1 | | q17 | q21 | q25 | q29 |
* | q2 | | q18 | q22 | q26 | q30 |
* | q3 | | q19 | q23 | q27 | q31 |
* +--------+ - - - - +------------------------+
*
* Accumulator
*/
// clang-format on
static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool, int remain_n) {
K = div_ceil(K, 8);
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
LDC = LDC * sizeof(int8_t);
// clang-format off
#define STORE_LINE(reg0) \
"cmp w10, #0 \n" \
"beq 101f\n" \
"st1 {v" reg0 ".4h}, [x0], #8\n" \
"subs w10, w10, #1\n"
#define STORE_C \
"mov w10, %w[remain_n]\n" \
STORE_LINE("16") \
STORE_LINE("20") \
STORE_LINE("24") \
STORE_LINE("28")
// clang-format on
register int16_t* outptr asm("x0") = output;
asm volatile(
// load accumulator C
"1:\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v20.16b, v20.16b\n"
"eor v21.16b, v21.16b, v21.16b\n"
"eor v22.16b, v22.16b, v22.16b\n"
"eor v23.16b, v23.16b, v23.16b\n"
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
"2: \n"
"ld1 {v0.8b, v1.8b}, [%[a_ptr]], #16\n"
"ld1 {v2.8b, v3.8b}, [%[a_ptr]], #16\n"
"ld1 {v8.8b, v9.8b}, [%[b_ptr]], #16\n"
"ld1 {v10.8b, v11.8b}, [%[b_ptr]], #16\n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3: \n"
//! k = 0
"smlal v16.8h, v0.8b, v8.8b\n"
"ld1 {v4.8b}, [%[a_ptr]], #8\n"
"smlal v17.8h, v1.8b, v8.8b\n"
"smlal v18.8h, v2.8b, v8.8b\n"
"ld1 {v5.8b}, [%[a_ptr]], #8\n"
"smlal v19.8h, v3.8b, v8.8b\n"
"smlal v20.8h, v0.8b, v9.8b\n"
"ld1 {v6.8b}, [%[a_ptr]], #8\n"
"smlal v21.8h, v1.8b, v9.8b\n"
"smlal v22.8h, v2.8b, v9.8b\n"
"ld1 {v7.8b}, [%[a_ptr]], #8\n"
"smlal v23.8h, v3.8b, v9.8b\n"
"smlal v24.8h, v0.8b, v10.8b\n"
"ld1 {v12.8b}, [%[b_ptr]], #8\n"
"smlal v25.8h, v1.8b, v10.8b\n"
"smlal v26.8h, v2.8b, v10.8b\n"
"ld1 {v13.8b}, [%[b_ptr]], #8\n"
"smlal v27.8h, v3.8b, v10.8b\n"
"smlal v28.8h, v0.8b, v11.8b\n"
"ld1 {v14.8b}, [%[b_ptr]], #8\n"
"smlal v29.8h, v1.8b, v11.8b\n"
"smlal v30.8h, v2.8b, v11.8b\n"
"ld1 {v15.8b}, [%[b_ptr]], #8\n"
"smlal v31.8h, v3.8b, v11.8b\n"
//! k = 8
"smlal v16.8h, v4.8b, v12.8b\n"
"ld1 {v0.8b}, [%[a_ptr]], #8\n"
"smlal v17.8h, v5.8b, v12.8b\n"
"smlal v18.8h, v6.8b, v12.8b\n"
"ld1 {v1.8b}, [%[a_ptr]], #8\n"
"smlal v19.8h, v7.8b, v12.8b\n"
"smlal v20.8h, v4.8b, v13.8b\n"
"ld1 {v2.8b}, [%[a_ptr]], #8\n"
"smlal v21.8h, v5.8b, v13.8b\n"
"smlal v22.8h, v6.8b, v13.8b\n"
"ld1 {v3.8b}, [%[a_ptr]], #8\n"
"smlal v23.8h, v7.8b, v13.8b\n"
"smlal v24.8h, v4.8b, v14.8b\n"
"ld1 {v8.8b}, [%[b_ptr]], #8\n"
"smlal v25.8h, v5.8b, v14.8b\n"
"smlal v26.8h, v6.8b, v14.8b\n"
"ld1 {v9.8b}, [%[b_ptr]], #8\n"
"smlal v27.8h, v7.8b, v14.8b\n"
"smlal v28.8h, v4.8b, v15.8b\n"
"ld1 {v10.8b}, [%[b_ptr]], #8\n"
"smlal v29.8h, v5.8b, v15.8b\n"
"smlal v30.8h, v6.8b, v15.8b\n"
"ld1 {v11.8b}, [%[b_ptr]], #8\n"
"smlal v31.8h, v7.8b, v15.8b\n"
"subs %w[K], %w[K], #1\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
//! even tail
//! k = 0
"smlal v16.8h, v0.8b, v8.8b\n"
"ld1 {v4.8b}, [%[a_ptr]], #8\n"
"smlal v17.8h, v1.8b, v8.8b\n"
"smlal v18.8h, v2.8b, v8.8b\n"
"ld1 {v5.8b}, [%[a_ptr]], #8\n"
"smlal v19.8h, v3.8b, v8.8b\n"
"smlal v20.8h, v0.8b, v9.8b\n"
"ld1 {v6.8b}, [%[a_ptr]], #8\n"
"smlal v21.8h, v1.8b, v9.8b\n"
"smlal v22.8h, v2.8b, v9.8b\n"
"ld1 {v7.8b}, [%[a_ptr]], #8\n"
"smlal v23.8h, v3.8b, v9.8b\n"
"smlal v24.8h, v0.8b, v10.8b\n"
"ld1 {v12.8b}, [%[b_ptr]], #8\n"
"smlal v25.8h, v1.8b, v10.8b\n"
"smlal v26.8h, v2.8b, v10.8b\n"
"ld1 {v13.8b}, [%[b_ptr]], #8\n"
"smlal v27.8h, v3.8b, v10.8b\n"
"smlal v28.8h, v0.8b, v11.8b\n"
"ld1 {v14.8b}, [%[b_ptr]], #8\n"
"smlal v29.8h, v1.8b, v11.8b\n"
"smlal v30.8h, v2.8b, v11.8b\n"
"ld1 {v15.8b}, [%[b_ptr]], #8\n"
"smlal v31.8h, v3.8b, v11.8b\n"
//! k = 8
"smlal v16.8h, v4.8b, v12.8b\n"
"smlal v17.8h, v5.8b, v12.8b\n"
"smlal v18.8h, v6.8b, v12.8b\n"
"smlal v19.8h, v7.8b, v12.8b\n"
"smlal v20.8h, v4.8b, v13.8b\n"
"smlal v21.8h, v5.8b, v13.8b\n"
"smlal v22.8h, v6.8b, v13.8b\n"
"smlal v23.8h, v7.8b, v13.8b\n"
"smlal v24.8h, v4.8b, v14.8b\n"
"smlal v25.8h, v5.8b, v14.8b\n"
"smlal v26.8h, v6.8b, v14.8b\n"
"smlal v27.8h, v7.8b, v14.8b\n"
"smlal v28.8h, v4.8b, v15.8b\n"
"smlal v29.8h, v5.8b, v15.8b\n"
"smlal v30.8h, v6.8b, v15.8b\n"
"smlal v31.8h, v7.8b, v15.8b\n"
"b 6f\n"
"5:\n"
//! odd tail
"smlal v16.8h, v0.8b, v8.8b\n"
"smlal v17.8h, v1.8b, v8.8b\n"
"smlal v18.8h, v2.8b, v8.8b\n"
"smlal v19.8h, v3.8b, v8.8b\n"
"smlal v20.8h, v0.8b, v9.8b\n"
"smlal v21.8h, v1.8b, v9.8b\n"
"smlal v22.8h, v2.8b, v9.8b\n"
"smlal v23.8h, v3.8b, v9.8b\n"
"smlal v24.8h, v0.8b, v10.8b\n"
"smlal v25.8h, v1.8b, v10.8b\n"
"smlal v26.8h, v2.8b, v10.8b\n"
"smlal v27.8h, v3.8b, v10.8b\n"
"smlal v28.8h, v0.8b, v11.8b\n"
"smlal v29.8h, v1.8b, v11.8b\n"
"smlal v30.8h, v2.8b, v11.8b\n"
"smlal v31.8h, v3.8b, v11.8b\n"
"6:\n"
//! reduece
"addp v16.8h, v16.8h, v17.8h\n"
"addp v18.8h, v18.8h, v19.8h\n"
"addp v20.8h, v20.8h, v21.8h\n"
"addp v22.8h, v22.8h, v23.8h\n"
"addp v24.8h, v24.8h, v25.8h\n"
"addp v26.8h, v26.8h, v27.8h\n"
"addp v16.8h, v16.8h, v18.8h\n"
"addp v28.8h, v28.8h, v29.8h\n"
"addp v30.8h, v30.8h, v31.8h\n"
"addp v20.8h, v20.8h, v22.8h\n"
"addp v16.8h, v16.8h, v16.8h\n"
"addp v20.8h, v20.8h, v20.8h\n"
"addp v24.8h, v24.8h, v26.8h\n"
"addp v24.8h, v24.8h, v24.8h\n"
"addp v28.8h, v28.8h, v30.8h\n"
"addp v28.8h, v28.8h, v28.8h\n"
"cmp %w[remain_n], #4\n"
"bne 7f\n"
"st1 {v16.4h}, [x0], #8\n"
"st1 {v20.4h}, [x0], #8\n"
"st1 {v24.4h}, [x0], #8\n"
"st1 {v28.4h}, [x0], #8\n"
"b 101f\n"
"7:\n" STORE_C
"101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "cc", "memory");
#undef STORE_C
#undef STORE_LINE
}
static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) {
int8x8x4_t in0 = vld4_s8(inptr);
vst1_s8(outptr + 0 * 8, in0.val[0]);
vst1_s8(outptr + 1 * 8, in0.val[1]);
vst1_s8(outptr + 2 * 8, in0.val[2]);
vst1_s8(outptr + 3 * 8, in0.val[3]);
}
static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2,
dt_int8* outptr) {
int8x16_t in0 = vld1q_s8(inptr);
int8x16_t in1 = vld1q_s8(inptr2);
int32x4x2_t in_x2 = {
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2);
}
static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) {
int8x16_t in0 = vld1q_s8(inptr);
int8x16_t in1 = vdupq_n_s8(0);
int32x4x2_t in_x2 = {
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2);
}
static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in,
int ldin, int m0, int mmax, int k0,
int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 4;
constexpr int pack_k = 8;
constexpr int pack_size = 4;
const int ksize = kmax - k0;
const int remain_k = ksize % pack_k;
const int kend = kmax - remain_k;
int8_t tmpbuff[pack_m * pack_k]{0};
for (int m_idx = m0; m_idx < mmax; m_idx += pack_m) {
const int8_t* inptr0 = in + m_idx / pack_size * ldin + k0;
for (int k_idx = k0; k_idx < kend; k_idx += pack_k) {
transpose_8x4_b(inptr0, out);
inptr0 += pack_m * pack_k;
out += pack_m * pack_k;
}
if (remain_k > 0) {
int8x16_t tmp = vld1q_s8(inptr0);
vst1q_s8(&tmpbuff[0], tmp);
transpose_8x4_b(&tmpbuff[0], out);
inptr0 += pack_m * pack_size;
out += pack_m * pack_k;
}
}
}
static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_n = 4;
constexpr int pack_k = 8;
constexpr int pack_size = 4;
const int ksize = kmax - k0;
const int packed_ksize = round_up(ksize, pack_k);
const int remain_k = ksize % pack_k;
const int kend = kmax - remain_k;
const int nsize = nmax - n0;
const int remain_n = nsize % pack_n;
const int nend = nmax - remain_n;
const int stride_input = pack_size * nsize;
int8_t tmpbuff[pack_n * pack_k]{0};
int8_t tmpbuff2[pack_n * pack_k]{0};
for (int k_idx = k0; k_idx < kend; k_idx += pack_k) {
const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size;
const int8_t* inptr2 = inptr + stride_input;
int8_t* outptr = out + k_idx * pack_n;
for (int n_idx = n0; n_idx < nend; n_idx += pack_n) {
interleve_8x4_b(inptr, inptr2, outptr);
inptr += pack_n * pack_size;
inptr2 += pack_n * pack_size;
outptr += pack_n * packed_ksize;
}
if (remain_n > 0) {
memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t));
memcpy(&tmpbuff2[0], inptr2, remain_n * pack_size * sizeof(int8_t));
interleve_8x4_b(&tmpbuff[0], &tmpbuff2[0], outptr);
outptr += pack_n * packed_ksize;
}
}
if (remain_k > 0) {
const int8_t* inptr = in + kend / pack_size * ldin + n0 * pack_size;
int8_t* outptr = out + kend * pack_n;
for (int n_idx = n0; n_idx < nend; n_idx += pack_n) {
interleve_8x4_b_pad(inptr, outptr);
inptr += pack_n * pack_size;
outptr += pack_n * packed_ksize;
}
if (remain_n > 0) {
memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t));
interleve_8x4_b_pad(&tmpbuff[0], outptr);
outptr += pack_n * packed_ksize;
}
}
}
} // namespace matmul_mk4_4x4x8_a72
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,12 +6,15 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
......@@ -197,4 +200,161 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
packA += K4;
}
}
// ===========================gemm_s8x8x16_mk4_16x12==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53);
void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0,
ymax, k0, kmax);
}
void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0,
xmax, k0, kmax);
}
void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
const dt_int8* packB, size_t M, size_t N,
size_t K, dt_int16* C, size_t LDC,
bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
constexpr size_t pack_size = 4;
constexpr size_t pack_m = 16;
constexpr size_t pack_n = 12;
const size_t remain_n = N % pack_n;
size_t remain_m = M % pack_m;
size_t m_idx = 0;
for (; m_idx + pack_m <= M; m_idx += pack_m) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
packA += pack_m * K;
}
if (remain_m >= 8) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
packA += 8 * K;
m_idx += 8;
remain_m -= 8;
}
if (remain_m == 4) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
}
}
// ===========================gemm_s8x8x16_mk4_4x4_a72==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72);
void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax,
k0, kmax);
}
void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax,
k0, kmax);
}
void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k,
const dt_int16*, dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
constexpr size_t pack_size = 4;
constexpr size_t pack_m = 4;
constexpr size_t pack_n = 4;
constexpr size_t pack_k = 8;
const size_t remain_n = N % pack_n;
const size_t nend = N - remain_n;
const size_t packed_k = round_up(K, pack_k);
for (size_t m_idx = 0; m_idx < M; m_idx += pack_m) {
int16_t* output = C + (m_idx / pack_size * LDC);
const int8_t* cur_packB = packB;
for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) {
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * packed_k;
}
if (remain_n > 0) {
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * packed_k;
}
packA += pack_m * packed_k;
}
}
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -20,6 +21,11 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true,
gemm_s8x8x16_8x8);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true,
gemm_s8x8x16_4x4);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false,
gemm_s8x8x16_mk4_4x4_a72);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16,
16, 12, 4, false, false,
gemm_s8x8x16_mk4_16x12_a53);
} // namespace matmul
} // namespace aarch64
......
......@@ -6,10 +6,11 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/aarch64/matrix_mul/algos.h"
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
......@@ -36,6 +37,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8;
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16;
AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4;
AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8;
AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1;
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8;
......@@ -70,6 +73,8 @@ public:
#endif
all_algos.emplace_back(&int8x8x16_k4x4x16);
all_algos.emplace_back(&int8x8x16_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_4x4x8);
all_algos.emplace_back(&int8x8x16_mk4_16x12x4);
all_algos.emplace_back(&int16x16x32_k12x8x1);
all_algos.emplace_back(&int16x16x32_mk8_8x8);
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/matrix_mul/opr_impl.h"
......@@ -43,6 +44,8 @@ private:
#endif
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16
class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16
class AlgoInt8x8x16MK4_4x4x8; // Aarch64 Int8x8x16 Kernel 4x4x8
class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1
class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8
......
......@@ -214,7 +214,6 @@ void* const ConvBiasImpl::sm_arm_common_algo_type =
bool ConvBiasImpl::is_matmul_quantized_prefer(
const ConvBiasImpl::NCBKernSizeParam& param) const {
// fallback::ConvBiasImpl::NCBKernParam conv_ncb_param;
fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param(
param, 0, param::MatrixMul::Format::DEFAULT, {}, 0,
BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY);
......
......@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifdef MGB_ENABLE_CPUINFO_CHECK
#include "src/common/utils.h"
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
#include "cpuinfo_arch_vendor.h"
......
......@@ -11,8 +11,8 @@
*/
#pragma once
#ifdef MGB_ENABLE_CPUINFO_CHECK
#include "src/common/utils.h"
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
#include <cpuinfo.h>
......
此差异已折叠。
......@@ -736,15 +736,21 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) {
}
#endif
#if MEGDNN_ARMV7
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) {
#if MEGDNN_ARMV7
const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8";
const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4";
printf("compare %s vs %s \n", default_algo, mk4_algo);
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3,
dtype::Int8(), dtype::Int16());
}
#else
const char* default_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16";
const char* mk4_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8";
printf("compare %s vs %s \n", default_algo, mk4_algo);
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3,
dtype::Int8(), dtype::Int16());
#endif
}
TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE1) {
BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD1_NCHW44",
......
......@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifdef MGB_ENABLE_CPUINFO_CHECK
#include "src/common/utils.h"
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
#include <cpuinfo.h>
#include <inttypes.h>
#include "gtest/gtest.h"
......@@ -18,7 +19,6 @@ namespace megdnn {
namespace test {
TEST(ARM_RUNTIME, CPUINFO_KIRIN980) {
ASSERT_TRUE(cpuinfo_initialize());
int right_soc = strcmp(cpuinfo_get_package(0)->name, "HiSilicon Kirin 980");
......@@ -68,7 +68,6 @@ TEST(ARM_RUNTIME, CPUINFO_KIRIN980) {
}
TEST(ARM_RUNTIME, CPUINFO_SDM8150) {
ASSERT_TRUE(cpuinfo_initialize());
int right_soc =
......@@ -119,7 +118,6 @@ TEST(ARM_RUNTIME, CPUINFO_SDM8150) {
}
TEST(ARM_RUNTIME, CPUINFO_SDM660) {
ASSERT_TRUE(cpuinfo_initialize());
int right_soc =
......@@ -173,4 +171,3 @@ TEST(ARM_RUNTIME, CPUINFO_SDM660) {
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册