From cfc83c14452a4776c2d8c4fe5c04c6e1ef05e945 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 11 Mar 2019 06:49:30 +0000 Subject: [PATCH] refine jitcodekey and enhance unit tests test=develop --- paddle/fluid/operators/jit/gen/act.cc | 1 + paddle/fluid/operators/jit/gen/blas.cc | 1 + paddle/fluid/operators/jit/gen/embseqpool.cc | 1 + paddle/fluid/operators/jit/gen/gru.cc | 1 + paddle/fluid/operators/jit/gen/hopv.cc | 1 + paddle/fluid/operators/jit/gen/lstm.cc | 1 + paddle/fluid/operators/jit/gen/matmul.cc | 2 +- paddle/fluid/operators/jit/gen/seqpool.cc | 1 + paddle/fluid/operators/jit/gen/sgd.cc | 1 + paddle/fluid/operators/jit/helper.h | 2 +- paddle/fluid/operators/jit/kernel_key.cc | 61 ++--- paddle/fluid/operators/jit/kernel_key.h | 2 +- paddle/fluid/operators/jit/kernel_pool.h | 6 +- paddle/fluid/operators/jit/registry.h | 1 + paddle/fluid/operators/jit/test.cc | 220 +++++++++++++++---- 15 files changed, 211 insertions(+), 91 deletions(-) diff --git a/paddle/fluid/operators/jit/gen/act.cc b/paddle/fluid/operators/jit/gen/act.cc index 5cac219f95f..ad68e792c7a 100644 --- a/paddle/fluid/operators/jit/gen/act.cc +++ b/paddle/fluid/operators/jit/gen/act.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/gen/act.h" +#include #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index e764a7983d3..c126b9077ae 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/gen/blas.h" +#include #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/jit/gen/embseqpool.cc b/paddle/fluid/operators/jit/gen/embseqpool.cc index 6e8ecc07e74..331a4b0d075 100644 --- a/paddle/fluid/operators/jit/gen/embseqpool.cc +++ b/paddle/fluid/operators/jit/gen/embseqpool.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/operators/jit/gen/embseqpool.h" #include // offsetof +#include #include #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/registry.h" diff --git a/paddle/fluid/operators/jit/gen/gru.cc b/paddle/fluid/operators/jit/gen/gru.cc index 4bc9247f6f0..b5b0cffa806 100644 --- a/paddle/fluid/operators/jit/gen/gru.cc +++ b/paddle/fluid/operators/jit/gen/gru.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/operators/jit/gen/gru.h" #include // offsetof +#include #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/jit/gen/hopv.cc b/paddle/fluid/operators/jit/gen/hopv.cc index 3383f17df8f..462ac68a932 100644 --- a/paddle/fluid/operators/jit/gen/hopv.cc +++ b/paddle/fluid/operators/jit/gen/hopv.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/gen/hopv.h" +#include #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/jit/gen/lstm.cc b/paddle/fluid/operators/jit/gen/lstm.cc index 5e7789aede1..2c3bc985e9a 100644 --- a/paddle/fluid/operators/jit/gen/lstm.cc +++ b/paddle/fluid/operators/jit/gen/lstm.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/operators/jit/gen/lstm.h" #include // offsetof +#include #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/jit/gen/matmul.cc b/paddle/fluid/operators/jit/gen/matmul.cc index ca50f26ce57..d9955c8cc65 100644 --- a/paddle/fluid/operators/jit/gen/matmul.cc +++ b/paddle/fluid/operators/jit/gen/matmul.cc @@ -14,8 +14,8 @@ #include "paddle/fluid/operators/jit/gen/matmul.h" #include // offsetof +#include #include - #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc index ceca104cc98..d9e5904add4 100644 --- a/paddle/fluid/operators/jit/gen/seqpool.cc +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/gen/seqpool.h" +#include #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/jit/gen/sgd.cc b/paddle/fluid/operators/jit/gen/sgd.cc index a40da9b9932..e65d3500b49 100644 --- a/paddle/fluid/operators/jit/gen/sgd.cc +++ b/paddle/fluid/operators/jit/gen/sgd.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/operators/jit/gen/sgd.h" #include // offsetof +#include #include #include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index d98eada81c0..1ac5318d461 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -36,7 +36,7 @@ inline typename std::enable_if< const Kernel*>::type GetJitCode(const typename KernelTuple::attr_type& attr) { using Attr = typename KernelTuple::attr_type; - size_t key = JitCodeKey(attr); + int64_t key = JitCodeKey(attr); auto& codes = JitCodePool::Instance(); if (codes.Has(key)) { return codes.AllKernels().at(key).get(); diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 6987c893de4..1ad220b3972 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -13,7 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/kernel_key.h" -#include +#include // XXH64: 13.8 GB/s #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -21,73 +21,46 @@ namespace operators { namespace jit { template <> -size_t JitCodeKey(const int& d) { +int64_t JitCodeKey(const int& d) { return d; } template <> -size_t JitCodeKey(const int64_t& d) { +int64_t JitCodeKey(const int64_t& d) { return d; } -// TODO(TJ): refine and benchmark JitCodeKey generatation -constexpr int act_type_shift = 3; // suppot 2^3 act types -static inline int act_type_convert(KernelType type) { - if (type == kVIdentity) { - return 0; - } else if (type == kVExp) { - return 1; - } else if (type == kVRelu) { - return 2; - } else if (type == kVSigmoid) { - return 3; - } else if (type == kVTanh) { - return 4; - } - PADDLE_THROW("Unsupported act type %d", type); - return 0; -} - template <> -size_t JitCodeKey(const lstm_attr_t& attr) { - // XXH64: 13.8 GB/s - - size_t key = attr.d; - int gate_key = act_type_convert(attr.act_gate) << 1; - int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift); - int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2); - return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + - attr.use_peephole; +int64_t JitCodeKey(const gru_attr_t& attr) { + return XXH64(&attr, sizeof(gru_attr_t), 0); } template <> -size_t JitCodeKey(const gru_attr_t& attr) { - size_t key = attr.d; - return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) + - (act_type_convert(attr.act_cand) << act_type_shift); +int64_t JitCodeKey(const lstm_attr_t& attr) { + int keys[5] = { + attr.d, static_cast(attr.act_gate), static_cast(attr.act_cand), + static_cast(attr.act_cell), static_cast(attr.use_peephole)}; + return XXH64(keys, sizeof(int) * 5, 0); } template <> -size_t JitCodeKey(const seq_pool_attr_t& attr) { - size_t key = attr.w; - constexpr int pool_type_shift = 3; - return (key << pool_type_shift) + static_cast(attr.type); +int64_t JitCodeKey(const seq_pool_attr_t& attr) { + int keys[2] = {attr.w, static_cast(attr.type)}; + return XXH64(keys, sizeof(int) * 2, 0); } template <> -size_t JitCodeKey(const matmul_attr_t& attr) { - size_t key = attr.m; - constexpr int shift = 21; - return (key << shift * 2) + ((static_cast(attr.n)) << shift) + attr.k; +int64_t JitCodeKey(const matmul_attr_t& attr) { + return XXH64(&attr, sizeof(int) * 3, 0); // m, n, k } template <> -size_t JitCodeKey(const emb_seq_pool_attr_t& attr) { +int64_t JitCodeKey(const emb_seq_pool_attr_t& attr) { return attr.table_width; } template <> -size_t JitCodeKey(const sgd_attr_t& attr) { +int64_t JitCodeKey(const sgd_attr_t& attr) { return attr.grad_width; } diff --git a/paddle/fluid/operators/jit/kernel_key.h b/paddle/fluid/operators/jit/kernel_key.h index 611a0210d61..b2cf92f23e8 100644 --- a/paddle/fluid/operators/jit/kernel_key.h +++ b/paddle/fluid/operators/jit/kernel_key.h @@ -46,7 +46,7 @@ struct KernelKey { // Every JitCode should have a method to get the key from attribution template -size_t JitCodeKey(const Attr& attr); +int64_t JitCodeKey(const Attr& attr); } // namespace jit } // namespace operators diff --git a/paddle/fluid/operators/jit/kernel_pool.h b/paddle/fluid/operators/jit/kernel_pool.h index 3e15242af28..ec5c2be55b2 100644 --- a/paddle/fluid/operators/jit/kernel_pool.h +++ b/paddle/fluid/operators/jit/kernel_pool.h @@ -30,7 +30,7 @@ namespace jit { template class JitCodePool { typedef std::unique_ptr GenBasePtr; - typedef std::unordered_map JitCodeMap; + typedef std::unordered_map JitCodeMap; public: JitCodePool() = default; @@ -41,9 +41,9 @@ class JitCodePool { const JitCodeMap& AllKernels() { return codes_; } - bool Has(size_t key) const { return codes_.find(key) != codes_.end(); } + bool Has(int64_t key) const { return codes_.find(key) != codes_.end(); } - void Insert(size_t key, GenBasePtr value) { + void Insert(int64_t key, GenBasePtr value) { codes_.emplace(key, std::move(value)); } diff --git a/paddle/fluid/operators/jit/registry.h b/paddle/fluid/operators/jit/registry.h index c8da92c0c53..567a9032369 100644 --- a/paddle/fluid/operators/jit/registry.h +++ b/paddle/fluid/operators/jit/registry.h @@ -17,6 +17,7 @@ #include #include #include +#include // for std::move #include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_pool.h" #include "paddle/fluid/platform/place.h" diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 898133a03b5..068b0ba7aea 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -886,7 +886,11 @@ void TestKernelVBroadcast() { // test pool TEST(JITKernel_pool, jitcreator) { const auto& jitcreators = jit::JitCodeCreatorPool::Instance().AllCreators(); +#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) + EXPECT_EQ(jitcreators.size(), 0UL); +#else EXPECT_EQ(jitcreators.size(), 25UL); +#endif } TEST(JITKernel_pool, jitpool) { @@ -894,13 +898,25 @@ TEST(JITKernel_pool, jitpool) { const auto& kers = jit::JitCodePool().Instance().AllKernels(); EXPECT_EQ(kers.size(), 0UL); jit::GetAllCandidateKernels, CPUPlace>(3); - // after call GetAllCandidateKernels, it will create jitcode Automatically +// after call GetAllCandidateKernels, it will create jitcode Automatically +#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) + EXPECT_EQ(kers.size(), 0UL); +#else EXPECT_EQ(kers.size(), 1UL); +#endif } TEST(JITKernel_pool, more) { const auto& kers = jit::KernelPool::Instance().AllKernels(); +#if defined(__APPLE__) || defined(__OSX__) + EXPECT_EQ(kers.size(), 10UL); +#else +#ifdef PADDLE_WITH_MKLML EXPECT_EQ(kers.size(), 21UL); +#else + EXPECT_EQ(kers.size(), 8UL); +#endif +#endif } TEST(JITKernel_pool, refer) { @@ -915,7 +931,11 @@ TEST(JITKernel_helper, GetAllCandidateKernels) { #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) EXPECT_GE(fp_kers.size(), 1UL); // refer #else +#ifdef PADDLE_WITH_MKLML EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer +#else + EXPECT_GE(fp_kers.size(), 2UL); // jitcode, refer +#endif #endif auto db_kers = @@ -923,18 +943,48 @@ TEST(JITKernel_helper, GetAllCandidateKernels) { #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) EXPECT_GE(db_kers.size(), 1UL); // refer #else +#ifdef PADDLE_WITH_MKLML EXPECT_GE(db_kers.size(), 2UL); // mkl, refer +#else + EXPECT_GE(db_kers.size(), 1UL); // refer +#endif #endif } TEST(JITKernel_helper, GetAllCandidateFuncsWithTypes) { auto fp_kers = jit::GetAllCandidateFuncsWithTypes, CPUPlace>(10); +#if defined(__APPLE__) || defined(__OSX__) + EXPECT_GE(fp_kers.size(), 1UL); // refer +#else +#if !defined(PADDLE_WITH_MKLML) || defined(_WIN32) + EXPECT_GE(fp_kers.size(), 2UL); // jitcode/mkl, refer +#else EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer +#endif +#endif auto db_kers = jit::GetAllCandidateFuncsWithTypes, CPUPlace>(10); +#if defined(__APPLE__) || defined(__OSX__) || !defined(PADDLE_WITH_MKLML) + EXPECT_GE(db_kers.size(), 1UL); // refer +#else EXPECT_GE(db_kers.size(), 2UL); // mkl, refer +#endif +} + +TEST(JITKernel_helper, KernelFuncs) { + auto f1 = jit::KernelFuncs, CPUPlace>::Cache().At(3); + auto f2 = jit::KernelFuncs, CPUPlace>::Cache()[3]; + EXPECT_TRUE(f1 != nullptr); + EXPECT_TRUE(f1 == f2); + + auto f3 = jit::KernelFuncs, CPUPlace>::Cache()[5]; +#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) + EXPECT_TRUE(f2 == f3); +#else + EXPECT_TRUE(f2 != f3); +#endif } TEST(JITKernel_helper, GetAllCandidateFuncs) { @@ -1011,6 +1061,134 @@ TEST(JITKernel_helper, attr) { EXPECT_EQ(out.str().size(), 14); } +// test keys +TEST(JITKernel_key, int) { + EXPECT_TRUE(jit::JitCodeKey(2) == jit::JitCodeKey(2)); + EXPECT_TRUE(jit::JitCodeKey(2) == jit::JitCodeKey(2)); + EXPECT_TRUE(jit::JitCodeKey(2) != jit::JitCodeKey(3)); +} + +TEST(JITKernel_key, gru) { + jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh); + jit::gru_attr_t attr2(8, jit::kVSigmoid, jit::kVTanh); + jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh); + jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity); + jit::gru_attr_t attr5(9, jit::kVTanh, jit::kVIdentity); + + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); + auto key5 = jit::JitCodeKey(attr5); + + EXPECT_TRUE(key1 == key2); + EXPECT_TRUE(key2 != key3); + EXPECT_TRUE(key2 != key4); + EXPECT_TRUE(key2 != key5); + EXPECT_TRUE(key3 != key4); + EXPECT_TRUE(key3 != key5); + EXPECT_TRUE(key4 != key5); +} + +TEST(JITKernel_key, lstm) { + jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); + jit::lstm_attr_t attr2(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); + jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); + jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh); + jit::lstm_attr_t attr5(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh, true); + jit::lstm_attr_t attr6(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh, true); + + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); + auto key5 = jit::JitCodeKey(attr5); + auto key6 = jit::JitCodeKey(attr6); + + EXPECT_TRUE(key1 == key2); + EXPECT_TRUE(key2 != key3); + EXPECT_TRUE(key2 != key4); + EXPECT_TRUE(key2 != key5); + EXPECT_TRUE(key3 != key4); + EXPECT_TRUE(key3 != key5); + EXPECT_TRUE(key4 != key5); + EXPECT_TRUE(key5 == key6); +} + +TEST(JITKernel_key, seq_pool) { + jit::seq_pool_attr_t attr1(2, jit::SeqPoolType::kSum, 1); + jit::seq_pool_attr_t attr2(2, jit::SeqPoolType::kSum, 3); + jit::seq_pool_attr_t attr3(3, jit::SeqPoolType::kSum, 3); + jit::seq_pool_attr_t attr4(3, jit::SeqPoolType::kAvg, 3); + + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); + + EXPECT_TRUE(key1 == key2); + EXPECT_TRUE(key2 != key3); + EXPECT_TRUE(key2 != key4); + EXPECT_TRUE(key3 != key4); +} + +TEST(JITKernel_key, matmul) { + jit::matmul_attr_t attr1(1, 2, 3); + jit::matmul_attr_t attr2(1, 2, 3); + jit::matmul_attr_t attr3(1, 3, 3); + jit::matmul_attr_t attr4(2, 3, 4); + + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); + + EXPECT_TRUE(key1 == key2); + EXPECT_TRUE(key2 != key3); + EXPECT_TRUE(key2 != key4); + EXPECT_TRUE(key3 != key4); +} + +TEST(JITKernel_key, emb_seq_pool) { + jit::emb_seq_pool_attr_t attr1(1, 2, 3, 4, 5, jit::SeqPoolType::kSum); + jit::emb_seq_pool_attr_t attr2(1, 2, 3, 4, 5, jit::SeqPoolType::kSum); + jit::emb_seq_pool_attr_t attr3(10, 2, 9, 8, 7, jit::SeqPoolType::kAvg); + jit::emb_seq_pool_attr_t attr4(10, 3, 9, 8, 7, jit::SeqPoolType::kSum); + jit::emb_seq_pool_attr_t attr5(1, 6, 3, 4, 5, jit::SeqPoolType::kSum); + + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); + auto key5 = jit::JitCodeKey(attr5); + + EXPECT_TRUE(key1 == key2); + EXPECT_TRUE(key2 == key3); + EXPECT_TRUE(key2 != key4); + EXPECT_TRUE(key2 != key5); + EXPECT_TRUE(key4 != key5); +} + +TEST(JITKernel_key, sgd) { + jit::sgd_attr_t attr1(1, 2, 3, 4, 5); + jit::sgd_attr_t attr2(1, 2, 3, 4, 5); + jit::sgd_attr_t attr3(9, 8, 7, 4, 6); + jit::sgd_attr_t attr4(1, 2, 3, 6, 5); + jit::sgd_attr_t attr5(10, 9, 8, 7, 6); + + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); + auto key5 = jit::JitCodeKey(attr5); + + EXPECT_TRUE(key1 == key2); + EXPECT_TRUE(key2 == key3); + EXPECT_TRUE(key3 != key4); + EXPECT_TRUE(key3 != key5); + EXPECT_TRUE(key4 != key5); +} + // test kernerls #define TestKernelVMul TestKernelXYZN #define TestKernelVAdd TestKernelXYZN @@ -1080,43 +1258,3 @@ TEST_CPU_KERNEL(MatMul); TEST_CPU_KERNEL(Softmax); TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(VBroadcast); - -TEST(JITKernel, kernel_func) { - auto f1 = jit::KernelFuncs, CPUPlace>::Cache().At(3); - auto f2 = jit::KernelFuncs, CPUPlace>::Cache()[3]; - EXPECT_TRUE(f1 != nullptr); - EXPECT_TRUE(f1 == f2); - // TODO(TJ): check not equal -} - -TEST(JITKernel_key, lstm) { - jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); - jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); - jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); - jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh); - - auto key1 = jit::JitCodeKey(attr1); - auto key2 = jit::JitCodeKey(attr2); - auto key3 = jit::JitCodeKey(attr3); - auto key4 = jit::JitCodeKey(attr4); - - EXPECT_TRUE(key1 != key2); - EXPECT_TRUE(key2 == key3); - EXPECT_TRUE(key3 != key4); -} - -TEST(JITKernel_key, gru) { - jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh); - jit::gru_attr_t attr2(9, jit::kVSigmoid, jit::kVTanh); - jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh); - jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity); - - auto key1 = jit::JitCodeKey(attr1); - auto key2 = jit::JitCodeKey(attr2); - auto key3 = jit::JitCodeKey(attr3); - auto key4 = jit::JitCodeKey(attr4); - - EXPECT_TRUE(key1 != key2); - EXPECT_TRUE(key2 == key3); - EXPECT_TRUE(key3 != key4); -} -- GitLab