提交 cfc83c14 编写于 作者: T tensor-tang

refine jitcodekey and enhance unit tests

test=develop
上级 6ff230a6
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h" #include "paddle/fluid/operators/jit/gen/act.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h" #include "paddle/fluid/operators/jit/gen/blas.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/embseqpool.h" #include "paddle/fluid/operators/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/gru.h" #include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h" #include "paddle/fluid/operators/jit/gen/hopv.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/lstm.h" #include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/fluid/operators/jit/gen/matmul.h" #include "paddle/fluid/operators/jit/gen/matmul.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h" #include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <memory>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/sgd.h" #include "paddle/fluid/operators/jit/gen/sgd.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
......
...@@ -36,7 +36,7 @@ inline typename std::enable_if< ...@@ -36,7 +36,7 @@ inline typename std::enable_if<
const Kernel*>::type const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
using Attr = typename KernelTuple::attr_type; using Attr = typename KernelTuple::attr_type;
size_t key = JitCodeKey<Attr>(attr); int64_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance(); auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
if (codes.Has(key)) { if (codes.Has(key)) {
return codes.AllKernels().at(key).get(); return codes.AllKernels().at(key).get();
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h> #include <xxhash.h> // XXH64: 13.8 GB/s
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -21,73 +21,46 @@ namespace operators { ...@@ -21,73 +21,46 @@ namespace operators {
namespace jit { namespace jit {
template <> template <>
size_t JitCodeKey<int>(const int& d) { int64_t JitCodeKey<int>(const int& d) {
return d; return d;
} }
template <> template <>
size_t JitCodeKey<int64_t>(const int64_t& d) { int64_t JitCodeKey<int64_t>(const int64_t& d) {
return d; return d;
} }
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <> template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { int64_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
// XXH64: 13.8 GB/s return XXH64(&attr, sizeof(gru_attr_t), 0);
size_t key = attr.d;
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
} }
template <> template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { int64_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d; int keys[5] = {
return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) + attr.d, static_cast<int>(attr.act_gate), static_cast<int>(attr.act_cand),
(act_type_convert(attr.act_cand) << act_type_shift); static_cast<int>(attr.act_cell), static_cast<int>(attr.use_peephole)};
return XXH64(keys, sizeof(int) * 5, 0);
} }
template <> template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) { int64_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w; int keys[2] = {attr.w, static_cast<int>(attr.type)};
constexpr int pool_type_shift = 3; return XXH64(keys, sizeof(int) * 2, 0);
return (key << pool_type_shift) + static_cast<int>(attr.type);
} }
template <> template <>
size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) { int64_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
size_t key = attr.m; return XXH64(&attr, sizeof(int) * 3, 0); // m, n, k
constexpr int shift = 21;
return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
} }
template <> template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) { int64_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return attr.table_width; return attr.table_width;
} }
template <> template <>
size_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) { int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width; return attr.grad_width;
} }
......
...@@ -46,7 +46,7 @@ struct KernelKey { ...@@ -46,7 +46,7 @@ struct KernelKey {
// Every JitCode should have a method to get the key from attribution // Every JitCode should have a method to get the key from attribution
template <typename Attr> template <typename Attr>
size_t JitCodeKey(const Attr& attr); int64_t JitCodeKey(const Attr& attr);
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
......
...@@ -30,7 +30,7 @@ namespace jit { ...@@ -30,7 +30,7 @@ namespace jit {
template <KernelType KT> template <KernelType KT>
class JitCodePool { class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr; typedef std::unique_ptr<GenBase> GenBasePtr;
typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap; typedef std::unordered_map<int64_t, GenBasePtr> JitCodeMap;
public: public:
JitCodePool() = default; JitCodePool() = default;
...@@ -41,9 +41,9 @@ class JitCodePool { ...@@ -41,9 +41,9 @@ class JitCodePool {
const JitCodeMap& AllKernels() { return codes_; } const JitCodeMap& AllKernels() { return codes_; }
bool Has(size_t key) const { return codes_.find(key) != codes_.end(); } bool Has(int64_t key) const { return codes_.find(key) != codes_.end(); }
void Insert(size_t key, GenBasePtr value) { void Insert(int64_t key, GenBasePtr value) {
codes_.emplace(key, std::move(value)); codes_.emplace(key, std::move(value));
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
#include <utility> // for std::move
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h" #include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
......
...@@ -886,7 +886,11 @@ void TestKernelVBroadcast() { ...@@ -886,7 +886,11 @@ void TestKernelVBroadcast() {
// test pool // test pool
TEST(JITKernel_pool, jitcreator) { TEST(JITKernel_pool, jitcreator) {
const auto& jitcreators = jit::JitCodeCreatorPool::Instance().AllCreators(); const auto& jitcreators = jit::JitCodeCreatorPool::Instance().AllCreators();
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ(jitcreators.size(), 0UL);
#else
EXPECT_EQ(jitcreators.size(), 25UL); EXPECT_EQ(jitcreators.size(), 25UL);
#endif
} }
TEST(JITKernel_pool, jitpool) { TEST(JITKernel_pool, jitpool) {
...@@ -894,13 +898,25 @@ TEST(JITKernel_pool, jitpool) { ...@@ -894,13 +898,25 @@ TEST(JITKernel_pool, jitpool) {
const auto& kers = jit::JitCodePool<jit::kVAdd>().Instance().AllKernels(); const auto& kers = jit::JitCodePool<jit::kVAdd>().Instance().AllKernels();
EXPECT_EQ(kers.size(), 0UL); EXPECT_EQ(kers.size(), 0UL);
jit::GetAllCandidateKernels<jit::VAddTuple<float>, CPUPlace>(3); jit::GetAllCandidateKernels<jit::VAddTuple<float>, CPUPlace>(3);
// after call GetAllCandidateKernels, it will create jitcode Automatically // after call GetAllCandidateKernels, it will create jitcode Automatically
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ(kers.size(), 0UL);
#else
EXPECT_EQ(kers.size(), 1UL); EXPECT_EQ(kers.size(), 1UL);
#endif
} }
TEST(JITKernel_pool, more) { TEST(JITKernel_pool, more) {
const auto& kers = jit::KernelPool::Instance().AllKernels(); const auto& kers = jit::KernelPool::Instance().AllKernels();
#if defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ(kers.size(), 10UL);
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_EQ(kers.size(), 21UL); EXPECT_EQ(kers.size(), 21UL);
#else
EXPECT_EQ(kers.size(), 8UL);
#endif
#endif
} }
TEST(JITKernel_pool, refer) { TEST(JITKernel_pool, refer) {
...@@ -915,7 +931,11 @@ TEST(JITKernel_helper, GetAllCandidateKernels) { ...@@ -915,7 +931,11 @@ TEST(JITKernel_helper, GetAllCandidateKernels) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(fp_kers.size(), 1UL); // refer EXPECT_GE(fp_kers.size(), 1UL); // refer
#else #else
#ifdef PADDLE_WITH_MKLML
EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer
#else
EXPECT_GE(fp_kers.size(), 2UL); // jitcode, refer
#endif
#endif #endif
auto db_kers = auto db_kers =
...@@ -923,18 +943,48 @@ TEST(JITKernel_helper, GetAllCandidateKernels) { ...@@ -923,18 +943,48 @@ TEST(JITKernel_helper, GetAllCandidateKernels) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(db_kers.size(), 1UL); // refer EXPECT_GE(db_kers.size(), 1UL); // refer
#else #else
#ifdef PADDLE_WITH_MKLML
EXPECT_GE(db_kers.size(), 2UL); // mkl, refer EXPECT_GE(db_kers.size(), 2UL); // mkl, refer
#else
EXPECT_GE(db_kers.size(), 1UL); // refer
#endif
#endif #endif
} }
TEST(JITKernel_helper, GetAllCandidateFuncsWithTypes) { TEST(JITKernel_helper, GetAllCandidateFuncsWithTypes) {
auto fp_kers = auto fp_kers =
jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<float>, CPUPlace>(10); jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<float>, CPUPlace>(10);
#if defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(fp_kers.size(), 1UL); // refer
#else
#if !defined(PADDLE_WITH_MKLML) || defined(_WIN32)
EXPECT_GE(fp_kers.size(), 2UL); // jitcode/mkl, refer
#else
EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer
#endif
#endif
auto db_kers = auto db_kers =
jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<double>, CPUPlace>(10); jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<double>, CPUPlace>(10);
#if defined(__APPLE__) || defined(__OSX__) || !defined(PADDLE_WITH_MKLML)
EXPECT_GE(db_kers.size(), 1UL); // refer
#else
EXPECT_GE(db_kers.size(), 2UL); // mkl, refer EXPECT_GE(db_kers.size(), 2UL); // mkl, refer
#endif
}
TEST(JITKernel_helper, KernelFuncs) {
auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 != nullptr);
EXPECT_TRUE(f1 == f2);
auto f3 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[5];
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_TRUE(f2 == f3);
#else
EXPECT_TRUE(f2 != f3);
#endif
} }
TEST(JITKernel_helper, GetAllCandidateFuncs) { TEST(JITKernel_helper, GetAllCandidateFuncs) {
...@@ -1011,6 +1061,134 @@ TEST(JITKernel_helper, attr) { ...@@ -1011,6 +1061,134 @@ TEST(JITKernel_helper, attr) {
EXPECT_EQ(out.str().size(), 14); EXPECT_EQ(out.str().size(), 14);
} }
// test keys
TEST(JITKernel_key, int) {
EXPECT_TRUE(jit::JitCodeKey<int>(2) == jit::JitCodeKey<int>(2));
EXPECT_TRUE(jit::JitCodeKey<int>(2) == jit::JitCodeKey<int64_t>(2));
EXPECT_TRUE(jit::JitCodeKey<int>(2) != jit::JitCodeKey<int>(3));
}
TEST(JITKernel_key, gru) {
jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr2(8, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
jit::gru_attr_t attr5(9, jit::kVTanh, jit::kVIdentity);
auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
auto key5 = jit::JitCodeKey<jit::gru_attr_t>(attr5);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key2 != key5);
EXPECT_TRUE(key3 != key4);
EXPECT_TRUE(key3 != key5);
EXPECT_TRUE(key4 != key5);
}
TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr2(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr5(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh, true);
jit::lstm_attr_t attr6(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh, true);
auto key1 = jit::JitCodeKey<jit::lstm_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::lstm_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::lstm_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::lstm_attr_t>(attr4);
auto key5 = jit::JitCodeKey<jit::lstm_attr_t>(attr5);
auto key6 = jit::JitCodeKey<jit::lstm_attr_t>(attr6);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key2 != key5);
EXPECT_TRUE(key3 != key4);
EXPECT_TRUE(key3 != key5);
EXPECT_TRUE(key4 != key5);
EXPECT_TRUE(key5 == key6);
}
TEST(JITKernel_key, seq_pool) {
jit::seq_pool_attr_t attr1(2, jit::SeqPoolType::kSum, 1);
jit::seq_pool_attr_t attr2(2, jit::SeqPoolType::kSum, 3);
jit::seq_pool_attr_t attr3(3, jit::SeqPoolType::kSum, 3);
jit::seq_pool_attr_t attr4(3, jit::SeqPoolType::kAvg, 3);
auto key1 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr4);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key3 != key4);
}
TEST(JITKernel_key, matmul) {
jit::matmul_attr_t attr1(1, 2, 3);
jit::matmul_attr_t attr2(1, 2, 3);
jit::matmul_attr_t attr3(1, 3, 3);
jit::matmul_attr_t attr4(2, 3, 4);
auto key1 = jit::JitCodeKey<jit::matmul_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::matmul_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::matmul_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::matmul_attr_t>(attr4);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key3 != key4);
}
TEST(JITKernel_key, emb_seq_pool) {
jit::emb_seq_pool_attr_t attr1(1, 2, 3, 4, 5, jit::SeqPoolType::kSum);
jit::emb_seq_pool_attr_t attr2(1, 2, 3, 4, 5, jit::SeqPoolType::kSum);
jit::emb_seq_pool_attr_t attr3(10, 2, 9, 8, 7, jit::SeqPoolType::kAvg);
jit::emb_seq_pool_attr_t attr4(10, 3, 9, 8, 7, jit::SeqPoolType::kSum);
jit::emb_seq_pool_attr_t attr5(1, 6, 3, 4, 5, jit::SeqPoolType::kSum);
auto key1 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr4);
auto key5 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr5);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key2 != key5);
EXPECT_TRUE(key4 != key5);
}
TEST(JITKernel_key, sgd) {
jit::sgd_attr_t attr1(1, 2, 3, 4, 5);
jit::sgd_attr_t attr2(1, 2, 3, 4, 5);
jit::sgd_attr_t attr3(9, 8, 7, 4, 6);
jit::sgd_attr_t attr4(1, 2, 3, 6, 5);
jit::sgd_attr_t attr5(10, 9, 8, 7, 6);
auto key1 = jit::JitCodeKey<jit::sgd_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::sgd_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::sgd_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::sgd_attr_t>(attr4);
auto key5 = jit::JitCodeKey<jit::sgd_attr_t>(attr5);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
EXPECT_TRUE(key3 != key5);
EXPECT_TRUE(key4 != key5);
}
// test kernerls // test kernerls
#define TestKernelVMul TestKernelXYZN #define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN #define TestKernelVAdd TestKernelXYZN
...@@ -1080,43 +1258,3 @@ TEST_CPU_KERNEL(MatMul); ...@@ -1080,43 +1258,3 @@ TEST_CPU_KERNEL(MatMul);
TEST_CPU_KERNEL(Softmax); TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(Sgd);
TEST_CPU_KERNEL(VBroadcast); TEST_CPU_KERNEL(VBroadcast);
TEST(JITKernel, kernel_func) {
auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 != nullptr);
EXPECT_TRUE(f1 == f2);
// TODO(TJ): check not equal
}
TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh);
auto key1 = jit::JitCodeKey<jit::lstm_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::lstm_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::lstm_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::lstm_attr_t>(attr4);
EXPECT_TRUE(key1 != key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
}
TEST(JITKernel_key, gru) {
jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr2(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
EXPECT_TRUE(key1 != key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册