提交 266a5d2f 编写于 作者: T tensor-tang

implement matmul refer and mkl kernel

上级 c5623c87
...@@ -213,7 +213,7 @@ void BenchSeqPoolKernel() { ...@@ -213,7 +213,7 @@ void BenchSeqPoolKernel() {
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() { void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) { for (int n : TestSizes()) {
for (int k : TestSizes()) { for (int k : TestSizes()) {
std::vector<T> a(m * k), b(k * n), c(m * n); std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f); RandomVec<T>(m * k, a.data(), -2.f, 2.f);
......
...@@ -3,7 +3,7 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) ...@@ -3,7 +3,7 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type # use mkl kernels by name and type
# USE_JITKERNEL_MORE(kMatMul, mkl) USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl) USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl) USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVScal, mkl)
......
...@@ -24,6 +24,20 @@ namespace jit { ...@@ -24,6 +24,20 @@ namespace jit {
namespace more { namespace more {
namespace mkl { namespace mkl {
template <>
void MatMul<float>(const float* a, const float* b, float* c, int m, int n,
int k) {
platform::dynload::cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
n, k, 1.f, a, k, b, n, 0.f, c, n);
}
template <>
void MatMul<double>(const double* a, const double* b, double* c, int m, int n,
int k) {
platform::dynload::cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
n, k, 1.0, a, k, b, n, 0.0, c, n);
}
template <> template <>
void VMul<float>(const float* x, const float* y, float* z, int n) { void VMul<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z); platform::dynload::vsMul(n, x, y, z);
...@@ -93,6 +107,11 @@ void VAXPY<double>(double a, const double* x, double* y, int n) { ...@@ -93,6 +107,11 @@ void VAXPY<double>(double a, const double* x, double* y, int n) {
} }
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool MatMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
template <> template <>
bool VMulKernel<float>::UseMe(const int& d) const { bool VMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
...@@ -139,6 +158,7 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const { ...@@ -139,6 +158,7 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true; \ return true; \
} }
AWALYS_USE_ME_WITH_DOUBLE(MatMul);
AWALYS_USE_ME_WITH_DOUBLE(VMul); AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd); AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal); AWALYS_USE_ME_WITH_DOUBLE(VScal);
...@@ -159,6 +179,7 @@ namespace mkl = paddle::operators::jit::more::mkl; ...@@ -159,6 +179,7 @@ namespace mkl = paddle::operators::jit::more::mkl;
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \ REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>) mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul); REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVScal, VScal);
......
...@@ -24,6 +24,9 @@ namespace jit { ...@@ -24,6 +24,9 @@ namespace jit {
namespace more { namespace more {
namespace mkl { namespace mkl {
template <typename T>
void MatMul(const T* a, const T* b, T* c, int m, int n, int k);
template <typename T> template <typename T>
void VMul(const T* x, const T* y, T* z, int n); void VMul(const T* x, const T* y, T* z, int n);
...@@ -93,6 +96,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { ...@@ -93,6 +96,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
const char* ImplType() const override { return "MKL"; } \ const char* ImplType() const override { return "MKL"; } \
} }
// ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples);
// XYZN // XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples); DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples); DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
......
...@@ -356,7 +356,20 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { ...@@ -356,7 +356,20 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
// A(M,K) * B(K,N) = C(M,N) // A(M,K) * B(K,N) = C(M,N)
template <typename T> template <typename T>
void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {} void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {
for (int m = 0; m < M; ++m) {
const T* pa = A + m * K;
T* pc = C + m * N;
for (int n = 0; n < N; ++n) {
const T* pb = B + n;
T sum = static_cast<T>(0);
for (int k = 0; k < K; ++k) {
sum += (pa[k] * pb[k * N]);
}
*(pc + n) = sum;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
......
...@@ -230,7 +230,8 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>, ...@@ -230,7 +230,8 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
}; };
template <typename T> template <typename T>
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>> { struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>, int, int, int> {
void operator()(const typename jit::MatMulTuples<T>::func_type tgt, void operator()(const typename jit::MatMulTuples<T>::func_type tgt,
const std::vector<T>& a, const std::vector<T>& b, const std::vector<T>& a, const std::vector<T>& b,
const std::vector<T>& cref, int m, int n, int k) { const std::vector<T>& cref, int m, int n, int k) {
...@@ -486,8 +487,8 @@ void TestMatMulKernel() { ...@@ -486,8 +487,8 @@ void TestMatMulKernel() {
auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>(); auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> a(m * k), b(k * n), c(m * n); std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f); RandomVec<T>(m * k, a.data(), -0.2f, 0.2f);
RandomVec<T>(k * n, b.data(), -2.f, 2.f); RandomVec<T>(k * n, b.data(), -0.2f, 0.2f);
const T* a_data = a.data(); const T* a_data = a.data();
const T* b_data = b.data(); const T* b_data = b.data();
T* c_data = c.data(); T* c_data = c.data();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册