提交 c5623c87 编写于 作者: T tensor-tang

init jit matmul kernel

上级 a92860a3
......@@ -210,6 +210,24 @@ void BenchSeqPoolKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
for (int k : TestSizes()) {
std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f);
RandomVec<T>(k * n, b.data(), -2.f, 2.f);
const T* a_data = a.data();
const T* b_data = b.data();
T* c_data = c.data();
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(k, a_data, b_data,
c_data, m, n, k);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
......@@ -251,4 +269,7 @@ int main(int argc, char* argv[]) {
// seq pool function
BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
// matmul
BenchMatMulKernel<jit::kMatMul, T, PlaceType>();
}
......@@ -47,6 +47,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool);
ONE_CASE(kMatMul);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
......
......@@ -42,6 +42,7 @@ typedef enum {
kLayerNorm,
kNCHW16CMulNC,
kSeqPool,
kMatMul,
} KernelType;
typedef enum {
......@@ -135,6 +136,13 @@ struct SeqPoolTuples {
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
};
template <typename T>
struct MatMulTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int, int);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;
......
......@@ -3,6 +3,7 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
# USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
......
......@@ -27,3 +27,4 @@ USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool)
USE_JITKERNEL_REFER(kMatMul)
......@@ -49,4 +49,6 @@ REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
REGISTER_REFER_KERNEL(kMatMul, MatMul);
#undef REGISTER_REFER_KERNEL
......@@ -354,6 +354,10 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
}
}
// A(M,K) * B(K,N) = C(M,N)
template <typename T>
void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
......@@ -394,6 +398,8 @@ DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer
......
......@@ -229,6 +229,25 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
}
};
template <typename T>
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::MatMulTuples<T>::func_type tgt,
const std::vector<T>& a, const std::vector<T>& b,
const std::vector<T>& cref, int m, int n, int k) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(a.size(), static_cast<size_t>(m * k));
EXPECT_EQ(b.size(), static_cast<size_t>(k * n));
EXPECT_EQ(cref.size(), static_cast<size_t>(m * n));
std::vector<T> c(cref.size());
const T* a_data = a.data();
const T* b_data = b.data();
const T* cref_data = cref.data();
T* c_data = c.data();
tgt(a_data, b_data, c_data, m, n, k);
ExpectEQ<T>(c_data, cref_data, m * n);
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
......@@ -458,6 +477,28 @@ void TestSeqPoolKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestMatMulKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
for (int k : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f);
RandomVec<T>(k * n, b.data(), -2.f, 2.f);
const T* a_data = a.data();
const T* b_data = b.data();
T* c_data = c.data();
ref(a_data, b_data, c_data, m, n, k);
TestAllImpls<KT, jit::MatMulTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, std::vector<T>>(k, a, b, c, m, n, k);
}
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
......@@ -618,6 +659,12 @@ TEST(JITKernel, kSeqPool) {
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kMatMul) {
namespace jit = paddle::operators::jit;
TestMatMulKernel<jit::kMatMul, float, paddle::platform::CPUPlace>();
TestMatMulKernel<jit::kMatMul, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册