提交 81177258 编写于 作者: T tensor-tang

add jit kernel hsum, hmax and softmax refer code

test=develop
上级 67e4450c
...@@ -158,7 +158,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -158,7 +158,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() { void BenchXYZNKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x, y, z; Tensor x, y, z;
...@@ -175,7 +175,7 @@ void BenchXYZNKernel() { ...@@ -175,7 +175,7 @@ void BenchXYZNKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() { void BenchAXYNKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const T a = static_cast<T>(3); const T a = static_cast<T>(3);
...@@ -190,7 +190,17 @@ void BenchAXYNKernel() { ...@@ -190,7 +190,17 @@ void BenchAXYNKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXRNKernel() {
for (int d : TestSizes()) {
Tensor x;
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
T res;
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() { void BenchXYNKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x, y; Tensor x, y;
...@@ -203,7 +213,7 @@ void BenchXYNKernel() { ...@@ -203,7 +213,7 @@ void BenchXYNKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() { void BenchLSTMKernel() {
for (bool use_peephole : {true, false}) { for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) { for (int d : TestSizes()) {
...@@ -240,7 +250,7 @@ void BenchLSTMKernel() { ...@@ -240,7 +250,7 @@ void BenchLSTMKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() { void BenchGRUKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
...@@ -262,7 +272,7 @@ void BenchGRUKernel() { ...@@ -262,7 +272,7 @@ void BenchGRUKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() { void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
...@@ -284,7 +294,7 @@ void BenchSeqPoolKernel() { ...@@ -284,7 +294,7 @@ void BenchSeqPoolKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() { void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
...@@ -305,57 +315,64 @@ void BenchMatMulKernel() { ...@@ -305,57 +315,64 @@ void BenchMatMulKernel() {
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSoftmaxKernel() {
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
Tensor x, y;
x.Resize({bs, n});
y.Resize({bs, n});
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n,
bs);
}
}
}
using T = float; using T = float;
using PlaceType = paddle::platform::CPUPlace; using CPUPlace = paddle::platform::CPUPlace;
// xyzn // xyzn
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, PlaceType>(); } BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, PlaceType>(); } BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, PlaceType>(); }
// axyn // axyn
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, PlaceType>(); } BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, PlaceType>(); } // xrn
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); }
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
// xyn // xyn
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, PlaceType>(); } BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); } BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, PlaceType>(); } BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, PlaceType>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, PlaceType>(); }
// lstm and peephole // lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>(); } BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>(); }
// gru functions // gru functions
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); } BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart1) { BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
}
BENCH_FP32_CPU(kGRUHtPart2) {
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
}
// seq pool function // seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>(); } BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// matmul // matmul
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, PlaceType>(); } BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
// softmax
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); }
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
......
...@@ -49,6 +49,9 @@ const char* to_string(KernelType kt) { ...@@ -49,6 +49,9 @@ const char* to_string(KernelType kt) {
ONE_CASE(kNCHW16CMulNC); ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool); ONE_CASE(kSeqPool);
ONE_CASE(kMatMul); ONE_CASE(kMatMul);
ONE_CASE(kHMax);
ONE_CASE(kHSum);
ONE_CASE(kSoftmax);
default: default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt); PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel"; return "NOT JITKernel";
......
...@@ -20,6 +20,7 @@ namespace paddle { ...@@ -20,6 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
// TODO(TJ): reorder by alphabet
typedef enum { typedef enum {
kNone = 0, kNone = 0,
kVMul = 1, kVMul = 1,
...@@ -44,6 +45,9 @@ typedef enum { ...@@ -44,6 +45,9 @@ typedef enum {
kNCHW16CMulNC, kNCHW16CMulNC,
kSeqPool, kSeqPool,
kMatMul, kMatMul,
kHSum, // horizontal max
kHMax, // horizontal sum
kSoftmax,
} KernelType; } KernelType;
typedef enum { typedef enum {
...@@ -70,6 +74,10 @@ struct XYNTuples { ...@@ -70,6 +74,10 @@ struct XYNTuples {
typedef void (*func_type)(const T*, T*, int); typedef void (*func_type)(const T*, T*, int);
}; };
// x, return and int
template <typename T>
struct XRNTuples : public XYNTuples<T> {};
typedef struct { typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1; const void* ct_1;
...@@ -159,6 +167,13 @@ struct LayerNormTuples { ...@@ -159,6 +167,13 @@ struct LayerNormTuples {
const float, int); const float, int);
}; };
template <typename T>
struct SoftmaxTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
};
// nChw16c = nChw16c .* NC // nChw16c = nChw16c .* NC
template <typename T> template <typename T>
struct NCHW16CMulNCTuples { struct NCHW16CMulNCTuples {
......
...@@ -29,3 +29,6 @@ USE_JITKERNEL_REFER(kNCHW16CMulNC) ...@@ -29,3 +29,6 @@ USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool) USE_JITKERNEL_REFER(kSeqPool)
USE_JITKERNEL_REFER(kMatMul) USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare) USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kSoftmax)
...@@ -52,4 +52,9 @@ REGISTER_REFER_KERNEL(kSeqPool, SeqPool); ...@@ -52,4 +52,9 @@ REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
REGISTER_REFER_KERNEL(kMatMul, MatMul); REGISTER_REFER_KERNEL(kMatMul, MatMul);
REGISTER_REFER_KERNEL(kHMax, HMax);
REGISTER_REFER_KERNEL(kHSum, HSum);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -378,6 +378,40 @@ void MatMul(const T* A, const T* B, T* C, int M, int N, int K) { ...@@ -378,6 +378,40 @@ void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {
} }
} }
template <typename T>
void HMax(const T* x, T* res, int n) {
res[0] = x[0];
for (int i = 1; i < n; ++i) {
res[0] = res[0] < x[i] ? x[i] : res[0];
}
}
template <typename T>
void HSum(const T* x, T* res, int n) {
res[0] = x[0];
for (int i = 1; i < n; ++i) {
res[0] += x[i];
}
}
// y = e^(x - max(x))
// y = y / sum(y)
template <typename T>
void Softmax(const T* x, T* y, int n, int bs = 1) {
for (int i = 0; i < bs; ++i) {
T scalar;
HMax(x, &scalar, n);
scalar = static_cast<T>(0) - scalar;
VAddBias(&scalar, x, y, n); // x - max
VExp(y, y, n);
HSum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
VScal(&scalar, y, y, n);
x += n;
y += n;
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<tuples<T>> { \
...@@ -421,6 +455,11 @@ DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); ...@@ -421,6 +455,11 @@ DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples); DECLARE_REFER_KERNEL(MatMul, MatMulTuples);
DECLARE_REFER_KERNEL(HMax, XRNTuples);
DECLARE_REFER_KERNEL(HSum, XRNTuples);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -61,6 +61,7 @@ std::vector<int> TestSizes() { ...@@ -61,6 +61,7 @@ std::vector<int> TestSizes() {
} }
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
using CPUPlace = paddle::platform::CPUPlace;
template <typename KernelTuples, typename... Args> template <typename KernelTuples, typename... Args>
struct TestFuncWithRefer { struct TestFuncWithRefer {
...@@ -121,6 +122,40 @@ struct TestFuncWithRefer<jit::AXYNTuples<T>, T, std::vector<T>, ...@@ -121,6 +122,40 @@ struct TestFuncWithRefer<jit::AXYNTuples<T>, T, std::vector<T>,
} }
}; };
template <typename T>
struct TestFuncWithRefer<jit::SoftmaxTuples<T>, std::vector<T>, std::vector<T>,
int, int> {
void operator()(const typename jit::SoftmaxTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref, int n,
int bs) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
const T* x_data = x.data();
const T* yref_data = yref.data();
std::vector<T> ytgt(n * bs);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, n, bs);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, n, bs);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
}
};
template <typename T>
struct TestFuncWithRefer<jit::XRNTuples<T>, std::vector<T>, T> {
void operator()(const typename jit::XRNTuples<T>::func_type tgt,
const std::vector<T>& x, const T ref_res) {
EXPECT_TRUE(tgt != nullptr);
T tgt_res;
tgt(x.data(), &tgt_res, x.size());
ExpectEQ<T>(&tgt_res, &ref_res, 1);
}
};
template <typename T> template <typename T>
struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> { struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::XYNTuples<T>::func_type tgt, void operator()(const typename jit::XYNTuples<T>::func_type tgt,
...@@ -172,7 +207,7 @@ struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>, ...@@ -172,7 +207,7 @@ struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>,
T* ht_data = ht.data(); T* ht_data = ht.data();
T* checked_data = checked.data(); T* checked_data = checked.data();
paddle::operators::jit::lstm_t step; jit::lstm_t step;
step.gates = x_data; step.gates = x_data;
step.ct_1 = ct_1_data; step.ct_1 = ct_1_data;
step.ct = ct_data; step.ct = ct_data;
...@@ -208,7 +243,7 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>, ...@@ -208,7 +243,7 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
const T* ht_ref_data = ht_ref.data(); const T* ht_ref_data = ht_ref.data();
T* x_data = x.data(); T* x_data = x.data();
T* ht_data = ht.data(); T* ht_data = ht.data();
paddle::operators::jit::gru_t step; jit::gru_t step;
step.gates = x_data; step.gates = x_data;
step.ht_1 = ht_1_data; step.ht_1 = ht_1_data;
step.ht = ht_data; step.ht = ht_data;
...@@ -255,8 +290,8 @@ struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>, ...@@ -255,8 +290,8 @@ struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
} }
}; };
template <paddle::operators::jit::KernelType KT, typename KernelTuples, template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
typename PlaceType, typename... Args> typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
TestFuncWithRefer<KernelTuples, Args...> test; TestFuncWithRefer<KernelTuples, Args...> test;
// test jitcode // test jitcode
...@@ -286,9 +321,8 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -286,9 +321,8 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
test(tgt, args...); test(tgt, args...);
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYZNKernel() { void TestXYZNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
...@@ -320,9 +354,8 @@ void TestXYZNKernel() { ...@@ -320,9 +354,8 @@ void TestXYZNKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestAXYNKernel() { void TestAXYNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>();
...@@ -347,9 +380,23 @@ void TestAXYNKernel() { ...@@ -347,9 +380,23 @@ void TestAXYNKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestXRNKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d);
RandomVec<T>(d, x.data());
T ref_res;
ref(x.data(), &ref_res, d);
TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x,
ref_res);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYNKernel() { void TestXYNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>();
...@@ -373,9 +420,8 @@ void TestXYNKernel() { ...@@ -373,9 +420,8 @@ void TestXYNKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestLSTMKernel() { void TestLSTMKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"}; std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) { for (int d : TestSizes()) {
...@@ -424,9 +470,8 @@ void TestLSTMKernel() { ...@@ -424,9 +470,8 @@ void TestLSTMKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestGRUKernel() { void TestGRUKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"}; std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) { for (int d : TestSizes()) {
...@@ -459,7 +504,7 @@ void TestGRUKernel() { ...@@ -459,7 +504,7 @@ void TestGRUKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() { void TestSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
...@@ -484,7 +529,7 @@ void TestSeqPoolKernel() { ...@@ -484,7 +529,7 @@ void TestSeqPoolKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestMatMulKernel() { void TestMatMulKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = acc; auto last_acc = acc;
...@@ -510,7 +555,32 @@ void TestMatMulKernel() { ...@@ -510,7 +555,32 @@ void TestMatMulKernel() {
acc = last_acc; acc = last_acc;
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestSoftmaxKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::SoftmaxTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* y_data = y.data();
std::vector<T> xinp(x.size()); // inplace test
std::copy(x.begin(), x.end(), xinp.begin());
ref(x_data, y_data, n, bs);
T* xinp_data = xinp.data();
ref(xinp_data, xinp_data, n, bs);
ExpectEQ<T>(xinp_data, y_data, n * bs);
TestAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType, std::vector<T>,
std::vector<T>>(n, x, y, n, bs);
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() { void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const int n = 3, c = 16 * 4, h = 10, w = 10; const int n = 3, c = 16 * 4, h = 10, w = 10;
...@@ -565,129 +635,123 @@ void TestNCHW16CMulNCKernel() { ...@@ -565,129 +635,123 @@ void TestNCHW16CMulNCKernel() {
// XYZNTuple // XYZNTuple
TEST(JITKernel, kVMul) { TEST(JITKernel, kVMul) {
namespace jit = paddle::operators::jit; TestXYZNKernel<jit::kVMul, float, CPUPlace>();
TestXYZNKernel<jit::kVMul, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVMul, double, CPUPlace>();
TestXYZNKernel<jit::kVMul, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVAdd) { TEST(JITKernel, kVAdd) {
namespace jit = paddle::operators::jit; TestXYZNKernel<jit::kVAdd, float, CPUPlace>();
TestXYZNKernel<jit::kVAdd, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVAdd, double, CPUPlace>();
TestXYZNKernel<jit::kVAdd, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVAddRelu) { TEST(JITKernel, kVAddRelu) {
namespace jit = paddle::operators::jit; TestXYZNKernel<jit::kVAddRelu, float, CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVAddRelu, double, CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVSub) { TEST(JITKernel, kVSub) {
namespace jit = paddle::operators::jit; TestXYZNKernel<jit::kVSub, float, CPUPlace>();
TestXYZNKernel<jit::kVSub, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVSub, double, CPUPlace>();
TestXYZNKernel<jit::kVSub, double, paddle::platform::CPUPlace>();
} }
// AXYNTuples // AXYNTuples
TEST(JITKernel, kVScal) { TEST(JITKernel, kVScal) {
namespace jit = paddle::operators::jit; TestAXYNKernel<jit::kVScal, float, CPUPlace>();
TestAXYNKernel<jit::kVScal, float, paddle::platform::CPUPlace>(); TestAXYNKernel<jit::kVScal, double, CPUPlace>();
TestAXYNKernel<jit::kVScal, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVAddBias) { TEST(JITKernel, kVAddBias) {
namespace jit = paddle::operators::jit; TestAXYNKernel<jit::kVAddBias, float, CPUPlace>();
TestAXYNKernel<jit::kVAddBias, float, paddle::platform::CPUPlace>(); TestAXYNKernel<jit::kVAddBias, double, CPUPlace>();
TestAXYNKernel<jit::kVAddBias, double, paddle::platform::CPUPlace>(); }
// XRNTuples
TEST(JITKernel, kHMax) {
TestXRNKernel<jit::kHMax, float, CPUPlace>();
TestXRNKernel<jit::kHMax, double, CPUPlace>();
}
TEST(JITKernel, kHSum) {
TestXRNKernel<jit::kHSum, float, CPUPlace>();
TestXRNKernel<jit::kHSum, double, CPUPlace>();
} }
// XYNTuples // XYNTuples
TEST(JITKernel, kVRelu) { TEST(JITKernel, kVRelu) {
namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVRelu, float, CPUPlace>();
TestXYNKernel<jit::kVRelu, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVRelu, double, CPUPlace>();
TestXYNKernel<jit::kVRelu, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVIdentity) { TEST(JITKernel, kVIdentity) {
namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVIdentity, float, CPUPlace>();
TestXYNKernel<jit::kVIdentity, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVIdentity, double, CPUPlace>();
TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVSquare) { TEST(JITKernel, kVSquare) {
namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVSquare, float, CPUPlace>();
TestXYNKernel<jit::kVSquare, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVSquare, double, CPUPlace>();
TestXYNKernel<jit::kVSquare, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVExp) { TEST(JITKernel, kVExp) {
namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVExp, float, CPUPlace>();
TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVExp, double, CPUPlace>();
TestXYNKernel<jit::kVExp, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVSigmoid) { TEST(JITKernel, kVSigmoid) {
namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVSigmoid, float, CPUPlace>();
TestXYNKernel<jit::kVSigmoid, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVSigmoid, double, CPUPlace>();
TestXYNKernel<jit::kVSigmoid, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kVTanh) { TEST(JITKernel, kVTanh) {
namespace jit = paddle::operators::jit; TestXYNKernel<jit::kVTanh, float, CPUPlace>();
TestXYNKernel<jit::kVTanh, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVTanh, double, CPUPlace>();
TestXYNKernel<jit::kVTanh, double, paddle::platform::CPUPlace>();
} }
// LSTM // LSTM
TEST(JITKernel, kLSTMCtHt) { TEST(JITKernel, kLSTMCtHt) {
namespace jit = paddle::operators::jit; TestLSTMKernel<jit::kLSTMCtHt, float, CPUPlace>();
TestLSTMKernel<jit::kLSTMCtHt, float, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::kLSTMCtHt, double, CPUPlace>();
TestLSTMKernel<jit::kLSTMCtHt, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kLSTMC1H1) { TEST(JITKernel, kLSTMC1H1) {
namespace jit = paddle::operators::jit; TestLSTMKernel<jit::kLSTMC1H1, float, CPUPlace>();
TestLSTMKernel<jit::kLSTMC1H1, float, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::kLSTMC1H1, double, CPUPlace>();
TestLSTMKernel<jit::kLSTMC1H1, double, paddle::platform::CPUPlace>();
} }
// GRU // GRU
TEST(JITKernel, kGRUH1) { TEST(JITKernel, kGRUH1) {
namespace jit = paddle::operators::jit; TestGRUKernel<jit::kGRUH1, float, CPUPlace>();
TestGRUKernel<jit::kGRUH1, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUH1, double, CPUPlace>();
TestGRUKernel<jit::kGRUH1, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kGRUHtPart1) { TEST(JITKernel, kGRUHtPart1) {
namespace jit = paddle::operators::jit; TestGRUKernel<jit::kGRUHtPart1, float, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart1, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart1, double, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart1, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kGRUHtPart2) { TEST(JITKernel, kGRUHtPart2) {
namespace jit = paddle::operators::jit; TestGRUKernel<jit::kGRUHtPart2, float, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart2, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart2, double, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kSeqPool) { TEST(JITKernel, kSeqPool) {
namespace jit = paddle::operators::jit; TestSeqPoolKernel<jit::kSeqPool, float, CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>(); TestSeqPoolKernel<jit::kSeqPool, double, CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kMatMul) { TEST(JITKernel, kMatMul) {
namespace jit = paddle::operators::jit; TestMatMulKernel<jit::kMatMul, float, CPUPlace>();
TestMatMulKernel<jit::kMatMul, float, paddle::platform::CPUPlace>(); TestMatMulKernel<jit::kMatMul, double, CPUPlace>();
TestMatMulKernel<jit::kMatMul, double, paddle::platform::CPUPlace>(); }
TEST(JITKernel, kSoftmax) {
TestSoftmaxKernel<jit::kSoftmax, float, CPUPlace>();
TestSoftmaxKernel<jit::kSoftmax, double, CPUPlace>();
} }
TEST(JITKernel, kNCHW16CMulNC) { TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit; TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, CPUPlace>();
paddle::platform::CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double,
paddle::platform::CPUPlace>();
} }
// TODO(yihua/TJ): add crf decoding and layer norm unit tests // TODO(yihua/TJ): add crf decoding and layer norm unit tests
......
...@@ -70,6 +70,8 @@ extern void* mklml_dso_handle; ...@@ -70,6 +70,8 @@ extern void* mklml_dso_handle;
__macro(cblas_ddot); \ __macro(cblas_ddot); \
__macro(cblas_sasum); \ __macro(cblas_sasum); \
__macro(cblas_dasum); \ __macro(cblas_dasum); \
__macro(cblas_isamax); \
__macro(cblas_idamax); \
__macro(cblas_sscal); \ __macro(cblas_sscal); \
__macro(cblas_dscal); \ __macro(cblas_dscal); \
__macro(vsAdd); \ __macro(vsAdd); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册