提交 8bc63815 编写于 作者: T tensor-tang

fix jitcodekey and refine test

test=develop
上级 7044cfa7
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,14 +24,30 @@ size_t JitCodeKey<int>(const int& d) { ...@@ -23,14 +24,30 @@ size_t JitCodeKey<int>(const int& d) {
return d; return d;
} }
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <> template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d; size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1; int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift); int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2); int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole; attr.use_peephole;
} }
...@@ -38,8 +55,8 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { ...@@ -38,8 +55,8 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
template <> template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d; size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) + return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift); (act_type_convert(attr.act_cand) << act_type_shift);
} }
template <> template <>
......
...@@ -40,11 +40,11 @@ template <typename T> ...@@ -40,11 +40,11 @@ template <typename T>
void ExpectEQ(const T* target, const T* refer, size_t n) { void ExpectEQ(const T* target, const T* refer, size_t n) {
if (std::is_floating_point<T>::value) { if (std::is_floating_point<T>::value) {
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
EXPECT_NEAR(target[i], refer[i], FLAGS_acc); EXPECT_NEAR(target[i], refer[i], FLAGS_acc) << " at index : " << i;
} }
} else { } else {
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
EXPECT_EQ(target[i], refer[i]); EXPECT_EQ(target[i], refer[i]) << " at index : " << i;
} }
} }
} }
...@@ -447,7 +447,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -447,7 +447,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYZNKernel() { void TestKernelXYZNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
...@@ -480,7 +480,7 @@ void TestXYZNKernel() { ...@@ -480,7 +480,7 @@ void TestXYZNKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestAXYNKernel() { void TestKernelAXYNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>();
...@@ -506,7 +506,7 @@ void TestAXYNKernel() { ...@@ -506,7 +506,7 @@ void TestAXYNKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestXRNKernel() { void TestKernelXRNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = FLAGS_acc; auto last_acc = FLAGS_acc;
FLAGS_acc = 1e-4; FLAGS_acc = 1e-4;
...@@ -524,7 +524,7 @@ void TestXRNKernel() { ...@@ -524,7 +524,7 @@ void TestXRNKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYNKernel() { void TestKernelXYNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>();
...@@ -549,10 +549,12 @@ void TestXYNKernel() { ...@@ -549,10 +549,12 @@ void TestXYNKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestLSTMKernel() { void TestKernelLSTMTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"}; std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) { auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int d : test_sizes) {
for (bool use_peephole : {true, false}) { for (bool use_peephole : {true, false}) {
for (auto& act_gate : all_acts) { for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) { for (auto& act_cand : all_acts) {
...@@ -599,10 +601,12 @@ void TestLSTMKernel() { ...@@ -599,10 +601,12 @@ void TestLSTMKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestGRUKernel() { void TestKernelGRUTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"}; std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) { auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int d : test_sizes) {
for (auto& act_gate : all_acts) { for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) { for (auto& act_cand : all_acts) {
const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
...@@ -633,14 +637,16 @@ void TestGRUKernel() { ...@@ -633,14 +637,16 @@ void TestGRUKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() { void TestKernelSeqPoolTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (auto type : pool_types) { for (auto type : pool_types) {
for (int w : TestSizes()) { for (int w : test_sizes) {
jit::seq_pool_attr_t attr(w, type); jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) { for (int h : test_sizes) {
attr.h = h; attr.h = h;
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
...@@ -658,11 +664,11 @@ void TestSeqPoolKernel() { ...@@ -658,11 +664,11 @@ void TestSeqPoolKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestMatMulKernel() { void TestKernelMatMulTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = FLAGS_acc; auto last_acc = FLAGS_acc;
// TODO(intel): fix MKL acc issue // export MKL_CBWR=AVX would make MKL force to use AVX
// https://github.com/PaddlePaddle/Paddle/issues/15447 // export KMP_DETERMINISTIC_REDUCTION=yes would make the result deterministic
FLAGS_acc = 1e-3; FLAGS_acc = 1e-3;
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) { for (int n : {1, 2, 3, 4}) {
...@@ -686,7 +692,7 @@ void TestMatMulKernel() { ...@@ -686,7 +692,7 @@ void TestMatMulKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestSoftmaxKernel() { void TestKernelSoftmaxTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int bs : {1, 2, 10}) { for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
...@@ -711,12 +717,14 @@ void TestSoftmaxKernel() { ...@@ -711,12 +717,14 @@ void TestSoftmaxKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestEmbSeqPoolKernel() { void TestKernelEmbSeqPoolTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
int64_t tbl_h = 1e4; int64_t tbl_h = 1e4;
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum}; // only support sum yet jit::SeqPoolType::kSum}; // only support sum yet
for (int tbl_w : TestSizes()) { auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int tbl_w : test_sizes) {
std::vector<T> table(tbl_h * tbl_w); std::vector<T> table(tbl_h * tbl_w);
RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f); RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f);
const T* table_data = table.data(); const T* table_data = table.data();
...@@ -745,7 +753,7 @@ void TestEmbSeqPoolKernel() { ...@@ -745,7 +753,7 @@ void TestEmbSeqPoolKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestSgdKernel() { void TestKernelSgdTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const T lr = 0.1; const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower, auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
...@@ -799,7 +807,7 @@ void TestSgdKernel() { ...@@ -799,7 +807,7 @@ void TestSgdKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() { void TestKernelNCHW16CMulNCTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const int n = 3, c = 16 * 4, h = 10, w = 10; const int n = 3, c = 16 * 4, h = 10, w = 10;
auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>(); auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>();
...@@ -852,7 +860,7 @@ void TestNCHW16CMulNCKernel() { ...@@ -852,7 +860,7 @@ void TestNCHW16CMulNCKernel() {
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestLayerNormKernel() { void TestKernelLayerNormTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const T epsilon = 9.99999975e-06; const T epsilon = 9.99999975e-06;
for (int n : {1, 2, 10}) { for (int n : {1, 2, 10}) {
...@@ -891,11 +899,13 @@ void TestLayerNormKernel() { ...@@ -891,11 +899,13 @@ void TestLayerNormKernel() {
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestCRFDecodingKernel() { void TestKernelCRFDecodingTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
constexpr int state_trans_base_idx = 2; constexpr int state_trans_base_idx = 2;
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int seq_len : {1, 11, 17, 50}) { for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : TestSizes()) { for (int tag_num : test_sizes) {
auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>(); auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
int x_sz = seq_len * tag_num; int x_sz = seq_len * tag_num;
...@@ -916,148 +926,76 @@ void TestCRFDecodingKernel() { ...@@ -916,148 +926,76 @@ void TestCRFDecodingKernel() {
} }
} }
// XYZNTuple #define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kVMul) { TEST(JITKernel, kernel_type) { \
TestXYZNKernel<jit::kVMul, float, CPUPlace>(); TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestXYZNKernel<jit::kVMul, double, CPUPlace>(); TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
} }
TEST(JITKernel, kVAdd) {
TestXYZNKernel<jit::kVAdd, float, CPUPlace>();
TestXYZNKernel<jit::kVAdd, double, CPUPlace>();
}
TEST(JITKernel, kVAddRelu) {
TestXYZNKernel<jit::kVAddRelu, float, CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, double, CPUPlace>();
}
TEST(JITKernel, kVSub) {
TestXYZNKernel<jit::kVSub, float, CPUPlace>();
TestXYZNKernel<jit::kVSub, double, CPUPlace>();
}
// AXYNTuples
TEST(JITKernel, kVScal) {
TestAXYNKernel<jit::kVScal, float, CPUPlace>();
TestAXYNKernel<jit::kVScal, double, CPUPlace>();
}
TEST(JITKernel, kVAddBias) {
TestAXYNKernel<jit::kVAddBias, float, CPUPlace>();
TestAXYNKernel<jit::kVAddBias, double, CPUPlace>();
}
// XRNTuples
TEST(JITKernel, kHMax) {
TestXRNKernel<jit::kHMax, float, CPUPlace>();
TestXRNKernel<jit::kHMax, double, CPUPlace>();
}
TEST(JITKernel, kHSum) {
TestXRNKernel<jit::kHSum, float, CPUPlace>();
TestXRNKernel<jit::kHSum, double, CPUPlace>();
}
// XYNTuples
TEST(JITKernel, kVRelu) {
TestXYNKernel<jit::kVRelu, float, CPUPlace>();
TestXYNKernel<jit::kVRelu, double, CPUPlace>();
}
TEST(JITKernel, kVIdentity) {
TestXYNKernel<jit::kVIdentity, float, CPUPlace>();
TestXYNKernel<jit::kVIdentity, double, CPUPlace>();
}
TEST(JITKernel, kVSquare) {
TestXYNKernel<jit::kVSquare, float, CPUPlace>();
TestXYNKernel<jit::kVSquare, double, CPUPlace>();
}
TEST(JITKernel, kVExp) { TEST_CPU_KERNEL(XYZNTuples, kVMul);
TestXYNKernel<jit::kVExp, float, CPUPlace>(); TEST_CPU_KERNEL(XYZNTuples, kVAdd);
TestXYNKernel<jit::kVExp, double, CPUPlace>(); TEST_CPU_KERNEL(XYZNTuples, kVAddRelu);
} TEST_CPU_KERNEL(XYZNTuples, kVSub);
TEST(JITKernel, kVSigmoid) { TEST_CPU_KERNEL(AXYNTuples, kVScal);
TestXYNKernel<jit::kVSigmoid, float, CPUPlace>(); TEST_CPU_KERNEL(AXYNTuples, kVAddBias);
TestXYNKernel<jit::kVSigmoid, double, CPUPlace>();
}
TEST(JITKernel, kVTanh) { TEST_CPU_KERNEL(XRNTuples, kHMax);
TestXYNKernel<jit::kVTanh, float, CPUPlace>(); TEST_CPU_KERNEL(XRNTuples, kHSum);
TestXYNKernel<jit::kVTanh, double, CPUPlace>();
}
// LSTM TEST_CPU_KERNEL(XYNTuples, kVRelu);
TEST(JITKernel, kLSTMCtHt) { TEST_CPU_KERNEL(XYNTuples, kVIdentity);
TestLSTMKernel<jit::kLSTMCtHt, float, CPUPlace>(); TEST_CPU_KERNEL(XYNTuples, kVSquare);
TestLSTMKernel<jit::kLSTMCtHt, double, CPUPlace>(); TEST_CPU_KERNEL(XYNTuples, kVExp);
} TEST_CPU_KERNEL(XYNTuples, kVSigmoid);
TEST_CPU_KERNEL(XYNTuples, kVTanh);
TEST(JITKernel, kLSTMC1H1) { TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt);
TestLSTMKernel<jit::kLSTMC1H1, float, CPUPlace>(); TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1);
TestLSTMKernel<jit::kLSTMC1H1, double, CPUPlace>();
}
// GRU TEST_CPU_KERNEL(GRUTuples, kGRUH1);
TEST(JITKernel, kGRUH1) { TEST_CPU_KERNEL(GRUTuples, kGRUHtPart1);
TestGRUKernel<jit::kGRUH1, float, CPUPlace>(); TEST_CPU_KERNEL(GRUTuples, kGRUHtPart2);
TestGRUKernel<jit::kGRUH1, double, CPUPlace>();
}
TEST(JITKernel, kGRUHtPart1) { TEST_CPU_KERNEL(NCHW16CMulNCTuples, kNCHW16CMulNC);
TestGRUKernel<jit::kGRUHtPart1, float, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart1, double, CPUPlace>();
}
TEST(JITKernel, kGRUHtPart2) { TEST_CPU_KERNEL(SeqPoolTuples, kSeqPool);
TestGRUKernel<jit::kGRUHtPart2, float, CPUPlace>(); TEST_CPU_KERNEL(MatMulTuples, kMatMul);
TestGRUKernel<jit::kGRUHtPart2, double, CPUPlace>(); TEST_CPU_KERNEL(SoftmaxTuples, kSoftmax);
} TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool);
TEST_CPU_KERNEL(SgdTuples, kSgd);
TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm);
TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding);
TEST(JITKernel, kSeqPool) { TEST(JITKernel_key, lstm) {
TestSeqPoolKernel<jit::kSeqPool, float, CPUPlace>(); jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
TestSeqPoolKernel<jit::kSeqPool, double, CPUPlace>(); jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
} jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh);
TEST(JITKernel, kMatMul) {
TestMatMulKernel<jit::kMatMul, float, CPUPlace>();
TestMatMulKernel<jit::kMatMul, double, CPUPlace>();
}
TEST(JITKernel, kSoftmax) {
TestSoftmaxKernel<jit::kSoftmax, float, CPUPlace>();
TestSoftmaxKernel<jit::kSoftmax, double, CPUPlace>();
}
TEST(JITKernel, kEmbSeqPool) { auto key1 = jit::JitCodeKey<jit::lstm_attr_t>(attr1);
TestEmbSeqPoolKernel<jit::kEmbSeqPool, float, CPUPlace>(); auto key2 = jit::JitCodeKey<jit::lstm_attr_t>(attr2);
TestEmbSeqPoolKernel<jit::kEmbSeqPool, double, CPUPlace>(); auto key3 = jit::JitCodeKey<jit::lstm_attr_t>(attr3);
} auto key4 = jit::JitCodeKey<jit::lstm_attr_t>(attr4);
TEST(JITKernel, kSgd) { EXPECT_TRUE(key1 != key2);
TestSgdKernel<jit::kSgd, float, CPUPlace>(); EXPECT_TRUE(key2 == key3);
TestSgdKernel<jit::kSgd, double, CPUPlace>(); EXPECT_TRUE(key3 != key4);
} }
TEST(JITKernel, kNCHW16CMulNC) { TEST(JITKernel_key, gru) {
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, CPUPlace>(); jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, CPUPlace>(); jit::gru_attr_t attr2(9, jit::kVSigmoid, jit::kVTanh);
} jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
TEST(JITKernel, kLayerNorm) { auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
TestLayerNormKernel<jit::kLayerNorm, float, paddle::platform::CPUPlace>(); auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
TestLayerNormKernel<jit::kLayerNorm, double, paddle::platform::CPUPlace>(); auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
} auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
TEST(JITKernel, kCRFDecoding) {
TestCRFDecodingKernel<jit::kCRFDecoding, float, paddle::platform::CPUPlace>();
TestCRFDecodingKernel<jit::kCRFDecoding, double,
paddle::platform::CPUPlace>();
}
TEST(JITKernel, pool) { EXPECT_TRUE(key1 != key2);
// TODO(TJ): add some test EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
} }
// TODO(TJ): add more test about key and pool
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册