提交 cfc83c14 编写于 作者: T tensor-tang

refine jitcodekey and enhance unit tests

test=develop
上级 6ff230a6
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -14,8 +14,8 @@
#include "paddle/fluid/operators/jit/gen/matmul.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <memory>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/sgd.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -36,7 +36,7 @@ inline typename std::enable_if<
const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
using Attr = typename KernelTuple::attr_type;
size_t key = JitCodeKey<Attr>(attr);
int64_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key).get();
......
......@@ -13,7 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h>
#include <xxhash.h> // XXH64: 13.8 GB/s
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -21,73 +21,46 @@ namespace operators {
namespace jit {
template <>
size_t JitCodeKey<int>(const int& d) {
int64_t JitCodeKey<int>(const int& d) {
return d;
}
template <>
size_t JitCodeKey<int64_t>(const int64_t& d) {
int64_t JitCodeKey<int64_t>(const int64_t& d) {
return d;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
// XXH64: 13.8 GB/s
size_t key = attr.d;
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
int64_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
return XXH64(&attr, sizeof(gru_attr_t), 0);
}
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) +
(act_type_convert(attr.act_cand) << act_type_shift);
int64_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
int keys[5] = {
attr.d, static_cast<int>(attr.act_gate), static_cast<int>(attr.act_cand),
static_cast<int>(attr.act_cell), static_cast<int>(attr.use_peephole)};
return XXH64(keys, sizeof(int) * 5, 0);
}
template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w;
constexpr int pool_type_shift = 3;
return (key << pool_type_shift) + static_cast<int>(attr.type);
int64_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
int keys[2] = {attr.w, static_cast<int>(attr.type)};
return XXH64(keys, sizeof(int) * 2, 0);
}
template <>
size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
size_t key = attr.m;
constexpr int shift = 21;
return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
int64_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
return XXH64(&attr, sizeof(int) * 3, 0); // m, n, k
}
template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
int64_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return attr.table_width;
}
template <>
size_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width;
}
......
......@@ -46,7 +46,7 @@ struct KernelKey {
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(const Attr& attr);
int64_t JitCodeKey(const Attr& attr);
} // namespace jit
} // namespace operators
......
......@@ -30,7 +30,7 @@ namespace jit {
template <KernelType KT>
class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr;
typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap;
typedef std::unordered_map<int64_t, GenBasePtr> JitCodeMap;
public:
JitCodePool() = default;
......@@ -41,9 +41,9 @@ class JitCodePool {
const JitCodeMap& AllKernels() { return codes_; }
bool Has(size_t key) const { return codes_.find(key) != codes_.end(); }
bool Has(int64_t key) const { return codes_.find(key) != codes_.end(); }
void Insert(size_t key, GenBasePtr value) {
void Insert(int64_t key, GenBasePtr value) {
codes_.emplace(key, std::move(value));
}
......
......@@ -17,6 +17,7 @@
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility> // for std::move
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
......
......@@ -886,7 +886,11 @@ void TestKernelVBroadcast() {
// test pool
TEST(JITKernel_pool, jitcreator) {
const auto& jitcreators = jit::JitCodeCreatorPool::Instance().AllCreators();
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ(jitcreators.size(), 0UL);
#else
EXPECT_EQ(jitcreators.size(), 25UL);
#endif
}
TEST(JITKernel_pool, jitpool) {
......@@ -894,13 +898,25 @@ TEST(JITKernel_pool, jitpool) {
const auto& kers = jit::JitCodePool<jit::kVAdd>().Instance().AllKernels();
EXPECT_EQ(kers.size(), 0UL);
jit::GetAllCandidateKernels<jit::VAddTuple<float>, CPUPlace>(3);
// after call GetAllCandidateKernels, it will create jitcode Automatically
// after call GetAllCandidateKernels, it will create jitcode Automatically
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ(kers.size(), 0UL);
#else
EXPECT_EQ(kers.size(), 1UL);
#endif
}
TEST(JITKernel_pool, more) {
const auto& kers = jit::KernelPool::Instance().AllKernels();
#if defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ(kers.size(), 10UL);
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_EQ(kers.size(), 21UL);
#else
EXPECT_EQ(kers.size(), 8UL);
#endif
#endif
}
TEST(JITKernel_pool, refer) {
......@@ -915,7 +931,11 @@ TEST(JITKernel_helper, GetAllCandidateKernels) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(fp_kers.size(), 1UL); // refer
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer
#else
EXPECT_GE(fp_kers.size(), 2UL); // jitcode, refer
#endif
#endif
auto db_kers =
......@@ -923,18 +943,48 @@ TEST(JITKernel_helper, GetAllCandidateKernels) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(db_kers.size(), 1UL); // refer
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_GE(db_kers.size(), 2UL); // mkl, refer
#else
EXPECT_GE(db_kers.size(), 1UL); // refer
#endif
#endif
}
TEST(JITKernel_helper, GetAllCandidateFuncsWithTypes) {
auto fp_kers =
jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<float>, CPUPlace>(10);
#if defined(__APPLE__) || defined(__OSX__)
EXPECT_GE(fp_kers.size(), 1UL); // refer
#else
#if !defined(PADDLE_WITH_MKLML) || defined(_WIN32)
EXPECT_GE(fp_kers.size(), 2UL); // jitcode/mkl, refer
#else
EXPECT_GE(fp_kers.size(), 3UL); // jitcode, mkl, refer
#endif
#endif
auto db_kers =
jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<double>, CPUPlace>(10);
#if defined(__APPLE__) || defined(__OSX__) || !defined(PADDLE_WITH_MKLML)
EXPECT_GE(db_kers.size(), 1UL); // refer
#else
EXPECT_GE(db_kers.size(), 2UL); // mkl, refer
#endif
}
TEST(JITKernel_helper, KernelFuncs) {
auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 != nullptr);
EXPECT_TRUE(f1 == f2);
auto f3 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[5];
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_TRUE(f2 == f3);
#else
EXPECT_TRUE(f2 != f3);
#endif
}
TEST(JITKernel_helper, GetAllCandidateFuncs) {
......@@ -1011,6 +1061,134 @@ TEST(JITKernel_helper, attr) {
EXPECT_EQ(out.str().size(), 14);
}
// test keys
TEST(JITKernel_key, int) {
EXPECT_TRUE(jit::JitCodeKey<int>(2) == jit::JitCodeKey<int>(2));
EXPECT_TRUE(jit::JitCodeKey<int>(2) == jit::JitCodeKey<int64_t>(2));
EXPECT_TRUE(jit::JitCodeKey<int>(2) != jit::JitCodeKey<int>(3));
}
TEST(JITKernel_key, gru) {
jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr2(8, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
jit::gru_attr_t attr5(9, jit::kVTanh, jit::kVIdentity);
auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
auto key5 = jit::JitCodeKey<jit::gru_attr_t>(attr5);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key2 != key5);
EXPECT_TRUE(key3 != key4);
EXPECT_TRUE(key3 != key5);
EXPECT_TRUE(key4 != key5);
}
TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr2(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr5(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh, true);
jit::lstm_attr_t attr6(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh, true);
auto key1 = jit::JitCodeKey<jit::lstm_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::lstm_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::lstm_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::lstm_attr_t>(attr4);
auto key5 = jit::JitCodeKey<jit::lstm_attr_t>(attr5);
auto key6 = jit::JitCodeKey<jit::lstm_attr_t>(attr6);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key2 != key5);
EXPECT_TRUE(key3 != key4);
EXPECT_TRUE(key3 != key5);
EXPECT_TRUE(key4 != key5);
EXPECT_TRUE(key5 == key6);
}
TEST(JITKernel_key, seq_pool) {
jit::seq_pool_attr_t attr1(2, jit::SeqPoolType::kSum, 1);
jit::seq_pool_attr_t attr2(2, jit::SeqPoolType::kSum, 3);
jit::seq_pool_attr_t attr3(3, jit::SeqPoolType::kSum, 3);
jit::seq_pool_attr_t attr4(3, jit::SeqPoolType::kAvg, 3);
auto key1 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr4);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key3 != key4);
}
TEST(JITKernel_key, matmul) {
jit::matmul_attr_t attr1(1, 2, 3);
jit::matmul_attr_t attr2(1, 2, 3);
jit::matmul_attr_t attr3(1, 3, 3);
jit::matmul_attr_t attr4(2, 3, 4);
auto key1 = jit::JitCodeKey<jit::matmul_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::matmul_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::matmul_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::matmul_attr_t>(attr4);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key3 != key4);
}
TEST(JITKernel_key, emb_seq_pool) {
jit::emb_seq_pool_attr_t attr1(1, 2, 3, 4, 5, jit::SeqPoolType::kSum);
jit::emb_seq_pool_attr_t attr2(1, 2, 3, 4, 5, jit::SeqPoolType::kSum);
jit::emb_seq_pool_attr_t attr3(10, 2, 9, 8, 7, jit::SeqPoolType::kAvg);
jit::emb_seq_pool_attr_t attr4(10, 3, 9, 8, 7, jit::SeqPoolType::kSum);
jit::emb_seq_pool_attr_t attr5(1, 6, 3, 4, 5, jit::SeqPoolType::kSum);
auto key1 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr4);
auto key5 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr5);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key2 != key4);
EXPECT_TRUE(key2 != key5);
EXPECT_TRUE(key4 != key5);
}
TEST(JITKernel_key, sgd) {
jit::sgd_attr_t attr1(1, 2, 3, 4, 5);
jit::sgd_attr_t attr2(1, 2, 3, 4, 5);
jit::sgd_attr_t attr3(9, 8, 7, 4, 6);
jit::sgd_attr_t attr4(1, 2, 3, 6, 5);
jit::sgd_attr_t attr5(10, 9, 8, 7, 6);
auto key1 = jit::JitCodeKey<jit::sgd_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::sgd_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::sgd_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::sgd_attr_t>(attr4);
auto key5 = jit::JitCodeKey<jit::sgd_attr_t>(attr5);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
EXPECT_TRUE(key3 != key5);
EXPECT_TRUE(key4 != key5);
}
// test kernerls
#define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
......@@ -1080,43 +1258,3 @@ TEST_CPU_KERNEL(MatMul);
TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL(Sgd);
TEST_CPU_KERNEL(VBroadcast);
TEST(JITKernel, kernel_func) {
auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 != nullptr);
EXPECT_TRUE(f1 == f2);
// TODO(TJ): check not equal
}
TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh);
auto key1 = jit::JitCodeKey<jit::lstm_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::lstm_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::lstm_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::lstm_attr_t>(attr4);
EXPECT_TRUE(key1 != key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
}
TEST(JITKernel_key, gru) {
jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr2(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
EXPECT_TRUE(key1 != key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册