提交 266a5d2f 编写于 作者: T tensor-tang

implement matmul refer and mkl kernel

上级 c5623c87
......@@ -213,7 +213,7 @@ void BenchSeqPoolKernel() {
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
for (int n : TestSizes()) {
for (int k : TestSizes()) {
std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f);
......
......@@ -3,7 +3,7 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
# USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
......
......@@ -24,6 +24,20 @@ namespace jit {
namespace more {
namespace mkl {
template <>
void MatMul<float>(const float* a, const float* b, float* c, int m, int n,
int k) {
platform::dynload::cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
n, k, 1.f, a, k, b, n, 0.f, c, n);
}
template <>
void MatMul<double>(const double* a, const double* b, double* c, int m, int n,
int k) {
platform::dynload::cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
n, k, 1.0, a, k, b, n, 0.0, c, n);
}
template <>
void VMul<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
......@@ -93,6 +107,11 @@ void VAXPY<double>(double a, const double* x, double* y, int n) {
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool MatMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
......@@ -139,6 +158,7 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE(MatMul);
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
......@@ -159,6 +179,7 @@ namespace mkl = paddle::operators::jit::more::mkl;
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal);
......
......@@ -24,6 +24,9 @@ namespace jit {
namespace more {
namespace mkl {
template <typename T>
void MatMul(const T* a, const T* b, T* c, int m, int n, int k);
template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
......@@ -93,6 +96,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
const char* ImplType() const override { return "MKL"; } \
}
// ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples);
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
......
......@@ -356,7 +356,20 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
// A(M,K) * B(K,N) = C(M,N)
template <typename T>
void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {}
void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {
for (int m = 0; m < M; ++m) {
const T* pa = A + m * K;
T* pc = C + m * N;
for (int n = 0; n < N; ++n) {
const T* pb = B + n;
T sum = static_cast<T>(0);
for (int k = 0; k < K; ++k) {
sum += (pa[k] * pb[k * N]);
}
*(pc + n) = sum;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
......
......@@ -230,7 +230,8 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
};
template <typename T>
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>> {
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>, int, int, int> {
void operator()(const typename jit::MatMulTuples<T>::func_type tgt,
const std::vector<T>& a, const std::vector<T>& b,
const std::vector<T>& cref, int m, int n, int k) {
......@@ -486,8 +487,8 @@ void TestMatMulKernel() {
auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f);
RandomVec<T>(k * n, b.data(), -2.f, 2.f);
RandomVec<T>(m * k, a.data(), -0.2f, 0.2f);
RandomVec<T>(k * n, b.data(), -0.2f, 0.2f);
const T* a_data = a.data();
const T* b_data = b.data();
T* c_data = c.data();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册