提交 81177258 编写于 作者: T tensor-tang

add jit kernel hsum, hmax and softmax refer code

test=develop
上级 67e4450c
......@@ -158,7 +158,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
using Tensor = paddle::framework::Tensor;
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() {
for (int d : TestSizes()) {
Tensor x, y, z;
......@@ -175,7 +175,7 @@ void BenchXYZNKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() {
for (int d : TestSizes()) {
const T a = static_cast<T>(3);
......@@ -190,7 +190,17 @@ void BenchAXYNKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXRNKernel() {
for (int d : TestSizes()) {
Tensor x;
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
T res;
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() {
for (int d : TestSizes()) {
Tensor x, y;
......@@ -203,7 +213,7 @@ void BenchXYNKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() {
for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) {
......@@ -240,7 +250,7 @@ void BenchLSTMKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() {
for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
......@@ -262,7 +272,7 @@ void BenchGRUKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
......@@ -284,7 +294,7 @@ void BenchSeqPoolKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) {
......@@ -305,57 +315,64 @@ void BenchMatMulKernel() {
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSoftmaxKernel() {
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
Tensor x, y;
x.Resize({bs, n});
y.Resize({bs, n});
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n,
bs);
}
}
}
using T = float;
using PlaceType = paddle::platform::CPUPlace;
using CPUPlace = paddle::platform::CPUPlace;
// xyzn
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, PlaceType>(); }
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, PlaceType>(); }
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, PlaceType>(); }
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
// axyn
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, PlaceType>(); }
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, PlaceType>(); }
// xrn
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); }
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
// xyn
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, PlaceType>(); }
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); }
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, PlaceType>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, PlaceType>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, PlaceType>(); }
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
// lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>(); }
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
// gru functions
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); }
BENCH_FP32_CPU(kGRUHtPart1) {
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
}
BENCH_FP32_CPU(kGRUHtPart2) {
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
}
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
// seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>(); }
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// matmul
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, PlaceType>(); }
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
// softmax
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); }
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
......
......@@ -49,6 +49,9 @@ const char* to_string(KernelType kt) {
ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool);
ONE_CASE(kMatMul);
ONE_CASE(kHMax);
ONE_CASE(kHSum);
ONE_CASE(kSoftmax);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
......
......@@ -20,6 +20,7 @@ namespace paddle {
namespace operators {
namespace jit {
// TODO(TJ): reorder by alphabet
typedef enum {
kNone = 0,
kVMul = 1,
......@@ -44,6 +45,9 @@ typedef enum {
kNCHW16CMulNC,
kSeqPool,
kMatMul,
kHSum, // horizontal max
kHMax, // horizontal sum
kSoftmax,
} KernelType;
typedef enum {
......@@ -70,6 +74,10 @@ struct XYNTuples {
typedef void (*func_type)(const T*, T*, int);
};
// x, return and int
template <typename T>
struct XRNTuples : public XYNTuples<T> {};
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1;
......@@ -159,6 +167,13 @@ struct LayerNormTuples {
const float, int);
};
template <typename T>
struct SoftmaxTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
};
// nChw16c = nChw16c .* NC
template <typename T>
struct NCHW16CMulNCTuples {
......
......@@ -29,3 +29,6 @@ USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool)
USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kSoftmax)
......@@ -52,4 +52,9 @@ REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
REGISTER_REFER_KERNEL(kMatMul, MatMul);
REGISTER_REFER_KERNEL(kHMax, HMax);
REGISTER_REFER_KERNEL(kHSum, HSum);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
#undef REGISTER_REFER_KERNEL
......@@ -378,6 +378,40 @@ void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {
}
}
template <typename T>
void HMax(const T* x, T* res, int n) {
res[0] = x[0];
for (int i = 1; i < n; ++i) {
res[0] = res[0] < x[i] ? x[i] : res[0];
}
}
template <typename T>
void HSum(const T* x, T* res, int n) {
res[0] = x[0];
for (int i = 1; i < n; ++i) {
res[0] += x[i];
}
}
// y = e^(x - max(x))
// y = y / sum(y)
template <typename T>
void Softmax(const T* x, T* y, int n, int bs = 1) {
for (int i = 0; i < bs; ++i) {
T scalar;
HMax(x, &scalar, n);
scalar = static_cast<T>(0) - scalar;
VAddBias(&scalar, x, y, n); // x - max
VExp(y, y, n);
HSum(y, &scalar, n);
scalar = static_cast<T>(1) / scalar;
VScal(&scalar, y, y, n);
x += n;
y += n;
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
......@@ -421,6 +455,11 @@ DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples);
DECLARE_REFER_KERNEL(HMax, XRNTuples);
DECLARE_REFER_KERNEL(HSum, XRNTuples);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer
......
......@@ -61,6 +61,7 @@ std::vector<int> TestSizes() {
}
namespace jit = paddle::operators::jit;
using CPUPlace = paddle::platform::CPUPlace;
template <typename KernelTuples, typename... Args>
struct TestFuncWithRefer {
......@@ -121,6 +122,40 @@ struct TestFuncWithRefer<jit::AXYNTuples<T>, T, std::vector<T>,
}
};
template <typename T>
struct TestFuncWithRefer<jit::SoftmaxTuples<T>, std::vector<T>, std::vector<T>,
int, int> {
void operator()(const typename jit::SoftmaxTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref, int n,
int bs) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
const T* x_data = x.data();
const T* yref_data = yref.data();
std::vector<T> ytgt(n * bs);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, n, bs);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, n, bs);
ExpectEQ<T>(ytgt_data, yref_data, n * bs);
}
};
template <typename T>
struct TestFuncWithRefer<jit::XRNTuples<T>, std::vector<T>, T> {
void operator()(const typename jit::XRNTuples<T>::func_type tgt,
const std::vector<T>& x, const T ref_res) {
EXPECT_TRUE(tgt != nullptr);
T tgt_res;
tgt(x.data(), &tgt_res, x.size());
ExpectEQ<T>(&tgt_res, &ref_res, 1);
}
};
template <typename T>
struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::XYNTuples<T>::func_type tgt,
......@@ -172,7 +207,7 @@ struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>,
T* ht_data = ht.data();
T* checked_data = checked.data();
paddle::operators::jit::lstm_t step;
jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
......@@ -208,7 +243,7 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ht_data = ht.data();
paddle::operators::jit::gru_t step;
jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
......@@ -255,8 +290,8 @@ struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args>
template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
TestFuncWithRefer<KernelTuples, Args...> test;
// test jitcode
......@@ -286,9 +321,8 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
test(tgt, args...);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYZNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
......@@ -320,9 +354,8 @@ void TestXYZNKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void TestAXYNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>();
......@@ -347,9 +380,23 @@ void TestAXYNKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void TestXRNKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d);
RandomVec<T>(d, x.data());
T ref_res;
ref(x.data(), &ref_res, d);
TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x,
ref_res);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>();
......@@ -373,9 +420,8 @@ void TestXYNKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void TestLSTMKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) {
......@@ -424,9 +470,8 @@ void TestLSTMKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void TestGRUKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) {
......@@ -459,7 +504,7 @@ void TestGRUKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<jit::SeqPoolType> pool_types = {
......@@ -484,7 +529,7 @@ void TestSeqPoolKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void TestMatMulKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = acc;
......@@ -510,7 +555,32 @@ void TestMatMulKernel() {
acc = last_acc;
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
template <jit::KernelType KT, typename T, typename PlaceType>
void TestSoftmaxKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::SoftmaxTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* y_data = y.data();
std::vector<T> xinp(x.size()); // inplace test
std::copy(x.begin(), x.end(), xinp.begin());
ref(x_data, y_data, n, bs);
T* xinp_data = xinp.data();
ref(xinp_data, xinp_data, n, bs);
ExpectEQ<T>(xinp_data, y_data, n * bs);
TestAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType, std::vector<T>,
std::vector<T>>(n, x, y, n, bs);
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const int n = 3, c = 16 * 4, h = 10, w = 10;
......@@ -565,129 +635,123 @@ void TestNCHW16CMulNCKernel() {
// XYZNTuple
TEST(JITKernel, kVMul) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::kVMul, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVMul, double, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVMul, float, CPUPlace>();
TestXYZNKernel<jit::kVMul, double, CPUPlace>();
}
TEST(JITKernel, kVAdd) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::kVAdd, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVAdd, double, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVAdd, float, CPUPlace>();
TestXYZNKernel<jit::kVAdd, double, CPUPlace>();
}
TEST(JITKernel, kVAddRelu) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::kVAddRelu, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, double, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, float, CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, double, CPUPlace>();
}
TEST(JITKernel, kVSub) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::kVSub, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVSub, double, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVSub, float, CPUPlace>();
TestXYZNKernel<jit::kVSub, double, CPUPlace>();
}
// AXYNTuples
TEST(JITKernel, kVScal) {
namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::kVScal, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::kVScal, double, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::kVScal, float, CPUPlace>();
TestAXYNKernel<jit::kVScal, double, CPUPlace>();
}
TEST(JITKernel, kVAddBias) {
namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::kVAddBias, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::kVAddBias, double, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::kVAddBias, float, CPUPlace>();
TestAXYNKernel<jit::kVAddBias, double, CPUPlace>();
}
// XRNTuples
TEST(JITKernel, kHMax) {
TestXRNKernel<jit::kHMax, float, CPUPlace>();
TestXRNKernel<jit::kHMax, double, CPUPlace>();
}
TEST(JITKernel, kHSum) {
TestXRNKernel<jit::kHSum, float, CPUPlace>();
TestXRNKernel<jit::kHSum, double, CPUPlace>();
}
// XYNTuples
TEST(JITKernel, kVRelu) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVRelu, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVRelu, double, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVRelu, float, CPUPlace>();
TestXYNKernel<jit::kVRelu, double, CPUPlace>();
}
TEST(JITKernel, kVIdentity) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVIdentity, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVIdentity, float, CPUPlace>();
TestXYNKernel<jit::kVIdentity, double, CPUPlace>();
}
TEST(JITKernel, kVSquare) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVSquare, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVSquare, double, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVSquare, float, CPUPlace>();
TestXYNKernel<jit::kVSquare, double, CPUPlace>();
}
TEST(JITKernel, kVExp) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVExp, double, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVExp, float, CPUPlace>();
TestXYNKernel<jit::kVExp, double, CPUPlace>();
}
TEST(JITKernel, kVSigmoid) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVSigmoid, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVSigmoid, double, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVSigmoid, float, CPUPlace>();
TestXYNKernel<jit::kVSigmoid, double, CPUPlace>();
}
TEST(JITKernel, kVTanh) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVTanh, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVTanh, double, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVTanh, float, CPUPlace>();
TestXYNKernel<jit::kVTanh, double, CPUPlace>();
}
// LSTM
TEST(JITKernel, kLSTMCtHt) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::kLSTMCtHt, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::kLSTMCtHt, double, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::kLSTMCtHt, float, CPUPlace>();
TestLSTMKernel<jit::kLSTMCtHt, double, CPUPlace>();
}
TEST(JITKernel, kLSTMC1H1) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::kLSTMC1H1, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::kLSTMC1H1, double, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::kLSTMC1H1, float, CPUPlace>();
TestLSTMKernel<jit::kLSTMC1H1, double, CPUPlace>();
}
// GRU
TEST(JITKernel, kGRUH1) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::kGRUH1, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUH1, double, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUH1, float, CPUPlace>();
TestGRUKernel<jit::kGRUH1, double, CPUPlace>();
}
TEST(JITKernel, kGRUHtPart1) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::kGRUHtPart1, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUHtPart1, double, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUHtPart1, float, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart1, double, CPUPlace>();
}
TEST(JITKernel, kGRUHtPart2) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::kGRUHtPart2, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUHtPart2, float, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart2, double, CPUPlace>();
}
TEST(JITKernel, kSeqPool) {
namespace jit = paddle::operators::jit;
TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, float, CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, CPUPlace>();
}
TEST(JITKernel, kMatMul) {
namespace jit = paddle::operators::jit;
TestMatMulKernel<jit::kMatMul, float, paddle::platform::CPUPlace>();
TestMatMulKernel<jit::kMatMul, double, paddle::platform::CPUPlace>();
TestMatMulKernel<jit::kMatMul, float, CPUPlace>();
TestMatMulKernel<jit::kMatMul, double, CPUPlace>();
}
TEST(JITKernel, kSoftmax) {
TestSoftmaxKernel<jit::kSoftmax, float, CPUPlace>();
TestSoftmaxKernel<jit::kSoftmax, double, CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
paddle::platform::CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double,
paddle::platform::CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, CPUPlace>();
}
// TODO(yihua/TJ): add crf decoding and layer norm unit tests
......
......@@ -70,6 +70,8 @@ extern void* mklml_dso_handle;
__macro(cblas_ddot); \
__macro(cblas_sasum); \
__macro(cblas_dasum); \
__macro(cblas_isamax); \
__macro(cblas_idamax); \
__macro(cblas_sscal); \
__macro(cblas_dscal); \
__macro(vsAdd); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册