提交 802f362a 编写于 作者: T tensor-tang

unify the kernelfuncs cache and add unit test

test=develop
上级 36e2d324
...@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>, auto ker = jit::KernelFuncs<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
platform::CPUPlace>(tag_num); platform::CPUPlace>::Cache()
.At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
......
...@@ -110,8 +110,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -110,8 +110,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>, auto multiply =
platform::CPUPlace>(0); jit::KernelFuncs<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
platform::CPUPlace>::Cache()
.At(0);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
......
...@@ -52,8 +52,10 @@ struct EmbeddingVSumFunctor { ...@@ -52,8 +52,10 @@ struct EmbeddingVSumFunctor {
out_width, jit::SeqPoolType::kSum); out_width, jit::SeqPoolType::kSum);
for (size_t i = 0; i != ids_lod.size() - 1; ++i) { for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i]; attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>, auto emb_seqpool =
platform::CPUPlace>(attr); jit::KernelFuncs<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width, emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr); &attr);
} }
...@@ -135,8 +137,10 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -135,8 +137,10 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace()); T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>(); const T *d_output_data = d_output->data<T>();
auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>, auto vbroadcast =
platform::CPUPlace>(out_width); jit::KernelFuncs<jit::kVBroadcast, jit::VBroadcastTuples<T>,
platform::CPUPlace>::Cache()
.At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
const T *src = d_output_data + i * out_width; const T *src = d_output_data + i * out_width;
......
...@@ -195,12 +195,15 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -195,12 +195,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \ jit::gru_t one_step; \
auto ComputeH1 = \ auto ComputeH1 = jit::KernelFuncs<jit::kGRUH1, jit::GRUTuples<T>, \
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ platform::CPUPlace>::Cache() \
auto ComputeHtPart1 = \ .At(attr); \
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ auto ComputeHtPart1 = jit::KernelFuncs<jit::kGRUHtPart1, jit::GRUTuples<T>, \
auto ComputeHtPart2 = \ platform::CPUPlace>::Cache() \
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ .At(attr); \
auto ComputeHtPart2 = jit::KernelFuncs<jit::kGRUHtPart2, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
......
...@@ -257,10 +257,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -257,10 +257,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
jit::lstm_t one_step; \ jit::lstm_t one_step; \
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \ auto ComputeC1H1 = jit::KernelFuncs<jit::kLSTMC1H1, jit::LSTMTuples<T>, \
jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \ platform::CPUPlace>::Cache() \
auto ComputeCtHt = \ .At(attr); \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr) auto ComputeCtHt = jit::KernelFuncs<jit::kLSTMCtHt, jit::LSTMTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
......
...@@ -81,10 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() { ...@@ -81,10 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() {
template <typename T> template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y, static void fc_relu(const T* x, const T* w, const T* b, T* y,
const jit::matmul_attr_t& attr) { const jit::matmul_attr_t& attr) {
auto matmul = auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>,
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr); platform::CPUPlace>::Cache()
auto addbias_relu = .At(attr);
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(attr.n); auto addbias_relu = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(attr.n);
matmul(x, w, y, &attr); matmul(x, w, y, &attr);
T* dst = y; T* dst = y;
for (int i = 0; i < attr.m; ++i) { for (int i = 0; i < attr.m; ++i) {
......
...@@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> { ...@@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt; attr.type = jit::SeqPoolType::kSqrt;
} }
auto seqpool = auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>,
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( platform::CPUPlace>::Cache()
attr); .At(attr);
size_t n = ins.size(); size_t n = ins.size();
size_t dst_step_size = n * w; size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
......
...@@ -93,20 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> { ...@@ -93,20 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
attr.n = y_dims[1]; attr.n = y_dims[1];
int o_numel = attr.m * attr.n; int o_numel = attr.m * attr.n;
auto vsquare_x = auto vsquare_x = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.m * platform::CPUPlace>::Cache()
attr.k); .At(attr.m * attr.k);
auto vsquare_y = auto vsquare_y = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.k * platform::CPUPlace>::Cache()
attr.n); .At(attr.k * attr.n);
auto vsquare_xy = auto vsquare_xy = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel); platform::CPUPlace>::Cache()
auto vsub = .At(o_numel);
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel); auto vsub = jit::KernelFuncs<jit::kVSub, jit::XYZNTuples<T>,
auto vscal = platform::CPUPlace>::Cache()
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel); .At(o_numel);
auto matmul = auto vscal = jit::KernelFuncs<jit::kVScal, jit::AXYNTuples<T>,
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr); platform::CPUPlace>::Cache()
.At(o_numel);
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* y_data = y->data<T>(); const T* y_data = y->data<T>();
......
...@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n") ...@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n") file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n") file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place xxhash)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc) list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
......
...@@ -142,7 +142,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -142,7 +142,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
} }
} }
// Test result from Get function // Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr); auto tgt = jit::KernelFuncs<KT, KernelTuples, PlaceType>::Cache().At(attr);
if (!tgt) { if (!tgt) {
LOG(FATAL) << "Target can not be empty!"; LOG(FATAL) << "Target can not be empty!";
} }
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
extern "C" {
#include <xxhash.h>
}
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -127,23 +130,36 @@ class KernelFuncs { ...@@ -127,23 +130,36 @@ class KernelFuncs {
return g_func_cache; return g_func_cache;
} }
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); } // the exposed interface to use
typename KernelTuples::func_type At(
void Insert(int key, typename KernelTuples::func_type func) { const typename KernelTuples::attr_type& attr) {
funcs_.emplace(key, func); // XXH64: 13.8 GB/s
} int64_t key = XXH64(&attr, sizeof(typename KernelTuples::attr_type), 0);
typename KernelTuples::func_type At(int key) {
if (Has(key)) { if (Has(key)) {
return funcs_.at(key); return funcs_.at(key);
} }
auto func = Get<KT, KernelTuples, PlaceType>(key); // If do not have this attr in cache,
// then could run some runtime benchmark of this attr and save the best one.
// Here just get the offline benchmarked best one.
auto func = Get<KT, KernelTuples, PlaceType>(attr);
Insert(key, func); Insert(key, func);
return func; return func;
} }
typename KernelTuples::func_type operator[](
const typename KernelTuples::attr_type& attr) {
return At(attr);
}
protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuples::func_type func) {
funcs_.emplace(key, func);
}
private: private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_; std::unordered_map<int64_t, typename KernelTuples::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncs); DISABLE_COPY_AND_ASSIGN(KernelFuncs);
}; };
......
...@@ -462,7 +462,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -462,7 +462,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
} }
// test result from Get function // test result from Get function
// VLOG(10) << "Test Get function "; // VLOG(10) << "Test Get function ";
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr); auto tgt = jit::KernelFuncs<KT, KernelTuples, PlaceType>::Cache().At(attr);
test(tgt, args...); test(tgt, args...);
} }
...@@ -845,7 +845,9 @@ void TestKernelNCHW16CMulNCTuples() { ...@@ -845,7 +845,9 @@ void TestKernelNCHW16CMulNCTuples() {
T* zjit_data = zjit.data(); T* zjit_data = zjit.data();
constexpr int simd_width = ZMM_FLOAT_BLOCK; constexpr int simd_width = ZMM_FLOAT_BLOCK;
int C = c / simd_width; int C = c / simd_width;
auto tgt = jit::Get<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0); auto tgt =
jit::KernelFuncs<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>::Cache().At(
0);
auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0); auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
EXPECT_TRUE(tgt != nullptr); EXPECT_TRUE(tgt != nullptr);
...@@ -970,7 +972,7 @@ void TestKernelVBroadcastTuples() { ...@@ -970,7 +972,7 @@ void TestKernelVBroadcastTuples() {
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \ #define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \ TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \ TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \ TestKernel##test_tuple<jit::kernel_type, double, CPUPlace>(); \
} }
TEST_CPU_KERNEL(XYZNTuples, kVMul); TEST_CPU_KERNEL(XYZNTuples, kVMul);
...@@ -1041,4 +1043,18 @@ TEST(JITKernel_key, gru) { ...@@ -1041,4 +1043,18 @@ TEST(JITKernel_key, gru) {
EXPECT_TRUE(key2 == key3); EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4); EXPECT_TRUE(key3 != key4);
} }
// TODO(TJ): add more test about key and pool
TEST(JITKernel, kernel_func) {
auto f1 =
jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
.At(3);
auto f2 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>,
CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 == f2);
f1 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
.At(3);
f2 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
.At(4);
EXPECT_TRUE(f1 != f2);
}
...@@ -229,9 +229,9 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -229,9 +229,9 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(scale->numel(), right); PADDLE_ENFORCE_EQ(scale->numel(), right);
PADDLE_ENFORCE_EQ(bias->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker = auto ker = jit::KernelFuncs<jit::kLayerNorm, jit::LayerNormTuples<T>,
jit::Get<jit::kLayerNorm, jit::LayerNormTuples<T>, platform::CPUPlace>( platform::CPUPlace>::Cache()
right); .At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon), right); static_cast<const float>(epsilon), right);
......
...@@ -255,9 +255,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -255,9 +255,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
jit::seq_pool_attr_t attr( jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]), static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum); jit::SeqPoolType::kSum);
auto seqpool = auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>,
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( platform::CPUPlace>::Cache()
attr); .At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]); attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr); seqpool(src, dst, &attr);
......
...@@ -47,8 +47,9 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -47,8 +47,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
int64_t rows_idx = 0; int64_t rows_idx = 0;
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd = auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>,
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr); platform::CPUPlace>::Cache()
.At(attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced. // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
...@@ -81,8 +82,9 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -81,8 +82,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
attr.selected_rows_size = grad_rows.size(); attr.selected_rows_size = grad_rows.size();
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width); PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
auto sgd = auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>,
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr); platform::CPUPlace>::Cache()
.At(attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr); sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册